## Additive attention
Paper: [Neural machine translation by jointly learning to align and translate](https://arxiv.org/pdf/1409.0473) - Bahdanau et. al 2015

Dataset: [Multi30K English to Deutsche dataset](https://huggingface.co/datasets/bentrevett/multi30k)

Model: Use LSTM as encoder and decoder

#### Model variations
- Stacked LSTM encoder decoder
- BiLSTM encoder + LSTM decoder



In [1]:
from datasets import load_dataset

train_dataset = load_dataset("bentrevett/multi30k", split="train")
print(len(train_dataset))
print(train_dataset[0])
print(train_dataset.column_names)

29000
{'en': 'Two young, White males are outside near many bushes.', 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'}
['en', 'de']


In [2]:
import re

TOKEN_RE = re.compile(r"\w+|[^\w\s]")
def word_tokenize(text):
    text = text.lower().strip()
    return TOKEN_RE.findall(text)

In [3]:
word_tokenize("Hello, world!")

['hello', ',', 'world', '!']

In [4]:
from collections import Counter

PAD, BOS, EOS, UNK = "<pad>", "<bos>", "<eos>", "<unk>"

SPECIAL_TOKENS = [PAD, BOS, EOS, UNK]

def build_vocab(tokenized_texts, max_vocab_size=10000, min_freq=3):
    counter = Counter(token for text in tokenized_texts for token in text)
    vocab = SPECIAL_TOKENS.copy()

    for token, freq in counter.most_common():
        if freq < min_freq:
            break
        if len(vocab) >= max_vocab_size:
            break
        if token not in vocab:
            vocab.append(token)

    vocab_to_index = {token: index for index, token in enumerate(vocab)}
    index_to_vocab = {index: token for token, index in vocab_to_index.items()}

    return vocab_to_index, index_to_vocab

In [5]:
def add_special_tokens(tokens):
    return [BOS] + tokens + [EOS]

In [36]:
def remove_special_tokens(tokens):
    return [token for token in tokens if token not in [PAD, BOS, EOS]]

def encode(token_to_index, text):
    return [token_to_index.get(token, token_to_index[UNK]) for token in text]

def decode(index_to_token, indices):
    return " ".join(remove_special_tokens([index_to_token.get(index, UNK) for index in indices]))


In [7]:
# Load data and build vocab

tokenized_train_dataset = train_dataset.map(lambda x: {"en": word_tokenize(x["en"]), "de": word_tokenize(x["de"])}, batched=False)

en_vocab_to_index, en_index_to_vocab = build_vocab(
    [item["en"] for item in tokenized_train_dataset]
)
de_vocab_to_index, de_index_to_vocab = build_vocab(
    [item["de"] for item in tokenized_train_dataset]
)

In [8]:
print(f"English Vocab size: {len(en_vocab_to_index)}")
print(f"German Vocab size: {len(de_vocab_to_index)}")
print(tokenized_train_dataset[0])

English Vocab size: 4560
German Vocab size: 5422
{'en': ['two', 'young', ',', 'white', 'males', 'are', 'outside', 'near', 'many', 'bushes', '.'], 'de': ['zwei', 'junge', 'weiße', 'männer', 'sind', 'im', 'freien', 'in', 'der', 'nähe', 'vieler', 'büsche', '.']}


In [9]:
def preprocess(batch, source_lang, target_lang, source_vocab_to_index, target_vocab_to_index):
    source_encodings = [encode(source_vocab_to_index, add_special_tokens(word_tokenize(text))) for text in batch[source_lang]]
    target_encodings = [encode(target_vocab_to_index, add_special_tokens(word_tokenize(text))) for text in batch[target_lang]]

    return {"source": source_encodings, "target": target_encodings}


In [10]:
preprocessed_train_dataset = train_dataset.map(
    lambda x: preprocess(x, "de", "en", de_vocab_to_index, en_vocab_to_index),
    batched=True,
    remove_columns=["en", "de"]
)
preprocessed_train_dataset.set_format(type="torch", columns=["source", "target"])

print(preprocessed_train_dataset[0])
print(f"German: {decode(de_index_to_vocab, preprocessed_train_dataset[0]['source'].tolist())}")
print(f"English: {decode(en_index_to_vocab, preprocessed_train_dataset[0]['target'].tolist())}")

{'source': tensor([   1,   18,   27,  215,   31,   85,   20,   89,    7,   15,  115,    3,
        3149,    4,    2]), 'target': tensor([   1,   16,   24,   15,   25,  776,   17,   57,   80,  204, 1305,    5,
           2])}
German: zwei junge weiße männer sind im freien in der nähe <unk> büsche .
English: two young , white males are outside near many bushes .


In [11]:
print(f"type: {type(preprocessed_train_dataset)}")
print(f"length: {len(preprocessed_train_dataset)}")
print(preprocessed_train_dataset[0])
print(preprocessed_train_dataset[1])
print(preprocessed_train_dataset[2])
print(preprocessed_train_dataset[3])

type: <class 'datasets.arrow_dataset.Dataset'>
length: 29000
{'source': tensor([   1,   18,   27,  215,   31,   85,   20,   89,    7,   15,  115,    3,
        3149,    4,    2]), 'target': tensor([   1,   16,   24,   15,   25,  776,   17,   57,   80,  204, 1305,    5,
           2])}
{'source': tensor([   1,   77,   31,   11,  831, 2082,    5,    3,    4,    2]), 'target': tensor([   1,  113,   30,    6,  325,  280,   17, 1180,    4,  712, 3814, 2644,
           5,    2])}
{'source': tensor([  1,   5,  67,  26, 226,   7,   5,   3,  58, 492,   4,   2]), 'target': tensor([   1,    4,   53,   33,  231,   69,    4,  248, 3815,    5,    2])}
{'source': tensor([  1,   5,  13,   7,   6,  47,  41,  30,  12,  14, 546,  10, 684,   5,
        250,   4,   2]), 'target': tensor([  1,   4,   9,   6,   4,  29,  23,  10,  36,   8,   4, 574, 575,   4,
        240,   5,   2])}


In [12]:
import torch

def pad_batch(sequences, pad_idx=0):
    lengths = torch.tensor([len(seq) for seq in sequences])
    max_length = lengths.max().item()
    padded_batch = torch.full((len(sequences), max_length), pad_idx)
    for i, seq in enumerate(sequences):
        end = lengths[i]
        padded_batch[i, :end] = seq
    return padded_batch, lengths


In [13]:
def collate_fn(batch):
    source = [item["source"] for item in batch]
    target = [item["target"] for item in batch]

    source, source_lengths = pad_batch(source) # defer padding till batching
    target, target_lengths = pad_batch(target)

    return {"source": source, "source_lengths": source_lengths, "target": target, "target_lengths": target_lengths}


In [14]:
from torch.utils.data import DataLoader
loader = DataLoader(
    preprocessed_train_dataset,
    batch_size=3,
    shuffle=True,
    collate_fn=collate_fn,
)

batch = next(iter(loader))
print(batch["source"].shape)
print(batch["source_lengths"].shape)
print(batch["target"].shape)
print(batch["target_lengths"].shape)

torch.Size([3, 20])
torch.Size([3])
torch.Size([3, 19])
torch.Size([3])


In [15]:
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout, padding_idx=0):
        super().__init__()
        self.directions = 2
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True, bidirectional=True)
        self.hidden_projection = nn.Linear(2 * hidden_dim, hidden_dim)
        self.cell_projection = nn.Linear(2 * hidden_dim, hidden_dim)

    def forward(self, source_encodings, source_lengths): # source_encodings: (batch_size, max_length), source_lengths: (batch_size)
        B, T = source_encodings.size()
        h_0 = torch.zeros(self.num_layers * self.directions, B, self.hidden_dim)
        c_0 = torch.zeros(self.num_layers * self.directions, B, self.hidden_dim)

        embedded = self.embedding(source_encodings) # (B, T, embedding_dim)

        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, source_lengths, batch_first=True, enforce_sorted=False)
        packed_outputs, (last_hidden, last_cell) = self.lstm(packed_embedded, (h_0, c_0)) # hidden: (num_layers * directions, B, hidden_dim)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True) # outputs: (B, T, hidden_dim * directions)

        h_f = last_hidden[0::2] # (num_layers, B, hidden_dim)
        h_b = last_hidden[1::2] # (num_layers, B, hidden_dim)
        h_cat = torch.cat((h_f, h_b), dim=2) # (num_layers, B, hidden_dim * 2) concatenate along the hidden dimension

        c_f = last_cell[0::2] # (num_layers, B, hidden_dim)
        c_b = last_cell[1::2] # (num_layers, B, hidden_dim)
        c_cat = torch.cat((c_f, c_b), dim=2) # (num_layers, B, hidden_dim * 2) concatenate along the hidden dimension

        last_hidden = torch.tanh(self.hidden_projection(h_cat)) # (num_layers, B, hidden_dim)
        last_cell = torch.tanh(self.cell_projection(c_cat)) # (num_layers, B, hidden_dim)

        return outputs, last_hidden, last_cell


