Codes were extracted from the exerpt by Arjun Sarkar

https://towardsdatascience.com/build-your-own-transformer-from-scratch-using-pytorch-84c850470dcb

This is a notebook which pretty much explains the basics of transformers and what lies underneath the hood. 

Tbh , you should think of transformers like a cell with many mechanisms inside. 

The preclusion for this is that you must have at least SOME knowledge on the basics of neural networks because this notebook wouldn't cover everything

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim 
import torch.utils.data as data
import math
import copy


### Multi - Head Attention Mechanism

In short, the multi-head attention mechanism is the heart of the entire transformer model.

By Copilot:

Multi-head attention is a mechanism in Transformer models that allows the model to focus on different parts of the input sequence simultaneously. It improves the model's ability to capture various aspects of the data by running multiple attention mechanisms in parallel. Each "head" in this context refers to a separate attention mechanism, with each focusing on different parts of the input to gather diverse information. The outputs of these heads are then combined and processed further. This approach enhances the model's ability to understand complex relationships in the data, leading to better performance on tasks like translation, text summarization, and more.

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, heads):
        super(MultiHeadAttention,self).__init__()
        assert d_model % heads == 0, "d model must be divisible by num heads"

        self.d_model = d_model
        self.heads = heads
        self.d_k = d_model // heads

        self.W_q = nn.Linear(d_model, d_model) # Query
        self.W_k = nn.Linear(d_model, d_model) # Key
        self.W_v = nn.Linear(d_model, d_model) # Values
        self.W_o = nn.Linear(d_model, d_model)
    
    def scaled_dot_product_attention(self, Q, K, V, mask = None):
        attn_scores = torch.matmul(Q, K.transpose(-2,-1)) / math.sqrt(self.d_k) # Attention score is calculated before being reshaped to 2 x 1 matrix 
        if mask  is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        attn_probs = torch.softmax(attn_scores, dim = -1) 
        output = torch.matmul(attn_probs, V) # 
        return output
    

    def split_heads(self,x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.heads, self.d_k).transpose(1,2)

    def combine_heads(self,x): # Area where the output is reshaped and combines output from all the heads
        print(x.size())
        batch_size,_, seq_length,d_k = x.size()
        return x.transpose(1,2).contiguous().view(batch_size, seq_length, self.d_model)
    
    def forward(self, Q, K, V, mask = None): # => allows model to focus on some aspets of i/p sequencing
        
        Q = self.split_heads(self.W_q(Q)) #Splits the heads
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        attn_output = self.scaled_dot_product_attention(Q,K,V, mask) # Attention score is calculated 
        output = self.W_o(self.combine_heads(attn_output))
        return output

FeedForward

In [3]:
class PositionalFF(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionalFF, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

Positional Encoding (For injecting position information of each token into input seq)

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype = torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

    

Encoder : Multi-Head Attention => Position-wise FFN => 2 Layer Normalisation 

In [5]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionalFF(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask):
        attn_output = self.self_attn(x,x,x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

Decoder: Self Attention => Cross Attention => FF => Linear => SoftMax

In [6]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionalFF(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x,x,x, tgt_mask) # Calculates the self attention scores
        x = self.norm1(x + self.dropout(attn_output)) # Normalisation
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask) #X taken from attn_output 1 & Encoder Output 1 taken from Encoder => producing the output
        x = self.norm2(x + self.dropout(attn_output)) #Normalisation
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

Transformer

In [7]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt): # Creates binary mask for src & Target seq to ignore padding tokens => preventing decoder from attending to future tokens
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)

        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask
    
    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)
        
        dec_output = tgt_embedded

        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output
    

Training Time!

In [12]:
src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1

transformer = Transformer(src_vocab_size,tgt_vocab_size,d_model,num_heads,num_layers,d_ff,max_seq_length,dropout)

src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))

criterion = nn.CrossEntropyLoss(ignore_index =0)
optimizer = optim.Adam(transformer.parameters(), lr = 0.0001, betas = (0.9,0.98), eps = 1e-9)

transformer.train()

for epoch in range(100):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch + 1} , Loss : {loss.item()}")

torch.Size([64, 8, 100, 64])
torch.Size([64, 8, 100, 64])
torch.Size([64, 8, 100, 64])
torch.Size([64, 8, 100, 64])
torch.Size([64, 8, 100, 64])
torch.Size([64, 8, 100, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
Epoch: 1 , Loss : 8.694376945495605
torch.Size([64, 8, 100, 64])
torch.Size([64, 8, 100, 64])
torch.Size([64, 8, 100, 64])
torch.Size([64, 8, 100, 64])
torch.Size([64, 8, 100, 64])
torch.Size([64, 8, 100, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])
torch.Size([64, 8, 99, 64])


KeyboardInterrupt: 

Testing The Model... I forgot to test-train-split the data

In [None]:
def test_model(model, test_loader, criterion):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():  # Disable gradient computation
        for src_data, tgt_data in test_loader:
            output = model(src_data, tgt_data[:, :-1])
            loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
            total_loss += loss.item()
            _, predicted = torch.max(output.data, -1)
            correct = (predicted == tgt_data[:, 1:]).float()  # Assuming your target is categorical
            total_correct += correct.sum().item()
            total_samples += tgt_data.size(0)

    avg_loss = total_loss / len(test_loader)
    accuracy = total_correct / total_samples * 100
    print(f'Test Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')

# Assuming you have a model, test_loader, and criterion defined
test_model(transformer, test_loader, criterion)

tensor([[[[7.6205e-01, 9.9327e-01, 5.0873e-01,  ..., 7.6584e-01,
           7.2699e-03, 6.6235e-01],
          [1.7281e-01, 6.1966e-01, 4.8188e-01,  ..., 6.8336e-01,
           2.1861e-01, 5.6384e-01],
          [3.7274e-01, 7.9061e-01, 9.2151e-01,  ..., 4.0698e-01,
           1.3234e-01, 3.4084e-01],
          ...,
          [7.7362e-01, 1.3355e-02, 8.7557e-01,  ..., 1.1168e-01,
           6.1797e-01, 5.5582e-01],
          [9.2714e-01, 1.6533e-01, 1.5321e-01,  ..., 2.7326e-01,
           3.5379e-01, 3.8937e-01],
          [5.9837e-01, 1.7890e-01, 1.7711e-01,  ..., 2.5047e-01,
           7.7111e-01, 9.6134e-01]],

         [[1.0337e-01, 2.9498e-02, 7.0527e-01,  ..., 5.9516e-01,
           1.1848e-01, 8.2698e-01],
          [9.2683e-01, 1.6115e-01, 6.4840e-01,  ..., 9.4216e-01,
           6.1256e-01, 4.0105e-01],
          [1.3863e-01, 8.3464e-01, 6.8435e-01,  ..., 7.5664e-01,
           8.2617e-01, 6.5980e-01],
          ...,
          [3.8531e-01, 4.0938e-02, 3.9644e-01,  ..., 3.7407