<a href="https://colab.research.google.com/github/bluehood/Transformer-Translation/blob/main/Transformer_Translation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Understanding the Transformer Architecture for Language Translation

This notebook implements a neural machine translation system based on the groundbreaking paper "Attention is All You Need" (Vaswani et al., 2017). The Transformer architecture introduced in this paper revolutionized natural language processing by eliminating the need for recurrent or convolutional neural networks, instead relying entirely on attention mechanisms to capture relationships between words.

Our implementation focuses on three key innovations from the paper:

1. **Multi-Head Self-Attention**: Allowing the model to simultaneously attend to information from different representation subspaces
2. **Encoder-Decoder Architecture**: Processing the input sequence and generating the output sequence using stacked attention layers
3. **Positional Encoding**: Incorporating sequence order information without recurrence

Through this project, we'll:
- Implement the core components of the Transformer architecture
- Train a model for English-French translation

## Disclaimer

The purpose of the notebook is to implement the Transformer architecture for language translation, discuss how the model functions and to understand the Pytorch implementation. This notebook is not designed to be used to train your own GPT model (although it could be modified to do this). If you want to train the full model please see the training details in the following repository: https://github.com/bluehood/Transformer-Translation.  

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.nn import functional as F

from tqdm import tqdm
import json
import csv
import re
from pathlib import Path
import numpy as np
import math
from collections import defaultdict

# Transformer Architecture
The Transformer consists of an encoder that processes the input sequence and a decoder that generates the translation.

The transformer architecture represents a fundamental shift in how we approach sequence-to-sequence tasks like translation. Instead of processing text word by word like traditional models, it looks at the entire sequence at once.

We'll give a brief overview of the structure of the Transformer. For a more detailed discussion, please see our implementation of a GPT: https://colab.research.google.com/github/bluehood/GPT-Implementation/blob/main/GPT_Implementation.ipynb.



## Tokenisation
The words in our languages need to be converted to a numerical representation before they are fed into the model. The most natural choice is word-level tokenisation, however, it is not the most performant choice (see the GPT implementation for a more thorough discussion).

We define a basic word-level tokeniser which will be used for both languages.

In [None]:
def basic_tokenize(text):
    # Convert to lowercase for consistency
    text = text.lower()

    # Add spaces around punctuation so they become separate tokens
    text = re.sub(r'([.,!?;])', r' \1 ', text)
    text = re.sub(r'(["\'])', r' \1 ', text)

    # Remove non-alphanumeric characters (except allowed punctuation)
    text = re.sub(r'[^a-z0-9.,!?;\'\" ]', ' ', text)

    # Normalize whitespace
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()

    # Split into individual tokens
    tokens = text.split()
    return tokens

This function performs several crucial preprocessing steps:

- Converting to lowercase: This reduces vocabulary size by treating "Word" and "word" as the same token
- Handling punctuation: By adding spaces around punctuation marks, we treat them as separate tokens, which helps the model learn their grammatical significance
- Cleaning text: Removing unusual characters helps standardize the input
- Splitting into tokens: The final step creates a list of individual words and punctuation marks

Once this is done we need to create the vocabulary, which is a mapping between words and their integer representations in the model.

In [None]:
def create_vocabulary(tokens, min_frequency=2):
    # Count token frequencies
    token_counts = defaultdict(int)
    for token in tokens:
        token_counts[token] += 1

    # Start with special tokens
    vocab = {
        '<PAD>': 0,  # Used for padding shorter sequences
        '<UNK>': 1,  # Used for unknown words
        '<START>': 2,  # Marks sequence start
        '<END>': 3,   # Marks sequence end
    }

    # Add frequent tokens to vocabulary
    token_idx = len(vocab)
    for token, count in token_counts.items():
        if count >= min_frequency:
            vocab[token] = token_idx
            token_idx += 1

    return vocab

This function:
- Counts how often each token appears in the training data
- Adds special tokens that serve specific purposes
- Includes only tokens that appear at least min_frequency times, helping reduce vocabulary size and prevent overfitting on rare words


