In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms

In [None]:
#nn.Module -> Base class for all neural network modules in pytorch
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        """
        Args:
            embed_size: final embedding size of the CNN encoder (256 in our case)
            hidden_size: hidden size of the LSTM
            vocab_size: size of the vocabulary - The total number of unique words the model knows.
            num_layers: number of layers of the LSTM (1 by default here)
        """

        # Calls the parent class (nn.Module) constructor, 
        # allowing this custom class to inherit its functionality.
        super(DecoderRNN, self).__init__()

        # Assigning hidden dimension
        # Stores the hidden state size for later use
        self.hidden_dim = hidden_size

        # Map each word in the vocabulary (represented as an integer index) 
        # to a dense embedding vector of embed_size
        # Why? -> Converts the input captions (word indices) into continuous vectors that capture semantic meaning.
        self.embed = nn.Embedding(vocab_size, embed_size)

        # Creating LSTM layer
        # batch_first=True: Indicates that input tensors will have the shape (batch_size, sequence_length, feature_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)

        # Initializing linear to apply at last of RNN layer for further prediction
        # nn.Linear: Fully connected (dense) layer that maps the LSTM’s hidden state output to a vocabulary-sized tensor.
        # Why? -> Converts the LSTM’s output into logits (unnormalized scores) for each word in the vocabulary.
        self.linear = nn.Linear(hidden_size, vocab_size)

        # Initializes the hidden state and cell state of the LSTM to zeros.
        # Shape: (num_layers, batch_size, hidden_size)
        self.hidden = (torch.zeros(1, 1, hidden_size), torch.zeros(1, 1, hidden_size))

    def forward(self, features, captions):
    # model processes inputs to produce outputs.
        
        # remove <end> token from captions and embed captions
        # This is because we want to predict the next word, not the <end> token.
        # self.embed: Converts the word indices into embedding vectors.
        cap_embedding = self.embed(
            captions[:, :-1]
        )  # (bs, cap_length) -> (bs, cap_length-1, embed_size)


        # torch.cat: Concatenates the image features and caption embeddings along the sequence dimension.
        # features.unsqueeze(dim=1): Adds a dimension to the features tensor, making it (batch_size, 1, embed_size).
        # [bs, embed_size] => [bs, 1, embed_size] concat [bs, cap_length-1, embed_size]

        # => [bs, cap_length, embed_size] add encoded image (features) as t=0
        # This adds the image feature as the first time step of the sequence, which helps the decoder generate captions.
        embeddings = torch.cat((features.unsqueeze(dim=1), cap_embedding), dim=1)


        # self.lstm: Processes the combined embeddings through the LSTM.
        # getting output i.e. score and hidden layer.

        # first value: all the output hidden states for each time stamp throughout the sequence. 
        # second value: Final hidden and cell states of the LSTM
        lstm_out, self.hidden = self.lstm(
            embeddings
        )  # lstm_out shape -> (bs, cap_length, hidden_size), 
           # hidden shape -> (num_layers = 1, bs, hidden_size)

        
        # self.linear: Applies a fully connected layer to map LSTM outputs to vocabulary scores.
        outputs = self.linear(lstm_out)  # (bs, cap_length, vocab_size)

        # A tensor of shape (batch_size, caption_length, vocab_size) 
        # containing scores for each word in the vocabulary at each time step.
        return outputs


    def sample(self, inputs, states=None, max_len=20):
       # Generates a predicted sentence (with word IDs) based on the provided pre-processed image tensor.

        # Initialize an empty list to store the predicted word indices
        res = []

        # Now we feed the LSTM output and hidden states back into itself to get the caption
        for i in range(max_len):

            # Pass the input through the LSTM to get the output and updated states. lstm_out: (1, 1, hidden_size)
            lstm_out, states = self.lstm(inputs, states)  

            # Pass the LSTM output through the linear layer to get predictions over the vocabulary. outputs: (1, vocab_size)
            outputs = self.linear(lstm_out.squeeze(dim=1))  

            # Select the word index with the highest score (most probable word)
            _, predicted_idx = outputs.max(dim=1)  # predicted: (1, 1)

            res.append(predicted_idx.item())

            # if the predicted idx is the stop index, the loop stops
            if predicted_idx == 1:
                break
            
             # Embed the predicted word index to get the input for the next iteration. inputs: (1, embed_size)
            inputs = self.embed(predicted_idx)  

            # prepare input for next iteration
            inputs = inputs.unsqueeze(1)  # inputs: (1, 1, embed_size)

        return res