In [None]:
import torch
import torch.nn as nn
import math
from pathlib import Path
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import warnings
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
import torchmetrics
import os

In [None]:
# Config
def get_config():
    return {
        "batch_size": 8,
        "num_epochs": 20,
        "lr": 10**-4,
        "seq_len": 350,
        "d_model": 512,
        "datasource": 'opus_books',
        "lang_src": "en",
        "lang_tgt": "it",
        "model_folder": "weights",
        "model_basename": "tmodel_",
        "preload": "latest",
        "tokenizer_file": "tokenizer_{0}.json",
        "experiment_name": "runs/tmodel"
    }

def get_weights_file_path(config, epoch: str):
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}{epoch}.pt"
    return str(Path('.') / model_folder / model_filename)

# Find the latest weights file in the weights folder
def latest_weights_file_path(config):
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}*"
    weights_files = list(Path(model_folder).glob(model_filename))
    if len(weights_files) == 0:
        return None
    weights_files.sort()
    return str(weights_files[-1])

In [None]:
# Dataset
class BilingualDataset(Dataset):
    def __init__(self , ds, tokenizer_src , tokenizer_tgt, src_lang , tgt_lang , seq_len):

        super().__init__()
        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.seq_len = seq_len
        
        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token_id = tokenizer_tgt.token_to_id("[PAD]")
        
    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        src_target_pair = self.ds[idx]
        src_text = src_target_pair['translation'][self.src_lang]
        tgt_text = src_target_pair['translation'][self.tgt_lang]

        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError("Sentence length exceeds maximum sequence length")

        encoder_input = torch.cat([self.sos_token,
                                   torch.tensor(enc_input_tokens, dtype=torch.int64),
                                   self.eos_token,
                                   torch.tensor([self.pad_token_id] * enc_num_padding_tokens, dtype=torch.int64)
                                  ])
        
        decoder_input = torch.cat([self.sos_token,
                                   torch.tensor(dec_input_tokens, dtype=torch.int64),
                                   torch.tensor([self.pad_token_id] * dec_num_padding_tokens, dtype=torch.int64)
                                  ])
        
        label = torch.cat([torch.tensor(dec_input_tokens, dtype=torch.int64),
                           self.eos_token,
                           torch.tensor([self.pad_token_id] * dec_num_padding_tokens, dtype=torch.int64)
                          ])
        
        assert encoder_input.shape[0] == self.seq_len
        assert decoder_input.shape[0] == self.seq_len
        assert label.shape[0] == self.seq_len


        return {
            "encoder_input": encoder_input,
            "decoder_input": decoder_input,
            "encoder_mask": (encoder_input != self.pad_token_id).unsqueeze(0).unsqueeze(0).int(),
            "decoder_mask": (decoder_input != self.pad_token_id).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)),
            "label": label,
            "src_text": src_text,
            "tgt_text": tgt_text
        }
    

def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

In [None]:
# Model
class InputEmbedding(nn.Module):

    def __init__(self , d_model:int , vocab_size:int):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size , d_model)

    def forward(self , x):
        return self.embedding(x) * math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):

    def __init__(self , d_model:int , seq_len:int , dropout:float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        #matrix of shape (seq_len,d_model)
        #Positional encoding
        pe = torch.zeros(seq_len,d_model) 

        #create a vector of shape (seq_len)
        position = torch.arange(0 , seq_len , dtype=torch.float).unsqueeze(1) #(seq_len ,1)
        div_term = torch.exp(torch.arange(0,d_model , 2).float() * (-math.log(10000.0)/d_model))

        #apply sine to even position and cosine to odd
        pe[:,0::2] = torch.sin(position * div_term)
        pe[:,1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0) #(1, seq_len , d_model)

        self.register_buffer('pe',pe)

    def forward(self ,x):
        x = x+ (self.pe[: , :x.shape[1], :].requires_grad_(False))
        return self.dropout(x)

    
class LayerNormalization(nn.Module):

    def __init__(self , eps:float = 10**-6) -> None:
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1))
        self.bias = nn.Parameter(torch.zeros(1)) #beta

    def forward(self , x):
        mean = x.mean(dim=-1 ,keepdim = True)
        std = x.std(dim = -1 , keepdim= True)
        return self.alpha * (x-mean) / (self.eps + std) + self.bias


