![alt text](https://miro.medium.com/v2/resize:fit:856/1*ZCFSvkKtppgew3cc7BIaug.png)

![](https://miro.medium.com/v2/resize:fit:786/format:webp/1*LpDpZojgoKTPBBt8wdC4nQ.png)

In [None]:
import torch
import torch.nn as nn
import math
import numpy as np
import torch.nn.functional as F

## Embedding and positional Encoder

In [None]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embedding_dim):      # embedding_dim = d_model in the orignale paper
        '''
        * The weights of the embedding layer are represented by a matrix that maps each word in the vocabulary to a vector in the embedding space.
        - vocab_size: is the number of words in the vocabulary
        - embedding_dim: is the dimension of the word embeddings
        * The shape of the weights matrix is (vocab_size,embedding_dim).
        '''
        super(Embedding, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, x):
        '''
        * The input to the embedding layer is typically a batch of sequences, where each sequence is a list of token indices.
        * For a batch of sequences (let's denote the batch size as 𝐵), the input shape would be (B,max_input_len).
        * The output of the embedding layer is the batch of sequences with each token index replaced by its corresponding embedding vector.
        * For a batch of sequences (with batch size B), the output shape would be (B,max_input_len,embedding_dim).
        '''
        return self.embedding(x)*math.sqrt(self.embedding_dim)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, embedding_dim, max_len, dropout):
        super(PositionalEncoding, self).__init__()
        self.embedding_dim = embedding_dim
        self.max_len = max_len
        self.dropout = nn.Dropout(dropout)

        # create matrix of shape (max_len, embedding_dim)
        pe = torch.zeros(self.max_len, self.embedding_dim)

        # create vector of shape (max_len, 1)
        pos = torch.arange(0, max_len, dtype= torch.float).unsqueeze(1)

        div_term = torch.exp(torch.arange(0, self.embedding_dim, 2).float()*(-math.log(10000.0)/self.embedding_dim))

        pe[:, 0::2] = torch.sin(pos*div_term)
        pe[:, 1::2] = torch.cos(pos*div_term)

        pe = pe.unsqueeze(0)   # (1, max_len, embedding_dim)

        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1],:]).requires_grad_(False)
        return self.dropout(x)

# Decoder:

## Normalisation Layer

In [None]:
class LayerNormalisation(nn.Module):
    def __init__(self, eps = 10**-6):
        super(LayerNormalisation, self).__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(1))
        self.bais = nn.Parameter(torch.ones(1))

    def forward(self, x):
        mean = x.mean(-1, keepdim = True)
        std = x.std(-1, keepdim = True)
        return self.alpha * (x - mean) / (std + self.eps) + self.bais

##

In [None]:
class FeedForward(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, dropout):
        super(FeedForward, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.dropout = nn.Dropout(dropout)

        self.linear1 = nn.Linear(embedding_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, embedding_dim)

    def forward(self, x):
        # (batch_s, seq_len, embedding_dim) --> (batch_s, seq_len, hidden_dim) --> (batch_s, seq_len, embedding_dim)
        x = self.dropout(F.relu(self.linear1(x)))
        x = self.linear2(x)
        return x


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim, num_heads, dropout):
        super(MultiHeadAttention, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout)

        assert embedding_dim % num_heads == 0, 'embedding_dim must be divisible by num_heads'
        self.head_dim = embedding_dim // num_heads
        self.W_Q = nn.Linear(embedding_dim, embedding_dim)
        self.W_K = nn.Linear(embedding_dim, embedding_dim)
        self.W_V = nn.Linear(embedding_dim, embedding_dim)
        self.W_O = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(Q, K, V, mask, dropout = None):
        """
            Q, K, V: (batch, num_heads, seq_len, head_dim)
        """
        d_k = Q.shape[-1]

        # (batch, num_heads, seq_len, head_dim) @ (batch, num_heads, head_dim, seq_len) --> (batch, num_heads, seq_len, seq_len)
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention_weights = scores.softmax(dim = -1)
        if dropout is not None:
            attention_weights = dropout(attention_weights)

        # (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, head_dim) --> (batch, num_heads, seq_len, head_dim)
        return (attention_weights @ V), attention_weights

    def forward(self, Q, K, V, mask = None):
        query = self.W_Q(Q)      # (batch, seq_len, embedding_dim) --> (batch, seq_len, embedding_dim)
        key = self.W_K(K)      # (batch, seq_len, embedding_dim) --> (batch, seq_len, embedding_dim)
        value = self.W_V(V)      # (batch, seq_len, embedding_dim) --> (batch, seq_len, embedding_dim)

        # (batch, seq_len, embedding_dim) -->  # (batch, num_heads, seq_len, head_dim)
        query = query.view(query.shape[0], -1, self.num_heads, self.head_dim).permute(0,2,1,3)

        # (batch, seq_len, embedding_dim) -->  # (batch, num_heads, seq_len, head_dim)
        key = key.view(key.shape[0], -1, self.num_heads, self.head_dim).permute(0,2,1,3)

        # (batch, seq_len, embedding_dim) -->  # (batch, num_heads, seq_len, head_dim)
        value = value.view(value.shape[0], -1, self.num_heads, self.head_dim).permute(0,2,1,3)

        # x: (batch, num_heads, seq_len, head_dim)
        x, self.attention_weights = MultiHeadAttention.attention(query, key, value, mask, self.dropout)

        # (batch, num_heads, seq_len, head_dim) --> (batch, seq_len, num_heads, head_dim) --> (batch, seq_len, embedding_dim)
        x = x.permute(0,2,1,3).contiguous().view(x.shape[0], -1, self.embedding_dim)

        # (batch, seq_len, embedding_dim) --> (batch, seq_len, embedding_dim)
        return self.W_O(x)

