In [1]:
import torch

def sorted_indices_by_euclidean_distance_torch(coords: torch.Tensor):
    """
    coords: LongTensor of shape (N,2) with (row, col) pairs.
    Returns: dict {i: LongTensor([i, j1, j2, ...])} with all N indices
             sorted by distance from coords[i].
    """
    N = coords.shape[0]
    dist_dict = {}
    # Convert to float for distance computation
    coords_f = coords.float()
    for i in range(N):
        # compute vector differences
        diff = coords_f - coords_f[i:i+1]         # shape (N,2)
        dists = diff.norm(dim=1)                 # shape (N,)
        order = torch.argsort(dists)             # indices sorted by distance
        dist_dict[i] = order                     # LongTensor of length N
    return dist_dict

def build_distance_dict_torch(shape=(8, 8), traversal=None):
    """
    shape: tuple (H, W)
    traversal: optional list or tensor of length H*W giving (row, col) coords
               in the order your 1D array uses. If None, assumes standard
               row-major flatten.
    Returns: dict from flat-index -> LongTensor of neighbors sorted by distance.
    """
    H, W = shape
    N = H * W

    if traversal is None:
        # build row‐major coords
        # coords[i] = (i//W, i%W)
        rows = torch.arange(N, dtype=torch.long) // W
        cols = torch.arange(N, dtype=torch.long) % W
        coords = torch.stack([rows, cols], dim=1)  # (N,2)
    else:
        coords = torch.tensor(traversal, dtype=torch.long)
        assert coords.shape == (N, 2), "Traversal must be (H*W, 2)"

    return sorted_indices_by_euclidean_distance_torch(coords)


In [2]:
dist_dict = build_distance_dict_torch((8,8))
savePath = '/home3/skaasyap/willett/outputs/'
torch.save(dist_dict, f'{savePath}dist_dict.pt')

In [5]:
import torch

import torch
import random

def channel_specaugment_masks(
    x,            # tensor [B, T, D]
    num_masks, max_channels_to_mask,
    dist_dict,
    num_channels=64,
    features_per_channel=2
):
    B, T, D = x.shape
    device = x.device
    masks = torch.zeros(B, D, dtype=torch.bool, device=device)

    # build a [B, num_channels] of uniform weights
    weights = torch.ones(B, num_channels, device=device)

    # now sample *per-row*:
    # starts1: [B, N], starts2: [B, M]
    starts1 = torch.multinomial(weights, num_masks, replacement=False)
    starts2 = torch.multinomial(weights, num_masks, replacement=False)
    
    # widths per mask, per sample
    widths1 = torch.randint(0, max_channels_to_mask+1, (B, num_masks), device=device)
    widths2 = torch.randint(0, max_channels_to_mask+1, (B, num_masks), device=device)
    
    # precompute feature-block offsets
    off1 = [feat * num_channels for feat in range(features_per_channel)]
    off2 = [features_per_channel * num_channels + feat * num_channels
            for feat in range(features_per_channel)]
    

    for b in range(B):
        # electrode 1
        for start_ch, w in zip(starts1[b], widths1[b]):
            w = int(w)
            if w == 0: 
                continue
            nearest = dist_dict[int(start_ch.item())][:w]
            idxs = torch.tensor(nearest, dtype=torch.long, device=device)
            for base in off1:
                masks[b, base + idxs] = True

        # electrode 2
        for start_ch, w in zip(starts2[b], widths2[b]):
            w = int(w)
            if w == 0:
                continue
            nearest = dist_dict[int(start_ch.item())][:w]
            idxs = torch.tensor(nearest, dtype=torch.long, device=device)
            for base in off2:
                masks[b, base + idxs] = True

    # broadcast mask over time
    masks = masks.unsqueeze(1).expand(-1, T, -1)
    X_masked = x.clone()
    X_masked[masks] = 0
    return X_masked, masks


In [11]:
torch.manual_seed(0)
random.seed(0)

B, T, D = 5, 10, 256
x = torch.randn(B, T, D)

X_masked, masks = channel_specaugment_masks(x, 20, 5, dist_dict)

  idxs = torch.tensor(nearest, dtype=torch.long, device=device)
  idxs = torch.tensor(nearest, dtype=torch.long, device=device)


In [12]:
X_masked.shape

torch.Size([5, 10, 256])

