In [83]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [84]:
# Model creation and training.
import torch
import torch.nn as nn
from transformer_components import (
    TransformerDecoder,
    TransformerEncoder,
    TransformerEncoderDecoder,
    get_causal_mask,
)
from torch.optim import Adam
import pickle

# Data download and decompression.
import os
import requests
import gzip
import shutil

# Data processing
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import spacy

In [85]:
# Set device to gpu if available
device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
print("You are using device: %s" % device)

You are using device: cuda


## Download and decompress the training, validation and test data.

In [86]:
MULTI30_URL = (
    "https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/"
)
LOCAL_DATA_DIR = "multi30k"
DATA_FILES_CONFIG = {
    "train": {"en": "train.en.gz", "fr": "train.fr.gz"},
    "val": {"en": "val.en.gz", "fr": "val.fr.gz"},
    "test": {"en": "test_2016_flickr.en.gz", "fr": "test_2016_flickr.fr.gz"},
}
DATA_FILES_NAMES = [
    filename for split in DATA_FILES_CONFIG.values() for filename in split.values()
]

# Download data into a local directory.
os.makedirs(LOCAL_DATA_DIR, exist_ok=True)

for file_name in DATA_FILES_NAMES:

    local_path = os.path.join(LOCAL_DATA_DIR, file_name)
    decompressed_path = os.path.splitext(local_path)[0]

    if not os.path.exists(decompressed_path):

        # Download file.
        with requests.get(
            os.path.join(MULTI30_URL, file_name), stream=True
        ) as response:
            response.raise_for_status()
            with open(local_path, "wb") as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)

        # Decompress file.
        with gzip.open(local_path, "rb") as f_in:
            with open(decompressed_path, "wb") as f_out:
                shutil.copyfileobj(f_in, f_out)

        # Remove compressed version of file.
        os.remove(local_path)

## Load the data from files into ram, and tokenize.

In [87]:
# To save time, save the data object. That way, spacy only needs tokenize the data once.
if os.path.exists("data_cache.pt"):
    data = torch.load("data_cache.pt", pickle_module=pickle)
else:
    spacy_en = spacy.load("en_core_web_sm", disable=["parser", "ner"])
    spacy_fr = spacy.load("fr_core_news_sm", disable=["parser", "ner"])

    def load_data(eng_file_path, fr_file_path):

        pairs = []
        with open(eng_file_path, "r") as f1, open(fr_file_path, "r") as f2:
            for eng_line, fr_line in zip(f1, f2):
                eng_tokens = [
                    token.text.lower() for token in spacy_en(eng_line.strip())
                ]
                fr_tokens = [token.text.lower() for token in spacy_fr(fr_line.strip())]
                pairs.append((eng_tokens, fr_tokens))

        return pairs

    data = {}
    for split, langs in DATA_FILES_CONFIG.items():
        eng_file_path = os.path.join(LOCAL_DATA_DIR, os.path.splitext(langs["en"])[0])
        fr_file_path = os.path.join(LOCAL_DATA_DIR, os.path.splitext(langs["fr"])[0])

        data[split] = load_data(eng_file_path, fr_file_path)

    torch.save(data, "data_cache.pt")

data["train"][0]  # Example

(['two',
  'young',
  ',',
  'white',
  'males',
  'are',
  'outside',
  'near',
  'many',
  'bushes',
  '.'],
 ['deux',
  'jeunes',
  'hommes',
  'blancs',
  'sont',
  'dehors',
  'près',
  'de',
  'buissons',
  '.'])

## Create English and French vocabularies.

In [88]:
PAD_TOKEN, PAD_IDX = "<PAD>", 0
SOS_TOKEN, SOS_IDX = "<SOS>", 1
EOS_TOKEN, EOS_IDX = "<EOS>", 2
UNK_TOKEN, UNK_IDX = "<UNK>", 3


