<a href="https://colab.research.google.com/github/kla55/transformer/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [64]:
!pip install datasets



In [65]:
!pip install dataset



In [66]:
from pathlib import Path


def get_config():
    return {
        'lang_source': 'en',
        'lang_target': 'it',
        'tokenizer_file': 'tokenizer_{0}.json',  # Provide the path to your tokenizer directory
        'batch_size': 1,
        'num_layers': 4,
        'd_model': 512,
        'num_heads': 8,
        'dff': 1024,
        'dropout': 0.1,
        'learning_rate': 10 ** -4,
        'num_epochs': 20,
        'model_folder': "weights",
        'model_basename': "transformer_model_",
        'preload': None,
        'experiment_name': "runs/transformer_model"

    }


def get_weights_file_path(config, epoch):
    model_folder = config['model_folder']
    model_basename = config['model_basename']
    model_filename = f"{model_basename}{epoch}.pt"
    return str(Path('.') / model_folder / model_filename)

In [67]:
import torch
from torch.utils.data import Dataset


class BilingualDataset(Dataset):
    def __init__(self, dataset, tokenizer_source, tokenizer_target, source_lang, target_lang, seq_len):
        """
        Bilingual dataset class for training a sequence-to-sequence model.

        Args:
            dataset (list): List of dictionaries containing source and target translations.
            tokenizer_source: Tokenizer for source language.
            tokenizer_target: Tokenizer for target language.
            source_lang (str): Key for accessing source language in the dataset dictionary.
            target_lang (str): Key for accessing target language in the dataset dictionary.
            seq_len (int): Maximum sequence length for encoder and decoder inputs.
        """
        self.seq_len = seq_len
        self.dataset = dataset
        self.tokenizer_source = tokenizer_source
        self.tokenizer_target = tokenizer_target
        self.source_lang = source_lang
        self.target_lang = target_lang
        self.sos_token = torch.tensor([tokenizer_source.token_to_id('[SOS]')], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_source.token_to_id('[EOS]')], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_source.token_to_id('[PAD]')], dtype=torch.int64)

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.dataset)

    def __getitem__(self, index):
        """
        Retrieves a sample from the dataset and preprocesses it.

        Args:
            index (int): Index of the sample to retrieve.

        Returns:
            encoder_input (Tensor): Padded encoder input sequence.
            decoder_input (Tensor): Padded decoder input sequence.
            masks, labels, and text data.
        """
        source_target_pair = self.dataset[index]
        source_text = source_target_pair['translation'][self.source_lang]
        target_text = source_target_pair['translation'][self.target_lang]

        encoder_input_tokens = self.tokenizer_source.encode(source_text).ids
        decoder_input_tokens = self.tokenizer_target.encode(target_text).ids

        encoder_num_padding_tokens = self.seq_len - len(encoder_input_tokens) - 2
        decoder_num_padding_tokens = self.seq_len - len(decoder_input_tokens) - 1

        if encoder_num_padding_tokens < 0 or decoder_num_padding_tokens < 0:
            raise ValueError("Sentence is too long")

        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(encoder_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * encoder_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(decoder_input_tokens, dtype=torch.int64),
                torch.tensor([self.pad_token] * decoder_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        label = torch.cat(
            [
                torch.tensor(decoder_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * decoder_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        return {
            'encoder_input': encoder_input,  # (seq_len)
            'decoder_input': decoder_input,  # (seq_len)
            'encoder_mask': (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(),  # (1, 1, seq_len)
            'decoder_mask': (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)),
            # (1, seq_len) & (1, seq_len, seq_len),
            'target_label': label,  # (seq_len)
            'source_text': source_text,
            'target_text': target_text,
        }


def causal_mask(size):
    """
        Creates a causal mask for the decoder.

        Args:
            size (int): Size of the mask.

        Returns:
            causal_mask (Tensor): Causal mask.
        """
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

In [68]:
from pathlib import Path
import torch
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import WordLevelTrainer
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

# from config import get_weights_file_path, get_config
# from dataset import BilingualDataset, causal_mask
# from model import build_transformer

In [69]:
def build_tokenizer(config, dataset, lang):
    # eg config['tokenizer_file'] = '../tokenizer/tokenizer_en.json'
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    print(tokenizer_path)
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(dataset, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))

    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))

    return tokenizer

def get_all_sentences(dataset, lang):
  for sentence in dataset:
      yield sentence['translation'][lang]

In [70]:
config = get_config()

In [71]:
dataset = load_dataset('opus_books', f'{config["lang_source"]}-{config["lang_target"]}', split='train')

In [72]:
dataset

Dataset({
    features: ['id', 'translation'],
    num_rows: 32332
})

In [73]:
tokenizer_source = build_tokenizer(config, dataset, config['lang_source'])
tokenizer_source

tokenizer_en.json


Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"[UNK]", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":1, "content":"[PAD]", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":2, "content":"[SOS]", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":3, "content":"[EOS]", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}], normalizer=None, pre_tokenizer=Whitespace(), post_processor=None, decoder=None, model=WordLevel(vocab={"[UNK]":0, "[PAD]":1, "[SOS]":2, "[EOS]":3, ",":4, "the":5, "and":6, ".":7, "to":8, "I":9, "of":10, "a":11, "'":12, "in":13, "was":14, "that":15, "he":16, "it":17, ";":18, "had":19, "his":20, "not":21, "with":22, "her":23, "you":24, "as":25, "for":26, "she":27, "my":28, "-":29, "at":30, "but":31, "him":32, "me":33, "is":34, """:35, "on":36, "be":37, ":

In [74]:
tokenizer_target = build_tokenizer(config, dataset, config['lang_target'])
tokenizer_target

tokenizer_it.json


Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"[UNK]", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":1, "content":"[PAD]", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":2, "content":"[SOS]", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":3, "content":"[EOS]", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}], normalizer=None, pre_tokenizer=Whitespace(), post_processor=None, decoder=None, model=WordLevel(vocab={"[UNK]":0, "[PAD]":1, "[SOS]":2, "[EOS]":3, ",":4, ".":5, "e":6, "di":7, "che":8, "—":9, "’":10, "la":11, "non":12, "a":13, "il":14, "un":15, "in":16, "per":17, "si":18, ";":19, "con":20, "una":21, "era":22, "le":23, "l":24, "mi":25, "ma":26, "è":27, "da":28, "'":29, "?":30, "del":31, "i":32, "come":33, "più":34, "della":35, "lo":36, "disse":37, "gli":

In [75]:
training_dataset_size = int(len(dataset) * 0.9)
validation_dataset_size = len(dataset) - training_dataset_size
training_dataset_raw, validation_dataset_raw = torch.utils.data.random_split(dataset, [training_dataset_size,
                                                                                        validation_dataset_size])
print(len(training_dataset_raw), len(validation_dataset_raw))

29098 3234


In [76]:
max_len_source = 0
max_len_target = 0
for item in training_dataset_raw:
    source_text = item['translation'][config['lang_source']]
    target_text = item['translation'][config['lang_target']]
    max_len_source = max(max_len_source, len(tokenizer_source.encode(source_text).ids))
    max_len_target = max(max_len_target, len(tokenizer_target.encode(target_text).ids))

In [77]:
training_dataset = BilingualDataset(training_dataset_raw, tokenizer_source, tokenizer_target,
                                    config['lang_source'],
                                    config['lang_target'], max_len_target)

In [78]:
def get_dataset(config):
    dataset = load_dataset('opus_books', f'{config["lang_source"]}-{config["lang_target"]}', split='train')

    # Build tokenizers
    tokenizer_source = build_tokenizer(config, dataset, config['lang_source'])
    tokenizer_target = build_tokenizer(config, dataset, config['lang_target'])

    # Keep 90% for training and 10% for validation
    training_dataset_size = int(len(dataset) * 0.9)
    validation_dataset_size = len(dataset) - training_dataset_size
    training_dataset_raw, validation_dataset_raw = torch.utils.data.random_split(dataset, [training_dataset_size,
                                                                                           validation_dataset_size])

    # # Calculate the maximum sequence lengths for source and target languages
    '''The goal is to determine the longest sequence of tokens (in terms of tokenized IDs) in both the source and target languages within the training dataset. These maximum lengths are used later to define the maximum sequence lengths for padding or truncation during training.
    '''
    max_len_source = 0
    max_len_target = 0
    for item in training_dataset_raw:
        source_text = item['translation'][config['lang_source']]
        target_text = item['translation'][config['lang_target']]
        max_len_source = max(max_len_source, len(tokenizer_source.encode(source_text).ids))
        max_len_target = max(max_len_target, len(tokenizer_target.encode(target_text).ids))

    training_dataset = BilingualDataset(training_dataset_raw, tokenizer_source, tokenizer_target,
                                        config['lang_source'],
                                        config['lang_target'], max_len_target)

    validation_dataset = BilingualDataset(validation_dataset_raw, tokenizer_source, tokenizer_target,
                                          config['lang_source'],
                                          config['lang_target'], max_len_target)

    # Set the maximum sequence lengths in the configuration
    config['seq_len'] = max_len_source
    config['max_seq_len'] = max_len_source

    print(f" Max length of source text: {max_len_source}")
    print(f" Max length of target text: {max_len_target}")

    # Create data loaders for training and validation datasets
    training_dataloader = torch.utils.data.DataLoader(training_dataset, batch_size=config['batch_size'], shuffle=True)
    validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=1, shuffle=False)

    return training_dataloader, validation_dataloader, tokenizer_source, tokenizer_target



In [79]:
training_dataloader, validation_dataloader, tokenizer_source, tokenizer_target = get_dataset(config)

tokenizer_en.json
tokenizer_it.json
 Max length of source text: 309
 Max length of target text: 274


# Get Model

In [80]:
def get_model(config, vocab_source_length, vocab_target_length):
    """
    Builds and returns a transformer model.

    Args:
        config (dict): Configuration settings.
        vocab_source_length (int): Vocabulary size for source language.
        vocab_target_length (int): Vocabulary size for target language.

    Returns:
        model (nn.Module): Transformer model.
    """
    # Extract model configuration parameters from config
    num_layers = config['num_layers']
    d_model = config['d_model']
    num_heads = config['num_heads']
    dff = config['dff']
    dropout = config['dropout']
    max_seq_len = config['seq_len']

    # Build the transformer model
    model = build_transformer(num_layers, d_model, num_heads, dff, dropout, vocab_source_length,
                              vocab_target_length, max_seq_len)

    return model

In [81]:
config1 = {
    'num_layers': 6,
    'd_model': 512,
    'num_heads': 8,
    'dff': 2048,
    'dropout': 0.1,
    'seq_len': 100
}
vocab_source_length = 30000
vocab_target_length = 30000

model = get_model(config1, vocab_source_length, vocab_target_length)

In [85]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
Path(config['model_folder']).mkdir(parents=True, exist_ok=True)
vocab_source_length = tokenizer_source.get_vocab_size()
vocab_target_length = tokenizer_target.get_vocab_size()
print(f'vocab source len {vocab_source_length}')
print(f'vocab target len {vocab_target_length}')
model = get_model(config, vocab_source_length, vocab_target_length)
writer = SummaryWriter(config['experiment_name'])
criterion = torch.nn.CrossEntropyLoss(ignore_index=tokenizer_target.token_to_id("[PAD]"),
                                      label_smoothing=0.1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], eps=1e-9)


