# __SHOW, ATTEND, AND TELL__

In this notebook, work conducted for METU MMI727 course project is demonstrated. In this project, **Show, Attend, and Tell** [***put_reference] image captioning model is examined. Implementation available in [A PyTorch Tutorial to Image Captioning](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning) is taken as a baseline and different improvement strategies are examined. The baseline model and 5 modified models are trained on MSCOCO image captioning dataset [***put_reference]. These 6 models are benchmarked and compared on the test set of MSCOCO.

# Implementation Details
In this work, 6 different models are trained and benchmarked. All models utilize ResNet101 backbone with some modifications in either ResNet encoder block, Attention block, or decoder block.

Details about models are as follows:\
Model 1: Default implementation provided at [A PyTorch Tutorial to Image Captioning](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning).\
Model 2: Attention block is modified. Different from Model 1, instead of summing attention vectors for encoder output and decoder hidden state, they are concatenated and passed through a linear layer.\
Model 3: Decoder block is modified. Different from Model 1, instead of LSTM, GRU is used as the decoder.\
Model 4: Encoder block is modified. Different form Model 1 (utilizes conv5_x output of ResNet101), modified conv4_x output of ResNet101 is utilized as final layer of encoder. See project report for details.\
Model 5: Encoder block is modified. Different from Model 1, 2-level Feature Pyramid Network (FPN) is implemented at the encoder part.\
Model 6: Decoder block is modified. Different from Model 1, 2-layer LSTM is used as the decoder.\

Details of models are read from config files with `.yaml` extensions. Config files for these models are provided under `configs` folder.

## Load Model For Training

In [1]:
from torch import nn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from datasets import *
from utils import *
import time
from torch.nn.utils.rnn import pack_padded_sequence
from nltk.translate.bleu_score import corpus_bleu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors

## DATASET AND CONFIG FILES
Download `dataset` folder in this link [***put_link] and put it in the same directory as this notebook.

For data loading operations `hdf5` files are used beacuse whole dataset is too large to read at once.

Note that TRAIN `hdf5` file in `dataset` folder is actually `TEST` set of COCO dataset. I could not upload actual `TRAIN` set because it was 22 GB. 

Define paths to config file and dataset related files.

In [2]:
config_path = "configs_test/model_1_config.yaml" # path to model config file
image_data_folder = "dataset/" # path to directory which holds image data
image_data_name = "coco_5_cap_per_img_5_min_word_freq" # name for the dataset created with create_input_files.py
word_map_path = "dataset/WORDMAP_coco_5_cap_per_img_5_min_word_freq.json" # path to file which maps words to integers
output_path = "test_out/" # path to save checkpoints

Load word_map and config files.

In [3]:
# Load wordmap file and config file
word_map = None
with open(word_map_path, 'r') as j:
    word_map = json.load(j)

cfgData = None
with open(config_path, "r") as cfgFile:
        cfgData = yaml.safe_load(cfgFile)

modelTypes = cfgData["Model Type"]
modelParams = cfgData["Model Parameters"]
trainParams = cfgData["Training Parameters"]

## Create Model According To Config File

In [4]:
encoder, decoder, encoder_optimizer, decoder_optimizer, epochs_since_improvement, best_bleu4, \
            encoderType, decoderType, attentionType, enable2LayerDecoder = create_model_for_training(config_path, len(word_map))

Encoder Type: default
Decoder Type: LSTM
Attention Type: default
Encoder Dim: 2048
Enable2LayerDecoder: False


## Print Details of Encoder and Decoder

In [5]:
print("***ENCODER***")
print(encoder)
print("\n\n***DECODER***")
print(decoder)

***ENCODER***
Encoder(
  (resnet): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0)

## Define Loss Function
CrossEntropyLoss is used to calculate loss between decoder output and ground truth label (one-hot encoded words for each time instant).

In [6]:
# Loss function
criterion = nn.CrossEntropyLoss().to(device)

## Load Dataset Using Custom Data Loader
Normalize dataset and define data loaders using custom dataset class provided by the [A PyTorch Tutorial to Image Captioning](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning) repository.

In [7]:
# Custom dataloaders
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
    CaptionDataset(image_data_folder, image_data_name, 'TRAIN', transform=transforms.Compose([normalize])),
    batch_size=trainParams["batch_size"], shuffle=True, num_workers=1, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
    CaptionDataset(image_data_folder, image_data_name, 'VAL', transform=transforms.Compose([normalize])),
    batch_size=trainParams["batch_size"], shuffle=True, num_workers=1, pin_memory=True)

## TRAINING
Define training loop and perform training.

