In [2]:
import torch
import torch.nn as nn
from tqdm import tqdm
import os
import wandb

import components

In [3]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    num_gpus = torch.cuda.device_count()
    print(f"Using {num_gpus} GPU(s)")
else:
    device = torch.device("cpu")

device

Using 4 GPU(s)


device(type='cuda')

# Download dataset

### download and inspect dataset

In [4]:
from datasets import load_dataset, DatasetDict

In [5]:
# Load the WMT14 dataset for German-English translation
dataset = load_dataset('wmt14', 'de-en')

In [6]:
dataset

DatasetDict({
    train: Dataset({
        features: ['translation'],
        num_rows: 4508785
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 3000
    })
    test: Dataset({
        features: ['translation'],
        num_rows: 3003
    })
})

In [7]:
dataset['train'][4]

{'translation': {'de': 'Heute möchte ich Sie bitten - das ist auch der Wunsch einiger Kolleginnen und Kollegen -, allen Opfern der Stürme, insbesondere in den verschiedenen Ländern der Europäischen Union, in einer Schweigeminute zu gedenken.',
  'en': "In the meantime, I should like to observe a minute' s silence, as a number of Members have requested, on behalf of all the victims concerned, particularly those of the terrible storms, in the various countries of the European Union."}}

In [8]:
# select a very small segment for experimentation
# Take a small subset for experimentation
small_train_dataset = dataset['train'].select(range(20))
small_val_dataset = dataset['validation'].select(range(5))

In [9]:
small_train_dataset

Dataset({
    features: ['translation'],
    num_rows: 20
})

### Tokenization

In [10]:
# as we are following the original `Attention is all you need paper` we will use Byte-Pair Encoding
from tokenizers import ByteLevelBPETokenizer

In [11]:
# Load the trained tokenizer
vocab_path = os.path.join(os.getcwd(), "transformers-based-translator/de-en-bpetokenizer/vocab.json")
merges_path = os.path.join(os.getcwd(), "transformers-based-translator/de-en-bpetokenizer/merges.txt")
tokenizer = ByteLevelBPETokenizer(
    vocab_path,
    merges_path
)

In [12]:
# Test the tokenizer
print(tokenizer.encode("Das ist ein Beispiel.").ids)

print([tokenizer.id_to_token(token) for token in tokenizer.encode("Das ist ein Beispiel").ids])
# Should return something like ['<s>', 'Das', 'ist', 'ein', 'Beispiel', '</s>']

print(tokenizer.token_to_id("</s>"))
# Should return a valid token ID for '</s>'

print(tokenizer.decode(tokenizer.encode("Das ist ein Beispiel.").ids))

[789, 423, 328, 3010, 18]
['Das', 'Ġist', 'Ġein', 'ĠBeispiel']
2
Das ist ein Beispiel.


In [13]:
PAD_TOKEN_ID = tokenizer.token_to_id('<pad>')
BOS_TOKEN_ID = tokenizer.token_to_id('<s>')
EOS_TOKEN_ID = tokenizer.token_to_id('</s>')
print(BOS_TOKEN_ID, EOS_TOKEN_ID, PAD_TOKEN_ID)

0 2 1


In [14]:
from data_processing import TranslationDataset, collate_fn

small_translation_ds = TranslationDataset(small_train_dataset, tokenizer=tokenizer, src_lang='de', tgt_lang='en', bos_token_id=BOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID, pad_token_id=PAD_TOKEN_ID, max_length=30)
small_translation_ds[0]

{'src_sentence': 'Wiederaufnahme der Sitzungsperiode',
 'tgt_sentence': 'Resumption of the session',
 'src_tokens': tensor([    0, 23062, 17719,   319, 26699,     2,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1]),
 'tgt_tokens': tensor([    0,  8859, 27958,   304,   280,  9974,     2,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1])}

## Instantiate Encoder-Decoder and Translator classes

In [15]:
import wandb

config = {
    "num_blocks": 6,
    "num_heads": 8, 
    "d_model": 512, 
    "d_ff": 2048, 
    "vocab_size": tokenizer.get_vocab_size(), 
    "max_len": 512, 
    "dropout": 0.1, 
    "verbose": False, 
    "betas": (0.9, 0.98),
    "eps": 1e-9,
    "warmup_steps": 5000,
    "batch_size": 64,
    "num_epochs": 3
}

