In [1]:
import math
import torch
import torch.nn as nn

In [2]:
attn_shape = (1, 5, 5)
subsequent_mask = torch.tril(torch.ones(attn_shape)).type(
        torch.uint8
    )
subsequent_mask

tensor([[[1, 0, 0, 0, 0],
         [1, 1, 0, 0, 0],
         [1, 1, 1, 0, 0],
         [1, 1, 1, 1, 0],
         [1, 1, 1, 1, 1]]], dtype=torch.uint8)

In [30]:
ids = torch.LongTensor([1, 3, 58, 0, 0])

In [33]:
mask = (ids != 0).type(torch.uint8)
# mask = mask.unsqueeze(-2).repeat_interleave(mask.shape[-1], dim=-2)
print(mask.shape)
print((mask & subsequent_mask)[0].shape)

torch.Size([5])
torch.Size([5, 5])


In [3]:
shape = (8, 29, 512)
query = torch.rand((8, 12, 512))
key = torch.rand(shape)
print(f"query.shape: {query.shape}")

query.shape: torch.Size([8, 12, 512])


In [4]:
num_heads = 8

query_r = query.view(shape[0], -1, num_heads, int(shape[2] / num_heads)).transpose(1, 2)
key_r = key.view(shape[0], -1, num_heads, int(shape[2] / num_heads)).transpose(1, 2)
print(f"query.shape: {query_r.shape}")
print(f"key.shape: {key_r.shape}")

query.shape: torch.Size([8, 8, 12, 64])
key.shape: torch.Size([8, 8, 29, 64])


In [5]:
scores = torch.matmul(query_r, key_r.transpose(-2, -1)) / math.sqrt(query_r.size(-1))

print(f"scores.shape: {scores.shape}")

scores.shape: torch.Size([8, 8, 12, 29])


In [8]:
mask_len = query.size(-2)
ones = torch.ones((1, mask_len, mask_len)).to(query.device)
mask = torch.triu(ones, diagonal=1)
print(f"mask.shape: {mask.shape}")

mask.shape: torch.Size([1, 29, 29])


In [9]:
scores = scores.masked_fill(mask == 0, -1e9)
print(scores[0][0])

tensor([[-1.0000e+09,  2.5214e+00,  2.0869e+00,  2.0053e+00,  2.0156e+00,
          2.0673e+00,  1.8962e+00,  2.2199e+00,  1.9403e+00,  2.4110e+00,
          2.0144e+00,  2.1573e+00,  2.2714e+00,  1.9761e+00,  2.2125e+00,
          2.0822e+00,  2.2391e+00,  2.2730e+00,  2.1010e+00,  2.0375e+00,
          1.8319e+00,  1.7566e+00,  2.2187e+00,  2.3758e+00,  2.0882e+00,
          2.2855e+00,  1.9832e+00,  2.3045e+00,  2.0299e+00],
        [-1.0000e+09, -1.0000e+09,  1.9065e+00,  2.0377e+00,  1.9656e+00,
          1.9613e+00,  1.9116e+00,  2.2074e+00,  1.9005e+00,  2.2270e+00,
          1.8258e+00,  1.9982e+00,  2.0762e+00,  1.9272e+00,  2.1602e+00,
          2.1139e+00,  2.2131e+00,  2.2959e+00,  2.0665e+00,  1.7861e+00,
          1.7493e+00,  1.6235e+00,  2.0782e+00,  2.1826e+00,  2.1153e+00,
          2.3264e+00,  1.9567e+00,  2.1610e+00,  2.1155e+00],
        [-1.0000e+09, -1.0000e+09, -1.0000e+09,  2.0674e+00,  2.0131e+00,
          1.9405e+00,  1.7704e+00,  2.1368e+00,  1.8895e+00,  