## Additive attention
Paper: [Neural machine translation by jointly learning to align and translate](https://arxiv.org/pdf/1409.0473) - Bahdanau et. al 2015

Dataset: [Multi30K English to Deutsche dataset](https://huggingface.co/datasets/bentrevett/multi30k)

Model: Use LSTM as encoder and decoder

#### Model variation
- BiLSTM encoder + LSTM decoder

This is the additive attention implementation which reuses lot of components built from the seq2seq training.

In [3]:
try:
    import google.colab
    is_running_on_colab = True
except ImportError:
    is_running_on_colab = False

In [2]:
# colab specific
%pip install early-stopping-pytorch

/Users/niranjan/Documents/programming/AI/deeplearning/projects/deeplearning-practice/.venv/bin/python: No module named pip
Note: you may need to restart the kernel to use updated packages.


In [4]:
import random
import numpy as np
import torch

SEED = 1557
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x10f8b3370>

In [5]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from utils.multi30k_data_processing_utils import word_tokenize, build_vocab, preprocess, decode, collate_fn
from utils.multi30k_data_processing_utils import BOS

train_dataset = load_dataset("bentrevett/multi30k", split="train")
tokenized_train_dataset = train_dataset.map(lambda x: {"en": word_tokenize(x["en"]), "de": word_tokenize(x["de"])}, batched=False)

en_vocab_to_index, en_index_to_vocab = build_vocab(
    [item["en"] for item in tokenized_train_dataset]
)
de_vocab_to_index, de_index_to_vocab = build_vocab(
    [item["de"] for item in tokenized_train_dataset]
)

preprocessed_train_dataset = train_dataset.map(
    lambda x: preprocess(x, "de", "en", de_vocab_to_index, en_vocab_to_index),
    batched=True,
    remove_columns=["en", "de"]
)
preprocessed_train_dataset.set_format(type="torch", columns=["source", "target"])

print(preprocessed_train_dataset[0])
print(f"German: {decode(de_index_to_vocab, preprocessed_train_dataset[0]['source'].tolist())}")
print(f"English: {decode(en_index_to_vocab, preprocessed_train_dataset[0]['target'].tolist())}")

loader = DataLoader(
    preprocessed_train_dataset,
    batch_size=3,
    shuffle=True,
    collate_fn=collate_fn,
)

batch = next(iter(loader))
print(batch["source"].shape)
print(batch["source_lengths"].shape)
print(batch["target"].shape)
print(batch["target_lengths"].shape)

{'source': tensor([   1,   18,   27,  215,   31,   85,   20,   89,    7,   15,  115,    3,
        3149,    4,    2]), 'target': tensor([   1,   16,   24,   15,   25,  776,   17,   57,   80,  204, 1305,    5,
           2])}
German: zwei junge weiße männer sind im freien in der nähe <unk> büsche .
English: two young , white males are outside near many bushes .
torch.Size([3, 15])
torch.Size([3])
torch.Size([3, 16])
torch.Size([3])


In [6]:
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout, device, padding_idx=0):
        super().__init__()
        self.device = device
        self.directions = 2
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.dropout = nn.Dropout(dropout)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, dropout=dropout if num_layers > 1 else 0.0, batch_first=True, bidirectional=True)
        self.hidden_projection = nn.Linear(2 * hidden_dim, hidden_dim)
        self.cell_projection = nn.Linear(2 * hidden_dim, hidden_dim)

    def forward(self, source_encodings, source_lengths): # source_encodings: (batch_size, max_length), source_lengths: (batch_size)
        B, T = source_encodings.size()
        h_0 = torch.zeros(self.num_layers * self.directions, B, self.hidden_dim, device=self.device)
        c_0 = torch.zeros(self.num_layers * self.directions, B, self.hidden_dim, device=self.device)

        embedded = self.embedding(source_encodings) # (B, T, embedding_dim)
        embedded = self.dropout(embedded)

        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, source_lengths.to('cpu'), batch_first=True, enforce_sorted=False)
        packed_outputs, (last_hidden, last_cell) = self.lstm(packed_embedded, (h_0, c_0)) # hidden: (num_layers * directions, B, hidden_dim)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs, batch_first=True) # outputs: (B, T, hidden_dim * directions)

        h_f = last_hidden[0::2] # (num_layers, B, hidden_dim)
        h_b = last_hidden[1::2] # (num_layers, B, hidden_dim)
        h_cat = torch.cat((h_f, h_b), dim=2) # (num_layers, B, hidden_dim * 2) concatenate along the hidden dimension

        c_f = last_cell[0::2] # (num_layers, B, hidden_dim)
        c_b = last_cell[1::2] # (num_layers, B, hidden_dim)
        c_cat = torch.cat((c_f, c_b), dim=2) # (num_layers, B, hidden_dim * 2) concatenate along the hidden dimension

        last_hidden = torch.tanh(self.hidden_projection(h_cat)) # (num_layers, B, hidden_dim)
        last_cell = torch.tanh(self.cell_projection(c_cat)) # (num_layers, B, hidden_dim)

        return outputs, last_hidden, last_cell


