In [56]:
!pip install datasets



In [110]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
import nltk
from nltk.tokenize import word_tokenize
from collections import Counter
import torch.optim as optim
from tqdm import tqdm
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu
from nltk.translate.bleu_score import SmoothingFunction
from concurrent.futures import ThreadPoolExecutor
import threading

nltk.download("punkt")
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [58]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Loading Dataset

In [59]:
dataset = load_dataset("wmt14", "de-en")

In [72]:
dataset['train'] = dataset['train'].select(range(500000))

# Transformer Architecture

In [73]:
class SelfAttention(nn.Module):
  def __init__(self, embed_size, num_heads):
    super(SelfAttention, self).__init__()
    self.embed_size = embed_size
    self.num_heads = num_heads
    self.head_dim = embed_size // num_heads

    assert (
      self.head_dim * self.num_heads == self.embed_size
    ), "The embed_size must be divisible by num_heads."

    self.values = nn.Linear(embed_size, embed_size)
    self.keys = nn.Linear(embed_size, embed_size)
    self.queries = nn.Linear(embed_size, embed_size)
    self.fc_out = nn.Linear(embed_size, embed_size)

  def forward(self, values, keys, queries, mask):
    N = queries.shape[0] # Number of training examples

    values_len, keys_len, queries_len = values.shape[1], keys.shape[1], queries.shape[1]

    values = self.values(values)
    keys = self.keys(keys)
    queries = self.queries(queries)

    # Split the embedding into self.num_heads different parts
    # embed_size => self.num_heads * self.head_dim
    values = values.reshape(N, values_len, self.num_heads, self.head_dim)
    keys = keys.reshape(N, keys_len, self.num_heads, self.head_dim)
    queries = queries.reshape(N, queries_len, self.num_heads, self.head_dim)

    # queries shape: (N, query_len, num_heads, heads_dim) = nqhd
    # keys shape: (N, key_len, num_heads, heads_dim) = nkhd
    # temp shape: (N, num_heads, query_len, key_len) = nhqk
    temp = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

    if mask is not None:
      MINUS_INF = float("-1e16")
      temp = temp.masked_fill(mask == 0, MINUS_INF)

    attention = torch.softmax(temp / (self.embed_size ** (1/2)), dim=3)

    # attention shape: (N, num_heads, query_len, key_len)
    # values shape: (N, value_len, num_heads, heads_dim)
    # out shape: (N, query_len, num_heads, head_dim)
    out = torch.einsum("nhql,nlhd->nqhd", [attention, values])

    # concatenate
    out = out.reshape(
      N, queries_len, self.num_heads * self.head_dim
    )

    out = self.fc_out(out)

    return out

In [74]:
class TransformerBlock(nn.Module):
  def __init__(self, embed_size, num_heads, dropout, forward_expansion):
    super(TransformerBlock, self).__init__()
    self.attention = SelfAttention(embed_size, num_heads)

    self.norm1 = nn.LayerNorm(embed_size)

    self.feed_forward = nn.Sequential(
      nn.Linear(embed_size, forward_expansion * embed_size),
      nn.ReLU(),
      nn.Linear(forward_expansion * embed_size, embed_size),
    )

    self.norm2 = nn.LayerNorm(embed_size)

    self.dropout = nn.Dropout(dropout)

  def forward(self, values, keys, queries, mask):
    attention_out = self.attention(values, keys, queries, mask)

    norm1_out = self.dropout(self.norm1(attention_out + queries)) # the queries represent the skip connection

    feed_forward_out = self.feed_forward(norm1_out)

    out = self.dropout(self.norm2(feed_forward_out + norm1_out)) # the norm1_out represent the skip connection

    return out

