# Training Network

In this notebook, we will train the CNN-RNN model for Image captioning

CNN [ResNet](https://arxiv.org/pdf/1512.03385.pdf) model is used for feature extraction. 

In [1]:

from data_loader import get_loader
from data_loader_val import get_loader as val_get_loader
from model import *

import torch.nn as nn
import torch.utils.data as data
import torch
import torch.nn as nn
import torchvision.models as models

from pycocotools.coco import COCO
from torchvision import transforms
from tqdm.notebook import tqdm
from collections import defaultdict
from nlp_utils import clean_sentence, bleu_score

import math
import json
import os
import sys
import numpy as np

%load_ext autoreload
%autoreload 2

[nltk_data] Downloading package punkt to /home/farid/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
  warn(


In [2]:
HOME: str = os.getenv('HOME') # echo $HOME
USER: str = os.getenv('USER') # echo $USER
print(HOME, USER)

/home/farid farid


In [3]:
# dataset dir path
cocoapi_dir = os.path.join("/scratch/project_2004072/IMG_Captioning", "MS_COCO") if USER=="alijanif" else os.path.join(HOME, "datasets/MS_COCO")
folders = [folder for folder in os.listdir(cocoapi_dir)]
print(folders)

['val2017.zip', 'train2017.zip', 'images', 'annotations_trainval2017.zip', 'annotations']


In [4]:
batch_size = 128  # batch size
vocab_threshold = 5  # minimum word count threshold
vocab_from_file = True  # if True, load existing vocab file
embed_size = 256  # dimensionality of image and word embeddings
hidden_size = 512  # number of features in hidden state of the RNN decoder
num_epochs = 1  # training epochs
save_every = 1  # determines frequency of saving model weights
print_every = 20  # determines window for printing average loss
log_file = "training_log.txt"  # name of file with saved training loss and perplexity

In [5]:
transform_train = transforms.Compose(
    [
        # smaller edge of image resized to 256
        transforms.Resize(256),
        # get 224x224 crop from random location
        transforms.RandomCrop(224),
        # horizontally flip image with probability=0.5
        transforms.RandomHorizontalFlip(),
        # convert the PIL Image to a tensor
        transforms.ToTensor(),
        transforms.Normalize(
            (0.485, 0.456, 0.406),  # normalize image for pre-trained model
            (0.229, 0.224, 0.225),
        ),
    ]
)

In [6]:
# Build data loader.
data_loader = get_loader(
    transform=transform_train,
    mode="train",
    batch_size=batch_size,
    vocab_threshold=vocab_threshold,
    vocab_from_file=vocab_from_file,
    cocoapi_loc=cocoapi_dir,
)

Vocabulary successfully loaded from vocab.pkl file!
loading annotations into memory...
Done (t=0.48s)
creating index...
index created!
Obtaining caption lengths...


100%|██████████| 591753/591753 [00:23<00:00, 25295.75it/s]


### CNN Encoder and RNN Decoder 

In [8]:
# The size of the vocabulary.
vocab_size = len(data_loader.dataset.vocab)
print("vocab size is : ",vocab_size)

# Initializing the encoder and decoder
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

# Move models to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
decoder.to(device)

In [9]:
# Defining the loss function
criterion = (
    nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()
)

# Specifying the learnable parameters of the mode
params = list(decoder.parameters()) + list(encoder.embed.parameters())

# Defining the optimize
optimizer = torch.optim.Adam(params, lr=0.001)

# Set the total number of training steps per epoc
total_step = math.ceil(len(data_loader.dataset) / data_loader.batch_sampler.batch_size)

vocab size is :  11543


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /home/farid/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:28<00:00, 3.62MB/s]


In [10]:
print(total_step)

4624


## Training the Model


In [11]:
# Open the training log file.
f = open(log_file, "w")

for epoch in range(1, num_epochs + 1):
    for i_step in range(1, total_step + 1):
        
        # Randomly sample a caption length, and sample indices with that length.
        indices = data_loader.dataset.get_train_indices()
        # Create and assign a batch sampler to retrieve a batch with the sampled indices.
        new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
        data_loader.batch_sampler.sampler = new_sampler

        # Obtain the batch.
        images, captions = next(iter(data_loader))

        # Move batch of images and captions to GPU if CUDA is available.
        images = images.to(device)
        captions = captions.to(device)

        # Zero the gradients.
        decoder.zero_grad()
        encoder.zero_grad()

        # Passing the inputs through the CNN-RNN model
        features = encoder(images)
        outputs = decoder(features, captions)

        # Calculating the batch loss.
        loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))

        # Backwarding pass
        loss.backward()

        # Updating the parameters in the optimizer
        optimizer.step()

        # Getting training statistics
        stats = (
            f"Epoch [{epoch}/{num_epochs}], Step [{i_step}/{total_step}], "
            f"Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):.4f}"
        )

        # Print training statistics to file.
        f.write(stats + "\n")
        f.flush()

        # Print training statistics (on different line).
        if i_step % print_every == 0:
            print("\r" + stats)

    # Save the weights.
    if epoch % save_every == 0:
        torch.save(
            decoder.state_dict(), os.path.join("./models", "decoder-%d.pkl" % epoch)
        )
        torch.save(
            encoder.state_dict(), os.path.join("./models", "encoder-%d.pkl" % epoch)
        )

