In [1]:
import torch
from torch import nn
import torchvision
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from torch.nn.utils.rnn import pad_sequence


import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2


from tqdm import tqdm
import os
import gc
import pandas as pd
import itertools
from tqdm.autonotebook import tqdm
from sklearn.model_selection import train_test_split
import pickle


from collections import Counter
import nltk
nltk.download('punkt')


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
INPUT_IMAGES_DIR = "/kaggle/input/flickr-image-dataset/flickr30k_images/flickr30k_images/"
LABEL_PATH = "/kaggle/input/flickr-image-dataset/flickr30k_images/results.csv"
OUTPUT_PATH = "/kaggle/working"

In [3]:
df = pd.read_csv(LABEL_PATH, delimiter="|")
df.columns = ['image', 'caption_number', 'caption']
df['caption'] = df['caption'].str.lstrip()
df['caption_number'] = df['caption_number'].str.lstrip()
df.loc[19999, 'caption_number'] = "4"
df.loc[19999, 'caption'] = "A dog runs across the grass ."
ids = [id_ for id_ in range(len(df) // 5) for i in range(5)]
df['id'] = ids
df.head()

Unnamed: 0,image,caption_number,caption,id
0,1000092795.jpg,0,Two young guys with shaggy hair look at their ...,0
1,1000092795.jpg,1,"Two young , White males are outside near many ...",0
2,1000092795.jpg,2,Two men in green shirts are standing in a yard .,0
3,1000092795.jpg,3,A man in a blue shirt standing in a garden .,0
4,1000092795.jpg,4,Two friends enjoy time spent together .,0


In [4]:
df["length"] = df["caption"].apply(lambda row: len(row.strip().split()))

In [5]:
import nltk
from collections import Counter

class Vocabulary():
    def __init__(self, df,vocab_threshold,vocab_file='vocab.pkl',
               start_word="<start>",end_word="<end>",unk_word="<unk>",vocab_from_file=False):
        self.vocab_threshold = vocab_threshold
        self.vocab_file = vocab_file
        self.start_word = start_word
        self.end_word = end_word
        self.unk_word = unk_word
        self.vocab_from_file = vocab_from_file
        self.df = df
        self.get_vocab()
        
    def get_vocab(self):
        """Load the vocabulary from file OR build the vocabulary from scratch."""
        if os.path.exists(self.vocab_file) & self.vocab_from_file:
            with open(self.vocab_file, 'rb') as f:
                vocab = pickle.load(f)
                self.word2idx = vocab.word2idx
                self.idx2word = vocab.idx2word
            print('Vocabulary successfully loaded from vocab.pkl file!')
        else:
            self.build_vocab()
            with open(self.vocab_file, 'wb') as f:
                pickle.dump(self, f)
                
    def build_vocab(self):
        """Populate the dictionaries for converting tokens to integers (and vice-versa)."""
        self.init_vocab()
        self.add_word(self.start_word)
        self.add_word(self.end_word)
        self.add_word(self.unk_word)
        self.add_captions()

    def init_vocab(self):
        """Initialize the dictionaries for converting tokens to integers (and vice versa)."""
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0 

    def add_word(self, word):
        """Add a token to the vocabulary."""
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def add_captions(self):
        """Loop over training captions and add all tokens to the vocabulary that meet or exceed the threshold."""
        counter = Counter()
        x = df['caption'].apply(lambda row:counter.update(nltk.tokenize.word_tokenize(row)))
        words = [word for word, cnt in counter.items() if cnt>=self.vocab_threshold]
        for i, word in enumerate(words):
            self.add_word(word)

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx[self.unk_word]
        return self.word2idx[word]
    
    def __len__(self):
        return len(self.word2idx)

In [6]:
def collate_batch(batch):
    image_list,caption_list,caplen_list = [],[],[]
    for (image, caption,caplen) in batch:
        image_list.append(image)
        caption_list.append(caption)
        caplen_list.append(caplen)
    
    image_list = torch.stack(image_list)
    caplen_list = torch.tensor(caplen_list)
    caption_list = pad_sequence(caption_list, batch_first=True, padding_value=0)
    return image_list,caption_list,caplen_list.unsqueeze(1)

In [7]:
class ImageCaptioningDataset(torch.utils.data.Dataset):
    def __init__(self,df,transform,mode,batch_size,vocab_threshold):
        self.caption_lengths = df['length']
        self.batch_size = batch_size
        self.df = df
        self.transform = transform
        self.mode = mode
        if mode == 'train':
            self.vocab = Vocabulary(df,vocab_threshold,'vocab.pkl1',vocab_from_file=False)
        elif (mode=='val') or (mode=='test'):
            self.vocab = Vocabulary(df,vocab_threshold,'vocab_pkl1',vocab_from_file=True)
    def __getitem__(self,index):
        if (self.mode == 'train') or (self.mode =='val'):
            image = Image.open(f"{INPUT_IMAGES_DIR}/{self.df['image'][index]}")
            image = self.transform(image)

            #Convert caption to tensor of word ids 
            tokens = nltk.tokenize.word_tokenize(self.df['caption'][index].lower())
            caption = []
            caption.append(self.vocab(self.vocab.start_word))
            caption.extend([self.vocab(token) for token in tokens])
            caption.append(self.vocab(self.vocab.end_word))
            caption = torch.Tensor(caption).long()
            
            caplen = self.df['length'][index]

            return (image,caption,caplen)
        else:
            p_image = Image.open(f"{INPUT_IMAGES_DIR}/{self.df['image'][index]}").convert('RGB')
            image = np.array(p_image)
            trans_image = self.transform(p_image)

            return image, trans_image,self.df['caption'][index]

    def get_train_indices(self):
        sel_length = np.random.choice(self.caption_lengths)
        all_indices = np.where([self.caption_lengths[i] == sel_length for i in np.arange(len(self.caption_lengths))])[0]
        indices = list(np.random.choice(all_indices, size=self.batch_size))
        return indices

    def __len__(self):
        return len(self.df['caption'])

In [8]:
mode ='train'
## TODO #1: Select appropriate values for the Python variables below.
batch_size = 128         # batch size
vocab_threshold = 6        # minimum word count threshold
embed_size = 512           # dimensionality of image and word embeddings
hidden_size = 512          # number of features in hidden state of the RNN decoder
num_epochs = 2             # number of training epochs (1 for testing)
save_every = 1             # determines frequency of saving model weights
print_every = 200          # determines window for printing average loss
log_file = 'training_log.txt'       # name of file with saved training loss and perplexity

In [9]:
transform_train = transforms.Compose([ 
    transforms.Resize((256,256)),                          # smaller edge of image resized to 256                      # get 224x224 crop from random location
    transforms.RandomHorizontalFlip(),               # horizontally flip image with probability=0.5
    transforms.ToTensor(),                           # convert the PIL Image to a tensor
    ])
transform_test = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(), 
    ])

