# Training Network

In this notebook, we will train the CNN-RNN model for Image captioning

CNN [ResNet](https://arxiv.org/pdf/1512.03385.pdf) model is used for feature extraction. 

In [None]:

from data_loader import get_loader
from data_loader_val import get_loader as val_get_loader
from model import *

import torch.nn as nn
import torch.utils.data as data
import torch
import torch.nn as nn
import torchvision.models as models

from pycocotools.coco import COCO
from torchvision import transforms
from tqdm.notebook import tqdm
from collections import defaultdict
from nlp_utils import clean_sentence, bleu_score

import math
import json
import os
import sys
import numpy as np

%load_ext autoreload
%autoreload 2

In [None]:
HOME: str = os.getenv('HOME') # echo $HOME
USER: str = os.getenv('USER') # echo $USER
print(HOME, USER)

In [None]:
# dataset dir path
cocoapi_dir = os.path.join("/scratch/project_2004072/IMG_Captioning", "MS_COCO") if USER=="alijanif" else os.path.join(HOME, "datasets/MS_COCO")
folders = [folder for folder in os.listdir(cocoapi_dir)]
print(folders)

In [10]:
batch_size = 128  # batch size
vocab_threshold = 5  # minimum word count threshold
vocab_from_file = True  # if True, load existing vocab file
embed_size = 256  # dimensionality of image and word embeddings
hidden_size = 512  # number of features in hidden state of the RNN decoder
num_epochs = 1  # training epochs
save_every = 1  # determines frequency of saving model weights
print_every = 200  # determines window for printing average loss
log_file = "training_log.txt"  # name of file with saved training loss and perplexity
os.makedirs("models", exist_ok=True)
models_dir = "models"
encoder_fname = f"encoder_{num_epochs}_nEpochs.pkl"
decoder_fname = f"decoder_{num_epochs}_nEpochs.pkl"

In [None]:
models_dir

In [12]:
transform_train = transforms.Compose(
    [
        # smaller edge of image resized to 256
        transforms.Resize(256),
        # get 224x224 crop from random location
        transforms.RandomCrop(224),
        # horizontally flip image with probability=0.5
        transforms.RandomHorizontalFlip(),
        # convert the PIL Image to a tensor
        transforms.ToTensor(),
        transforms.Normalize(
            (0.485, 0.456, 0.406),  # normalize image for pre-trained model
            (0.229, 0.224, 0.225),
        ),
    ]
)

In [None]:
# Build data loader.
data_loader = get_loader(
    transform=transform_train,
    mode="train",
    batch_size=batch_size,
    vocab_threshold=vocab_threshold,
    vocab_from_file=vocab_from_file,
    cocoapi_loc=cocoapi_dir,
)

### CNN Encoder and RNN Decoder 

In [None]:
# The size of the vocabulary.
vocab_size = len(data_loader.dataset.vocab)
print("vocab size is : ",vocab_size)

# Initializing the encoder and decoder
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

# Move models to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
decoder.to(device)

In [15]:
# Defining the loss function
criterion = (
    nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()
)

# Specifying the learnable parameters of the mode
params = list(decoder.parameters()) + list(encoder.embed.parameters())

# Defining the optimize
optimizer = torch.optim.Adam(params, lr=0.001)

# Set the total number of training steps per epoc
total_step = math.ceil(len(data_loader.dataset) / data_loader.batch_sampler.batch_size)

In [None]:
print(total_step)

## Training the Model


In [None]:
# Open the training log file.
f = open(log_file, "w")

for epoch in range(1, num_epochs + 1):
    for i_step in range(1, total_step + 1):
        
        # Randomly sample a caption length, and sample indices with that length.
        indices = data_loader.dataset.get_train_indices()
        # Create and assign a batch sampler to retrieve a batch with the sampled indices.
        new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
        data_loader.batch_sampler.sampler = new_sampler

        # Obtain the batch.
        images, captions = next(iter(data_loader))

        # Move batch of images and captions to GPU if CUDA is available.
        images = images.to(device)
        captions = captions.to(device)

        # Zero the gradients.
        decoder.zero_grad()
        encoder.zero_grad()

        # Passing the inputs through the CNN-RNN model
        features = encoder(images)
        outputs = decoder(features, captions)

        # Calculating the batch loss.
        loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))

        # Backwarding pass
        loss.backward()

        # Updating the parameters in the optimizer
        optimizer.step()

        # Getting training statistics
        stats = (
            f"Epoch [{epoch}/{num_epochs}], Step [{i_step}/{total_step}], "
            f"Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):.4f}"
        )

        # Print training statistics to file.
        f.write(stats + "\n")
        f.flush()

        # Print training statistics (on different line).
        if i_step % print_every == 0:
            print("\r" + stats)

    # Save the weights.
    if epoch % save_every == 0:
        torch.save(decoder.state_dict(), os.path.join(models_dir, decoder_fname))
        torch.save(encoder.state_dict(), os.path.join(models_dir, encoder_fname))

# Close the training log file.
f.close()


## Validating the Model using Bleu Score

In [None]:
transform_test = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.485, 0.456, 0.406),  # normalize image for pre-trained model
            (0.229, 0.224, 0.225),
        ),
    ]
)

#Create the data loader.
val_data_loader = val_get_loader(
    transform=transform_test, 
    mode="valid", 
    cocoapi_loc=cocoapi_dir,
)

In [None]:
print(type(val_data_loader))

In [None]:
# Initialize the encoder and decoder.
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

# Moving models to GPU if CUDA is available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
decoder.to(device)

In [None]:
!ls -l models

In [None]:
print(models_dir)

print(os.path.join(models_dir, encoder_fname))
print(os.path.join(models_dir, decoder_fname))

# Loading the trained weights
encoder.load_state_dict(torch.load(os.path.join(models_dir, encoder_fname)))
decoder.load_state_dict(torch.load(os.path.join(models_dir, decoder_fname)))

encoder.eval()
decoder.eval()

In [None]:
# infer captions for all images
pred_result = defaultdict(list)
for img_id, img in tqdm(val_data_loader):
    print(img_id, type(img))
    # img = img.to(device)
    with torch.no_grad():
        features = encoder(img).unsqueeze(1)
        output = decoder.sample(features)
    sentence = clean_sentence(output, val_data_loader.dataset.vocab.idx2word)
    pred_result[img_id.item()].append(sentence)

In [14]:
with open(
    #os.path.join(cocoapi_dir, "cocoapi", "annotations/captions_val2014.json"), "r"
    os.path.join(cocoapi_dir, "annotations/captions_val2017.json"), "r"
) as f:
    caption = json.load(f)

valid_annot = caption["annotations"]
valid_result = defaultdict(list)
for i in valid_annot:
    valid_result[i["image_id"]].append(i["caption"].lower())

In [None]:
list(valid_result.values())[:3]

In [None]:
list(pred_result.values())[:3]

In [None]:
bleu_score(true_sentences=valid_result, predicted_sentences=pred_result)

Not a bad bleu score with only 3 epochs!