In [23]:
X_masked[0, :, 2+64]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [22]:
masks[0, 0, :]

tensor([False, False,  True, False,  True, False, False, False,  True, False,
        False, False,  True, False, False, False,  True, False,  True,  True,
         True, False, False,  True,  True,  True, False,  True,  True, False,
         True,  True, False, False,  True,  True,  True,  True,  True,  True,
        False, False, False,  True, False, False,  True, False, False, False,
        False,  True, False, False,  True,  True, False, False,  True,  True,
         True,  True,  True,  True, False, False,  True, False,  True, False,
        False, False,  True, False, False, False,  True, False, False, False,
         True, False,  True,  True,  True, False, False,  True,  True,  True,
        False,  True,  True, False,  True,  True, False, False,  True,  True,
         True,  True,  True,  True, False, False, False,  True, False, False,
         True, False, False, False, False,  True, False, False,  True,  True,
        False, False,  True,  True,  True,  True,  True,  True, 

In [1]:
import torch

class TimeMasker:
    def __init__(self, max_mask_pct=0.5, num_masks=1):
        self.max_mask_pct = max_mask_pct
        self.num_masks = num_masks

    def apply_time_masking(self, X, X_len):
        """
        Fully vectorized time masking (no loops at all).

        Args:
            X: (B, P, D) input tensor
            X_len: (B,) valid lengths in timepoints

        Returns:
            X_masked: (B, P, D) with masked patches
            mask: (B, P) boolean mask of where values were masked
            masked_indices: list of 1D LongTensors, each with indices of masked patches per batch
            unmasked_indices: list of 1D LongTensors, each with indices of unmasked patches per batch
        """
        B, P, D = X.shape
        device = X.device

        valid_lens = X_len
        max_mask_lens = (self.max_mask_pct * valid_lens).long()  # (B,)

        B_rep = B * self.num_masks
        valid_lens_rep = valid_lens.repeat_interleave(self.num_masks)           
        max_mask_lens_rep = max_mask_lens.repeat_interleave(self.num_masks)

        t = (torch.rand(B_rep, device=device) * (max_mask_lens_rep + 1).float()).floor().long()
        max_start = (valid_lens_rep - t + 1).clamp(min=1)
        t0 = (torch.rand(B_rep, device=device) * max_start.float()).floor().long()

        arange = torch.arange(P, device=device).unsqueeze(0)         
        t0_exp = t0.unsqueeze(1)                                     
        t1_exp = (t0 + t).unsqueeze(1)                               
        mask_chunks = (arange >= t0_exp) & (arange < t1_exp)        

        batch_idx = torch.arange(B, device=device).repeat_interleave(self.num_masks)
        patch_idx = mask_chunks.nonzero(as_tuple=False)
        b_indices = batch_idx[patch_idx[:, 0]]
        p_indices = patch_idx[:, 1]

        mask = torch.zeros(B, P, dtype=torch.bool, device=device)
        mask[b_indices, p_indices] = True

        X_masked = X.clone()
        X_masked[mask] = 0

        # Get masked/unmasked indices per batch
        masked_indices = [mask[b].nonzero(as_tuple=True)[0] for b in range(B)]
        unmasked_indices = [~mask[b].nonzero(as_tuple=True)[0] for b in range(B)]

        return X_masked, mask, masked_indices, unmasked_indices

# Test setup
B, T, F = 1, 100, 256
X = torch.randn(B, T, F)
X_len = torch.tensor([50])  # Only first 50 are "valid"

masker = TimeMasker(max_mask_pct=0.5, num_masks=1)
X_masked, mask, masked_idx, unmasked_idx = masker.apply_time_masking(X, X_len)

# Print some quick diagnostics
print("Input shape:", X.shape)
print("Masked output shape:", X_masked.shape)
print("Mask shape:", mask.shape)
print("Masked indices:", masked_idx)
print("Mask sum:", mask.sum().item())  # Should be <= 25 since max_mask_pct is 0.5 of X_len (50)


Input shape: torch.Size([1, 100, 256])
Masked output shape: torch.Size([1, 100, 256])
Mask shape: torch.Size([1, 100])
Masked indices: [tensor([28, 29, 30, 31, 32, 33, 34])]
Mask sum: 7


  from .autonotebook import tqdm as notebook_tqdm
