## Definitions

In [None]:
import torch
import torch.utils.data as data
import torch.nn as nn
import torchvision.models as models
import nltk
import pickle
import os.path
from pycocotools.coco import COCO
from collections import Counter
import math
import os
from PIL import Image
import numpy as np
from tqdm import tqdm
import random
import json
import warnings

class Vocabulary(object):

    def __init__(self,
        vocab_threshold,
        vocab_file='./vocab.pkl',
        start_word=0,
        end_word=1,
        unk_word="<unk>",
        annotations_file='annotations/captions_train2014.json',
        vocab_from_file=False):
        
        self.vocab_threshold = vocab_threshold
        self.vocab_file = vocab_file
        self.start_word = start_word
        self.end_word = end_word
        self.unk_word = unk_word
        self.annotations_file = annotations_file
        self.vocab_from_file = vocab_from_file
        self.get_vocab()

    def get_vocab(self):
        """Load the vocabulary from file OR build the vocabulary from scratch."""
        if os.path.exists(self.vocab_file) & self.vocab_from_file:
            with open(self.vocab_file, 'rb') as f:
                vocab = pickle.load(f)
                self.word2idx = vocab.word2idx
                self.idx2word = vocab.idx2word
            print('Vocabulary successfully loaded from vocab.pkl file!')
        else:
            self.build_vocab()
            with open(self.vocab_file, 'wb') as f:
                pickle.dump(self, f)
        
    def build_vocab(self):
        """Populate the dictionaries for converting tokens to integers (and vice-versa)."""
        self.init_vocab()
        self.add_word(self.start_word)
        self.add_word(self.end_word)
        self.add_word(self.unk_word)
        self.add_captions()

    def init_vocab(self):
        """Initialize the dictionaries for converting tokens to integers (and vice-versa)."""
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        """Add a token to the vocabulary."""
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def add_captions(self):
        """Loop over training captions and add all tokens to the vocabulary that meet or exceed the threshold."""
        coco = COCO(self.annotations_file)
        counter = Counter()
        ids = coco.anns.keys()
        for i, id in enumerate(ids):
            caption = str(coco.anns[id]['caption'])
            tokens = nltk.tokenize.word_tokenize(caption.lower())
            counter.update(tokens)

            if i % 100000 == 0:
                print("[%d/%d] Tokenizing captions..." % (i, len(ids)))

        words = [word for word, cnt in counter.items() if cnt >= self.vocab_threshold]

        for i, word in enumerate(words):
            self.add_word(word)

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx[self.unk_word]
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

def get_loader(transform,
               mode='train',
               batch_size=1,
               vocab_threshold=None,
               vocab_file='./vocab.pkl',
               start_word=0,
               end_word=1,
               unk_word="<unk>",
               vocab_from_file=True,
               num_workers=0,
               cocoapi_loc=''):
    
    if mode == 'train':
        if vocab_from_file==True: assert os.path.exists(vocab_file), "vocab_file does not exist.  Change vocab_from_file to False to create vocab_file."
        img_folder = os.path.join(cocoapi_loc, 'images/train2014/')
        annotations_file = os.path.join(cocoapi_loc, 'annotations/captions_train2014.json')
    
    if mode == 'val':
        if vocab_from_file==True: assert os.path.exists(vocab_file), "vocab_file does not exist.  Change vocab_from_file to False to create vocab_file."
        img_folder = os.path.join(cocoapi_loc, 'images/val2014/')
        annotations_file = os.path.join(cocoapi_loc, 'annotations/captions_val2014.json')
    
    if mode == 'test':
        assert batch_size==1, "Please change batch_size to 1 if testing your model."
        assert os.path.exists(vocab_file), "Must first generate vocab.pkl from training data."
        assert vocab_from_file==True, "Change vocab_from_file to True."
        img_folder = os.path.join(cocoapi_loc, 'images/test2014/')
        annotations_file = os.path.join(cocoapi_loc, 'annotations/image_info_test2014.json')

    # COCO caption dataset.
    dataset = CoCoDataset(transform=transform,
                          mode=mode,
                          batch_size=batch_size,
                          vocab_threshold=vocab_threshold,
                          vocab_file=vocab_file,
                          start_word=start_word,
                          end_word=end_word,
                          unk_word=unk_word,
                          annotations_file=annotations_file,
                          vocab_from_file=vocab_from_file,
                          img_folder=img_folder)

    if mode == 'train' or mode == 'val':
        # Randomly sample a caption length, and sample indices with that length.
        indices = dataset.get_train_indices()
        # Create and assign a batch sampler to retrieve a batch with the sampled indices.
        initial_sampler = data.sampler.SubsetRandomSampler(indices=indices)
        
        data_loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    
        
    else:
        data_loader = data.DataLoader(dataset=dataset,
                                      batch_size=dataset.batch_size,
                                      shuffle=True,
                                      num_workers=num_workers)

    return data_loader

