## Additive attention
Paper: [Neural machine translation by jointly learning to align and translate](https://arxiv.org/pdf/1409.0473) - Bahdanau et. al 2015

Dataset: [Multi30K English to Deutsche dataset](https://huggingface.co/datasets/bentrevett/multi30k)

Model: Use LSTM as encoder and decoder

#### Model variations
- Stacked LSTM encoder decoder
- BiLSTM encoder + LSTM decoder



In [2]:
from datasets import load_dataset

train_dataset = load_dataset("bentrevett/multi30k", split="train")
print(len(train_dataset))
print(train_dataset[0])
print(train_dataset.column_names)

29000
{'en': 'Two young, White males are outside near many bushes.', 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'}
['en', 'de']


In [3]:
import re

TOKEN_RE = re.compile(r"\w+|[^\w\s]")
def word_tokenize(text):
    text = text.lower().strip()
    return TOKEN_RE.findall(text)

In [4]:
word_tokenize("Hello, world!")

['hello', ',', 'world', '!']

In [5]:
from collections import Counter

PAD, BOS, EOS, UNK = "<pad>", "<bos>", "<eos>", "<unk>"

SPECIAL_TOKENS = [PAD, BOS, EOS, UNK]

def build_vocab(tokenized_texts, max_vocab_size=10000, min_freq=3):
    counter = Counter(token for text in tokenized_texts for token in text)
    vocab = SPECIAL_TOKENS.copy()

    for token, freq in counter.most_common():
        if freq < min_freq:
            break
        if len(vocab) >= max_vocab_size:
            break
        if token not in vocab:
            vocab.append(token)

    vocab_to_index = {token: index for index, token in enumerate(vocab)}
    index_to_vocab = {index: token for token, index in vocab_to_index.items()}

    return vocab_to_index, index_to_vocab

In [6]:
def add_special_tokens(tokens):
    return [BOS] + tokens + [EOS]

In [7]:
def remove_special_tokens(tokens):
    return [token for token in tokens if token not in [PAD, BOS, EOS]]

def encode(token_to_index, text):
    return [token_to_index.get(token, token_to_index[UNK]) for token in text]

def decode(index_to_token, indices):
    return " ".join(remove_special_tokens([index_to_token.get(index, UNK) for index in indices]))


In [8]:
# Load data and build vocab

tokenized_train_dataset = train_dataset.map(lambda x: {"en": word_tokenize(x["en"]), "de": word_tokenize(x["de"])}, batched=False)

en_vocab_to_index, en_index_to_vocab = build_vocab(
    [item["en"] for item in tokenized_train_dataset]
)
de_vocab_to_index, de_index_to_vocab = build_vocab(
    [item["de"] for item in tokenized_train_dataset]
)

In [9]:
print(f"English Vocab size: {len(en_vocab_to_index)}")
print(f"German Vocab size: {len(de_vocab_to_index)}")
print(tokenized_train_dataset[0])

English Vocab size: 4560
German Vocab size: 5422
{'en': ['two', 'young', ',', 'white', 'males', 'are', 'outside', 'near', 'many', 'bushes', '.'], 'de': ['zwei', 'junge', 'weiße', 'männer', 'sind', 'im', 'freien', 'in', 'der', 'nähe', 'vieler', 'büsche', '.']}


In [10]:
def preprocess(batch, source_lang, target_lang, source_vocab_to_index, target_vocab_to_index):
    source_encodings = [encode(source_vocab_to_index, add_special_tokens(word_tokenize(text))) for text in batch[source_lang]]
    target_encodings = [encode(target_vocab_to_index, add_special_tokens(word_tokenize(text))) for text in batch[target_lang]]

    return {"source": source_encodings, "target": target_encodings}


In [None]:
preprocessed_train_dataset = train_dataset.map(
    lambda x: preprocess(x, "de", "en", de_vocab_to_index, en_vocab_to_index),
    batched=True,
    remove_columns=["en", "de"]
)
preprocessed_train_dataset.set_format(type="torch", columns=["source", "target"])

print(preprocessed_train_dataset[0])
print(f"German: {decode(de_index_to_vocab, preprocessed_train_dataset[0]['source'].tolist())}")
print(f"English: {decode(en_index_to_vocab, preprocessed_train_dataset[0]['target'].tolist())}")

{'source': tensor([   1,   18,   27,  215,   31,   85,   20,   89,    7,   15,  115,    3,
        3149,    4,    2]), 'target': tensor([   1,   16,   24,   15,   25,  776,   17,   57,   80,  204, 1305,    5,
           2])}
English: two young , white males are outside near many bushes .
German: zwei junge weiße männer sind im freien in der nähe <unk> büsche .


In [12]:
print(f"type: {type(preprocessed_train_dataset)}")
print(f"length: {len(preprocessed_train_dataset)}")
print(preprocessed_train_dataset[0])
print(preprocessed_train_dataset[1])
print(preprocessed_train_dataset[2])
print(preprocessed_train_dataset[3])

type: <class 'datasets.arrow_dataset.Dataset'>
length: 29000
{'source': tensor([   1,   18,   27,  215,   31,   85,   20,   89,    7,   15,  115,    3,
        3149,    4,    2]), 'target': tensor([   1,   16,   24,   15,   25,  776,   17,   57,   80,  204, 1305,    5,
           2])}
{'source': tensor([   1,   77,   31,   11,  831, 2082,    5,    3,    4,    2]), 'target': tensor([   1,  113,   30,    6,  325,  280,   17, 1180,    4,  712, 3814, 2644,
           5,    2])}
{'source': tensor([  1,   5,  67,  26, 226,   7,   5,   3,  58, 492,   4,   2]), 'target': tensor([   1,    4,   53,   33,  231,   69,    4,  248, 3815,    5,    2])}
{'source': tensor([  1,   5,  13,   7,   6,  47,  41,  30,  12,  14, 546,  10, 684,   5,
        250,   4,   2]), 'target': tensor([  1,   4,   9,   6,   4,  29,  23,  10,  36,   8,   4, 574, 575,   4,
        240,   5,   2])}


In [13]:
import torch

def pad_batch(sequences, pad_idx=0):
    lengths = torch.tensor([len(seq) for seq in sequences])
    max_length = lengths.max().item()
    padded_batch = torch.full((len(sequences), max_length), pad_idx)
    for i, seq in enumerate(sequences):
        end = lengths[i]
        padded_batch[i, :end] = seq
    return padded_batch, lengths


In [14]:
def collate_fn(batch):
    source = [item["source"] for item in batch]
    target = [item["target"] for item in batch]

    source, source_lengths = pad_batch(source) # defer padding till batching
    target, target_lengths = pad_batch(target)

    return {"source": source, "source_lengths": source_lengths, "target": target, "target_lengths": target_lengths}


In [None]:
from torch.utils.data import DataLoader
loader = DataLoader(
    preprocessed_train_dataset,
    batch_size=3,
    shuffle=True,
    collate_fn=collate_fn,
)

batch = next(iter(loader))
print(batch["source"].shape)
print(batch["source_lengths"].shape)
print(batch["target"].shape)
print(batch["target_lengths"].shape)

torch.Size([3, 14])
torch.Size([3])
torch.Size([3, 18])
torch.Size([3])


In [27]:
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout, padding_idx=0):
        super().__init__()
        self.directions = 2
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True, bidirectional=True)

    def forward(self, source_encodings, source_lengths): # source_encodings: (batch_size, max_length), source_lengths: (batch_size)
        B, T = source_encodings.size()
        h_0 = torch.zeros(self.num_layers * self.directions, B, self.hidden_dim)
        c_0 = torch.zeros(self.num_layers * self.directions, B, self.hidden_dim)

        embedded = self.embedding(source_encodings) # (B, T, embedding_dim)

        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, source_lengths, batch_first=True, enforce_sorted=False)
        packed_outputs, (last_hidden, last_cell) = self.lstm(packed_embedded, (h_0, c_0)) # hidden: (num_layers * directions, B, hidden_dim)
        all_hiddens, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True) # all_hiddens: (B, T, hidden_dim * directions)

        return all_hiddens, last_hidden, last_cell