# Close the training log file.
f.close()

Epoch [1/1], Step [20/4624], Loss: 4.8973, Perplexity: 133.9271
Epoch [1/1], Step [40/4624], Loss: 4.9032, Perplexity: 134.7243
Epoch [1/1], Step [60/4624], Loss: 4.3973, Perplexity: 81.2294
Epoch [1/1], Step [80/4624], Loss: 3.9979, Perplexity: 54.4858
Epoch [1/1], Step [100/4624], Loss: 3.9218, Perplexity: 50.4938
Epoch [1/1], Step [120/4624], Loss: 4.0951, Perplexity: 60.0455
Epoch [1/1], Step [140/4624], Loss: 3.6245, Perplexity: 37.5077
Epoch [1/1], Step [160/4624], Loss: 3.9668, Perplexity: 52.8144
Epoch [1/1], Step [180/4624], Loss: 3.4522, Perplexity: 31.5700
Epoch [1/1], Step [200/4624], Loss: 3.4176, Perplexity: 30.4957
Epoch [1/1], Step [220/4624], Loss: 3.6445, Perplexity: 38.2623
Epoch [1/1], Step [240/4624], Loss: 3.8861, Perplexity: 48.7210
Epoch [1/1], Step [260/4624], Loss: 3.5298, Perplexity: 34.1169
Epoch [1/1], Step [280/4624], Loss: 3.4250, Perplexity: 30.7239
Epoch [1/1], Step [300/4624], Loss: 3.1883, Perplexity: 24.2463
Epoch [1/1], Step [320/4624], Loss: 3.1663


## Validating the Model using Bleu Score

In [None]:
transform_test = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.485, 0.456, 0.406),  # normalize image for pre-trained model
            (0.229, 0.224, 0.225),
        ),
    ]
)

#Create the data loader.
val_data_loader = val_get_loader(
    transform=transform_test, 
    mode="valid", 
    cocoapi_loc=cocoapi_dir,
)

In [None]:
encoder_file = f"encoder_{num_epochs}_nEpochs.pkl"
decoder_file = f"decoder_{num_epochs}_nEpochs.pkl"

In [None]:
# Initialize the encoder and decoder.
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

# Moving models to GPU if CUDA is available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
decoder.to(device)

# Loading the trained weights
encoder.load_state_dict(torch.load(os.path.join("./models", encoder_file)))
decoder.load_state_dict(torch.load(os.path.join("./models", decoder_file)))

encoder.eval()
decoder.eval()

In [None]:
# infer captions for all images
pred_result = defaultdict(list)
for img_id, img in tqdm(val_data_loader):
    img = img.to(device)
    with torch.no_grad():
        features = encoder(img).unsqueeze(1)
        output = decoder.sample(features)
    sentence = clean_sentence(output, val_data_loader.dataset.vocab.idx2word)
    pred_result[img_id.item()].append(sentence)

In [14]:
with open(
    #os.path.join(cocoapi_dir, "cocoapi", "annotations/captions_val2014.json"), "r"
    os.path.join(cocoapi_dir, "annotations/captions_val2017.json"), "r"
) as f:
    caption = json.load(f)

valid_annot = caption["annotations"]
valid_result = defaultdict(list)
for i in valid_annot:
    valid_result[i["image_id"]].append(i["caption"].lower())

In [None]:
list(valid_result.values())[:3]

In [None]:
list(pred_result.values())[:3]

In [None]:
bleu_score(true_sentences=valid_result, predicted_sentences=pred_result)

Not a bad bleu score with only 3 epochs!