## Embedding Layer
Once we have our tokens converted to indices, the embedding layer transforms these indices into dense vectors that capture semantic meaning. This is implemented in the larger `Transformer` class which will be introduced as code later:

```python
class Transformer(nn.Module):
    def __init__(self, config):
        # Create separate embedding layers for source and target languages
        self.src_tok_emb = nn.Embedding(config.src_vocab_size, config.n_embd)
        self.tgt_tok_emb = nn.Embedding(config.tgt_vocab_size, config.n_embd)
        
        # Create positional embeddings
        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
```
Firstly, each token index is converted into a vector of size `n_embd`. These vectors are learned during training and end up capturing semantic relationships - similar words will have similar embeddings.

Since the transformer has no inherent way of understanding word order, we add positional information with the positional embeddings.



## Multi-Head Attention
The transformer uses attention to weigh the importance of different words when processing each word in a sequence. The attention mechanism is defined in our `MultiHeadAttention` class.

- Each word is projected into three different vectors: query, key, and value
- The query vector of each word is compared with the key vectors of all words to determine attention weights
- These weights are used to create a weighted sum of the value vectors

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.num_heads == 0

        self.num_heads = config.num_heads
        self.head_size = config.n_embd // config.num_heads
        self.n_embd = config.n_embd
        self.dropout = config.dropout

        self.q_proj = nn.Linear(config.n_embd, config.n_embd)
        self.k_proj = nn.Linear(config.n_embd, config.n_embd)
        self.v_proj = nn.Linear(config.n_embd, config.n_embd)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd)

        self.dropout = nn.Dropout(config.dropout)

    def forward(self, q, k=None, v=None, mask=None, is_causal=False):
        batch_size = q.size(0)

        if k is None:
            k = q
        if v is None:
            v = q

        q = self.q_proj(q)
        k = self.k_proj(k)
        v = self.v_proj(v)

        q = q.view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2)
        k = k.view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2)
        v = v.view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2)

        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_size)

        if is_causal:
            seq_len = q.size(-2)
            causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=q.device), diagonal=1)
            scores.masked_fill_(causal_mask, float('-inf'))

        if mask is not None:
            if mask.dim() == 3:
                mask = mask.unsqueeze(1)
            scores.masked_fill_(~mask, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = attn @ v

        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.n_embd)
        out = self.out_proj(out)

        return out

## The Encoder: Understanding the Input

The encoder's job is to process the input sentence and create a rich representation that captures both the meaning of each word and its relationships with other words. It does this through multiple layers, each containing two sub-components: the Multi-Head attention layers and feed-forward layers. After each layer, layer normalisation is applied.

In [None]:
class Block(nn.Module):
    def __init__(self, config):
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.attn = MultiHeadAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.ffwd = FeedForward(config)

    def forward(self, x, mask=None):
        # Self-attention followed by feed-forward processing
        x = x + self.attn(self.ln1(x), self.ln1(x), self.ln1(x), mask)
        x = x + self.ffwd(self.ln2(x))
        return x

class Encoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd)

    def forward(self, x, mask=None):
        for block in self.blocks:
            x = block(x, mask)
        return self.ln_f(x)

Each encoder block:

- First apply self-attention, allowing each word to gather information from other relevant words
- Then processes this information through a feed-forward network
- Uses residual connections (the `x + ...` parts) to maintain a smooth information flow
- Applies layer normalization (`ln1` and `ln2`) to stabilise the learning process

Typically we have several Encoder blocks applied in sequence; in the original paper, six encoder blocks were used.

## The Decoder: Generating the Translation

The decoder has the complex task of generating the translation one word at a time. It needs to:

- Look at the translation it has generated so far
- Consider the entire input sentence
- Decide what word comes next

This is implemented through the following class:

