In [9]:
import torch
import torch.nn.functional as F

N, H, T, T_M = 16, 12, 203, 128

attention_mask = torch.zeros(N, 1, 1, T)
attention_mask[:,:,:,:170] = 1
print(attention_mask)
zero_one_attention_mask = (attention_mask > -1).float()
print(zero_one_attention_mask)

tensor([[[[1., 1., 1.,  ..., 0., 0., 0.]]],


        [[[1., 1., 1.,  ..., 0., 0., 0.]]],


        [[[1., 1., 1.,  ..., 0., 0., 0.]]],


        ...,


        [[[1., 1., 1.,  ..., 0., 0., 0.]]],


        [[[1., 1., 1.,  ..., 0., 0., 0.]]],


        [[[1., 1., 1.,  ..., 0., 0., 0.]]]])
tensor([[[[1., 1., 1.,  ..., 1., 1., 1.]]],


        [[[1., 1., 1.,  ..., 1., 1., 1.]]],


        [[[1., 1., 1.,  ..., 1., 1., 1.]]],


        ...,


        [[[1., 1., 1.,  ..., 1., 1., 1.]]],


        [[[1., 1., 1.,  ..., 1., 1., 1.]]],


        [[[1., 1., 1.,  ..., 1., 1., 1.]]]])


In [10]:
def grid_sample_bf16(input, grid, mode='nearest', align_corners=False, padding_mode='zeros'):
    input_dtype = input.dtype
    op_dtype = torch.float32 if torch.get_autocast_gpu_dtype() == torch.bfloat16 else input_dtype
    if op_dtype != input_dtype:
        input = input.to(op_dtype)
        grid = grid.to(op_dtype)
    y = F.grid_sample(
        input=input,
        grid=grid,
        mode=mode,
        align_corners=align_corners,
        padding_mode='zeros',
    )
    if y.dtype != input_dtype:
        y = y.to(input_dtype)
    return y

In [11]:
def resize_from_m_to_t(x, masked_fill_value, target_width=None):
    N, H, T1, T_M = x.shape
    if target_width is not None:
        T2 = target_width
    else:
        T2 = T1
    # with timer("resize"):
        # with timer("resize.grid"):

    token_index_x = zero_one_attention_mask.view(N, 1, T2)
    if masked_fill_value is not None:
        # token_index_x = torch.roll(token_index_x, shifts=(1,), dims=(-1)).cumsum(-1) + ((1.0 - zero_one_attention_mask) * 2).view(N, 1, T2)
        # token_index_x = (token_index_x / ((zero_one_attention_mask.sum(-1) + 2).view(N, 1, 1) + 1e-8) * 2 - 1).expand(N, T1, T2)
        mask = token_index_x
        mask_cs = mask.cumsum(-1)
        token_length = (mask_cs[:, :, -1].unsqueeze(-1) - 1) + 3 * (mask_cs[:, :, -1].unsqueeze(-1)/T_M)
        token_index_x = torch.clamp(((((mask_cs - 1) + (1 - mask) * 5000)) / (token_length + 1e-8)) * 2 - 1, -1, 1)
        token_index_x = token_index_x.expand(N, T1, T2)
        # print(token_index_x[0,0,:], mask_cs[0,0,-1])
    else:
        token_index_x = token_index_x.cumsum(-1)
        token_index_x = (token_index_x / ((zero_one_attention_mask.sum(-1) - 1).view(N, 1, 1) + 1e-8) * 2 - 1).expand(N, T1, T2)
            
    token_index_y = (
        torch.arange(T1, dtype=torch.long, device=token_index_x.device)\
            .view(1, T1, 1) / T1 * 2 - 1)\
            .expand(N, T1, T2) #type: torch.Tensor
    token_index = torch.cat([
        token_index_x.unsqueeze(-1),
        token_index_y.unsqueeze(-1)
    ], dim=-1)
        
        # print('ti', token_index[0,-1,:,:])
        

    grid_input = F.pad(F.pad(x, pad=(0, 2), value=0), pad=(0, 1), value=masked_fill_value) if masked_fill_value is not None else x
    if grid_input.dtype != x.dtype:
        grid_input = grid_input.to(x.dtype)
    if token_index.dtype != x.dtype:
        token_index = token_index.to(x.dtype)
    # print(grid_input.shape, grid_input[0,0,0,-5:])
    
    return grid_sample_bf16(
        input=grid_input,
        grid=token_index,
        mode='nearest',
        align_corners=True,
        padding_mode='border'
    )

