In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [7]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
import timm

In [21]:
image_dim   = 512
text_dim    = 512
decoder_dim = 1024

In [26]:
class Net(nn.Module):
    def __init__(self,):
        super(Net, self).__init__()
        m = timm.create_model('resnet26d')
        self.encoder = nn.Sequential(*list(m.children())[:-1], nn.Linear(2048, 512))
        self.token_embed = nn.Embedding(vocab_size, text_dim)
        self.logit = nn.Linear(decoder_dim, vocab_size)

        self.rnn = nn.LSTM(
            text_dim,
            decoder_dim,
            num_layers = 2,
            bias = True,
            batch_first = True,
            dropout = 0.2,
            bidirectional = False
        )
        self.token_embed.weight.data.uniform_(-0.1, 0.1)
        self.logit.bias.data.fill_(0)
        self.logit.weight.data.uniform_(-0.1, 0.1)


    def forward(self, image, token, length):
        batch_size,c,h,w  = image.shape

        image_embed = self.encoder(image)
        text_embed  = self.token_embed(token)

        x = torch.cat([image_embed.unsqueeze(1),text_embed],1)
        y , (h,c) = self.rnn(x)
        logit = self.logit(y[:,1:])
        return logit

In [27]:
def seq_cross_entropy_loss(logit, token, length):
    truth = token[:, 1:]
    L = [l - 1 for l in length]
    logit = pack_padded_sequence(logit, L, batch_first=True).data
    truth = pack_padded_sequence(truth, L, batch_first=True).data
    loss = F.cross_entropy(logit, truth, ignore_index=STOI['<pad>'])
    return loss

In [28]:
def run_check_net():
    batch_size = 7
    C,H,W = 3, 224, 224
    image = torch.randn((batch_size,C,H,W))

    token  = np.full((batch_size, max_length), STOI['<pad>'], np.int64) #token
    length = np.random.randint(5,max_length-2, batch_size)
    length = np.sort(length)[::-1].copy()
    for b in range(batch_size):
        l = length[b]
        t = np.random.choice(vocab_size,l)
        t = np.insert(t,0,     0)
        t = np.insert(t,len(t),1)
        L = len(t)
        token[b,:L]=t

    token  = torch.from_numpy(token).long()



    #---
    net = Net()
    net.train()

    logit = net(image, token, length)
    print('vocab_size',vocab_size)
    print('max_length',max_length)
    print('')
    print(length)
    print(length.shape)
    print(token.shape)
    print(image.shape)
    print('---')

    print(logit.shape)
    print('---')

In [29]:
run_check_net()

vocab_size 193
max_length 300

[266 223 175 162 128  91  21]
(7,)
torch.Size([7, 300])
torch.Size([7, 3, 224, 224])
---
torch.Size([7, 300, 193])
---