In [10]:
train, test =train_test_split(df,test_size=0.2,shuffle=True)
train_df,valid_df = train_test_split(train,test_size=0.2,shuffle=True)
train_df, valid_df, test = train_df.reset_index(drop=True), valid_df.reset_index(drop=True), test.reset_index(drop=True)

train_dataset = ImageCaptioningDataset(train_df,transform_train,mode,batch_size,vocab_threshold)
valid_dataset = ImageCaptioningDataset(valid_df,transform_train,'val',batch_size,vocab_threshold)
train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,
                                               collate_fn=collate_batch,num_workers=2)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset,batch_size=batch_size,shuffle=True,
                                               collate_fn=collate_batch,num_workers=2)
# indices = train_dataset.get_train_indices()
# initial_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices=indices)
#train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_sampler=torch.utils.data.sampler.BatchSampler(sampler=initial_sampler,batch_size=batch_size,drop_last=False))

In [11]:
test_dataset = ImageCaptioningDataset(test,transform_test,'test',1,0)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=1,shuffle=True)

Vocabulary successfully loaded from vocab.pkl file!


In [25]:
class Encoder(nn.Module):
    """
    Encoder.
    shift to only output the feature map
    """

    def __init__(self, encoded_image_size=14):
        super(Encoder, self).__init__()
        self.enc_image_size = encoded_image_size

        resnet = torchvision.models.resnet50(pretrained=True)  # pretrained ImageNet ResNet-101

        # Remove linear and pool layers (since we're not doing classification)
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)

        # Resize image to fixed size to allow input images of variable size
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))

        self.fine_tune()

    def forward(self, images):
        """
        Forward propagation.
        :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size)
        :return: encoded images
        """
        feature_map = self.resnet(images)  # (batch_size, 2048, image_size/32, image_size/32)
        #out = self.adaptive_pool(out)  # (batch_size, 2048, encoded_image_size, encoded_image_size)
        #out = out.permute(0, 2, 3, 1)  # (batch_size, encoded_image_size, encoded_image_size, 2048)
        return feature_map

    def fine_tune(self, fine_tune=False):
        """
        Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder.
        :param fine_tune: Allow?
        """
        for p in self.resnet.parameters():
            p.requires_grad = False
        # If fine-tuning, only fine-tune convolutional blocks 2 through 4
        for c in list(self.resnet.children())[5:]:
            for p in c.parameters():
                p.requires_grad = fine_tune

