# Coding Attention in PyTorch
This notebook contains Self-Attention, Masked Self-Attention, Multi-Head Attention

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Attention class

In [2]:
class Attention(nn.Module):
    def __init__(self, dmodel=2, drow=0, dcol=1):
        super().__init__()

        self.wq = nn.Linear(dmodel, dmodel, bias=False)
        self.wk = nn.Linear(dmodel, dmodel, bias=False)
        self.wv = nn.Linear(dmodel, dmodel, bias=False)

        self.drow = drow
        self.dcol = dcol

    def forward(self, encodingq, encodingk, encodingv, mask=None):
        q = self.wq(encodingq)
        k = self.wk(encodingk)
        v = self.wv(encodingv)

        sims = torch.matmul(q, k.transpose(self.drow, self.dcol)) / torch.sqrt(torch.tensor(k.size(-1)))

        if mask is not None:
            sims = torch.masked_fill(sims, mask, value=-1e9)

        sims_probs = F.softmax(sims, dim=self.dcol)

        output = torch.matmul(sims_probs, v)

        return output


In [3]:
## set the seed for the random number generator
torch.manual_seed(42)

attention = Attention()

## create a matrix of token encodings...
encodings_matrix = torch.tensor([[1.16, 0.23],
                                 [0.57, 1.36],
                                 [4.41, -2.16]])

### Calculate self-attention

In [4]:
output = attention(encodings_matrix, encodings_matrix, encodings_matrix)
print(output)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)


### Calculate encoder-decoder attention

In [5]:
# Encoder outputs (assume these are processed source sentence embeddings)
encoder_outputs = torch.tensor([[1.0, 2.0],   # Encoded word 1
                              [3.0, 4.0],     # Encoded word 2
                              [5.0, 6.0]])    # Encoded word 3

# Decoder current states (different sequence length!)
decoder_states = torch.tensor([[0.5, 0.5], # Current decoder state 1
                              [1.5, 1.5]])   # Current decoder state 2  

# Use decoder states as queries, encoder outputs as keys and values
enc_dec_attention_output = attention(decoder_states,      # q (from decoder)
                                   encoder_outputs,       # k (from encoder)
                                   encoder_outputs)       # v (from encoder)

print("Encoder outputs shape:", encoder_outputs.shape)
print("Decoder states shape:", decoder_states.shape)
print("Encoder-Decoder attention output shape:", enc_dec_attention_output.shape)
print(enc_dec_attention_output)

Encoder outputs shape: torch.Size([3, 2])
Decoder states shape: torch.Size([2, 2])
Encoder-Decoder attention output shape: torch.Size([2, 2])
tensor([[-0.2030,  2.3877],
        [-0.1989,  2.4166]], grad_fn=<MmBackward0>)


### Calculate masked self-attention

In [6]:
mask = torch.tril(torch.ones((encodings_matrix.size(0),encodings_matrix.size(0))))
mask = mask==0

output = attention(encodings_matrix, encodings_matrix, encodings_matrix, mask)
print(output)

tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)


## Multi-Head Attention

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dmodel=2, drow=0, dcol=1, nhead=1):
        super().__init__()

        self.nhead = nhead

        self.heads = nn.ModuleList([Attention(dmodel, drow, dcol) for _ in range(self.nhead)])

    def forward(self, encodingq, encodingk, encodingv, mask=None):
        output = []
        for head in self.heads:
            output.append(head(encodingq, encodingk, encodingv))
        
        output_con = torch.concat(output, dim=-1)

        return output_con

Verify that we can correctly calculate attention with single head

In [8]:
## set the seed for the random number generator
torch.manual_seed(42)

singlehead = MultiHeadAttention(dmodel=2, drow=0, dcol=1, nhead=1)
output = singlehead(encodings_matrix, encodings_matrix, encodings_matrix)

print(output)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<CatBackward0>)


Calculate attention with multiple head

In [9]:
torch.manual_seed(42)

multihead = MultiHeadAttention(dmodel=2, drow=0, dcol=1, nhead=2)
print(multihead(encodings_matrix, encodings_matrix, encodings_matrix))

tensor([[ 1.0100,  1.0641, -0.7081, -0.8268],
        [ 0.2040,  0.7057, -0.7417, -0.9193],
        [ 3.4989,  2.2427, -0.7190, -0.8447]], grad_fn=<CatBackward0>)
