In [107]:
import torch

def sorted_indices_by_euclidean_distance_torch(coords: torch.Tensor):
    """
    coords: LongTensor of shape (N,2) with (row, col) pairs.
    Returns: dict {i: LongTensor([i, j1, j2, ...])} with all N indices
             sorted by distance from coords[i].
    """
    N = coords.shape[0]
    dist_dict = {}
    # Convert to float for distance computation
    coords_f = coords.float()
    for i in range(N):
        # compute vector differences
        diff = coords_f - coords_f[i:i+1]         # shape (N,2)
        dists = diff.norm(dim=1)                 # shape (N,)
        order = torch.argsort(dists)             # indices sorted by distance
        dist_dict[i] = order                     # LongTensor of length N
    return dist_dict

def build_distance_dict_torch(shape=(8, 8), traversal=None):
    """
    shape: tuple (H, W)
    traversal: optional list or tensor of length H*W giving (row, col) coords
               in the order your 1D array uses. If None, assumes standard
               row-major flatten.
    Returns: dict from flat-index -> LongTensor of neighbors sorted by distance.
    """
    H, W = shape
    N = H * W

    if traversal is None:
        # build row‐major coords
        # coords[i] = (i//W, i%W)
        rows = torch.arange(N, dtype=torch.long) // W
        cols = torch.arange(N, dtype=torch.long) % W
        coords = torch.stack([rows, cols], dim=1)  # (N,2)
    else:
        coords = torch.tensor(traversal, dtype=torch.long)
        assert coords.shape == (N, 2), "Traversal must be (H*W, 2)"

    return sorted_indices_by_euclidean_distance_torch(coords)


In [110]:
dist_dict = build_distance_dict_torch((8,8))
savePath = '/home3/skaasyap/willett/outputs/'
torch.save(dist_dict, f'{savePath}dist_dict.pt')

In [106]:
import torch

import torch
import random

def channel_specaugment_masks(
    x,            # tensor [B, T, D]
    num_masks, max_channels_to_mask,
    dist_dict,
    num_channels=64,
    features_per_channel=2
):
    B, T, D = x.shape
    device = x.device
    masks = torch.zeros(B, D, dtype=torch.bool, device=device)

    # build a [B, num_channels] of uniform weights
    weights = torch.ones(B, num_channels, device=device)

    # now sample *per-row*:
    # starts1: [B, N], starts2: [B, M]
    starts1 = torch.multinomial(weights, num_masks, replacement=False)
    starts2 = torch.multinomial(weights, num_masks, replacement=False)
    
    # widths per mask, per sample
    widths1 = torch.randint(0, max_channels_to_mask+1, (B, num_masks), device=device)
    widths2 = torch.randint(0, max_channels_to_mask+1, (B, num_masks), device=device)
    

    # precompute feature-block offsets
    off1 = [feat * num_channels for feat in range(features_per_channel)]
    off2 = [features_per_channel * num_channels + feat * num_channels
            for feat in range(features_per_channel)]
    

    for b in range(B):
        # electrode 1
        for start_ch, w in zip(starts1[b], widths1[b]):
            w = int(w)
            if w == 0: 
                continue
            nearest = dist_dict[int(start_ch.item())][:w]
            idxs = torch.tensor(nearest, dtype=torch.long, device=device)
            for base in off1:
                masks[b, base + idxs] = True

        # electrode 2
        for start_ch, w in zip(starts2[b], widths2[b]):
            w = int(w)
            if w == 0:
                continue
            nearest = dist_dict[int(start_ch.item())][:w]
            idxs = torch.tensor(nearest, dtype=torch.long, device=device)
            for base in off2:
                masks[b, base + idxs] = True

    # broadcast mask over time
    return masks.unsqueeze(1).expand(-1, T, -1)


In [105]:
torch.manual_seed(0)
random.seed(0)

B, T, D = 5, 10, 256
x = torch.randn(B, T, D)

mask = channel_specaugment_masks(x, 2, 5, dist_dict)

[0, 64] [128, 192]


  idxs = torch.tensor(nearest, dtype=torch.long, device=device)
  idxs = torch.tensor(nearest, dtype=torch.long, device=device)


In [78]:
torch.where(mask[0,0])[0].shape

torch.Size([20])

In [34]:
device = 'cpu'
B, P, C = 10, 100, 4
mask = torch.zeros(B, P, C, dtype=torch.bool, device=device)
mask.shape


torch.Size([10, 100, 4])

In [47]:
Z, B, N = 5, 10, 8
hi = torch.Tensor([random.randint(0, Z) for _ in range(B*N)]).reshape(10,8)
for row in hi:
    print(row.shape)

torch.Size([8])
torch.Size([8])
torch.Size([8])
torch.Size([8])
torch.Size([8])
torch.Size([8])
torch.Size([8])
torch.Size([8])
torch.Size([8])
torch.Size([8])


In [28]:
channel_specaugment_indices(3, 2, dist_dict=dist_dict)

[42] [18, 58] [0] [0, 1]


tensor([186, 250])

In [23]:
dist_dict[18]

tensor([18, 26, 19, 10, 17, 25,  9, 11, 27,  2, 20, 34, 16, 12, 24, 28,  8, 33,
        35,  3,  1,  0, 32, 36,  4, 21, 42, 43, 29, 41, 13, 44, 40, 37,  5, 22,
        50, 30, 14, 51, 49, 45, 38, 52,  6, 48, 23, 58, 53, 46, 15, 31, 57, 59,
         7, 56, 39, 60, 54, 47, 61, 55, 62, 63])

In [29]:
250-186

64