#### Image Captioning

In [4]:
import torch
import torchvision
import torch.nn as nn
import numpy as np 

In [20]:
class Attention(nn.Module) : 
    def __init__(self) : 
        super().__init__()
        
    def forward(self, query, key, values) : 
        """
        query => [batch, num_queries, embeddings]
        key ==> [batch, num_objects(vocab), embeddings]
        value ==> [batch, num_objects(vocab), value_embeddings]
        
        """
        dot_product = query @ key.transpose(-1, -2)
        scaled_dot = dot_product / torch.sqrt(torch.tensor(key.shape[-1], dtype=torch.float32))
        weights = torch.softmax(scaled_dot, dim = 1) # num_queries dimension
        self.attention_map = weights.detach()
        result = weights @ values
        return result

In [None]:
q = torch.randn(2, 3, 5)
k = torch.randn(2, 4, 5)
v = torch.randn(2, 4, 7)


results = Attention()(q,k,v)

assert results.shape == (2,3,7) # (batch_size, num_queries, output_embeddings)

In [None]:
class CaptionNet(nn.Module):
    def __init__(self, n_tokens=n_tokens, emb_size=128, lstm_units=256, cnn_channels=512):
        """ A recurrent 'head' network for image captioning. Read scheme below. """
        super(self.__class__, self).__init__()

        # a layer that converts conv features to initial LSTM states
        self.cnn_to_h0 = nn.Linear(cnn_channels, lstm_units)
        self.cnn_to_c0 = nn.Linear(cnn_channels, lstm_units)

        # create embedding for input words. Use the parameters (e.g. emb_size).
        self.emb = nn.Embedding(n_tokens, emb_size)

        # attention: create attention over image spatial positions
        # The query is previous lstm hidden state, the keys are transformed cnn features,
        # the values are cnn features
        self.attention = Attention(ScaledDotProductScore())

        # attention: create transform from cnn features to the keys
        # Hint: one linear layer should work
        # Hint: the dimensionality of keys should be lstm_units as lstm
        #       hidden state is the attention query
        self.cnn_to_attn_key = nn.Linear(cnn_channels, lstm_units)

        # lstm: create a recurrent core of your network. Use LSTMCell
        self.lstm = nn.LSTMCell(emb_size + cnn_channels, lstm_units)

        # create logits: MLP that takes attention response, lstm hidden state
        # and the previous word embedding as an input and computes one number per token
        # Hint: I used an architecture with one hidden layer, but you may try deeper ones
        self.logits_mlp = nn.Sequential(
            nn.Linear(lstm_units + cnn_channels + emb_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_tokens)
        )

    def forward(self, image_features, captions_ix):
        """
        Apply the network in training mode.
        :param image_features: torch tensor containing VGG features for each position.
                               shape: [batch, cnn_channels, width * height]
        :param captions_ix: torch tensor containing captions as matrix. shape: [batch, word_i].
            padded with pad_ix
        :returns: logits for next token at each tick, shape: [batch, word_i, n_tokens]
        """
        batch_size, cnn_channels, spatial_size = image_features.shape
        caption_length = captions_ix.shape[1]

        # Initialize LSTM states from averaged image features
        initial_cell = self.cnn_to_c0(image_features.mean(2))  # [batch, lstm_units]
        initial_hid = self.cnn_to_h0(image_features.mean(2))   # [batch, lstm_units]

        # Transpose to [batch, spatial_size, cnn_channels] for attention
        image_features = image_features.transpose(1, 2)

        # compute embeddings for captions_ix
        captions_emb = self.emb(captions_ix)  # [batch, caption_length, emb_size]

        # apply recurrent layer to captions_emb.
        # 1. initialize lstm state with initial_* from above
        h_t, c_t = initial_hid, initial_cell

        # Lists to store outputs
        recurrent_outputs = []
        attention_maps = []

        # 2. In the recurrent loop over tokens:
        for t in range(caption_length):
            # 2.1. transform image vectors to the keys for attention
            image_keys = self.cnn_to_attn_key(image_features)  # [batch, spatial_size, lstm_units]

            # 2.2. use previous lstm state as an attention query and image vectors as values
            query = h_t.unsqueeze(1)  # [batch, 1, lstm_units] - add query dimension

            # 2.3. apply attention to obtain context vector
            context = self.attention(query, image_keys, image_features)  # [batch, 1, cnn_channels]
            context = context.squeeze(1)  # [batch, cnn_channels] - remove query dimension

            # 2.4. store attention map
            attention_maps.append(self.attention.attention_map.squeeze(1))  # [batch, spatial_size]

            # Get current word embedding
            current_word_emb = captions_emb[:, t, :]  # [batch, emb_size]

            # 2.5. feed lstm with current token embedding concatenated with context vector
            lstm_input = torch.cat([current_word_emb, context], dim=1)  # [batch, emb_size + cnn_channels]

            # 2.6. update lstm hidden and cell vectors
            h_t, c_t = self.lstm(lstm_input, (h_t, c_t))

            # 2.7. store current lstm hidden state,
            combined_output = torch.cat([h_t, context, current_word_emb], dim=1)  # [batch, lstm_units + cnn_channels + emb_size]
            recurrent_outputs.append(combined_output)

         # Stack outputs along time dimension
        reccurent_out = torch.stack(recurrent_outputs, dim=1)  # [batch, caption_length, lstm_units + cnn_channels + emb_size]
        attention_map = torch.stack(attention_maps, dim=1)     # [batch, caption_length, spatial_size]

        # compute logits for next token probabilities
        # based on the stored in (2.7) values (reccurent_out)
        logits = self.logits_mlp(reccurent_out)  # [batch, caption_length, n_tokens]

        # return logits and attention maps from (2.4)
        return logits, attention_map

Understood the architecture, just provides the attention + vector embeddings to the LSTMs which are stacked. and the attention is calculated between thei lstm hidden state as queries and transformed features and cnn features directly as values. 