# Transformer Machine Translation Training on Google Colab

This notebook trains a Transformer model for bilingual machine translation (English to Italian) using the OPUS Books dataset.

## 1. Install and Import Required Libraries

In [1]:
# Install required packages for Colab
!pip install torch datasets tokenizers tqdm tensorboard torchmetrics -q

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import math
from pathlib import Path
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {torch.cuda.get_device_name(device)}')

Using device: Tesla T4


## 2. Configure Training Parameters

In [2]:
# Configuration
config = {
    'batch_size': 16,
    'num_epochs': 20,
    'lr': 10**-4,
    'seq_len': 350,
    'd_model': 512,
    'lang_src': 'en',
    'lang_tgt': 'it',
    'model_folder': './weights',
    'model_basename': 'tmodel_',
    'preload': None,  # Set to epoch number to resume training (e.g., '05')
    'tokenizer_file': './tokenizer_{0}.json',
    'experiment_name': 'runs/tmodel'
}

def get_weights_file_path(config, epoch: str):
    model_folder = config['model_folder']
    model_basename = config['model_basename']
    model_filename = f'{model_basename}{epoch}.pt'
    return str(Path(model_folder) / model_filename)

def latest_weights_file_path(config):
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}*"
    weights_files = list(Path(model_folder).glob(model_filename))
    if len(weights_files) == 0:
        return None
    weights_files.sort()
    return str(weights_files[-1])

# Create necessary directories
Path(config['model_folder']).mkdir(parents=True, exist_ok=True)

print(f"Config: {config}")

Config: {'batch_size': 16, 'num_epochs': 20, 'lr': 0.0001, 'seq_len': 350, 'd_model': 512, 'lang_src': 'en', 'lang_tgt': 'it', 'model_folder': './weights', 'model_basename': 'tmodel_', 'preload': None, 'tokenizer_file': './tokenizer_{0}.json', 'experiment_name': 'runs/tmodel'}


## 3. Build and Load Tokenizers

In [3]:
def get_all_sentences(ds, lang):
    """Generator to yield all sentences for a given language"""
    for item in ds:
        yield item['translation'][lang]

def get_or_build_tokenizer(config, ds, lang):
    """Load or build tokenizer for a given language"""
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not tokenizer_path.exists():
        print(f'Building tokenizer for {lang}...')
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"])
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
        print(f'Tokenizer for {lang} saved at {tokenizer_path}')
    else:
        print(f'Loading tokenizer for {lang} from {tokenizer_path}')
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

# Load dataset
print('Loading dataset...')
ds = load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_tgt"]}', split='train')
print(f'Dataset loaded: {len(ds)} samples')

# Build or load tokenizers
src_tokenizer = get_or_build_tokenizer(config, ds, config['lang_src'])
tgt_tokenizer = get_or_build_tokenizer(config, ds, config['lang_tgt'])

print(f'Source vocabulary size: {src_tokenizer.get_vocab_size()}')
print(f'Target vocabulary size: {tgt_tokenizer.get_vocab_size()}')

Loading dataset...


README.md: 0.00B [00:00, ?B/s]

en-it/train-00000-of-00001.parquet:   0%|          | 0.00/5.73M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/32332 [00:00<?, ? examples/s]

Dataset loaded: 32332 samples
Building tokenizer for en...
Tokenizer for en saved at tokenizer_en.json
Building tokenizer for it...
Tokenizer for it saved at tokenizer_it.json
Source vocabulary size: 25138
Target vocabulary size: 30000


## 4. Create Bilingual Dataset Class

In [4]:
def causal_mask(seq_len):
    """Create a causal mask for the decoder to prevent attention to future tokens"""
    mask = torch.triu(torch.ones(1, seq_len, seq_len), diagonal=1).type(torch.int64)
    return mask == 0  # shape: (1, seq_len, seq_len)

