In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [75]:
class SelfAttention(nn.Module):
    def __init__(
            self,
            d_model=2,
            row_dim=0,
            col_dim=1
    ):
        super().__init__()
        
        self.d_model = d_model
        self.row_dim = row_dim
        self.col_dim = col_dim
        
        
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)


    def forward(self, token_embeddings, mask=None):

        Q = self.W_q(token_embeddings)
        K = self.W_k(token_embeddings)
        V = self.W_v(token_embeddings)

        # the dot products between each token and others
        similarities = torch.matmul(Q, K.transpose(dim0=self.row_dim, dim1=self.col_dim))
        scaled_similarites = similarities / (self.d_model ** 0.5)

        if mask is not None:
            scaled_similarites = scaled_similarites.masked_fill(mask=mask, value=1e9)
            
        # taking softmax as percentages of how important each word to each one
        attention_percentages = F.softmax(scaled_similarites)

        # Scalling the values by their associated percentages
        attention_values = torch.matmul(attention_percentages, V)


        return attention_values


In [79]:
model = SelfAttention(d_model=2)

encodings_matrix = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

## set the seed for the random number generator
torch.manual_seed(42)
mask = torch.tril(torch.ones((3, 3)))
mask = mask == 0

model(encodings_matrix, mask)

  attention_percentages = F.softmax(scaled_similarites)


tensor([[1.7596, 1.4775],
        [3.8695, 2.4246],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])

In [59]:
model.W_k.weight

Parameter containing:
tensor([[ 0.0900,  0.4665,  0.0631],
        [-0.1821,  0.1551, -0.1566],
        [ 0.2430,  0.5155,  0.3337]], requires_grad=True)