In [None]:
class ResidualConnection(nn.Module):
    def __init__(self, dropout):
        super(ResidualConnection, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalisation()

    def forward(self, x, subLayer):
        return x + self.dropout(self.norm(subLayer(x)))   # self.norm(subLayer(x)) or subLayer(self.norm(x))

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, self_attention_block, feed_forward_block, dropout):
        super(EncoderBlock, self).__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block

        self.residual_connection = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])

    def forward(self, x, src_mask = None):
        x = self.residual_connection[0](x, lambda x: self.self_attention_block(x,x,x, src_mask))
        x = self.residual_connection[1](x, self.feed_forward_block)
        return x



In [None]:
class Encoder(nn.Module):
    def __init__(self, layers: nn.ModuleList):
        super(Encoder, self).__init__()
        self.layers = layers
        self.norm = LayerNormalisation()

    def forward(self, x, src_mask = None):
        for layer in self.layers:
            x = layer(x, src_mask)
        return self.norm(x)

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, self_attention_block, cross_attention_block, feed_forward_block, dropout):
        super(DecoderBlock, self).__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block

        self.residual_connection = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residual_connection[0](x, lambda x: self.self_attention_block(x,x,x, tgt_mask))
        x = self.residual_connection[1](x, lambda x: self.self_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connection[2](x, self.feed_forward_block)
        return x


In [None]:
class Decoder(nn.Module):
    def __init__(self, layers: nn.ModuleList):
        super(Decoder, self).__init__()
        self.layers = layers
        self.norm = LayerNormalisation()

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

In [None]:
class ProjectionLayer(nn.Module):
    def __init__(self, embedding_dim, vocab_size):
        super(ProjectionLayer, self).__init__()
        self.embedding_dim = embedding_dim
        self.vocab_size = vocab_size
        self.projection = nn.Linear(embedding_dim, vocab_size)
    def forward(self, x):
        # (batch, seq_len, embedding_dim) --> (batch, seq_len, vocab_size)
        return torch.log_softmax(self.projection(x), dim = -1)

In [None]:
class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_embedding, tgt_embedding, src_pos, tgt_pos, projection_layer):
        super(Transformer, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embedding = src_embedding
        self.tgt_embedding = tgt_embedding
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, x, src_mask):
        x = self.src_embedding(x)
        x = self.src_pos(x)
        x = self.encoder(x)
        return x

    def decode(self, x, encoder_output, src_mask, tgt_mask):
        x = self.tgt_embedding(x)
        x = self.tgt_pos(x)
        x = self.decoder(x, encoder_output, src_mask, tgt_mask)
        return x

    def projection(self, x):
        return self.projection_layer(x)