class Vocab:

    def __init__(self):

        self.word_to_ordinal = {
            PAD_TOKEN: PAD_IDX,
            SOS_TOKEN: SOS_IDX,
            EOS_TOKEN: EOS_IDX,
            UNK_TOKEN: UNK_IDX,
        }
        self.ordinal_to_word = {
            PAD_IDX: PAD_TOKEN,
            SOS_IDX: SOS_TOKEN,
            EOS_IDX: EOS_TOKEN,
            UNK_IDX: UNK_TOKEN,
        }
        self.count = 4

    # This method builds the vocabulary, for each sentence passed in.
    def add_sentence(self, sentence):

        for word in sentence:
            if word not in self.word_to_ordinal:
                self.word_to_ordinal[word] = self.count
                self.ordinal_to_word[self.count] = word
                self.count += 1

    # This method is for creating model inputs.
    def to_ordinals(self, sentence):

        ordinal_sentence = [SOS_IDX]
        for word in sentence:
            ordinal_sentence.append(self.word_to_ordinal.get(word, UNK_IDX))
        ordinal_sentence.append(EOS_IDX)

        return torch.tensor(ordinal_sentence, dtype=torch.int64)

    # This method is for viewing model outputs.
    def to_words(self, ordinal_sentence):

        tokens = []
        for ordinal in ordinal_sentence:
            ordinal = ordinal.item()
            if ordinal == EOS_IDX:
                break
            if ordinal != SOS_IDX and ordinal != PAD_IDX:
                tokens.append(self.ordinal_to_word.get(ordinal, UNK_TOKEN))

        return " ".join(tokens)


# Build the vocabularies from training data.
en_vocab = Vocab()
fr_vocab = Vocab()

for en_fr_pair in data["train"]:
    en_vocab.add_sentence(en_fr_pair[0])
    fr_vocab.add_sentence(en_fr_pair[1])

## Creating a dataloader for train, val, and test datasets.

In [89]:
class Multi30k(Dataset):

    def __init__(self, pairs, en_vocab, fr_vocab):
        super().__init__()

        self.ordinal_pairs = []
        for pair in pairs:
            ordinal_pair = (
                en_vocab.to_ordinals(pair[0]),
                fr_vocab.to_ordinals(pair[1]),
            )
            self.ordinal_pairs.append(ordinal_pair)

    def __len__(self):
        return len(self.ordinal_pairs)

    def __getitem__(self, index):
        return self.ordinal_pairs[index]

In [90]:
BATCH_SIZE = 16


def collate_fn(batch):
    X_src = pad_sequence(
        [batch[i][0] for i in range(len(batch))],
        batch_first=True,
        padding_value=PAD_IDX,
    )
    X_tgt = pad_sequence(
        [batch[i][1] for i in range(len(batch))],
        batch_first=True,
        padding_value=PAD_IDX,
    )
    return (X_src, X_tgt)


dataloaders = {}
for split, pairs in data.items():
    dataset = Multi30k(pairs, en_vocab, fr_vocab)
    dataloaders[split] = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=(split == "train"),  # Shuffle only True for training data
        collate_fn=collate_fn,
    )

## Model creation and training loop.

In [91]:
# Wrapper for the TransformerEncoder that embeds tokens and adds positional encodings.
class Encoder(nn.Module):

    def __init__(self, transformer_encoder_config, vocab_size, context_size):
        super().__init__()

        self.embedding = nn.Embedding(
            vocab_size, transformer_encoder_config["hidden_size"]
        )
        self.positional_encoding = nn.Embedding(
            context_size, transformer_encoder_config["hidden_size"]
        )
        self.transformer_encoder = TransformerEncoder(**transformer_encoder_config)

    def forward(self, X, key_padding_mask):

        X = self.embedding(X) + self.positional_encoding(
            torch.arange(X.shape[1], device=X.device)
        )
        X = self.transformer_encoder(X, key_padding_mask)

        return X


# Similar idea to above.
class Decoder(nn.Module):

    def __init__(self, transformer_decoder_config, vocab_size, context_size):
        super().__init__()
        
        self.embedding = nn.Embedding(
            vocab_size, transformer_decoder_config["hidden_size"]
        )
        self.positional_encoding = nn.Embedding(
            context_size, transformer_decoder_config["hidden_size"]
        )
        self.project = nn.Linear(
            transformer_decoder_config["hidden_size"],
            vocab_size,
        )
        self.transformer_decoder = TransformerDecoder(**transformer_decoder_config)
        
        # Store for generate function 
        self.key_size = transformer_decoder_config["key_size"]
        self.value_size = transformer_decoder_config["value_size"]
        self.num_heads = transformer_decoder_config["num_heads"]
        self.vocab_size = vocab_size
        self.stack_size = transformer_decoder_config["stack_size"]


    def forward(
        self, X_tgt, X_src, tgt_mask, tgt_key_padding_mask, src_key_padding_mask, all_kv_cache=None
    ):

        X_tgt = self.embedding(X_tgt) + self.positional_encoding(
            torch.arange(X_tgt.shape[1], device=X_tgt.device)
        )
        features = self.transformer_decoder(
            X_tgt, X_src, tgt_mask, tgt_key_padding_mask, src_key_padding_mask, all_kv_cache
        )
        logits = self.project(features)

        return logits

