In [None]:
import numpy as np
import pickle
!pip install bcolz
import bcolz



In [None]:
! tar -xvf drive/MyDrive/Flickr-8K.tar

In [None]:
import nltk
nltk.__version__

'3.2.5'

In [None]:
import os
import torch
from PIL import Image
import pandas as pd
import spacy
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
spacy_eng = spacy.load("en")

In [None]:
class Vocabulary:
    def __init__(self, freq_thresh):
        self.freq_thresh = freq_thresh
        self.itos = {0:"<PAD>", 1:"<SOS>", 2:"<EOS>", 3:"<UNK>"}
        self.stoi = {"<PAD>":0, "<SOS>":1, "<EOS>":2, "<UNK>":3}
    
    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    def build_vocab(self, sentence_list):
        frequencies = {}
        idx = 4
        for sentence in sentence_list:
            for word in self.tokenizer(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1
                if frequencies[word] == self.freq_thresh:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
    
    def numericalize(self, text):
        tok_text = self.tokenizer(text)
        vec = [self.stoi[word] if word in self.stoi else self.stoi["<UNK>"] for word in tok_text]
        return vec

In [None]:
# for test and train we need to create new text files of the format (Flickr8k.token.txt)
class FlickrDataset(Dataset):
    def __init__(self, root_dir, caption_file, img_file, freq_thresh, transform=None,):
        self.freq_thresh = freq_thresh
        self.transform = transform
        self.root_dir = root_dir
        self.img_file = img_file
        self.caption_file = caption_file
        
        self.caption_dict = self.imgId_caption_dict()
        self.imgs , self.captions = self.load_img_caption()
        self.vocab = Vocabulary(self.freq_thresh)
        self.vocab.build_vocab(self.captions)

    def imgId_caption_dict(self):
        caption_dict = {}
        with open(self.caption_file, 'r') as f:
            for line in f.readlines():
                line = line.strip('\n')
                temp = line.split()
                img_name, _ = temp[0].split('#')  # first word will be img_id
                description = "".join(temp[1:]) # get back the description
                if img_name not in caption_dict:
                    caption_dict[img_name] = [description]
                else:
                    caption_dict[img_name].append(description)
        return caption_dict

    def load_img_caption(self):
        imgs = []
        captions = []
        with open(self.img_file, 'r') as f:
            for line in f.readlines():
                line = line.strip('\n')
                for caption in self.caption_dict[line]:
                    imgs.append(line)
                    captions.append(caption)
        return imgs, captions

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption += [self.vocab.stoi["<EOS>"]]

        return img, torch.tensor(numericalized_caption)



In [None]:
# we can also define simply a function instead of a class
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
    
    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0) # batch * imgs_size * 3 (RGB)
        captions = [item[1] for item in batch] 
        captions = pad_sequence(captions, batch_first=False, padding_value=self.pad_idx)
        return imgs, captions # return the batched images and captions

In [None]:
def get_loader(root_dir, caption_file, img_file, transform, batch=32, num_worker=2, shuffle=True, pin_memory=True):
    dataset = FlickrDataset(root_dir=root_dir, caption_file=caption_file, img_file=img_file, freq_thresh=5, transform=transform)
    pad_idx = dataset.vocab.stoi["<PAD>"]
    loader = DataLoader(dataset=dataset,
                        batch_size=batch,
                        shuffle=shuffle,
                        collate_fn=MyCollate(pad_idx),
                        pin_memory=pin_memory,
                        num_workers=num_worker)
    return loader

In [None]:
# Need to write a transform for images.
trans = transforms.Compose([
                            transforms.Resize((224, 224)),
                            transforms.ToTensor(),
])
dataloader = get_loader(root_dir="Flickr-8K/Flicker8k_Dataset",
                        img_file="Flickr-8K/Flickr_8k.testImages.txt",
                        caption_file="Flickr-8K/Flickr8k.token.txt",
                        transform=trans)

In [None]:
for idx, (img, caption) in enumerate(dataloader):
    print(img.shape)
    print(caption.shape)
    if idx == 5:
        break

## Embedding layer using GLoVE

In [None]:
! unzip drive/MyDrive/glove.6B.zip

Archive:  drive/MyDrive/glove.6B.zip
  inflating: glove.6B.50d.txt        
  inflating: glove.6B.100d.txt       
  inflating: glove.6B.200d.txt       
  inflating: glove.6B.300d.txt       


In [None]:
# Need not run this again and again


# words = []
# idx = 0
# word2idx = {}
# vectors = bcolz.carray(np.zeros(1), rootdir=f'drive/MyDrive/6B.300.dat', mode='w')