In [None]:
class CrossAttentionBlock(nn.Module):
    """Transformer block with cross-attention to encoder output"""
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.self_attn = MultiHeadAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.cross_attn = MultiHeadAttention(config)
        self.ln3 = nn.LayerNorm(config.n_embd)
        self.ffwd = FeedForward(config)

    def forward(self, x, enc_out, self_mask=None, cross_mask=None):
        # Self attention with causal masking
        x = x + self.self_attn(
            self.ln1(x),
            mask=self_mask,
            is_causal=True
        )

        # Cross attention to encoder output
        x = x + self.cross_attn(
            q=self.ln2(x),
            k=enc_out,
            v=enc_out,
            mask=cross_mask
        )

class Decoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        # Pre-cross attention transformer blocks
        self.pre_blocks = nn.ModuleList([
            Block(config) for _ in range(config.n_pre_cross_layer)
        ])
        # Cross attention blocks
        self.cross_blocks = nn.ModuleList([
            CrossAttentionBlock(config) for _ in range(config.n_cross_layer)
        ])
        self.ln_f = nn.LayerNorm(config.n_embd)

    def forward(self, x, enc_out, padding_mask=None, cross_mask=None):
        # First run through pre-cross attention blocks with causal masking
        for block in self.pre_blocks:
            x = block(x, padding_mask, is_causal=True)

        # Then through cross attention blocks
        # Cross attention blocks still use causal masking for self-attention
        for block in self.cross_blocks:
            x = block(x, enc_out, padding_mask, cross_mask)

        return self.ln_f(x)

        # Feed forward
        x = x + self.ffwd(self.ln3(x))
        return x

The decoder uses two types of attention:

- Masked self-attention to look at the previously generated word. We use masking to ensure that the model does not peek forward during training and essentially cheat by viewing the words it looking to translate
- Cross-attention to incorporate the input sequence into its predictions
- Feed-forward processing to combine this information

The decoder typically has a stack before the input sequence representations are incorporated.

## Putting it Together
We can now define the whole model architecture in the `Transformer` class:

In [None]:
class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Source and target embeddings
        self.src_tok_emb = nn.Embedding(config.src_vocab_size, config.n_embd)
        self.tgt_tok_emb = nn.Embedding(config.tgt_vocab_size, config.n_embd)
        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
        self.drop = nn.Dropout(config.dropout)

        # Encoder and Decoder
        self.encoder = Encoder(config)
        self.decoder = Decoder(config)

        # Output projection
        self.head = nn.Linear(config.n_embd, config.tgt_vocab_size, bias=False)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, src_ids, tgt_ids, src_mask=None, tgt_mask=None):
        B, T = src_ids.size()

        # Source embedding
        src_emb = self.src_tok_emb(src_ids)
        src_pos = self.pos_emb[:, :T, :]
        x = self.drop(src_emb + src_pos)

        # Encode
        encoder_out = self.encoder(x, src_mask)

        # Target embedding
        tgt_emb = self.tgt_tok_emb(tgt_ids)
        tgt_pos = self.pos_emb[:, :tgt_ids.size(1), :]
        y = self.drop(tgt_emb + tgt_pos)

        # Decode
        y = self.decoder(y, encoder_out, tgt_mask, src_mask)

        # Project to vocabulary
        logits = self.head(y)

        return logits

    def generate(self, src_ids, max_new_tokens, temperature=1.0, top_k=None):
        """Generate translation tokens autoregressively"""
        self.eval()
        B, T = src_ids.size()

        # Create source padding mask (1 for tokens, 0 for padding)
        src_mask = (src_ids != 0).unsqueeze(1).unsqueeze(2).to(dtype=torch.bool)

        # First encode the source sequence
        src_emb = self.src_tok_emb(src_ids)
        pos_emb = self.pos_emb[:, :T, :]
        x = self.drop(src_emb + pos_emb)
        encoder_out = self.encoder(x, src_mask)

        # Initialize target sequence with START token
        tgt_ids = torch.full((B, 1), fill_value=2, dtype=torch.long, device=src_ids.device)  # Assume 2 is START token

        for _ in range(max_new_tokens):
            # Cut off if sequence is too long
            if tgt_ids.size(1) > self.config.block_size:
                break

            # Create target padding mask (1 for tokens, 0 for padding)
            tgt_mask = (tgt_ids != 0).unsqueeze(1).unsqueeze(2).to(dtype=torch.bool)

            # Get embeddings for target sequence
            tgt_emb = self.tgt_tok_emb(tgt_ids)
            pos_emb = self.pos_emb[:, :tgt_ids.size(1), :]
            y = self.drop(tgt_emb + pos_emb)

            # Decode
            y = self.decoder(y, encoder_out, tgt_mask, src_mask)

            # Project to vocabulary
            logits = self.head(y)

            # Only take the last token's logits
            logits = logits[:, -1, :] / temperature

            # Optional top-k sampling
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')

            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            # Append next token to sequence
            tgt_ids = torch.cat((tgt_ids, next_token), dim=1)

            # Stop if we hit the EOS token (assume 3 is EOS token)
            if (next_token == 3).any():
                break

        return tgt_ids