In [92]:
CONTEXT_SIZE = 60  # The number of positional encodings to learn.

encoder_config = {
    "stack_size": 6,
    "num_heads": 8,
    "hidden_size": 512,
    "key_size": 64,
    "value_size": 64,
    "feedforward_size": 2048,
    "dropout": 0.1,
}

decoder_config = {
    "stack_size": 6,
    "num_heads": 8,
    "hidden_size": 512,
    "key_size": 64,
    "value_size": 64,
    "feedforward_size": 2048,
    "dropout": 0.1,
}

model = TransformerEncoderDecoder(
    Encoder(encoder_config, vocab_size=en_vocab.count, context_size=CONTEXT_SIZE),
    Decoder(decoder_config, vocab_size=fr_vocab.count, context_size=CONTEXT_SIZE),
).to(device)
criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = Adam(model.parameters(), lr=0.0001)

epochs_completed = 0

if os.path.exists("checkpoint.pt"):
    checkpoint = torch.load("checkpoint.pt", pickle_module=pickle)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    epochs_completed = checkpoint["epochs_completed"]

model.train()

epochs = 5
for _ in range(epochs):
    for batch_src, batch_tgt in dataloaders["train"]:

        # Move batch to gpu, prepare model inputs.
        encoder_in = batch_src.to(device)
        decoder_in = batch_tgt[:, :-1].to(device)  # Do not include the last token.
        ground_truth = batch_tgt[:, 1:].to(
            device
        )  # Do not include the first token. The ground truth for the SOS token is thus the first word of the French sentence.

        # Create masks.
        tgt_len = decoder_in.shape[1]

        tgt_causal_mask = get_causal_mask(tgt_len, device)
        tgt_key_padding_mask = decoder_in == PAD_IDX
        src_key_padding_mask = encoder_in == PAD_IDX

        # Update weights
        optimizer.zero_grad()

        features = model(
            decoder_in,
            encoder_in,
            tgt_causal_mask,
            tgt_key_padding_mask,
            src_key_padding_mask,
        )
        loss = criterion(features.view(-1, features.shape[-1]), ground_truth.view(-1))

        print(loss)

        loss.backward()
        optimizer.step()

    epochs_completed += 1

torch.save(
    {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epochs_completed": epochs_completed,
    },
    "checkpoint.pt",
)

tensor(9.6150, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(8.7307, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(8.3914, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.9061, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.9173, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.9846, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.8313, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.6964, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.6513, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.5631, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.6090, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.5294, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.4195, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.3678, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.1330, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.1952, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(7.2388, device='cuda:0', grad_fn=

KeyboardInterrupt: 

In [None]:
# The number of epochs completed in this demo notebook, for the example below:
epochs_completed

0

In [None]:
batches = iter(dataloaders["train"])

X_src, X_tgt = next(batches)
X_src = X_src.to(device)
X_tgt = X_tgt.to(device)

X_src = X_src[:1, :]
src_key_padding_mask = X_src == PAD_IDX

model.eval()
with torch.no_grad():
    sentence = model.generate(X_src, src_key_padding_mask, 7, 60, SOS_IDX, PAD_IDX)

print("English sentence:", en_vocab.to_words(X_src[0]))
print("French sentence ground truth:", fr_vocab.to_words(X_tgt[0]))
print("Model output:", fr_vocab.to_words(sentence))

English sentence: a teenage girl wearing a t - shirt and bathing suit bottom jumps into a body of water .
French sentence ground truth: une adolescente vêtue d' un t - shirt et d' un bas de maillot saute dans un plan d' eau .
Model output: shish falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises falaises
