In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm
import os
import wandb

import components

In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device

device(type='mps')

# Download dataset

### download and inspect dataset

In [3]:
from datasets import load_dataset, DatasetDict

In [4]:
# Load the WMT14 dataset for German-English translation
dataset = load_dataset('wmt14', 'de-en')

In [5]:
dataset

DatasetDict({
    train: Dataset({
        features: ['translation'],
        num_rows: 4508785
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 3000
    })
    test: Dataset({
        features: ['translation'],
        num_rows: 3003
    })
})

In [6]:
dataset['train'][4]

{'translation': {'de': 'Heute möchte ich Sie bitten - das ist auch der Wunsch einiger Kolleginnen und Kollegen -, allen Opfern der Stürme, insbesondere in den verschiedenen Ländern der Europäischen Union, in einer Schweigeminute zu gedenken.',
  'en': "In the meantime, I should like to observe a minute' s silence, as a number of Members have requested, on behalf of all the victims concerned, particularly those of the terrible storms, in the various countries of the European Union."}}

In [7]:
# select a very small segment for experimentation
# Take a small subset for experimentation
small_train_dataset = dataset['train'].select(range(20))
small_val_dataset = dataset['validation'].select(range(5))

In [8]:
small_train_dataset

Dataset({
    features: ['translation'],
    num_rows: 20
})

### Tokenization

In [9]:
# as we are following the original `Attention is all you need paper` we will use Byte-Pair Encoding
from tokenizers import ByteLevelBPETokenizer

In [10]:
# Load the trained tokenizer
tokenizer = ByteLevelBPETokenizer(
    "bpe_tokenizer/vocab.json",
    "bpe_tokenizer/merges.txt"
)

In [11]:
# Test the tokenizer
print(tokenizer.encode("Das ist ein Beispiel.").ids)

print([tokenizer.id_to_token(token) for token in tokenizer.encode("Das ist ein Beispiel").ids])
# Should return something like ['<s>', 'Das', 'ist', 'ein', 'Beispiel', '</s>']

print(tokenizer.token_to_id("</s>"))
# Should return a valid token ID for '</s>'

print(tokenizer.decode(tokenizer.encode("Das ist ein Beispiel.").ids))

[789, 423, 328, 3010, 18]
['Das', 'Ġist', 'Ġein', 'ĠBeispiel']
2
Das ist ein Beispiel.


In [12]:
PAD_TOKEN_ID = tokenizer.token_to_id('<pad>')
BOS_TOKEN_ID = tokenizer.token_to_id('<s>')
EOS_TOKEN_ID = tokenizer.token_to_id('</s>')
print(BOS_TOKEN_ID, EOS_TOKEN_ID, PAD_TOKEN_ID)

0 2 1


In [13]:
from data_processing import TranslationDataset, collate_fn

small_translation_ds = TranslationDataset(small_train_dataset, tokenizer=tokenizer, src_lang='de', tgt_lang='en', bos_token_id=BOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID, pad_token_id=PAD_TOKEN_ID, max_length=30)
small_translation_ds[0]

{'src_sentence': 'Wiederaufnahme der Sitzungsperiode',
 'tgt_sentence': 'Resumption of the session',
 'src_tokens': tensor([    0, 23062, 17719,   319, 26699,     2,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1]),
 'tgt_tokens': tensor([    0,  8859, 27958,   304,   280,  9974,     2,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1])}

## Instantiate Encoder-Decoder and Translator classes

In [14]:
model = components.EncoderDecoder.from_hyperparameters(
    num_blocks=6, 
    num_heads=8, 
    d_model=512, 
    d_ff=2048,
    vocab_size=tokenizer.get_vocab_size(),
    max_len=512, 
    dropout=0.1,
    verbose=False
)
model

EncoderDecoder(
  (encoder): Encoder(
    (encoder_blocks): ModuleList(
      (0-5): 6 x EncoderLayer(
        (mha): MultiHeadAttention(
          (query_linear): Linear(in_features=512, out_features=512, bias=True)
          (key_linear): Linear(in_features=512, out_features=512, bias=True)
          (value_linear): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (output_linear): Linear(in_features=512, out_features=512, bias=True)
        )
        (ffn): PositionwiseFFN(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (layernorm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (layernorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (layernorm): LayerNorm((

In [15]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
pytorch_total_params

82065544

In [16]:
translator = components.Translator(
    model=model,
    tokenizer=tokenizer,
    bos_token_id=BOS_TOKEN_ID, 
    eos_token_id=EOS_TOKEN_ID,
    pad_token_id=PAD_TOKEN_ID, 
    device=device
)
translator

<components.Translator at 0x16dee06d0>

## Set up Trainer

In [17]:
from train import Trainer
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

optimizer = torch.optim.AdamW(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)

Setting device to: mps


In [18]:
trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    optimizer=optimizer,
    criterion=torch.nn.NLLLoss(),
    device=device,
)

In [19]:
# create dataset
train_ds = TranslationDataset(dataset['train'].shuffle().select(range(20000)), tokenizer=tokenizer, src_lang='de', tgt_lang='en', bos_token_id=BOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID, pad_token_id=PAD_TOKEN_ID)
val_ds = TranslationDataset(dataset['validation'], tokenizer=tokenizer, src_lang='de', tgt_lang='en', bos_token_id=BOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID, pad_token_id=PAD_TOKEN_ID)

In [20]:
# create dataloaders
batch_size = 16
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=lambda batch: collate_fn(batch, PAD_TOKEN_ID))
val_dl = DataLoader(val_ds, batch_size=batch_size, collate_fn=lambda batch: collate_fn(batch, PAD_TOKEN_ID))

In [21]:
for batch in train_dl:
    print(batch.keys())
    break

dict_keys(['src_tokens', 'tgt_input', 'tgt_output', 'src_mask', 'tgt_mask'])


In [22]:
trainer.train(
    training_dl=train_dl,
    validation_dl=val_dl,
    n_epochs=2,
    save_dir='test_model'
)

Epoch 1/2


  0%|          | 1/1250 [00:02<1:00:09,  2.89s/it]

Epoch 1/2, average loss at batch 0: 10.5058


  1%|          | 9/1250 [00:23<52:59,  2.56s/it]  