In [1]:
import torch

In [2]:
# Simulate the inputs
B = 1 # Batch size
T = 16 # Number of frames
N = 14*14 # Number of tokens per frame
mask_ratio = 0.75

BT = B*T
attn = torch.rand(BT, N) # Simulated Attention Map showing the score of each token in each frame across the batch
print(f"Example : Attention Map for the tokens of frame 1: \n{attn[0]}")

# Step 1 : Calculate the number of visible patches
N_vis = N - int(N * mask_ratio) # Number of visible tokens based on mask ratio
print(f"\nNumber of visible tokens per frame: {N_vis}")

Example : Attention Map for the tokens of frame 1: 
tensor([0.6404, 0.2407, 0.1934, 0.7004, 0.0069, 0.3592, 0.4818, 0.6008, 0.9668,
        0.1750, 0.6840, 0.6035, 0.8331, 0.1818, 0.7058, 0.1386, 0.7779, 0.7988,
        0.0787, 0.4564, 0.0965, 0.7411, 0.7778, 0.1418, 0.1889, 0.4742, 0.2425,
        0.8073, 0.6988, 0.4472, 0.2829, 0.2378, 0.2497, 0.9875, 0.6660, 0.6724,
        0.7923, 0.0616, 0.1201, 0.4395, 0.1286, 0.2702, 0.8954, 0.7531, 0.5376,
        0.7369, 0.1449, 0.9912, 0.9562, 0.8112, 0.4442, 0.5831, 0.4216, 0.1040,
        0.2171, 0.1679, 0.3076, 0.9386, 0.9061, 0.6639, 0.8172, 0.9311, 0.6379,
        0.2627, 0.2324, 0.9860, 0.1023, 0.0500, 0.0907, 0.6694, 0.5931, 0.4810,
        0.2235, 0.6389, 0.6386, 0.5263, 0.3426, 0.2549, 0.2487, 0.7174, 0.2559,
        0.8263, 0.7175, 0.2918, 0.8234, 0.8742, 0.2673, 0.9093, 0.3831, 0.7890,
        0.0109, 0.5015, 0.7419, 0.2233, 0.7751, 0.7962, 0.2987, 0.2038, 0.2066,
        0.3175, 0.5388, 0.1610, 0.6415, 0.8088, 0.4262, 0.7893, 0.08

In [3]:
# Step 2 : Sample tokens based on attention scores
importance = torch.multinomial(attn, N) # Sample all tokens based on attention scores
print(f"Sampled token indices shape: (shape: {importance.shape}):\nImportance: {importance}")
print(f"\nExample: Importance for the tokens of frame 1: \n{importance[0]}")

Sampled token indices shape: (shape: torch.Size([16, 196])):
Importance: tensor([[120,  76, 185,  ..., 168,  46,  90],
        [ 88, 183, 189,  ..., 113, 154, 165],
        [152, 127, 184,  ..., 183,  85,  16],
        ...,
        [154, 112,  77,  ..., 103,  26, 192],
        [119,  76, 111,  ..., 147,  37,  20],
        [161,  72,   6,  ..., 163, 107, 112]])

Example: Importance for the tokens of frame 1: 
tensor([120,  76, 185, 110,   5,  94, 121,  82,  45, 176, 128, 192,  48, 105,
         12, 191, 136, 152,  10, 144, 135,  35,  93,   3, 154, 182, 167, 107,
          8, 143,  49,  86,  22,   0, 195, 133, 118,  84,  79,  71, 147,  28,
        172,  69,  61,  62,  34,  33,  83, 104,  63, 119,  58, 181, 126, 194,
         53, 180, 148,  54,  36,  65, 103,  41, 116,  89, 157,  56,  92,  51,
          7, 134,  17, 156, 123,  87,  16, 186, 165, 102,  57,  30, 140, 174,
        150,  42, 179,  43, 161,  59,  50,  77, 138,  72, 171, 122, 187, 170,
         21,  74, 146,  81, 190,  44, 101,

In [4]:
# Step 3: Initialize the mask (all tokens are masked initially)
bool_masked_pos = torch.ones((BT, N), dtype=torch.bool)  # All masked (True)
print(f"Initialized mask: {bool_masked_pos}")

Initialized mask: tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])


In [19]:
# Step 4: Mark visible tokens in the mask
pos1 = torch.arange(BT).view(-1, 1).repeat(1, N_vis)  # Total frames indices for visible tokens
print(f"We define a tensor filled with 0 for frame nb°0 having N_vis elements:\n", pos1[0])
pos2 = importance[:, :N_vis]  # Indices of selected visible tokens
print(f"We define a tensor filled with the first N_vis elemenets in importance:\n",pos2[0])
bool_masked_pos[pos1, pos2] = False  # Set visible tokens to unmasked (False)
print(f"The visible & masked tokens in Frame 0: \n", bool_masked_pos[0]) # -> Example : Tokens 120, 76, 0 etc belong to the first N_vis elements in importance so they are marked with False -> visible 

We define a tensor filled with 0 for frame nb°0 having N_vis elements:
 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0])
We define a tensor filled with the first N_vis elemenets in importance:
 tensor([120,  76, 185, 110,   5,  94, 121,  82,  45, 176, 128, 192,  48, 105,
         12, 191, 136, 152,  10, 144, 135,  35,  93,   3, 154, 182, 167, 107,
          8, 143,  49,  86,  22,   0, 195, 133, 118,  84,  79,  71, 147,  28,
        172,  69,  61,  62,  34,  33,  83])
The visible & masked tokens in Frame 0: 
 tensor([False,  True,  True, False,  True, False,  True,  True, False,  True,
        False,  True, False,  True,  True,  True,  True,  True,  True,  True,
         True,  True, False,  True,  True,  True,  True,  True, False,  True,
         True,  True,  True, False, False, False,  True,  True,  True,  True,
         True,  True,  True,  True,  True, False,  