In [28]:
encoder = Encoder(5, 20, 32, 2, 0.0)

random_source_encodings = torch.randint(0, 5, (3, 5))
random_source_lengths = torch.tensor([5, 4, 3])

all_hiddens, last_hidden, last_cell = encoder(random_source_encodings, random_source_lengths)

print(f"all_hiddens (B, T, hidden_dim * directions): {all_hiddens.shape}")
print(f"last_hidden (num_layers * directions, B, hidden_dim): {last_hidden.shape}")
print(f"last_cell (num_layers * directions, B, hidden_dim): {last_cell.shape}")

print(f"last_hidden: {last_hidden}")


all_hiddens (B, T, hidden_dim * directions): torch.Size([3, 5, 64])
last_hidden (num_layers * directions, B, hidden_dim): torch.Size([4, 3, 32])
last_cell (num_layers * directions, B, hidden_dim): torch.Size([4, 3, 32])
last_hidden: tensor([[[ 3.6276e-02, -1.6420e-01, -1.2759e-01,  1.0308e-01,  1.8728e-02,
           5.7011e-02, -2.2099e-01, -3.6541e-02,  3.0304e-02,  7.3551e-02,
           4.2598e-02, -4.3661e-02, -5.9864e-02, -4.8414e-03, -1.6167e-01,
           3.8792e-02,  1.4658e-01, -2.2599e-02,  6.9806e-02,  1.1741e-01,
           6.7002e-02,  5.4525e-02, -6.4591e-02,  1.6797e-02, -1.6648e-01,
          -1.1272e-01, -2.4655e-02, -9.3729e-02,  4.4855e-02,  7.0224e-04,
           1.2342e-02,  1.7151e-02],
         [-2.3649e-02, -1.0516e-01,  9.1599e-02,  1.5202e-01, -1.0838e-01,
           7.9911e-02, -1.0120e-01,  6.1571e-02, -1.6074e-03,  3.2993e-01,
          -1.2820e-01,  3.3278e-02,  9.7030e-02, -9.8207e-02, -1.5009e-01,
          -1.0774e-01,  7.8268e-02,  1.2507e-01,  8.859

In [29]:
encoder = Encoder(len(de_vocab_to_index), 120, 256, 2, 0.0)

all_hiddens, last_hidden, last_cell = encoder(batch["source"], batch["source_lengths"])

print(f"source lengths: {batch['source_lengths']}")
print(f"all_hiddens (B, T, hidden_dim * directions): {all_hiddens.shape}")
print(f"last_hidden (num_layers * directions, B, hidden_dim): {last_hidden.shape}")
print(f"last_cell (num_layers * directions, B, hidden_dim): {last_cell.shape}")

print(all_hiddens[0][13]) # hidden state for padded token after re padding

source lengths: tensor([11, 14, 13])
all_hiddens (B, T, hidden_dim * directions): torch.Size([3, 14, 512])
last_hidden (num_layers * directions, B, hidden_dim): torch.Size([4, 3, 256])
last_cell (num_layers * directions, B, hidden_dim): torch.Size([4, 3, 256])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.,

In [None]:
import torch.nn.functional as F

class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout, padding_idx=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True, bidirectional=False)
        self.encoder_hidden_projection = nn.Linear(2*hidden_dim, hidden_dim)
        self.encoder_cell_projection = nn.Linear(2*hidden_dim, hidden_dim)
        self.output_projection = nn.Linear(hidden_dim, vocab_size)

    def forward(self, target_encodings, target_lengths, encoder_outputs, encoder_last_hidden, encoder_last_cell):
        h_0, c_0 = self._get_initial_hidden_state(encoder_last_hidden, encoder_last_cell)

        embedded = self.embedding(target_encodings) # (B, T, embedding_dim)
        outputs, (h_t, c_t) = self.lstm(embedded, (h_0, c_0)) # (B, T, hidden_dim)
        logits = self.output_projection(outputs) # (B, T, vocab_size)
        log_softmax = F.log_softmax(logits, dim=2) # (B, T, vocab_size)

        return log_softmax

    def _get_initial_hidden_state(self, encoder_hidden):
        h_f = encoder_hidden[0::2] # (num_layers, B, hidden_dim)
        h_b = encoder_hidden[1::2] # (num_layers, B, hidden_dim)
        h_cat = torch.cat((h_f, h_b), dim=2) # (num_layers, B, hidden_dim * 2) concatenate along the hidden dimension
        return torch.tanh(self.encoder_hidden_projection(h_cat)), torch.tanh(self.encoder_cell_projection(h_cat)) # (num_layers, B, hidden_dim)