In [10]:
# Epochs
print_freq = 50  # print training/validation stats every __ batches
for epoch in range(trainParams["start_epoch"], trainParams["num_epochs"]):

    # Decay learning rate if there is no improvement for 8 consecutive epochs
    if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
        adjust_learning_rate(decoder_optimizer, 0.8)
        if trainParams["fine_tune_encoder"]:
            adjust_learning_rate(encoder_optimizer, 0.8)

    # ***************TRAINING********************
    # put models in train mode
    decoder.train()
    encoder.train()

    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss (per word decoded)
    top5accs = AverageMeter()  # top5 accuracy

    start = time.time()

    # Batches
    for i, (imgs, caps, caplens) in enumerate(train_loader):
        data_time.update(time.time() - start)

        # Move to GPU, if available
        imgs = imgs.to(device)
        caps = caps.to(device)
        caplens = caplens.to(device)

        # Forward prop.
        imgs = encoder(imgs)
        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

        # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
        targets = caps_sorted[:, 1:]

        # Remove timesteps that we didn't decode at, or are pads
        # pack_padded_sequence is an easy trick to do this
        # ***too many values to unpack error https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning/issues/86
        scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).data
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data

        # Calculate loss
        loss = criterion(scores, targets)

        # Add doubly stochastic attention regularization
        loss += trainParams["alpha_c"] * ((1. - alphas.sum(dim=1)) ** 2).mean()

        # Back prop.
        decoder_optimizer.zero_grad()
        if encoder_optimizer is not None:
            encoder_optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        if trainParams["grad_clip"] is not None:
            clip_gradient(decoder_optimizer, trainParams["grad_clip"])
            if encoder_optimizer is not None:
                clip_gradient(encoder_optimizer, trainParams["grad_clip"])

        # Update weights
        decoder_optimizer.step()
        if encoder_optimizer is not None:
            encoder_optimizer.step()

        # Keep track of metrics
        top5 = accuracy(scores, targets, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))
        batch_time.update(time.time() - start)

        start = time.time()

        # Print status
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                          batch_time=batch_time,
                                                                          data_time=data_time, loss=losses,
                                                                          top5=top5accs))

    # ***************VALIDATION********************

    decoder.eval()  # eval mode (no dropout or batchnorm)
    if encoder is not None:
        encoder.eval()

    batch_time_eval = AverageMeter()
    losses_eval = AverageMeter()
    top5accs_eval = AverageMeter()

    start_eval = time.time()

    references = list()  # references (true captions) for calculating BLEU-4 score
    hypotheses = list()  # hypotheses (predictions)

    # explicitly disable gradient calculation to avoid CUDA memory error
    # solves the issue #57
    with torch.no_grad():
        # Batches
        for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):

            # Move to device, if available
            imgs = imgs.to(device)
            caps = caps.to(device)
            caplens = caplens.to(device)

            # Forward prop.
            if encoder is not None:
                imgs = encoder(imgs)
            scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

            # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
            targets = caps_sorted[:, 1:]

            # Remove timesteps that we didn't decode at, or are pads
            # pack_padded_sequence is an easy trick to do this
            scores_copy = scores.clone()
            scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).data
            targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data

            # Calculate loss
            loss = criterion(scores, targets)

            # Add doubly stochastic attention regularization
            loss += trainParams["alpha_c"] * ((1. - alphas.sum(dim=1)) ** 2).mean()
            
            # Keep track of metrics
            losses_eval.update(loss.item(), sum(decode_lengths))
            top5 = accuracy(scores, targets, 5)
            top5accs_eval.update(top5, sum(decode_lengths))
            batch_time_eval.update(time.time() - start_eval)

            start_eval = time.time()

            if i % print_freq == 0:
                print('Validation: [{0}/{1}]\t'
                      'Batch Time {batch_time_eval.val:.3f} ({batch_time_eval.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time_eval=batch_time_eval,
                                                                                loss=losses_eval, top5=top5accs_eval))

            # Store references (true captions), and hypothesis (prediction) for each image
            # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
            # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]

            # References
            allcaps = allcaps[sort_ind]  # because images were sorted in the decoder
            for j in range(allcaps.shape[0]):
                img_caps = allcaps[j].tolist()
                img_captions = list(
                    map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],
                        img_caps))  # remove <start> and pads
                references.append(img_captions)

            # Hypotheses
            _, preds = torch.max(scores_copy, dim=2)
            preds = preds.tolist()
            temp_preds = list()
            for j, p in enumerate(preds):
                temp_preds.append(preds[j][:decode_lengths[j]])  # remove pads
            preds = temp_preds
            hypotheses.extend(preds)

            assert len(references) == len(hypotheses)

        # Calculate BLEU-4 scores
        bleu4 = corpus_bleu(references, hypotheses)

        print(
            '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
                loss=losses_eval,
                top5=top5accs_eval,
                bleu=bleu4))

    recent_bleu4 = bleu4

    # Check if there was an improvement
    is_best = recent_bleu4 > best_bleu4
    best_bleu4 = max(recent_bleu4, best_bleu4)
    if not is_best:
        epochs_since_improvement += 1
        print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
    else:
        epochs_since_improvement = 0

    # Save checkpoint
    save_checkpoint(image_data_name, epoch, epochs_since_improvement, encoderType, decoderType, enable2LayerDecoder, attentionType, encoder, decoder, encoder_optimizer,
                    decoder_optimizer, recent_bleu4, is_best, epoch, output_path)

Epoch: [0][5/25000]	Batch Time 0.167 (0.196)	Data Load Time 0.000 (0.014)	Loss 9.5502 (9.5459)	Top-5 Accuracy 22.222 (23.333)
Validation: [5/25000]	Batch Time 0.023 (0.033)	Loss 9.3255 (9.1914)	Top-5 Accuracy 22.222 (23.438)	

 * LOSS - 9.191, TOP-5 ACCURACY - 23.438, BLEU-4 - 1.377957355695106e-231

Epoch: [1][5/25000]	Batch Time 0.166 (0.173)	Data Load Time 0.000 (0.006)	Loss 8.2541 (8.8383)	Top-5 Accuracy 33.333 (26.761)
Validation: [5/25000]	Batch Time 0.024 (0.146)	Loss 8.3851 (9.1483)	Top-5 Accuracy 40.000 (26.389)	

 * LOSS - 9.148, TOP-5 ACCURACY - 26.389, BLEU-4 - 1.3183254360496076e-231


Epochs since last improvement: 1

