### 1) Initial Imports and loading the utils function. The dataset is used is <a href='https://www.kaggle.com/adityajn105/flickr8k'>Flickr 8k</a> from kaggle.<br>Custom dataset and dataloader is implemented in <a href="https://www.kaggle.com/mdteach/torch-data-loader-flicker-8k">this</a> notebook.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import pandas as pd
import os
from collections import Counter
import spacy
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader,Dataset
import torchvision.transforms as T
import pickle
from PIL import Image
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models

In [None]:
#location of the data 
train_data_location =  "../input/flickr30k"
# test_data_location = "../input/flickr8k"
!ls $data_location


<h2>2) Writing the custom dataset</h2>
<p>Writing the custom torch dataset class so, that we can abastract out the dataloading steps during the training and validation process</p>
<p>Here, dataloader is created which gives the batch of image and its captions with following processing done:</p>

<li>caption word tokenized to unique numbers</li>
<li>vocab instance created to store all the relivent words in the datasets</li>
<li>each batch, caption padded to have same sequence length</li>
<li>image resized to the desired size and converted into captions</li>

<br><p>In this way the dataprocessing is done, and the dataloader is ready to be used with <b>Pytorch</b></p>

In [None]:
spacy_eng = spacy.load("en")

In [None]:
# SOS: start of sentence, EOS: end of sentence, PAD: padding token, UNK: unknown
class Vocabulary:
    def __init__(self,freq_threshold):
        #setting the pre-reserved tokens int to string tokens
        self.itos = {0:"<PAD>",1:"<SOS>",2:"<EOS>",3:"<UNK>"}
        
        #string to int tokens
        #its reverse dict self.itos
        self.stoi = {v:k for k,v in self.itos.items()}
        
        self.freq_threshold = freq_threshold
        
    def __len__(self): return len(self.itos)
    
    @staticmethod
    def tokenize(text):
        return [token.text.lower() for token in spacy_eng.tokenizer(str(text))]
    
    def build_vocab(self, sentence_list):
        frequencies = Counter()
        idx = 4
        
        for sentence in sentence_list:
            for word in self.tokenize(sentence):
                frequencies[word] += 1
                
                #add the word to the vocab if it reaches minum frequecy threshold
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
    
    def numericalize(self,text):
        """ For each word in the text corresponding index token for that word form the vocab built as list """
        tokenized_text = self.tokenize(text)
        return [ self.stoi[token] if token in self.stoi else self.stoi["<UNK>"] for token in tokenized_text ]  

    def save_vocab(self, filepath):
        with open(filepath, 'wb') as f:
            pickle.dump({'itos': self.itos, 'stoi': self.stoi}, f)

    def load_vocab(self, filepath):
        with open(filepath, 'rb') as f:
            vocab_data = pickle.load(f)
            self.itos = vocab_data['itos']
            self.stoi = vocab_data['stoi']


In [None]:
class FlickrDataset(Dataset):
    """
    FlickrDataset
    """
    def __init__(self,root_dir,captions_file,test,transform=None,freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        split_ratio = 0.8
        total_samples = len(self.df)

        split_index = int(total_samples * split_ratio)
        self.df = self.df.sample(frac=1, random_state=42).reset_index(drop=True)

        self.full_captions = self.df["caption"]
        
        self.imgs = self.df["image"][:split_index]
        self.imgs_test = self.df["image"][split_index:]
        self.captions = self.df["caption"][:split_index]
        self.captions_test = self.df["caption"][split_index:]
        self.test = test
        
        #Initialize vocabulary and build vocab
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocab(self.full_captions.tolist())
        self.vocab.save_vocab('flickr30k_vocab.pkl')

        
    def __len__(self):
        if self.test == True:
            return len(self.imgs_test)
        return len(self.imgs)
    
    def __getitem__(self,idx: int):
        if self.test == True:
            caption = self.captions_test.iloc[idx]
            img_id = self.imgs_test.iloc[idx]
        else:
            caption = self.captions.iloc[idx]
            img_id = self.imgs.iloc[idx]
        
        
        img_location = os.path.join(self.root_dir,img_id)
        img = Image.open(img_location).convert("RGB")
        
        #apply the transfromation to the image
        if self.transform is not None:
            img = self.transform(img)
        
        #numericalize the caption text
        caption_vec = []
        caption_vec += [self.vocab.stoi["<SOS>"]]
        caption_vec += self.vocab.numericalize(caption)
        caption_vec += [self.vocab.stoi["<EOS>"]]
        
        return img, torch.tensor(caption_vec)

    def get_dataset_sizes(self):
        train_size = len(self.imgs)
        valid_size = len(self.imgs_test)
        return train_size, valid_size

In [None]:
#defing the transform to be applied
transforms1 = T.Compose([
    T.Resize((224,224)),
    T.ToTensor()
])

In [None]:
def show_image(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

In [None]:
class CapsCollate:
    def __init__(self,pad_idx,batch_first=False):
        self.pad_idx = pad_idx
        self.batch_first = batch_first
    
    def __call__(self,batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs,dim=0)
        
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=self.batch_first, padding_value=self.pad_idx)
        return imgs,targets

In [None]:
def get_data_loader(dataset,batch_size,shuffle=False,num_workers=1):
    pad_idx = dataset.vocab.stoi["<PAD>"]
    collate_fn = CapsCollate(pad_idx=pad_idx,batch_first=True)

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=collate_fn
    )

    return data_loader