In [12]:
def softmax_bf16(input, dim=-1):
    input_dtype = input.dtype
    op_dtype = torch.float32 if torch.get_autocast_gpu_dtype() in [torch.bfloat16, torch.float16] else input_dtype
    if op_dtype != input_dtype:
        input = input.to(op_dtype)
    y = torch.softmax(input, dim=-1)
    if y.dtype != input_dtype:
        y = y.to(input_dtype)
    return y

In [13]:
import numpy as np
import cv2
from matplotlib import cm
from PIL import Image

ZOOM = 1

def convert_to_colormap(arr: np.ndarray):
    T, T = arr.shape
    arr_min, arr_max = np.min(arr), np.max(arr)
    normalized = (arr - arr_min) / (arr_max - arr_min + 1e-12)
    colormapped = cm.gist_earth(normalized)
    gamma = 0.2
    colormapped = (colormapped / np.max(colormapped)) ** gamma
    im = Image.fromarray((colormapped*255).astype(np.uint8))
    arr = np.asarray(im)[:, :, :3]
    arr = cv2.resize(arr, None, fx=ZOOM, fy=ZOOM, interpolation=cv2.INTER_NEAREST)
    border = np.ones((arr.shape[0]+2, arr.shape[1]+2, arr.shape[2]), dtype=np.uint8)
    border = border * 255
    border[1:-1, 1:-1, :] = arr
    return border

In [15]:
# visualize one tensor
import os
os.makedirs(f"./plots/poc/test_resizing/", exist_ok=True)

def visualize_one_tensor(tensor, name):
    img = convert_to_colormap(tensor.numpy())
    path = f"./plots/poc/test_resizing/{name}.png"
    cv2.imwrite(path, img)
    print('processed', path)

In [27]:
# estimated_attention_score = torch.randn(N, H, T, T_M)
estimated_attention_score = (torch.arange(T*T_M)/(T*T_M)).view(1, 1, T, T_M).expand(N, H, T, T_M)
estimated_attention_probs = softmax_bf16(estimated_attention_score, -1)
estimated_attention_probs_resized = resize_from_m_to_t(estimated_attention_probs.float(), masked_fill_value=0)

In [28]:
visualize_one_tensor(estimated_attention_score[0, 0, : , : ], "est_score")

processed ./plots/poc/test_resizing/est_score.png


In [29]:
visualize_one_tensor(estimated_attention_probs[0, 0, : , : ], "est_probs")

processed ./plots/poc/test_resizing/est_probs.png


In [30]:
visualize_one_tensor(estimated_attention_probs_resized[0, 0, : , : ], "est_probs_resized")

processed ./plots/poc/test_resizing/est_probs_resized.png


In [36]:
b = torch.arange(120).view(2, 3, 4, 5).float()
c = b.mean(-2, keepdim=True)
print(b.shape)
print(b)