Using device: cpu
vocab source len 15698
vocab target len 22463


In [83]:
    # Define loss function and optimizer
    criterion = torch.nn.CrossEntropyLoss(ignore_index=tokenizer_target.token_to_id("[PAD]"),
                                          label_smoothing=0.1).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], eps=1e-9)

    initial_epoch = 0
    global_step = 0

In [60]:
  if config['preload']:
      model_filename = get_weights_file_path(config, config['preload'])
      print(" Pre-Loading model", model_filename)
      state = torch.load(model_filename)
      initial_epoch = state['epoch'] + 1
      optimizer.load_state_dict(state['optimizer_state_dict'])
      global_step = state['global_step']

In [58]:
writer

<torch.utils.tensorboard.writer.SummaryWriter at 0x7bf8a7baba00>

In [86]:
# How the Masks Work
# Encoder Self-Attention:

# Applies the encoder_mask to attention scores, ensuring padding tokens don't contribute.
# Decoder Self-Attention:

# Applies the decoder_mask to attention scores, ensuring both:
# Causal masking: Only attends to previous and current tokens.
# Padding masking: Ignores padding tokens.
# Encoder-Decoder Attention:

# Uses the encoder_mask to prevent attention to padding tokens in the source sequence.

# Mask Shape: (Batch, 1, 1, Seq_Len)
# Example Mask (Seq_Len = 5):
# [1, 1, 1, 0, 0] -> Indicates valid tokens (1) and padding tokens (0).


