# Setup

In [None]:
!nvidia-smi

Wed May  3 18:49:36 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   43C    P8     9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from datasets import load_dataset

import math

import re
import random

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, random_split, RandomSampler
from torchtext.data.metrics import bleu_score

from tqdm.notebook import tqdm

torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data

In [None]:
NUM_SAMPLES = 10000
NUM_EPOCHS = 10
TRAIN_SIZE = NUM_SAMPLES * NUM_EPOCHS

In [None]:
train_data = load_dataset("wmt14", "de-en", split=f"train[:{TRAIN_SIZE}]")
val_data = load_dataset("wmt14", "de-en", split="validation")
test_data = load_dataset("wmt14", "de-en", split="test")



## Vocabulary

In [None]:
pad_word = "<pad>"
bos_word = "<s>"
eos_word = "</s>"
unk_word = "<unk>"
pad_id = 0
bos_id = 1
eos_id = 2
unk_id = 3
    
def normalize_sentence(s):
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

class Vocabulary:
    def __init__(self):
        self.word_to_id = {pad_word: pad_id, bos_word: bos_id, eos_word: eos_id, unk_word: unk_id}
        self.word_count = {}
        self.id_to_word = {pad_id: pad_word, bos_id: bos_word, eos_id: eos_word, unk_id: unk_word}
        self.num_words = 4
    
    def get_ids_from_sentence(self, sentence):
        sentence = normalize_sentence(sentence)
        sent_ids = [bos_id] + [self.word_to_id[word] if word in self.word_to_id \
                               else unk_id for word in sentence.split()] + \
                               [eos_id]
        return sent_ids
    
    def tokenized_sentence(self, sentence):
        sent_ids = self.get_ids_from_sentence(sentence)
        return [self.id_to_word[word_id] for word_id in sent_ids]

    def decode_sentence_from_ids(self, sent_ids):
        words = list()
        for i, word_id in enumerate(sent_ids):
            if word_id in [bos_id, eos_id, pad_id]:
                # Skip these words
                continue
            else:
                words.append(self.id_to_word[word_id])
        return ' '.join(words)

    def add_words_from_sentence(self, sentence):
        sentence = normalize_sentence(sentence)
        for word in sentence.split():
            if word not in self.word_to_id:
                # add this word to the vocabulary
                self.word_to_id[word] = self.num_words
                self.id_to_word[self.num_words] = word
                self.word_count[word] = 1
                self.num_words += 1
            else:
                # update the word count
                self.word_count[word] += 1
    
    def prune(self, min_count=1):
        special = set([pad_word, bos_word, eos_word, unk_word])
        word_to_id = {pad_word: pad_id, bos_word: bos_id, eos_word: eos_id, unk_word: unk_id}
        id_to_word = {pad_id: pad_word, bos_id: bos_word, eos_id: eos_word, unk_id: unk_word}
        word_count = {}
        num_words = 4

        for word in self.word_count:
            if self.word_count[word] >= min_count and word not in special:
                word_to_id[word] = num_words
                id_to_word[num_words] = word
                word_count[word] = self.word_count[word]
                num_words += 1
        
        self.word_to_id = word_to_id
        self.id_to_word = id_to_word
        self.word_count = word_count
        self.num_words = num_words

In [None]:
en_vocab = Vocabulary()
de_vocab = Vocabulary()
for item in tqdm(train_data):
    en_vocab.add_words_from_sentence(item['translation']['en'])
    de_vocab.add_words_from_sentence(item['translation']['de'])

print(f"Total words in the English vocabulary = {en_vocab.num_words}")
print(f"Total words in the German vocabulary = {de_vocab.num_words}")

en_vocab.prune(min_count=2)
de_vocab.prune(min_count=2)

print()
print(f"Total words in the English vocabulary after pruning = {en_vocab.num_words}")
print(f"Total words in the German vocabulary after pruning = {de_vocab.num_words}")

  0%|          | 0/100000 [00:00<?, ?it/s]

Total words in the English vocabulary = 28809
Total words in the German vocabulary = 67353