### 2) **<b>Implementing the Helper function to plot the Tensor image**

In [None]:
#show the tensor image
import matplotlib.pyplot as plt
def show_image(img, title=None):
    """Imshow for Tensor."""
    
    #unnormalize 
    img[0] = img[0] * 0.229
    img[1] = img[1] * 0.224 
    img[2] = img[2] * 0.225 
    img[0] += 0.485 
    img[1] += 0.456 
    img[2] += 0.406
    
    img = img.numpy().transpose((1, 2, 0))
    
    
    plt.imshow(img)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

In [None]:
#Initiate the Dataset and Dataloader
#setting the constants
BATCH_SIZE = 256
NUM_WORKER = 4

#defining the transform to be applied
transforms = T.Compose([
    T.Resize(226),                     
    T.RandomCrop(224),                 
    T.ToTensor(),                               
    T.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
])


#testing the dataset class
train_dataset =  FlickrDataset(
    root_dir = train_data_location+"/Images",
    captions_file = train_data_location+"/captions.txt",
    transform=transforms,
    test = False
)

#writing the dataloader
train_data_loader = get_data_loader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKER,
    shuffle=True,
    # batch_first=False
)
valid_dataset =  FlickrDataset(
    root_dir = train_data_location+"/Images",
    captions_file = train_data_location+"/captions.txt",
    transform=transforms, 
    test = True
)

#writing the dataloader
valid_data_loader = get_data_loader(
    dataset=valid_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKER,
    shuffle=False,
    # batch_first=False
)

In [None]:
#vocab_size
vocab_size = len(valid_dataset.vocab)
print(vocab_size)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
print(train_dataset.get_dataset_sizes())

In [None]:
print(len(valid_dataset))

### 3) Defining the Model Architecture

Model is seq2seq model. In the **encoder** pretrained ResNet model is used to extract the features. Decoder, is the implementation of the Bahdanau Attention Decoder. In the decoder model **LSTM cell**.

In [None]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet101(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad_(False)
        
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)
        

    def forward(self, images):
        features = self.resnet(images)                                    #(batch_size,2048,7,7)
        features = features.permute(0, 2, 3, 1)                           #(batch_size,7,7,2048)
        features = features.view(features.size(0), -1, features.size(-1)) #(batch_size,49,2048)
        return features


In [None]:
#Bahdanau Attention
class Attention(nn.Module):
    def __init__(self, encoder_dim,decoder_dim,attention_dim):
        super(Attention, self).__init__()
        
        self.attention_dim = attention_dim
        
        self.W = nn.Linear(decoder_dim,attention_dim)
        self.U = nn.Linear(encoder_dim,attention_dim)
        
        self.A = nn.Linear(attention_dim,1)
        
        
        
        
    def forward(self, features, hidden_state):
        u_hs = self.U(features)     #(batch_size,num_layers,attention_dim)
        w_ah = self.W(hidden_state) #(batch_size,attention_dim)
        
        combined_states = torch.tanh(u_hs + w_ah.unsqueeze(1)) #(batch_size,num_layers,attemtion_dim)
        
        attention_scores = self.A(combined_states)         #(batch_size,num_layers,1)
        attention_scores = attention_scores.squeeze(2)     #(batch_size,num_layers)
        
        
        alpha = F.softmax(attention_scores,dim=1)          #(batch_size,num_layers)
        
        attention_weights = features * alpha.unsqueeze(2)  #(batch_size,num_layers,features_dim)
        attention_weights = attention_weights.sum(dim=1)   #(batch_size,num_layers)
        
        return alpha,attention_weights
        

