In [107]:
import math
from torch import nn
import torch
from torch.nn.functional import softmax

In [257]:
attn_mask = torch.tril(torch.ones(*scores.shape)) == 0
attn_mask[0]

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False]])

In [100]:
mask = torch.tril(torch.ones(10, 10)) == 0
mask

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False]])

In [101]:
scores = torch.rand(8, 10, 10)
scores[0][0]

tensor([0.9756, 0.6373, 0.1157, 0.6557, 0.3390, 0.7457, 0.1824, 0.6997, 0.0744,
        0.5098])

In [102]:
scores = torch.masked_fill(scores, mask, float("-inf"))
scores[0][0]

tensor([0.9756,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,
          -inf])

In [110]:
scores = scores / math.sqrt(10)

In [111]:
softmax(scores, dim=-1)[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.4839, 0.5161, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.3550, 0.3220, 0.3230, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2569, 0.2532, 0.2565, 0.2334, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2004, 0.1777, 0.1813, 0.2284, 0.2123, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1785, 0.1769, 0.1576, 0.1699, 0.1455, 0.1716, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1413, 0.1264, 0.1630, 0.1274, 0.1551, 0.1339, 0.1529, 0.0000, 0.0000,
         0.0000],
        [0.1234, 0.1062, 0.1086, 0.1301, 0.1389, 0.1317, 0.1355, 0.1257, 0.0000,
         0.0000],
        [0.1038, 0.1070, 0.1077, 0.1278, 0.1197, 0.1017, 0.1019, 0.1252, 0.1051,
         0.0000],
        [0.1150, 0.0986, 0.1093, 0.1014, 0.0903, 0.0888, 0.1039, 0.0864, 0.1037,
         0.1026]])

In [138]:
def get_attn_mask(shape):
    return torch.tril(torch.ones(shape)) == 0

get_attn_mask((10, 10))

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False]])

In [283]:
import math

class SingleHeadAttention(nn.Module):
    def __init__(self, embedding_dim, attention_dim):
        super().__init__()
        torch.manual_seed(0)
        self.attention_dim = attention_dim
    
        self.key_proj = nn.Linear(embedding_dim, attention_dim, bias=False)
        self.query_proj = nn.Linear(embedding_dim, attention_dim, bias=False)
        self.value_proj = nn.Linear(embedding_dim, attention_dim, bias=False)

    def forward(self, X):
        key = self.key_proj(X)
        query = self.query_proj(X)
        value = self.value_proj(X)

        scores = query @ key.transpose(2, 1)
        # Scale scores by sqrt of attention dim
        scores = scores / math.sqrt(self.attention_dim)

        # Compute a mask and set all future values to -inf. This ensure a score of 0 after softmax.
        attn_mask = torch.tril(torch.ones(*scores.shape)) == 0
        scores = torch.masked_fill(scores, attn_mask, float("-inf"))

        # Compute softmax of scores.
        scores = softmax(scores, dim=-1)

        # Now do final projection with values.
        out = scores @ value

        return torch.round(out, decimals=4)

In [284]:
embedding_dim = 1024
attention_dim = 100
self_attention = SingleHeadAttention(embedding_dim, attention_dim)

In [285]:
input_ = torch.randn(8, 10, 1024)

In [286]:
out = self_attention(input_)
out.shape

torch.Size([8, 10, 100])

In [287]:
embedding_dim = 2
attention_dim = 3
embedded = torch.tensor([
  [[-1.4381, 0.1232],
   [-0.1080, 0.3458]],
  [[0.1929, -0.8567],
   [-0.1160, 1.2547]]
])

In [288]:
embedded.shape

torch.Size([2, 2, 2])

In [289]:
self_attention = SingleHeadAttention(embedding_dim, attention_dim)
out = self_attention(embedded)
print(out[0])

tensor([[ 0.9138,  0.4224, -0.3497],
        [ 0.4183,  0.2337, -0.1193]], grad_fn=<SelectBackward0>)