In [75]:
class Encoder(nn.Module):
  def __init__(
    self,
    vocab_size,
    embed_size,
    num_layers,
    num_heads,
    device,
    forward_expansion,
    dropout,
    max_length # for positional embedding
  ):
    super(Encoder, self).__init__()
    self.embed_size = embed_size
    self.device = device
    self.word_embedding = nn.Embedding(vocab_size, embed_size)
    self.position_embedding = nn.Embedding(max_length, embed_size)

    self.layers = nn.ModuleList(
      [
        TransformerBlock(
          embed_size,
          num_heads,
          dropout=dropout,
          forward_expansion=forward_expansion
        )
        for _ in range(num_layers)
      ]
    )

    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask):
    N, seq_length = x.shape
    positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
    out = self.word_embedding(x) + self.position_embedding(positions)
    out = self.dropout(out)

    for layer in self.layers:
      out = layer(out, out, out, mask) # values, keys and queries are all the same

    return out

In [76]:
class DecoderBlock(nn.Module):
  def __init__(
      self,
      embed_size,
      num_heads,
      forward_expansion,
      dropout,
      device
    ):
    super(DecoderBlock, self).__init__()
    self.attention = SelfAttention(embed_size, num_heads)
    self.norm = nn.LayerNorm(embed_size)
    self.transformer_block = TransformerBlock(
      embed_size, num_heads, dropout, forward_expansion
    )
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, value, key, src_mask, trg_mask):
    attention_out = self.attention(x, x, x, trg_mask)
    norm_out = self.dropout(self.norm(attention_out + x)) # the x represents the skip connection
    out = self.transformer_block(value, key, norm_out, src_mask)
    return out

In [77]:
class Decoder(nn.Module):
  def __init__(
    self,
    vocab_size,
    embed_size,
    num_layers,
    num_heads,
    forward_expansion,
    dropout,
    device,
    max_length,
  ):
    super(Decoder, self).__init__()
    self.device = device
    self.word_embedding = nn.Embedding(vocab_size, embed_size)
    self.position_embedding = nn.Embedding(max_length, embed_size)

    self.layers = nn.ModuleList(
      [
        DecoderBlock(embed_size, num_heads, forward_expansion, dropout, device)
        for _ in range(num_layers)
      ]
    )
    self.fc_out = nn.Linear(embed_size, vocab_size)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, enc_out, src_mask, trg_mask):
    N, seq_length = x.shape
    positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
    x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

    for layer in self.layers:
      x = layer(x, enc_out, enc_out, src_mask, trg_mask)

    out = self.fc_out(x)

    return out

In [78]:
class Transformer(nn.Module):
  def __init__(
    self,
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    trg_pad_idx,
    embed_size=512,
    num_layers=6,
    forward_expansion=4,
    heads=8,
    dropout=0,
    device="cpu",
    max_length=100,
  ):

    super(Transformer, self).__init__()

    self.encoder = Encoder(
      src_vocab_size,
      embed_size,
      num_layers,
      heads,
      device,
      forward_expansion,
      dropout,
      max_length,
    )

    self.decoder = Decoder(
      trg_vocab_size,
      embed_size,
      num_layers,
      heads,
      forward_expansion,
      dropout,
      device,
      max_length,
    )

    self.src_pad_idx = src_pad_idx
    self.trg_pad_idx = trg_pad_idx
    self.device = device

  def make_src_mask(self, src):
    src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
    # (N, 1, 1, src_len)
    return src_mask.to(self.device)

  def make_trg_mask(self, trg):
    N, trg_len = trg.shape
    trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
        N, 1, trg_len, trg_len
    )

    return trg_mask.to(self.device)

  def forward(self, src, trg):
    src_mask = self.make_src_mask(src)
    trg_mask = self.make_trg_mask(trg)
    enc_src = self.encoder(src, src_mask)
    out = self.decoder(trg, enc_src, src_mask, trg_mask)
    return out

In [79]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(
    device
)
trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

src_pad_idx = 0
trg_pad_idx = 0
src_vocab_size = 10
trg_vocab_size = 10
model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(
    device
)
out = model(x, trg[:, :-1])
print(out.shape)

cuda
torch.Size([2, 7, 10])


# Preprocessing

