In [1]:
import torch
import torch.nn as nn
import torchvision
from torchsummary import summary

import os
import json
import h5py
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from collections import Counter

from datasets import CaptionDataset
from models import Encoder, Attention, DecoderWithAttention

%load_ext autoreload
%autoreload 2

# 01 Dimentions

Let's start from understanding key dimensions that we use in `Decoder`:

- `encoder_dim`: This is the last dimension of `encoder_out`, in our case 2048.
- `decoder_dim`: This is a dimension of a hidden state of LSTM, 512.
- `attention_dim`: Hidden size of linear layers that we use in our `Attention` module.
- `embed_dim`: Dimension of embedded vectors that we generate for every token.

# 02 Initial state

The logic behind building an initial state is simple (see the paper):
> The initial memory state and hidden state of the LSTM are predicted by an average of the annotation vectors fed through two separate MLPs (`init_c` and `init_h`): 
>
>$$c_0 = f_{init-c}(\frac{1}{L} \sum_i^L{a_i})$$
>
>$$h_0 = f_{init-h}(\frac{1}{L} \sum_i^L{a_i})$$

That's exactly what we're doing:

```python
# (batch_size, num_pixels, encoder_dim) -> (batch_size, encoder_dim)
mean_encoder_out = encoder_out.mean(dim=1)

# (batch_size, encoder_dim) -> (batch_size, decoder_dim)
c = self.init_c(mean_encoder_out)
h = self.init_h(mean_encoder_out)  
```

There's no surprise here that $f_{init-c}$ and $f_{init-h}$ are just linear layers:

```python
# initial states
# (BS, encoder_dim) -> (BS, decoder_dim)
# linear layer to find initial hidden state of LSTMCell
self.init_h = nn.Linear(encoder_dim, decoder_dim)
# linear layer to find initial cell state of LSTMCell  
self.init_c = nn.Linear(encoder_dim, decoder_dim)  
```

In [2]:
decoder = DecoderWithAttention()

In [3]:
decoder.init_h, decoder.init_c

(Linear(in_features=2048, out_features=512, bias=True),
 Linear(in_features=2048, out_features=512, bias=True))

In [4]:
encoder_out = torch.zeros(1, 196, 2048)

In [5]:
h, c = decoder.init_hidden_state(encoder_out)

In [6]:
h.shape, c.shape

(torch.Size([1, 512]), torch.Size([1, 512]))

# 03 `encoder_out`

`encoder_out` has shape `(BS, 14, 14, 2048)` as we know. We have to change it to `(BS, 196, 2048)` - we need this for `Attention` module and for initializing `h` and `c`.

In [7]:
encoder_out = torch.zeros(1, 14, 14, 2048)

In [9]:
encoder_out = encoder_out.view(1, -1, 2048)  
num_pixels = encoder_out.size(1)

In [10]:
encoder_out.shape, num_pixels

(torch.Size([1, 196, 2048]), 196)

# 04 main loop

## 04-1 `decode_step`

The most important thing to notice - we use a special form of input to our LSTM cell. It's not only a caption `embeddings[...]` but also a context vector. Context vector here is an output of our attention module `attention_weighted_encoding`. We use a concatenation of those 2 inputs: `torch.cat(...)`, so the dimension of our LSTM cell is `embed_dim + encoder_dim`.

Probably the better names would be:
- `embeddings` - `embedded_captions`;
- `attention_weighted_encoding` - `context_vector`;

This step looks like this in code:

```python
# decode_step - LSTM cell
# (batch_size_t, embed_dim + encoder_dim) -> (batch_size_t, decoder_dim)
# so we use embedded captions (only a single word at a step) and ourput of 
# our encoder after attention and translate this with LSTM cell into hidden state
h, c = self.decode_step(
    torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
    (h[:batch_size_t], c[:batch_size_t])
) 
```

Here `self.decode_step` is just an LSTM cell:

```python
# this is our LSTM; we use cell in a loop
self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)  
```

The input to a LSTM cell in the formula $(1)$ in the paper looks like this:

$$
\begin{bmatrix} Ey_{t-1} \\ h_{t-1} \\ \hat{z}_t \end{bmatrix}
$$

Let's have a look at the 1st element of this vector:
- $y_{t-1}$ is a one-hot-encode vector from $\mathbb{R}^K$, where $K$ is the size of the vocabulary; 
- $E \in \mathbb{R}^{m \times K}$ is an embedding matrix, where $m$ is an embedding dimension; so $Ey_{t-1}$ is  an embedded vector in $\mathbb{R}^m$; 
- in our case this is `embeddings[:batch_size_t, t, :]` with the size `embed_dim`;

Now let's have a look at the last element:

>$\hat{z}_t \in \mathbb{R}^D$ is a
>context vector, capturing the visual information associated with a particular input location; the extractor produces $L$ vectors, each of which is a $D$-dimensional representation corresponding to a part of the image.

- this is a result of applying an attention layer that we're considering later; 
- in our case this is `attention_weighted_encoding`; in other words this is a context vector;
- its size is `encoder_dim` which is `2048` in our case (`encoder_out` has a dimension `(BS, 14*14, 2048)` after flattenning);

Finally we're using $h_{t-1}$ and $c_{t-1}$, in our case this is `h[:batch_size_t], c[:batch_size_t]`.

## 04-2 `predictions`

At this stage to get predictions we use a projection from hidden space to vocabulary space. All the logic for sampling a caption (including `softmax` and `BEAM search`) is incorporated in `caption_image_beam_search()` (file `caption.py`):

```python
# this projection from hidden space onto vocabulary space
# (batch_size_t, decoder_dim) -> (batch_size_t, vocab_size)
# we don't compute softmax here or choose max value
preds = self.fc(self.dropout(h)) 
```

Here `self.fc` is just a Linear layer `decoder_dim -> vocab_size`:

```python
self.fc = nn.Linear(decoder_dim, vocab_size)
```