# Mask Shape: (Batch, 1, Seq_Len, Seq_Len)
# Example Mask (Seq_Len = 5):
# [[1, 0, 0, 0, 0],
#  [1, 1, 0, 0, 0],
#  [1, 1, 1, 0, 0],
#  [1, 1, 1, 1, 0],
#  [1, 1, 1, 1, 1]] -> Prevents attending to future tokens.

# Encoder Output
encoder_output = model.encode(encoder_input, encoder_mask)  # (Batch, Seq_Len, d_model)

Input:
  - encoder_input: The tokenized source text ((Batch, Seq_Len)).
  - encoder_mask: A mask to ignore padding tokens in the encoder.

Process:
The encoder processes the input sequence and generates contextualized embeddings for each token.
Output:
encoder_output: A tensor containing contextualized embeddings for each position in the sequence.

Shape: (Batch, Seq_Len, d_model)

Batch: Number of sequences in the batch.

Seq_Len: Length of each sequence.

d_model: Dimension of the model's embeddings.

# Decoder Output
decoder_output = model.decode(decoder_input, encoder_output, encoder_mask, decoder_mask)  # (Batch, Seq_Len, d_model)

Input:

- decoder_input: The tokenized target text up to the current step during training ((Batch, Seq_Len)).

- encoder_output: The output of the encoder ((Batch, Seq_Len, d_model)).