In [7]:
from torch.nn import functional as F

class BahdanauAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.query_projection = nn.Linear(hidden_dim, hidden_dim)
        self.key_projection = nn.Linear(2 * hidden_dim, hidden_dim) # bidirectional LSTM
        self.score_projection = nn.Linear(hidden_dim, 1, bias=False)


    def forward(self, decoder_hidden_states, encoder_outputs): # decoder_hidden_states: (num_layers, B, hidden_dim), encoder_outputs: (B, T, 2 * hidden_dim) bidirectional LSTM
        # use the last hidden state of the multi-layer LSTM decoder as the query
        query = decoder_hidden_states[-1] # (B, hidden_dim)
        query = query.unsqueeze(1) # (B, 1, hidden_dim)

        query_projected = self.query_projection(query) # (B, 1, hidden_dim)
        keys_projected = self.key_projection(encoder_outputs) # (B, T, hidden_dim)

        # Note: query_projected is broadcasted to (B, T, hidden_dim) automatically without expand or repeat by PyTorch
        energy = self.score_projection(torch.tanh(query_projected + keys_projected)).squeeze(-1) # (B, T)

        # take a mask to prevent the attention to the padding tokens

        attention_weights = F.softmax(energy, dim=1) # (B, T)

        return attention_weights

In [16]:
class AttentionDecoder(nn.Module):
    """
    Decoder for the Seq2Seq model. This works on a batch of single step targets (B, 1).
    It doesn't take the entire target sequence, because the decision to teacher force and what kind of search strategy to use
    is done at the sequence level, not the step level.
    """
    def __init__(self, attention, vocab_size, embedding_dim, hidden_dim, num_layers, dropout, padding_idx=0):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.dropout = nn.Dropout(dropout)
        self.attention = attention
        self.lstm = nn.LSTM(
             2 * hidden_dim + embedding_dim,
            hidden_dim,
            num_layers,
            dropout=dropout if num_layers > 1 else 0.0,
            batch_first=True,
            bidirectional=False,
        )
        self.output_projection = nn.Linear(hidden_dim, vocab_size)

    def forward(self, inputs, encoder_last_hidden, encoder_last_cell, encoder_outputs):
        # Note: (B, 1) is needed because the LSTM expects a 2D input (batch_size, sequence_length)
        inputs = inputs.unsqueeze(1) # (B, 1)

        embedded = self.embedding(inputs)  # (B, 1, embedding_dim)
        embedded = self.dropout(embedded)

        attention_weights = self.attention(encoder_last_hidden, encoder_outputs) # (B, T)
        attention_weights = attention_weights.unsqueeze(1) # (B, 1, T)
        context_vector = attention_weights @ encoder_outputs # (B, 1, hidden_dim)

        lstm_input = torch.cat((embedded, context_vector), dim=2) # (B, 1, embedding_dim + 2 *hidden_dim)

        outputs, (hidden, cell) = self.lstm(lstm_input, (encoder_last_hidden, encoder_last_cell)) # (B, 1, hidden_dim)
        outputs = self.dropout(outputs)
        logits = self.output_projection(outputs) # (B, 1, vocab_size)

        return logits, hidden, cell, attention_weights.squeeze(1)

In [None]:
test_encoder = Encoder(len(de_vocab_to_index), 256, 256, 2, 0.0, 'cpu')
test_attention = BahdanauAttention(256)
test_decoder = AttentionDecoder(test_attention, len(en_vocab_to_index), 256, 256, 2, 0.0, de_vocab_to_index[BOS])

encoder_outputs, hidden, cell = test_encoder.forward(batch["source"], batch["source_lengths"])

logits, hidden, cell, attention_weights = test_decoder.forward(batch["target"][:, 0], hidden, cell, encoder_outputs)

print(f"target: {batch['target'][:, 0]}")
print(f"Logits: {logits.shape}")
print(f"Hidden: {hidden.shape}")
print(f"Cell: {cell.shape}")
print(f"Attention weights: {attention_weights.shape}")

print(f"Attention weights: {attention_weights}")


target: tensor([1, 1, 1])
Logits: torch.Size([3, 1, 4560])
Hidden: torch.Size([2, 3, 256])
Cell: torch.Size([2, 3, 256])
Attention weights: torch.Size([3, 15])
Attention weights: tensor([[0.0672, 0.0667, 0.0666, 0.0673, 0.0660, 0.0663, 0.0665, 0.0665, 0.0675,
         0.0675, 0.0666, 0.0670, 0.0666, 0.0661, 0.0657],
        [0.0671, 0.0667, 0.0664, 0.0662, 0.0662, 0.0666, 0.0661, 0.0668, 0.0668,
         0.0668, 0.0668, 0.0668, 0.0668, 0.0668, 0.0668],
        [0.0675, 0.0671, 0.0665, 0.0667, 0.0676, 0.0679, 0.0671, 0.0661, 0.0657,
         0.0658, 0.0661, 0.0664, 0.0665, 0.0660, 0.0669]],
       grad_fn=<SqueezeBackward1>)
