In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, d_model=2, row_dim=0,col_dim=1):
        super().__init__()
        self.W_q = nn.Linear(in_features = d_model,out_features = d_model, bias = False)
        self.W_k = nn.Linear(in_features = d_model,out_features = d_model, bias = False)
        self.W_v = nn.Linear(in_features = d_model,out_features = d_model, bias = False)

        self.row_dim = row_dim
        self.col_dim = col_dim

    def forward(self, token_encodings):
        q = self.W_q(token_encodings)
        k = self.W_k(token_encodings)
        v = self.W_v(token_encodings)

        sims = torch.matmul(q, k.transpose(dim0 = self.row_dim, dim1 = self.col_dim))
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)
        attention_percents = F.softmax(scaled_sims, dim = self.col_dim)
        attention_scores = torch.matmul(attention_percents , v)
        return attention_scores


In [3]:
encodings_matrix = torch.tensor([[1.16, 0.23],
                                 [0.57, 1.36],
                                 [4.41, -2.16]])

## set the seed for the random number generator
torch.manual_seed(42)

## create a basic self-attention ojbect
selfAttention = SelfAttention(d_model=2, row_dim=0,col_dim=1)

## calculate basic attention for the token encodings
selfAttention(encodings_matrix)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [4]:
selfAttention.W_q.weight.transpose(0, 1)

tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<TransposeBackward0>)

In [5]:
selfAttention.W_k.weight.transpose(0, 1)

tensor([[-0.1549, -0.3443],
        [ 0.1427,  0.4153]], grad_fn=<TransposeBackward0>)

In [6]:
selfAttention.W_v.weight.transpose(0, 1)

tensor([[ 0.6233,  0.6146],
        [-0.5188,  0.1323]], grad_fn=<TransposeBackward0>)

In [7]:
selfAttention.W_q(encodings_matrix)

tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>)

In [8]:
selfAttention.W_k(encodings_matrix)

tensor([[-0.1469, -0.3038],
        [ 0.1057,  0.3685],
        [-0.9914, -2.4152]], grad_fn=<MmBackward0>)

In [9]:
class MaskedSelfAttention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1):
        super().__init__()
        self.W_q = nn.Linear(in_features = d_model,out_features = d_model, bias = False)
        self.W_k = nn.Linear(in_features = d_model,out_features = d_model, bias = False)
        self.W_v = nn.Linear(in_features = d_model,out_features = d_model, bias = False)

        self.row_dim = row_dim
        self.col_dim = col_dim

    def forward(self, token_encodings, mask = None):
        q = self.W_q(token_encodings)
        k = self.W_k(token_encodings)
        v = self.W_v(token_encodings)

        sims = torch.matmul(q, k.transpose(dim0 = self.row_dim, dim1 = self.col_dim))
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask,value = -1e9)
        
        attention_percents = F.softmax(scaled_sims, dim = self.col_dim)
        attention_scores = torch.matmul(attention_percents , v)
        return attention_scores


In [10]:
maskedSelfAttention = MaskedSelfAttention(d_model=2,
                               row_dim=0,
                               col_dim=1)

In [11]:
mask = torch.tril(torch.ones(3, 3))
mask=mask==0

In [12]:
maskedSelfAttention(encodings_matrix, mask)

tensor([[-0.3970, -0.2253],
        [-0.3488,  0.1166],
        [-0.7190, -0.8447]], grad_fn=<MmBackward0>)

maskedSelfAttention(encodings_matrix, mask)

In [13]:
class Attention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1):
        super().__init__()
        self.W_q = nn.Linear(in_features = d_model,out_features = d_model, bias = False)
        self.W_k = nn.Linear(in_features = d_model,out_features = d_model, bias = False)
        self.W_v = nn.Linear(in_features = d_model,out_features = d_model, bias = False)

        self.row_dim = row_dim
        self.col_dim = col_dim
    def forward(self, encodings_q, encodings_v, encodings_k, mask = None):
        #different encodings here
        q = self.W_q(encodings_q)
        v = self.W_v(encodings_v)
        k = self.W_k(encodings_k)

        sims = torch.matmul(q, k.transpose(dim0 = self.row_dim, dim1 = self.col_dim))
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)
        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask,value = -1e9)
        
        attention_percents = F.softmax(scaled_sims, dim = self.col_dim)
        attention_scores = torch.matmul(attention_percents , v)
        return attention_scores

In [14]:
## create matrices of token encodings...
encodings_for_q = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_k = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_v = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

## set the seed for the random number generator
torch.manual_seed(42)

## create an attention object
attention = Attention(d_model=2,
                      row_dim=0,
                      col_dim=1)

## calculate encoder-decoder attention


In [15]:
attention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [20]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=2,row_dim=0,col_dim=1,num_heads = 1):
        super().__init__()
        self.heads = nn.ModuleList([Attention(d_model,row_dim,col_dim) for _ in range(num_heads)])
        self.col_dim = col_dim

    def forward(self, encodings_q, encodings_k, encodings_v):

        return torch.cat([head(encodings_q, encodings_k, encodings_v) for head in self.heads], dim = self.col_dim)

In [21]:
## set the seed for the random number generator
torch.manual_seed(42)

## create an attention object
multiHeadAttention = MultiHeadAttention(d_model=2,
                                        row_dim=0,
                                        col_dim=1,
                                        num_heads=1)

## calculate encoder-decoder attention
multiHeadAttention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<CatBackward0>)