In [3]:
import numpy as np
def mask_consecutive_channels(X, max_mask=8):
    """
    Masks consecutive channels in both first and last 128 channels separately
    """
    X_masked = X.copy()  # Changed from .clone() for numpy compatibility
    batch_size, _, num_channels = X.shape

    area_6v_superior = np.array([
        [62,  51,  43,  35,  94,  87,  79,  78],
        [60,  53,  41,  33,  95,  86,  77,  76],
        [63,  54,  47,  44,  93,  84,  75,  74],
        [58,  55,  48,  40,  92,  85,  73,  72],
        [59,  45,  46,  38,  91,  82,  71,  70],
        [61,  49,  42,  36,  90,  83,  69,  68],
        [56,  52,  39,  34,  89,  81,  67,  66],
        [57,  50,  37,  32,  88,  80,  65,  64]
    ])

    area_6v_inferior = np.array([
        [125, 126, 112, 103,  31,  28,  11,  8],
        [123, 124, 110, 102,  29,  26,   9,  5],
        [121, 122, 109, 101,  27,  19,  18,  4],
        [119, 120, 108, 100,  25,  15,  12,  6],
        [117, 118, 107,  99,  23,  13,  10,  3],
        [115, 116, 106,  97,  21,  20,   7,  2],
        [113, 114, 105,  98,  17,  24,  14,  0],
        [127, 111, 104,  96,  30,  22,  16,  1]
    ])

In [7]:
area_6v_superior = np.array([
    [62,  51,  43,  35,  94,  87,  79,  78],
    [60,  53,  41,  33,  95,  86,  77,  76],
    [63,  54,  47,  44,  93,  84,  75,  74],
    [58,  55,  48,  40,  92,  85,  73,  72],
    [59,  45,  46,  38,  91,  82,  71,  70],
    [61,  49,  42,  36,  90,  83,  69,  68],
    [56,  52,  39,  34,  89,  81,  67,  66],
    [57,  50,  37,  32,  88,  80,  65,  64]
])

area_6v_inferior = np.array([
    [125, 126, 112, 103,  31,  28,  11,  8],
    [123, 124, 110, 102,  29,  26,   9,  5],
    [121, 122, 109, 101,  27,  19,  18,  4],
    [119, 120, 108, 100,  25,  15,  12,  6],
    [117, 118, 107,  99,  23,  13,  10,  3],
    [115, 116, 106,  97,  21,  20,   7,  2],
    [113, 114, 105,  98,  17,  24,  14,  0],
    [127, 111, 104,  96,  30,  22,  16,  1]
])

area_6v_superior[4,4]

91

In [98]:
def mask_electrodes(X, max_mask_size):
    
    X = X.clone()
    
    batch_size, _, _  = X.shape
    
    area_6v_superior = np.array([
    [62,  51,  43,  35,  94,  87,  79,  78],
    [60,  53,  41,  33,  95,  86,  77,  76],
    [63,  54,  47,  44,  93,  84,  75,  74],
    [58,  55,  48,  40,  92,  85,  73,  72],
    [59,  45,  46,  38,  91,  82,  71,  70],
    [61,  49,  42,  36,  90,  83,  69,  68],
    [56,  52,  39,  34,  89,  81,  67,  66],
    [57,  50,  37,  32,  88,  80,  65,  64]
    ])

    area_6v_inferior = np.array([
        [125, 126, 112, 103,  31,  28,  11,  8],
        [123, 124, 110, 102,  29,  26,   9,  5],
        [121, 122, 109, 101,  27,  19,  18,  4],
        [119, 120, 108, 100,  25,  15,  12,  6],
        [117, 118, 107,  99,  23,  13,  10,  3],
        [115, 116, 106,  97,  21,  20,   7,  2],
        [113, 114, 105,  98,  17,  24,  14,  0],
        [127, 111, 104,  96,  30,  22,  16,  1]
    ])
        
    for b in range(batch_size):
        
        M = np.random.randint(0, max_mask_size+1)
        
        if M > 0:
            
            masked_indices = return_mask_electrodes_optimized(M)
            rows, cols = np.array(masked_indices).T  # Shape (2, M)
            superior_masked_indices = area_6v_superior[rows, cols]
            inferior_masked_indices = area_6v_inferior[rows, cols]
            masked_channels = np.concatenate((superior_masked_indices, inferior_masked_indices))
            masked_channels_all = np.concatenate((masked_channels, masked_channels+128))
            X[b, :, masked_channels_all] = 0
            
    return X, masked_channels_all
        

def return_mask_electrodes_optimized(M, grid_size=8):
    """
    Optimized electrode masking with vectorized operations.
    
    Args:
        M (int): Number of electrodes to mask
        grid_size (int): Size of square grid (default 8x8)
        
    Returns:
        ndarray: Masked electrode indices sorted by distance
    """
    # Precompute grid coordinates using broadcasting
    rows, cols = np.divmod(np.arange(grid_size**2), grid_size)
    
    # Random center selection
    center_idx = np.random.randint(grid_size**2)
    
    # Vectorized distance calculation
    distances = np.hypot(rows - rows[center_idx], 
                        cols - cols[center_idx])
    
    # Create mask excluding center and sort
    mask = np.ones(grid_size**2, bool)
    valid_indices = np.where(mask)[0]
    
    # Sort with tie-breaking using 64-bit precision
    sorted_indices = valid_indices[
        np.lexsort((np.random.random(len(valid_indices)),  # Tiebreaker
                   distances[valid_indices]))
    ]
    
    return [(idx // grid_size, idx % grid_size) for idx in sorted_indices[:M]]

In [97]:
import torch
X = torch.ones((1,10,256))    
X_masked = mask_electrodes(X, 5)

In [88]:
X[:, :, 15]

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])