class CoCoDataset(data.Dataset):
    
    def __init__(self, transform, mode, batch_size, vocab_threshold, vocab_file, start_word, 
        end_word, unk_word, annotations_file, vocab_from_file, img_folder):
        self.transform = transform
        self.mode = mode
        self.batch_size = batch_size
        self.vocab = Vocabulary(vocab_threshold, vocab_file, start_word,
            end_word, unk_word, annotations_file, vocab_from_file)
        self.img_folder = img_folder
        if self.mode == 'train' or self.mode == 'val':
            self.coco = COCO(annotations_file)
            self.ids = list(self.coco.anns.keys())
            print('Obtaining caption lengths...')
            all_tokens = [nltk.tokenize.word_tokenize(str(self.coco.anns[self.ids[index]]['caption']).lower()) for index in tqdm(np.arange(len(self.ids)))]
            self.caption_lengths = [len(token) for token in all_tokens]
        else:
            test_info = json.loads(open(annotations_file).read())
            self.paths = [item['file_name'] for item in test_info['images']]
        
    def __getitem__(self, index):
        # obtain image and caption if in training mode
        if self.mode == 'train' or self.mode == 'val':
            ann_id = self.ids[index]
            caption = self.coco.anns[ann_id]['caption']
            img_id = self.coco.anns[ann_id]['image_id']
            path = self.coco.loadImgs(img_id)[0]['file_name']

            # Convert image to tensor and pre-process using transform
            image = Image.open(os.path.join(self.img_folder, path)).convert('RGB')
            image = self.transform(image)

            # Convert caption to tensor of word ids.
            tokens = nltk.tokenize.word_tokenize(str(caption).lower())
            caption = []
            caption.append(self.vocab(self.vocab.start_word))
            caption.extend([self.vocab(token) for token in tokens])
            caption.append(self.vocab(self.vocab.end_word))
            caption = torch.Tensor(caption)#.long() #TO DO: might need to change

            # return pre-processed image and caption tensors
            return image, caption

        # obtain image if in test mode
        else:
            path = self.paths[index]

            # Convert image to tensor and pre-process using transform
            PIL_image = Image.open(os.path.join(self.img_folder, path)).convert('RGB')
            orig_image = np.array(PIL_image)
            image = self.transform(PIL_image)

            # return original image and pre-processed image tensor
            return orig_image, image

    def get_train_indices(self):
        sel_length = np.random.choice(self.caption_lengths)
        all_indices = np.where([self.caption_lengths[i] == sel_length for i in np.arange(len(self.caption_lengths))])[0]
        indices = list(np.random.choice(all_indices, size=self.batch_size))
        return indices

    def __len__(self):
        if self.mode == 'train' or self.mode == 'val':
            return len(self.ids)
        else:
            return len(self.paths)

        

        
import sys
nltk.download('punkt')
#from data_loader import get_loader
from torchvision import transforms

