# 0. Import important libraris (modules)

In [4]:
import torch
print(torch.__version__)

2.8.0+cpu


In [5]:
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from torch.utils.tensorboard import SummaryWriter
# Math
import math
# Huggingface libraries
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
# Pathlib
from pathlib import Path
# Typing
from typing import  Any
# Library that progress bars in loops
from tqdm import tqdm
# Importing library of warnings
import warnings


## 1. Input Embedding


English sentence:
The animal didn't cross the street because it was too tired


**Tokens**

The, animal, didn't, cross, the, street, because, it, was, too, tired


**Vocabulary**

the=0, animal=1, didn't=2, cross=3, street=4, because=5, it=6, was=7, too=8, tired=9

0,1,2,3,0,4,5,6,7,8,9

In [6]:
class InputEmbedding(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.embedding(x) * math.sqrt(self.d_model)    

## Positional Encoding

In [7]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        # Create a matrix of shape (seq_len, d_model) to hold the positional encodings
        pe = torch.zeros(seq_len, d_model)#
        position = torch.arrange(0, seq_len, dtype=torch.float).unsqueeze(1) 
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # (batch_size, seq_len, d_model)
        x = x + self.pe[:, :x.shape[1], :].requires_grad_(False) # (1, seq_len, d_model)
        return self.dropout(x)

## MultiHead Attention (self attention)

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.h = h
        assert d_model % h == 0, "d_model must be divisible by h"
        d_k = d_model // h
        self.d_k = d_k
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Any = None, dropout: Any = None):
        d_k = query.size(-1)
        scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k) 
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = scores.softmax(dim=-1)   
        if dropout is not None:
            scores = dropout(scores) 
        return (scores @ value), scores   

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Any = None):
        query = self.w_q(query)  # (batch_size, seq_len, d_model)
        key = self.w_k(key)      # (batch_size, seq_len, d_model)
        value = self.w_v(value)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)  # (batch_size, h, seq_len, d_k)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)      
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
        X, self.scores = MultiHeadAttention.attention(query, key, value, mask, self.dropout)  # (batch_size, h, seq_len, d_k)
        X = X.transpose(1, 2).contiguous().view(X.shape[0], -1, self.h * self.d_k)
        return self.w_o(X)

## Layer Normalisation

In [9]:
class LayerNorm(nn.Module):
    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1))
        self.beta = nn.Parameter(torch.zeros(1))

    def forward(self, x: torch.Tensor):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.alpha * (x-mean) / (std + self.eps) + self.beta 
    

## Feed Forward Network


This consists of two linear transformation with a Relu activation in bettween.

**FFN(x) = max(0, xW1 + b1)W2 + b2**

input = 512 and output = 512

In [10]:
class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear2(self.dropout(torch.relu(self.linear1(x))))

## Residual Connection

In [11]:
class ResidualConnection(nn.Module):
    def __init__(self, size: int, dropout: float):
        super().__init__()
        self.norm = LayerNorm()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, sublayer: Any) -> torch.Tensor:
        return x + self.dropout(sublayer(self.norm(x)))

## Encoder

stack of N = 6 identical layers

Each layer has two sub-layers

d_model=512

In [13]:
class EncoderBlock(nn.Module):
    def __init__(self, self_attn_block: MultiHeadAttention, feed_forward_block: FeedForward, dropout: float):
        super().__init__()
        self.self_attn_block = self_attn_block
        self.feed_forward_block = feed_forward_block
        self.residual_connection = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])

    def forward(self, x: torch.Tensor, src_mask: Any) -> torch.Tensor:
        x = self.residual_connection[0](x, lambda x: self.self_attn_block(x, x, x, src_mask))
        return self.residual_connection[1](x, self.feed_forward_block)

In [12]:
class Encoder(nn.Module):
    def __init__(self, layers:nn.ModuleList,):
        super().__init__()
        self.layers = layers
        self.norm = LayerNorm()
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)