In [None]:
from google.colab import drive
drive.mount('/content/drive')

Paste the below code to console to prevent runtime from unconnecting.

```
var startClickConnect = function startClickConnect() {
  var clickConnect = function clickConnect() {
    console.log("Connect button has been clicked.");
    document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click();
  };

  var intervalId = setInterval(clickConnect, 60000);

  var stopClickConnectHandler = function stopClickConnect() {
    clearInterval(intervalId);
    console.log("Stop auto clicker.");
  };

  return stopClickConnectHandler;
};

var stopClickConnect = startClickConnect();
```


In [None]:
!pip install --upgrade torch triton
!pip install wandb
!pip install torchmetrics

In [None]:
import os
import re
import pdb
import math
from tqdm import tqdm
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import Dataset, DataLoader
import torchmetrics
from torchmetrics.text import BLEUScore

import wandb
import numpy as np
from google.colab import userdata

from transformers import BertTokenizer, PreTrainedTokenizerFast, DataCollatorForSeq2Seq
from tokenizers import ByteLevelBPETokenizer
from tokenizers.processors import TemplateProcessing
from datasets import Dataset

In [None]:
# Preprocessing
vocab_size = 30000
max_len = 20
pad_id = 0
ukn_id = 1
bos_id = 2
eos_id = 3

# Model
n_encoder_layers = 6
n_decoder_layers = 6
d_model = 512
n_heads = 8
d_ff = 2048
dropout = 0.1

# Train
batch_size = 512
n_epochs = 50
warmup_steps = 4000

colab_dir = '/content/drive/My Drive/Colab Notebooks/'

In [None]:
hyperparameters = {
    "vocab_size": vocab_size,
    "max_len": max_len,
    "pad_id": pad_id,
    "ukn_id": ukn_id,
    "bos_id": bos_id,
    "eos_id": eos_id,
    "n_encoder_layers": n_encoder_layers,
    "n_decoder_layers": n_decoder_layers,
    "d_model": d_model,
    "n_heads": n_heads,
    "d_ff": d_ff,
    "dropout": dropout,
    "batch_size": batch_size,
    "n_epochs": n_epochs,
    "warmup_steps": warmup_steps
}

In [None]:
file_path = os.path.join(colab_dir, 'dataset/eng-fra.csv')
df = pd.read_csv(file_path, encoding='utf-8')
print(df.head())

In [None]:
dataset = Dataset.from_pandas(df)
print(dataset[0])

In [None]:
split_dataset = dataset.train_test_split(test_size=0.4, shuffle=True)
train_dataset = split_dataset['train']
temp_dataset = split_dataset['test']

split_temp_dataset = temp_dataset.train_test_split(test_size=0.5, shuffle=False)
val_dataset = split_temp_dataset['train']
test_dataset = split_temp_dataset['test']

print("Train dataset size:", len(train_dataset))
print("Validation dataset size:", len(val_dataset))
print("Test dataset size:", len(test_dataset))

In [None]:
tokenizer = ByteLevelBPETokenizer()

def batch_iterator():
    for i in range(0, len(dataset), 1000):
        yield dataset[i : i + 1000]["English words/sentences"]
    for i in range(0, len(dataset), 1000):
        yield dataset[i : i + 1000]["French words/sentences"]

tokenizer.train_from_iterator(batch_iterator(),
                              vocab_size=vocab_size,
                              length=len(dataset['English words/sentences']) + len(dataset['French words/sentences']),
                              special_tokens=["[PAD]", "[UNK]", "<s>","</s>"])

tokenizer.post_processor = TemplateProcessing(
    single="<s> $A </s>",
    special_tokens=[
        ("<s>", tokenizer.token_to_id("<s>")),
        ("</s>", tokenizer.token_to_id("</s>")),
    ]
)

tokenizer_save_path = "pretrained-bpe-tokenizer"
os.makedirs(tokenizer_save_path, exist_ok=True)

tokenizer_json_path = os.path.join(tokenizer_save_path, "tokenizer.json")
tokenizer.save(tokenizer_json_path)

In [None]:
# Wrap the trained tokenizer with PreTrainedTokenizerFast
new_tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_json_path)

new_tokenizer.add_special_tokens({
    'pad_token': '[PAD]',
    'unk_token': '[UNK]',
    'bos_token': '<s>',
    'eos_token': '</s>',
})