# Define a transform to pre-process the training images.
transform_train = transforms.Compose([ 
    transforms.Resize(256),                          # smaller edge of image resized to 256
    transforms.RandomCrop(224),                      # get 224x224 crop from random location
    transforms.RandomHorizontalFlip(),               # horizontally flip image with probability=0.5
    transforms.ToTensor(),                           # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),      # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))])

transform_test = transforms.Compose([
     transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize((0.485, 0.456, 0.406),
                          (0.229, 0.224, 0.225))
    ])



class Encoder(nn.Module):
    """
    Encoder.
    """
    def __init__(self, model_type, encoded_image_size=14, fine_tune=False):
        super(Encoder, self).__init__()
        self.enc_image_size = encoded_image_size
        model = getattr(models, model_type)
        resnet = model(pretrained=True)  

        # Remove linear and pool layers (since we're not doing classification)
        modules = list(resnet.children())[:-2] 
        self.resnet = nn.Sequential(*modules)

        # Resize image to fixed size to allow input images of variable size
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
        self.fine_tune(fine_tune)
        self.embed = nn.Linear(resnet.fc.in_features, encoded_image_size)

    def forward(self, images):
        out = self.resnet(images)  
        out = self.adaptive_pool(out) 
        out = torch.flatten(out,2,3) 
        out = out.permute(2, 0, 1) 
        return out
        
    def fine_tune(self, fine_tune=False):
        """
        Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder.
        :param fine_tune: Allow?
        """
        for p in self.resnet.parameters():
            p.requires_grad = False
        # If fine-tuning, only fine-tune convolutional blocks 2 through 4
        for c in list(self.resnet.children())[-1]: 
            for p in c.parameters():
                p.requires_grad = fine_tune


class MLP_init(nn.Module):
    def __init__(self, encoder_hidden_size, decoder_hidden_size):
        super(MLP_init, self).__init__()
        
        self.encoder_hidden_size = encoder_hidden_size
        self.decoder_hidden_size = decoder_hidden_size
        
        self.init_MLP = nn.Sequential(
                            nn.Linear(encoder_hidden_size, decoder_hidden_size),
                            nn.ReLU(),
                            nn.Linear(decoder_hidden_size, decoder_hidden_size)
                        )
        
    def forward(self, h):
        return self.init_MLP(h)


class ScaledDotAttention(nn.Module):
    def __init__(self, hidden_size):
        super(ScaledDotAttention, self).__init__()

        self.hidden_size = hidden_size

        self.Q = nn.Linear(hidden_size, hidden_size)
        self.K = nn.Linear(hidden_size, hidden_size)
        self.V = nn.Linear(hidden_size, hidden_size)
        self.softmax = nn.Softmax(dim=1)
        self.scaling_factor = torch.rsqrt(torch.tensor(self.hidden_size, dtype= torch.float))

    def forward(self, queries, keys, values):
        """The forward pass of the scaled dot attention mechanism.
        Arguments:
            queries: The current decoder hidden state, 2D or 3D tensor. (batch_size x (k) x hidden_size)
            keys: The encoder hidden states for each step of the input sequence. (batch_size x seq_len x hidden_size)
            values: The encoder hidden states for each step of the input sequence. (batch_size x seq_len x hidden_size)
        Returns:
            context: weighted average of the values (batch_size x k x hidden_size)
            attention_weights: Normalized attention weights for each encoder hidden state. (batch_size x seq_len x 1)
            The output must be a softmax weighting over the seq_len annotations.
        """

        batch_size = queries.shape[0]
        q = self.Q(queries.view(batch_size, -1, queries.shape[-1]))
        k = self.K(keys)
        v = self.V(values)
        unnormalized_attention = k@q.transpose(2,1)*self.scaling_factor
        attention_weights = self.softmax(unnormalized_attention)
        context = attention_weights.transpose(2,1)@v
        return context, attention_weights
        

class CausalScaledDotAttention(nn.Module):
    def __init__(self, hidden_size):
        super(CausalScaledDotAttention, self).__init__()

        self.hidden_size = hidden_size
        self.neg_inf = torch.tensor(-1e7)

        self.Q = nn.Linear(hidden_size, hidden_size)
        self.K = nn.Linear(hidden_size, hidden_size)
        self.V = nn.Linear(hidden_size, hidden_size)
        self.softmax = nn.Softmax(dim=1)
        self.scaling_factor = torch.rsqrt(torch.tensor(self.hidden_size, dtype= torch.float))

    def forward(self, queries, keys, values):

        batch_size = queries.shape[0]
        q = self.Q(queries.view(batch_size, -1, queries.shape[-1]))
        k = self.K(keys)
        v = self.V(values)
        unnormalized_attention = k@q.transpose(2,1)*self.scaling_factor
        mask = ~torch.triu(unnormalized_attention).bool()
        attention_weights = self.softmax(unnormalized_attention.masked_fill(mask, self.neg_inf))
        context = attention_weights.transpose(2,1)@v
        return context, attention_weights



class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers, num_heads, dropout):
        super(TransformerDecoder, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(vocab_size, hidden_size)        
        self.num_layers = num_layers
        self.num_heads = num_heads
        
        self.self_attentions = nn.ModuleList([nn.ModuleList([CausalScaledDotAttention(
                                    hidden_size=hidden_size, 
                                 ) for i in range(self.num_heads)]) for j in range(self.num_layers)])
        self.encoder_attentions = nn.ModuleList([nn.ModuleList([ScaledDotAttention(
                                    hidden_size=hidden_size, 
                                 ) for i in range(self.num_heads)]) for j in range(self.num_layers)])
        self.attention_mlps = nn.ModuleList([nn.Sequential(
                                    nn.Linear(hidden_size, hidden_size),
                                    nn.ReLU(),
                                 ) for i in range(self.num_layers)])
        

        self.linear_after_causal = nn.ModuleList([nn.Linear(self.num_heads*hidden_size, hidden_size) for j in range(self.num_layers)])
        self.linear_after_scaled = nn.ModuleList([nn.Linear(self.num_heads*hidden_size, hidden_size) for j in range(self.num_layers)])

        self.out = nn.Linear(hidden_size, vocab_size)

        self.positional_encodings = self.create_positional_encodings()

        self.dropout = nn.Dropout(p=dropout)

        self.layernorms1 = nn.ModuleList([nn.LayerNorm([self.hidden_size]) for i in range(self.num_layers)])
        self.layernorms2 = nn.ModuleList([nn.LayerNorm([self.hidden_size]) for i in range(self.num_layers)])
        self.layernorms3 = nn.ModuleList([nn.LayerNorm([self.hidden_size]) for i in range(self.num_layers)])

    def forward(self, inputs, annotations):

        batch_size, seq_len = inputs.size()
        inputs = inputs.long()
        embed = self.embedding(inputs)  # batch_size x seq_len x hidden_size
        embed = embed + self.positional_encodings[:seq_len]
        embed = self.dropout(embed)

        encoder_attention_weights_list = []
        self_attention_weights_list = []
        contexts = embed

        

        for i in range(self.num_layers):
            
            concat_causal = torch.empty((batch_size, seq_len, 0), device='cuda')
            concat_scaled = torch.empty((batch_size, seq_len, 0), device='cuda')
            for j in range(self.num_heads):
                new_contexts, self_attention_weights = self.self_attentions[i][j](contexts, contexts, contexts)  # batch_size x seq_len x hidden_size
                concat_causal = torch.cat((concat_causal, new_contexts), axis=2)
                
            new_contexts = self.linear_after_causal[i](concat_causal) #batch_size x seq_len x hidden_size*num_heads -----> batch_size x seq_len x hidden_size
            new_contexts = self.dropout(new_contexts) #dropout
            residual_contexts = self.layernorms1[i](contexts + new_contexts) #add and norm

            for j in range(self.num_heads):
                new_contexts, encoder_attention_weights = self.encoder_attentions[i][j](residual_contexts, annotations, annotations) # batch_size x seq_len x hidden_size
                concat_scaled = torch.cat((concat_scaled, new_contexts), axis=2)
            
            new_contexts = self.linear_after_scaled[i](concat_scaled) #batch_size x seq_len x hidden_size*num_heads -----> batch_size x seq_len x hidden_size
            new_contexts = self.dropout(new_contexts) #dropout
            residual_contexts = self.layernorms2[i](residual_contexts + new_contexts) #add and norm

            new_contexts = self.attention_mlps[i](residual_contexts)
            new_contexts = self.dropout(new_contexts) #dropout
            contexts = self.layernorms3[i](residual_contexts + new_contexts) #add and norm

            encoder_attention_weights_list.append(encoder_attention_weights)
            self_attention_weights_list.append(self_attention_weights)
          
        output = self.out(contexts)
        encoder_attention_weights = torch.stack(encoder_attention_weights_list)
        self_attention_weights = torch.stack(self_attention_weights_list)
        
        return output, (encoder_attention_weights, self_attention_weights)

    def create_positional_encodings(self, max_seq_len=1000):
        """Creates positional encodings for the inputs.
        Arguments:
          max_seq_len: a number larger than the maximum string length we expect to encounter during training
        Returns:
          pos_encodings: (max_seq_len, hidden_dim) Positional encodings for a sequence with length max_seq_len. 
        """
        pos_indices = torch.arange(max_seq_len)[..., None]
        dim_indices = torch.arange(self.hidden_size//2)[None, ...]
        exponents = (2*dim_indices).float()/(self.hidden_size)
        trig_args = pos_indices / (10000**exponents)
        sin_terms = torch.sin(trig_args)
        cos_terms = torch.cos(trig_args)

        pos_encodings = torch.zeros((max_seq_len, self.hidden_size))
        pos_encodings[:, 0::2] = sin_terms
        pos_encodings[:, 1::2] = cos_terms

        #print(pos_encodings.shape)
        pos_encodings = pos_encodings.cuda()

        return pos_encodings

class EncoderDecoder(nn.Module):
#Main class for Encoder and Decoder

    def __init__(
            self, encoder_class, decoder_class,
            target_vocab_size, target_sos=-2, target_eos=-1, encoder_type='resnet18', fine_tune=False, encoder_hidden_size=512,
            decoder_hidden_size=1024, word_embedding_size=1024, attention_dim=512, cell_type='lstm', decoder_type='rnn', beam_width=4, dropout=0.0,
            transformer_layers=3, num_heads=1):
        # Init Encoder and Decoder
        super().__init__()
        self.target_vocab_size = target_vocab_size
        self.target_sos = target_sos
        self.target_eos = target_eos
        self.encoder_type = encoder_type
        self.fine_tune = fine_tune
        self.encoder_hidden_size = encoder_hidden_size
        self.decoder_hidden_size = decoder_hidden_size
        self.word_embedding_size = word_embedding_size
        self.attention_dim = attention_dim
        self.cell_type = cell_type
        self.decoder_type = decoder_type
        self.beam_width = beam_width
        self.dropout = dropout
        self.transformer_layers = transformer_layers
        self.num_heads = num_heads
        self.encoder = self.decoder = None
        self.init_submodules(encoder_class, decoder_class)
        
    def init_submodules(self, encoder_class, decoder_class):
        self.encoder = encoder_class(self.encoder_type, fine_tune=self.fine_tune)

        self.decoder = decoder_class(self.target_vocab_size, 
                                self.encoder_hidden_size,
                                self.transformer_layers,
                                self.num_heads,
                                self.dropout)

    def get_target_padding_mask(self, E):
        '''Determine what parts of a target sequence batch are padding
        `E` is right-padded with end-of-sequence symbols. This method
        creates a mask of those symbols, excluding the first in every sequence
        (the first eos symbol should not be excluded in the loss).
        Parameters
        ----------
        E : torch.LongTensor
            A float tensor of shape ``(T - 1, N)``, where ``E[t', n]`` is
            the ``t'``-th token id of a gold-standard transcription for the
            ``n``-th source sequence. *Should* exclude the initial
            start-of-sequence token.
        Returns
        -------
        pad_mask : torch.BoolTensor
            A boolean tensor of shape ``(T - 1, N)``, where ``pad_mask[t, n]``
            is :obj:`True` when ``E[t, n]`` is considered padding.
        '''
        pad_mask = E == self.target_eos  # (T - 1, N)
        pad_mask = pad_mask & torch.cat([pad_mask[:1], pad_mask[:-1]], 0)
        return pad_mask

    def forward(self, images, captions=None, max_T=100, on_max='raise'):
        h = self.encoder(images)  # (L, N, H)
        if self.training:
            return self.get_logits_for_teacher_forcing(h, captions)
        else:
            return self.beam_search(h, max_T, on_max)

    def get_logits_for_teacher_forcing(self, h, captions):
        # name is not relevant
        op = []
        op, _ = self.decoder(captions[:-1,:].T, h.permute(1,0,2))
        return op
        
    def beam_search(self, h, max_T, on_max):

        # Inputs: h: encoder hidden states. #(H*W, batch_size, L) default is (196, batch_size, 2048)
        assert not self.training
        random_placeholder = torch.randn(h.shape[1], self.decoder_hidden_size, device=h.device)
        logpb_tm1 = torch.where(
            torch.arange(self.beam_width, device=h.device) > 0,  # K
            torch.full_like(
                random_placeholder[..., 0].unsqueeze(1), -float('inf')),  # k > 0
            torch.zeros_like(
                random_placeholder[..., 0].unsqueeze(1)),  # k == 0
        )  # (N, K)
        
        assert torch.all(logpb_tm1[:, 0] == 0.)
        assert torch.all(logpb_tm1[:, 1:] == -float('inf'))
        b_tm1_1 = torch.full_like(  # (t, N, K)
            logpb_tm1, self.target_sos, dtype=torch.float).unsqueeze(0) #Changed long to float
        # We treat each beam within the batch as just another batch when
        # computing logits, then recover the original batch dimension by
        # reshaping
        
        h = h.unsqueeze(2).repeat(1, 1, self.beam_width, 1)
        h = h.flatten(1, 2)  # (S, N * K, L)
        v_is_eos = torch.arange(self.target_vocab_size, device=h.device)
        v_is_eos = v_is_eos == self.target_eos  # (V,)
        t = 0
        logits_tm1 = None
        cur_transformer_ip = None
        while torch.any(b_tm1_1[-1, :, 0] != self.target_eos):
            if t == max_T:
                if on_max == 'raise':
                    raise RuntimeError(
                        f'Beam search has not finished by t={t}. Increase the '
                        f'number of parameters and train longer')
                elif on_max == 'halt':
                    print(f'Beam search not finished by t={t}. Halted')
                    break
            finished = (b_tm1_1[-1] == self.target_eos)
           
            E_tm1 = b_tm1_1[-1].flatten().unsqueeze(1)  # (N * K, 1)

            if cur_transformer_ip == None:
                cur_transformer_ip = E_tm1
#             except:
            else:
                cur_transformer_ip = torch.cat([cur_transformer_ip, E_tm1], axis=1)
            op, _ = self.decoder(cur_transformer_ip, h.permute(1,0,2))
            logits_t = op[:, -1, :]
            logits_tm1 = logits_t
            logits_t = logits_t.view(
                -1, self.beam_width, self.target_vocab_size)  # (N, K, V)
            logpy_t = nn.functional.log_softmax(logits_t, -1)
            # We length-normalize the extensions of the unfinished paths
            if t:
                logpb_tm1 = torch.where(
                    finished, logpb_tm1, logpb_tm1 * (t / (t + 1)))
                logpy_t = logpy_t / (t + 1)
            logpy_t = logpy_t.masked_fill(
                finished.unsqueeze(-1) & v_is_eos, 0.)
            logpy_t = logpy_t.masked_fill(
                finished.unsqueeze(-1) & (~v_is_eos), -float('inf'))
            if self.decoder_type == 'rnn':
                if self.cell_type == 'lstm':
                    htilde_t = (
                        htilde_t[0].view(
                            -1, self.beam_width, self.decoder_hidden_size),
                        htilde_t[1].view(
                            -1, self.beam_width, self.decoder_hidden_size),
                    )
                else:
                    htilde_t = htilde_t.view(
                        -1, self.beam_width, self.decoder_hidden_size)
                b_t_0, b_t_1, logpb_t = self.update_beam(
                    htilde_t, b_tm1_1, logpb_tm1, logpy_t)
                del logits_t, logpy_t, finished, htilde_t
                if self.cell_type == 'lstm':
                    htilde_tm1 = (
                        b_t_0[0].flatten(end_dim=1),
                        b_t_0[1].flatten(end_dim=1)
                    )
                else:
                    htilde_tm1 = b_t_0.flatten(end_dim=1)  # (N * K, 2 * H)
            else:
                b_t_1, logpb_t = self.update_beam(None, b_tm1_1, logpb_tm1, logpy_t)
                del logits_t, logpy_t, finished
            logpb_tm1, b_tm1_1 = logpb_t, b_t_1
            t += 1
        return b_tm1_1

    def update_beam(self, htilde_t, b_tm1_1, logpb_tm1, logpy_t):
        V = logpy_t.shape[2] #Vocab size
        K = logpy_t.shape[1] #Beam width

        s = logpb_tm1.unsqueeze(-1).expand_as(logpy_t) + logpy_t
        logy_flat = torch.flatten(s, 1, 2)
        top_k_val, top_k_ind = torch.topk(logy_flat, K, dim = 1)
        temp = top_k_ind // V #This tells us which beam that top value  is from
        logpb_t = top_k_val

        temp_ = temp.expand_as(b_tm1_1)
        b_t_1 = torch.cat((torch.gather(b_tm1_1, 2, temp_), (top_k_ind % V).unsqueeze(0)))

        if htilde_t != None:
            if(self.cell_type == 'lstm'):
                temp_ = temp.unsqueeze(-1).expand_as(htilde_t[0])
                b_t_0 = (torch.gather(htilde_t[0], 1, temp_), torch.gather(htilde_t[1], 1, temp_))
            else:
                temp_ = temp.unsqueeze(-1).expand_as(htilde_t)
                b_t_0 = torch.gather(htilde_t, 1, temp_)

            return b_t_0, b_t_1, logpb_t
        else:
            return b_t_1, logpb_t


from torch.nn.utils.rnn import pad_sequence

def train_for_epoch(model, dataloader, optimizer, device, n_iter, epoch, losses):
    # trains for one epoch
    
    criterion1 = nn.CrossEntropyLoss(ignore_index=-1, reduction='sum')
    total_loss = 0 
    total_num = 0
    count = 0
    target_eos = 1
    
    for data in tqdm(dataloader):
        images, captions, cap_lens = data
        captions = pad_sequence(captions, padding_value=model.module.target_eos) #(seq_len, batch_size)
        images, captions = images.to(device), captions.to(device)
        optimizer.zero_grad()
        logits = model(images, captions).permute(1, 0, 2)

        captions = captions[1:]
        mask = model.module.get_target_padding_mask(captions)
        captions = captions.masked_fill(mask,-1)
        loss = criterion1(torch.flatten(logits, 0, 1), torch.flatten(captions))
        total_loss += float(loss.item())
        total_num += len(cap_lens)
        loss.backward()
        if grad_clip is not None:
            clip_gradient(optimizer, grad_clip)
        optimizer.step()
        n_iter += 1
        torch.cuda.empty_cache()
        if n_iter % 10 == 0:
            # Get training statistics.
            stats = 'Epoch %d, Step %d, Loss: %.4f, Perplexity: %5.4f' % (epoch, n_iter, total_loss/total_num, torch.exp(loss))
            print(stats, "\n")
        
        if n_iter % 10 == 0:
            losses.append(total_loss/total_num)
        
        if n_iter % 2000 == 0:
            torch.save(model.module.decoder.state_dict(), os.path.join('./checkpoints1', 'transformerdecoder-epoch-%d-step-%d.pkl'%(epoch,n_iter)))
        del images
        del captions
    return total_loss/total_num, n_iter, losses


def clip_gradient(optimizer, grad_clip):
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)


                
def collate_fn(data):
    # Creates mini-batch tensors from the list of tuples (image, caption).
    # Sort a data list by caption length (descending order).
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).
    images = torch.stack(images, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]  
    lengths = torch.tensor(lengths)
    
    return images, targets, lengths
                

## Training

In [None]:
encoder_class = Encoder
decoder_class = TransformerDecoder
encoder_type = 'resnet50'
decoder_type = 'transformer' #transformer, rnn
warmup_steps = 4000
n_iter = 1
CNN_channels = 2048

epoch = 0
max_epochs = 4
beam_width = 4

print("Epochs are read correctly: ", max_epochs)
print("Encoder type is read correctly: ", encoder_type)
print("Number of CNN channels being used: ", CNN_channels)
print("Fine tune setting is set to: ", bool(0))


word_embedding_size = 512
attention_dim = 512
model_save_path = './model_saves/'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lamda = 1.


# print("Label smoothing set to: ", bool(1))   
learning_rate = 0.00004
decoder_hidden_size = CNN_channels
dropout = 0.1

batch_size = 64
batch_size_val = 64
grad_clip = 5.
transformer_layers = 3
heads = 2
beta1 = 0.9
beta2 = 0.98
mode = 'train'
losses = []
vocab_threshold = 5
data_loader = get_loader(transform=transform_train,
                         mode='train',
                         batch_size=64,
                         vocab_threshold=vocab_threshold,
                         vocab_from_file=False,
                         num_workers=4)
vocab_size = len(data_loader.dataset.vocab)


if not os.path.isdir(model_save_path):
    os.mkdir(model_save_path)


model = EncoderDecoder(encoder_class, decoder_class, vocab_size, target_sos=0, 
                      target_eos=1, fine_tune=bool(0), encoder_type=encoder_type, encoder_hidden_size=CNN_channels, 
                       decoder_hidden_size=decoder_hidden_size, 
                       word_embedding_size=word_embedding_size, attention_dim=attention_dim, decoder_type=decoder_type, cell_type='lstm', beam_width=beam_width, dropout=dropout,
                       transformer_layers=transformer_layers, num_heads=heads)

if torch.cuda.is_available():
    device = torch.cuda.current_device()
    model = torch.nn.DataParallel(model, device_ids=[0]).to(device)


    
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(beta1, beta2))

while epoch <= max_epochs:
        model.train()
        loss, n_iter, losses = train_for_epoch(model, data_loader, optimizer, device, n_iter, epoch, losses)
        print(f'Epoch {epoch}: loss={loss}')
        epoch += 1
        with open("losses_after_%d_epochs.txt"%epoch, "wb") as buttz1:   #Pickling
            pickle.dump(losses, buttz1)
        buttz1.close()
        

## Testing

In [None]:
encoder_class = Encoder
decoder_class = TransformerDecoder
encoder_type = 'resnet50'
decoder_type = 'transformer' #transformer, rnn
warmup_steps = 4000
n_iter = 1
CNN_channels = 2048

epoch = 0
max_epochs = 4
beam_width = 4

print("Epochs are read correctly: ", max_epochs)
print("Encoder type is read correctly: ", encoder_type)
print("Number of CNN channels being used: ", CNN_channels)
print("Fine tune setting is set to: ", bool(0))


word_embedding_size = 512
attention_dim = 512
model_save_path = './model_saves/'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lamda = 1.


print("Label smoothing set to: ", bool(1))   
learning_rate = 0.00004
decoder_hidden_size = CNN_channels
dropout = 0.1

batch_size = 32 #CHANGE THIS WHEN YOU NEED TO
# batch_size_val = 64
grad_clip = 5.
transformer_layers = 3
heads = 4
beta1 = 0.9
beta2 = 0.98
mode = 'test'
losses = []
directory = './qcheckpoints2'
decoder_file = 'transformerdecoder-epoch-4-step-64000.pkl'

transform_val = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.485, 0.456, 0.406),
                                                          (0.229, 0.224, 0.225))
                                    ])

val_loader = get_loader(transform=transform_val,
                         mode='val',
                         batch_size=32,
                         vocab_threshold=vocab_threshold,
                         vocab_from_file=False,
                         num_workers=4)

with open('data_loader_parallel.txt', 'rb') as dl:
    data_loader = pickle.load(dl)

vocab_size = len(data_loader.dataset.vocab)

model = EncoderDecoder(encoder_class, decoder_class, vocab_size, target_sos=0, 
                      target_eos=1, fine_tune=bool(0), encoder_type=encoder_type, encoder_hidden_size=CNN_channels, 
                       decoder_hidden_size=decoder_hidden_size, 
                       word_embedding_size=word_embedding_size, attention_dim=attention_dim, decoder_type=decoder_type, cell_type='lstm', beam_width=beam_width, dropout=dropout,
                       transformer_layers=transformer_layers, num_heads=heads)



model.decoder.load_state_dict(torch.load(os.path.join(directory, decoder_file)))

In [12]:
import matplotlib.pyplot as plt

def get_output_sentence(model, device, images, vocab):
    # hypotheses = []
    with torch.no_grad():
        torch.cuda.empty_cache()

        images = images.to(device)
        target_eos = len(vocab) + 1
        target_sos = 0

        b_1 = model(images, on_max='halt')
        captions_cand = b_1[..., 0]

        cands = captions_cand.T
        cands_list = cands.tolist()
        for i in range(len(cands_list)): #Removes sos tags
            cands_list[i] = list(filter((target_sos).__ne__, cands_list[i]))
            cands_list[i] = list(filter((target_eos).__ne__, cands_list[i]))

    #     hypotheses += cands_list
    
    return cands_list

def clean_sentence(output):
    
    words_sequence = []
    
    for i in output:
        if (i == 1):
            continue
        words_sequence.append(data_loader.dataset.vocab.idx2word[i])
    
    words_sequence = words_sequence[1:-1] 
    return words_sequence


def get_prediction(model):
    orig_image, images = next(iter(data_loader_test))
    vocab = data_loader.dataset.vocab
    plt.imshow(np.squeeze(orig_image,0))
    plt.title('Sample Image')
    plt.show()
    images = images.to(device)
    model.eval()
    model = model.to(device)
    print("images: ", images.shape)
    sentence = get_output_sentence(model, device, images, vocab)[0]
    sentence = clean_sentence(sentence)
    print(sentence)

In [None]:
get_prediction(model)

In [None]:
#Open losses file and plot. You will need to replace the filename with whatever you saved your loss as

with open('q2_losses_after_5_epochs.txt', 'rb') as five:
    losses5 = pickle.load(five)
ploss=[i.item() for i in losses5]
plt.plot(ploss)

## Definitions for getting Bleu Score

In [18]:
def clean_sentence_bleu(output, data_loader):
    words_sequence = []
    for i in output:
        i = int(i)
        if (i == 1):
            continue
        words_sequence.append(data_loader.dataset.vocab.idx2word[i])
    words_sequence = words_sequence[1:-1] 
    return words_sequence

def get_avg_bleu(model, dataloader, device):
    '''Determine the average BLEU score across the entire dataset
    '''
    with torch.no_grad():
        total_score1 = 0
        total_score2 = 0
        total_score3 = 0
        total_score4 = 0
        total_num = 0
        for data in tqdm(dataloader):
            torch.cuda.empty_cache()
            
            #load images and reference captions from the dataset
            images, captions_ref, cap_lens = data
            # Get caption predictions for the current batch of validation images
            captions_cand = caption_list
            
            #calculate the bleu score for the entire batch
            print(captions_ref.shape)
            refs = []
            cands = []
            for captions in captions_ref:
                refs.append(clean_sentence_bleu(captions, dataloader))
                
            for captions_2 in captions_cand:
                cands.append(clean_sentence_bleu(captions_2, dataloader))
            batch_score1, batch_score2, batch_score3, batch_score4 = get_batch_bleu(refs, cands)
            
            #increment total_score and total_num
            total_score1 = total_score1 + batch_score1
            total_score2 = total_score2 + batch_score2
            total_score3 = total_score3 + batch_score3
            total_score4 = total_score4 + batch_score4
            
            total_num = total_num + dataloader.dataset.batch_size
        
            
            #print('Total Num: ',total_num)
            
        avg_score1 = total_score1/total_num
        avg_score2 = total_score2/total_num
        avg_score3 = total_score3/total_num
        avg_score4 = total_score4/total_num
        
        return avg_score1, avg_score2, avg_score3, avg_score4