In [80]:
def build_vocab(data, language, vocab_size=10000):
  token_counter = Counter()
  for translation in data["translation"]:
    tokens = word_tokenize(translation[language].lower())
    token_counter.update(tokens)
  most_common = token_counter.most_common(vocab_size - 4)  # Reserve 4 for special tokens
  vocab = {word: idx + 4 for idx, (word, _) in enumerate(most_common)}
  vocab["<PAD>"] = 0 # Padding
  vocab["<SOS>"] = 1 # Start of sequence
  vocab["<EOS>"] = 2 # End of sequence
  vocab["<UNK>"] = 3 # Unknown
  return vocab

In [81]:
src_vocab = build_vocab(dataset["train"], "en")
trg_vocab = build_vocab(dataset["train"], "de")

In [96]:
id_to_en = {v: k for k, v in src_vocab.items()}
id_to_de = {v: k for k, v in trg_vocab.items()}

In [83]:
def preprocess_translation(example, src_vocab, trg_vocab):
  src_text = word_tokenize(example["translation"]["en"].lower())
  trg_text = word_tokenize(example["translation"]["de"].lower())

  src_ids = [src_vocab.get(token, src_vocab["<UNK>"]) for token in src_text]
  trg_ids = [trg_vocab.get(token, trg_vocab["<UNK>"]) for token in trg_text]

  return (
    [src_vocab["<SOS>"]] + src_ids + [src_vocab["<EOS>"]],
    [trg_vocab["<SOS>"]] + trg_ids + [trg_vocab["<EOS>"]],
  )

train_data_pairs = [
    preprocess_translation(example, src_vocab, trg_vocab) for example in dataset["train"]
]
validation_data_pairs = [
    preprocess_translation(example, src_vocab, trg_vocab) for example in dataset["validation"]
]
test_data_pairs = [
    preprocess_translation(example, src_vocab, trg_vocab) for example in dataset["test"]
]

In [84]:
print(train_data_pairs[0])
dataset["train"][0]

([1, 5090, 7, 4, 1621, 2], [1, 3606, 7, 3137, 2])


{'translation': {'de': 'Wiederaufnahme der Sitzungsperiode',
  'en': 'Resumption of the session'}}

In [85]:
def pad_sequence(seq, max_length, pad_idx):
  return seq + [pad_idx] * (max_length - len(seq)) if len(seq) < max_length else seq[:max_length]

In [86]:
class TranslationDataset(Dataset):
  def __init__(self, data, src_pad_idx, trg_pad_idx, max_length=50):
    self.data = data
    self.src_pad_idx = src_pad_idx
    self.trg_pad_idx = trg_pad_idx
    self.max_length = max_length

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    src, trg = self.data[idx]
    src = pad_sequence(src, self.max_length, self.src_pad_idx)
    trg = pad_sequence(trg, self.max_length, self.trg_pad_idx)
    return torch.tensor(src), torch.tensor(trg)

In [87]:
src_pad_idx = src_vocab["<PAD>"]
trg_pad_idx = trg_vocab["<PAD>"]

train_dataset = TranslationDataset(train_data_pairs, src_pad_idx, trg_pad_idx)
validation_dataset = TranslationDataset(validation_data_pairs, src_pad_idx, trg_pad_idx)
test_dataset = TranslationDataset(test_data_pairs, src_pad_idx, trg_pad_idx)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

# Tranining The Transformer

In [88]:
def evaluate_model(loader, model, criterion, device):
  model.eval()
  total_loss = 0

  with torch.no_grad():
    for src, trg in loader:
      src, trg = src.to(device), trg.to(device)

      output = model(src, trg[:, :-1])
      output = output.reshape(-1, output.shape[2])
      trg = trg[:, 1:].reshape(-1)

      loss = criterion(output, trg)
      total_loss += loss.item()

  return total_loss / len(loader)