class FeedForward(nn.Module):

    def __init__(self , d_model:int , d_ff:int , dropout:float = 0.1) -> None:
        super().__init__()
        self.linear1 = nn.Linear(d_model , d_ff) #W1 and B1
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff , d_model) #W2 and B2
        

    def forward(self , x):
        #(Batch, Seq_len , d_model) --> (Batch , Seq_len , d_ff) --> (Batch , Seq_len , d_model)
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

class MultiHeadAttention(nn.Module):

    def __init__(self , d_model:int , num_heads:int , dropout:float) -> None:
        super().__init__()
        self.d_model = d_model
        self.h = num_heads
        assert d_model % num_heads == 0 , "d_model must be divisible by num_heads"

        self.d_k = d_model // num_heads
        self.w_q = nn.Linear(d_model , d_model)
        self.w_k = nn.Linear(d_model , d_model)
        self.w_v = nn.Linear(d_model , d_model)
        self.w_o = nn.Linear(d_model , d_model)
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query , key , value , mask , dropout:nn.Dropout):
        d_k = query.shape[-1]


        # (Batch , h , Seq_len , d_k) @ (Batch , h , d_k , Seq_len) --> (Batch , h , Seq_len , Seq_len)
        attention_score = (query @ key.transpose(-2 , -1)) / math.sqrt(d_k) #(B , h , Seq_len , Seq_len)
        if mask is not None:
            attention_score = attention_score.masked_fill(mask == 0 , -1e9) ##replace with a very small value
        attention_score = attention_score.softmax(dim = -1) #(B , h , Seq_len , Seq_len)

        if dropout is not None:
            attention_score = dropout(attention_score)

        return (attention_score @ value) , attention_score #(B , h , Seq_len , d_k)

    def forward(self , query , key , value , mask):
        batch_size = query.shape[0]

        #linear projections
        Q = self.w_q(query) #(B , Seq_len , d_model)
        K = self.w_k(key)   #(B , Seq_len , d_model)
        V = self.w_v(value) #(B , Seq_len , d_model)

        #split into h heads
        Q = Q.view(Q.shape[0] , Q.shape[1] , self.h , self.d_k).transpose(1,2) #(B , h , Seq_len , d_k)
        K = K.view(K.shape[0] , K.shape[1] , self.h , self.d_k).transpose(1,2) #(B , h , Seq_len , d_k)
        V = V.view(V.shape[0] , V.shape[1] , self.h , self.d_k).transpose(1,2) #(B , h , Seq_len , d_k)

        x, self.attention_score = MultiHeadAttention.attention(Q , K , V , mask , self.dropout) #(B , h , Seq_len , d_k)


        # (batch , h , seq_len , d_k) --> (batch , seq_len , h , d_k) --> (batch , seq_len , d_model)
        x = x.transpose(1,2).contiguous().view(x.shape[0] , -1 , self.h * self.d_k) #(B , Seq_len , d_model)

        return self.w_o(x) #(B , Seq_len , d_model)

    
