In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm
import os
import wandb

import components

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device("cpu")

device

# Download dataset

### download and inspect dataset

In [None]:
from datasets import load_dataset, DatasetDict

In [None]:
# Load the WMT14 dataset for German-English translation
dataset = load_dataset('wmt14', 'de-en')

In [None]:
dataset

In [None]:
dataset['train'][4]

In [None]:
# select a very small segment for experimentation
# Take a small subset for experimentation
small_train_dataset = dataset['train'].select(range(20))
small_val_dataset = dataset['validation'].select(range(5))

In [None]:
small_train_dataset

### Tokenization

In [None]:
# as we are following the original `Attention is all you need paper` we will use Byte-Pair Encoding
from tokenizers import ByteLevelBPETokenizer

In [None]:
# Load the trained tokenizer
vocab_path = os.path.join(os.getcwd(), "transformers-based-translator/de-en-bpetokenizer/vocab.json")
merges_path = os.path.join(os.getcwd(), "transformers-based-translator/de-en-bpetokenizer/merges.txt")
tokenizer = ByteLevelBPETokenizer(
    vocab_path,
    merges_path
)

In [None]:
# Test the tokenizer
print(tokenizer.encode("Das ist ein Beispiel.").ids)

print([tokenizer.id_to_token(token) for token in tokenizer.encode("Das ist ein Beispiel").ids])
# Should return something like ['<s>', 'Das', 'ist', 'ein', 'Beispiel', '</s>']

print(tokenizer.token_to_id("</s>"))
# Should return a valid token ID for '</s>'

print(tokenizer.decode(tokenizer.encode("Das ist ein Beispiel.").ids))

In [None]:
PAD_TOKEN_ID = tokenizer.token_to_id('<pad>')
BOS_TOKEN_ID = tokenizer.token_to_id('<s>')
EOS_TOKEN_ID = tokenizer.token_to_id('</s>')
print(BOS_TOKEN_ID, EOS_TOKEN_ID, PAD_TOKEN_ID)

In [None]:
from data_processing import TranslationDataset, collate_fn

small_translation_ds = TranslationDataset(small_train_dataset, tokenizer=tokenizer, src_lang='de', tgt_lang='en', bos_token_id=BOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID, pad_token_id=PAD_TOKEN_ID, max_length=30)
small_translation_ds[0]

## Instantiate Encoder-Decoder and Translator classes

In [None]:
import wandb

config = {
    "num_blocks": 6,
    "num_heads": 8, 
    "d_model": 512, 
    "d_ff": 2048, 
    "vocab_size": tokenizer.get_vocab_size(), 
    "max_len": 512, 
    "dropout": 0.1, 
    "verbose": False, 
    "betas": (0.9, 0.98),
    "eps": 1e-9,
    "warmup_steps": 5000,
    "batch_size": 32,
    "num_epochs": 5
}

api_key = os.getenv("WANDB_API_KEY")

wandb.init(
    project="transformers-based-translator",
    config=config
)

In [None]:
model = components.EncoderDecoder.from_hyperparameters(
    **config
)
model

In [None]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
pytorch_total_params

In [None]:
translator = components.Translator(
    model=model,
    tokenizer=tokenizer,
    bos_token_id=BOS_TOKEN_ID, 
    eos_token_id=EOS_TOKEN_ID,
    pad_token_id=PAD_TOKEN_ID, 
    device=device
)
translator

## Set up Trainer

In [None]:
from train import Trainer
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

optimizer = torch.optim.AdamW(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)

In [None]:
trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    optimizer=optimizer,
    criterion=torch.nn.NLLLoss(),
    device=device,
)

In [None]:
# create dataset
train_ds = TranslationDataset(dataset['train'].shuffle(), tokenizer=tokenizer, src_lang='de', tgt_lang='en', bos_token_id=BOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID, pad_token_id=PAD_TOKEN_ID)
val_ds = TranslationDataset(dataset['validation'], tokenizer=tokenizer, src_lang='de', tgt_lang='en', bos_token_id=BOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID, pad_token_id=PAD_TOKEN_ID)

In [None]:
# create dataloaders
batch_size = config['batch_size']
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=lambda batch: collate_fn(batch, PAD_TOKEN_ID))
val_dl = DataLoader(val_ds, batch_size=batch_size, collate_fn=lambda batch: collate_fn(batch, PAD_TOKEN_ID))

In [None]:
for batch in train_dl:
    print(batch.keys())
    break

In [None]:
trainer.train(
    training_dl=train_dl,
    validation_dl=val_dl,
    n_epochs=config['num_epochs'],
    save_dir='test_model'
)