In [23]:
import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import torchmetrics

from datasets import load_dataset, get_dataset_split_names
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

from pathlib import Path
import numpy as np
from tqdm import tqdm
import warnings
import math
import os

warnings.filterwarnings("ignore")

In [24]:
class InputEmbedding(nn.Module):
    def __init__(self, d_module: int, vocab_size: int):
        super().__init__()
        self.d_module = d_module
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_module)

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_module)

In [25]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_length: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_length = seq_length
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(seq_length, d_model)
        position = torch.arange(0, seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)

        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.size(1), :]).requires_grad_(False)
        return self.dropout(x)

In [26]:
class LayerNormalization(nn.Module):
    def __init__(self, eps: float = 10**-6):
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1))
        self.beta = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.alpha * (x - mean) / (std + self.eps) + self.beta

In [27]:
class FeedForwardBlock(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        self.layer1 = nn.Linear(d_model, d_ff)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.layer2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = self.layer1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.layer2(x)
        return x

In [28]:
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.h = h

        assert d_model % h == 0, "d_model must be divisible by h"

        self.d_k = d_model // h
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, d_k, mask, dropout: nn.Dropout=None):
        d_k = query.shape[-1]
        
        # (batch_size, h, seq_length, d_k) --> (batch_size, h, seq_length, seq_length)
        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            attention_scores.masked_fill_(mask == 0, -1e9)
        attention_scores = torch.softmax(attention_scores, dim=-1) # (batch_size, h, seq_length, seq_length)
        if dropout is not None:
            attention_scores = dropout(attention_scores)

        # (batch_size, h, seq_length, seq_length) --> (batch_size, h, seq_length, d_k)
        attention_output = torch.matmul(attention_scores, value)
        return (attention_output) , attention_scores

    def forward(self, q, k, v, mask):
        query = self.W_Q(q)  # (batch_size, seq_length, d_model) --> (batch_size, seq_length, d_model)
        key = self.W_K(k)   # (batch_size, seq_length, d_model) --> (batch_size, seq_length, d_model)
        value = self.W_V(v) # (batch_size, seq_length, d_model) --> (batch_size, seq_length, d_model)

        # (batch_size, seq_length, d_model) --> (batch_size, h, seq_length, d_k) --> (batch_size, h, seq_length, d_k)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

        # (batch_size, h, seq_length, d_k) --> (batch_size, h, seq_length, d_k)
        x, self.attention_scores = self.attention(query, key, value, self.d_k, mask, self.dropout)

        # (batch_size, h, seq_length, d_k) --> (batch_size, seq_length, h, d_k) --> (batch_size, seq_length, d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)  # self.h * self.d_k = d_model

        # (batch_size, seq_length, d_model) --> (batch_size, seq_length, d_model)
        return self.W_O(x)   



In [29]:
class ResidualConnection(nn.Module):
    def __init__(self, dropout: float) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = LayerNormalization()

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.layer_norm(x)))