class TransformerConfig:
    """Configuration class to store the configuration of a `Transformer`."""
    def __init__(
        self,
        src_vocab_size=50257,
        tgt_vocab_size=50257,
        block_size=1024,
        n_layer=12,  # Number of encoder layers
        n_pre_cross_layer=6,  # Number of decoder layers before cross attention
        n_cross_layer=6,  # Number of decoder layers with cross attention
        n_embd=768,
        num_heads=12,
        dropout=0.1
    ):
        self.src_vocab_size = src_vocab_size
        self.tgt_vocab_size = tgt_vocab_size
        self.block_size = block_size
        self.n_layer = n_layer
        self.n_pre_cross_layer = n_pre_cross_layer
        self.n_cross_layer = n_cross_layer
        self.n_embd = n_embd
        self.num_heads = num_heads
        self.dropout = dropout

# Training
We used the following dataset of English to French translations: https://www.kaggle.com/datasets/devicharith/language-translation-englishfrench.

## Preparing the Data
Before we can train the model, we need to prepare our data in a way that the model can process. We define the `TranslationDataset` to tokenise the dataset and return the English to French pairs at a given index in the dataset:

In [None]:
class TranslationDataset(Dataset):
    def __init__(self, src_texts, tgt_texts, src_tokenizer, tgt_tokenizer, max_length=128):
        self.src_texts = src_texts
        self.tgt_texts = tgt_texts
        self.max_length = max_length

        # Pre-tokenize all texts
        print("Pre-tokenizing source texts...")
        self.src_encoded = tokenize_and_pad(src_texts, src_tokenizer, max_length)
        print("Pre-tokenizing target texts...")
        self.tgt_encoded = tokenize_and_pad(tgt_texts, tgt_tokenizer, max_length)

    def __len__(self):
        return len(self.src_texts)

    def __getitem__(self, idx):
        return {
            'src_ids': self.src_encoded[idx],
            'tgt_ids': self.tgt_encoded[idx],
            'src_text': self.src_texts[idx],
            'tgt_text': self.tgt_texts[idx]
        }

Training is implemented by the `train_model` sequence. We'll discuss a few important steps before defining this function:

### Learning Rate Management
Just as a language learner might need to adjust their learning pace, our model uses a learning rate scheduler:

```python
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
 optimizer, mode='min', factor=0.5, patience=2, verbose=True
)
```

This scheduler watches the model's performance and adjusts the learning rate when improvement slows down. If the model stops improving for a while (defined by the patience parameter), the learning rate is reduced by half, allowing for finer adjustments to the model's parameters.

### Validation and Model Selection
During training, we regularly check how well the model is doing on unseen data:
```python
if avg_val_loss < best_val_loss:
 best_val_loss = avg_val_loss
 torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': best_val_loss,
    }, './models/best_model.pt')
```

This is like giving the language learner a test on material they haven't specifically studied. We save the version of the model that performs best on these validation tests, as this indicates good generalization to new sentences.

