In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import time
import torch
from tqdm import tqdm
from datasets import load_dataset
from transformers import MarianTokenizer

lr = 3e-4

debug = False

if debug:
  batch_size = 12
  num_epochs = 3
  emb_dim=16
  num_heads=4
  num_blocks=2
  train_size = 0.5
else:
  batch_size = 12
  num_epochs = 10
  emb_dim=256
  num_heads=16
  num_blocks=3
  train_size = 0.9

books = load_dataset("opus_books", "en-fr")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
dataset = books["train"].to_list()

if debug:
  dataset = dataset[:5_000]

In [None]:
books

DatasetDict({
    train: Dataset({
        features: ['id', 'translation'],
        num_rows: 127085
    })
})

In [None]:
len(dataset)

127085

In [None]:
dataset[1]

{'id': '1', 'translation': {'en': 'Alain-Fournier', 'fr': 'Alain-Fournier'}}

In [None]:
num_batches = len(dataset) // batch_size

training_batch_start = 0
training_batch_end = int(num_batches*train_size)
validation_batch_start = training_batch_end + 1
validation_batch_end = num_batches

training_batch_start, training_batch_end, validation_batch_start, validation_batch_end

(0, 9531, 9532, 10590)

In [None]:
model_name = 'Helsinki-NLP/opus-mt-en-fr'
tokenizer = MarianTokenizer.from_pretrained(model_name)

vocab_size = tokenizer.vocab_size
vocab_size



59514

In [None]:
tokenizer.all_special_tokens

['</s>', '<unk>', '<pad>']

In [None]:
tokenizer.encode(tokenizer.all_special_tokens)

[0, 1, 59513, 0]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
tokenized_dataset_en = [tokenizer.encode(d["translation"]["en"])[:-1] for d in dataset]  # don't need EOS token for English sentences
tokenized_dataset_fr = [tokenizer.encode(d["translation"]["fr"]) for d in dataset]

max_en_len = max(len(tokens) for tokens in tokenized_dataset_en)
max_fr_len = max(len(tokens) for tokens in tokenized_dataset_fr)


print(f"{max_en_len=}, {max_fr_len=}")

Token indices sequence length is longer than the specified maximum sequence length for this model (533 > 512). Running this sequence through the model will result in indexing errors


max_en_len=532, max_fr_len=884


In [None]:
def get_batch(batch_id, batch_size=8, max_en_len=max_en_len, max_fr_len=max_fr_len):
  en_tokens = tokenized_dataset_en[batch_id*batch_size:(1+batch_id)*batch_size]
  fr_tokens = tokenized_dataset_fr[batch_id*batch_size:(1+batch_id)*batch_size]

  # Pad sequences
  en_tokens_padded = [tokens + [tokenizer.pad_token_id] * (max_en_len - len(tokens)) for tokens in en_tokens]
  fr_tokens_padded = [tokens + [tokenizer.pad_token_id] * (max_fr_len - len(tokens)) for tokens in fr_tokens]

  en_tokens_tensor = torch.tensor(en_tokens_padded).to(device)
  fr_tokens_tensor = torch.tensor(fr_tokens_padded).to(device)

  # Create labels (shifted French tokens)
  ys = [tokens[1:] + [tokenizer.pad_token_id] * (max_fr_len - len(tokens[1:])) for tokens in fr_tokens]
  ys_tensor = torch.tensor(ys).to(device)

  return en_tokens_tensor, fr_tokens_tensor, ys_tensor

In [None]:
get_batch(1)[0].shape, get_batch(1)[1].shape, get_batch(1)[2].shape

(torch.Size([8, 532]), torch.Size([8, 884]), torch.Size([8, 884]))

