In [0]:
!pip install -q transformers

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

from transformers import AutoModelWithLMHead, AutoTokenizer

from collections import namedtuple
from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

In [0]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [209]:
BATCH_SIZE = 2
transform = transforms.Compose([
     transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True)

valset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
valloader = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE,
                                         shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [0]:
class Embedding(nn.Module):
    def __init__(self, cnn, embed_size):
        super().__init__()
        self.cnn = cnn
        self.embed_size = embed_size
        _, last_module = list(cnn.named_children())[-1]
        self.proj = nn.Linear(last_module.out_features, embed_size)

    def forward(self, x):
        x = self.cnn(x)
        return self.proj(x)

    def train(self, mode=True):
        self.cnn.eval()
        self.proj.train(mode)

In [0]:
SPECIAL_TOKENS = ('img', 'desc', 'pad')

In [0]:
SpecialIds = namedtuple('speacil_ids', SPECIAL_TOKENS)

In [0]:
class Image2TextDescriptor(nn.Module):
    def __init__(self, pretrained_model_text='gpt2', 
                     pretrained_model_image='resnet18'):
        super().__init__()
        self.tokenizer, self.special_ids = self._build_tokenizer(pretrained_model_text)
        self.gpt2 = self._build_gpt2(len(self.tokenizer), pretrained_model_text).to(device)

        self.img_tok_emb, self.desc_tok_emb = self.gpt2.transformer.wte(
                    torch.tensor([self.special_ids.img, self.special_ids.desc]).to(device)
        )
        self.img_tok_emb.to(device)
        self.desc_tok_emb.to(device)

        cnn = self._build_cnn(pretrained_model_image).to(device)
        self.embedding = Embedding(cnn, self.gpt2.config.hidden_size).to(device)
        self.to(device)

    def forward(self, img_data, text_tok_ids=None):
        context_embeds = self.get_context_embeds(img_data).to(device)
        labels_ids = None
        if text_tok_ids:
            text_tok_ids = torch.tensor(text_tok_ids).to(device)
            text_embeds = self.gpt2.transformer.wte(text_tok_ids).to(device)
            full_embeds = torch.cat((context_embeds, text_embeds), dim=1).to(device)
            if self.training:
                labels_ids = self.get_labels(
                    full_embeds.shape[:-1],
                    context_embeds.size(1),
                    text_tok_ids
                ).to(device)
        else:
            full_embeds = context_embeds
        
        out = self.gpt2(inputs_embeds=full_embeds, labels=labels_ids)        
        return out

    def get_context_embeds(self, img_data):
        img_embeds = self.embedding(img_data)[:, None, :] # new shape: (batch_size, seq_len, embed_size)
        context_embeds = torch.cat((self.img_tok_emb.expand_as(img_embeds), 
                                    img_embeds, 
                                    self.desc_tok_emb.expand_as(img_embeds)), dim=1)
        return context_embeds

    def get_labels(self, shape, context_len, tok_ids):
        labels_ids = torch.empty(shape, dtype=torch.long).fill_(-100) # -100 is mask value for labels in hugginface models
        mask = (tok_ids == self.special_ids.pad)
        labels_ids[:, context_len:] = tok_ids.masked_fill(mask, -100)
        print(label_ids)

        return label_ids

    def _build_tokenizer(self, pretrained_model):
        special_dct = {t: f"<{t.upper()}>" for t in SPECIAL_TOKENS}
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
        tokenizer.add_special_tokens({'additional_special_tokens': list(special_dct.values())})
        special_ids = SpecialIds(**{k: tokenizer.encode(v)[0] for k, v in special_dct.items()})
        
        return tokenizer, special_ids

    def _build_gpt2(self, vocab_size, pretrained_model):
        gpt2 = AutoModelWithLMHead.from_pretrained(pretrained_model)
        gpt2.resize_token_embeddings(vocab_size)
        
        return gpt2

    def _build_cnn(self, pretrained_model):
        if pretrained_model == 'resnet18':
            return models.resnet18(pretrained=True)
        else:
            raise ValueError(f'{pretrained_model} is not supported yet :(')

In [0]:
descriptor = Image2TextDescriptor()

In [0]:
data, target = next(iter(trainloader))

In [0]:
def pad(seq, max_len, pad_elem=special_ids.PAD):
    return seq + [pad_elem] * max(0, max_len - len(seq))

In [0]:
ids_1 = tok.encode('The man is seeting on chair')
ids_2 = tok.encode('The cat eat meet')
text_ids = [ids_1, pad(ids_2, len(ids_1))]

In [280]:
len(text_ids[0])

7

In [281]:
text_ids

[[464, 582, 318, 384, 13629, 319, 5118],
 [464, 3797, 4483, 1826, 50259, 50259, 50259]]

In [283]:
out = descriptor(data.float().to(device), text_ids)

tensor([[ -100,  -100,  -100,   464,   582,   318,   384, 13629,   319,  5118],
        [ -100,  -100,  -100,   464,  3797,  4483,  1826,  -100,  -100,  -100]])


In [286]:
out[0]

tensor(82.0220, device='cuda:0', grad_fn=<NllLossBackward>)