class BilingualDataset(Dataset):
    """Dataset for bilingual translation pairs"""
    
    def __init__(self, ds, src_tokenizer, tgt_tokenizer, src_lang, tgt_lang, seq_len):
        super().__init__()
        self.ds = ds
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.seq_len = seq_len
        
        self.sos_token = torch.tensor([tgt_tokenizer.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tgt_tokenizer.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tgt_tokenizer.token_to_id("[PAD]")], dtype=torch.int64)
    
    def __len__(self):
        return len(self.ds)
    
    def __getitem__(self, idx):
        src_tgt_pair = self.ds[idx]
        src_text = src_tgt_pair['translation'][self.src_lang]
        tgt_text = src_tgt_pair['translation'][self.tgt_lang]
        
        enc_input_tokens = self.src_tokenizer.encode(src_text).ids
        dec_input_tokens = self.tgt_tokenizer.encode(tgt_text).ids
        
        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
        
        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError(f"Sequence length {self.seq_len} is too small for the given text")
        
        encoder_input = torch.cat([
            self.sos_token,
            torch.tensor(enc_input_tokens, dtype=torch.int64),
            self.eos_token,
            torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64)
        ])
        
        decoder_input = torch.cat([
            self.sos_token,
            torch.tensor(dec_input_tokens, dtype=torch.int64),
            torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)
        ])
        
        labels = torch.cat([
            torch.tensor(dec_input_tokens, dtype=torch.int64),
            self.eos_token,
            torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)
        ])
        
        assert encoder_input.shape[0] == self.seq_len
        assert decoder_input.shape[0] == self.seq_len
        assert labels.shape[0] == self.seq_len
        
        return {
            "encoder_input": encoder_input,      # (seq_len,)
            "decoder_input": decoder_input,      # (seq_len,)
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(),  # (1, 1, seq_len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.shape[0]),  # (1, 1, seq_len, seq_len)
            "labels": labels,  # (seq_len,)
            "src_text": src_text,
            "tgt_text": tgt_text
        }

# Create and split dataset
print('Creating datasets...')
train_ds_size = int(len(ds) * 0.9)
val_ds_size = len(ds) - train_ds_size
train_ds_raw, val_ds_raw = random_split(ds, [train_ds_size, val_ds_size])

train_ds = BilingualDataset(train_ds_raw, src_tokenizer, tgt_tokenizer, config['lang_src'], config['lang_tgt'], config['seq_len'])
val_ds = BilingualDataset(val_ds_raw, src_tokenizer, tgt_tokenizer, config['lang_src'], config['lang_tgt'], config['seq_len'])

# Calculate max lengths
max_len_src, max_len_tgt = 0, 0
for item in ds:
    src_ids = src_tokenizer.encode(item['translation'][config['lang_src']]).ids
    tgt_ids = tgt_tokenizer.encode(item['translation'][config['lang_tgt']]).ids
    max_len_src = max(max_len_src, len(src_ids))
    max_len_tgt = max(max_len_tgt, len(tgt_ids))

print(f'Max length of source sentence: {max_len_src}')
print(f'Max length of target sentence: {max_len_tgt}')
print(f'Training dataset size: {len(train_ds)}')
print(f'Validation dataset size: {len(val_ds)}')

Creating datasets...
Max length of source sentence: 309
Max length of target sentence: 274
Training dataset size: 29098
Validation dataset size: 3234


## 5. Build Transformer Model Architecture

In [5]:
class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
    
    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_seq_len: int, dropout: float):
        super().__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        
        pe = torch.zeros(max_seq_len, d_model)
        positions = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(1000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(positions * div_term)
        pe[:, 1::2] = torch.cos(positions * div_term)
        pe = pe.unsqueeze(0)  # (1, max_seq_len, d_model)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
        return self.dropout(x)

class LayerNormalization(nn.Module):
    def __init__(self, eps: float = 10**-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1))
        self.bias = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

class FeedForwardBlock(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
    
    def forward(self, x):
        return self.linear2(self.dropout(torch.relu(self.linear1(x))))

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, h: int, dropout: float):
        super().__init__()
        self.d_model = d_model
        self.h = h
        assert d_model % h == 0, "d_model must be divisible by h"
        
        self.d_k = d_model // h
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    
    @staticmethod
    def attention(query, key, value, mask, dropout):
        d_k = query.shape[-1]
        attention_scores = (query @ key.transpose(-1, -2)) / math.sqrt(d_k)
        
        if mask is not None:
            attention_scores.masked_fill_(mask == 0, -1e9)
        
        attention_scores = attention_scores.softmax(dim=-1)
        
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        
        return (attention_scores @ value), attention_scores
    
    def forward(self, q, k, v, mask):
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)
        
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
        
        x, self.attention_scores = MultiHeadAttention.attention(query, key, value, mask, self.dropout)
        
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
        return self.w_o(x)

class ResidualConnection(nn.Module):
    def __init__(self, dropout: float):
        super().__init__()
        self.norm = LayerNormalization()
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

class EncoderBlock(nn.Module):
    def __init__(self, self_attention_block, feed_forward_block, dropout: float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])
    
    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x

class Encoder(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()
    
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class DecoderBlock(nn.Module):
    def __init__(self, self_attention_block, cross_attention_block, feed_forward_block, dropout: float):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])
    
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x

class Decoder(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()
    
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super().__init__()
        self.projection = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        return torch.log_softmax(self.projection(x), dim=-1)

class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer
    
    def encode(self, src, src_mask):
        src_embedded = self.src_embed(src)
        src_positioned = self.src_pos(src_embedded)
        return self.encoder(src_positioned, src_mask)
    
    def decode(self, tgt, encoder_output, src_mask, tgt_mask):
        tgt_embedded = self.tgt_embed(tgt)
        tgt_positioned = self.tgt_pos(tgt_embedded)
        return self.decoder(tgt_positioned, encoder_output, src_mask, tgt_mask)
    
    def project(self, decoder_output):
        return self.projection_layer(decoder_output)

def build_transformer(src_vocab_size, tgt_vocab_size, src_seq_len, tgt_seq_len, d_model=512, N=6, h=8, dropout=0.1, d_ff=2048):
    src_embed = InputEmbeddings(d_model, src_vocab_size)
    tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)
    
    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
    
    encoder_layers = []
    for _ in range(N):
        self_attention_block = MultiHeadAttention(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(self_attention_block, feed_forward_block, dropout)
        encoder_layers.append(encoder_block)
    
    decoder_layers = []
    for _ in range(N):
        self_attention_block = MultiHeadAttention(d_model, h, dropout)
        cross_attention_block = MultiHeadAttention(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(self_attention_block, cross_attention_block, feed_forward_block, dropout)
        decoder_layers.append(decoder_block)
    
    encoder = Encoder(nn.ModuleList(encoder_layers))
    decoder = Decoder(nn.ModuleList(decoder_layers))
    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
    
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)
    
    # Initialize parameters
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    
    return transformer

# Build model
print('Building Transformer model...')
model = build_transformer(
    src_vocab_size=src_tokenizer.get_vocab_size(),
    tgt_vocab_size=tgt_tokenizer.get_vocab_size(),
    src_seq_len=config['seq_len'],
    tgt_seq_len=config['seq_len'],
    d_model=config['d_model']
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f'Model created with {total_params:,} parameters')

Building Transformer model...
Model created with 87,728,496 parameters


## 6. Create Data Loaders

In [6]:
train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=False)

print(f'Train batches: {len(train_dataloader)}')
print(f'Validation batches: {len(val_dataloader)}')

Train batches: 1819
Validation batches: 3234


In [7]:
def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, config['seq_len'], config['seq_len'], config['d_model'])    
    return model


In [8]:

def get_ds(config):
    ds = load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_tgt"]}', split='train')
   
    src_tokenizer = get_or_build_tokenizer(config, ds, config['lang_src'])
    tgt_tokenizer = get_or_build_tokenizer(config, ds, config['lang_tgt'])
    
    train_ds_size = int(len(ds) * 0.9)
    val_ds_size = len(ds) - train_ds_size

    train_ds_row, val_ds_row = random_split(ds, [train_ds_size, val_ds_size])

    train_ds = BilingualDataset(train_ds_row, src_tokenizer, tgt_tokenizer, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(val_ds_row, src_tokenizer, tgt_tokenizer, config['lang_src'], config['lang_tgt'], config['seq_len'])

    max_len_src, max_len_tgt = 0, 0

    for item in ds:
        src_ids = src_tokenizer.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tgt_tokenizer.encode(item['translation'][config['lang_tgt']]).ids

        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))
    
    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')

    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, src_tokenizer, tgt_tokenizer


## 7. Implement Greedy Decoding for Validation

In [9]:
import torchmetrics
def greedy_decode(model, src, src_mask, tgt_tokenizer, max_len, device):
    sos_idx = tgt_tokenizer.token_to_id('[SOS]')
    eos_idx = tgt_tokenizer.token_to_id('[EOS]')

    # Precompute encoder output once: (1, Seq_len, D_model)
    encoder_output = model.encode(src, src_mask)

    # Start with SOS token: (1, 1)
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(src).to(device)

    while True:
        if decoder_input.size(1) == max_len:
            break

        # Build causal mask for current decoder input: (1, 1, current_len, current_len)
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(src_mask).to(device)

        # Decode: (1, current_len, D_model)
        out = model.decode(decoder_input, encoder_output, src_mask, decoder_mask)

        # Project only the last token: (1, 1, Vocab_tgt_len)
        prob = model.project(out[:, -1])

        # Greedy: pick the token with the highest probability
        _, next_word = torch.max(prob, dim=1)  # (1,)

        # Append to decoder input: (1, current_len + 1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(src).fill_(next_word.item()).to(device)],
            dim=1
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)  # (generated_len,)


def validate(model, val_dataloader, tgt_tokenizer, max_len, device, print_msg, global_step, writer, num_examples=2):
    model.eval()
    count = 0

    source_texts = []
    expected = []
    predicted = []

    console_width = 80

    with torch.no_grad():
        for batch in val_dataloader:
            count += 1

            enc_input = batch['encoder_input'].to(device)  # (1, Seq_len)
            enc_mask = batch['encoder_mask'].to(device)    # (1, 1, 1, Seq_len)

            assert enc_input.size(0) == 1, "Validation batch size must be 1"

            model_out = greedy_decode(model, enc_input, enc_mask, tgt_tokenizer, max_len, device)  # (generated_len,)

            src_text = batch['src_text'][0]
            tgt_text = batch['tgt_text'][0]
            model_out_text = tgt_tokenizer.decode(model_out.detach().cpu().numpy())

            source_texts.append(src_text)
            expected.append(tgt_text)
            predicted.append(model_out_text)

            print_msg('-' * console_width)
            print_msg(f'{"SOURCE: ":>12s}{src_text}')
            print_msg(f'{"TARGET: ":>12s}{tgt_text}')
            print_msg(f'{"PREDICTED: ":>12s}{model_out_text}')

            if count == num_examples:
                break

    if writer:
        # Evaluate the character error rate
        # Compute the char error rate 
        metric = torchmetrics.CharErrorRate()
        cer = metric(predicted, expected)
        writer.add_scalar('validation cer', cer, global_step)
        writer.flush()

        # Compute the word error rate
        metric = torchmetrics.WordErrorRate()
        wer = metric(predicted, expected)
        writer.add_scalar('validation wer', wer, global_step)
        writer.flush()

        # Compute the BLEU metric
        metric = torchmetrics.BLEUScore()
        bleu = metric(predicted, expected)
        writer.add_scalar('validation BLEU', bleu, global_step)
        writer.flush()


print('Validation functions defined')

Validation functions defined


## 8. Train the Model

In [None]:

def train_model(config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    Path(config['model_folder']).mkdir(parents=True, exist_ok=True)

    train_dataloader, val_dataloader, src_tokenizer, tgt_tokenizer = get_ds(config)
    model = get_model(config, src_tokenizer.get_vocab_size(), tgt_tokenizer.get_vocab_size()).to(device)

    writer = SummaryWriter(config['experiment_name'])

    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)


    initial_epoch = 0
    global_step = 0

    if config['preload']:
        model_filename = get_weights_file_path(config, config['preload'])
        print(f'Preloading model {model_filename}')

        state = torch.load(model_filename)
        
        initial_epoch = state['epoch'] + 1

        optimizer.load_state_dict(state['optimizer_state_dict'])

        global_step = state['global_step']

    loss_fn = nn.CrossEntropyLoss(ignore_index=src_tokenizer.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

    for epoch in range(initial_epoch, config['num_epochs']):
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f'Processing Epoch {epoch:02d}')

        for batch in batch_iterator:
            enc_input = batch['encoder_input'].to(device) # (B, Seq_len)
            dec_input = batch['decoder_input'].to(device) # (B, Seq_len)
            enc_mask = batch['encoder_mask'].to(device) # (B, 1, 1, Seq_len)
            dec_mask = batch['decoder_mask'].to(device) # (B, 1, Seq_len, Seq_len)
            labels = batch['labels'].to(device) # (B, Seq_len)

            enc_output = model.encode(enc_input, enc_mask) # (B, Seq_len, D_model)
            dec_output = model.decode(dec_input, enc_output, enc_mask, dec_mask) # (B, Seq_len, D_model)
            projected_output = model.project(dec_output) # (B, Seq_len, Vocab_tgt_len)

            # (B, Seq_len, Vocab_tgt_len) -> (B * Seq_len, Vocab_tgt_len)
            loss = loss_fn(projected_output.view(-1, tgt_tokenizer.get_vocab_size()), labels.view(-1))

            batch_iterator.set_postfix({f'loss': f'{loss.item(): 6.3f}'})

            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()

            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            global_step += 1

        
        # Run validation at the end of each epoch
        validate(model, val_dataloader, tgt_tokenizer, config['seq_len'], device,
                 lambda msg: batch_iterator.write(msg), global_step, writer)

        model_filename = get_weights_file_path(config, f'{epoch:02d}')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)


warnings.filterwarnings("ignore")
train_model(config)
print('Training completed!')

Using device: cuda
Loading tokenizer for en from tokenizer_en.json
Loading tokenizer for it from tokenizer_it.json
Max length of source sentence: 309
Max length of target sentence: 274


## 9. Test Inference on New Examples

In [None]:
# Load the best model checkpoint (last epoch)
model_filename = latest_weights_file_path(config)
state = torch.load(model_filename, map_location=device)
model.load_state_dict(state['model_state_dict'])

print('Testing model on validation set...')
validate(model, val_dataloader, tgt_tokenizer, config['seq_len'], device, lambda msg: print(msg), 0, None, num_examples=10)