In [30]:
class EncoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connection1 = ResidualConnection(dropout)
        self.residual_connection2 = ResidualConnection(dropout)

    def forward(self, x, src_mask):
        x = self.residual_connection1(x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connection2(x, self.feed_forward_block)
        return x


In [31]:
class Encoder(nn.Module):

    def __init__(self, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()

    def forward(self, x, src_mask):
        for layer in self.layers:
            x = layer(x, src_mask)
        return self.norm(x)

In [32]:
class DecoderBlock(nn.Module):

    def __init__(self, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock,
                 feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connection1 = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residual_connection1[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connection1[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connection1[2](x, self.feed_forward_block)
        return x

In [33]:
class Decoder(nn.Module):
    
    def __init__(self, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

In [34]:
class ProjectionLayer(nn.Module):

    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        # (batch_size, seq_length, d_model) --> (batch_size, seq_length, vocab_size)
        return torch.log_softmax(self.proj(x), dim=-1)

In [35]:
class Transformer(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbedding, tgt_embed: InputEmbedding,
                 src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src, src_mask):
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)
    
    def decode(self, encoder_output, src_mask, tgt, tgt_mask):
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
    
    def project(self, x):
        return self.projection_layer(x)

In [36]:
def build_transformer(vocab_size: int, src_seq_len: int, tgt_seq_len: int,
                      d_model: int = 512, N: int = 6, h: int = 8, dropout: float = 0.1, d_ff: int = 2048) -> Transformer:
    src_embed = InputEmbedding(d_model, vocab_size)
    tgt_embed = InputEmbedding(d_model, vocab_size)

    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)

    encoder_blocks = []

    for _ in range(N):
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_blocks.append(EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout))

    decoder_blocks = []

    for _ in range(N):
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
        decoder_blocks.append(decoder_block)

    encoder = Encoder(nn.ModuleList(encoder_blocks))
    decoder = Decoder(nn.ModuleList(decoder_blocks))

    projection_layer = ProjectionLayer(d_model, vocab_size)

    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)

    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer

## Training

In [37]:
class Eli5Dataset(Dataset):

    def __init__(self, ds, tokenizer, src_seq_len, tgt_seq_len):
        super().__init__()
        self.src_seq_len = src_seq_len
        self.tgt_seq_len = tgt_seq_len

        self.ds = ds
        self.tokenizer = tokenizer

        self.sos_token = torch.tensor([tokenizer.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer.token_to_id("[PAD]")], dtype=torch.int64)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        title = self.ds[idx]["title"]
        answers = self.ds[idx]["answers"]["text"]
        
        # write torch random function to choose an idx from answers
        answer_idx = np.random.randint(0, len(answers))
        answer = answers[answer_idx]

        # Transform the text into tokens
        enc_input_tokens = self.tokenizer.encode(title).ids
        dec_input_tokens = self.tokenizer.encode(answer).ids

        if len(dec_input_tokens) > self.tgt_seq_len - 1:
            num_tokens_list = []
            for i, answer in enumerate(answers):
                num_tokens_list.append(len(self.tokenizer.encode(answer).ids))

            answer_idx = np.argmax(num_tokens_list)
            answer = answers[answer_idx]
            dec_input_tokens = self.tokenizer.encode(answer).ids

        # Add sos, eos and padding to each sentence
        enc_num_padding_tokens = self.src_seq_len - len(enc_input_tokens) - 2  # We will add <s> and </s>
        # We will only add <s>, and </s> only on the label
        dec_num_padding_tokens = self.tgt_seq_len - len(dec_input_tokens) - 1

        # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
        if enc_num_padding_tokens < 0:
            # raise ValueError("Sentence is too long")
            enc_input_tokens = enc_input_tokens[:self.src_seq_len - 2]
            enc_num_padding_tokens = 0

        if dec_num_padding_tokens < 0:
            # raise ValueError("Sentence is too long")
            dec_input_tokens = dec_input_tokens[:self.tgt_seq_len - 1]
            dec_num_padding_tokens = 0

        # Add <s> and </s> token
        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Add only <s> token
        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Add only </s> token
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Double check the size of the tensors to make sure they are all seq_len long
        assert encoder_input.size(0) == self.src_seq_len
        assert decoder_input.size(0) == self.tgt_seq_len
        assert label.size(0) == self.tgt_seq_len

        return {
            "encoder_input": encoder_input,  # (seq_len)
            "decoder_input": decoder_input,  # (seq_len)
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
            "label": label,  # (seq_len)
            "src_text": title,
            "tgt_text": answer if len(answer.split()) < self.tgt_seq_len else answer[:self.tgt_seq_len]
        }
    
def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

In [38]:
def get_all_sentences(ds):
    for item in tqdm(ds):
        yield item['title']
        for sentence in item["answers"]['text']:
            yield sentence

def get_or_build_tokenizer(config, ds):
    tokenizer_path = Path(config['tokenizer_file'])
    if not Path.exists(tokenizer_path):
        # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        print(f"Loading tokenizer from file {tokenizer_path}")
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

def get_ds_and_tokenizer(config):
    print('Loading dataset')
    ds_train_raw = load_dataset("eli5", split="train_eli5")
    ds_val_raw = load_dataset("eli5", split="validation_eli5")
    ds_test_raw = load_dataset("eli5", split="test_eli5")
    print('Dataset loaded')

    print('Building tokenizer')
    tokenizer = get_or_build_tokenizer(config, ds_train_raw)
    print('Tokenizer built')

    max_len_src = 0
    max_len_tgt = 0

    # print('Finding max length of source and target sentences')
    # for item in tqdm(ds_train_raw):
    #     src_ids = tokenizer.encode(item['title']).ids
    #     max_len_src = max(max_len_src, len(src_ids))

    #     for sentence in item["answers"]["text"]:
    #         tgt_ids = tokenizer.encode(sentence).ids
    #         max_len_tgt = max(max_len_tgt, len(tgt_ids))

    # print(f'Max length of source sentence: {max_len_src}')
    # print(f'Max length of target sentence: {max_len_tgt}')

    if max_len_src > config['src_seq_len']:
        print(f'Warning: max length of source sentence is greater than the config value of {config["src_seq_len"]}')
    if max_len_tgt > config['tgt_seq_len']:
        print(f'Warning: max length of target sentence is greater than the config value of {config["tgt_seq_len"]}')

    train_ds = Eli5Dataset(ds_train_raw, tokenizer, config['src_seq_len'], config['tgt_seq_len'])
    val_ds = Eli5Dataset(ds_val_raw, tokenizer, config['src_seq_len'], config['tgt_seq_len'])
    test_ds = Eli5Dataset(ds_test_raw, tokenizer, config['src_seq_len'], config['tgt_seq_len'])

    # Find the maximum length of each sentence in the source and target sentence
    

    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True, num_workers=2)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True, num_workers=2)
    test_dataloader = DataLoader(test_ds, batch_size=1, shuffle=True, num_workers=2)

    return train_dataloader, val_dataloader, test_dataloader, tokenizer