In [None]:
def transformer(src_vocab_size, tgt_vocab_size, src_seq_len, tgt_sqe_len, embedding_dim = 512, num_heads = 8, num_hidden = 2048, num_layers = 6, dropout = 0.1):
    # create Embedding layers
    src_embedding = Embedding(src_vocab_size, embedding_dim)
    tgt_embedding = Embedding(tgt_vocab_size, embedding_dim)

    # create positional encoding layer
    src_pos = PositionalEncoding(embedding_dim, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(embedding_dim, tgt_sqe_len, dropout)

    # create the encoder blocks
    encoder_blocks = []
    for _ in range(num_layers):
        encoder_self_attention_block = MultiHeadAttention(embedding_dim, num_heads, dropout)
        encoder_feed_forward_block = FeedForward(embedding_dim, num_hidden, dropout)
        encoder_blocks.append(EncoderBlock(encoder_self_attention_block, encoder_feed_forward_block, dropout))

    # create the decoder blocks
    decoder_blocks = []
    for _ in range(num_layers):
        decoder_self_attention_block = MultiHeadAttention(embedding_dim, num_heads, dropout)
        decoder_cross_attention_block = MultiHeadAttention(embedding_dim, num_heads, dropout)
        decoder_feed_forward_block = FeedForward(embedding_dim, num_hidden, dropout)
        decoder_blocks.append(DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, decoder_feed_forward_block, dropout))

    # create the encoder
    encoder = Encoder(nn.ModuleList(encoder_blocks))

    # create the decoder
    decoder = Decoder(nn.ModuleList(decoder_blocks))

    # create the projection layer
    projection_layer = ProjectionLayer(embedding_dim, tgt_vocab_size)

    # create the transformer
    transformer = Transformer(encoder, decoder, src_embedding, tgt_embedding, src_pos, tgt_pos, projection_layer)

    # initialize the parameters of the transformer
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p.data)

    return transformer


# training


## tokenizer

In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/547.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m153.6/547.8 kB[0m [31m4.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (40.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.8/40.8 MB[0m [31m13.6 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
Collecting requests>=2.32.2 (from datasets)
  Downloading requests-2.32.3-py3-none-any

In [None]:
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path
from typing import Any
from torch.utils.data import Dataset, random_split, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

In [None]:
def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]

In [None]:
def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token = "[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens = ["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency = 2)

        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

## dataset

In [None]:
def causal_mask(size):
    mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int64)
    return mask == 0

In [None]:
class BilingualDataset(Dataset):
    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len) -> None:
        super().__init__()

        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.seq_len = seq_len

        self.sos_token = torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_src.token_to_id('[PAD]')], dtype=torch.int64)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, index: Any) -> Any:
        src_target_pair = self.ds[index]
        src_text = src_target_pair['translation'][self.src_lang]
        tgt_text = src_target_pair['translation'][self.tgt_lang]

        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError("Sentence is too long")

        # Add SOS and EOS to the source text
        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
            ]
        )

        # Add SOS to the decoder input
        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ]
        )

        # Add EOS to the label (what we expect as output from the decoder)
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ]
        )

        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        return {
            "encoder_input": encoder_input,   # ( Seq_Len)
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(),  # (1, 1, Seq_Len)
            "decoder_input": decoder_input,   # ( Seq_Len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)),  # (1, 1, Seq_Len)
            "label": label,   # ( Seq_Len)
            "src_text": src_text,
            "tgt_text": tgt_text,
        }


In [None]:

def get_ds(config):
    ds_raw = load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_tgt"]}', split='train')

    # Build tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config["lang_src"])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config["lang_tgt"])

    # Keep 90% for training and 10% for validation
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config["lang_src"], config["lang_tgt"], config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config["lang_src"], config["lang_tgt"], config['seq_len'])

    max_src_len = 0
    max_tgt_len = 0

    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config["lang_src"]]).ids
        tgt_ids = tokenizer_src.encode(item['translation'][config["lang_tgt"]]).ids

        max_src_len = max(max_src_len, len(src_ids))
        max_tgt_len = max(max_tgt_len, len(src_ids))

    print(f"max src len: {max_src_len},\nmax tgt len: {max_tgt_len}")

    train_dataloader = DataLoader(train_ds, batch_size = config['batch_size'], shuffle = True)
    val_dataloader = DataLoader(val_ds, batch_size = 1, shuffle = True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

In [None]:
def get_model(config, vocsb_src_size, vocab_tgt_src):
    model = transformer(src_vocab_size = vocsb_src_size,
                        tgt_vocab_size = vocab_tgt_src,
                        src_seq_len = config['seq_len'],
                        tgt_sqe_len = config['seq_len'],
                        embedding_dim = config['embedding_dim'])
    return model

In [None]:
def get_config():
    return {
        'batch_size': 4,
        'num_epochs': 4,
        'lr': 0.001,
        'seq_len':  400,
        'embedding_dim': 512,
        'lang_src': 'en',
        'lang_tgt': 'it',
        'tokenizer_file': 'tokenizer_{0}.json',
        'model_folder': 'weights',
        'model_filename': 'tmodel_',
        'preload':' 1',
        'experiment_name': 'runs/tmodel',
    }

In [None]:
def get_weights_file_name(config, _epoch):
    return f"./{config['model_folder']}/{config['model_filename']}{_epoch}.pt"

In [None]:
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # build mask for target
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        # calculate output
        out = model.decode(decoder_input, encoder_output, source_mask, decoder_mask)

        # get next token
        prob = model.projection(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)

In [None]:
def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):
    model.eval()
    count = 0


    # If we can't get the console width, use 80 as default
    console_width = 80

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
            encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)

            # check that the batch size is 1
            assert encoder_input.size(
                0) == 1, "Batch size must be 1 for validation"

            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())

            # Print the source, target and model output
            print_msg('-'*console_width)
            print_msg(f"{f'SOURCE: ':>12}{source_text}")
            print_msg(f"{f'TARGET: ':>12}{target_text}")
            print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")

            if count == num_examples:
                print_msg('-'*console_width)
                break

