In [None]:
!wget http://images.cocodataset.org/zips/train2014.zip
!wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
!unzip train2014.zip
!unzip annotations_trainval2014.zip
!rm train2014.zip
!rm annotations_trainval2014.zip

In [None]:
!python3.9 -m pip install padl
!python3.9 -m pip install transformers

In [1]:
from transformers import pipeline
import padl

tg = pipeline('text-generation')
pl = pipeline('image-classification')

No model was supplied, defaulted to gpt2 (https://huggingface.co/gpt2)
No model was supplied, defaulted to google/vit-base-patch16-224 (https://huggingface.co/google/vit-base-patch16-224)


In [2]:
import pandas
import random
import json

with open('annotations/captions_train2014.json') as f:
    data = json.load(f)

image_lookup = {}
for image in data['images']:
    image_lookup[image['id']] = 'train2014/' + image['file_name']
    
annotations = []
for annotation in data['annotations']:
    annotations.append({'image': image_lookup[annotation['image_id']], 'caption': annotation['caption']})
    
annotations = pandas.DataFrame(annotations)
all_images = annotations['image'].unique().tolist()

random.shuffle(all_images)

train_images = all_images[:-1000]
valid_images = all_images[-1000:]

train_annotations = annotations[annotations['image'].isin(train_images)].to_dict('split')['data']
valid_annotations = annotations[annotations['image'].isin(valid_images)].to_dict('split')['data']
for x in train_annotations:
    if not x[1].endswith('.'):
        x[1] += '.'

In [3]:
import torch


@padl.transform
class SimpleRNN(torch.nn.Module):
    def __init__(self, rnn, proj, embed):
        super().__init__()
        self.rnn = rnn
        self.proj = proj
        self.embed = embed
        
    def forward(self, hidden, input_ids):
        return self.proj(self.rnn(self.embed(input_ids), hidden)[0])
        
        
@padl.transform
class Greedy(torch.nn.Module):
    def __init__(self, rnn, proj, embed, end, max_len=20):
        super().__init__()
        self.rnn = rnn
        self.proj = proj
        self.embed = embed
        self.end = end
        self.max_len = max_len
        
    def forward(self, hidden):
        input_ids = [self.proj(hidden).topk(1)[1].item()]
        it = 0
        while True:
            hidden = self.rnn(self.embed(torch.tensor([input_ids[-1]])[None, :]), hidden)[0]
            input_ids.append(self.proj(hidden).squeeze().topk(1)[1].item())
            if input_ids[-1] == self.end:
                break
            if it >= self.max_len:
                break
            it +=1
        return torch.tensor(input_ids)

In [4]:
conditioner = padl.transform(torch.nn.Linear(768, 512))
rnn = torch.nn.GRU(64, 512, 1, batch_first=True)
embed = torch.nn.Embedding(tg.tokenizer.vocab_size, 64)
proj = torch.nn.Linear(512, tg.tokenizer.vocab_size)

In [None]:
generator = Greedy(rnn, proj, embed, tg.tokenizer.encode('.#')[0])
logits = SimpleRNN(rnn, proj, embed)

In [None]:
pl.model.config.output_hidden_states = True
image_features = (
    padl.transform(pl.preprocess)
    >> padl.transform(lambda x: x['pixel_values'][0])
    >> padl.batch
    >> padl.transform(pl.model.vit)
    >> padl.transform(lambda x: x.last_hidden_state[:, 0, :])
)

@padl.transform
def myloss(x, y):
    targets, lens = y
    loss = 0
    for i in range(x.shape[0]):
        loss += torch.nn.functional.cross_entropy(x[i, :lens[i], :], targets[i, :lens[i]])
    return loss.div(x.shape[0])


@padl.transform
def mypad(x):
    len_ = len(x)
    x = x[:15]
    x = torch.cat([x, torch.zeros(15 - len(x))]).type(torch.long)
    return x, len_

text_preprocess = (
    padl.transform(tg.tokenizer.encode)
    >> padl.transform(torch.tensor)
    >> mypad
    >> padl.batch
)

training_model = (
    (image_features >> conditioner >> padl.same.unsqueeze(0)) / (padl.transform(lambda x: '!' + x) >> text_preprocess >> padl.same[0])
    >> logits
)

inference_model = (
    image_features
    >> conditioner 
    >> padl.same.unsqueeze(1)
    >> generator
    >> padl.transform(tg.tokenizer.decode)
)

loss = (
    training_model + (padl.same[1] >> padl.transform(lambda x: x + '#') >> text_preprocess)
    >> myloss
)

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

for p in pl.model.vit.parameters():
    p.requires_grad = False
    
o = torch.optim.Adam([p for p in loss.pd_parameters() if p.requires_grad], lr=0.001)
loss.pd_to('cuda')
    
for it, l_ in enumerate(loss.train_apply(train_annotations, batch_size=250, num_workers=5)):
    o.zero_grad()
    l_.backward()
    o.step()
    print(f'TRAIN iteration: {it}; loss; {l_};')

In [None]:
import PIL.Image
from IPython.display import display
import random

image = valid_annotations[random.randrange(len(valid_annotations))][0]

display(PIL.Image.open(image))
inference_model.infer_apply(image)