In this post, we will take a look at image captioning. 

In [2]:
import torch
from torch import nn
from torchvision.models import inception_v3

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_dim, dropout):
        super(EncoderCNN, self).__init__()
        self.inception = inception_v3(
            pretrained=True, aux_logits=False
        ).to(device)
        self.inception.fc = nn.Linear(
            self.inception.fc.in_features, embed_dim
        )
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        with torch.no_grad():
            x = self.inception(x)
        return self.dropout(F.relu(x))

In [6]:
class DecoderRNN(nn.Module):
    def __init__(
        self, vocab_size, embed_dim, hidden_size, num_layers
    ):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(
            num_embeddings=vocab_size, 
            embedding_dim=embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=embed_dim, 
            hidden_size=hidden_size, 
            num_layers=num_layers,
        )
        self.fc = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, features, captions):
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0)
        hiddens, _ = self.lstm(embeddings)
        # hiddens.shape == (seq_len, batch, hidden_size)
        outputs = self.fc(hiddens)
        # outputs.shape == (seq_len, batch, vocab_size)
        return outputs

In [None]:
class ImageCaptioner(nn.Module):
    def __init__(
        self, vocab_size, embed_dim, hidden_size, num_layers, idx2token, dropout=0.5
    ):
        super(ImageCaptioner, self).__init__()
        self.idx2token = idx2token
        self.encoder = EncoderCNN(embed_dim, dropout)
        self.decoder = DecoderRNN(
            vocab_size, embed_dim, hidden_size, num_layers, max_seq_len
        )

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)

    def generate(self, image, max_len=50):
        tokens = []
        with torch.no_grad():
            states = None
            inputs = self.encoder(image).unsqueeze(0)
            for _ in range(max_len):
                hiddens, states = self.lstm(inputs, states)
                outputs = self.fc(hiddens.squeeze(0))
                predicted = outputs.argmax(1)
                token = self.idx2token[predicted.item()]
                if token == "<eos>":
                    break
                tokens.append(token)
                inputs = self.embed(predicted).unsqueeze(0)
        return " ".join(tokens)

https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/03-advanced/image_captioning
https://www.youtube.com/watch?v=y2BaTt1fxJU&t=588s