- encoder_mask: Ensures the decoder does not attend to padding tokens in the encoder output.

- decoder_mask: Ensures causal masking (no attending to future tokens) and ignores padding tokens in the target sequence.

Process:

The decoder attends to the encoder_output and its own decoder_input to produce predictions for the target sequence.
Output:

- decoder_output: A tensor containing contextualized embeddings for the decoder's current prediction.

- Shape: (Batch, Seq_Len, d_model)

# Projection to Vocabulary
projection_output = model.project(decoder_output)  # (Batch, Seq_Len, target_vocab_size)

Input:

- decoder_output: Contextualized embeddings from the decoder.
Process:

- Applies a linear transformation to project the decoder output into a probability distribution over the target vocabulary.
Output:

- projection_output: Logits representing the likelihood of each token in the target vocabulary.
- Shape: (Batch, Seq_Len, target_vocab_size)
- target_vocab_size: Number of tokens in the target vocabulary.

# Target Labels

target_label = batch['target_label'].to(device)  # (Batch, Seq_Len)

Input:

- batch['target_label']: Ground truth token indices for the target sequence.
Process:

- Moves the target labels to the same device (CPU/GPU) as the model for loss computation.

Output:

- target_label: Token indices for the target sequence.
- Shape: (Batch, Seq_Len)