# with open(f'glove.6B.300d.txt', 'rb') as f:
#     for l in f:
#         line = l.decode().split()
#         word = line[0]
#         words.append(word)
#         word2idx[word] = idx
#         idx += 1
#         vect = np.array(line[1:]).astype(np.float)
#         vectors.append(vect)
    
# vectors = bcolz.carray(vectors[1:].reshape((400000, 300)), rootdir=f'drive/MyDrive/6B.300.dat', mode='w')
# vectors.flush()
# pickle.dump(words, open(f'drive/MyDrive/6B.300_words.pkl', 'wb'))
# pickle.dump(word2idx, open(f'drive/MyDrive/6B.300_idx.pkl', 'wb'))

In [None]:
vectors = bcolz.open(f'drive/MyDrive/6B.300.dat')[:]
words = pickle.load(open(f'drive/MyDrive/6B.300_words.pkl', 'rb'))
word2idx = pickle.load(open(f'drive/MyDrive/6B.300_idx.pkl', 'rb'))

glove = {w: vectors[word2idx[w]] for w in words}

In [None]:
glove['the'].shape

(300,)

In [None]:
# Get the flickr Dataset and access the vocab from it

In [None]:
matrix_len = len(target_vocab)  # target vocabulary 
weights_matrix = np.zeros((matrix_len, 300))
words_found = 0

for i, word in enumerate(target_vocab):
    try: 
        weights_matrix[i] = glove[word]
        words_found += 1
    except KeyError:
        weights_matrix[i] = np.random.normal(scale=0.6, size=(300, ))

NameError: ignored

In [None]:
def create_emb_layer(weights_matrix, non_trainable=False):
    num_embeddings, embedding_dim = weights_matrix.size()
    emb_layer = nn.Embedding(num_embeddings, embedding_dim)
    emb_layer.load_state_dict({'weight': weights_matrix})
    if non_trainable:
        emb_layer.weight.requires_grad = False

    return emb_layer, num_embeddings

## Code for Encoder CNN