Total words in the English vocabulary after pruning = 19760
Total words in the German vocabulary after pruning = 36138


## Dataset

In [None]:
class TranslationDataset(Dataset):
    def __init__(self, dataset: list):
        self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index: int):
        en = self.dataset[index]['translation']['en']
        de = self.dataset[index]['translation']['de']
        
        en = torch.tensor(en_vocab.get_ids_from_sentence(en))
        de = torch.tensor(de_vocab.get_ids_from_sentence(de))
        return en, de

def collate_fn(batch):
    en_batch = pad_sequence([item[0] for item in batch], batch_first=True)
    de_batch = pad_sequence([item[1] for item in batch], batch_first=True)
    return en_batch, de_batch

In [None]:
train_set = TranslationDataset(train_data)
val_set = TranslationDataset(val_data)
test_set = TranslationDataset(test_data)

In [None]:
BATCH_SIZE = 16

lengths = [(1.0 / NUM_EPOCHS)] * NUM_EPOCHS
train_subsets = random_split(train_set, lengths)
train_loaders = [DataLoader(subset, batch_size=BATCH_SIZE, collate_fn=collate_fn) for subset in train_subsets]

val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, collate_fn=collate_fn)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, collate_fn=collate_fn)

# UniDeT

In [None]:
class UniDeT(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, hidden_size=256, num_layers=3, dropout=0.0):
        super().__init__()
        
        self.src_embedding = nn.Embedding(src_vocab_size, hidden_size, padding_idx=pad_id)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, hidden_size, padding_idx=pad_id)
        
        self.encoder = nn.LSTM(hidden_size, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=True)

        self.l2r_decoder = nn.LSTM(hidden_size, hidden_size, num_layers=2*num_layers, batch_first=True, dropout=dropout)
        self.l2r_classifier = nn.Linear(hidden_size, tgt_vocab_size)


    def forward(self, src, tgt):
        src = self.src_embedding(src)
        tgt = self.tgt_embedding(tgt)
        
        _, (hidden, cell) = self.encoder(src)

        l2r, _ = self.l2r_decoder(tgt, (hidden, cell))
        l2r = self.l2r_classifier(l2r)
        return l2r
    
    def predict(self, src):
        src = self.src_embedding(src)
        
        _, (hidden, cell) = self.encoder(src)

        prev = torch.tensor([bos_id]).to(device)
        l2r = []
        while prev != eos_id:
            input = self.tgt_embedding(prev)
            out, _ = self.l2r_decoder(input, (hidden, cell))
            out = self.l2r_classifier(out)
            l2r.append(out.squeeze(dim=0))
            prev = out.argmax(dim=1)
        l2r = torch.tensor(l2r).to(device)

        output = l2r.argmax(dim=1)
        return output

# BiDeT

In [None]:
class BiDeT(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, hidden_size=256, num_layers=3, dropout=0.0):
        super().__init__()
        
        self.src_embedding = nn.Embedding(src_vocab_size, hidden_size, padding_idx=pad_id)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, hidden_size, padding_idx=pad_id)
        
        self.encoder = nn.LSTM(hidden_size, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=True)

        self.l2r_decoder = nn.LSTM(hidden_size, hidden_size, num_layers=2*num_layers, batch_first=True, dropout=dropout)
        self.l2r_classifier = nn.Linear(hidden_size, tgt_vocab_size)

        self.r2l_decoder = nn.LSTM(hidden_size, hidden_size, num_layers=2*num_layers, batch_first=True, dropout=dropout)
        self.r2l_classifier = nn.Linear(hidden_size, tgt_vocab_size)


    def forward(self, src, tgt):
        src = self.src_embedding(src)
        tgt = self.tgt_embedding(tgt)
        
        _, (hidden, cell) = self.encoder(src)

        l2r, _ = self.l2r_decoder(tgt, (hidden, cell))
        l2r = self.l2r_classifier(l2r)
        
        r2l, _ = self.r2l_decoder(tgt.flip(dims=[0]), (hidden, cell))
        r2l = self.r2l_classifier(r2l)
        r2l = r2l.flip(dims=[0])
        
        output = l2r + r2l
        return output
    
    def predict(self, src):
        src = self.src_embedding(src)
        
        _, (hidden, cell) = self.encoder(src)

        prev = torch.tensor([bos_id]).to(device)
        l2r = []
        while prev != eos_id and len(l2r) < 100:
            print(prev)
            input = self.tgt_embedding(prev)
            out, _ = self.l2r_decoder(input, (hidden, cell))
            out = self.l2r_classifier(out)
            l2r.append(out.squeeze(dim=0))
            prev = out.argmax(dim=1)
        l2r = torch.tensor(l2r).to(device)

        prev = torch.tensor([eos_id]).to(device)
        r2l = []
        while prev != bos_id and len(r2l) < 100:
            input = self.tgt_embedding(prev)
            out, _ = self.r2l_decoder(input, (hidden, cell))
            out = self.r2l_classifier(out)
            r2l.append(out.squeeze(dim=0))
            prev = out.argmax(dim=1)
        r2l = torch.tensor(r2l).to(device).flip(dims=[0])

        pred = l2r + r2l
        output = pred.argmax(dim=1)
        return output