In [None]:
#Attention Decoder
class DecoderRNN(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()
        
        #save the model param
        self.vocab_size = vocab_size
        self.attention_dim = attention_dim
        self.decoder_dim = decoder_dim
        
        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.attention = Attention(encoder_dim,decoder_dim,attention_dim)
        
        
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  
        self.lstm_cell = nn.LSTMCell(embed_size+encoder_dim,decoder_dim,bias=True)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        
        
        self.fcn = nn.Linear(decoder_dim,vocab_size)
        self.drop = nn.Dropout(drop_prob)
        
        
    
    def forward(self, features, captions):
        
        #vectorize the caption
        embeds = self.embedding(captions)
        
        # Initialize LSTM state
        h, c = self.init_hidden_state(features)  # (batch_size, decoder_dim)
        
        #get the seq length to iterate
        seq_length = len(captions[0])-1 #Exclude the last one
        batch_size = captions.size(0)
        num_features = features.size(1)
        
        preds = torch.zeros(batch_size, seq_length, self.vocab_size).to(device)
        alphas = torch.zeros(batch_size, seq_length,num_features).to(device)
                
        for s in range(seq_length):
            alpha,context = self.attention(features, h)
            lstm_input = torch.cat((embeds[:, s], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
                    
            output = self.fcn(self.drop(h))
            
            preds[:,s] = output
            alphas[:,s] = alpha  
        
        
        return preds, alphas
    
    def generate_caption(self,features,max_len=20,vocab=None):
        # Inference part
        # Given the image features generate the captions
        
        batch_size = features.size(0)
        h, c = self.init_hidden_state(features)  # (batch_size, decoder_dim)
        
        alphas = []
        
        #starting input
        word = torch.tensor(vocab.stoi['<SOS>']).view(1,-1).to(device)
        embeds = self.embedding(word)

        
        captions = []
        
        for i in range(max_len):
            alpha,context = self.attention(features, h)
            
            
            #store the apla score
            alphas.append(alpha.cpu().detach().numpy())
            
            lstm_input = torch.cat((embeds[:, 0], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            output = self.fcn(self.drop(h))
            output = output.view(batch_size,-1)
        
            
            #select the word with most val
            predicted_word_idx = output.argmax(dim=1)
            
            #save the generated word
            captions.append(predicted_word_idx.item())
            
            #end if <EOS detected>
            if vocab.itos[predicted_word_idx.item()] == "<EOS>":
                break
            
            #send generated word as the next caption
            embeds = self.embedding(predicted_word_idx.unsqueeze(0))
        
        #covert the vocab idx to words and return sentence
        return [vocab.itos[idx] for idx in captions],alphas
    
    
    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c


In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()
        self.encoder = EncoderCNN()
        self.decoder = DecoderRNN(
            embed_size=embed_size,
            vocab_size = vocab_size,
            attention_dim=attention_dim,
            encoder_dim=encoder_dim,
            decoder_dim=decoder_dim
        )
        
    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs


### 4) Setting Hypperparameter and Init the model

In [None]:
#Hyperparams
embed_size=300
vocab_size = len(train_dataset.vocab)
attention_dim=256
encoder_dim=2048
decoder_dim=512
learning_rate = 3e-4


In [None]:
#init model
model = EncoderDecoder(
    embed_size=embed_size,
    vocab_size = vocab_size,
    attention_dim=attention_dim,
    encoder_dim=encoder_dim,
    decoder_dim=decoder_dim
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=train_dataset.vocab.stoi["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
#helper function to save the model
def save_model(model,num_epochs):
    model_state = {
        'num_epochs':num_epochs,
        'embed_size':embed_size,
        'vocab_size':vocab_size,
        'attention_dim':attention_dim,
        'encoder_dim':encoder_dim,
        'decoder_dim':decoder_dim,
        'state_dict':model.state_dict()
    }

    torch.save(model_state,'attention_model_state.pth')

## 5) Training Job from above configs

In [None]:
import wandb
!wandb login '79878ad1e35ed3ddbaaf85c61158d656fd739989'

In [None]:
wandb.init(project = 'image-captioning')

In [None]:
num_epochs = 40
print_every = 100
patience = 3  # Number of epochs to wait for improvement in validation loss
best_valid_loss = float("inf")  # Initialize best validation loss
epochs_without_improvement = 0 

for epoch in range(1,num_epochs+1):  
    model.train()
    epoch_train_loss = 0
    for idx, (image, captions) in enumerate(iter(train_data_loader)):
        image,captions = image.to(device),captions.to(device)
        optimizer.zero_grad()
        outputs,attentions = model(image, captions)
        targets = captions[:,1:]
        loss = criterion(outputs.view(-1, vocab_size), targets.reshape(-1))
        loss.backward()
        optimizer.step()

        epoch_train_loss += loss.item()

        if (idx+1)%print_every == 0:
            print("Train section")
            print("Epoch: {} |Loss: {:.5f}".format(epoch,loss.item()))
            
    wandb.log({"train_loss": epoch_train_loss / len(train_data_loader), "epoch": epoch})
    model.eval()  

    epoch_valid_loss = 0
    with torch.no_grad():
        for idx, (image, captions) in enumerate(iter(valid_data_loader)):
            image,captions = image.to(device),captions.to(device)
            outputs,attentions = model(image, captions)
            targets = captions[:,1:]
            loss = criterion(outputs.view(-1, vocab_size), targets.reshape(-1))

            epoch_valid_loss += loss.item()
    
            if (idx+1)%print_every == 0:
                print("Valid section")
                print("Epoch: {} |Loss: {:.5f}".format(epoch,loss.item()))
                
    wandb.log({"valid_loss": epoch_valid_loss / len(valid_data_loader), "epoch": epoch})
    
    if epoch_valid_loss < best_valid_loss:
        best_valid_loss = epoch_valid_loss
        epochs_without_improvement = 0  # Reset the counter if there's improvement
    else:
        epochs_without_improvement += 1
    
    if epochs_without_improvement >= patience:
        print(f"Early stopping at epoch {epoch} due to no improvement in validation loss.")
        break
    save_model(model,epoch)


## 6 Visualizing the attentions
Defining helper functions
<li>Given the image generate captions and attention scores</li>
<li>Plot the attention scores in the image</li>

In [None]:
# #generate caption
# def get_caps_from(features_tensors):
#     #generate the caption
#     model.eval()
#     with torch.no_grad():
#         features = model.encoder(features_tensors.to(device))
#         caps,alphas = model.decoder.generate_caption(features,vocab=train_dataset.vocab)
#         caption = ' '.join(caps)
#         show_image(features_tensors[0],title=caption)
    
#     return caps,alphas

# #Show attention
# def plot_attention(img, result, attention_plot):
#     #untransform
#     img[0] = img[0] * 0.229
#     img[1] = img[1] * 0.224 
#     img[2] = img[2] * 0.225 
#     img[0] += 0.485 
#     img[1] += 0.456 
#     img[2] += 0.406
    
#     img = img.numpy().transpose((1, 2, 0))
#     temp_image = img

#     fig = plt.figure(figsize=(15, 15))

#     len_result = len(result)
#     for l in range(len_result):
#         temp_att = attention_plot[l].reshape(7,7)
        
#         ax = fig.add_subplot(len_result//2,len_result//2, l+1)
#         ax.set_title(result[l])
#         img = ax.imshow(temp_image)
#         ax.imshow(temp_att, cmap='gray', alpha=0.7, extent=img.get_extent())
        

#     plt.tight_layout()
#     plt.show()

In [None]:
# #show any 1
# dataiter = iter(train_data_loader)
# images,_ = next(dataiter)

# img = images[0].detach().clone()
# img1 = images[0].detach().clone()
# caps,alphas = get_caps_from(img.unsqueeze(0))

# plot_attention(img1, caps, alphas)