api_key = os.getenv("WANDB_API_KEY")

wandb.init(
    project="transformers-based-translator",
    config=config
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mdominic-culver[0m ([33mdominic-l-culver[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [16]:
model = components.EncoderDecoder.from_hyperparameters(
    **config
)
model

EncoderDecoder(
  (encoder): Encoder(
    (encoder_blocks): ModuleList(
      (0-5): 6 x EncoderLayer(
        (mha): MultiHeadAttention(
          (query_linear): Linear(in_features=512, out_features=512, bias=True)
          (key_linear): Linear(in_features=512, out_features=512, bias=True)
          (value_linear): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (output_linear): Linear(in_features=512, out_features=512, bias=True)
        )
        (ffn): PositionwiseFFN(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (layernorm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (layernorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (layernorm): LayerNorm((

In [17]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
pytorch_total_params

82065544

In [18]:
translator = components.Translator(
    model=model,
    tokenizer=tokenizer,
    bos_token_id=BOS_TOKEN_ID, 
    eos_token_id=EOS_TOKEN_ID,
    pad_token_id=PAD_TOKEN_ID, 
    device=device
)
translator

<components.Translator at 0x7f60c4a14eb0>

## Set up Dataloaders and Trainer

In [19]:
from train import Trainer
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

optimizer = torch.optim.AdamW(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)

Setting device to: cuda


In [20]:
# create dataset
train_ds = TranslationDataset(dataset['train'].shuffle(), tokenizer=tokenizer, src_lang='de', tgt_lang='en', bos_token_id=BOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID, pad_token_id=PAD_TOKEN_ID)
val_ds = TranslationDataset(dataset['validation'], tokenizer=tokenizer, src_lang='de', tgt_lang='en', bos_token_id=BOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID, pad_token_id=PAD_TOKEN_ID)

In [21]:
# create dataloaders
batch_size = config['batch_size']
num_workers = 16
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=lambda batch: collate_fn(batch, PAD_TOKEN_ID), num_workers=num_workers, pin_memory=True)
val_dl = DataLoader(val_ds, batch_size=batch_size, collate_fn=lambda batch: collate_fn(batch, PAD_TOKEN_ID), num_workers=num_workers, pin_memory=True)

In [22]:
for batch in train_dl:
    print(batch.keys())
    break

dict_keys(['src_tokens', 'tgt_input', 'tgt_output', 'src_mask', 'tgt_mask'])


In [23]:
trainer = Trainer(
    model=model,
    train_dl=train_dl,
    val_dl=val_dl,
    tokenizer=tokenizer,
    optimizer=optimizer,
    criterion=torch.nn.NLLLoss(),
    device=device,
)

Using 4 GPU(s)


In [24]:
trainer.train(
    n_epochs=config['num_epochs'],
    save_dir='test_model'
)

Epoch 1/3


  0%|          | 101/70450 [02:29<28:40:57,  1.47s/it]

Epoch 1/3, average loss at batch 100: 3.4904


  0%|          | 201/70450 [04:56<28:46:28,  1.47s/it]

Epoch 1/3, average loss at batch 200: 2.1079


  0%|          | 301/70450 [07:24<28:50:53,  1.48s/it]

Epoch 1/3, average loss at batch 300: 1.5897


  1%|          | 401/70450 [09:51<28:43:51,  1.48s/it]

Epoch 1/3, average loss at batch 400: 1.3091


  1%|          | 500/70450 [12:19<28:43:41,  1.48s/it]

Epoch 1/3, average loss at batch 500: 1.1349


100%|██████████| 47/47 [00:31<00:00,  1.51it/s]


Epoch 1/3, Batch 500, Average validation loss: 0.3623099235144067
Validation loss has improved!


  1%|          | 601/70450 [15:20<28:39:21,  1.48s/it] 

Epoch 1/3, average loss at batch 600: 1.0148


  1%|          | 701/70450 [17:48<28:40:51,  1.48s/it]

Epoch 1/3, average loss at batch 700: 0.9275


  1%|          | 801/70450 [20:16<28:31:22,  1.47s/it]

Epoch 1/3, average loss at batch 800: 0.8598


  1%|          | 807/70450 [20:25<28:29:58,  1.47s/it]