In [1]:
import torch
import torch.nn as nn
import torchvision
from torchsummary import summary

import os
import json
import h5py
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from collections import Counter

from datasets import CaptionDataset
from models import Encoder, Attention, DecoderWithAttention

%load_ext autoreload
%autoreload 2

# 01 Math behind the attention

In the paper they mention that incorporating of the attention mechanism:
> is inspired by recent success in employing attention in machine translation (Bahdanau et al., 2014)

They also mention that they closely follow this paper:
> There has been a long line of previous work incorporating attention into neural networks for vision related tasks. In particular however, our work directly extends the work of Bahdanau et al. (2014); Mnih et al. (2014); Ba et al. (2014).

Let's compare the main steps of the attention in those papers. First let's have a look into *Bahdanau et al. (2014)*. In this paper we have seq-to-seq model with 2 RNNs (Encoder/Decoder).

To compute attention we need 3 steps (in notation from *Bahdanau et al. (2014)*):
1) Compute alignment scores using an alignment model $e_{ij} = a(s_{i-1}, h_j)$. Here $h_j$ is a hidden state of *Encoder*, $s_{i-1}$ - hidden state of *Decoder*. Here $a$ is just a feedforward neural network.
2) Compute attention weights by applying `softmax`: $\alpha_{ij} = softmax(e)$.
3) Compute **context vector** as a weighted sum of *Encoder's* hidden states (this is the key idea of attention - we use those hidden states of *Encoder* that are relevant for the current step of *Decoder*): $c_i = \sum_j \alpha_{ij} h_j$

These are exactly the same steps that we may see in *Xu et al, 2016*:

$$
e_{ti} = f_{att}(a_i, h_{t-1}) \\
\alpha_{ti} = softmax(e) \\
\hat{z}_t = \phi(a_i, \alpha_i)
$$

In the 1st step we're using image features from CNN, not hidden states of Encoder RNN in the pevious model. $f_{att}$ is also a neural network as before. In the last step there are 2 options in the paper but in practice we're using the same sum as before.

# 02 Code for attention module

Now we can easily interpret our code. The first step is to apply some linear layers (`self.encoder_att` etc.) to our input from *Encoder* (image features from CNN) and *Decoder*:

```python
# (batch_size, num_pixels, attention_dim)
att1 = self.encoder_att(encoder_out)  
# (batch_size, attention_dim)
att2 = self.decoder_att(decoder_hidden)
# (batch_size, num_pixels)  
att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  
```

The next step is to apply `softmax` to get attention weights:

```python
# (batch_size, num_pixels)
alpha = self.softmax(att)  
```

Finally we're using weighted sum of image features from *Encoder* to get our context vector:
```python
# (batch_size, encoder_dim)
attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) 
```