In [None]:
from torchvision import models
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        nn.Module.__init__(self)
        self.model = models.resnet152(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, embed_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        for name, param in self.model.named_parameters():
            if "fc.weight" in name or "fc.bias" in name:
                param.requires_grad = True
            else:
                param.requires_grad = False

    def forward(self, X):
        x = self.model(X)
        return self.dropout(self.relu(x))

In [None]:
class DecoderRNN(nn.Module):
    """
    takes as input img vector and captions
    """
    def __init__(self):
        nn.Module.__init__(self, vocab)
        self.input_size = 512  # m = input_size in paper
        self.hidden_size = 512 # n = hidden_size in paper
        self.word_size = 300
        self.image_size = 1024
        self.lstm = nn.LSTMCell(self.input_size, self.hidden_size) # using single cell rather than lstm
        self.U = nn.Linear(self.word_size, self.word_size)
        self.V = nn.Linear(self.hidden_size, self.word_size)
        self.W_Y_A = nn.Linear(self.word_size, self.hidden_size)
        self.W_x_Y = nn.Linear(self.word_size, self.input_size)
        self.W_Y_h = nn.Linear(self.hidden_size, self.word_size)
        self.image_mapping = nn.Linear(self.image_size, self.input_size)
        self.diag_weight = torch.randn(self.word_size, requires_grad=True)
        self.input_diag_mat = torch.diag(self.diag_weight)

        # use the pretrained glove embedding
        self.vocab, self.num_embedding = create_emb_layer(weight_matrix, True)
        self.soft = nn.Softmax(dim=1)
        self.tanh = nn.Tanh()
        self.ReLU = nn.ReLU(inplace = True)

    def input_attention(self, previous_word, attributes):
        """ 
        Apply input attention at each step.
        attributes -> matrix where columns store the word embedding
        """
        # for each word in the vocabulary find the score
        previous_word = torch.unsqueeze(previous_word, 0) # batch x 300
        score = torch.mm(self.U(previous_word), attributes) # batch x no_of_attributes
        score = self.soft(score)  # batch x no_of_attributes
        score = torch.transpose(score, 0, 1) # no_of_attributes x batch
        weighted_y = torch.mm(attributes, score)  # 300 x batch
        scaled_y = torch.mm(self.input_diag_mat, weighted_y)  # 300 x batch
        final_y = scaled_y + torch.transpose(previous_word, 0, 1) # 300 x batch
        x = self.W_x_Y(torch.transpose(final_y, 0, 1))  # batch x 512
        return x, torch.transpose(score, 0, 1)

    def output_attention(self, hidden_state, attributes):
        """
        Apply output attention at each step
        hidden_state = current hidden state of LSTMCell, used for predicting current output
        attributes -> matrix where columns store the word embedding
        """
        hidden_state = torch.unsqueeze(hidden_state, 0) # batch x 512
        score = torch.mm(self.V(hidden_state), self.tanh(attributes)) # batch x no_of_attributes
        score = self.soft(score)  # batch x no_of_attributes
        score = torch.transpose(score, 0, 1) # no_of_attributes x batch
        weighted_y = torch.mm(self.tanh(attributes), score) # 300 x batch
        y_to_h = self.W_Y_A(torch.transpose(weighted_y, 0, 1))  # batch x 512
        hidden = hidden_state + y_to_h  # batch x 512
        hidden = self.W_Y_h(hidden)  # batch x 300
        logits = torch.mm(torch.transpose(attributes, 0, 1), torch.transpose(hidden, 0, 1)) # no_of_attr x batch
        logits = torch.transpose(logits, 0, 1) # batch x no_of_attr
        return logits, torch.transpose(score, 0, 1)  

    def prepare_attributes(self, attr):
        no_of_attr = attr.shape[1] # batch x no_of_attr
        batch = attr.shape[0] 
        attributes = torch.zeros(batch, no_of_attr, 300)
        for i in range(batch):
            for j in range(no_of_attr):
                attributes[i, j, :] = self.vocab[attr[i, j]]
        return attributes

    def next_word(self, logits):
        """
        Samples from the top 3 words randomly and returns a word
        """
        probability = torch.softmax(logits, dim=1)  # batch x no_of_attr
        prob_sort, indices = torch.sort(probability, descending=True)
        # generate a array of size -> batch
        arr = torch.randint(3, (logits.shape[0],))
        random_words = torch.zeros(logits.shape[0], 300)

        for i in range(logits.shape[0]):
            ind = inidices[i, arr[i]]
            random_words[i] = self.vocab[ind]
        return random_words  # batch x 300

    def forward(self, seq_len, image_vectors, attributes, caption_ids=None):
        """
        image_vectors -> batch x img_size
        caption_ids -> batch x seq_len x word_size (make it not None, if you want to do teacher forcing)
        attributes -> batch x 20, we have 20 attributes for each of the image (contains indices of words in our dictionary)
        it should also contain <PAD>, <SOS>, <EOS> as well. 
        """
        image_vectors = self.ReLU(self.image_mapping(image_vectors)) # batch x 512
        attributes = self.prepare_attributes(attributes)
        batch_size = image_vectors.shape[0]

        Hidden_logits = torch.zeros(batch_size, seq_len, attributes.shape[1])  # batch x seq x no_of_attr
        Hidden_scores = torch.zeros(batch_size, seq_len, attributes.shape[1]) # batch x seq x no_of_attr
        Input_scores = torch.zeros(batch_size, seq_len, attributes.shape[1]) # batch x seq x no_of_attr

        hidden, cell = self.lstm(image_vectors) # batch x 512
        # Hidden_states.append(hidden.unsqueeze(1))    no need to append this first one
        logits, hidden_score = self.output_attention(hidden, attributes)
        word = self.next_word(logits)  # batch x 300

        for i in range(seq_len):
            # prepare the input using input attention
            next_input, input_score = self.input_attention(word, attributes)
            Input_scores[:, i, :] = input_score
            hidden, cell = self.lstm(next_input, (hidden, cell))
            logits, hidden_score = self.output_attention(hidden, attributes)
            Hidden_scores[:, i, :] = hidden_score
            Hidden_logits[:, i, :] = logits
            word = self.next_word(logits)  # batch x 300

        return Hidden_logits, Hidden_scores, Input_scores


In [None]:
# loss function
class MyLoss(nn.Module):
    """Crossentropy + regularization"""
    def __init__(self):
        nn.Module.__init__(self)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, output, target, Hidden_scores, Input_scores):
        # output -> batch x seq x no_of_attr
        # target -> batch x seq
        x = self.criterion(output, target)  # batch x seq_len
        x = x.item()
        x = torch.sum(x, dim=1)  # batch
        hidden_reg = self.regularize(Hidden_scores)
        input_reg = self.regularize(Input_scores)
        return x + hidden_reg + input_reg

    def regularize(self, scores):
        # batch x time x i(attr index)
        s1 = torch.sum(score, dim=1)
        s1 = s1 * s1
        s1 = torch.sqrt(torch.sum(s1, dim=1)) # sum_i_()^{0.5}

        s2 = torch.sqrt(score)
        s2 = torch.sum(s2, dim=2)
        s2 = s2 * s2
        s2 = torch.sum(s2, dim=1)  # sum_time_()^{2}
        return s1 + s2


## Training Loop

In [None]:
NUM_EPOCHS = 100
criterion = MyLoss()
encoder = EncoderCNN(1024)
decoder = DecoderRNN()

for epoch in range(NUM_EPOCHS):
    for idx, (img, caption) in enumerate(dataloader):
        # write training loop
        pass