In [88]:
def run_validation(model, validation_dataset, tokenizer_source, tokenizer_target, max_length, device, print_msg,
                   global_state, writer, num_examples=2):
    model.eval()
    count = 0

    # source_texts = []
    # expected = []
    # predicted = []

    # Size of the control window( just us a default value)
    console_width = 80
    with torch.no_grad():
        for batch in validation_dataset:
            count += 1
            encoder_input = batch['encoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)

            assert encoder_input.size(0) == 1, " Batch size must be 1 for validation"

            model_output = greedy_decode(model, encoder_input, encoder_mask, tokenizer_source, tokenizer_target,
                                         max_length,
                                         device)

            source_text = batch['source_text'][0]
            target_text = batch['target_text'][0]

            model_output_text = tokenizer_target.decode(model_output.detach().cpu().numpy())

            # source_texts.append(source_text)
            # expected.append(target_text)
            # predicted.append(model_output_text)

            # Print to the console

            print_msg('-' * console_width)
            print_msg(f'SOURCE{source_text}')
            print_msg(f'TARGET{target_text}')
            print_msg(f'PREDICT{model_output_text}')

            if count == num_examples:
                break

In [89]:
def train_model(config):
    # Define the device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    Path(config['model_folder']).mkdir(parents=True, exist_ok=True)

    # Get training data, validation data, source tokenizer, and target tokenizer
    training_dataloader, validation_dataloader, tokenizer_source, tokenizer_target = get_dataset(config)

    # Get vocabulary sizes for source and target languages from tokenizers
    vocab_source_length = tokenizer_source.get_vocab_size()
    vocab_target_length = tokenizer_target.get_vocab_size()

    # Build the transformer model
    model = get_model(config, vocab_source_length, vocab_target_length)

    # Tensorboard # writes data into a specified directory, and TensorBoard reads from that directory to generate visualizations.
    writer = SummaryWriter(config['experiment_name'])

    # Define loss function and optimizer
    criterion = torch.nn.CrossEntropyLoss(ignore_index=tokenizer_target.token_to_id("[PAD]"),
                                          label_smoothing=0.1).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], eps=1e-9)

    initial_epoch = 0
    global_step = 0

    if config['preload']:
        model_filename = get_weights_file_path(config, config['preload'])
        print(" Pre-Loading model", model_filename)
        state = torch.load(model_filename)
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']

    # Training loop
    num_epochs = config['num_epochs']
    for epoch in range(initial_epoch, num_epochs):
        model.train()
        batch_iterator = tqdm(training_dataloader, desc=f'Processing epoch {epoch:02d}')
        total_loss = 0

        for batch in training_dataloader:
            model.train()
            # The shapes of the encoder_mask and decoder_mask are specific to how the attention mechanism in a Transformer model works.
            # This represents the tokenized sequences for the source text. Each sequence has a length of Seq_Len.
            encoder_input = batch['encoder_input'].to(device)  # (Batch , Seq_Len)
            decoder_input = batch['decoder_input'].to(device)  # (Batch , Seq_Len)

            # The mask for the encoder is applied during self-attention in the encoder.
            # The attention score matrix has a shape of (Batch, Num_Heads, Seq_Len, Seq_Len)
            # It ensures the model only attends to valid tokens (e.g., ignores padding tokens).
            # The shape (Batch, 1, 1, Seq_Len) is designed to broadcast correctly when computing attention scores:
            # Doesn't need to consider future tokens since the encoder processes the entire source sequence at once.
            encoder_mask = batch['encoder_mask'].to(device)  # (Batch ,1 ,1 ,Seq_Len)
            # Needs to consider causal masking to prevent future tokens from being attended to.
            # Requires a mask of shape (Batch, 1, Seq_Len, Seq_Len) to apply both causal and padding masks for self-attention.
            # Only attends to tokens that have been generated up to the current position
            # Ignores padding tokens in the target text.
            decoder_mask = batch['decoder_mask'].to(device)  # (Batch ,1 ,Seq_len ,Seq_Len)

            print("Encoder Input Shape:", encoder_input.shape)
            print("Decoder Input Shape:", decoder_input.shape)
            print("Encoder Mask Shape:", encoder_mask.shape)
            print("Decoder Mask Shape:", decoder_mask.shape)

            # Run the tensors through the transformer

            encoder_output = model.encode(encoder_input, encoder_mask)  # (Batch , Seq_Len, d_model)
            decoder_output = model.decode(decoder_input, encoder_output, encoder_mask,
                                          decoder_mask)  # (Batch  Seq_Len, d_model)
            projection_output = model.project(decoder_output)  # (Batch, Seq_Len, target_vocab_size)

            target_label = batch['target_label'].to(device)  # (Batch, Seq_Len)

            # Calculate the loss
            # Flatten the projection_output and target_label tensors for the CrossEntropyLoss
            # Projection output shape after view: (Batch * Seq_Len, target_vocab_size)
            # Target label shape after view: (Batch * Seq_Len)
            loss = criterion(projection_output.view(-1, projection_output.shape[-1]), target_label.view(-1))
            batch_iterator.set_postfix({f"loss": f"{loss.item():6.3f}"})
            total_loss += loss.item()

            # Backpropagation and optimization step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()



            # Increment the global step count
            global_step += 1

        run_validation(model, validation_dataloader, tokenizer_source, tokenizer_target, config['seq_len'],
                       device, lambda msg: batch_iterator.write(msg), global_step, writer)

        # Calculate the average loss for the epoch
        avg_loss = total_loss / len(training_dataloader)

        # Log the loss to Tensorboard
        writer.add_scalar('Training Loss', avg_loss, global_step)
        writer.flush()

        # Print epoch info
        print(f"Epoch [{epoch + 1}/{num_epochs}] - Loss: {avg_loss:.4f}")

        # Save the trained model

        model_filename = get_weights_file_path(config, f'{epoch:02d}')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
            'global_step': global_step
        }, model_filename)