In [None]:
class EncoderTransformer(torch.nn.Module):
  def __init__(self, emb_dim=32, num_heads=8):
    super().__init__()
    self.mha = torch.nn.MultiheadAttention(embed_dim=emb_dim, num_heads=num_heads, dropout=0.1, batch_first=True, device=device)
    self.layer_norm1 = torch.nn.LayerNorm(emb_dim)
    self.layer_norm2 = torch.nn.LayerNorm(emb_dim)
    self.ffn = torch.nn.Sequential(
        torch.nn.Linear(emb_dim, emb_dim),
        torch.nn.ReLU(),
        torch.nn.Linear(emb_dim, emb_dim)
    )

  def forward(self, X, padding_mask=None):
    res = X
    attn, _ = self.mha(X, X, X, need_weights=False, key_padding_mask=padding_mask)
    X = self.layer_norm1(res + attn) # note: modern LLMs use pre-LN, original paper implementation here is using post-LN

    res = X
    X = self.layer_norm2(res + self.ffn(X))
    return X


class Encoder(torch.nn.Module):
  def __init__(self, emb_dim=32, num_heads=8, num_blocks=2):
    super().__init__()
    self.embedding = torch.nn.Embedding(vocab_size, emb_dim)
    self.positional_encoding = torch.nn.Embedding(max_en_len, emb_dim)
    self.transformer_blocks = torch.nn.ModuleList([EncoderTransformer(emb_dim=emb_dim, num_heads=num_heads) for i in range(num_blocks)])

  def forward(self, X, padding_mask=None):
    # shape of X = [batch_size, seq_len]
    positions = torch.arange(X.shape[1]).to(device).unsqueeze(0)
    X = self.embedding(X) + self.positional_encoding(positions) # shape = [batch_size, seq_len, emb_dim]
    for layer in self.transformer_blocks:
      X = layer(X, padding_mask)

    return X


class DecoderTransformer(torch.nn.Module):
  def __init__(self, emb_dim=32, num_heads=8):
    super().__init__()
    self.mh_self_attention = torch.nn.MultiheadAttention(embed_dim=emb_dim, num_heads=num_heads, dropout=0.1, batch_first=True, device=device, )
    self.mh_cross_attention = torch.nn.MultiheadAttention(embed_dim=emb_dim, num_heads=num_heads, dropout=0.1, batch_first=True, device=device)
    self.layer_norm1 = torch.nn.LayerNorm(emb_dim)
    self.layer_norm2 = torch.nn.LayerNorm(emb_dim)
    self.layer_norm3 = torch.nn.LayerNorm(emb_dim)
    self.ffn = torch.nn.Sequential(
        torch.nn.Linear(emb_dim, emb_dim),
        torch.nn.ReLU(),
        torch.nn.Linear(emb_dim, emb_dim)
    )

  def forward(self, X, encoder_X, encoder_padding_mask=None, decoder_padding_mask=None):
    res = X
    attn_mask = torch.tril(torch.ones(X.shape[1], X.shape[1])).to(device)
    attn1, _ = self.mh_self_attention(X, X, X, is_causal=True, need_weights=False, attn_mask=attn_mask, key_padding_mask=decoder_padding_mask)
    X = self.layer_norm1(res + attn1) # note: modern LLMs use pre-LN, original paper implementation here is using post-LN

    res = X
    attn2, _ = self.mh_cross_attention(X, encoder_X, encoder_X, need_weights=False, key_padding_mask=encoder_padding_mask)
    X = self.layer_norm2(res + attn2) # note: modern LLMs use pre-LN, original paper implementation here is using post-LN

    res = X
    X = self.layer_norm3(res + self.ffn(X))
    return X


class Decoder(torch.nn.Module):
  def __init__(self, emb_dim=32, num_heads=8, num_blocks=2):
    super().__init__()
    self.embedding = torch.nn.Embedding(vocab_size, emb_dim)
    self.positional_encoding = torch.nn.Embedding(max_fr_len, emb_dim)
    self.transformer_blocks = torch.nn.ModuleList([DecoderTransformer(emb_dim=emb_dim, num_heads=num_heads) for i in range(num_blocks)])

  def forward(self, X, X_encoder, encoder_padding_mask=None, decoder_padding_mask=None):
    # shape of X = [batch_size, seq_len]
    positions = torch.arange(X.shape[1]).to(device).unsqueeze(0)
    X = self.embedding(X) + self.positional_encoding(positions) # shape = [batch_size, seq_len, emb_dim]

    for layer in self.transformer_blocks:
      X = layer(X, X_encoder, encoder_padding_mask, decoder_padding_mask)

    return X


