Attention took the NLP community by storm a few years ago when it was first announced. I've personally heard about attention many times, but never had the chance to fully dive into what it was. In this post, we will attempt to bake in a simple attention mechanism into a seq2seq model. 

Before anything, I highly recommend that you check out Ben Trevett's [sequence modeling tutorials](https://github.com/bentrevett/pytorch-seq2seq). In particular, this post is heavily based off of his [NMT notebook](https://github.com/bentrevett/pytorch-seq2seq/blob/master/3%20-%20Neural%20Machine%20Translation%20by%20Jointly%20Learning%20to%20Align%20and%20Translate.ipynb). Now let's get started!

# Setup

For this tutorial, we will need to import a number of dependencies, mainly from `torch` and `torchtext`. `torchtext` is a library that provides a nice interface to dealing with text-based data in PyTorch. 

In [29]:
import random
import time

import torch
import torchtext
from torch import nn
import torch.nn.functional as F
from torchtext.data import BucketIterator, Field
from torchtext.datasets import Multi30k

In particular, `torchtext` includes a `Field` class, which essentially allows us to define some preprocessing steps to be applied on the data. We will be using the `Multi30k` dataset, which contains translations of short texts from many languages. In this tutorial, we will be using German and English, so we define preprocessing steps for each language. The preprocessing, as defined below, tells `torchtext` to:

* Tokenize the dataset using `spacy`
* Prepend each line with `"<sos>"` and `"<eos>"` tokens
* Lowercase every word

In [30]:
SRC = Field(
    tokenize="spacy",
    tokenizer_language="de",
    init_token="<sos>",
    eos_token="<eos>",
    lower=True,
)

TRG = Field(
    tokenize="spacy",
    tokenizer_language="en",
    init_token="<sos>",
    eos_token="<eos>",
    lower=True,
)

We can now prepare the data by calling `split()` on the `Multi30k` dataset, using the fields we have defined above.

In [31]:
train_data, validation_data, test_data = Multi30k.splits(
    root="data", exts=(".de", ".en"), fields=(SRC, TRG)
)

SRC.build_vocab(train_data, max_size=10000, min_freq=2)
TRG.build_vocab(train_data, max_size=10000, min_freq=2)

Next, we create iterators to load the dataset to be fed into our model. These iterators are effectively data loaders in PyTorch. 

In [33]:
BATCH_SIZE = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, validation_iterator, test_iterator = BucketIterator.splits(
    (train_data, validation_data, test_data), batch_size=BATCH_SIZE, device=device
)

Below, we can see that all data have properly been batched. Notice that the length of each batch is different; of course, within each batch, all sentences have the same length. Otherwise, they wouldn't be a batch in the first place. However, it is apparent from this design that one benefit of using `torchtext` for batching data is that there is no need to worry about zero padding each sentence to make their lengths uniform across all batches.

In [34]:
for i, batch in enumerate(train_iterator):
    print(batch.src.shape)
    if i == 5:
        break

torch.Size([37, 128])
torch.Size([28, 128])
torch.Size([28, 128])
torch.Size([37, 128])
torch.Size([28, 128])
torch.Size([27, 128])


# Modeling

Now is the time for the fun part: modeling and implementing attention. Attention mechanisms originally arose in the context of sequence-to-sequence modeling. The underlying question is this: when some information is encoded via the encoder, then decoded by the decoder, can the decoder learn which part of the encoding to focus on while decoding? An easy real-life example of this would be machine translation. Given the input "I love you," the Korean translation would be "나는 너를 사랑해," or, translated word by word "I you love." In this particular instance, the decoder has to know that there is some syntactic difference between Korean and English, and know which part of the original English sequence to focus on when producing a translation. 

Now that we have some idea of what attention is, let's start coding the encoder. 