### Training Configuration
Our model uses specific hyperparameters chosen to balance training speed and performance:
```python
config = TransformerConfig(
 src_vocab_size=len(eng_tokenizer),
 tgt_vocab_size=len(fr_tokenizer),
 block_size=128,
 n_layer=6,
 n_pre_cross_layer=3,
 n_cross_layer=3,
 n_embd=256,
 num_heads=8,
 dropout=0.1
)
```
These parameters define:

- The size of our vocabularies for both languages
- The maximum sequence length (block_size)
- The model's architecture (number of layers and their sizes)
- Regularisation strength (dropout)

### Preventing Overfitting
The training process includes several mechanisms to prevent the model from memorizing the training data instead of learning to translate:

- Dropout is applied throughout the model:
```python
self.drop = nn.Dropout(config.dropout)
```
- The training data is reshuffled at the start of each epoch:
```python
reshuffle_training_data(train_dataset)
```

- Gradient clipping prevents extreme parameter updates:
```python
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
```

The full training loop can be implemented as:

In [None]:
def tokenize_and_pad(texts, tokenizer, max_length):
    """Tokenize and pad a list of texts in one batch operation"""
    encoded_texts = []
    for text in texts:
        # Encode text
        encoded = [tokenizer['<START>']] + \
                 [tokenizer.get(token, tokenizer['<UNK>'])
                  for token in text.split()] + \
                 [tokenizer['<END>']]

        # Truncate if necessary
        encoded = encoded[:max_length]

        # Pad sequence
        encoded += [tokenizer['<PAD>']] * (max_length - len(encoded))
        encoded_texts.append(encoded)

    return torch.tensor(encoded_texts)

def create_masks(src_ids, tgt_ids):
    # Source mask (padding mask)
    src_mask = (src_ids != 0).unsqueeze(1).unsqueeze(2)  # (B, 1, 1, src_len)

    # Target mask (combination of padding and subsequent mask)
    tgt_pad_mask = (tgt_ids != 0).unsqueeze(1).unsqueeze(2)  # (B, 1, 1, tgt_len)

    tgt_len = tgt_ids.size(1)
    subsequent_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt_ids.device, dtype=torch.bool))
    subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1)  # (1, 1, tgt_len, tgt_len)

    tgt_mask = tgt_pad_mask & subsequent_mask

    return src_mask.to(torch.bool), tgt_mask.to(torch.bool)

def load_tokenizer(path):
    with open(path, 'r', encoding='utf-8') as f:
        return json.load(f)

def load_dataset(path):
    english_sentences = []
    french_sentences = []

    with open(path, 'r', encoding='utf-8') as file:
        csv_reader = csv.reader(file)
        next(csv_reader)  # Skip header
        for row in csv_reader:
            if len(row) == 2:
                english_sentences.append(row[0].lower())  # Convert to lowercase
                french_sentences.append(row[1].lower())  # Convert to lowercase

    return english_sentences, french_sentences

def reshuffle_training_data(dataset):
    """Reshuffle the training data while keeping pairs aligned"""
    indices = list(range(len(dataset.src_texts)))
    np.random.shuffle(indices)

    # Reorder all dataset attributes using the shuffled indices
    dataset.src_texts = [dataset.src_texts[i] for i in indices]
    dataset.tgt_texts = [dataset.tgt_texts[i] for i in indices]
    dataset.src_encoded = dataset.src_encoded[indices]
    dataset.tgt_encoded = dataset.tgt_encoded[indices]