def tokenize_function(examples):
    return new_tokenizer(
        examples['English words/sentences'],
        text_target=examples['French words/sentences'],
        max_length=max_len, # 99.38% of sentences have 20 tokens or less
        padding='max_length',
        truncation=True,
        return_tensors='pt',
        return_token_type_ids=False,
    )

# Apply tokenization to the datasets
tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_val_dataset = val_dataset.map(tokenize_function, batched=True)
tokenized_test_dataset = test_dataset.map(tokenize_function, batched=True)

# Remove the original text columns from the tokenized datasets
tokenized_train_dataset = tokenized_train_dataset.remove_columns(['English words/sentences', 'French words/sentences'])
tokenized_val_dataset = tokenized_val_dataset.remove_columns(['English words/sentences', 'French words/sentences'])
tokenized_test_dataset = tokenized_test_dataset.remove_columns(['English words/sentences', 'French words/sentences'])

In [None]:
print(tokenized_train_dataset[0])

In [None]:
print(tokenizer.get_vocab_size())

In [None]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model: int, dropout: float=0.1, max_len: int=512):
    super().__init__()
    assert d_model % 2 == 0, "d_model must be even"
    self.dropout = nn.Dropout(p=dropout)

    position_idx = torch.arange(max_len).unsqueeze(1) # [max_len, 1]
    dim_idx = torch.arange(0, d_model, 2).unsqueeze(0) # [1, d_model//2]
    den = torch.exp(-dim_idx/d_model * math.log(10000)) # [1, d_model//2]

    positional_encoding = torch.zeros(max_len, d_model)
    positional_encoding[:, 0::2] = torch.sin(position_idx * den) # even
    positional_encoding[:, 1::2] = torch.cos(position_idx * den) # odd
    positional_encoding = positional_encoding.unsqueeze(0) # [1, max_len, d_model]

    self.register_buffer('positional_encoding', positional_encoding)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Args:
        x: [batch_size, seq_len, d_model]
    Returns:
        x: [batch_size, seq_len, d_model]
    """
    seq_len = x.size(1)
    x = x + self.positional_encoding[:, :seq_len]
    return self.dropout(x)

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model: int, n_heads: int):
    super().__init__()
    assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
    self.d_model = d_model
    self.n_heads = n_heads
    self.d_k = d_model // n_heads

    self.W_q = nn.Linear(d_model, d_model, bias=False)
    self.W_k = nn.Linear(d_model, d_model, bias=False)
    self.W_v = nn.Linear(d_model, d_model, bias=False)

    self.W_o = nn.Linear(d_model, d_model, bias=False)

  def scaled_dot_product_attention(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
    """
    Args:
        Q: [batch_size, n_heads, q_seq_len, d_k]
        K, V: [batch_size, n_heads, k_seq_len, d_k]
        mask: [batch_size, 1, 1, src_len] or [batch_size, 1, tgt_len, tgt_len]
    Returns:
        outp: [batch_size, n_heads, q_seq_len, d_k]
    """
    d_k = Q.size(-1)

    outp = torch.matmul(Q, K.transpose(-2, -1))
    outp = outp / math.sqrt(d_k)

    if mask is not None:
      outp = outp.masked_fill(mask == 0, -1e9)
    outp = F.softmax(outp, dim=-1)

    outp = torch.matmul(outp, V)

    return outp

  def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
    """
    Args:
        Q: [batch_size, q_seq_len, d_model]
        K, V: [batch_size, k_seq_len, d_model]
        mask: [batch_size, n_heads, q_seq_len, k_seq_len]
    Returns:
        outp: [batch_size, q_seq_len, d_model]
    """
    batch_size = Q.size(0)

    Q = self.W_q(Q)
    K = self.W_k(K)
    V = self.W_v(V)

    Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
    K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
    V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

    outp = self.scaled_dot_product_attention(Q, K, V, mask)
    outp = outp.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

    outp = self.W_o(outp)

    return outp

In [None]:
class FeedForward(nn.Module):
  def __init__(self, d_model: int, d_ff: int, dropout: float=0.1):
    super().__init__()
    self.linear1 = nn.Linear(d_model, d_ff, bias=True)
    self.linear2 = nn.Linear(d_ff, d_model, bias=True)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Args:
        x: [batch_size, seq_len, d_model]
    Returns:
        outp: [batch_size, seq_len, d_model]
    """

    outp = F.relu(self.linear1(x))
    outp = self.linear2(outp)

    return outp