In [None]:
def train_model(config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    Path(config['model_folder']).mkdir(parents= True, exist_ok= True)
    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

    # tensorBord
    writer = SummaryWriter(config['experiment_name'])

    optimizer = torch.optim.Adam(model.parameters(), lr = config['lr'], eps= 1e-9)

    initial_epoch = 0
    global_step = 0
    if config['preload'] is not None:
        model_filename = get_weights_file_name(config, config['preload'])

        state = torch.load(model_filename)
        initial_epoch = state['epoch']
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']

    loss_fn = nn.CrossEntropyLoss(ignore_index= tokenizer_src.token_to_id('[PAD]'), label_smoothing= 0.1).to(device)

    for epoch in range(initial_epoch, config['num_epochs']):
        model.train()
        batch_iterator = tqdm(train_dataloader, desc = f'processing epoch {epoch: 02d}')
        for batch in batch_iterator:
            encoder_input = batch['encoder_input'].to(device)      # (B, seq_len)
            decoder_input = batch['decoder_input'].to(device)      # (B, seq_len)
            encoder_mask = batch['encoder_mask'].to(device)      # (B, 1, 1, seq_len)
            decoder_mask = batch['decoder_mask'].to(device)      # (B, 1, seq_len, seq_len)

            # run transformer
            encoder_output = model.encode(encoder_input, encoder_mask)   # (B, seq_len, embedding_dim)
            dencoder_output = model.decode(decoder_input, encoder_output, encoder_mask, decoder_mask)   # (B, seq_len, embedding_dim)
            proj_output = model.projection(dencoder_output)   # (B, seq_len, tgt_vocab_size)

            # calculat loss
            label = batch['label'].to(device)   # (B, seq_len)

            # proj_output: (B, seq_len, tgt_vocab_size) --> (B * seq_len, tgt_vocab_size)
            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))

            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})


            # log the loss
            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()

            # backpropaget loss
            loss.backward()

            # update the weights
            optimizer.step()
            optimizer.zero_grad()

            global_step += 1
        # save the model at the end of each epoch
        model_filename = get_weights_file_name(config, f"{epoch: 02d}")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)

In [None]:
config = get_config()
train_model(config)

max src len: 309,
max tgt len: 309


processing epoch  1:   0%|          | 33/7275 [00:08<32:38,  3.70it/s, loss=7.837]


KeyboardInterrupt: 

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the state_dict into the model and optimizer
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

model_filename = get_weights_file_name(config, config['preload'])
state = torch.load(model_filename)

model.load_state_dict(state['model_state_dict'])

max src len: 309,
max tgt len: 309


<All keys matched successfully>

In [None]:
# batch_iterator = tqdm(train_dataloader, desc = f'processing epoch')
run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: print(msg) , None, None)


--------------------------------------------------------------------------------
    SOURCE: But a week passed, and another, and a third, and no impression was noticeable in Society. His friends, the specialists and the scholars, sometimes – from politeness – mentioned it; his other acquaintances, not interested in learned works, did not mention it to him at all.
    TARGET: Ma passò una settimana, ne passarono due, tre e nella società non si notava alcuna impressione; gli amici specialisti e studiosi, a volte, evidentemente per cortesia, ne cominciavano a parlare. Ma gli altri suoi conoscenti, non interessati a un libro di contenuto scientifico, non ne parlavano affatto.
 PREDICTED: — No , non vi , — disse , — — e , — e , e , con un sorriso , e , con un sorriso , e la signora Reed , e , e , con un sorriso , e con un sorriso , e con un sorriso , e la sua voce , e con un sorriso , e la sua voce , e con un sorriso di una voce , e con un sorriso di una voce di cui disse , e con la sua voc