In [None]:
import numpy as np
import cv2
import os
import torch
import spacy
import pandas as pd
import pickle
import torchvision.transforms as T
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
import torchvision.models.video as video


from tqdm.notebook import tqdm
from PIL import Image
from torch.utils.data import DataLoader,Dataset
from collections import Counter
from IPython.display import clear_output

In [None]:
!python -m spacy download ru_core_news_lg

Collecting ru-core-news-lg==3.6.0
  Downloading https://github.com/explosion/spacy-models/releases/download/ru_core_news_lg-3.6.0/ru_core_news_lg-3.6.0-py3-none-any.whl (513.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m513.4/513.4 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('ru_core_news_lg')


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!cp ./drive/MyDrive/rutube_hackathon_novosibirsk.zip ./rutube_hackathon_novosibirsk.zip

In [None]:
!zip -FF ./rutube_hackathon_novosibirsk.zip --out ./pleasework.zip
clear_output()

In [None]:
!unzip -qq ./pleasework.zip

In [None]:
spacy_ru = spacy.load("ru_core_news_lg")

text = "Сергей Чуприн любит hi! ОН - КАРТОфЕЛЕЛЮБ!!!"
[token.text.lower() for token in spacy_ru.tokenizer(text)]

['сергей',
 'чуприн',
 'любит',
 'hi',
 '!',
 'он',
 '-',
 'картофелелюб',
 '!',
 '!',
 '!']

In [None]:
train_csv = pd.read_csv("./rutube_hackathon_novosibirsk/train/train.csv")

In [None]:
train_csv["len"] = train_csv["description"].apply(lambda x: len(x.split()))

In [None]:
train_csv.len.max()

321

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [None]:
class Vocabulary:
    def __init__(self,freq_threshold):
        self.itos = {0:"<PAD>", 1:"<SOS>", 2:"<EOS>", 3:"<UNK>"}

        self.stoi = {v:k for k,v in self.itos.items()}

        self.freq_threshold = freq_threshold

    def __len__(self): return len(self.itos)

    @staticmethod
    def tokenize(text):
        return [token.text.lower() for token in spacy_ru.tokenizer(text)]

    def build_vocab(self, sentence_list):
        frequencies = Counter()
        idx = 4

        for sentence in tqdm(sentence_list):
            for word in self.tokenize(sentence):
                frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    @staticmethod
    def del_timestamps(texts):
        output = ""
        for text in texts.split("\n"):
          output += "".join(text.split("]  ")[1:]) + " "
        return " ".join(text)

    def build_vocab(self, sentence_list, texts_list):
      frequencies = Counter()
      idx = 4

      for sentence in tqdm(sentence_list):
          for word in self.tokenize(sentence):
              frequencies[word] += 1

              if frequencies[word] == self.freq_threshold:
                  self.stoi[word] = idx
                  self.itos[idx] = word
                  idx += 1

      for text in tqdm(texts_list):
          text = self.del_timestamps(text)
          for word in self.tokenize(text):
              frequencies[word] += 1

              if frequencies[word] == self.freq_threshold:
                  self.stoi[word] = idx
                  self.itos[idx] = word
                  idx += 1

    def numericalize(self,text):
        tokenized_text = self.tokenize(text)
        return [ self.stoi[token] if token in self.stoi else self.stoi["<UNK>"] for token in tokenized_text ]

In [None]:
class TextDataset(Dataset):
    def __init__(self, root_dir, captions, captions_files, transform=None,freq_threshold=1, build=True):
        self.root_dir = root_dir
        self.transform = transform

        self.texts = []
        self.captions = captions

        for file in tqdm(captions_files):
            with open(self.root_dir + "/train_stt/" + file) as cf:
                self.texts.append(" ".join(cf.readlines()))

        self.vocab = Vocabulary(freq_threshold)
        if build:
            self.vocab.build_vocab(self.captions, self.texts)


    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        caption = self.captions[idx]
        text = self.texts[idx]

        caption_vec = []
        text_vec = []
        caption_vec += [self.vocab.stoi["<SOS>"]]
        caption_vec += self.vocab.numericalize(caption)
        caption_vec += [self.vocab.stoi["<EOS>"]]

        text_vec += [self.vocab.stoi["<SOS>"]]
        text_vec += self.vocab.numericalize(text)
        text_vec += [self.vocab.stoi["<EOS>"]]

        if len(caption_vec) < 300:
            for i in range(300 - len(caption_vec)):
                caption_vec.append(0)

        if len(text_vec) < 300:
            for i in range(300 - len(text_vec)):
                text_vec.append(0)

        return torch.tensor(text_vec[:300]), torch.tensor(caption_vec[:300])

In [None]:
class VideoDataset(Dataset):
    def __init__(self, root_dir, captions, video_files, texts, transform=None,freq_threshold=1, build=True):
        self.root_dir = root_dir
        self.transform = transform

        self.imgs = video_files
        self.captions = captions

        text = []
        for file in tqdm(text):
            with open(self.root_dir + "/train_stt/" + file) as cf:
                    text.append(" ".join(cf.readlines()))

        self.vocab = Vocabulary(freq_threshold)
        if build:
            self.vocab.build_vocab(self.captions, texts)


    def __len__(self):
        return len(self.imgs)

    def _read_video(self, path, frames_num=25, window=30):
        frames = []
        cap = cv2.VideoCapture(self.root_dir + "/train_video/" + path)

        fps = int(cap.get(cv2.CAP_PROP_FPS))

        length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        N = length // (frames_num)

        current_frame = 1
        for i in range(length):
            ret, frame = cap.read(current_frame)
            if ret and i == current_frame and len(frames) < frames_num:
                size = 226, 226
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = cv2.resize(frame, size)
                frames.append(frame)
                current_frame += N


        cap.release()

        return np.array(frames)

    def __getitem__(self,idx):
        caption = self.captions[idx]
        img = self.imgs[idx]
        img = np.rollaxis(self._read_video(self.imgs[idx]), 3, 0)
        img = np.array(np.array(img) / 255, dtype=np.float32)


        caption_vec = []
        caption_vec += [self.vocab.stoi["<SOS>"]]
        caption_vec += self.vocab.numericalize(caption)
        caption_vec += [self.vocab.stoi["<EOS>"]]
        if len(caption_vec) < 300:
            for i in range(300 - len(caption_vec)):
                caption_vec.append(0)

        return torch.tensor(img), torch.tensor(caption_vec[:300])


In [None]:
class TextEncoderRNN(nn.Module):
    def __init__(self, vocab_size=300, emb_dim=300, hid_dim=300, n_layers=3, dropout=0.5):
        super().__init__()

        self.hid_dim = hid_dim
        self.n_layers = n_layers

        self.embedding = nn.Embedding(vocab_size, emb_dim)

        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)

        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, (hidden, cell) = self.rnn(embedded)
        return outputs


In [None]:
class TextAttention(nn.Module):
    def __init__(self, encoder_dim,decoder_dim,attention_dim):
        super(TextAttention, self).__init__()

        self.attention_dim = attention_dim

        self.W = nn.Linear(decoder_dim,attention_dim)
        self.U = nn.Linear(encoder_dim,attention_dim)

        self.A = nn.Linear(attention_dim,1)




    def forward(self, features, hidden_state):
        u_hs = self.U(features)
        w_ah = self.W(hidden_state)

        combined_states = torch.tanh(u_hs + w_ah.unsqueeze(1))

        attention_scores = self.A(combined_states)
        attention_scores = attention_scores.squeeze(2)


        alpha = F.softmax(attention_scores,dim=1)

        attention_weights = features * alpha.unsqueeze(2)
        attention_weights = attention_weights.sum(dim=1)

        return alpha,attention_weights

In [None]:
class TextDecoderRNN(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()

        self.vocab_size = vocab_size
        self.attention_dim = attention_dim
        self.decoder_dim = decoder_dim

        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.attention = TextAttention(encoder_dim,decoder_dim,attention_dim)


        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.lstm_cell = nn.LSTMCell(embed_size+encoder_dim,decoder_dim,bias=True)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)


        self.fcn = nn.Linear(decoder_dim,vocab_size)
        self.drop = nn.Dropout(drop_prob)



    def forward(self, features, captions):
        embeds = self.embedding(captions)

        h, c = self.init_hidden_state(features)

        seq_length = len(captions[0])-1
        batch_size = captions.size(0)
        num_features = features.size(1)

        preds = torch.zeros(batch_size, seq_length, self.vocab_size).to(device)
        alphas = torch.zeros(batch_size, seq_length,num_features).to(device)

        for s in range(seq_length):
            alpha,context = self.attention(features, h)
            lstm_input = torch.cat((embeds[:, s], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))

            output = self.fcn(self.drop(h))

            preds[:,s] = output
            alphas[:,s] = alpha


        return preds, alphas

    def generate_caption(self,features, max_len=300 ,vocab=None):

        batch_size = features.size(0)
        h, c = self.init_hidden_state(features)

        alphas = []
        word = torch.tensor(vocab.stoi['<SOS>']).view(1,-1).to(device)
        embeds = self.embedding(word)


        captions = []

        for i in range(max_len):
            alpha,context = self.attention(features, h)

            alphas.append(alpha.cpu().detach().numpy())

            lstm_input = torch.cat((embeds[:, 0], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            output = self.fcn(self.drop(h))
            output = output.view(batch_size,-1)


            predicted_word_idx = output.argmax(dim=1)

            captions.append(predicted_word_idx.item())

            if vocab.itos[predicted_word_idx.item()] == "<EOS>":
                break

            embeds = self.embedding(predicted_word_idx.unsqueeze(0))

        return [vocab.itos[idx] for idx in captions],alphas


    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)
        c = self.init_c(mean_encoder_out)
        return h, c

In [None]:


class TextEncoderDecoder(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()
        self.encoder = TextEncoderRNN(
            vocab_size = vocab_size
            )
        self.decoder = TextDecoderRNN(
            embed_size=embed_size,
            vocab_size = vocab_size,
            attention_dim=attention_dim,
            encoder_dim=encoder_dim,
            decoder_dim=decoder_dim
        )

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs



In [None]:
class VideoEncoderCNN(nn.Module):
    def __init__(self):
        super(VideoEncoderCNN, self).__init__()
        self.swin_t = video.swin3d_s(pretrained=True)
        for param in self.swin_t.parameters():
            param.requires_grad_(False)
        modules = list(self.swin_t.children())[:-2]
        self.swin_t = nn.Sequential(*modules)


    def forward(self, images):
        features = self.swin_t(images)
        features = features.view(features.size(0), -1, features.size(-1))
        return features

In [None]:
class VideoAttention(nn.Module):
    def __init__(self, encoder_dim,decoder_dim,attention_dim):
        super(VideoAttention, self).__init__()

        self.attention_dim = attention_dim

        self.W = nn.Linear(decoder_dim,attention_dim)
        self.U = nn.Linear(encoder_dim,attention_dim)

        self.A = nn.Linear(attention_dim,1)




    def forward(self, features, hidden_state):
        u_hs = self.U(features)
        w_ah = self.W(hidden_state)

        combined_states = torch.tanh(u_hs + w_ah.unsqueeze(1))

        attention_scores = self.A(combined_states)
        attention_scores = attention_scores.squeeze(2)


        alpha = F.softmax(attention_scores,dim=1)

        attention_weights = features * alpha.unsqueeze(2)
        attention_weights = attention_weights.sum(dim=1)

        return alpha,attention_weights

In [None]:

class VideoDecoderRNN(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()

        self.vocab_size = vocab_size
        self.attention_dim = attention_dim
        self.decoder_dim = decoder_dim

        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.attention = VideoAttention(encoder_dim,decoder_dim,attention_dim)


        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.lstm_cell = nn.LSTMCell(embed_size+encoder_dim,decoder_dim,bias=True)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)


        self.fcn = nn.Linear(decoder_dim,vocab_size)
        self.drop = nn.Dropout(drop_prob)



    def forward(self, features, captions):
        embeds = self.embedding(captions)

        h, c = self.init_hidden_state(features)

        seq_length = len(captions[0])-1
        batch_size = captions.size(0)
        num_features = features.size(1)

        preds = torch.zeros(batch_size, seq_length, self.vocab_size).to(device)
        alphas = torch.zeros(batch_size, seq_length,num_features).to(device)

        for s in range(seq_length):
            alpha,context = self.attention(features, h)
            lstm_input = torch.cat((embeds[:, s], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))

            output = self.fcn(self.drop(h))

            preds[:,s] = output
            alphas[:,s] = alpha


        return preds, alphas

    def generate_caption(self,features, max_len=300 ,vocab=None):

        batch_size = features.size(0)
        h, c = self.init_hidden_state(features)

        alphas = []
        word = torch.tensor(vocab.stoi['<SOS>']).view(1,-1).to(device)
        embeds = self.embedding(word)


        captions = []

        for i in range(max_len):
            alpha,context = self.attention(features, h)

            alphas.append(alpha.cpu().detach().numpy())

            lstm_input = torch.cat((embeds[:, 0], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            output = self.fcn(self.drop(h))
            output = output.view(batch_size,-1)


            predicted_word_idx = output.argmax(dim=1)

            captions.append(predicted_word_idx.item())

            if vocab.itos[predicted_word_idx.item()] == "<EOS>":
                break

            embeds = self.embedding(predicted_word_idx.unsqueeze(0))

        return [vocab.itos[idx] for idx in captions],alphas


    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)
        c = self.init_c(mean_encoder_out)
        return h, c

In [None]:


class VideoEncoderDecoder(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()
        self.encoder = VideoEncoderCNN()
        self.decoder = VideoDecoderRNN(
            embed_size=embed_size,
            vocab_size=vocab_size,
            attention_dim=attention_dim,
            encoder_dim=encoder_dim,
            decoder_dim=decoder_dim
        )

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs



In [None]:
class ModelDataset(Dataset):
    def __init__(
        self, root_dir, captions, video_files, text_files, vocab: Vocabulary,
        video_model_state, text_model_state, batch_size, transform=None, freq_threshold=1
        ):
        self.transforms = T.Compose([
            T.Resize(226),
            T.RandomCrop(224),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
        ])


        self.video_dataset =  VideoDataset(
            root_dir=root_dir,
            captions=captions,
            video_files=video_files,
            texts=text_files,
            transform=transforms,
            build=False
        )
        self.text_dataset = TextDataset(
            root_dir=root_dir,
            captions=captions,
            captions_files=text_files,
            transform=transforms,
            build=False
        )
        self.vocab = vocab
        self.video_dataset.vocab = self.vocab
        self.text_dataset.vocab = self.vocab

        self.text_model = TextEncoderDecoder(
            embed_size=text_model_state["embed_size"],
            vocab_size = text_model_state['vocab_size'],
            attention_dim=text_model_state['attention_dim'],
            encoder_dim=text_model_state['encoder_dim'],
            decoder_dim=text_model_state['decoder_dim']
        ).to(device)
        self.text_model.load_state_dict(text_model_state["state_dict"])
        self.text_model.requires_grad_ = False
        self.text_model.eval()

        self.video_model = VideoEncoderDecoder(
            embed_size=video_model_state["embed_size"],
            vocab_size = video_model_state['vocab_size'],
            attention_dim=video_model_state['attention_dim'],
            encoder_dim=video_model_state['encoder_dim'],
            decoder_dim=video_model_state['decoder_dim']
        ).to(device)
        self.video_model.load_state_dict(video_model_state["state_dict"])
        self.video_model.requires_grad_ = False
        self.video_model.eval()


    def __len__(self):
        return len(self.video_dataset)


    def __getitem__(self,idx):
        text, caption = self.text_dataset[idx]
        video, caption = self.video_dataset[idx]
        text = text.unsqueeze(0)
        video = video.unsqueeze(0)
        # print(text.shape, video.shape)
        with torch.no_grad():
            features = self.video_model.encoder(video[0:1].to(device))
            caps, alphas = self.video_model.decoder.generate_caption(features, vocab=self.text_dataset.vocab)
            video_caption = ' '.join(caps)

        with torch.no_grad():
            features = self.text_model.encoder(text[0:1].to(device))
            caps, alphas = self.text_model.decoder.generate_caption(features, vocab=self.text_dataset.vocab)
            text_caption = ' '.join(caps)

        video_vec = []
        text_vec = []
        caption_vec = []

        video_vec += [self.vocab.stoi["<SOS>"]]
        video_vec += self.vocab.numericalize(video_caption)
        video_vec += [self.vocab.stoi["<EOS>"]]

        text_vec += [self.vocab.stoi["<SOS>"]]
        text_vec += self.vocab.numericalize(text_caption)
        text_vec += [self.vocab.stoi["<EOS>"]]

        if len(video_vec) < 300:
            for i in range(300 - len(video_vec)):
                video_vec.append(0)

        if len(text_vec) < 300:
            for i in range(300 - len(text_vec)):
                text_vec.append(0)

        return torch.cat([
            torch.tensor(text_vec[:300]), torch.tensor(video_vec[:300])
        ]), caption

In [None]:
data_location =  "./rutube_hackathon_novosibirsk/train"
BATCH_SIZE = 50

transforms = T.Compose([
    T.Resize(226),
    T.RandomCrop(224),
    T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
])


dataset = ModelDataset(
    root_dir=data_location,
    captions=train_csv.description.tolist(),
    video_files=train_csv.video_name.tolist(),
    text_files=train_csv.stt_name.tolist(),
    vocab=pickle.load(open("./vocab.pkl", "rb")),
    video_model_state=torch.load("./drive/MyDrive/video_model.pth"),
    text_model_state=torch.load("./drive/MyDrive/text_model.pth"),
    batch_size=BATCH_SIZE
)


data_loader = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

vocab_size = len(dataset.vocab)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

0it [00:00, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]



device(type='cuda', index=0)

In [None]:
class EncoderRNN(nn.Module):
    def __init__(self, vocab_size=300, emb_dim=300, hid_dim=300, n_layers=3, dropout=0.5):
        super().__init__()

        self.hid_dim = hid_dim
        self.n_layers = n_layers

        self.embedding = nn.Embedding(vocab_size, emb_dim)

        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)

        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, (hidden, cell) = self.rnn(embedded)
        return outputs


In [None]:
class Attention(nn.Module):
    def __init__(self, encoder_dim,decoder_dim,attention_dim):
        super(Attention, self).__init__()

        self.attention_dim = attention_dim

        self.W = nn.Linear(decoder_dim,attention_dim)
        self.U = nn.Linear(encoder_dim,attention_dim)

        self.A = nn.Linear(attention_dim,1)




    def forward(self, features, hidden_state):
        u_hs = self.U(features)
        w_ah = self.W(hidden_state)

        combined_states = torch.tanh(u_hs + w_ah.unsqueeze(1))

        attention_scores = self.A(combined_states)
        attention_scores = attention_scores.squeeze(2)


        alpha = F.softmax(attention_scores,dim=1)

        attention_weights = features * alpha.unsqueeze(2)
        attention_weights = attention_weights.sum(dim=1)

        return alpha,attention_weights

In [None]:
class DecoderRNN(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()

        self.vocab_size = vocab_size
        self.attention_dim = attention_dim
        self.decoder_dim = decoder_dim

        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.attention = Attention(encoder_dim,decoder_dim,attention_dim)


        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.lstm_cell = nn.LSTMCell(embed_size+encoder_dim,decoder_dim,bias=True)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)


        self.fcn = nn.Linear(decoder_dim,vocab_size)
        self.drop = nn.Dropout(drop_prob)



    def forward(self, features, captions):
        embeds = self.embedding(captions)

        h, c = self.init_hidden_state(features)

        seq_length = len(captions[0])-1
        batch_size = captions.size(0)
        num_features = features.size(1)

        preds = torch.zeros(batch_size, seq_length, self.vocab_size).to(device)
        alphas = torch.zeros(batch_size, seq_length,num_features).to(device)

        for s in range(seq_length):
            alpha,context = self.attention(features, h)
            lstm_input = torch.cat((embeds[:, s], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))

            output = self.fcn(self.drop(h))

            preds[:,s] = output
            alphas[:,s] = alpha


        return preds, alphas

    def generate_caption(self,features, max_len=300 ,vocab=None):

        batch_size = features.size(0)
        h, c = self.init_hidden_state(features)

        alphas = []
        word = torch.tensor(vocab.stoi['<SOS>']).view(1,-1).to(device)
        embeds = self.embedding(word)


        captions = []

        for i in range(max_len):
            alpha,context = self.attention(features, h)

            alphas.append(alpha.cpu().detach().numpy())

            lstm_input = torch.cat((embeds[:, 0], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            output = self.fcn(self.drop(h))
            output = output.view(batch_size,-1)


            predicted_word_idx = output.argmax(dim=1)

            captions.append(predicted_word_idx.item())

            if vocab.itos[predicted_word_idx.item()] == "<EOS>":
                break

            embeds = self.embedding(predicted_word_idx.unsqueeze(0))

        return [vocab.itos[idx] for idx in captions],alphas


    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)
        c = self.init_c(mean_encoder_out)
        return h, c

In [None]:


class EncoderDecoder(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()
        self.encoder = EncoderRNN(
            vocab_size = vocab_size
            )
        self.decoder = DecoderRNN(
            embed_size=embed_size,
            vocab_size = vocab_size,
            attention_dim=attention_dim,
            encoder_dim=encoder_dim,
            decoder_dim=decoder_dim
        )

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs



In [None]:
embed_size=300
vocab_size = len(dataset.vocab)
attention_dim=256
encoder_dim=300
decoder_dim=512
learning_rate = 4e-4

In [None]:
#init model
model = EncoderDecoder(
    embed_size=embed_size,
    vocab_size=vocab_size,
    attention_dim=attention_dim,
    encoder_dim=encoder_dim,
    decoder_dim=decoder_dim
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
#helper function to save the model
def save_model(model,num_epochs):
    model_state = {
        'num_epochs':num_epochs,
        'embed_size':embed_size,
        'vocab_size':len(dataset.vocab),
        'attention_dim':attention_dim,
        'encoder_dim':encoder_dim,
        'decoder_dim':decoder_dim,
        'state_dict':model.state_dict()
    }

    torch.save(model_state,'./drive/MyDrive/full_model.pth')

In [None]:
import nltk
from nltk.translate import meteor
from nltk import word_tokenize
nltk.download('punkt')
nltk.download('wordnet')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [None]:
def score(text, text_sum):
    text = text.replace("<PAD>", "")
    if isinstance(text_sum, str):
        return round(meteor([word_tokenize(text)],word_tokenize(text_sum)), 4)
    else:
        return 0

In [None]:
num_epochs = 25
print_every = 1

torch.cuda.empty_cache()

for epoch in tqdm(range(1, num_epochs + 1)):
    losses = []
    print(f"Epoch: {epoch}")
    try:
        for idx, (image, captions) in tqdm(enumerate(data_loader), total=len(data_loader)):
            # print(image.shape[2])
            # if image.shape[2] < 50:
            #     continue
            image, captions = image.to(device),captions.to(device)

            optimizer.zero_grad()

            outputs,attentions = model(image, captions)

            loss = criterion(outputs.view(-1, vocab_size), captions[:, 1:].reshape(-1))

            loss.backward()

            optimizer.step()
            losses.append(loss.item())

            if idx % print_every == 0:
                model.eval()
                with torch.no_grad():
                    dataiter = iter(data_loader)
                    img,targ = next(dataiter)
                    features = model.encoder(img[0:1].to(device))
                    caps, alphas = model.decoder.generate_caption(features, vocab=dataset.vocab)
                    caption = ' '.join(caps)
                    target = " ".join([dataset.vocab.itos[i] for i in targ[0:1][0][1:].tolist()])
                    print("true:", target)
                    print("pred:", caption)
                    print("loss:", loss.item(), "score:", score(target, caption))


                model.train()
    except Exception as ex:
        print("Err:", ex)
        raise ex
    if len(losses) != 0:
        print("Epoch loss: {:.5f}".format(sum(losses) / len(losses)))
    else:
        print(losses)
    #save the latest model
    save_model(model,epoch)
    # clear_output(wait=True)

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 1


  0%|          | 0/10 [00:00<?, ?it/s]

loss: 8.901432037353516 score: 0.0068
loss: 8.893943786621094 score: 0.0094
loss: 8.875504493713379 score: 0.046
loss: 8.854141235351562 score: 0.0402
loss: 8.827407836914062 score: 0.1352
loss: 8.782187461853027 score: 0.0628
loss: 8.722457885742188 score: 0.0724
loss: 8.648955345153809 score: 0.0
loss: 8.485517501831055 score: 0.041
loss: 8.282970428466797 score: 0.042
Epoch loss: 8.72745
Epoch: 2


  0%|          | 0/10 [00:00<?, ?it/s]

loss: 8.075401306152344 score: 0.041
loss: 7.861804962158203 score: 0.0325
loss: 7.66754150390625 score: 0.028
loss: 7.648908615112305 score: 0.0443
loss: 7.377035617828369 score: 0.0296
loss: 7.347804546356201 score: 0.0185
loss: 7.27073335647583 score: 0.0191
loss: 7.206053733825684 score: 0.0371
loss: 7.223572731018066 score: 0.0274
loss: 7.175075531005859 score: 0.02
Epoch loss: 7.48539
Epoch: 3


  0%|          | 0/10 [00:00<?, ?it/s]

loss: 7.117039203643799 score: 0.0154
loss: 6.984583377838135 score: 0.0179
loss: 7.100799560546875 score: 0.1198
loss: 6.887745380401611 score: 0.098
loss: 7.0797624588012695 score: 0.0948
loss: 7.117587089538574 score: 0.0484
loss: 7.187944412231445 score: 0.132
loss: 7.162382125854492 score: 0.0693
loss: 7.169386386871338 score: 0.184
loss: 6.922194004058838 score: 0.1211
Epoch loss: 7.07294
Epoch: 4


  0%|          | 0/10 [00:00<?, ?it/s]

loss: 6.942858695983887 score: 0.2231
loss: 6.921104431152344 score: 0.0692
loss: 6.883472442626953 score: 0.0359
loss: 6.960726737976074 score: 0.0314


In [None]:
2