In [2]:
import torch
from torch import nn
from torchvision.models import inception_v3

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_dim):
        super(EncoderCNN, self).__init__()
        self.inception = inception_v3(
            pretrained=True, aux_logits=False
        ).to(device)
        self.inception.fc = nn.Linear(
            self.inception.fc.in_features, embed_dim
        )
        self.dropout = nn.Dropout(0.4)
        
    def forward(self, x):
        with torch.no_grad():
            x = self.inception(x)
        return self.dropout(F.relu(x))

In [6]:
class DecoderRNN(nn.Module):
    def __init__(
        self, vocab_size, embedding_dim, hidden_size, num_layers, max_seq_len
    ):
        super(DecoderRNN, self).__init__()
        self.max_seq_len = max_seq_len
        self.embed = nn.Embedding(
            num_embeddings=vocab_size, 
            embedding_dim=embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=embedding_dim, 
            hidden_size=hidden_size, 
            num_layers=num_layers,
        )
        self.fc = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, features, captions):
        embeddings = self.embed(captions)
        embeddings = torch.cat([features.unsqueeze(0), embeddings], dim=0)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.fc(hiddens)
        return outputs
    
    def generate(self, features, idx2word):
        res = []
        with torch.no_grad():
            states = None
            # remove batch dimension from cnn
            inputs = features.unsqueeze(0)
            for _ in range(self.max_seq_len):
                hiddens, states = self.lstm(inputs, states)
                outputs = self.fc(hiddens.squeeze(0))
                predicted = outputs.argmax(1)
                res.append(predicted.item())
                if idx2word[predicted.item()] == "<EOS>":
                    break
                inputs = self.embed(predicted).unsqueeze(0)
        return [idx2word[i] for i in res]

In [None]:
class Vocabulary:
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.count = 0
    
    def add(self, word):
        
    
    

In [None]:
encoder = EncoderCNN()

https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/03-advanced/image_captioning
https://www.youtube.com/watch?v=y2BaTt1fxJU&t=588s