In [None]:
class EncoderLayer(nn.Module):
  def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float=0.1):
    super().__init__()
    self.dropout = nn.Dropout(p=dropout)

    self.mha = MultiHeadAttention(d_model, n_heads)
    self.norm1 = nn.LayerNorm(d_model)

    self.ff = FeedForward(d_model, d_ff, dropout)
    self.norm2 = nn.LayerNorm(d_model)

  def forward(self, x: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
    """
    Args:
        x: [batch_size, src_len, d_model]
        src_mask: [batch_size, 1, src_len, src_len]
    Returns:
        outp: [batch_size, src_len, d_model]
    """
    outp = self.mha(x, x, x, src_mask)
    outp = self.dropout(outp)
    outp = outp + x
    outp = self.norm1(outp)

    res = outp
    outp = self.ff(outp)
    outp = self.dropout(outp)
    outp = outp + res
    outp = self.norm2(outp)

    return outp

In [None]:
class Encoder(nn.Module):
  def __init__(self, n_layers: int, d_model: int, n_heads: int, d_ff: int, vocab_size: int, dropout: int):
    super().__init__()

    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(d_model)

    self.layers = nn.ModuleList([
            EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
        ])

  def forward(self, x: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
    """
    Args:
        x: [batch_size, src_len]
        src_mask: [batch_size, 1, src_len, src_len]
    Returns:
        outp: [batch_size, src_len, d_model]
    """
    outp = self.embedding(x)
    outp = self.pos_encoding(outp)

    for layer in self.layers:
      outp = layer(outp, src_mask)

    return outp

In [None]:
class DecoderLayer(nn.Module):
  def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: int):
    super().__init__()

    self.dropout = nn.Dropout(p=dropout)

    self.mha1 = MultiHeadAttention(d_model, n_heads)
    self.norm1 = nn.LayerNorm(d_model)

    self.mha2 = MultiHeadAttention(d_model, n_heads)
    self.norm2 = nn.LayerNorm(d_model)

    self.ff = FeedForward(d_model, d_ff, dropout)
    self.norm3 = nn.LayerNorm(d_model)

  def forward(self, x: torch.Tensor, encoder_outp: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor) -> torch.Tensor:
    """
    Args:
        x: [batch_size, tgt_len, d_model]
        encoder_output: [batch_size, src_len, d_model]
        src_mask: [batch_size, 1, 1, src_len]
        tgt_mask: [batch_size, 1, tgt_len, tgt_len]
    Returns:
        outp: [batch_size, tgt_seq_len, d_model]
    """

    outp = self.mha1(x, x, x, tgt_mask)
    outp = self.dropout(outp)
    outp = outp + x
    outp = self.norm1(outp)

    res = outp
    outp = self.mha2(outp, encoder_outp, encoder_outp, src_mask)
    outp = self.dropout(outp)
    outp = outp + res
    outp = self.norm2(outp)

    res = outp
    outp = self.ff(outp)
    outp = self.dropout(outp)
    outp = outp + res
    outp = self.norm3(outp)

    return outp

In [None]:
class Decoder(nn.Module):
  def __init__(self, n_layers: int, d_model: int, n_heads: int, d_ff: int, vocab_size: int, dropout: int):
    super().__init__()

    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(d_model)

    self.layers = nn.ModuleList([
        DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
    ])

    self.linear = nn.Linear(d_model, vocab_size, bias=False)

    self.linear.weight = self.embedding.weight # weight sharing

  def forward(self, x: torch.Tensor, encoder_outp: torch.Tensor, src_mask: torch.Tensor, tgt_mask: torch.Tensor) -> torch.Tensor:
    """
    Args:
        x: [batch_size, tgt_len]
        encoder_outp: [batch_size, src_len, d_model]
        src_mask: [batch_size, 1, 1, src_len]
        tgt_mask: [batch_size, 1, tgt_len, tgt_len]
    Returns:
        outp: [batch_size, tgt_len, vocab_size]
    """
    outp = self.embedding(x)
    outp = self.pos_encoding(outp)

    for layer in self.layers:
      outp = layer(outp, encoder_outp, src_mask, tgt_mask)

    outp = self.linear(outp)

    return outp

In [None]:
class Transformer(nn.Module):
  def __init__(self, n_encoder_layers: int, n_decoder_layers: int, d_model: int,
               n_heads: int, d_ff: int, vocab_size: int, dropout: int,
               pad_id: int, bos_id: int, eos_id: int):
    super().__init__()

    self.pad_id = pad_id
    self.bos_id = bos_id
    self.eos_id = eos_id

    self.encoder = Encoder(n_encoder_layers, d_model, n_heads, d_ff, vocab_size, dropout)
    self.decoder = Decoder(n_decoder_layers, d_model, n_heads, d_ff, vocab_size, dropout)

  def get_attn_pad_mask(self, k_seq: torch.Tensor, pad_id: int) -> torch.Tensor:
    """
    Args:
        k_seq: [batch_size, k_seq_len]
    Returns:
        mask: [batch_size, k_seq_len]
    """
    mask = (k_seq != pad_id)
    return mask

  def get_attn_look_ahead_mask(self, k_seq: torch.Tensor) -> torch.Tensor:
    """
    Args:
        k_seq: [batch_size, k_seq_len]
    Returns:
        mask: [batch_size, k_seq_len, k_seq_len]
    """
    batch_size, k_seq_len = k_seq.size()
    mask = torch.tril(torch.ones(k_seq_len, k_seq_len, dtype=torch.bool, device=k_seq.device)).unsqueeze(0).expand(batch_size, -1, -1)
    return mask

  def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
    """
    Args:
        src: [batch_size, src_seq_len]
        tgt: [batch_size, tgt_seq_len]
    Returns:
        decoder_outp: [batch_size, tgt_seq_len, vocab_size]
    """

    src_mask = self.get_attn_pad_mask(src, self.pad_id)
    tgt_mask = self.get_attn_look_ahead_mask(tgt) & self.get_attn_pad_mask(tgt, self.pad_id).unsqueeze(1)

    src_mask = src_mask.unsqueeze(1).unsqueeze(2)
    tgt_mask = tgt_mask.unsqueeze(1)

    encoder_outp = self.encoder(src, src_mask)
    decoder_outp = self.decoder(tgt, encoder_outp, src_mask, tgt_mask)

    return decoder_outp

In [None]:
model = Transformer(n_encoder_layers, n_decoder_layers, d_model, n_heads, d_ff, vocab_size, dropout, pad_id, bos_id, eos_id)
model.to('cuda')

optimizer = optim.Adam(model.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-09)

lr_lambda = lambda step: d_model**(-0.5) * min((step+1)**(-0.5), (step+1) * warmup_steps**(-1.5))

scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_parameters(model))

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=new_tokenizer, padding=True)

