In [3]:
from torch import nn, device as torch_device, cuda, tanh
import torchvision.models as models
import torch.nn.functional as F

# VGG

In [4]:
class VGG16FeatureExtractor(nn.Module):
    def __init__(self):
        super(VGG16FeatureExtractor, self).__init__()

        model = models.vgg16(pretrained=True)

        # VGG-16 Feature Layers
        self.features = nn.Sequential(*list(model.features))

        # VGG-16 Average Pooling Layer
        self.pooling = model.avgpool

        # Convert the image into one-dimensional vector
        self.flatten = nn.Flatten()
        
        # First part of fully-connected layer from VGG16
        self.fc = model.classifier[0]
        
    def forward(self, x):
        out = self.features(x)
        out = self.pooling(out)
        out = self.flatten(out)
        out = self.fc(out) 
        
        return out

In [5]:
device = torch_device('cuda:0' if cuda.is_available() else "cpu")

model = VGG16FeatureExtractor().to(device)
model

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /Users/lucavaio/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:15<00:00, 36.2MB/s] 


VGG16FeatureExtractor(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, pad

In [6]:
class EncoderCNN(nn.Module):
    def __init__(self, cnn, embed_dim):
        super(EncoderCNN, self).__init__()
        
        self.cnn = cnn  # init pretrained CNN
        self.fc = nn.Linear(cnn.fc.out_features, embed_dim)


    def forward(self, images):
        # images shape: (8, 8, 2048)
        features = self.cnn.forward(images)  # images shape: (64, 4096)
        features = self.fc.forward(features)  # images shape: (64, 256)
        features = F.relu(features)
        
        return features

In [7]:
embed_dim = 256
encoder = EncoderCNN(model, embed_dim)

In [8]:
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, units):
        super(Attention, self).__init__()
        
        self.W1 = nn.Linear(encoder_dim, units)
        self.W2 = nn.Linear(decoder_dim, units)
        self.V = nn.Linear(units, 1)
    
    def forward(self, features, hidden_state):
        # features shape: (batch, 64, 256)

        # tanh scores
        scores = tanh(
            self.W1(features) + self.W2(hidden_state)
        ) # (batch, 64, units)
        scores = self.V(scores)  # (batch, 64, 1)
        # scores = scores.squeeze(2)
        
        attention = F.softmax(scores, dim=1)  # (batch, 64)
        
        context_vector = features * attention.unsqueeze(2)  # (batch, 64, 256)
        context_vector = context_vector.sum(dim=1)   # (batch, 64)
        
        return attention, context_vector