def train_model(model, train_dataloader, val_dataloader, train_dataset, num_epochs, device, learning_rate=3e-4):
    optimizer = AdamW(model.parameters(), lr=learning_rate)

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2, verbose=True
    )

    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        # Reshuffle training data at the start of each epoch
        reshuffle_training_data(train_dataset)

        model.train()
        total_loss = 0
        train_pbar = tqdm(train_dataloader, desc=f'Epoch [{epoch+1}/{num_epochs}]', leave=False)

        for batch in train_pbar:
            # Move batch to device
            src_ids = batch['src_ids'].to(device)  # [batch_size, seq_len]
            tgt_ids = batch['tgt_ids'].to(device)  # [batch_size, seq_len]

            # Create attention masks
            src_mask = (src_ids != 0).unsqueeze(1).unsqueeze(2)  # [batch_size, 1, 1, src_len]
            tgt_mask = (tgt_ids != 0).unsqueeze(1).unsqueeze(2)  # [batch_size, 1, 1, tgt_len]

            # Forward pass
            logits = model(
                src_ids=src_ids,
                tgt_ids=tgt_ids[:, :-1],  # Remove last token from target input
                src_mask=src_mask,
                tgt_mask=tgt_mask[:, :, :, :-1]  # Adjust mask accordingly
            )

            # Calculate loss
            loss = F.cross_entropy(
                logits.contiguous().view(-1, logits.size(-1)),
                tgt_ids[:, 1:].contiguous().view(-1),  # Shift target ids right by 1
                ignore_index=0  # Ignore padding token
            )

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()
            train_pbar.set_postfix({'loss': loss.item()})

        avg_train_loss = total_loss / len(train_dataloader)

        # Validation
        model.eval()
        total_val_loss = 0

        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc='Validation', leave=False):
                src_ids = batch['src_ids'].to(device)
                tgt_ids = batch['tgt_ids'].to(device)

                src_mask = (src_ids != 0).unsqueeze(1).unsqueeze(2)
                tgt_mask = (tgt_ids != 0).unsqueeze(1).unsqueeze(2)

                logits = model(
                    src_ids=src_ids,
                    tgt_ids=tgt_ids[:, :-1],
                    src_mask=src_mask,
                    tgt_mask=tgt_mask[:, :, :, :-1]
                )

                loss = F.cross_entropy(
                    logits.contiguous().view(-1, logits.size(-1)),
                    tgt_ids[:, 1:].contiguous().view(-1),
                    ignore_index=0
                )

                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_dataloader)

        # Update learning rate
        scheduler.step(avg_val_loss)

        print(f'Epoch [{epoch+1}/{num_epochs}] | Train Loss = {avg_train_loss:.4f} | Validation Loss = {avg_val_loss:.4f}')

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_val_loss,
            }, './models/best_model.pt')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load tokenizers
eng_tokenizer = load_tokenizer('./models/english_tokeniser.json')
fr_tokenizer = load_tokenizer('./models/french_tokeniser.json')

# Load dataset
eng_sentences, fr_sentences = load_dataset('./data/eng_french.csv')

# Create shuffled indices for the full dataset
indices = list(range(len(eng_sentences)))
np.random.shuffle(indices)

# Use the shuffled indices to reorder both sentence lists simultaneously
eng_sentences = [eng_sentences[i] for i in indices]
fr_sentences = [fr_sentences[i] for i in indices]

# Create train/val split
split_idx = int(len(eng_sentences) * 0.9)

train_eng = eng_sentences[:split_idx]
train_fr = fr_sentences[:split_idx]
val_eng = eng_sentences[split_idx:]
val_fr = fr_sentences[split_idx:]

# Create datasets
train_dataset = TranslationDataset(
    train_eng, train_fr, eng_tokenizer, fr_tokenizer
)
val_dataset = TranslationDataset(
    val_eng, val_fr, eng_tokenizer, fr_tokenizer
)

# Create dataloaders
train_dataloader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=4
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=4
)

# Initialize model
config = TransformerConfig(
    src_vocab_size=len(eng_tokenizer),
    tgt_vocab_size=len(fr_tokenizer),
    block_size=128,
    n_layer=6,
    n_pre_cross_layer=3,
    n_cross_layer=3,
    n_embd=256,
    num_heads=8,
    dropout=0.1
)

model = Transformer(config).to(device)

# Train the model
train_model(
    model=model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    train_dataset=train_dataset,  # Pass the training dataset for reshuffling
    num_epochs=20,
    device=device,
    learning_rate=3e-4
)