In [7]:
import torch
import torch.nn as nn
from tqdm import tqdm
import os
import wandb

import components
import utils

# Download dataset

In [8]:
from datasets import load_dataset

In [9]:
# Load the WMT14 dataset for German-English translation
dataset = load_dataset('wmt14', 'de-en')

In [10]:
# In this notebook, we will train on a small segment of the dataset as we will be working locally. 
# We will figure out the parameters and then train on the full set in the cloud. 

# Take a small subset for experimentation
small_train_dataset = dataset['train'].select(range(20000))
small_val_dataset = dataset['validation'].select(range(1000))

### Tokenization

In [11]:
# as we are following the original `Attention is all you need paper` we will use Byte-Pair Encoding
from tokenizers import ByteLevelBPETokenizer

In [None]:
dataset['train'][0]

In [None]:
# we will be training our own BPE tokenizer for this task. 

with open('train_texts.txt', 'w', encoding='utf-8') as f:

    for example in tqdm(dataset['train']):
        f.write(example['translation']['de'] + '\n')
        f.write(example['translation']['en'] + '\n')



In [None]:
# now train a BPE tokenizer
bpe_tokenizer = ByteLevelBPETokenizer()

In [None]:
bpe_tokenizer.train(
    files=['train_texts.txt'],
    vocab_size=37000,
    min_frequency=2,
    special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
)

In [None]:
# save the tokenizer
save_directory = 'bpe_tokenizer'
if not os.path.exists(save_directory):
    os.makedirs(save_directory)
bpe_tokenizer.save_model('bpe_tokenizer')

In [12]:
# Load the trained tokenizer
tokenizer = ByteLevelBPETokenizer(
    "bpe_tokenizer/vocab.json",
    "bpe_tokenizer/merges.txt"
)

In [13]:
# Test the tokenizer
print(tokenizer.encode("Das ist ein Beispiel.").ids)
# Should return something like ['<s>', 'Das', 'ist', 'ein', 'Beispiel', '</s>']

print(tokenizer.token_to_id("</s>"))
# Should return a valid token ID for '</s>'


[789, 423, 328, 3010, 18]
2


### Define the tokenization pipeline

In [14]:
def tokenize(examples):

    # Extract German and English sentences from the list of dictionaries
    src_texts = [example['de'] for example in examples['translation']]
    tgt_texts = [example['en'] for example in examples['translation']]

    # tokenize src and tgt
    src_tokens = tokenizer.encode_batch(src_texts)
    tgt_tokens = tokenizer.encode_batch(tgt_texts)

    # return dictionary format expected by PyTorch
    return {
        'input_ids': [[tokenizer.token_to_id('<s>')] + encoding.ids + [tokenizer.token_to_id('</s>')] for encoding in src_tokens],
        'attention_mask': [[tokenizer.token_to_id('<pad>')] + encoding.attention_mask + [tokenizer.token_to_id('<pad>')] for encoding in src_tokens],
        'labels': [[tokenizer.token_to_id('<s>')] + encoding.ids + [tokenizer.token_to_id('</s>')] for encoding in tgt_tokens]
    }


In [15]:
# tokenize the data
tokenized_train = small_train_dataset.map(tokenize, batched=True)
tokenized_val = small_val_dataset.map(tokenize, batched=True)

In [16]:
example = tokenized_train[98]

assert len(example['input_ids']) == len(example['attention_mask'])

In [17]:
tokenized_train[0]

{'translation': {'de': 'Wiederaufnahme der Sitzungsperiode',
  'en': 'Resumption of the session'},
 'input_ids': [0, 23062, 17719, 319, 26699, 2],
 'attention_mask': [1, 1, 1, 1, 1, 1],
 'labels': [0, 8859, 27958, 304, 280, 9974, 2]}

In [18]:
tokenizer.decode(tokenized_train[0]['labels'])

'<s>Resumption of the session</s>'

## Set up Dataloaders

In [19]:
# we need to write a collate_fn to pad sentences to be of the same size
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    input_ids = [torch.tensor(item['input_ids']) for item in batch]
    attention_mask = [torch.tensor(item['attention_mask']) for item in batch]
    labels = [torch.tensor(item['labels']) for item in batch]

    # Pad sequences to the length of the longest sequence in the batch
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.token_to_id('<pad>'))
    attention_mask_padded = pad_sequence(attention_mask, batch_first=True, padding_value=0)
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=tokenizer.token_to_id('<pad>'))

    return {
        'input_ids': input_ids_padded,
        'attention_mask': attention_mask_padded,
        'labels': labels_padded
    }

    

In [22]:
from torch.utils.data import DataLoader

BATCH_SIZE = 16

# create the data loaders
train_dl = DataLoader(
    tokenized_train, 
    batch_size=BATCH_SIZE,
    shuffle=True, 
    collate_fn=collate_fn
)

val_dl = DataLoader(
    tokenized_val, 
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn
)

In [23]:
# Get the first batch from the training DataLoader
for batch in train_dl:
    print("Input IDs:", batch['input_ids'].shape)
    print("Attention Mask:", batch['attention_mask'].shape)
    print("Labels:", batch['labels'].shape)
    break


