In [0]:
!pip install -q transformers

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

import pandas as pd

from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelWithLMHead, AutoTokenizer

from PIL import Image
import requests
from io import BytesIO
from IPython.display import display

from collections import namedtuple
from time import perf_counter
from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

In [0]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Создадим датасет

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
!cp -r drive/My\ Drive/image_captioning/ data/

In [0]:
transform = transforms.Compose([
     transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
])

class GCCDataset(Dataset):
    def __init__(self, path='data/Train_GCC-training.tsv', img_transform=transform, tokenizer=None, max_length=512):
        self.data = pd.read_csv(path, sep='\t', header=None, names=['desc', 'url'])
        self.img_transform = img_transform
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = desc = url = None
        while img is None:
            desc, url = self.data.iloc[idx]
            try:
                response = requests.get(url)
                img = Image.open(BytesIO(response.content))
                img = self.img_transform(img)
                break
            except Exception as e:
                #print(f'nope: {url}')
                #print(str(e))
                img = None
                idx += 1
        text_tok_ids = self.tokenizer.encode(desc, max_length=self.max_length)

        return img, text_tok_ids

In [0]:
def get_dataloader(dataset, batch_size, pad_elem):
    def pad(seq, max_len, pad_elem):
        return seq + [pad_elem] * max(0, max_len - len(seq))

    def collate_fn(batch_data):
        batch_img, batch_ids = list(zip(*batch_data))
        
        batch_len = max(map(len, batch_ids))
        batch_ids = torch.tensor(
            [pad(ids, batch_len, pad_elem) 
             for ids in batch_ids]
        ).long()
        batch_mask = batch_ids.ne(pad_elem).int()
         
        batch_img = torch.stack(batch_img, 0)

        return batch_img, batch_ids, batch_mask

    return DataLoader(
        dataset=dataset, batch_size=batch_size, shuffle=True,
        collate_fn=collate_fn, pin_memory=True, num_workers=0 
    )

## Основные архитектуры

In [0]:
class Embedding(nn.Module):
    def __init__(self, cnn, embed_size):
        super().__init__()
        self.cnn = cnn
        self.embed_size = embed_size
        _, last_module = list(cnn.named_children())[-1]
        self.proj = nn.Linear(last_module.out_features, embed_size)

    def forward(self, x):
        x = self.cnn(x)
        return self.proj(x)

    def train(self, mode=True):
        self.cnn.eval()
        self.proj.train(mode)

In [0]:
SPECIAL_TOKENS = ('img', 'desc', 'pad')

In [0]:
SpecialIds = namedtuple('speacil_ids', SPECIAL_TOKENS)

In [0]:
class Image2TextDescriptor(nn.Module):
    def __init__(self, pretrained_model_text='gpt2', 
                     pretrained_model_image='resnet18'):
        super().__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.tokenizer, self.special_ids = self._build_tokenizer(pretrained_model_text)
        self.gpt2 = self._build_gpt2(len(self.tokenizer), pretrained_model_text)

        cnn = self._build_cnn(pretrained_model_image)
        self.embedding = Embedding(cnn, self.gpt2.config.hidden_size)#.to(device)
        self.to(device)

    def forward(self, img_data, text_tok_ids=None, attn_mask=None):
        context_embeds = self.get_context_embeds(img_data)
        labels_ids = None
        if text_tok_ids is not None:
            text_embeds = self.gpt2.transformer.wte(text_tok_ids)
            full_embeds = torch.cat((context_embeds, text_embeds), dim=1)
            if self.training:
                labels_ids = self.get_labels(
                    full_embeds.shape[:-1],
                    context_embeds.size(1),
                    text_tok_ids
                ).to(self.device)
        else:
            full_embeds = context_embeds
        
        out = self.gpt2(inputs_embeds=full_embeds, labels=labels_ids, attention_mask=attn_mask)        
        return out

    def get_context_embeds(self, img_data):
        img_tok_emb, desc_tok_emb = self.gpt2.transformer.wte(
            torch.tensor([self.special_ids.img, self.special_ids.desc]).to(self.device)
        )

        img_embeds = self.embedding(img_data)[:, None, :] # new shape: (batch_size, seq_len, embed_size)
        context_embeds = torch.cat((img_tok_emb.expand_as(img_embeds), 
                                    img_embeds, 
                                    desc_tok_emb.expand_as(img_embeds)), dim=1)
        return context_embeds

    def get_labels(self, shape, context_len, tok_ids):
        labels_ids = torch.empty(shape, dtype=torch.long).fill_(-100) # -100 is mask value for labels in hugginface models
        mask = (tok_ids == self.special_ids.pad)
        labels_ids[:, context_len:] = tok_ids.masked_fill(mask, -100)

        return labels_ids

    def _build_tokenizer(self, pretrained_model):
        special_dct = {t: f"<{t.upper()}>" for t in SPECIAL_TOKENS}
        tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
        tokenizer.add_special_tokens({'additional_special_tokens': list(special_dct.values())})
        special_ids = SpecialIds(**{k: tokenizer.encode(v)[0] for k, v in special_dct.items()})
        
        return tokenizer, special_ids

    def _build_gpt2(self, vocab_size, pretrained_model):
        gpt2 = AutoModelWithLMHead.from_pretrained(pretrained_model)
        gpt2.resize_token_embeddings(vocab_size)
        
        return gpt2

    def _build_cnn(self, pretrained_model):
        if pretrained_model == 'resnet18':
            return models.resnet18(pretrained=True)
        else:
            raise ValueError(f'{pretrained_model} is not supported yet :(')

## Тренировочный код

In [0]:
def train_epoch(dataloader, model, optimizer):
    torch.cuda.empty_cache()

    pbar = tqdm(enumerate(dataloader), total=len(dataloader), leave=False)
    for idx_batch, batch in pbar:
        imgs, ids, mask = [x.to(model.device) for x in batch]

        loss, *_ = model(imgs, ids)#, attn_mask=mask)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        pbar.set_description(f'loss:{loss.item():.4f}')
        
        if (idx_batch + 1) % 200 == 0:
            model.save()

    model.save()

In [0]:
def train(model, dataloader, n_epochs=5, batch_size=16, lr=1e-2):
    model.train()

    optimizer = optim.Adam(
        [param for param in model.parameters()
               if param.requires_grad], 
        lr=lr
    )

    for epoch in range(n_epochs):
        train_epoch(dataloader, model, optimizer)
        print(f'Epoch #{epoch} finished')

In [0]:
descriptor = Image2TextDescriptor()

In [0]:
path='data/Train_GCC-training.tsv'
batch_size=32
gcc_dataset = GCCDataset(path=path, 
                         img_transform=transform, 
                         tokenizer=descriptor.tokenizer, 
                         max_length=descriptor.gpt2.config.n_positions - 10)
dataloader = get_dataloader(gcc_dataset, batch_size, descriptor.special_ids.pad)

In [16]:
train(descriptor, dataloader)

HBox(children=(FloatProgress(value=0.0, max=103698.0), HTML(value='')))

KeyboardInterrupt: ignored