In [2]:
import torch
import torch


class DummySpecAug:
    def __init__(self, mask_token, patch_height=1, max_mask_pct=0.2):
        self.mask_token = mask_token
        self.patch_height = patch_height
        self.max_mask_pct = max_mask_pct

    def apply_specaugment_mask(self, X, X_len, num_masks=2):
        """
        Fully vectorized SpecAugment-style time masking (no loops).
        Args:
            X: (B, P, D)
            X_len: (B,) in timepoints
        Returns:
            X_masked: (B, P, D), mask: (B, P)
        """
        B, P, D = X.shape
        device = X.device

        valid_lens = (X_len // self.patch_height).to(device)
        max_mask_lens = (self.max_mask_pct * valid_lens).clamp(min=1).long()

        B_rep = B * num_masks
        valid_lens_rep = valid_lens.repeat_interleave(num_masks)
        max_mask_lens_rep = max_mask_lens.repeat_interleave(num_masks)

        t = (torch.rand(B_rep, device=device) * max_mask_lens_rep.float()).floor().long() + 1
        max_start = (valid_lens_rep - t + 1).clamp(min=1)
        t0 = (torch.rand(B_rep, device=device) * max_start.float()).floor().long()

        # Build flattened mask indices
        arange = torch.arange(P, device=device).unsqueeze(0)  # (1, P)
        t0_exp = t0.unsqueeze(1)                              # (B_rep, 1)
        t1_exp = (t0 + t).unsqueeze(1)                        # (B_rep, 1)
        mask_chunks = (arange >= t0_exp) & (arange < t1_exp)  # (B_rep, P)

        # Now gather the flat indices to write into (B, P)
        batch_idx = torch.arange(B, device=device).repeat_interleave(num_masks)  # (B_rep,)
        patch_idx = mask_chunks.nonzero(as_tuple=False)  # (N_masked, 2): [mask_row, patch_col]
        b_indices = batch_idx[patch_idx[:, 0]]           # Map B_rep index → actual batch index
        p_indices = patch_idx[:, 1]                      # patch index

        # Set those positions to True in the full mask
        mask = torch.zeros(B, P, dtype=torch.bool, device=device)
        mask[b_indices, p_indices] = True

        # Apply mask token
        X_masked = X.clone()
        X_masked[mask] = self.mask_token

        return X_masked, mask


# --------------------------
# Create dummy input
B, P, D = 2, 12, 4  # batch, patches, dim
X = torch.randn(B, P, D)
X_len = torch.tensor([12, 10])  # valid lengths (in timepoints)

# Define mask token and module
mask_token = torch.tensor([-99.0] * D)
specaug = DummySpecAug(mask_token=mask_token, patch_height=1, max_mask_pct=0.3)

# Apply
X_masked, mask = specaug.apply_specaugment_mask(X, X_len, num_masks=2)

# --------------------------
# Inspect the result
for b in range(B):
    print(f"\nSample {b}:")
    print("Original:")
    print(X[b])
    print("Masked:")
    print(X_masked[b])
    print("Mask positions:", mask[b].nonzero(as_tuple=True)[0].tolist())


IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [4], [4, 12]