class Model(torch.nn.Module):
  def __init__(self, emb_dim=32, num_heads=4, num_blocks=2):
    super().__init__()
    self.encoder = Encoder(emb_dim=emb_dim, num_heads=num_heads, num_blocks=num_blocks)
    self.decoder = Decoder(emb_dim=emb_dim, num_heads=num_heads, num_blocks=num_blocks)
    self.linear = torch.nn.Linear(emb_dim, vocab_size)
    self.loss = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

  def forward(self, X_en, X_fr, y=None):
    X_en_padding = (X_en == tokenizer.pad_token_id)
    X_fr_padding = (X_fr == tokenizer.pad_token_id)

    X_encoder = self.encoder(X_en, X_en_padding)
    X = self.decoder(X_fr, X_encoder, X_en_padding, X_fr_padding)
    logits = self.linear(X)
    if y is not None:
      return logits, self.loss(logits.view(-1, logits.size(-1)), y.reshape(-1))
    else:
      return logits, None

model = Model(emb_dim=emb_dim, num_heads=num_heads, num_blocks=num_blocks).to(device)

In [None]:
f"Number of model parameters: {sum(len(param) for param in model.parameters()):,}"

'Number of model parameters: 271,728'

In [None]:
def eval_model():
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for i in tqdm(range(validation_batch_start, validation_batch_end)):
            X_en, X_fr, y = get_batch(i, batch_size=batch_size)
            preds, batch_loss = model(X_en, X_fr, y)
            total_loss += batch_loss.item()
        total_loss /= (validation_batch_end - validation_batch_start)
    return total_loss

import time
time_ = int(time.time())
print(f"{time_}")


optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

for epoch in range(num_epochs):
  model.train()
  training_loss = 0
  for i in tqdm(range(training_batch_start, training_batch_end)):
    X_en, X_fr, y = get_batch(i, batch_size=batch_size)
    preds, batch_loss = model(X_en, X_fr, y)


    optimizer.zero_grad()
    batch_loss.backward()
    optimizer.step()

    training_loss += batch_loss.item()
  training_loss /= (training_batch_end - training_batch_start)
  validation_loss = eval_model()
  print(f"\n{epoch=}, {training_loss=}, {validation_loss=}")

  torch.save(
      {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
        "training_loss": training_loss,
        "validation_loss": validation_loss
      },
      f"/content/drive/MyDrive/transformer_model_checkpoints/{time_}_ep{epoch}_vl{validation_loss:.2f}_en-fr-MT_model.pth")


1756343129


100%|██████████| 9531/9531 [1:55:52<00:00,  1.37it/s]
100%|██████████| 1058/1058 [04:29<00:00,  3.92it/s]



epoch=0, training_loss=0.7251606587749327, validation_loss=0.058508697710194076


100%|██████████| 9531/9531 [1:55:54<00:00,  1.37it/s]
100%|██████████| 1058/1058 [04:29<00:00,  3.93it/s]



epoch=1, training_loss=0.07137991474920877, validation_loss=0.03246214348512098


100%|██████████| 9531/9531 [1:55:55<00:00,  1.37it/s]
100%|██████████| 1058/1058 [04:29<00:00,  3.93it/s]



epoch=2, training_loss=0.03632437859526867, validation_loss=0.02101494767434222


 97%|█████████▋| 9238/9531 [1:52:21<03:34,  1.37it/s]

In [None]:
from google.colab import runtime

# Disconnects and deletes the current runtime
runtime.unassign()