# Training

In [None]:
model = BiDeT(en_vocab.num_words, de_vocab.num_words).to(device)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

In [None]:
unidet = UniDeT(en_vocab.num_words, de_vocab.num_words).to(device)
count_parameters(unidet)

30961962

In [None]:
def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 43,407,444 trainable parameters


In [None]:
def train(model: nn.Module, iterator: DataLoader, 
          optimizer: torch.optim.Optimizer, criterion):
    model.train()
    epoch_loss = 0.0

    for src, trg in tqdm(iterator):
        optimizer.zero_grad()
        src, trg = src.to(device), trg.to(device)
        output = model(src, trg).transpose(1, 2)

        loss = criterion(output, trg)
        epoch_loss += loss.item()

        loss.backward()
        optimizer.step()
    return epoch_loss / len(iterator)


def evaluate(model: nn.Module, iterator: DataLoader, criterion):
    model.eval()
    epoch_loss = 0

    with torch.no_grad():
        for src, trg in iterator:
            src, trg = src.to(device), trg.to(device)
            output = model(src, trg)
            
            loss = criterion(output.transpose(1, 2), trg)
            epoch_loss += loss.item()
    return epoch_loss / len(iterator)

In [None]:
for epoch in range(NUM_EPOCHS):
    train_loss = train(model, train_loaders[epoch], optimizer, criterion)
    print(f'Train Loss: {train_loss:.3f}')
    valid_loss = evaluate(model, val_loader, criterion)
    print(f'Val. Loss: {valid_loss:.3f}')

test_loss = evaluate(model, test_loader, criterion)

print(f'\nTest Loss: {test_loss:.3f}')

  0%|          | 0/625 [00:00<?, ?it/s]

Train Loss: 3.542
Val. Loss: 3.560


  0%|          | 0/625 [00:00<?, ?it/s]

Train Loss: 2.953
Val. Loss: 2.674


  0%|          | 0/625 [00:00<?, ?it/s]

Train Loss: 2.221
Val. Loss: 2.241


  0%|          | 0/625 [00:00<?, ?it/s]

Train Loss: 1.961
Val. Loss: 2.029


  0%|          | 0/625 [00:00<?, ?it/s]

Train Loss: 1.771
Val. Loss: 1.900


  0%|          | 0/625 [00:00<?, ?it/s]

Train Loss: 1.676
Val. Loss: 1.810


  0%|          | 0/625 [00:00<?, ?it/s]

Train Loss: 1.556
Val. Loss: 1.728


  0%|          | 0/625 [00:00<?, ?it/s]

Train Loss: 1.493
Val. Loss: 1.686


  0%|          | 0/625 [00:00<?, ?it/s]

Train Loss: 1.438
Val. Loss: 1.613


  0%|          | 0/625 [00:00<?, ?it/s]

Train Loss: 1.369
Val. Loss: 1.543

Test Loss: 1.656
