In [1]:
import torch
import torch.nn as nn
import torchvision

from torchvision.datasets import Flickr8k
from torchvision import transforms
import matplotlib.pyplot as plt
import string
import nltk

import numpy as np

# Загрузка датасета

In [2]:
def im_transform(size):
    def inner(image):
        return transforms.ToTensor()(image.resize((size, size)))
    return inner

def target_transform(targets):
    target = targets[0].lower()
    target = target.translate(str.maketrans('', '', string.punctuation))
    return target

In [3]:
def train_valid_test_split(batch_size=64, image_size=256, prop=0.9):
    data = Flickr8k('Flickr8k/Flicker8k_Dataset', ann_file='Flickr8k/8k-pictures.html',
                transform=im_transform(image_size), target_transform=target_transform)
    inds = list(range(len(data)))
    split1 = int(np.floor(prop * len(inds)))
    split2 = int(np.floor(prop * split1))
    np.random.shuffle(inds)
        
    train_inds = inds[:split2]
    valid_inds = inds[split2:split1]
    test_inds = inds[split1:]
    
    train_sampler = torch.utils.data.SubsetRandomSampler(train_inds)
    valid_sampler = torch.utils.data.SubsetRandomSampler(valid_inds)
    test_sampler = torch.utils.data.SubsetRandomSampler(test_inds)
    
    train_dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, 
                                       sampler=train_sampler, num_workers=0)
    valid_dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, 
                                       sampler=valid_sampler, num_workers=0)
    test_dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, 
                                       sampler=valid_sampler, num_workers=0)
    return train_dataloader, valid_dataloader, test_dataloader

In [4]:
train_dataloader, valid_dataloader, test_dataloader = train_valid_test_split(image_size=256, batch_size=64)

Создаем словарь

In [5]:
word2ind = {'<START>': 0, '<END>': 1, '<UNK>': 2, '<PAD>': 3}
ind2word = ['<START>', '<END>', '<UNK>', '<PAD>']

In [6]:
for X, y in train_dataloader:
    for seq in y:
        target = nltk.tokenize.word_tokenize(seq)
        for word in target:
            if word not in ind2word:
                word2ind[word] = len(ind2word)
                ind2word.append(word)

Функция, преобразующая набор строк в матричку

In [7]:
def batch_capt_transform(batch, seq_len=20):
    seqs = [None] * len(batch)
    for i in range(len(batch)):
        seq = nltk.tokenize.word_tokenize(batch[i])
        seq.append('<END>')
        seq.insert(0, '<START>')
        
        seqs[i] = [word2ind['<PAD>']] * seq_len
        for j, word in enumerate(seq):
            if j > len(seqs[i]) - 1:
                break
            if word in word2ind:
                seqs[i][j] = word2ind[word]
            else:
                seqs[i][j] = word2ind['<UNK>']
        
    return seqs

In [8]:
for X, y in train_dataloader:
    conv = batch_capt_transform(y)
    for i in range(len(y)):
        print(y[i])
        print(conv[i])
    break

two bums are sitting on a sidewalk outside a peace mission in portland oregon
[0, 18, 1627, 198, 269, 23, 4, 309, 260, 4, 810, 1628, 42, 1629, 1630, 1, 3, 3, 3, 3]
a dog catching a biscuit in its mouth
[0, 4, 6, 548, 4, 3166, 42, 282, 87, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
a man in a checked shirt stands next to a large wooden bowl
[0, 4, 41, 42, 4, 1169, 80, 61, 311, 48, 4, 43, 9, 1318, 1, 3, 3, 3, 3, 3]
a person wearing skis makes a jump over the snow
[0, 4, 167, 57, 536, 217, 4, 222, 193, 21, 92, 1, 3, 3, 3, 3, 3, 3, 3, 3]
a brown dog picks up a twig from a stone surface
[0, 4, 32, 6, 2124, 28, 4, 2125, 8, 4, 861, 307, 1, 3, 3, 3, 3, 3, 3, 3]
a man and woman are sitting on concrete stairs
[0, 4, 41, 29, 243, 198, 269, 23, 613, 876, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3]
a medium brown dog is jumping over a short brick wall surrounding some dirt
[0, 4, 707, 32, 6, 25, 7, 193, 4, 244, 360, 320, 2072, 38, 93, 1, 3, 3, 3, 3]
men in drak outfits around acampfire
[0, 204, 42, 3653, 148, 153, 3654, 1

# Encoder

Для того, чтобы использовать pretrained cnn с нашими данными можно действовать несколькими способами:
1. Дообучать всю сетку на наших данных
2. Дообучать только последний полносвязный слой (а до этого всю сетку заморозить).

Пока запускала ноутбук локально, поэтому выбрала второй вариант. Но опыт написания последней домашки показывает, что так ничего не научим, поэтому при переходе на колаб параметры разморожу.

In [9]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size=1024):
        super().__init__()
        self.densenet = torchvision.models.densenet121(pretrained=True)
        
        for param in self.densenet.parameters():
            param.requires_grad = False
        self.densenet.classifier = nn.Sequential(nn.Linear(in_features=1024, out_features=1024), nn.ReLU())
        self.embed = nn.Linear(in_features=1024, out_features=embed_size)
        
    def forward(self, images):
        densenet_outputs = self.densenet(images)
        embeddings = self.embed(densenet_outputs)
        return embeddings

# Decoder

Схема декодера похожа на ту, что в этой статье:
https://arxiv.org/pdf/1411.4555.pdf

In [10]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size):
        super().__init__()
        self.embedding_layer = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(input_size = embed_size,hidden_size = hidden_size)
        self.linear = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, features, captions):
        captions = captions[:, :-1] # убираем <end>, он не подается как x_t
        embed = self.embedding_layer(captions)
        embed = torch.cat((features.unsqueeze(1), embed), dim=1)
        lstm_outputs, hidden = self.lstm(embed)
        return self.linear(lstm_outputs)
    
    def sample(self, features, captions=None, seq_len=20):
        output_sentence = []
        for i in range(seq_len):
            lstm_outputs, hidden = self.lstm(features, captions)
            lstm_outputs = lstm_outputs.squeeze(1)
            out = self.linear(lstm_outputs)
            out = out.max(1)[1]
            output_sentence.append(out.item())
            features = self.embedding_layer(last_pick).unsqueeze(1)
        
        return output_sentence

