In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class MaskedSelfAttention(torch.nn.Module):
    def __init__(self, d_model, row=0, column=1):
        super().__init__()

        self.wq = torch.nn.Linear(d_model, d_model, bias=False)
        self.wk = torch.nn.Linear(d_model, d_model, bias=False)
        self.wv = torch.nn.Linear(d_model, d_model, bias=False)

        self.softmax = torch.nn.Softmax(dim=-1)
        self.row_dim = row
        self.col_dim = column

    def forward(self, input_embedding, mask=None):
        q = self.wq(input_embedding)
        k = self.wk(input_embedding)
        v = self.wv(input_embedding)

        # scaled_sims = torch.matmul(q, k.t()) / torch.tensor(torch.sqrt(torch.tensor(k.size(-1))))
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        # print("scaled_sims:", scaled_sims)

        if mask is not None:
            # scaled_sims += mask
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
            # print(scaled_sims)
        
        sims_score = self.softmax(scaled_sims)

        # print("sims_score", sims_score)

        return torch.matmul(sims_score, v)

In [3]:
## set the seed for the random number generator
torch.manual_seed(42)

maskedSelfAttention = MaskedSelfAttention(d_model=2)

## create a matrix of token encodings...
encodings_matrix = torch.tensor([[1.16, 0.23],
                                 [0.57, 1.36],
                                 [4.41, -2.16]])

mask_dim = encodings_matrix.size(0)
mask = torch.tril(torch.ones(3, 3)) 
mask = mask == 0 
# mask = torch.where(mask, float("-inf"), mask )
print(mask)

output = maskedSelfAttention(encodings_matrix, mask)
print(output)

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])
tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)


In [6]:
## print out the weight matrix that creates the queries
maskedSelfAttention.wq.weight.transpose(0, 1)
## print out the weight matrix that creates the keys
maskedSelfAttention.wk.weight.transpose(0, 1)
## print out the weight matrix that creates the values
maskedSelfAttention.wv.weight.transpose(0, 1)

## calculate the queries
maskedSelfAttention.wq(encodings_matrix)

## calculate the keys
maskedSelfAttention.wk(encodings_matrix)

## calculate the values
maskedSelfAttention.wv(encodings_matrix)

q = maskedSelfAttention.wq(encodings_matrix)
k = maskedSelfAttention.wk(encodings_matrix)
sims = torch.matmul(q, k.transpose(dim0=0, dim1=1))
scaled_sims = sims / (torch.tensor(2)**0.5)

masked_scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)

attention_percents = F.softmax(masked_scaled_sims, dim=1)
torch.matmul(attention_percents, maskedSelfAttention.wv(encodings_matrix))


tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)