In [40]:
def get_weights_file_path(config, epoch: str):
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}{epoch}.pt"
    return str(Path('.') / model_folder / model_filename)

# Find the latest weights file in the weights folder
def latest_weights_file_path(config):
    model_folder = f"{config['datasource']}_{config['model_folder']}"
    model_filename = f"{config['model_basename']}*"
    weights_files = list(Path(model_folder).glob(model_filename))
    if len(weights_files) == 0:
        return None
    weights_files.sort()
    return str(weights_files[-1])

In [39]:
def load_prev_state(config, model, optimizer, initial_epoch, global_step):
    
    preload = config['preload']
    model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
    if model_filename:
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        model.load_state_dict(state['model_state_dict'])
        optimizer.load_state_dict(state['optimizer_state_dict'])
        initial_epoch = state['epoch'] + 1
        global_step = state['global_step']
    else:
        print('No model to preload, starting from scratch')
    return model, optimizer, initial_epoch, global_step

In [41]:
def greedy_decode(model, source, source_mask, tokenizer, max_dec_seq_len, device):
    sos_idx = tokenizer.token_to_id('[SOS]')
    eos_idx = tokenizer.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_dec_seq_len:
            break

        # build mask for target
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        # calculate output
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        # get next token
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)

def run_validation(model, validation_ds, tokenizer, max_dec_seq_len, device, print_msg, global_step, writer, num_examples=2):
    model.eval()
    count = 0

    source_texts = []
    expected = []
    predicted = []


    console_width = 20

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            print(f"Validation example: {count}")
            encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
            encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)

            # check that the batch size is 1
            assert encoder_input.size(
                0) == 1, "Batch size must be 1 for validation"

            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer, max_dec_seq_len, device)

            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            model_out_text = tokenizer.decode(model_out.detach().cpu().numpy())

            source_texts.append(source_text)
            expected.append(target_text)
            predicted.append(model_out_text)
            
            # print the source, target and model output
            print_msg('-'*console_width)
            print_msg(f"{f'SOURCE: ':>12}{source_text}")
            print_msg(f"{f'TARGET: ':>12}{target_text}")
            print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")

            if count == num_examples:
                print_msg('-'*console_width)
                break
    
    if writer:
        # Evaluate the character error rate
        # Compute the char error rate 
        metric = torchmetrics.CharErrorRate()
        cer = metric(predicted, expected)
        writer.add_scalar('validation cer', cer, global_step)
        writer.flush()

        # Compute the word error rate
        metric = torchmetrics.WordErrorRate()
        wer = metric(predicted, expected)
        writer.add_scalar('validation wer', wer, global_step)
        writer.flush()

        # Compute the BLEU metric
        metric = torchmetrics.BLEUScore()
        bleu = metric(predicted, expected)
        writer.add_scalar('validation BLEU', bleu, global_step)
        writer.flush()

