In [1]:
import torch

def sorted_indices_by_euclidean_distance_torch(coords: torch.Tensor):
    """
    coords: LongTensor of shape (N,2) with (row, col) pairs.
    Returns: dict {i: LongTensor([i, j1, j2, ...])} with all N indices
             sorted by distance from coords[i].
    """
    N = coords.shape[0]
    dist_dict = {}
    # Convert to float for distance computation
    coords_f = coords.float()
    for i in range(N):
        # compute vector differences
        diff = coords_f - coords_f[i:i+1]         # shape (N,2)
        dists = diff.norm(dim=1)                 # shape (N,)
        order = torch.argsort(dists)             # indices sorted by distance
        dist_dict[i] = order                     # LongTensor of length N
    return dist_dict

def build_distance_dict_torch(shape=(8, 8), traversal=None):
    """
    shape: tuple (H, W)
    traversal: optional list or tensor of length H*W giving (row, col) coords
               in the order your 1D array uses. If None, assumes standard
               row-major flatten.
    Returns: dict from flat-index -> LongTensor of neighbors sorted by distance.
    """
    H, W = shape
    N = H * W

    if traversal is None:
        # build row‐major coords
        # coords[i] = (i//W, i%W)
        rows = torch.arange(N, dtype=torch.long) // W
        cols = torch.arange(N, dtype=torch.long) % W
        coords = torch.stack([rows, cols], dim=1)  # (N,2)
    else:
        coords = torch.tensor(traversal, dtype=torch.long)
        assert coords.shape == (N, 2), "Traversal must be (H*W, 2)"

    return sorted_indices_by_euclidean_distance_torch(coords)


In [2]:
dist_dict = build_distance_dict_torch((8,8))
savePath = '/home3/skaasyap/willett/outputs/'
torch.save(dist_dict, f'{savePath}dist_dict.pt')

In [5]:
import torch

import torch
import random

def channel_specaugment_masks(
    x,            # tensor [B, T, D]
    num_masks, max_channels_to_mask,
    dist_dict,
    num_channels=64,
    features_per_channel=2
):
    B, T, D = x.shape
    device = x.device
    masks = torch.zeros(B, D, dtype=torch.bool, device=device)

    # build a [B, num_channels] of uniform weights
    weights = torch.ones(B, num_channels, device=device)

    # now sample *per-row*:
    # starts1: [B, N], starts2: [B, M]
    starts1 = torch.multinomial(weights, num_masks, replacement=False)
    starts2 = torch.multinomial(weights, num_masks, replacement=False)
    
    # widths per mask, per sample
    widths1 = torch.randint(0, max_channels_to_mask+1, (B, num_masks), device=device)
    widths2 = torch.randint(0, max_channels_to_mask+1, (B, num_masks), device=device)
    
    # precompute feature-block offsets
    off1 = [feat * num_channels for feat in range(features_per_channel)]
    off2 = [features_per_channel * num_channels + feat * num_channels
            for feat in range(features_per_channel)]
    

    for b in range(B):
        # electrode 1
        for start_ch, w in zip(starts1[b], widths1[b]):
            w = int(w)
            if w == 0: 
                continue
            nearest = dist_dict[int(start_ch.item())][:w]
            idxs = torch.tensor(nearest, dtype=torch.long, device=device)
            for base in off1:
                masks[b, base + idxs] = True

        # electrode 2
        for start_ch, w in zip(starts2[b], widths2[b]):
            w = int(w)
            if w == 0:
                continue
            nearest = dist_dict[int(start_ch.item())][:w]
            idxs = torch.tensor(nearest, dtype=torch.long, device=device)
            for base in off2:
                masks[b, base + idxs] = True

    # broadcast mask over time
    masks = masks.unsqueeze(1).expand(-1, T, -1)
    X_masked = x.clone()
    X_masked[masks] = 0
    return X_masked, masks


In [11]:
torch.manual_seed(0)
random.seed(0)

B, T, D = 5, 10, 256
x = torch.randn(B, T, D)

X_masked, masks = channel_specaugment_masks(x, 20, 5, dist_dict)

  idxs = torch.tensor(nearest, dtype=torch.long, device=device)
  idxs = torch.tensor(nearest, dtype=torch.long, device=device)


In [12]:
X_masked.shape

torch.Size([5, 10, 256])

In [23]:
X_masked[0, :, 2+64]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [22]:
masks[0, 0, :]

tensor([False, False,  True, False,  True, False, False, False,  True, False,
        False, False,  True, False, False, False,  True, False,  True,  True,
         True, False, False,  True,  True,  True, False,  True,  True, False,
         True,  True, False, False,  True,  True,  True,  True,  True,  True,
        False, False, False,  True, False, False,  True, False, False, False,
        False,  True, False, False,  True,  True, False, False,  True,  True,
         True,  True,  True,  True, False, False,  True, False,  True, False,
        False, False,  True, False, False, False,  True, False, False, False,
         True, False,  True,  True,  True, False, False,  True,  True,  True,
        False,  True,  True, False,  True,  True, False, False,  True,  True,
         True,  True,  True,  True, False, False, False,  True, False, False,
         True, False, False, False, False,  True, False, False,  True,  True,
        False, False,  True,  True,  True,  True,  True,  True, 