# Train

In [14]:
embed_size = 1024
hidden_size = 1024
vocab_size = len(ind2word)

num_epochs = 1

In [15]:
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, len(ind2word))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
decoder.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(decoder.parameters()) + list(encoder.embed.parameters()), lr=1e-3)

In [16]:
for epoch in range(1, num_epochs + 1):
    step = 0
    decoder.train()
    encoder.train()
    print('***********Train*************')
    for images, captions in train_dataloader:
        images = torch.Tensor(images).to(device)
        captions = torch.Tensor(batch_capt_transform(captions)).long().to(device)
        
        decoder.zero_grad()
        encoder.zero_grad()
        
        features = encoder(images)
        outputs = decoder(features, captions)
        
        loss = criterion(outputs.view(-1, len(ind2word)), captions.view(-1))
        loss.backward()
        optimizer.step()
      
        print(f'Epoch: [{epoch}/{num_epochs}], Step: {step}, Loss: {loss.item()}, Perplexity: {np.exp(loss.item())}')
        step += 1
        
    # Validate
    encoder.eval()
    decoder.eval()
    print('***********Validation*************')
    with torch.no_grad():
        for images, captions in valid_dataloader:
            images = torch.Tensor(images).to(device)
            captions = torch.Tensor(batch_capt_transform(captions)).long().to(device)

            features = encoder(images)
            outputs = decoder(features, captions)
        
            loss = criterion(outputs.view(-1, len(ind2word)), captions.view(-1))
            print(print(f'Epoch: [{epoch}/{num_epochs}], Step: {step}, Loss: {loss.item()}, Perplexity: {np.exp(loss.item())}'))

***********Train*************
Epoch: [1/1], Step: 0, Loss: 8.295736312866211, Perplexity: 4006.7523839870137
Epoch: [1/1], Step: 1, Loss: 6.209278106689453, Perplexity: 497.3420937351666
Epoch: [1/1], Step: 2, Loss: 5.2470526695251465, Perplexity: 190.00543357697808
Epoch: [1/1], Step: 3, Loss: 4.984119415283203, Perplexity: 146.0748870610914
Epoch: [1/1], Step: 4, Loss: 4.66823673248291, Perplexity: 106.50977153925982
Epoch: [1/1], Step: 5, Loss: 4.0283708572387695, Perplexity: 56.16932881151616
Epoch: [1/1], Step: 6, Loss: 3.5147347450256348, Perplexity: 33.60701231433022
Epoch: [1/1], Step: 7, Loss: 3.2462894916534424, Perplexity: 25.69482196533845
Epoch: [1/1], Step: 8, Loss: 3.0619912147521973, Perplexity: 21.370067217755498
Epoch: [1/1], Step: 9, Loss: 3.121738910675049, Perplexity: 22.68579393352146
Epoch: [1/1], Step: 10, Loss: 3.0718846321105957, Perplexity: 21.582539517227964
Epoch: [1/1], Step: 11, Loss: 3.2466368675231934, Perplexity: 25.7037492769433
Epoch: [1/1], Step: 12

Epoch: [1/1], Step: 97, Loss: 2.4316515922546387, Perplexity: 11.377657822333475
None
Epoch: [1/1], Step: 97, Loss: 2.477728843688965, Perplexity: 11.91417471392518
None
Epoch: [1/1], Step: 97, Loss: 2.3168039321899414, Perplexity: 10.143204077144656
None
Epoch: [1/1], Step: 97, Loss: 2.2650763988494873, Perplexity: 9.631860436057192
None
Epoch: [1/1], Step: 97, Loss: 2.379883289337158, Perplexity: 10.803641890142043
None
Epoch: [1/1], Step: 97, Loss: 2.3781332969665527, Perplexity: 10.78475213254616
None


Ну. видно что оно чему-то учится. 