In [7]:
def greedy_decode(model, source, encoder_mask, tokenizer_source, tokenizer_target, max_length, device):
    sos_idx = tokenizer_target.token_to_id('[SOS]')
    eos_idx = tokenizer_target.token_to_id('[EOS]')

    # Pre-compute the encoder output and reuse it for every token we get from the decoder
    encoder_output = model.encode(source, encoder_mask)
    # Initialize the decoder input with sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(encoder_mask).to(device)

    while True:
        if decoder_input.size(1) == max_length:
            break

        # Build the mask for the target ( decoder input )
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(encoder_mask).to(device)

        # Calculate the output of the decoder
        decoder_output = model.decode(decoder_input, encoder_output, encoder_mask,
                                      decoder_mask)

        # Get the next token
        probabilities = model.project(decoder_output[:, -1])

        # Select the token with the max probability (because it is greedy search)
        _, next_word = torch.max(probabilities, dim=1)

        decoder_input = torch.cat([decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)],
                                  dim=1)

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)


ModuleNotFoundError: No module named 'dataset.bilingual'

# Transformer model

In [60]:
import math
import torch
import torch.nn as nn

import math
import torch
import torch.nn as nn


class InputEmbeddings(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.word_embeddings = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        print("InputEmbeddings - Input x shape:", x.shape)
        embeddings = self.word_embeddings(x) * math.sqrt(self.d_model)
        print("InputEmbeddings - Output embeddings shape:", embeddings.shape)
        return embeddings


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, seq_len, dropout):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(seq_len, d_model)

        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        print("PositionalEncoding - Input x shape:", x.shape)
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
        print("PositionalEncoding - Output x shape:", x.shape)
        return self.dropout(x)