class SublayerConnection(nn.Module):

    def __init__(self , dropout:float) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization()

    def forward(self , x , sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

class EncoderLayer(nn.Module):

    def __init__(self , self_attention_block: MultiHeadAttention, feed_forward_block: FeedForward , dropout:float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([SublayerConnection(dropout) for _ in range(2)])

    def forward(self , x , src_mask):
        x = self.residual_connections[0](x , lambda x: self.self_attention_block(x , x , x , src_mask))
        x = self.residual_connections[1](x , self.feed_forward_block)
        return x

class Encoder(nn.Module):

    def __init__(self , layers:nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()

    def forward(self , x , src_mask):
        for layer in self.layers:
            x = layer(x , src_mask)
        return self.norm(x)
    

class DecoderBlock(nn.Module):

    def __init__(self , self_attention_block: MultiHeadAttention , cross_attention_block: MultiHeadAttention , feed_forward_block: FeedForward , dropout:float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([SublayerConnection(dropout) for _ in range(3)])

    def forward(self , x , encoder_output, src_mask , tgt_mask):
        x = self.residual_connections[0](x , lambda x: self.self_attention_block(x , x , x , tgt_mask))
        x = self.residual_connections[1](x , lambda x: self.cross_attention_block(x , encoder_output , encoder_output, src_mask))
        x = self.residual_connections[2](x , self.feed_forward_block)
        return x

class Decoder(nn.Module):

    def __init__(self , layers:nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()

    def forward(self , x , encoder_output , src_mask , tgt_mask):
        for layer in self.layers:
            x = layer(x , encoder_output , src_mask , tgt_mask)
        return self.norm(x)

class ProjectionLayer(nn.Module):

    def __init__(self , d_model:int , vocab_size:int) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model , vocab_size)

    def forward(self , x):
        #(batch , seq_len , d_model) --> (batch , seq_len , vocab_size)
        return torch.log_softmax(self.proj(x) , dim = -1)


class Transformer(nn.Module):

    def __init__(self , encoder:Encoder , decoder:Decoder , src_embed:nn.Module , tgt_embed:nn.Module ,src_pos:PositionalEncoding , tgt_pos:PositionalEncoding ,projection_layer:ProjectionLayer) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection = projection_layer

    def encode(self , src , src_mask):
        return self.encoder(self.src_pos(self.src_embed(src)) , src_mask)
   
    def decode(self , tgt , encoder_output , src_mask , tgt_mask):
        return self.decoder(self.tgt_pos(self.tgt_embed(tgt)) , encoder_output , src_mask , tgt_mask)

    def project(self , x):
        return self.projection(x)


def build_transformer(src_vocab_size:int , tgt_vocab_size:int , src_seq_len:int , tgt_seq_len:int , d_model:int = 512 , N:int = 6 , H:int = 8 , dropout:float = 0.1 , d_ff:int = 2048) -> Transformer:
    #Embedding layers
    src_embed = InputEmbedding(d_model , src_vocab_size)
    tgt_embed = InputEmbedding(d_model , tgt_vocab_size)

    #Positional Encoding layers
    src_pos = PositionalEncoding(d_model , src_seq_len , dropout)
    tgt_pos = PositionalEncoding(d_model , tgt_seq_len , dropout)

    #Encoder layers
    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = MultiHeadAttention(d_model , H , dropout)
        feed_forward_block = FeedForward(d_model , d_ff , dropout)
        encoder_layer = EncoderLayer(encoder_self_attention_block , feed_forward_block , dropout)
        encoder_blocks.append(encoder_layer)

    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = MultiHeadAttention(d_model , H , dropout)
        cross_attention_block = MultiHeadAttention(d_model , H , dropout)
        feed_forward_block = FeedForward(d_model , d_ff , dropout)
        decoder_layer = DecoderBlock(decoder_self_attention_block , cross_attention_block , feed_forward_block , dropout)
        decoder_blocks.append(decoder_layer)

    #create the encoder and decoder
    encoder = Encoder(nn.ModuleList(encoder_blocks))
    decoder = Decoder(nn.ModuleList(decoder_blocks))

    #create the projection layer
    projection_layer = ProjectionLayer(d_model , tgt_vocab_size)

    #create the transformer model
    model = Transformer(encoder , decoder , src_embed , tgt_embed , src_pos , tgt_pos , projection_layer)

    #initialize the parameters
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return model

In [None]:
# Train
def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]


def get_or_build_tokenizer(config , ds , lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        # Build tokenizer
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))

    return tokenizer


def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # build mask for target
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        # calculate output
        out = model.decode(decoder_input, encoder_output, source_mask, decoder_mask)

        # get next token
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)


def get_ds(config):
    ds_raw = load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_tgt"]}', split='train')


    #Build tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

    #train and val split
    train_ds_size = int(len(ds_raw) * 0.9)
    ds_train = ds_raw.select(range(0, train_ds_size))
    ds_val = ds_raw.select(range(train_ds_size, len(ds_raw)))

    train_ds = BilingualDataset(ds_train, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(ds_val, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

    max_len_src = 0
    max_len_tgt = 0

    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))

    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')
    

    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt


def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model'])
    return model

def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):
    model.eval()
    count = 0

    source_texts = []
    expected = []
    predicted = []

    try:
        # get the console window width
        with os.popen('stty size', 'r') as console:
            _, console_width = console.read().split()
            console_width = int(console_width)
    except:
        # If we can't get the console width, use 80 as default
        console_width = 80

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
            encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)

            # check that the batch size is 1
            assert encoder_input.size(
                0) == 1, "Batch size must be 1 for validation"

            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())

            source_texts.append(source_text)
            expected.append(target_text)
            predicted.append(model_out_text)
            
            # Print the source, target and model output
            print_msg('-'*console_width)
            print_msg(f"{f'SOURCE: ':>12}{source_text}")
            print_msg(f"{f'TARGET: ':>12}{target_text}")
            print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")

            if count == num_examples:
                print_msg('-'*console_width)
                break
    
    if writer:
        # Evaluate the character error rate
        # Compute the char error rate 
        metric = torchmetrics.CharErrorRate()
        cer = metric(predicted, expected)
        writer.add_scalar('validation cer', cer, global_step)
        writer.flush()

        # Compute the word error rate
        metric = torchmetrics.WordErrorRate()
        wer = metric(predicted, expected)
        writer.add_scalar('validation wer', wer, global_step)
        writer.flush()

        # Compute the BLEU metric
        metric = torchmetrics.BLEUScore()
        bleu = metric(predicted, expected)
        writer.add_scalar('validation BLEU', bleu, global_step)
        writer.flush()



