In [5]:
import torch.nn
import torch.nn.functional as F

In [1]:
BATCH_SIZE = 64
BUFFER_SIZE = 1000
NB_IMAGES = 8000

embed_dim = 256
units = 512
num_steps = NB_IMAGES // BATCH_SIZE

# Extraction from CNN: (64, 2048)
features_shape = 2048
attention_shape = 64

In [None]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super(EncoderCNN, self).__init__()
        
        self.cnn = None  # init pretrained CNN
        self.fc = nn.Dense(embed_dim)
        

    def forward(self, images):
        # images shape: (8, 8, 2048)
        features = self.cnn.fit(images)  # images shape: (64, 2048)
        features = self.fc.fit(features)  # images shape: (64, 256)
        features = F.relu(features)
        
        return features

In [None]:
class Attention(nn.Module):
    def __init__(self, units):
        super(Attention, self).__init__()
        
        self.W1 = nn.Linear(units)
        self.W2 = nn.Linear(units)
        self.V = nn.Linear(1)
        
    def forward(self, features, hidden_state):
        # input shape: (64, embedding_shape)

        # "concat" scores (tanh)
        concat = torch.tanh(
            self.W1(features) + self.W2(hidden_state)
        ) # (64, units)
        attention_scores = self.V(concat)  # (units, 1)
        
        alignment = F.softmax(attention_scores, dim=1)  # (units)
        
        context_vector = features * alignment.unsqueeze(2)  # (units, hidden_size)
        context_vector = context_vector.sum(dim=1)   # (units)
        
        return alignment, context_vector

In [None]:
class DecoderRNN(nn.Module):
    def __init__(self, units, vocab_size, embed_shape):
        super().__init__()

        self.units = units

        self.attention = Attention(self.units)

        self.embedding = nn.Embedding(vocab_size, embed_shape)
        self.gru = nn.GRU(units)
        self.fc1 = nn.Linear(self.units)
        self.fc2 = nn.Linear(vocab_size)
    
    def forward(self):
        # defining attention as a separate model
        context_vector, attention_weights = self.attention(features, hidden)

        # x shape after passing through embedding == (batch_size, 1, embedding_dim)
        x = self.embedding(x)

        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

        # passing the concatenated vector to the GRU
        output, state = self.gru(x)

        # shape == (batch_size, max_length, hidden_size)
        x = self.fc1(output)

        # x shape == (batch_size * max_length, hidden_size)
        x = tf.reshape(x, (-1, x.shape[2]))

        # output shape == (batch_size * max_length, vocab)
        x = self.fc2(x)

        return x, state, attention_weights

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = EncoderCNN()
        self.decoder = DecoderRNN()
    
    def forward(self, images, captions):
        features = self.encoder.forward(images)
        outputs = self.decoder.forward(features, captions)
        return outputs