torch.Size([2, 3, 4, 5])
tensor([[[[  0.,   1.,   2.,   3.,   4.],
          [  5.,   6.,   7.,   8.,   9.],
          [ 10.,  11.,  12.,  13.,  14.],
          [ 15.,  16.,  17.,  18.,  19.]],

         [[ 20.,  21.,  22.,  23.,  24.],
          [ 25.,  26.,  27.,  28.,  29.],
          [ 30.,  31.,  32.,  33.,  34.],
          [ 35.,  36.,  37.,  38.,  39.]],

         [[ 40.,  41.,  42.,  43.,  44.],
          [ 45.,  46.,  47.,  48.,  49.],
          [ 50.,  51.,  52.,  53.,  54.],
          [ 55.,  56.,  57.,  58.,  59.]]],


        [[[ 60.,  61.,  62.,  63.,  64.],
          [ 65.,  66.,  67.,  68.,  69.],
          [ 70.,  71.,  72.,  73.,  74.],
          [ 75.,  76.,  77.,  78.,  79.]],

         [[ 80.,  81.,  82.,  83.,  84.],
          [ 85.,  86.,  87.,  88.,  89.],
          [ 90.,  91.,  92.,  93.,  94.],
          [ 95.,  96.,  97.,  98.,  99.]],

         [[100., 101., 102., 103., 104.],
          [105., 106., 107., 108., 109.],
          [110., 111., 112., 113., 114.

In [37]:
print(c.shape)
print(c)

torch.Size([2, 3, 1, 5])
tensor([[[[  7.5000,   8.5000,   9.5000,  10.5000,  11.5000]],

         [[ 27.5000,  28.5000,  29.5000,  30.5000,  31.5000]],

         [[ 47.5000,  48.5000,  49.5000,  50.5000,  51.5000]]],


        [[[ 67.5000,  68.5000,  69.5000,  70.5000,  71.5000]],

         [[ 87.5000,  88.5000,  89.5000,  90.5000,  91.5000]],

         [[107.5000, 108.5000, 109.5000, 110.5000, 111.5000]]]])


In [38]:
a= torch.randn(N, H, T, T_M)
b = a*(attention_mask.transpose(-1,-2)>-1)

In [39]:
a

tensor([[[[ 2.2026e+00,  7.5384e-01, -1.8393e-01,  ...,  3.3712e-01,
           -1.2784e+00,  4.1680e-01],
          [ 9.0825e-01,  8.2721e-01, -1.6723e-01,  ..., -3.7292e-02,
            8.3455e-01,  7.1869e-01],
          [ 3.8644e-01,  3.5422e-01, -4.9646e-01,  ..., -2.1055e+00,
           -7.1569e-03, -1.0145e-01],
          ...,
          [-4.9632e-01, -1.6668e+00, -3.4051e-01,  ..., -9.1249e-01,
           -1.3586e+00, -8.1768e-01],
          [ 7.5356e-01,  1.3950e-01,  5.8772e-01,  ...,  8.2595e-01,
            6.2263e-01,  1.2877e+00],
          [ 6.9680e-01, -5.9088e-01, -7.3453e-01,  ..., -2.5833e-01,
           -1.1082e+00, -2.9648e-01]],

         [[-1.9430e-01,  8.2447e-01,  2.2134e+00,  ..., -1.9011e+00,
           -1.4680e+00,  1.1100e+00],
          [ 1.0054e+00,  1.1121e+00, -4.6547e-01,  ..., -1.7525e+00,
            1.4339e-02, -5.7059e-01],
          [ 4.5010e-01,  1.8084e-01, -2.8183e-01,  ..., -4.7081e-01,
           -1.8265e-01,  1.4009e+00],
          ...,
     

In [40]:
b

tensor([[[[ 2.2026e+00,  7.5384e-01, -1.8393e-01,  ...,  3.3712e-01,
           -1.2784e+00,  4.1680e-01],
          [ 9.0825e-01,  8.2721e-01, -1.6723e-01,  ..., -3.7292e-02,
            8.3455e-01,  7.1869e-01],
          [ 3.8644e-01,  3.5422e-01, -4.9646e-01,  ..., -2.1055e+00,
           -7.1569e-03, -1.0145e-01],
          ...,
          [-4.9632e-01, -1.6668e+00, -3.4051e-01,  ..., -9.1249e-01,
           -1.3586e+00, -8.1768e-01],
          [ 7.5356e-01,  1.3950e-01,  5.8772e-01,  ...,  8.2595e-01,
            6.2263e-01,  1.2877e+00],
          [ 6.9680e-01, -5.9088e-01, -7.3453e-01,  ..., -2.5833e-01,
           -1.1082e+00, -2.9648e-01]],

         [[-1.9430e-01,  8.2447e-01,  2.2134e+00,  ..., -1.9011e+00,
           -1.4680e+00,  1.1100e+00],
          [ 1.0054e+00,  1.1121e+00, -4.6547e-01,  ..., -1.7525e+00,
            1.4339e-02, -5.7059e-01],
          [ 4.5010e-01,  1.8084e-01, -2.8183e-01,  ..., -4.7081e-01,
           -1.8265e-01,  1.4009e+00],
          ...,
     

In [41]:
b[:,:,:,:]=0

In [42]:
b

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         ...,

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 

In [43]:
a

tensor([[[[ 2.2026e+00,  7.5384e-01, -1.8393e-01,  ...,  3.3712e-01,
           -1.2784e+00,  4.1680e-01],
          [ 9.0825e-01,  8.2721e-01, -1.6723e-01,  ..., -3.7292e-02,
            8.3455e-01,  7.1869e-01],
          [ 3.8644e-01,  3.5422e-01, -4.9646e-01,  ..., -2.1055e+00,
           -7.1569e-03, -1.0145e-01],
          ...,
          [-4.9632e-01, -1.6668e+00, -3.4051e-01,  ..., -9.1249e-01,
           -1.3586e+00, -8.1768e-01],
          [ 7.5356e-01,  1.3950e-01,  5.8772e-01,  ...,  8.2595e-01,
            6.2263e-01,  1.2877e+00],
          [ 6.9680e-01, -5.9088e-01, -7.3453e-01,  ..., -2.5833e-01,
           -1.1082e+00, -2.9648e-01]],

         [[-1.9430e-01,  8.2447e-01,  2.2134e+00,  ..., -1.9011e+00,
           -1.4680e+00,  1.1100e+00],
          [ 1.0054e+00,  1.1121e+00, -4.6547e-01,  ..., -1.7525e+00,
            1.4339e-02, -5.7059e-01],
          [ 4.5010e-01,  1.8084e-01, -2.8183e-01,  ..., -4.7081e-01,
           -1.8265e-01,  1.4009e+00],
          ...,
     

In [48]:
attention_mask = torch.zeros(N, 1, 1, T)
attention_mask[:10,:,:,170:] = -1000
attention_mask[10:16,:,:,140:] = -1000
token_length = (attention_mask > -1).long().sum(-1).view(N, -1)
attention_mask

tensor([[[[    0.,     0.,     0.,  ..., -1000., -1000., -1000.]]],


        [[[    0.,     0.,     0.,  ..., -1000., -1000., -1000.]]],


        [[[    0.,     0.,     0.,  ..., -1000., -1000., -1000.]]],


        ...,


        [[[    0.,     0.,     0.,  ..., -1000., -1000., -1000.]]],


        [[[    0.,     0.,     0.,  ..., -1000., -1000., -1000.]]],


        [[[    0.,     0.,     0.,  ..., -1000., -1000., -1000.]]]])

In [53]:
token_length.shape

torch.Size([16, 1])

In [54]:
token_length

tensor([[170],
        [170],
        [170],
        [170],
        [170],
        [170],
        [170],
        [170],
        [170],
        [170],
        [140],
        [140],
        [140],
        [140],
        [140],
        [140]])

In [52]:
large_inx = torch.randn(N, 1, 12)
large_inx.shape

torch.Size([16, 1, 12])

In [None]:
large_inx.expand(token_length)

In [55]:
a = torch.ones(N, H, T, T)
print(a)
b = a + attention_mask

tensor([[[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         ...,

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 

In [63]:
print(b)
c = torch.softmax(b, dim=-1)
print(c)

tensor([[[[   1.,    1.,    1.,  ..., -999., -999., -999.],
          [   1.,    1.,    1.,  ..., -999., -999., -999.],
          [   1.,    1.,    1.,  ..., -999., -999., -999.],
          ...,
          [   1.,    1.,    1.,  ..., -999., -999., -999.],
          [   1.,    1.,    1.,  ..., -999., -999., -999.],
          [   1.,    1.,    1.,  ..., -999., -999., -999.]],

         [[   1.,    1.,    1.,  ..., -999., -999., -999.],
          [   1.,    1.,    1.,  ..., -999., -999., -999.],
          [   1.,    1.,    1.,  ..., -999., -999., -999.],
          ...,
          [   1.,    1.,    1.,  ..., -999., -999., -999.],
          [   1.,    1.,    1.,  ..., -999., -999., -999.],
          [   1.,    1.,    1.,  ..., -999., -999., -999.]],

         [[   1.,    1.,    1.,  ..., -999., -999., -999.],
          [   1.,    1.,    1.,  ..., -999., -999., -999.],
          [   1.,    1.,    1.,  ..., -999., -999., -999.],
          ...,
          [   1.,    1.,    1.,  ..., -999., -999.,

In [64]:
d = torch.ones(N, H, T, 128)
e = torch.matmul(c,d)
print(e)

tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1

In [65]:
e.shape

torch.Size([16, 12, 203, 128])

In [81]:
a = torch.arange(120).view(2,3,4,5)

b = a.permute(0,2,1,3).contiguous().view(2,4, 3*5)

print(a.shape)
print(b.shape)

torch.Size([2, 3, 4, 5])
torch.Size([2, 4, 15])


In [82]:
a

tensor([[[[  0,   1,   2,   3,   4],
          [  5,   6,   7,   8,   9],
          [ 10,  11,  12,  13,  14],
          [ 15,  16,  17,  18,  19]],

         [[ 20,  21,  22,  23,  24],
          [ 25,  26,  27,  28,  29],
          [ 30,  31,  32,  33,  34],
          [ 35,  36,  37,  38,  39]],

         [[ 40,  41,  42,  43,  44],
          [ 45,  46,  47,  48,  49],
          [ 50,  51,  52,  53,  54],
          [ 55,  56,  57,  58,  59]]],


        [[[ 60,  61,  62,  63,  64],
          [ 65,  66,  67,  68,  69],
          [ 70,  71,  72,  73,  74],
          [ 75,  76,  77,  78,  79]],

         [[ 80,  81,  82,  83,  84],
          [ 85,  86,  87,  88,  89],
          [ 90,  91,  92,  93,  94],
          [ 95,  96,  97,  98,  99]],

         [[100, 101, 102, 103, 104],
          [105, 106, 107, 108, 109],
          [110, 111, 112, 113, 114],
          [115, 116, 117, 118, 119]]]])

In [83]:
b

tensor([[[  0,   1,   2,   3,   4,  20,  21,  22,  23,  24,  40,  41,  42,  43,
           44],
         [  5,   6,   7,   8,   9,  25,  26,  27,  28,  29,  45,  46,  47,  48,
           49],
         [ 10,  11,  12,  13,  14,  30,  31,  32,  33,  34,  50,  51,  52,  53,
           54],
         [ 15,  16,  17,  18,  19,  35,  36,  37,  38,  39,  55,  56,  57,  58,
           59]],

        [[ 60,  61,  62,  63,  64,  80,  81,  82,  83,  84, 100, 101, 102, 103,
          104],
         [ 65,  66,  67,  68,  69,  85,  86,  87,  88,  89, 105, 106, 107, 108,
          109],
         [ 70,  71,  72,  73,  74,  90,  91,  92,  93,  94, 110, 111, 112, 113,
          114],
         [ 75,  76,  77,  78,  79,  95,  96,  97,  98,  99, 115, 116, 117, 118,
          119]]])

In [84]:
print(a.data_ptr())
print(b.data_ptr())


94232273023552
94232269433216


In [85]:
a is b

False

In [86]:
b[:,:,:]=0

In [87]:
a

tensor([[[[  0,   1,   2,   3,   4],
          [  5,   6,   7,   8,   9],
          [ 10,  11,  12,  13,  14],
          [ 15,  16,  17,  18,  19]],

         [[ 20,  21,  22,  23,  24],
          [ 25,  26,  27,  28,  29],
          [ 30,  31,  32,  33,  34],
          [ 35,  36,  37,  38,  39]],

         [[ 40,  41,  42,  43,  44],
          [ 45,  46,  47,  48,  49],
          [ 50,  51,  52,  53,  54],
          [ 55,  56,  57,  58,  59]]],


        [[[ 60,  61,  62,  63,  64],
          [ 65,  66,  67,  68,  69],
          [ 70,  71,  72,  73,  74],
          [ 75,  76,  77,  78,  79]],

         [[ 80,  81,  82,  83,  84],
          [ 85,  86,  87,  88,  89],
          [ 90,  91,  92,  93,  94],
          [ 95,  96,  97,  98,  99]],

         [[100, 101, 102, 103, 104],
          [105, 106, 107, 108, 109],
          [110, 111, 112, 113, 114],
          [115, 116, 117, 118, 119]]]])

In [88]:
b

tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]])

In [89]:
c = a.permute(0,2,1,3).contiguous()
print(c)

tensor([[[[  0,   1,   2,   3,   4],
          [ 20,  21,  22,  23,  24],
          [ 40,  41,  42,  43,  44]],

         [[  5,   6,   7,   8,   9],
          [ 25,  26,  27,  28,  29],
          [ 45,  46,  47,  48,  49]],

         [[ 10,  11,  12,  13,  14],
          [ 30,  31,  32,  33,  34],
          [ 50,  51,  52,  53,  54]],

         [[ 15,  16,  17,  18,  19],
          [ 35,  36,  37,  38,  39],
          [ 55,  56,  57,  58,  59]]],


        [[[ 60,  61,  62,  63,  64],
          [ 80,  81,  82,  83,  84],
          [100, 101, 102, 103, 104]],

         [[ 65,  66,  67,  68,  69],
          [ 85,  86,  87,  88,  89],
          [105, 106, 107, 108, 109]],

         [[ 70,  71,  72,  73,  74],
          [ 90,  91,  92,  93,  94],
          [110, 111, 112, 113, 114]],

         [[ 75,  76,  77,  78,  79],
          [ 95,  96,  97,  98,  99],
          [115, 116, 117, 118, 119]]]])


In [98]:
print(a.data_ptr() == c.data_ptr())
print(b.data_ptr() == c.data_ptr())


False
False


In [93]:
c[:,:,:,:]=1

In [94]:
a

tensor([[[[  0,   1,   2,   3,   4],
          [  5,   6,   7,   8,   9],
          [ 10,  11,  12,  13,  14],
          [ 15,  16,  17,  18,  19]],

         [[ 20,  21,  22,  23,  24],
          [ 25,  26,  27,  28,  29],
          [ 30,  31,  32,  33,  34],
          [ 35,  36,  37,  38,  39]],

         [[ 40,  41,  42,  43,  44],
          [ 45,  46,  47,  48,  49],
          [ 50,  51,  52,  53,  54],
          [ 55,  56,  57,  58,  59]]],


        [[[ 60,  61,  62,  63,  64],
          [ 65,  66,  67,  68,  69],
          [ 70,  71,  72,  73,  74],
          [ 75,  76,  77,  78,  79]],

         [[ 80,  81,  82,  83,  84],
          [ 85,  86,  87,  88,  89],
          [ 90,  91,  92,  93,  94],
          [ 95,  96,  97,  98,  99]],

         [[100, 101, 102, 103, 104],
          [105, 106, 107, 108, 109],
          [110, 111, 112, 113, 114],
          [115, 116, 117, 118, 119]]]])

In [95]:
c

tensor([[[[1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1]],

         [[1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1]],

         [[1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1]],

         [[1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1]]],


        [[[1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1]],

         [[1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1]],

         [[1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1]],

         [[1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1]]]])

In [96]:
b

tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]])

In [99]:
a.shape

torch.Size([2, 3, 4, 5])

In [100]:
d = a.view(2,3,4*5)
print(d)

tensor([[[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
           14,  15,  16,  17,  18,  19],
         [ 20,  21,  22,  23,  24,  25,  26,  27,  28,  29,  30,  31,  32,  33,
           34,  35,  36,  37,  38,  39],
         [ 40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,
           54,  55,  56,  57,  58,  59]],

        [[ 60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,
           74,  75,  76,  77,  78,  79],
         [ 80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,
           94,  95,  96,  97,  98,  99],
         [100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113,
          114, 115, 116, 117, 118, 119]]])


In [101]:
d[:,:,:]=0

In [102]:
a

tensor([[[[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]],

         [[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]],

         [[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]]],


        [[[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]],

         [[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]],

         [[0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]]]])

In [103]:
d

tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]])

In [104]:
a is d

False

In [105]:
a.data_ptr() == d.data_ptr()

True

In [111]:
p = torch.tensor([3,2,3,1,1])

torch.all(p>0).item()

True

In [113]:
l = torch.randn(2,3,4,5)

l1 = l.permute(0,2,1,3).contiguous().view(2,4,3*5)
l2 = l.permute(0,2,1,3).contiguous().view(2,4,3*5)



In [114]:
l1.data_ptr() == l2.data_ptr()

False

In [115]:
l1 is l2

False

In [117]:
attention_mask

tensor([[[[    0.,     0.,     0.,  ..., -1000., -1000., -1000.]]],


        [[[    0.,     0.,     0.,  ..., -1000., -1000., -1000.]]],


        [[[    0.,     0.,     0.,  ..., -1000., -1000., -1000.]]],


        ...,


        [[[    0.,     0.,     0.,  ..., -1000., -1000., -1000.]]],


        [[[    0.,     0.,     0.,  ..., -1000., -1000., -1000.]]],


        [[[    0.,     0.,     0.,  ..., -1000., -1000., -1000.]]]])

In [122]:
(attention_mask.transpose(-1, -2) < -1).int().shape

torch.Size([16, 1, 203, 1])

In [123]:
(attention_mask.transpose(-1, -2) < -1).int()

tensor([[[[0],
          [0],
          [0],
          ...,
          [1],
          [1],
          [1]]],


        [[[0],
          [0],
          [0],
          ...,
          [1],
          [1],
          [1]]],


        [[[0],
          [0],
          [0],
          ...,
          [1],
          [1],
          [1]]],


        ...,


        [[[0],
          [0],
          [0],
          ...,
          [1],
          [1],
          [1]]],


        [[[0],
          [0],
          [0],
          ...,
          [1],
          [1],
          [1]]],


        [[[0],
          [0],
          [0],
          ...,
          [1],
          [1],
          [1]]]], dtype=torch.int32)

In [118]:
a = torch.ones(N, H, T, T_M)

a.masked_fill_(
    mask=attention_mask.transpose(-1, -2) < -1,
    value=0
)

tensor([[[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         ...,

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 

In [119]:
a.shape

torch.Size([16, 12, 203, 128])

In [None]:
p = torch.ones(2,3)

p.masked_fill_(dim=-1,i =)

In [125]:
p = torch.zeros(5)
p

tensor([0., 0., 0., 0., 0.])