def get_batch_candidates(model, images):
    '''Gets predictions for the current batch of images
    '''
    with torch.no_grad():
        captions_cand = []
        
        print("images: ", images.shape)
        captions_cand = model(images, on_max = 'halt')
        print("captions cand: ", captions_cand)

        return torch.Tensor(captions_cand)

def get_batch_bleu(captions_ref, captions_cand):
    '''Compute the total BLEU score over elements in a batch
    '''
    with torch.no_grad():
        scores1 = 0
        scores2 = 0
        scores3 = 0
        scores4 = 0
        for i in range(len(captions_ref)):
            #print(len(ref))
            captions_cand[i] = [x for x in captions_cand[i] if x != '.']
            captions_ref[i] = [x for x in captions_ref[i] if x != '.']
            captions_cand[i] = [x for x in captions_cand[i] if x != 0]
            captions_ref[i] = [x for x in captions_ref[i] if x != 0]
            
            print('captions_ref= ',captions_ref[i])
            print('captions_cand= ',captions_cand[i])
            scores1 += nltk.translate.bleu_score.sentence_bleu([captions_ref[i]], captions_cand[i], weights=(1, 0, 0, 0))
            scores2 += nltk.translate.bleu_score.sentence_bleu([captions_ref[i]], captions_cand[i], weights=(0.5, 0.5, 0, 0))
            scores3 += nltk.translate.bleu_score.sentence_bleu([captions_ref[i]], captions_cand[i], weights=(0.33, 0.33, 0.33, 0.33))
            scores4 += nltk.translate.bleu_score.sentence_bleu([captions_ref[i]], captions_cand[i], weights=(0.25, 0.25, 0.25, 0.25))
            print('Score : ',nltk.translate.bleu_score.sentence_bleu([captions_ref[i]], captions_cand[i],weights=(1, 0, 0, 0)))
            print('Score : ',nltk.translate.bleu_score.sentence_bleu([captions_ref[i]], captions_cand[i],weights=(0.5, 0.5, 0, 0)))
            print('Score : ',nltk.translate.bleu_score.sentence_bleu([captions_ref[i]], captions_cand[i],weights=(0.33, 0.33, 0.33, 0)))
            print('Score : ',nltk.translate.bleu_score.sentence_bleu([captions_ref[i]], captions_cand[i],weights=(0.25, 0.25, 0.25, 0.25)))
        return scores1/batch_size, scores2/batch_size, scores3/batch_size, scores4/batch_size

In [None]:
# images, captions_ref, cap_lens = next(iter(val_loader))

model = model.to(device)
model.eval()

caption_list = []
references = []
vocab = data_loader.dataset.vocab
count = 0
num_batches = 100 #Determine over how many batches you would like to calculate BLEU score

for data in tqdm(val_loader):
    images, captions_ref, cap_lens = data
    sublist = []
    for image in images:
        image = torch.unsqueeze(image, 0)
        sentence = get_output_sentence(model, device, image, vocab)[0]
        try:
            sentence = clean_sentence(sentence)
            sublist.append(sentence)
        except:
            pass
        
    caption_list.append(sublist)
    references.append(captions_ref)
    count += 1
    if count == num_batches:
        break
        

In [None]:
#Get BLEU scores

bleu_list = []
bleus = [0,0,0,0]
for i in range(4000): 
    refs = []
    for caption in references[i]:
        refs.append(clean_sentence_bleu(caption, val_loader))
    bleu1, bleu2, bleu3, bleu4 = get_batch_bleu(refs, caption_list[i])
    bleus[0] += bleu1
    bleus[1] += bleu2
    bleus[2] += bleu3
    bleus[3] += bleu4
    bleu_list.append([bleu1, bleu2, bleu3, bleu4])

for i in range(len(bleus)):
    bleus[i] = bleus[i]/4000

print(bleus)