In [90]:
src_vocab_size = len(src_vocab)
trg_vocab_size = len(trg_vocab)
model = Transformer(
    src_vocab_size=src_vocab_size,
    trg_vocab_size=trg_vocab_size,
    src_pad_idx=src_pad_idx,
    trg_pad_idx=trg_pad_idx,
    device=device,
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=src_pad_idx)
optimizer = optim.Adam(model.parameters(), lr=3e-4)

EPOCHS = 3

for epoch in range(EPOCHS):
  model.train()
  loop = tqdm(train_loader, leave=True)
  epoch_loss = 0
  for batch_idx, (src, trg) in enumerate(loop):
    src, trg = src.to(device), trg.to(device)

    output = model(src, trg[:, :-1])
    output = output.reshape(-1, output.shape[2])
    trg = trg[:, 1:].reshape(-1)

    loss = criterion(output, trg)
    epoch_loss += loss.item()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    loop.set_description(f"Epoch [{epoch+1}/{EPOCHS}]")
    loop.set_postfix(loss=loss.item())

  val_loss = evaluate_model(validation_loader, model, criterion, device)
  print(f"Epoch {epoch+1} completed. Training Loss: {epoch_loss / len(train_loader):.4f}, Validation Loss: {val_loss:.4f}")

test_loss = evaluate_model(test_loader, model, criterion, device)
print(f"Test Loss: {test_loss:.4f}")

Epoch [1/3]: 100%|██████████| 7813/7813 [40:38<00:00,  3.20it/s, loss=1.97]


Epoch 1 completed. Training Loss: 2.9196, Validation Loss: 2.5239


Epoch [2/3]: 100%|██████████| 7813/7813 [40:33<00:00,  3.21it/s, loss=2.07]


Epoch 2 completed. Training Loss: 1.9441, Validation Loss: 2.3028


Epoch [3/3]: 100%|██████████| 7813/7813 [40:32<00:00,  3.21it/s, loss=1.57]


Epoch 3 completed. Training Loss: 1.7412, Validation Loss: 2.2023
Test Loss: 2.1899


In [None]:
test_loss = evaluate_model(test_loader, model, criterion, device)

In [None]:
print(test_loss)

In [91]:
def translate_sentence(sentence, model, src_vocab, trg_vocab, src_pad_idx, trg_pad_idx, max_length=50, device="cpu"):
  model.eval()

  tokens = word_tokenize(sentence.lower())
  src_ids = [src_vocab.get(token, src_vocab["<UNK>"]) for token in tokens]
  src_ids = [src_vocab["<SOS>"]] + src_ids + [src_vocab["<EOS>"]]

  src_ids = src_ids + [src_pad_idx] * (max_length - len(src_ids)) if len(src_ids) < max_length else src_ids[:max_length]
  src_tensor = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device)

  trg_ids = [trg_vocab["<SOS>"]]

  for _ in range(max_length):
    trg_tensor = torch.tensor(trg_ids, dtype=torch.long).unsqueeze(0).to(device)  # Add batch dimension

    with torch.no_grad():
      output = model(src_tensor, trg_tensor)

    next_token = output.argmax(dim=-1)[:, -1].item()

    trg_ids.append(next_token)

    if next_token == trg_vocab["<EOS>"]:
      break

  trg_tokens = [id_to_de[id] for id in trg_ids]

  translated_sentence = " ".join(trg_tokens[1:-1])

  return translated_sentence

In [92]:
input_sentence = "Hello, how are you?"

translated_sentence = translate_sentence(
    sentence=input_sentence,
    model=model,
    src_vocab=src_vocab,
    trg_vocab=trg_vocab,
    src_pad_idx=src_vocab["<PAD>"],
    trg_pad_idx=trg_vocab["<PAD>"],
    max_length=50,
    device=device
)

print(f"Translated Sentence: {translated_sentence}")

Translated Sentence: wie sollen sie denn sein ?


