In [83]:
import torch
import numpy as np

In [107]:
batch_size = 2
feature_size = 2
num_tokens = 4
token_size = 6

In [108]:
x = torch.randn((batch_size, feature_size, num_tokens, token_size))
print(x)
print(x.shape)

tensor([[[[ 0.3415, -0.7421, -1.6642, -0.6188,  0.5851,  0.7946],
          [ 0.3862,  0.3465, -0.4487,  1.0918, -0.4493, -0.5910],
          [ 1.0110,  0.1336,  0.8510,  1.9779,  0.5185, -0.9421],
          [-2.2669,  0.4826,  1.2216,  1.4022,  0.7849,  0.7682]],

         [[-0.0357, -1.2210,  1.1843,  0.2753, -0.8004,  0.9002],
          [-0.2157,  0.2856, -1.0466,  0.2261,  0.9820, -0.3331],
          [-0.9773,  1.3335,  1.7221,  1.0311,  0.9494,  0.2304],
          [ 0.1028, -1.3259,  2.0360,  0.5324, -1.2760,  0.2598]]],


        [[[-0.1043,  0.7626, -0.1383,  0.3262,  1.4163,  1.1930],
          [-0.9546,  1.0084,  1.4956, -1.1866, -0.7160,  0.0618],
          [ 0.8694, -0.8502, -0.5088,  0.7030, -0.3161, -0.5047],
          [ 0.1013,  1.6498, -0.0852, -0.3387, -1.7625, -0.4069]],

         [[ 1.1139,  0.9033, -0.1297,  0.9186, -0.0377,  0.4204],
          [ 0.4495,  0.7723, -0.4204, -0.6152,  0.8331,  1.0382],
          [ 0.7052,  0.4141,  0.5190, -0.5778,  0.9022,  1.5894],
  

In [109]:
def generate_random_masks_mini_tokens(mask_ratio, num_tokens, batch_size):
    """ Generate random masks for the mini tokens """
    def single_sample_mask():
        idx = np.random.permutation(num_tokens)[: int(mask_ratio * num_tokens)]
        mask = np.zeros(num_tokens)
        mask[idx] = 1
        return mask
    masks_list = [single_sample_mask() for _ in range(batch_size)]
    masks = np.stack(masks_list, axis=0)  # (num_samples, num_tokens)
    return masks

In [161]:
rand_mask = generate_random_masks_mini_tokens(0.5, num_tokens, batch_size)
print(rand_mask, rand_mask.shape)

[[1. 0. 0. 1.]
 [0. 1. 1. 0.]] (2, 4)


In [162]:
rand_mask_1 = np.expand_dims(rand_mask, axis=2)
rand_mask_1.shape

(2, 4, 1)

In [163]:
rand_mask_2 = np.repeat(rand_mask_1, token_size, axis=2)
rand_mask_2.shape

(2, 4, 6)

In [164]:
rand_mask_3 = np.expand_dims(rand_mask_2, axis=1)
rand_mask_3 = np.repeat(rand_mask_3, feature_size, axis=1)
rand_mask_3.shape

(2, 2, 4, 6)

In [165]:
x * rand_mask_3

tensor([[[[ 0.3415, -0.7421, -1.6642, -0.6188,  0.5851,  0.7946],
          [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [-2.2669,  0.4826,  1.2216,  1.4022,  0.7849,  0.7682]],

         [[-0.0357, -1.2210,  1.1843,  0.2753, -0.8004,  0.9002],
          [-0.0000,  0.0000, -0.0000,  0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.1028, -1.3259,  2.0360,  0.5324, -1.2760,  0.2598]]],


        [[[-0.0000,  0.0000, -0.0000,  0.0000,  0.0000,  0.0000],
          [-0.9546,  1.0084,  1.4956, -1.1866, -0.7160,  0.0618],
          [ 0.8694, -0.8502, -0.5088,  0.7030, -0.3161, -0.5047],
          [ 0.0000,  0.0000, -0.0000, -0.0000, -0.0000, -0.0000]],

         [[ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000,  0.0000],
          [ 0.4495,  0.7723, -0.4204, -0.6152,  0.8331,  1.0382],
          [ 0.7052,  0.4141,  0.5190, -0.5778,  0.9022,  1.5894],
  

In [166]:
def convert_mask_to_shape(rand_mask):
    rand_mask_1 = np.expand_dims(rand_mask, axis=2)
    rand_mask_2 = np.repeat(rand_mask_1, token_size, axis=2)
    rand_mask_3 = np.expand_dims(rand_mask_2, axis=1)
    rand_mask_4 = np.repeat(rand_mask_3, feature_size, axis=1)
    return rand_mask_4

In [167]:
x * convert_mask_to_shape(rand_mask)

tensor([[[[ 0.3415, -0.7421, -1.6642, -0.6188,  0.5851,  0.7946],
          [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0000],
          [-2.2669,  0.4826,  1.2216,  1.4022,  0.7849,  0.7682]],

         [[-0.0357, -1.2210,  1.1843,  0.2753, -0.8004,  0.9002],
          [-0.0000,  0.0000, -0.0000,  0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.1028, -1.3259,  2.0360,  0.5324, -1.2760,  0.2598]]],


        [[[-0.0000,  0.0000, -0.0000,  0.0000,  0.0000,  0.0000],
          [-0.9546,  1.0084,  1.4956, -1.1866, -0.7160,  0.0618],
          [ 0.8694, -0.8502, -0.5088,  0.7030, -0.3161, -0.5047],
          [ 0.0000,  0.0000, -0.0000, -0.0000, -0.0000, -0.0000]],

         [[ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000,  0.0000],
          [ 0.4495,  0.7723, -0.4204, -0.6152,  0.8331,  1.0382],
          [ 0.7052,  0.4141,  0.5190, -0.5778,  0.9022,  1.5894],
  

In [168]:
x.shape

torch.Size([2, 2, 4, 6])

In [169]:
def get_view_of_fixed_mask(x, mask, keep_tok, token_size, device):
    """ Return just the masked section """
    mask = torch.from_numpy(mask).to(device)
    
    idx = torch.argsort(mask, dim=1) # batch_size, keep_tok
    ids_restore = torch.argsort(idx, dim=1)
    print("idx:", idx)
    print("ids_restore:", ids_restore)
    idx = idx[:,:keep_tok]
    idx = torch.unsqueeze(idx, dim=1) # batch_size, feature_size, keep_tok
    idx = torch.unsqueeze(idx, dim=3) # batch_size, feature_size, keep_tok, token_size
    idx = idx.repeat(1, feature_size, 1, token_size)
    return torch.gather(x, dim=2, index=idx), ids_restore

In [155]:
def get_view_of_fixed_mask(x, mask, keep_tok):
    """ Return just the masked section """
    mask = torch.from_numpy(mask)
    print(f"mask:{mask.shape}") 
    
    idx = torch.argsort(mask, dim=1)[:,:keep_tok] # batch_size, keep_tok
    ids_restore = torch.argsort(idx, dim=1)
    print("idx:", idx)
    print("ids_restore:", ids_restore)
    idx = torch.unsqueeze(idx, dim=1) # batch_size, feature_size, keep_tok, token_size
    idx = torch.unsqueeze(idx, dim=3) # batch_size, feature_size, keep_tok, token_size
    idx = idx.repeat(1, feature_size, 1, token_size)
    # batch_size, feature_size, keep_tok, 1
    print(f"idx:{idx.shape}")
    print(f"x:{x.shape}")
    return torch.gather(x, dim=2, index=idx), ids_restore

In [170]:
get_view_of_fixed_mask(x, rand_mask, 2, token_size, 'cpu')

idx: tensor([[1, 2, 0, 3],
        [0, 3, 1, 2]])
ids_restore: tensor([[2, 0, 1, 3],
        [0, 2, 3, 1]])


(tensor([[[[ 0.3862,  0.3465, -0.4487,  1.0918, -0.4493, -0.5910],
           [ 1.0110,  0.1336,  0.8510,  1.9779,  0.5185, -0.9421]],
 
          [[-0.2157,  0.2856, -1.0466,  0.2261,  0.9820, -0.3331],
           [-0.9773,  1.3335,  1.7221,  1.0311,  0.9494,  0.2304]]],
 
 
         [[[-0.1043,  0.7626, -0.1383,  0.3262,  1.4163,  1.1930],
           [ 0.1013,  1.6498, -0.0852, -0.3387, -1.7625, -0.4069]],
 
          [[ 1.1139,  0.9033, -0.1297,  0.9186, -0.0377,  0.4204],
           [ 0.9845,  0.8140,  1.2143,  1.4363, -0.5735,  0.4860]]]]),
 tensor([[2, 0, 1, 3],
         [0, 2, 3, 1]]))

In [154]:
get_view_of_fixed_mask(x, rand_mask, 2)

mask:torch.Size([2, 4])
idx:torch.Size([2, 2, 2, 6])
x:torch.Size([2, 2, 4, 6])


tensor([[[[ 0.3415, -0.7421, -1.6642, -0.6188,  0.5851,  0.7946],
          [ 0.3862,  0.3465, -0.4487,  1.0918, -0.4493, -0.5910]],

         [[-0.0357, -1.2210,  1.1843,  0.2753, -0.8004,  0.9002],
          [-0.2157,  0.2856, -1.0466,  0.2261,  0.9820, -0.3331]]],


        [[[-0.1043,  0.7626, -0.1383,  0.3262,  1.4163,  1.1930],
          [-0.9546,  1.0084,  1.4956, -1.1866, -0.7160,  0.0618]],

         [[ 1.1139,  0.9033, -0.1297,  0.9186, -0.0377,  0.4204],
          [ 0.4495,  0.7723, -0.4204, -0.6152,  0.8331,  1.0382]]]])