class LayerNormalisation(nn.Module):
    def __init__(self, epsilon=1e-6):
        super().__init__()
        self.epsilon = epsilon
        self.alpha = nn.Parameter(torch.ones(1))
        self.bias = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        x = x.float()
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha * (x - mean) / (std + self.epsilon) + self.bias


class FeedForward(nn.Module):
    def __init__(self, d_model, dff, dropout):
        super().__init__()
        self.d_model = d_model
        self.dff = dff
        self.dropout = nn.Dropout(dropout)
        self.linear1 = nn.Linear(d_model, dff)
        self.linear2 = nn.Linear(dff, d_model)

    def forward(self, x):
        print("FeedForward - Input x shape:", x.shape)
        x = self.linear2(self.dropout(torch.relu(self.linear1(x))))
        print("FeedForward - Output x shape:", x.shape)
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout):
        super().__init__()
        self.attention_scores = None
        self.d_model = d_model
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout)

        self.d_k = d_model // self.num_heads

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)

        self.wo = nn.Linear(d_model, d_model)

        self.layer_norm1 = LayerNormalisation()
        self.layer_norm2 = LayerNormalisation()
        self.layer_norm3 = LayerNormalisation()

    @staticmethod
    def attention(q, k, v, mask, dropout):
        d_k = q.shape[-1]
        attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1)  # (Batch, num_heads, Seq_Len,  Seq_Len)

        if dropout is not None:
            attention_scores = dropout(attention_scores)

        attn = torch.matmul(attention_scores, v)  # (Batch, num_heads, Seq_Len, d_k)
        return attn, attention_scores

    def forward(self, q, k, v, mask):

        q = self.wq(q)

        k = self.wk(k)

        v = self.wv(v)

        q = q.view(q.shape[0], q.shape[1], self.num_heads, self.d_k).transpose(1, 2)

        k = k.view(k.shape[0], k.shape[1], self.num_heads, self.d_k).transpose(1, 2)

        v = v.view(v.shape[0], v.shape[1], self.num_heads, self.d_k).transpose(1, 2)

        x, self.attention_scores = MultiHeadAttention.attention(q, k, v, mask, self.dropout)

        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.num_heads * self.d_k)

        x = self.wo(x)

        return x


class ResidualConnection(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = LayerNormalisation()

    def forward(self, x, sub_layer):
        return x + self.dropout(sub_layer(self.layer_norm(x)))


class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dff, dropout):
        super().__init__()
        self.dff = dff  # Feed Forward Neural Network Output Size
        self.mha = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, dff, dropout)
        self.residual_mha = ResidualConnection(dropout)
        self.residual_ffn = ResidualConnection(dropout)

    def forward(self, x, mask):
        # Multi-Head Attention sub-layer
        attn_output = self.residual_mha(x, lambda x: self.mha(x, x, x, mask))

        # FeedForward sub-layer
        ffn_output = self.residual_ffn(attn_output, self.ffn)

        return ffn_output


class Encoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, dff, dropout):
        super().__init__()
        self.num_layers = num_layers
        self.d_model = d_model
        self.num_heads = num_heads
        self.dff = dff
        self.dropout = nn.Dropout(dropout)

        self.layer = nn.ModuleList([EncoderLayer(d_model, num_heads, dff, dropout) for _ in range(num_layers)])
        self.layer_norm = LayerNormalisation()

    def forward(self, x, mask=None):
        for i in range(self.num_layers):
            x = self.layer[i](x, mask)
        return self.layer_norm(x)