def train_model(config):
    # Define the device
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
    print("Using device:", device)
    if (device == 'cuda'):
        print(f"Device name: {torch.cuda.get_device_name(device.index)}")
        print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
    elif (device == 'mps'):
        print(f"Device name: <mps>")
    else:
        print("NOTE: If you have a GPU, consider using it for training.")
        print("      On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc")
        print("      On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu")
    device = torch.device(device)

    # Make sure the weights folder exists
    Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)

    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    # Tensorboard
    writer = SummaryWriter(config['experiment_name'])

    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

    # If the user specified a model to preload before training, load it
    initial_epoch = 0
    global_step = 0
    preload = config['preload']
    model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
    if model_filename:
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
    else:
        print('No model to preload, starting from scratch')

    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_tgt.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

    for epoch in range(initial_epoch, config['num_epochs']):
        torch.cuda.empty_cache()
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
        for batch in batch_iterator:

            encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
            decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
            encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
            decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)

            # Run the tensors through the encoder, decoder and the projection layer
            encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
            decoder_output = model.decode(decoder_input, encoder_output, encoder_mask, decoder_mask) # (B, seq_len, d_model)
            proj_output = model.project(decoder_output) # (B, seq_len, vocab_size)

            # Compare the output with the label
            label = batch['label'].to(device) # (B, seq_len)

            # Compute the loss using a simple cross entropy
            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

            # Log the loss
            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()

            # Backpropagate the loss
            loss.backward()

            # Update the weights
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            global_step += 1

        # Run validation at the end of every epoch
        run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)

        # Save the model at the end of every epoch
        model_filename = get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)

In [None]:
warnings.filterwarnings("ignore")
config = get_config()
config['batch_size'] = 16
config['preload'] = None
config['num_epochs'] = 10

train_model(config)

In [None]:
def translate(sentence: str):
    device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu")
    print(f"Using device: {device}")
    
    config = get_config()
    tokenizer_src = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_src']))))
    tokenizer_tgt = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_tgt']))))
    
    model = build_transformer(tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size(), config["seq_len"], config['seq_len'], d_model=config['d_model']).to(device)
    
    # Load the latest weights
    model_filename = latest_weights_file_path(config)
    state = torch.load(model_filename)
    model.load_state_dict(state['model_state_dict'])
    
    model.eval()
    with torch.no_grad():
        source = tokenizer_src.encode(sentence)
        source = torch.cat([
            torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64), 
            torch.tensor(source.ids, dtype=torch.int64),
            torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64),
            torch.tensor([tokenizer_src.token_to_id('[PAD]')] * (config['seq_len'] - len(source.ids) - 2), dtype=torch.int64)
        ], dim=0).to(device)
        
        source_mask = (source != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device)
        
        encoder_output = model.encode(source, source_mask)
        
        decoder_input = torch.empty(1, 1).fill_(tokenizer_tgt.token_to_id('[SOS]')).type_as(source).to(device)
        
        print(f"Translating: {sentence}")
        
        while True:
            if decoder_input.size(1) == config['seq_len']:
                break
            
            decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)
            
            out = model.decode(decoder_input, encoder_output, source_mask, decoder_mask)
            
            prob = model.project(out[:, -1])
            _, next_word = torch.max(prob, dim=1)
            
            decoder_input = torch.cat(
                [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
            )
            
            if next_word == tokenizer_tgt.token_to_id('[EOS]'):
                break
        
        translation = tokenizer_tgt.decode(decoder_input.squeeze(0).detach().cpu().numpy())
        print(f"Translation: {translation}")

# Test with an English sentence
translate("I am learning to code.")