In [109]:
def evaluate_bleu_score(model, test_loader, src_vocab, trg_vocab, device):
  references = []
  predictions = []

  for src, trg in test_loader:
    src = src.to(device)
    trg = trg.to(device)

    for i in range(src.shape[0]):
      src_sentence = [id_to_en[idx] for idx in src[i].tolist() if idx not in [src_pad_idx]]
      src_sentence = " ".join(src_sentence)
      reference = [id_to_de[idx] for idx in trg[i].tolist() if idx not in [trg_pad_idx]]
      reference = " ".join(reference)

      prediction = translate_sentence(
        sentence=src_sentence,
        model=model,
        src_vocab=src_vocab,
        trg_vocab=trg_vocab,
        src_pad_idx=src_vocab["<PAD>"],
        trg_pad_idx=trg_vocab["<PAD>"],
        device=device
      )

      references.append(reference)
      predictions.append(prediction)

    print("Progress:", i+1, "/", len(test_loader))


  smooth_fn = SmoothingFunction().method4
  bleu_score = corpus_bleu(references, predictions, smoothing_function=smooth_fn)
  return bleu_score

bleu = evaluate_bleu_score(model, test_loader, src_vocab, trg_vocab, device)
print(f"BLEU Score: {bleu:.4f}")

Progress: 64 / 47


KeyboardInterrupt: 

In [113]:
def evaluate_bleu_score(model, test_loader, src_vocab, trg_vocab, device):
  references = []
  predictions = []

  progress_counter = threading.Lock()
  completed_batches = 0

  def process_batch_item(idx):
    src_sentence = [id_to_en[token] for token in src[idx].tolist() if token not in [src_pad_idx]]
    src_sentence = " ".join(src_sentence)

    reference = [id_to_de[token] for token in trg[idx].tolist() if token not in [trg_pad_idx]]
    reference = " ".join(reference)

    prediction = translate_sentence(
      sentence=src_sentence,
      model=model,
      src_vocab=src_vocab,
      trg_vocab=trg_vocab,
      src_pad_idx=src_vocab["<PAD>"],
      trg_pad_idx=trg_vocab["<PAD>"],
      device=device
    )
    return reference, prediction

  for batch_idx, (src, trg) in enumerate(test_loader):
    src = src.to(device)
    trg = trg.to(device)

    with ThreadPoolExecutor() as executor:
      results = list(executor.map(process_batch_item, range(src.shape[0])))

    with progress_counter:
      for reference, prediction in results:
        references.append(reference)
        predictions.append(prediction)

      completed_batches += 1
      print(f"Progress: {completed_batches} / {len(test_loader)}")

  smooth_fn = SmoothingFunction().method4
  bleu_score = corpus_bleu(references, predictions, smoothing_function=smooth_fn)
  return bleu_score

bleu = evaluate_bleu_score(model, test_loader, src_vocab, trg_vocab, device)
print(f"BLEU Score: {bleu:.4f}")

Progress: 1 / 47
Progress: 2 / 47
Progress: 3 / 47
Progress: 4 / 47
Progress: 5 / 47
Progress: 6 / 47
Progress: 7 / 47
Progress: 8 / 47
Progress: 9 / 47
Progress: 10 / 47
Progress: 11 / 47
Progress: 12 / 47
Progress: 13 / 47
Progress: 14 / 47
Progress: 15 / 47
Progress: 16 / 47
Progress: 17 / 47
Progress: 18 / 47
Progress: 19 / 47
Progress: 20 / 47
Progress: 21 / 47
Progress: 22 / 47
Progress: 23 / 47
Progress: 24 / 47
Progress: 25 / 47
Progress: 26 / 47
Progress: 27 / 47
Progress: 28 / 47
Progress: 29 / 47
Progress: 30 / 47
Progress: 31 / 47
Progress: 32 / 47
Progress: 33 / 47
Progress: 34 / 47
Progress: 35 / 47
Progress: 36 / 47
Progress: 37 / 47
Progress: 38 / 47
Progress: 39 / 47
Progress: 40 / 47
Progress: 41 / 47
Progress: 42 / 47
Progress: 43 / 47
Progress: 44 / 47
Progress: 45 / 47
Progress: 46 / 47
Progress: 47 / 47
BLEU Score: 0.0000