In [13]:
class Spatial_attention(nn.Module):
    """
    Attention Network.
    """

    def __init__(self,feature_map,decoder_dim,K = 512):
        """
        :param feature_map: feature map in level L
        :param decoder_dim: size of decoder's RNN
        """
        super(Spatial_attention, self).__init__()
        _,C,H,W = tuple([int(x) for x in feature_map])
        self.W_s = nn.Parameter(torch.randn(C,K))
        self.W_hs = nn.Parameter(torch.randn(K,decoder_dim))
        self.W_i = nn.Parameter(torch.randn(K,1))
        self.bs = nn.Parameter(torch.randn(K))
        self.bi = nn.Parameter(torch.randn(1))
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim = 0)  # softmax layer to calculate weights
        
    def forward(self, feature_map, decoder_hidden):
        """
        Forward propagation.
        :param feature_map: feature map in level L(batch_size, C, H, W)
        :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
        :return: alpha
        """
        V_map = feature_map.view(feature_map.shape[0],2048,-1) 
        V_map = V_map.permute(0,2,1)#(batch_size,W*H,C)
        # print(V_map.shape)
        # print("m1",torch.matmul(V_map,self.W_s).shape)
        # print("m2",torch.matmul(decoder_hidden,self.W_hs).shape)
        att = self.tanh((torch.matmul(V_map,self.W_s)+self.bs) + (torch.matmul(decoder_hidden,self.W_hs).unsqueeze(1)))#(batch_size,W*H,C)
        # print("att",att.shape)
        alpha = self.softmax(torch.matmul(att,self.W_i) + self.bi)
#         print("alpha",alpha.shape)
        alpha = alpha.squeeze(2)
        feature_map = feature_map.view(feature_map.shape[0],2048,-1) 
        # print("feature_map",feature_map.shape)
        # print("alpha",alpha.shape)
        temp_alpha = alpha.unsqueeze(1)
        attention_weighted_encoding = torch.mul(feature_map,temp_alpha)
        return attention_weighted_encoding,alpha

In [14]:
class Channel_wise_attention(nn.Module):
    """
    Attention Network.
    """

    def __init__(self,feature_map,decoder_dim,K = 512):
        """
        :param feature_map: feature map in level L
        :param decoder_dim: size of decoder's RNN
        """
        super(Channel_wise_attention, self).__init__()
        _,C,H,W = tuple([int(x) for x in feature_map])
        self.W_c = nn.Parameter(torch.randn(1,K))
        self.W_hc = nn.Parameter(torch.randn(K,decoder_dim))
        self.W_i_hat = nn.Parameter(torch.randn(K,1))
        self.bc = nn.Parameter(torch.randn(K))
        self.bi_hat = nn.Parameter(torch.randn(1))
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim = 0)  # softmax layer to calculate weights
        
    def forward(self, feature_map, decoder_hidden):
        """
        Forward propagation.
        :param feature_map: feature map in level L(batch_size, C, H, W)
        :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
        :return: alpha
        """
        V_map = feature_map.view(feature_map.shape[0],2048,-1) .mean(dim=2)
        V_map = V_map.unsqueeze(2)#(batch_size,C,1)
        # print(feature_map.shape)
        # print(V_map.shape)
        # print("wc",self.W_c.shape)
        # print("whc",self.W_hc.shape)
        # print("decoder_hidden",decoder_hidden.shape)
        # print("m1",torch.matmul(V_map,self.W_c).shape)
        # print("m2",torch.matmul(decoder_hidden,self.W_hc).shape)
        # print("bc",self.bc.shape)
        att = self.tanh((torch.matmul(V_map,self.W_c) + self.bc) + (torch.matmul(decoder_hidden,self.W_hc).unsqueeze(1)))#(batch_size,C,K)