In [42]:
# import torchmetrics
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

def train_model(config, model, optimizer, loss_fn, train_dataloader, val_dataloader, test_loader, tokenizer, device):
    model, optimizer, initial_epoch, global_step = load_prev_state(config, model, optimizer, 0, 0)

    # Tensorboard
    writer = SummaryWriter(config['experiment_name'])

    for epoch in range(initial_epoch, config['num_epochs']):
        torch.cuda.empty_cache()
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")

        
        for batch in batch_iterator:

            encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
            decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
            encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
            decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)

            # Run the tensors through the encoder, decoder and the projection layer
            encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
            proj_output = model.project(decoder_output) # (B, seq_len, vocab_size)

            # Compare the output with the label
            label = batch['label'].to(device) # (B, seq_len)

            # Compute the loss using a simple cross entropy
            loss = loss_fn(proj_output.view(-1, tokenizer.get_vocab_size()), label.view(-1))
            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

            # Log the loss
            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()

            # Backpropagate the loss
            loss.backward()

            # Update the weights
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            global_step += 1

        # Run validation at the end of every epoch
        run_validation(model, val_dataloader, tokenizer, config['tgt_seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer, 2)

        # Save the model at the end of every epoch
        model_filename = get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)

In [43]:
config = {
        "batch_size": 16,
        "num_epochs": 8,
        "lr": 10**-4,
        'src_seq_len': 128,
        'tgt_seq_len': 256,
        "d_model": 512,
        "datasource": 'eli5',
        "lang_src": "en",
        "lang_tgt": "it",
        "model_folder": "weights",
        "model_basename": "tmodel_",
        "preload": "latest",
        "tokenizer_file": "tokenizer.json",
        "experiment_name": "runs/tmodel"
    }

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
if (device == 'cuda'):
    print(f"Device name: {torch.cuda.get_device_name(device.index)}")
    print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
device = torch.device(device)

Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)
train_dataloader, val_dataloader, test_loader, tokenizer = get_ds_and_tokenizer(config)

model = build_transformer(tokenizer.get_vocab_size(), config["src_seq_len"], config['tgt_seq_len'], d_model=config['d_model']).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

In [44]:
train_model(config, model, optimizer, loss_fn, train_dataloader, val_dataloader, test_loader, tokenizer, device)

Using device: cuda
Device name: NVIDIA GeForce RTX 3090
Device memory: 23.68316650390625 GB
Loading dataset


Dataset loaded
Building tokenizer
Loading tokenizer from file tokenizer.json
Tokenizer built
No model to preload, starting from scratch


Processing Epoch 00: 100%|██████████| 17040/17040 [39:18<00:00,  7.23it/s, loss=5.312]


Validation example: 1
--------------------
    SOURCE: My Daughter asked me what condensation is, I tried I explaining but she didn't understand. I need a genuine ELI5 on what condensation is.
    TARGET: You can find water in the air. It's called humidity or moisture.

**"BUT WATER IS A LIQUID!!"** Kinda, water prefers to be in the liquid state but a small percentage goes to the air, this percentage depends mostly on the temperature. Hot environment can hold more water in the air and Cold ones don't hold up to as much water. 

**Are you serious?! This happens?**" Yes. Actually how to clothes dry up in the sun? The shirts don't heat up to 100ºC / 212ºF. The sun just heats up the air around the shirt and the water mixes up with the air. Removing it from the shirt to the air.

**WHAT HAS THIS TO DO WITH CONDENSATION?!"** Well, hot environment hold more water than cold ones. So what happens when he pass from hot air to cold air? Well, the cold air can't hold up as much water and the water

Processing Epoch 01:   5%|▍         | 800/17040 [01:51<37:34,  7.20it/s, loss=5.462]


KeyboardInterrupt: 