In [35]:
class Encoder(nn.Module):
    def __init__(
        self, 
        vocab_size, 
        embed_dim, 
        encoder_hidden_size, 
        decoder_hidden_size, 
        dropout
    ):
        super(Encoder, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(embed_dim, encoder_hidden_size, bidirectional=True)
        self.fc = nn.Linear(encoder_hidden_size * 2, decoder_hidden_size)

    def forward(self, x):
        embedding = self.dropout(self.embed(x))
        outputs, hidden = self.gru(embedding)
        # outputs.shape == (seq_len, batch_size, 2 * encoder_hidden_size)
        # hidden.shape == (2, batch_size, encoder_hidden_size)
        concat_hidden = torch.cat((hidden[-1], hidden[-2]), dim=1)
        # concat_hidden.shape == (batch_size, encoder_hidden_size * 2)
        hidden = torch.tanh(self.fc(concat_hidden))
        # hidden.shape = (batch_size, decoder_hidden_size)
        return outputs, hidden

In [36]:
class Attention(nn.Module):
    def __init__(
        self, 
        encoder_hidden_size,
        decoder_hidden_size,
    ):
        super(Attention, self).__init__()
        self.fc1 = nn.Linear(
            encoder_hidden_size * 2 + decoder_hidden_size, 
            decoder_hidden_size
        )
        self.fc2 = nn.Linear(decoder_hidden_size, 1)
    
    def forward(self, hidden, encoder_outputs):
        # hidden.size = (batch_size, decoder_hidden_size)
        # encoder_outputs = (seq_len, batch_size, encoder_hidden_size * 2)
        seq_len = encoder_outputs.size(0)
        batch_size = encoder_outputs.size(1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        hidden = hidden.unsqueeze(1).repeat(1, seq_len, 1)
        # hidden.size = (batch_size, seq_len, decoder_hidden_size)
        # encoder_outputs = (batch_size, seq_len, encoder_hidden_size * 2)
        concat = torch.cat((hidden, encoder_outputs), dim=2)
        # concat.shape == (batch_size, seq_len, encoder_hidden_size * 2 + decoder_hidden_size)
        energy = torch.tanh(self.fc1(concat))
        # energy.shape == (batch_size, seq_len, decoder_hidden_size)
        attention = self.fc2(energy)
        # attention.shape == (batch_size, seq_len, 1)
        attention = attention.squeeze(2)
        # attention.shape == (batch_size, seq_len)
        return F.softmax(attention, dim=1)


In [37]:
class Decoder(nn.Module):
    def __init__(
        self, 
        vocab_size,
        embed_dim,
        decoder_hidden_size,
        encoder_hidden_size, 
        droppout,
    ):
        super(Decoder, self).__init__()
        self.vocab_size = vocab_size
        self.dropout = nn.Dropout(droppout)
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(
            encoder_hidden_size * 2 + embed_dim, decoder_hidden_size
        )
        self.attention = Attention(encoder_hidden_size, decoder_hidden_size)
        self.fc = nn.Linear(
            encoder_hidden_size * 2 + decoder_hidden_size + embed_dim, 
            vocab_size
        )
    
    def forward(self, x, hidden, encoder_outputs):
        x = x.unsqueeze(0)
        # x.shape == (1, batch_size)
        # hidden.shape = (batch_size, decoder_hidden_size)
        embedding = self.dropout(self.embedding(x))
        # embedding.shape == (1, batch_size, embed_dim)
        attention = self.attention(hidden, encoder_outputs)
        # attention.shape == (batch_size, seq_len)
        attention = attention.unsqueeze(1)
        # attention.shape == (batch_size, 1, seq_len)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        # encoder_outputs.shape == (batch_size, seq_len, encoder_hidden_size * 2)
        weighted = torch.bmm(attention, encoder_outputs)
        # weighted.shape == (batch_size, 1, encoder_hidden_dim * 2)
        weighted.permute(1, 0, 2)
        # weighted.shape == (1, batch_size, encoder_hidden_dim * 2)
        weighted_concat = weighted.cat((embedding, weighted), dim=2)
        # weighted_concat.shape == (1, batch_size, encoder_hidden_dim * 2 + embed_dim)
        output, hidden = self.gru(weighted_concat, hidden)
        # output.shape == (1, batch_size, decoder_hidden_size)
        # hidden.shape == (1, batch_size, decoder_hidden_size)
        embedding = embedding.squeeze(0)        
        output = output.squeeze(0)
        hidden = hidden.squeeze(0)
        weighted = weighted.squeeze(0)
        # embedding.shape == (batch_size, embed_dim)
        # output.shape == (batch_size, decoder_hidden_size)
        # weighted.shape == (batch_size, encoder_hidden_dim * 2)
        fc_in = torch.cat((output, weighted, embedding), dim=1)
        prediction = self.fc(fc_in)
        # prediction.shape == (batch_size, vocab_size)
        return prediction, hidden

In [38]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
    
    def forward(self, source, target, teacher_force_ratio=0.5):
        seq_len = target.size(0)
        batch_size = target.size(1)    
        outputs = torch.zeros(
            seq_len, batch_size, self.decoder.vocab_size
        ).to(self.device)
        
        encoder_outputs, hidden = self.encoder(source)
        x = target[0]
        
        for t in range(seq_len):
            output, hidden = self.decoder(x, hidden, encoder_outputs)
            outputs[t] = output
            teacher_force = random.random() < teacher_force_ratio
            if teacher_force:
                x = predictions.argmax(1)
            else:
                x = target[t]
        return outputs

In [39]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

encoder = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
decoder = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT)
model = Seq2Seq(encoder, decoder, device).to(device)

https://github.com/bentrevett/pytorch-seq2seq/blob/master/3%20-%20Neural%20Machine%20Translation%20by%20Jointly%20Learning%20to%20Align%20and%20Translate.ipynb