#         print("att",att.shape)
        beta = self.softmax(torch.matmul(att,self.W_i_hat) + self.bi_hat)
        beta = beta.unsqueeze(2)
        # print("beta",beta.shape)
        attention_weighted_encoding = torch.mul(feature_map,beta)

        return attention_weighted_encoding,beta

In [15]:
class DecoderWithAttention(nn.Module):
    """
    Decoder.
    shift to sca attention
    """

    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size,encoder_out_shape=[1,2048,8,8], K=512,encoder_dim=2048, dropout=0.5):
        """
        :param attention_dim: size of attention network
        :param embed_dim: embedding size
        :param decoder_dim: size of decoder's RNN
        :param vocab_size: size of vocabulary
        :param encoder_dim: feature size of encoded images
        :param dropout: dropout
        """
        super(DecoderWithAttention, self).__init__()

        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout

        self.Spatial_attention = Spatial_attention(encoder_out_shape, decoder_dim, K)  # attention network
        self.Channel_wise_attention = Channel_wise_attention(encoder_out_shape, decoder_dim, K) # ATTENTION 
        self.embedding = nn.Embedding(vocab_size, embed_dim)  # embedding layer
        self.dropout = nn.Dropout(p=self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)  # decoding LSTMCell
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial hidden state of LSTMCell
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial cell state of LSTMCell
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  # linear layer to create a sigmoid-activated gate
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)  # linear layer to find scores over vocabulary
        self.init_weights()  # initialize some layers with the uniform distribution
        self.AvgPool = nn.AvgPool2d(8)
    def init_weights(self):
        """
        Initializes some parameters with values from the uniform distribution, for easier convergence.
        """
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def load_pretrained_embeddings(self, embeddings):
        """
        Loads embedding layer with pre-trained embeddings.
        :param embeddings: pre-trained embeddings
        """
        self.embedding.weight = nn.Parameter(embeddings)

    def fine_tune_embeddings(self, fine_tune=True):
        """
        Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings).
        :param fine_tune: Allow?
        """
        for p in self.embedding.parameters():
            p.requires_grad = fine_tune

    def init_hidden_state(self, encoder_out):
        """
        Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
        :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
        :return: hidden state, cell state
        """
        mean_encoder_out = self.AvgPool(encoder_out).squeeze(-1).squeeze(-1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        Forward propagation.
        :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
        :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length)
        :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1)
        :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
        """

        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size

        # Flatten image
        # encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        # num_pixels = encoder_out.size(1)

        # Sort input data by decreasing lengths; why? apparent below
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        # Embedding
        embeddings = self.embedding(encoded_captions)  # (batch_size, max_caption_length, embed_dim)

        # Initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)

        # We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
        # So, decoding lengths are actual lengths - 1
        decode_lengths = (caption_lengths - 1).tolist()

        # Create tensors to hold word predicion scores and alphas
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)#需要更改形状？
        #alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)#需要更改形状

        # At each time-step, decode by
        # attention-weighing the encoder's output based on the decoder's previous hidden state output
        # then generate a new word in the decoder with the previous word and the attention weighted encoding
        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            # attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
            #                                                     h[:batch_size_t])
            #channel-spatial模式attention
            #channel_wise
            attention_weighted_encoding, beta = self.Channel_wise_attention(encoder_out[:batch_size_t],h[:batch_size_t])
            #spatial
            attention_weighted_encoding, alpha = self.Spatial_attention(attention_weighted_encoding[:batch_size_t],h[:batch_size_t])
            #对attention_weighted_encoding降维
            attention_weighted_encoding = attention_weighted_encoding.view(attention_weighted_encoding.shape[0],2048,8,8)
            attention_weighted_encoding = self.AvgPool(attention_weighted_encoding)
            attention_weighted_encoding = attention_weighted_encoding.squeeze(-1).squeeze(-1)
            # gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar, (batch_size_t, encoder_dim)
            # attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))  # (batch_size_t, decoder_dim)
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
            #alphas[:batch_size_t, t, :] = alpha

        return predictions, encoded_captions, decode_lengths, sort_ind

In [16]:
def accuracy(scores, targets, k):
    """
    Computes top-k accuracy, from predicted and true labels.
    :param scores: scores from the model
    :param targets: true labels
    :param k: k in top-k accuracy
    :return: top-k accuracy
    """

    batch_size = targets.size(0)
    _, ind = scores.topk(k, 1, True, True)
    correct = ind.eq(targets.view(-1, 1).expand_as(ind))
    correct_total = correct.view(-1).float().sum()  # 0D tensor
    return correct_total.item() * (100.0 / batch_size)

In [17]:
def clip_gradient(optimizer, grad_clip):
    """
    Clips gradients computed during backpropagation to avoid explosion of gradients.
    :param optimizer: optimizer with the gradients to be clipped
    :param grad_clip: clip value
    """
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)

In [18]:
def adjust_learning_rate(optimizer, shrink_factor):
    """
    Shrinks learning rate by a specified factor.
    :param optimizer: optimizer whose learning rate must be shrunk.
    :param shrink_factor: factor in interval (0, 1) to multiply learning rate with.
    """

    print("\nDECAYING learning rate.")
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * shrink_factor
    print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],))

In [19]:
def save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, decoder_optimizer,
                    is_best):
    """
    Saves model checkpoint.
    :param data_name: base name of processed dataset
    :param epoch: epoch number
    :param epochs_since_improvement: number of epochs since last improvement in BLEU-4 score
    :param encoder: encoder model
    :param decoder: decoder model
    :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning
    :param decoder_optimizer: optimizer to update decoder's weights
    :param bleu4: validation BLEU-4 score for this epoch
    :param is_best: is this checkpoint the best so far?
    """
    state = {'epoch': epoch,
             'epochs_since_improvement': epochs_since_improvement,
             'encoder': encoder,
             'decoder': decoder,
             'encoder_optimizer': encoder_optimizer,
             'decoder_optimizer': decoder_optimizer}
    filename = 'checkpoint_' + data_name + '.pth.tar'
    torch.save(state, filename)
    # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint
    if is_best:
        torch.save(state, 'BEST_' + filename)

In [20]:
class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [21]:
def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):
    """
    Performs one epoch's training.

    :param train_loader: DataLoader for training data
    :param encoder: encoder model
    :param decoder: decoder model
    :param criterion: loss layer
    :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning)
    :param decoder_optimizer: optimizer to update decoder's weights
    :param epoch: epoch number
    """

    decoder.train()  # train mode (dropout and batchnorm is used)
    encoder.train()

#     batch_time = AverageMeter()  # forward prop. + back prop. time
#     data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss (per word decoded)
    top5accs = AverageMeter()  # top5 accuracy

    # Batches
    for i, (imgs, caps,caplens) in enumerate(train_loader):

        # Move to GPU, if available
        imgs = imgs.to(device)
        caps = caps.to(device)

        # Forward prop.
        imgs = encoder(imgs)
        scores, caps_sorted, decode_lengths, sort_ind = decoder(imgs, caps, caplens)

        # Since we decoded starting with , the targets are all words after , up to 
        targets = caps_sorted[:, 1:]

        # Remove timesteps that we didn't decode at, or are pads
        # pack_padded_sequence is an easy trick to do this
        scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)
        scores = scores.data
        targets = targets.data
        # Calculate loss
        loss = criterion(scores, targets)

        # Add doubly stochastic attention regularization
        #loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

        # Back prop.
        decoder_optimizer.zero_grad()
        if encoder_optimizer is not None:
            encoder_optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        if grad_clip is not None:
            clip_gradient(decoder_optimizer, grad_clip)
            if encoder_optimizer is not None:
                clip_gradient(encoder_optimizer, grad_clip)

        # Update weights
        decoder_optimizer.step()
        if encoder_optimizer is not None:
            encoder_optimizer.step()

        # Keep track of metrics
        top5 = accuracy(scores, targets, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))

        # Print status
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                          loss=losses,top5=top5accs))

In [22]:
def validate(val_loader, encoder, decoder, criterion):
    """
    Performs one epoch's validation.

    :param val_loader: DataLoader for validation data.
    :param encoder: encoder model
    :param decoder: decoder model
    :param criterion: loss layer
    :return: BLEU-4 score
    """
    decoder.eval()  # eval mode (no dropout or batchnorm)
    if encoder is not None:
        encoder.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    top5accs = AverageMeter()

    references = list()  # references (true captions) for calculating BLEU-4 score
    hypotheses = list()  # hypotheses (predictions)

    # explicitly disable gradient calculation to avoid CUDA memory error
    # solves the issue #57
    with torch.no_grad():
        # Batches
        for i, (imgs, caps,caplens) in enumerate(val_loader):

            # Move to device, if available
            imgs = imgs.to(device)
            caps = caps.to(device)

            # Forward prop.
            if encoder is not None:
                imgs = encoder(imgs)
            scores, caps_sorted, decode_lengths, sort_ind = decoder(imgs, caps, caplens)

            # Since we decoded starting with , the targets are all words after , up to 
            targets = caps_sorted[:, 1:]

            # Remove timesteps that we didn't decode at, or are pads
            # pack_padded_sequence is an easy trick to do this
            scores_copy = scores.clone()
            scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)
            targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)
            scores = scores.data
            targets = targets.data
            # Calculate loss
            loss = criterion(scores, targets)

            # Add doubly stochastic attention regularization
            #loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

            # Keep track of metrics
            losses.update(loss.item(), sum(decode_lengths))
            top5 = accuracy(scores, targets, 5)
            top5accs.update(top5, sum(decode_lengths))


            if i % print_freq == 0:
                print('Validation: [{0}/{1}]\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader),loss=losses, top5=top5accs))

            # Store references (true captions), and hypothesis (prediction) for each image
            # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
            # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]

            # References
#             allcaps = allcaps[sort_ind]  # because images were sorted in the decoder
#             for j in range(allcaps.shape[0]):
#                 img_caps = allcaps[j].tolist()
#                 img_captions = list(
#                     map(lambda c: [w for w in c if w not in {word_map[''], word_map['']}],
#                         img_caps))  # remove  and pads
#                 references.append(img_captions)

            # Hypotheses
            _, preds = torch.max(scores_copy, dim=2)
            preds = preds.tolist()
            temp_preds = list()
            for j, p in enumerate(preds):
                temp_preds.append(preds[j][:decode_lengths[j]])  # remove pads
            preds = temp_preds
            hypotheses.extend(preds)

            #assert len(references) == len(hypotheses)

        # Calculate BLEU-4 scores
        #bleu4 = corpus_bleu(references, hypotheses)

        print(
            '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}\n'.format(
                loss=losses,
                top5=top5accs))

    return losses.avg

In [23]:
emb_dim = 512  # dimension of word embeddings
attention_dim = 512  # dimension of attention linear layers
decoder_dim = 512  # dimension of decoder RNN
dropout = 0.5

# Training parameters
start_epoch = 0
epochs = 2  # number of epochs to train for (if early stopping is not triggered)
epochs_since_improvement = 0  # keeps track of number of epochs since there's been an improvement in validation BLEU
batch_size = 16
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
best_bleu4 = 0.  # BLEU-4 score right now
print_freq = 100  # print training/validation stats every __ batches
fine_tune_encoder = False  # fine-tune encoder?
checkpoint = None  # path to checkpoint, None if none

In [26]:
word_map = train_dataloader.dataset.vocab.word2idx
decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),lr=decoder_lr)
encoder = Encoder()
encoder.fine_tune(fine_tune_encoder)
encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                     lr=encoder_lr) if fine_tune_encoder else None

decoder = decoder.to(device)
encoder = encoder.to(device)

# Loss function
criterion = nn.CrossEntropyLoss().to(device)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

In [27]:
# Epochs
for epoch in range(start_epoch, epochs):
    # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
    if epochs_since_improvement == 20:
        break
    if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
        adjust_learning_rate(decoder_optimizer, 0.8)
        if fine_tune_encoder:
            adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
    train(train_loader=train_dataloader,
          encoder=encoder,
          decoder=decoder,
          criterion=criterion,
          encoder_optimizer=encoder_optimizer,
          decoder_optimizer=decoder_optimizer,
          epoch=epoch)

        # One epoch's validation
    recent_bleu4 = validate(val_loader=valid_dataloader,
                            encoder=encoder,
                            decoder=decoder,
                            criterion=criterion)

    # Check if there was an improvement
    is_best = recent_bleu4 > best_bleu4
    best_bleu4 = max(recent_bleu4, best_bleu4)
    if not is_best:
        epochs_since_improvement += 1
        print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
    else:
        epochs_since_improvement = 0

    # Save checkpoint
    save_checkpoint('SCA_CNN', epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
                    decoder_optimizer, is_best)

Epoch: [0][0/795]	Loss 8.9260 (8.9260)	Top-5 Accuracy 0.060 (0.060)
Epoch: [0][100/795]	Loss 6.3344 (7.7331)	Top-5 Accuracy 21.281 (14.296)
Epoch: [0][200/795]	Loss 5.7010 (6.8293)	Top-5 Accuracy 28.099 (19.519)
Epoch: [0][300/795]	Loss 5.6928 (6.4424)	Top-5 Accuracy 28.912 (22.521)
Epoch: [0][400/795]	Loss 5.5128 (6.2315)	Top-5 Accuracy 30.381 (24.167)
Epoch: [0][500/795]	Loss 5.5507 (6.0980)	Top-5 Accuracy 29.369 (25.176)
Epoch: [0][600/795]	Loss 5.4672 (6.0025)	Top-5 Accuracy 30.190 (25.907)
Epoch: [0][700/795]	Loss 5.5818 (5.9325)	Top-5 Accuracy 29.485 (26.453)
Validation: [0/199]	Loss 5.4317 (5.4317)	Top-5 Accuracy 30.322 (30.322)	
Validation: [100/199]	Loss 5.3572 (5.4131)	Top-5 Accuracy 29.962 (30.543)	

 * LOSS - 5.415, TOP-5 ACCURACY - 30.568

Epoch: [1][0/795]	Loss 5.3575 (5.3575)	Top-5 Accuracy 30.554 (30.554)
Epoch: [1][100/795]	Loss 5.4442 (5.4696)	Top-5 Accuracy 28.939 (29.998)
Epoch: [1][200/795]	Loss 5.5081 (5.4679)	Top-5 Accuracy 30.373 (30.031)
Epoch: [1][300/795]	Los

In [31]:
checkpoint = '/kaggle/working/BEST_checkpoint_SCA_CNN.pth.tar'  # model checkpoint
word_map_file =  test_dataloader.dataset.vocab.word2idx

cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Load model
checkpoint = torch.load(checkpoint)
decoder1 = checkpoint['decoder']
decoder1 = decoder1.to(device)
decoder1.eval()
encoder1 = checkpoint['encoder']
encoder1 = encoder1.to(device)
encoder1.eval()

rev_word_map = test_dataloader.dataset.vocab.idx2word
vocab_size = len(word_map_file)

In [341]:
def evaluate(beam_size,loader):
    """
    Evaluation
    :param beam_size: beam size at which to generate captions for evaluation
    :return: BLEU-4 score
    """
    # DataLoader
    # TODO: Batched Beam Search
    # Therefore, do not use a batch_size greater than 1 - IMPORTANT!

    # Lists to store references (true captions), and hypothesis (prediction) for each image
    # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
    # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
    references = list()
    hypotheses = list()

    # For each image
    for i, (orig_image, image, caps) in enumerate(tqdm(loader, desc="EVALUATING AT BEAM SIZE " + str(beam_size))):

        k = beam_size

        # Move to GPU device, if available
        image = image.to(device)  # (1, 3, 256, 256)

        # Encode
        encoder_out = encoder1(image)  # (1, enc_image_size, enc_image_size, encoder_dim)
        enc_image_size = encoder_out.size(1)
        encoder_dim = encoder_out.size(3)

        # Flatten encoding
        encoder_out = encoder_out.view(1, -1, encoder_dim)  # (1, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        # We'll treat the problem as having a batch size of k
        encoder_out = encoder_out.expand(k, num_pixels, encoder_dim)  # (k, num_pixels, encoder_dim)

        # Tensor to store top k previous words at each step; now they're just <start>
        k_prev_words = torch.LongTensor([[word_map['<start>']]] * k).to(device)  # (k, 1)

        # Tensor to store top k sequences; now they're just <start>
        seqs = k_prev_words  # (k, 1)

        # Tensor to store top k sequences' scores; now they're just 0
        top_k_scores = torch.zeros(k, 1).to(device)  # (k, 1)

        # Lists to store completed sequences and scores
        complete_seqs = list()
        complete_seqs_scores = list()

        # Start decoding
        step = 1
        h, c = decoder1.init_hidden_state(encoder_out)
        # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
        while True:

            embeddings = decoder1.embedding(k_prev_words).squeeze(1)  # (s, embed_dim)
            
            #awe, _ = decoder1.Channel_wise_attention(encoder_out,h)
            awe, _ = decoder1.Spatial_attention(encoder_out,h)
            awe = awe.view(awe.shape[0],2048,8,8)
            awe = nn.AvgPool2d(8)(awe)
            awe = awe.squeeze(-1).squeeze(-1)  # (s, encoder_dim), (s, num_pixels)

            gate = decoder1.sigmoid(decoder.f_beta(h))  # gating scalar, (s, encoder_dim)
            awe = gate * awe

            h, c = decoder1.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (s, decoder_dim)

            scores = decoder1.fc(h)  # (s, vocab_size)
            scores = F.log_softmax(scores, dim=1)

            # Add
            scores = top_k_scores.expand_as(scores) + scores  # (s, vocab_size)

            # For the first step, all k points will have the same scores (since same k previous words, h, c)
            if step == 1:
                top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)
            else:
                # Unroll and find top scores, and their unrolled indices
                top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)

            # Convert unrolled indices to actual indices of scores
            prev_word_inds = top_k_words / vocab_size  # (s)
            next_word_inds = top_k_words % vocab_size  # (s)
            prev_word_inds = prev_word_inds.type(torch.LongTensor)        
            # Add new words to sequences
            seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)

            # Which sequences are incomplete (didn't reach <end>)?
            incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                               next_word != word_map['<end>']]
            complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))

            # Set aside complete sequences
            if len(complete_inds) > 0:
                complete_seqs.extend(seqs[complete_inds].tolist())
                complete_seqs_scores.extend(top_k_scores[complete_inds])
            k -= len(complete_inds)  # reduce beam length accordingly

            # Proceed with incomplete sequences
            if k == 0:
                break
            seqs = seqs[incomplete_inds]
            h = h[prev_word_inds[incomplete_inds]]
            c = c[prev_word_inds[incomplete_inds]]
            encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
            top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
            k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)

            # Break if things have been going on too long
            if step > 50:
                break
            step += 1

        i = complete_seqs_scores.index(max(complete_seqs_scores))
        seq = complete_seqs[i]

        # References
#         img_caps = allcaps[0].tolist()
#         img_captions = list(
#             map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<end>'], word_map['<pad>']}],
#                 img_caps))  # remove <start> and pads
#         references.append(img_captions)

        # Hypotheses
        hypotheses.append([w for w in seq if w not in {word_map['<start>'], word_map['<end>'], word_map['<pad>']}])


    # Calculate BLEU-4 scores
    return hypotheses

In [342]:
# beam_size=1
# evaluate(beam_size,test_dataloader)