In [16]:
encoder = Encoder(5, 20, 32, 2, 0.0)

random_source_encodings = torch.randint(0, 5, (3, 5))
random_source_lengths = torch.tensor([5, 4, 3])

outputs, last_hidden, last_cell = encoder(random_source_encodings, random_source_lengths)

print(f"outputs (B, T, hidden_dim * directions): {outputs.shape}")
print(f"last_hidden (num_layers * directions, B, hidden_dim): {last_hidden.shape}")
print(f"last_cell (num_layers * directions, B, hidden_dim): {last_cell.shape}")

print(f"last_hidden: {last_hidden}")


outputs (B, T, hidden_dim * directions): torch.Size([3, 5, 64])
last_hidden (num_layers * directions, B, hidden_dim): torch.Size([2, 3, 32])
last_cell (num_layers * directions, B, hidden_dim): torch.Size([2, 3, 32])
last_hidden: tensor([[[ 0.0788, -0.0284,  0.1819, -0.0922,  0.1050, -0.0100, -0.2240,
           0.0224, -0.1046,  0.1614, -0.0043,  0.0925, -0.1726, -0.0251,
          -0.1199, -0.0522,  0.0073, -0.1281, -0.0866, -0.1276,  0.1677,
           0.0379, -0.1042, -0.1321, -0.1767, -0.2932, -0.1137, -0.0041,
           0.0403,  0.1457, -0.0557, -0.0722],
         [ 0.1437, -0.0456,  0.2158, -0.1223,  0.1345, -0.0075, -0.2380,
           0.0491, -0.1186,  0.1785,  0.0361,  0.0780, -0.1970, -0.0717,
          -0.0765, -0.0764,  0.0210, -0.1607, -0.0897, -0.1058,  0.1730,
           0.0217, -0.1530, -0.1606, -0.1826, -0.3062, -0.0887,  0.0848,
           0.0294,  0.1544, -0.0967, -0.0838],
         [-0.1216,  0.0375,  0.1296, -0.0756,  0.0616, -0.0055, -0.1475,
          -0.0135, -

In [17]:
encoder = Encoder(len(de_vocab_to_index), 120, 256, 2, 0.0)

outputs, last_hidden, last_cell = encoder(batch["source"], batch["source_lengths"])

print(f"source lengths: {batch['source_lengths']}")
print(f"outputs (B, T, hidden_dim * directions): {outputs.shape}")
print(f"last_hidden (num_layers * directions, B, hidden_dim): {last_hidden.shape}")
print(f"last_cell (num_layers * directions, B, hidden_dim): {last_cell.shape}")

print(outputs[0][13]) # hidden state for padded token after re padding

source lengths: tensor([14,  9, 20])
outputs (B, T, hidden_dim * directions): torch.Size([3, 20, 512])
last_hidden (num_layers * directions, B, hidden_dim): torch.Size([2, 3, 256])
last_cell (num_layers * directions, B, hidden_dim): torch.Size([2, 3, 256])
tensor([ 4.3524e-02,  1.5765e-02, -2.1352e-02, -1.8995e-02,  2.7187e-02,
         6.0332e-02, -3.6823e-02,  5.0755e-02,  1.5351e-02, -2.9590e-02,
        -1.3017e-02, -2.6168e-02, -9.1572e-03,  3.5017e-02,  4.3192e-02,
         3.2589e-02,  3.3422e-02, -2.4076e-02, -1.2255e-02,  2.2332e-02,
        -3.8486e-02, -2.3635e-02, -8.1174e-02, -8.9897e-02,  3.2353e-02,
         3.5341e-02,  9.8116e-02, -1.1213e-02,  3.3061e-02, -2.4547e-02,
        -2.5842e-02, -5.9218e-02,  4.0249e-03,  1.5704e-02,  1.2798e-02,
         1.4629e-02, -1.2037e-02, -3.0187e-02, -4.5610e-02,  5.2475e-02,
         3.7832e-02, -4.4547e-02, -1.7846e-02,  3.8613e-02,  3.1487e-02,
         2.4184e-02,  1.9012e-02, -4.7324e-02,  4.3101e-02, -8.8056e-03,
        -4.59

In [18]:
class Decoder(nn.Module):
    """
    Decoder for the Seq2Seq model. This works on a batch of single step targets (B, 1).
    It doesn't take the entire target sequence, because the decision to teacher force and what kind of search strategy to use
    is done at the sequence level, not the step level.
    """
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout, padding_idx=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            num_layers,
            dropout=dropout,
            batch_first=True,
            bidirectional=False,
        )
        self.encoder_hidden_projection = nn.Linear(2 * hidden_dim, hidden_dim)
        self.encoder_cell_projection = nn.Linear(2 * hidden_dim, hidden_dim)
        self.output_projection = nn.Linear(hidden_dim, vocab_size)

    def forward(self, target, encoder_last_hidden, encoder_last_cell):
        # Normalize target shape to (B, 1)
        if target.dim() == 0:
            target = target.view(1, 1)
        else:
            target = target.unsqueeze(1)

        target = target.long()

        embedded = self.embedding(target)  # (B, T, embedding_dim)
        # use dropout on the embedded input later
        outputs, (hidden, cell) = self.lstm(embedded, (encoder_last_hidden, encoder_last_cell)) # (B, 1, hidden_dim)
        logits = self.output_projection(outputs) # (B, 1, vocab_size)

        return logits, hidden, cell

In [19]:
decoder = Decoder(len(en_vocab_to_index), 120, 256, 2, 0.0)

# One decoding step for the whole batch (B,)
step0 = batch["target"][:, 0]
print(step0.shape, step0)

logits, hidden, cell = decoder(step0, last_hidden, last_cell)
print(f"logits (B, 1, vocab_size): {logits.shape}")
print(f"hidden (num_layers, B, hidden_dim): {hidden.shape}")
print(f"cell (num_layers, B, hidden_dim): {cell.shape}")
print(logits)


torch.Size([3]) tensor([1, 1, 1])
logits (B, 1, vocab_size): torch.Size([3, 1, 4560])
hidden (num_layers, B, hidden_dim): torch.Size([2, 3, 256])
cell (num_layers, B, hidden_dim): torch.Size([2, 3, 256])
tensor([[[ 0.0064,  0.0087,  0.0349,  ..., -0.0150,  0.0367, -0.0055]],

        [[ 0.0049,  0.0064,  0.0333,  ..., -0.0152,  0.0376, -0.0006]],

        [[ 0.0094,  0.0095,  0.0435,  ..., -0.0171,  0.0397, -0.0075]]],
       grad_fn=<ViewBackward0>)


In [81]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, bos_idx, eos_idx, max_target_length):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.bos_idx = bos_idx
        self.eos_idx = eos_idx
        self.max_target_length = max_target_length
    def forward(self, source, source_lengths, target=None):
        logits_all = []
        preds_all = []
        B = source.shape[0]


        encoder_outputs, hidden, cell = self.encoder(source, source_lengths)

        if target is not None:
            steps = target.shape[1] - 1
            inputs = target[:, 0]
        else:
            steps = self.max_target_length
            inputs = torch.full((B,), self.bos_idx, dtype=torch.long)


        for t in range(steps):
            logits, hidden, cell = self.decoder(inputs, hidden, cell)
            step_logits = logits.squeeze(1) # (B, 1, vocab_size) -> (B, vocab_size)
            step_preds = step_logits.argmax(dim=1)

            logits_all.append(step_logits)
            preds_all.append(step_preds)

            # Teacher forcing
            if target is not None:
                inputs = target[:, t + 1]
            else:
                inputs = step_preds
                if step_preds.eq(self.eos_idx).all():
                    break


        logits_all = torch.stack(logits_all, dim=1) # (B, steps, vocab_size)
        preds_all = torch.stack(preds_all, dim=1) # (B, steps)

        return logits_all, preds_all


In [42]:
criterion = nn.CrossEntropyLoss(ignore_index=de_vocab_to_index[PAD])

In [82]:
model = Seq2Seq(encoder, decoder, de_vocab_to_index[BOS], de_vocab_to_index[EOS], 30)

logits_all, preds_all = model(batch["source"], batch["source_lengths"], batch["target"])
target = batch["target"][:, 1:].reshape(-1)
loss = criterion(logits_all.reshape(-1, logits_all.shape[-1]), target)


print(f"logits_all.shape: {logits_all.shape}")
print(f"preds_all.shape: {preds_all.shape}")
print(f"batch['target'].shape: {batch['target'].shape}")
print(f"loss: {loss}\n")

for i in range(len(batch.items())-1):
    print(f"Source: {decode(de_index_to_vocab, batch['source'][i].tolist())}")
    print(f"Target: {decode(en_index_to_vocab, batch['target'][i].tolist())}")
    print(f"Pred: {decode(en_index_to_vocab, preds_all[i].tolist())}")
    print("-"*100)

logits_all.shape: torch.Size([3, 18, 4560])
preds_all.shape: torch.Size([3, 18])
batch['target'].shape: torch.Size([3, 19])
loss: 8.42668342590332

Source: ein brauner hund läuft mit etwas in seinem mund durch wasser .
Target: a brown dog <unk> the water with something in its mouth .
Pred: reddish balanced he baker puck puck horn slowly garden garden puck feet feet passing passing skater skater skater
----------------------------------------------------------------------------------------------------
Source: zwei mädchen zeigen sich gegenseitig etwas .
Target: two girls showing something to each other .
Pred: floaties mic puck trots bra 6 puck puck puck squinting skater skater skater skater puck puck puck puck
----------------------------------------------------------------------------------------------------
Source: eine frau kauft artikel von einem mann in einem roten hemd , der auf seinem hof verkauft .
Target: a woman is purchasing items from the man in the red shirt ' s yard sale 

In [83]:
tgt_vocab_size = len(en_vocab_to_index) # 4560
src_vocab_size = len(de_vocab_to_index) # 5422
emb_dim = 512
hidden_dim = 512
enc_num_layers = 2
dec_num_layers = 2
enc_dropout = 0.0
dec_dropout = 0.0
padding_idx = de_vocab_to_index[PAD]
bos_idx = de_vocab_to_index[BOS]
eos_idx = de_vocab_to_index[EOS]
max_target_length = 30

batch_size = 64

In [85]:
encoder = Encoder(src_vocab_size, emb_dim, hidden_dim, enc_num_layers, enc_dropout, padding_idx)
decoder = Decoder(tgt_vocab_size, emb_dim, hidden_dim, dec_num_layers, dec_dropout, padding_idx)

model = Seq2Seq(encoder, decoder, bos_idx, eos_idx, max_target_length)

### Model parameter calculation

#### Encoder
- source embedding: vocab * emb_dim + bias
- 2 layers BiLSTM:  
    - direction * gates * ( W_x + W_h + b_x + b_h)
    - direction * gates * ( W_h + W_h + b_h + b_h)
- hidden projection: 2 * (2 * hidden_dim * hidden_dim + b_h)

#### Decoder
- target embedding: vocab * emb_dim + bias
- 2 layers LSTM: 
    - gates * ( W_x + W_h + b_x + b_h)
    - gates * ( W_x + W_h + b_x + b_h)
- output projection: hidden_dim * target_vocab + b_v

In [86]:
p_src_emd = src_vocab_size * emb_dim + src_vocab_size
p_enc_lstm_1 = 2 * 4 * (emb_dim * hidden_dim + hidden_dim * hidden_dim + (4*hidden_dim) + (4*hidden_dim))
p_enc_lstm_2 = 2 * 4 * (hidden_dim * hidden_dim + hidden_dim * hidden_dim + (4*hidden_dim) + (4*hidden_dim))
p_enc_hidden_proj = 2 * (2 * hidden_dim * hidden_dim + hidden_dim)
p_enc_cell_proj = 2 * (2 * hidden_dim * hidden_dim + hidden_dim)

p_encoder = p_src_emd + p_enc_lstm_1 + p_enc_lstm_2 + p_enc_hidden_proj + p_enc_cell_proj

print(f"Encoder parameters: {p_encoder:,}")

p_tgt_emd = tgt_vocab_size * emb_dim + tgt_vocab_size
p_dec_lstm_1 = 4 * (emb_dim * hidden_dim + hidden_dim * hidden_dim + (4*hidden_dim) + (4*hidden_dim))
p_dec_lstm_2 = 4 * (hidden_dim * hidden_dim + hidden_dim * hidden_dim + (4*hidden_dim) + (4*hidden_dim))
p_dec_output_proj = hidden_dim * tgt_vocab_size + tgt_vocab_size

p_decoder = p_tgt_emd + p_dec_lstm_1 + p_dec_lstm_2 + p_dec_output_proj

print(f"Decoder parameters: {p_decoder:,}")

p_seq2seq = p_encoder + p_decoder

print(f"Seq2Seq parameters: {p_seq2seq:,}")



Encoder parameters: 13,334,830
Decoder parameters: 8,905,632
Seq2Seq parameters: 22,240,462


In [87]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The Seq2Seq model has {count_parameters(model):,} trainable parameters')
print(f"Encoder parameters: {count_parameters(encoder):,}")
print(f"Decoder parameters: {count_parameters(decoder):,}")


The Seq2Seq model has 24,253,904 trainable parameters
Encoder parameters: 14,327,808
Decoder parameters: 9,926,096


### Rough estimate of memory needed
- long or float32 = 4 bytes
- memory for model for Adam optimizer = 4 * mode parameters * 4 bytes
- activations of LSTM = 6 * hidden_dim * num_layers * batch_size * seq_len * 4 bytes
- activations embedding = batch_size * seq_len * embed_dim * 2
- hidden projection activations = batch_size * seq_len * hidden_dim * 2
- output project activation = batch_size * seq_len * tgt_vocab

total = sum of above + 30%

In [88]:
model_memory = 4 * 4 * p_seq2seq
activations_memory = 6 * hidden_dim * (enc_num_layers + dec_num_layers) * batch_size * max_target_length * 4
embedding_memory = batch_size * max_target_length * emb_dim * 2
hidden_projection_memory = batch_size * max_target_length * hidden_dim * 2
output_projection_memory = batch_size * max_target_length * tgt_vocab_size

total_activations_memory = activations_memory + embedding_memory + hidden_projection_memory + output_projection_memory
total_memory = model_memory + total_activations_memory


print(f"Total memory: {total_memory:,} bytes")
print(f"Model memory: {model_memory:,} bytes")
print(f"Activations memory: {activations_memory + embedding_memory + hidden_projection_memory + output_projection_memory:,} bytes")
print(f"Total required memory: {total_memory * 1.3:,} bytes")

Total memory: 462,906,592 bytes
Model memory: 355,847,392 bytes
Activations memory: 107,059,200 bytes
Total required memory: 601,778,569.6 bytes