train_dataloader = DataLoader(tokenized_train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
val_dataloader = DataLoader(tokenized_val_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
test_dataloader = DataLoader(tokenized_test_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)

In [None]:
def preprocess_decoder_input(tgt: torch.Tensor, pad_id: int, eos_id: int) -> torch.Tensor:
  """
  Args:
      tgt: [batch_size, tgt_len]
  Returns:
      decoder_input: [batch_size, tgt_len]
  """

  # Remove </s>
  decoder_input = tgt.detach().clone()
  decoder_input[decoder_input == eos_id] = pad_id
  return decoder_input

def preprocess_decoder_label(tgt: torch.Tensor, pad_id: int) -> torch.Tensor:
  """
  Args:
      tgt: [batch_size, tgt_len]
  Returns:
      decoder_label: [batch_size, tgt_len]
  """

  # Remove <s>
  decoder_label = tgt[:, 1:]
  pad_tensor = torch.full((tgt.size(0), 1), pad_id, device=tgt.device)
  decoder_label = torch.cat([decoder_label, pad_tensor], dim=1)

  # Replace <pad> with negative value
  decoder_label[decoder_label == pad_id] = -100
  return decoder_label

In [None]:
# Initialization
start_epoch = 0
best_loss = np.inf
patience = 3
counter = 0
wandb_run_id = None

# Load checkpoint if it exists
checkpoint_path = None # os.path.join(colab_dir, 'transformer-from-scratch/modelXXX.pt')
if checkpoint_path:

  print("Loading checkpoint...")
  checkpoint = torch.load(checkpoint_path)

  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
  start_epoch = checkpoint['epoch'] + 1
  best_loss = checkpoint['best_loss']
  wandb_run_id = checkpoint['wandb_run_id']

  print(f"Resuming training from epoch {start_epoch}")

In [None]:
# Write to wandb
wandb_api_key = userdata.get('WANDB_API_KEY')
wandb.login(key=wandb_api_key)
wandb.init(
    project="transformer-from-scratch",
    config=hyperparameters,
    id=wandb_run_id,
    resume="must" if wandb_run_id else None,
)

In [None]:
# Train
for epoch in range(start_epoch, n_epochs):

  model.train()
  total_train_loss = 0

  train_progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{n_epochs} [Train]")

  for batch in train_progress_bar:

    optimizer.zero_grad()

    src = batch['input_ids'].to('cuda')
    tgt = batch['labels'].to('cuda')
    decoder_input = preprocess_decoder_input(tgt, pad_id, eos_id)
    decoder_label = preprocess_decoder_label(tgt, pad_id)

    decoder_outp = model(src, decoder_input)

    loss = F.cross_entropy(decoder_outp.view(-1, decoder_outp.size(-1)), decoder_label.view(-1), ignore_index=-100)
    loss.backward()
    optimizer.step()
    scheduler.step()

    total_train_loss += loss.item()
    train_progress_bar.set_postfix(train_loss=loss.item())

    wandb.log({
            "train_batch_loss": loss.item(),
            "learning_rate": optimizer.param_groups[0]['lr']
    })

  avg_train_loss = total_train_loss / len(train_dataloader)

  model.eval()
  total_val_loss = 0

  with torch.no_grad():

    val_progress_bar = tqdm(val_dataloader, desc=f"Epoch {epoch+1}/{n_epochs} [Validation]")

    for batch in val_progress_bar:
      src = batch['input_ids'].to('cuda')
      tgt = batch['labels'].to('cuda')
      decoder_label = preprocess_decoder_label(tgt, pad_id)

      decoder_outp = model(src, tgt)

      loss = F.cross_entropy(decoder_outp.view(-1, decoder_outp.size(-1)), decoder_label.view(-1), ignore_index=-100)

      total_val_loss += loss.item()
      val_progress_bar.set_postfix(val_loss=loss.item())

  avg_val_loss = total_val_loss / len(val_dataloader)

  wandb.log({
        "epoch": epoch + 1,
        "avg_train_loss": avg_train_loss,
        "avg_val_loss": avg_val_loss
    })

  # Early Stopping
  if avg_val_loss < best_loss:
    best_loss = avg_val_loss
    counter = 0
    print(f"Validation loss improved.")

  else:
    counter += 1
    print(f"Validation loss did not improve. Counter: {counter}/{patience}")

  # Save latest checkpoint
  print("Saving latest checkpoint...")
  checkpoint = {
      'epoch': epoch,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'scheduler_state_dict': scheduler.state_dict(),
      'best_loss': best_loss,
      'wandb_run_id': wandb.run.id,
  }
  torch.save(checkpoint, os.path.join(colab_dir, f"model{epoch:03d}.pt"))

  if counter >= patience:
    print("Early stopping triggered.")
    break

  print(f"Epoch: {epoch+1}, Avg Train Loss: {avg_train_loss:.4f}, Avg Val Loss: {avg_val_loss:.4f}")

In [None]:
best_model_path = None # os.path.join(colab_dir, 'transformer-from-scratch/modelXXX.pt')
if best_model_path:
  model.load_state_dict(torch.load(best_model_path))

In [None]:
model.eval()

predictions = []
references = []

bleu_score = BLEUScore()

with torch.no_grad():

    test_progress_bar = tqdm(test_dataloader, desc="[Test]")

    for batch in test_progress_bar:

        src = batch['input_ids'].to('cuda')
        tgt = batch['labels'].clone()
        tgt[tgt == -100] = pad_id
        tgt = tgt.to('cuda')

        # Start to generate french sentences with <s>
        decoder_input = torch.full((src.shape[0], 1), bos_id, device=src.device)

        for _ in range(max_len + 1):

            decoder_outp = model(src, decoder_input)
            next_token = decoder_outp.argmax(dim=-1)[:, -1].unsqueeze(1)
            decoder_input = torch.cat([decoder_input, next_token], dim=1)

            if (next_token == eos_id).all():
                break

        pred_text = new_tokenizer.batch_decode(decoder_input, skip_special_tokens=True)
        ref_text = new_tokenizer.batch_decode(tgt, skip_special_tokens=True)
        src_text = new_tokenizer.batch_decode(src, skip_special_tokens=True)

        # Print one example per batch
        print("\n" + "="*50)
        print(f"Source:      {src_text[0]}")
        print(f"Reference:   {ref_text[0]}")
        print(f"Prediction:  {pred_text[0]}")
        print("="*50)

        predictions.extend(pred_text)
        # Wrap references of each prediction with list as BLEU can have multiple references
        references.extend([[r] for r in ref_text])

bleu = bleu_score(predictions, references)
print(f"\nTest BLEU Score: {bleu.item() * 100:.2f}")