Input IDs: torch.Size([16, 64])
Attention Mask: torch.Size([16, 64])
Labels: torch.Size([16, 68])


## Masking and Batching

In [24]:
import torch
import torch.optim as optim

def make_std_mask(tgt, pad):
    "Create a mask to hide padding and future words."
    #print("Target (tgt):", tgt)
    
    # Padding mask
    tgt_padding_mask = (tgt != pad).unsqueeze(1).unsqueeze(2)
    #print("Padding Mask:", tgt_padding_mask)
    
    # Look-ahead mask (subsequent mask)
    tgt_seq_len = tgt.size(-1)
    look_ahead_mask = torch.triu(torch.ones((1, tgt_seq_len, tgt_seq_len), device=tgt.device), diagonal=1).type_as(tgt_padding_mask.data)
    #print("Look-Ahead Mask (Subsequent Mask):", look_ahead_mask)
    
    # Combined mask
    tgt_mask = tgt_padding_mask & (look_ahead_mask == 0)
    #print("Combined Target Mask:", tgt_mask)
    
    return tgt_mask


In [25]:
# Example English sentence
sentence = "<s>The cat sat on the mat.</s>"

# Tokenize the sentence using your trained tokenizer
tgt_tokens = tokenizer.encode(sentence)
tgt_token_ids = tgt_tokens.ids  # Get the list of token IDs
print("Tokenized Sentence IDs:", tgt_token_ids)

# Convert the token IDs to a tensor (assuming <pad> token ID is 0)
tgt_tensor = torch.tensor([tgt_token_ids + [tokenizer.token_to_id('<pad>')] * (20 - len(tgt_token_ids))])  # Pad to length 10
print("Padded Tokenized Sentence Tensor:", tgt_tensor)


Tokenized Sentence IDs: [32, 87, 34, 465, 16218, 22524, 385, 280, 4226, 18, 32, 19, 87, 34]
Padded Tokenized Sentence Tensor: tensor([[   32,    87,    34,   465, 16218, 22524,   385,   280,  4226,    18,
            32,    19,    87,    34,     1,     1,     1,     1,     1,     1]])


In [26]:
make_std_mask(tgt_tensor, tokenizer.token_to_id('<pad>'))

tensor([[[[ True, False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False, False],
          [ True,  True, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False, False],
          [ True,  True,  True, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False, False],
          [ True,  True,  True,  True, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False, False],
          [ True,  True,  True,  True,  True, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False, False],
          [ True,  True,  True,  True,  True,  True, False, False, False, False,
           False, False, False, False, False, False, False, False, False, False],
          [ True,  Tru

# Training

In [27]:
import wandb

# initialize wandb
wandb.init(project='transformer-translator')

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011176713433168414, max=1.0…

In [42]:
# Model config
model_config ={
    'd_model': 512, 
    'num_heads': 8,
    'num_encoder_layers': 6, 
    'num_decoder_layers': 6,
    'd_ff': 2048,
    'dropout': 0.1, 
    'src_vocab': tokenizer.get_vocab_size(), 
    'tgt_vocab': tokenizer.get_vocab_size()
}
d_model = 512  # Model dimension
num_heads = 8  # Number of attention heads
num_encoder_layers = 6  # Number of encoder layers
num_decoder_layers = 6  # Number of decoder layers
d_ff = 2048  # Dimension of feedforward layers
dropout = 0.1  # Dropout rate

vocab_size = tokenizer.get_vocab_size()  # Vocabulary size from your tokenizer

# Initialize the encoder, decoder, and the full model
encoder = components.Encoder(num_encoder_layers, num_heads, d_model, d_ff, dropout)
decoder = components.Decoder(num_decoder_layers, num_heads, d_model, d_ff, dropout)
src_embed = nn.Sequential(nn.Embedding(vocab_size, d_model), components.PositionalEncoding(d_model, dropout))
tgt_embed = nn.Sequential(nn.Embedding(vocab_size, d_model), components.PositionalEncoding(d_model, dropout))
generator = components.Generator(d_model, vocab_size)

# Initialize the EncoderDecoder model
model = components.EncoderDecoder(encoder, decoder, src_embed, tgt_embed, generator)

In [43]:
# Create learning rate scheduler, following `Attention is All You Need` for now. 
# lr = d_model ** (-0.5) * min(step_num ** (-0.5), step_num * warmup_steps ** (-1.5))
warmup_steps = 4000

def get_lr(step_num):
    return d_model ** -0.5 * min(step_num ** -0.5, step_num * warmup_steps ** -1.5)


# initialize optimizer, criterion
optimizer = optim.AdamW(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id('<pad>'))


In [44]:
# specify device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

model.to(device)

EncoderDecoder(
  (encoder): Encoder(
    (encoder_blocks): ModuleList(
      (0-5): 6 x EncoderLayer(
        (mha): MultiHeadAttention(
          (query_linear): Linear(in_features=512, out_features=512, bias=True)
          (key_linear): Linear(in_features=512, out_features=512, bias=True)
          (value_linear): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (output_linear): Linear(in_features=512, out_features=512, bias=True)
        )
        (ffn): PositionwiseFFN(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (layernorm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (layernorm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (layernorm): LayerNorm((

In [45]:
optimizer.defaults

{'lr': 0,
 'betas': (0.9, 0.98),
 'eps': 1e-09,
 'weight_decay': 0.01,
 'amsgrad': False,
 'foreach': None,
 'maximize': False,
 'capturable': False,
 'differentiable': False,
 'fused': None}

In [46]:
from tqdm.notebook import tqdm


# Number of epochs to train
num_epochs = 5
step_num = 0
pad_token_id = tokenizer.token_to_id('<pad>')

# log hyperparameters
hyperparameters = model_config.copy()
hyperparameters['num_epochs'] = num_epochs
hyperparameters['batch_size'] = train_dl.batch_size
hyperparameters['initial_lr'] = 0
hyperparameters['warmup_steps'] = warmup_steps
hyperparameters['betas'] = optimizer.defaults['betas']
hyperparameters['eps'] = optimizer.defaults['eps']
hyperparameters['model'] = model.__class__.__name__


wandb.config.update(hyperparameters)

for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0

    for batch in tqdm(train_dl, desc=f"Training Epoch: {epoch + 1}"):
        step_num += 1

        # adjust the learning rate according to the schedule
        lr = get_lr(step_num)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device).unsqueeze(1).unsqueeze(2)
        labels = batch['labels'].to(device)


        # shift the target token ids for the decoder input
        tgt_input = labels[:, :-1]
        tgt_y = labels[:, 1:]

        # create the target mask (combining padding and look-ahead masks)
        tgt_mask = make_std_mask(tgt_input, pad_token_id)

        # forward pass
        optimizer.zero_grad()
        logits = model(src=input_ids, tgt=tgt_input, src_mask=attention_mask, tgt_mask=tgt_mask)

        # compute the loss
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_y.reshape(-1))
        total_train_loss += loss.item()

        # backward pass and optimization
        loss.backward()
        optimizer.step()

        # log the learning rate and training loss to wandb
        wandb.log({
            'train_loss': loss.item(),
            'learning_rate': lr, 
            'step': step_num,
            'epoch': epoch + 1,
            })
    
    avg_loss = total_train_loss / len(train_dl)
    print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {avg_loss:.4f}")

    # Validation loop
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in val_dl:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device).unsqueeze(1).unsqueeze(2)
            labels = batch['labels'].to(device)

            tgt_input = labels[:, :-1]
            tgt_y = labels[:, 1:]

            tgt_mask = make_std_mask(tgt_input, pad_token_id)

            logits = model(input_ids, tgt_input, attention_mask, tgt_mask)
            loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_y.reshape(-1))
            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(val_dl)
    print(f"Epoch {epoch + 1}/{num_epochs}, Validation Loss: {avg_val_loss:.4f}")

    # log validation loss to wandb
    wandb.log({
        'val_loss': avg_val_loss, 
        'epoch': epoch + 1
    })

# finish the wandb run
wandb.finish()


Training Epoch: 1:   0%|          | 0/1250 [00:00<?, ?it/s]

### Inference

In [None]:
# check on a few german sentences

model.eval()

num_examples = 5
examples = []
for i in range(num_examples):
    examples.append(tokenized_val[i])

def decode_tokens(tokens, tokenizer):
    return tokenizer.decode(tokens, skip_special_tokens=True)

# perform inference
with torch.no_grad():
    for i, src in enumerate(examples):

        # convert to tensor and move to device
        src_tensor = torch.tensor(src['input_ids']).unsqueeze(0).to(device)
        attention_mask = (src_tensor != pad_token_id).unsqueeze(1).unsqueeze(2)

        print(f"src_tensor shape: {src_tensor.shape}")
        print(f"attention_mask shape: {attention_mask.shape}")

        tgt_tensor = torch.tensor([tokenizer.token_to_id('<s>')]).unsqueeze(0).to(device)

        for _ in range(100):  # limit the length of the generated sequence for now at least...
            # create the tgt_mask
            tgt_mask = make_std_mask(tgt_tensor, pad_token_id)

            # run the model
            logits = model(src_tensor, tgt_tensor, attention_mask, tgt_mask)

            # get the predicted next_token
            next_token = logits[:, -1, :].argmax(dim=-1)
            # print(f"Next token predicted: {next_token.item()} (Token: {tokenizer.decode([next_token.item()])})")

            tgt_tensor = torch.cat([tgt_tensor, next_token.unsqueeze(0)], dim=1)

            # Check if the predicted token is </s>
            if next_token.item() == tokenizer.token_to_id('</s>'):
                print("End of sentence token encountered, stopping inference.")
                break

        # decode the source and target
        src_sentence = decode_tokens(src['input_ids'], tokenizer=tokenizer)
        tgt_sentence = decode_tokens(tgt_tensor.squeeze().tolist(), tokenizer)
        actual_tgt = src['translation']['en']

        print(f"German: {src_sentence}")
        print(f"Actual translation: {actual_tgt}")
        print(f"NMT: {tgt_sentence}")
        print("-" * 50)