class DecoderLayer(nn.Module):

    def __init__(self, d_model, num_heads, dff, dropout):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.dff = dff  # Feed Forward Neural Network Output Size
        self.dropout = nn.Dropout(dropout)

        self.mha = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_mha = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, dff, dropout)
        self.residual_mha = ResidualConnection(dropout)
        self.residual_cross_mha = ResidualConnection(dropout)
        self.residual_ffn = ResidualConnection(dropout)

    def forward(self, x, encoder_output, source_mask, target_mask):
        # Multi-Head Attention sub-layer
        attn_output = self.residual_mha(x, lambda x: self.mha(x, x, x, target_mask))

        # Cross-Attention sub-layer
        cross_attn_output = self.residual_cross_mha(attn_output,
                                                    lambda x: self.mha(x, encoder_output, encoder_output, source_mask))

        # FeedForward sub-layer
        ffn_output = self.residual_ffn(cross_attn_output, self.ffn)

        return ffn_output


class Decoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, dff, dropout):
        super().__init__()
        self.num_layers = num_layers
        self.d_model = d_model
        self.num_heads = num_heads
        self.dff = dff
        self.dropout = nn.Dropout(dropout)

        self.layer = nn.ModuleList([DecoderLayer(d_model, num_heads, dff, dropout) for _ in range(num_layers)])
        self.layer_norm = LayerNormalisation()

    def forward(self, x, encoder_output, source_mask, target_mask):
        for i in range(self.num_layers):
            x = self.layer[i](x, encoder_output, source_mask, target_mask)
        return self.layer_norm(x)


class ProjectionLayer(nn.Module):
    def __init__(self, d_model, vocabulary_size):
        super().__init__()
        self.d_model = d_model
        self.projection = nn.Linear(d_model, vocabulary_size)

    def forward(self, x):
        # (Batch, Seq_Len, D_Model) -->( Batch, Seq_Len, Vocab_Size)
        return torch.log_softmax(self.projection(x), dim=-1)


class Transformer(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, dff, dropout, source_embeddings, target_embeddings,
                 source_pos_encodings, target_pos_encodings, vocabulary_size):
        super().__init__()
        self.encoder = Encoder(num_layers, d_model, num_heads, dff, dropout)
        self.decoder = Decoder(num_layers, d_model, num_heads, dff, dropout)
        self.projection = ProjectionLayer(d_model, vocabulary_size)
        self.source_embeddings = source_embeddings
        self.target_embeddings = target_embeddings
        self.source_pos_encodings = source_pos_encodings
        self.target_pos_encodings = target_pos_encodings

    def encode(self, source_input, source_mask):
        # Embedding and positional encoding for source inputs
        source_embedded = self.source_embeddings(source_input)
        source_embedded = self.source_pos_encodings(source_embedded)
        # Pass source input through the encoder
        encoder_output = self.encoder(source_embedded, source_mask)
        return encoder_output

    def decode(self, target_input, encoder_output, source_mask, target_mask):
        # Embedding and positional encoding for target inputs
        target_embedded = self.target_embeddings(target_input)
        target_embedded = self.target_pos_encodings(target_embedded)
        # Pass target input through the decoder
        decoder_output = self.decoder(target_embedded, encoder_output, source_mask, target_mask)
        return decoder_output

    def project(self, decoder_output):
        # Project the decoder output to the vocabulary size
        output_logits = self.projection(decoder_output)
        return output_logits


def build_transformer(num_layers, d_model, num_heads, dff, dropout, source_vocab_size, target_vocab_size,
                      max_seq_len):
    # Create embeddings and positional encodings
    source_embeddings = InputEmbeddings(d_model, source_vocab_size)
    target_embeddings = InputEmbeddings(d_model, target_vocab_size)
    source_pos_encodings = PositionalEncoding(d_model, max_seq_len, dropout)
    target_pos_encodings = PositionalEncoding(d_model, max_seq_len, dropout)

    # Create the Transformer model
    transformer = Transformer(num_layers, d_model, num_heads, dff, dropout,
                              source_embeddings, target_embeddings,
                              source_pos_encodings, target_pos_encodings, target_vocab_size)

    # Initialize the parameters
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer