## Imports

In [1]:
import sys, pdb, os
import os.path as osp

from PIL import Image
import pickle

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence, pad_sequence

from vist_api.vist import *
from hyperparams import *
from bean_search import *

## Vocabulary

In [2]:
# a vocabulary calss adapted from 11731 assignment 1 starter code
# https://phontron.com/class/mtandseq2seq2019/assignments.html
from collections import Counter
import torch
import pickle
import pdb

class Vocabulary():
    def __init__(self, sents, freq_cutoff=1):
        self.w2i = {"<s>": 0, "</s>": 1, "<unk>": 2, "<pad>": 3}
        self.i2w = {v: k for k, v in self.w2i.items()}
        self.unk_id = 2
        self.sents = sents
        self.cutoff = freq_cutoff
        self.build()

    def build(self):
        # Start a counter and only include words that appear frequently.
        # freq_cutoff is to be set to 1, until we have a different tokenization method.
        word_freq = Counter()
        for sent in self.sents:
            word_freq["<s>"] += 1
            for word in sent.split():
                word_freq[word] += 1
            word_freq["</s>"] += 1
        
        # Gather valid words that pass cutoff and add them to the respective dictionaries
        valid_words = [w for w, v in word_freq.items() if v >= self.cutoff]
        valid_words = list(set(valid_words)) # Verifying one of each word
        for word in valid_words:
            wid = len(self.w2i)
            self.w2i[word] = wid
            self.i2w[wid] = word

    def __getitem__(self, word):
        pdb.set_trace()
        return self.w2i.get(word, self.unk_id)

    def __contains__(self, word):
        return word in self.w2i

    def __len__(self):
        return len(self.w2i)

    def sent2vec(self, sent, tokenized=False):
        if not tokenized:
            sent = sent.split()
        return torch.tensor([self.w2i[w] for w in sent]).type(torch.LongTensor)

    def vec2sent(self, sent):
        result = [self.i2w[i] for i in sent]
        return " ".join(result)

## Globals

In [3]:
vocab_save_path = "vocab.pt"
vist_annotations_dir = './vist_api/'
images_dir = './vist_api/images/'
sis_train = Story_in_Sequence(images_dir + "train", vist_annotations_dir)
# sis_val = Story_in_Sequence(images_dir+"val", vist_annotations_dir)
# sis_test = Story_in_Sequence(images_dir+"test", vist_annotations_dir)

cuda = True
cuda = cuda and torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")

# build/read vocabulary
corpus = []
for story in sis_train.Stories:
    sent_ids = sis_train.Stories[story]['sent_ids']
    for sent_id in sent_ids:
        corpus.append(sis_train.Sents[sent_id]['text'])
vocab = Vocabulary(corpus, freq_cutoff=1) # reads and builds
pickle.dump(vocab, open(vocab_save_path, 'wb'))

Make mapping ...
Mapping for [Albums][Images][Stories][Sents] done.
2741 stories remaining.


## Model

In [4]:
class fc7_Extractor(nn.Module):
    def __init__(self, fine_tune=False):
        super(fc7_Extractor, self).__init__()
        self.pretrained = models.vgg16(pretrained=True)
        self.fine_tune(fine_tune)

    def forward(self, x):
        x = self.pretrained.features(x)
        x = self.pretrained.avgpool(x)
        x = torch.flatten(x, 1)
        x = nn.Sequential(*list(self.pretrained.classifier.children())[:-1])(x)
        return x

    def fine_tune(self, fine_tune):
        if not fine_tune:
            for p in self.pretrained.parameters():
                p.requires_grad = False

                
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.fc7 = fc7_Extractor()
        self.gru = nn.GRU(FC7_SIZE, HIDDEN_SIZE)

    # batch * 5 * 3 * w * h
    def forward(self, images, hidden):
        batch_size, num_pics, channels, width, height = images.size()
        embedded = torch.zeros((num_pics, batch_size, FC7_SIZE)).to(device)
        for i in range(num_pics):
            batch_i = images[:, -(i+1), :, :, :]  # ith pics
            features = self.fc7(batch_i)  # out shape:batch * 5 * 4096
            embedded[i, :, :] = features  # shape: num_pic * batch * 4096
        output, hidden = self.gru(embedded, hidden)
        # output: num_pic, batch, 1000
        # hidden: 1, batch, 1000
        return output, hidden


class Decoder(nn.Module):
    def __init__(self, vocab_size):
        super(Decoder, self).__init__()
        self.hidden_size = HIDDEN_SIZE
        self.embedding = nn.Embedding(vocab_size, EMBEDDING_SIZE, padding_idx=3)
        self.gru = nn.GRU(EMBEDDING_SIZE, HIDDEN_SIZE)

    def forward(self, padded_stories, hidden, lens):
        padded_stories = self.embedding(padded_stories)
        packed_stories = pack_padded_sequence(padded_stories, lens, enforce_sorted=False)
        output, hidden = self.gru(packed_stories, hidden)
        return output, hidden


class BaselineModel(nn.Module):
    def __init__(self, vocab):
        super(BaselineModel, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder(vocab_size=len(vocab))
        self.vocab = vocab
        self.out_layer = nn.Linear(HIDDEN_SIZE, len(vocab))
        self.vocab_length = len(vocab)
#         self.logSoftmax = nn.LogSoftmax(dim=1)
#         self.loss = nn.NLLLoss()  # default mean
                                
#     def get_decoded_output(self, decoder_input, hidden, lens):
#         output, hidden = self.decoder(decoder_input, hidden, lens)
#         output, _ = pad_packed_sequence(output)
#         output = self.out_layer(output)
#         # output = output.view(output.size()[0], -1)
#         return output, hidden

    def forward(self, images, stories, story_lens):
        batch_size = images.size(0)
        hidden_1 = torch.rand(1, batch_size, HIDDEN_SIZE).to(device)
        out, hidden = self.encoder(images, hidden_1)
        out, hidden = self.decoder(stories, hidden, story_lens)
        n_tokens = story_lens.sum() - story_lens.size(0)
        
        loss = 0.0
        # https://pytorch.org/docs/stable/nn.html#ctcloss
        criterion = nn.CTCLoss(blank=3)
        
        out, out_lens = pad_packed_sequence(out)
        out = self.out_layer(out) 
        stories = stories.t()
        input_lengths = out_lens
        target_lengths = story_lens # flip these?
        loss = criterion(out, stories, input_lengths, target_lengths)
        
#         for i in range(out.size()[0]-1):
#             active = i + 1 < story_lens
#             loss += criterion(out[i, active,: ], stories[i+1, active])
#             pdb.set_trace()
            
#         loss /= n_tokens
        
        return loss, out


## Dataset

In [5]:
# Build dataset
class StoryDataset(Dataset):
    def __init__(self, sis, vocab):
        self.sis = sis
        self.story_indices = list(self.sis.Stories.keys())
        self.vocab = vocab

    def __len__(self):
        return 100
        return len(self.sis.Stories)

    @staticmethod
    def read_image(path):
        img = Image.open(path)
        img = torchvision.transforms.Resize((224, 224))(img)
        img = torchvision.transforms.ToTensor()(img)
        return img

    def __getitem__(self, idx):
        story_id = self.story_indices[idx]
        story = self.sis.Stories[story_id]
        sent_ids = story['sent_ids']
        img_ids = story['img_ids']
        imgs = []
        for i, img_id in enumerate(img_ids):
            img_file = osp.join(self.sis.images_dir, img_id + '.jpg')
            img_tensor = self.read_image(img_file)
            imgs.append(img_tensor)
        imgs = torch.stack(imgs)
        
        sent = ""
        for sent_id in sent_ids:
            # Add a space for the sentence, probably want to just remove puncuation
            sent += " " + self.sis.Sents[sent_id]["text"] 
        
        sent_tensor = self.vocab.sent2vec("<s> " + sent + " </s>")
        return imgs, sent_tensor


def collate_story(seq_list):
    imgs, sents = zip(*seq_list)
    imgs = torch.stack(imgs)
    sents_len = torch.Tensor([len(sent) for sent in sents])
    sents = pad_sequence(sents, padding_value=3)
    return imgs, sents, sents_len

## Main

In [6]:
def train(epochs, model, train_dataloader, optimizer):
    model.train()
    model.to(device)
    for epoch in range(epochs):
        avg_loss = 0
        for batch_num, (images, sents, sents_len) in enumerate(train_dataloader):
            optimizer.zero_grad()
            
            # Process data and put on device
            images = images.float()
            sents = sents.long()
            sents_len = sents_len.long()
            images, sents, sents_len = images.to(device), sents.to(device), sents_len.to(device)
            
            # Run through model
            loss, output = model(images, sents, sents_len)
            
#             greedy_decode(model, images, device, vocab)
            # comment out to see the current greedy decoded story
            avg_loss += loss.item()
            print(loss.item())
            
            loss.backward()
            optimizer.step()

            if batch_num % PRINT_LOSS == PRINT_LOSS-1:
                print('Epoch: {}\tBatch: {}\tAvg-Loss: {:.4f}'.format(epoch + 1, batch_num + 1, avg_loss / 50))
                avg_loss = 0.0

        # torch.save(model.state_dict(), model_path + "/"+ str(epoch) +".pt")
        
def main():
    torch.cuda.empty_cache()
    
    train_story_set = StoryDataset(sis_train, vocab)
    # val_story_set = StoryDataset(sis_val, vocab)
    # test_story_set = StoryDataset(sis_test, vocab)

    train_loader = DataLoader(train_story_set, shuffle=False, batch_size=BATCH_SIZE, collate_fn=collate_story)
    # imgs of shape [BS, 5, 3, 224, 224]
    # sents BS * 5  * MAX_LEN

    baseline_model = BaselineModel(vocab)
    optimizer = torch.optim.Adam(baseline_model.parameters(), lr=0.01)

    train(1, baseline_model, train_loader, optimizer)
    
main()

0.00036869384348392487
-0.0013964317040517926
7.115257263183594
inf
nan
nan
nan
nan
nan
nan
Epoch: 1	Batch: 10	Avg-Loss: nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
Epoch: 1	Batch: 20	Avg-Loss: nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
Epoch: 1	Batch: 30	Avg-Loss: nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
Epoch: 1	Batch: 40	Avg-Loss: nan
nan
nan
nan
nan


KeyboardInterrupt: 