In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
#%cd /content/drive/MyDrive/a-PyTorch-Tutorial-to-Image-Captioning
%cd /content/drive/MyDrive/PyTorch

/content/drive/MyDrive/PyTorch


In [3]:
!ls


BEST_checkpoint_flickr30k_5_cap_per_img_5_min_word_freq.pth.tar
captioneers.ipynb
caption.py
checkpoint_flickr30k_5_cap_per_img_5_min_word_freq.pth.tar
create_input_files.py
datasets.py
enlarge_resnet.json
eval.py
flickr_prep
gan_model.py
img
models.py
__pycache__
TRAIN_IMAGES_flickr30k_5_cap_per_img_5_min_word_freq-003.hdf5
train.py
train_utils.py
utils.py


In [26]:
#import argparse
#parser = argparse.ArgumentParser(description='Process some integers.')
#parser.add_argument('--device', type=str)
#args = parser.parse_args()

from datasets import CaptionDataset
from models import *
from utils import *
from train_utils import *
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
import os
import json
import numpy as np

In [27]:
#os.environ['CUDA_VISIBLE_DEVICES'] = args.device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [28]:
%cd /content/drive/MyDrive/a-PyTorch-Tutorial-to-Image-Captioning

/content/drive/.shortcut-targets-by-id/1qbeagtpdhUjJ0qiqM0snKrjUhBUYcQFB/a-PyTorch-Tutorial-to-Image-Captioning


In [29]:
# Data parameters
#main_folder = "./a-PyTorch-Tutorial-to-Image-Captioning"

# data_folder = '/media/ssd/caption data'  # folder with data files saved by create_input_files.py
data_folder = f'flickr_prep/'  # folder with data files saved by create_input_files.py
# data_name = 'coco_5_cap_per_img_5_min_word_freq'  # base name shared by data files
data_name = 'flickr30k_5_cap_per_img_5_min_word_freq'  # base name shared by data files
word_map_file = f'flickr_prep/WORDMAP_flickr30k_5_cap_per_img_5_min_word_freq.json'
with open(word_map_file, 'r') as j:
    word_map = json.load(j)

# Model parameters
emb_dim = 512  # dimension of word embeddings
attention_dim = 512  # dimension of attention linear layers
decoder_dim = 512  # dimension of decoder RNN
dropout = 0.5
n_heads = 8
encoder_layers = 2
decoder_layers = 6
attention_method = "ByPixel"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Training parameters
start_epoch = 0
epochs = 10  # number of epochs to train for (if early stopping is not triggered)
epochs_since_improvement = 0  # keeps track of number of epochs since there's been an improvement in validation BLEU
batch_size = 64
workers = 0  # for data-loading; right now, only 1 works with h5py
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
best_bleu4 = 0.  # BLEU-4 score right now
print_freq = 100  # print training/validation stats every __ batches
fine_tune_encoder = False  # fine-tune encoder?
checkpoint = None  # path to checkpoint, None if none

decoder = Transformer(vocab_size=len(word_map),
                        embed_dim=emb_dim,
                        encoder_layers=encoder_layers,
                        decoder_layers=decoder_layers,
                        dropout=dropout,
                        attention_method=attention_method,
                        n_heads=n_heads)

decoder = decoder.to(device)

decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                     lr=decoder_lr)
encoder = CNN_Encoder()
encoder.fine_tune(fine_tune_encoder)
encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                     lr=encoder_lr) if fine_tune_encoder else None
encoder = encoder.to(device)

# Loss function
criterion = nn.CrossEntropyLoss(ignore_index = 0).to(device)

# Custom dataloaders
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
    CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
    batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
    CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
    batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

save_name = 'enlarge_resnet.json'
training_track = {'loss':[],'bleu':[]}

In [30]:
%cd /content/drive/MyDrive/PyTorch

/content/drive/MyDrive/PyTorch


In [None]:
# Epochs
for epoch in range(start_epoch, epochs):

    # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
    if epochs_since_improvement == 20:
        break
    if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
        adjust_learning_rate(decoder_optimizer, 0.8)
        if fine_tune_encoder:
            adjust_learning_rate(encoder_optimizer, 0.8)

    # One epoch's training
    train(train_loader=train_loader,
          encoder=encoder,
          decoder=decoder,
          criterion=criterion,
          encoder_optimizer=encoder_optimizer,
          decoder_optimizer=decoder_optimizer,
          epoch=epoch,alpha_c=alpha_c,
          print_freq=print_freq,
          grad_clip=grad_clip,
          n_heads=n_heads,
          decoder_layers=decoder_layers)

    # One epoch's validation
    recent_bleu4, loss = validate(val_loader=val_loader,
                            encoder=encoder,
                            decoder=decoder,
                            criterion=criterion,
                            alpha_c=alpha_c,
                            print_freq=print_freq,
                            word_map=word_map,
                            n_heads=n_heads,
                            decoder_layers=decoder_layers)

    training_track['bleu'].append(recent_bleu4)
    training_track['loss'].append(loss)
    with open(save_name, 'w') as f:
        json.dump(training_track, f)

    # Check if there was an improvement
    is_best = recent_bleu4 > best_bleu4
    best_bleu4 = max(recent_bleu4, best_bleu4)
    if not is_best:
        epochs_since_improvement += 1
        print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
    else:
        epochs_since_improvement = 0

    # Save checkpoint
    save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
                    decoder_optimizer, recent_bleu4, is_best)

Epoch: [0][0/2266]	Batch Time 2.517 (2.517)	Data Load Time 0.526 (0.526)	Loss 9.5918 (9.5918)	Top-5 Accuracy 0.113 (0.113)
Epoch: [0][100/2266]	Batch Time 2.315 (2.308)	Data Load Time 0.260 (0.264)	Loss 6.0170 (6.3699)	Top-5 Accuracy 30.626 (28.692)
Epoch: [0][200/2266]	Batch Time 2.318 (2.302)	Data Load Time 0.268 (0.256)	Loss 6.1959 (6.2697)	Top-5 Accuracy 28.390 (28.806)
Epoch: [0][300/2266]	Batch Time 2.284 (2.299)	Data Load Time 0.240 (0.252)	Loss 6.1832 (6.2257)	Top-5 Accuracy 28.656 (28.986)
Epoch: [0][400/2266]	Batch Time 2.274 (2.295)	Data Load Time 0.227 (0.248)	Loss 5.9820 (6.1662)	Top-5 Accuracy 31.096 (29.740)
Epoch: [0][500/2266]	Batch Time 2.291 (2.290)	Data Load Time 0.263 (0.246)	Loss 5.9607 (6.1167)	Top-5 Accuracy 33.531 (30.323)
Epoch: [0][600/2266]	Batch Time 2.268 (2.288)	Data Load Time 0.246 (0.244)	Loss 5.8630 (6.0791)	Top-5 Accuracy 34.479 (30.771)
Epoch: [0][700/2266]	Batch Time 2.289 (2.285)	Data Load Time 0.249 (0.242)	Loss 5.7272 (6.0468)	Top-5 Accuracy 36.0

Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().



 * LOSS - 10.600, TOP-5 ACCURACY - 18.344, BLEU-4 - 0.6770720407899168

Epoch: [1][0/2266]	Batch Time 3.046 (3.046)	Data Load Time 0.991 (0.991)	Loss 5.5104 (5.5104)	Top-5 Accuracy 34.957 (34.957)
Epoch: [1][100/2266]	Batch Time 2.272 (2.347)	Data Load Time 0.239 (0.309)	Loss 5.8351 (5.6116)	Top-5 Accuracy 33.054 (35.332)
Epoch: [1][200/2266]	Batch Time 2.312 (2.317)	Data Load Time 0.271 (0.279)	Loss 5.3784 (5.6106)	Top-5 Accuracy 38.040 (35.359)
Epoch: [1][300/2266]	Batch Time 2.262 (2.307)	Data Load Time 0.231 (0.270)	Loss 5.5269 (5.6131)	Top-5 Accuracy 35.723 (35.312)
Epoch: [1][400/2266]	Batch Time 2.299 (2.298)	Data Load Time 0.266 (0.262)	Loss 5.4686 (5.6130)	Top-5 Accuracy 38.421 (35.325)
Epoch: [1][500/2266]	Batch Time 2.266 (2.293)	Data Load Time 0.240 (0.257)	Loss 5.5204 (5.6087)	Top-5 Accuracy 36.121 (35.402)
Epoch: [1][600/2266]	Batch Time 2.209 (2.288)	Data Load Time 0.214 (0.255)	Loss 5.5914 (5.6074)	Top-5 Accuracy 35.853 (35.395)
Epoch: [1][700/2266]	Batch Time 2.204 (2

In [4]:
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from datasets import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np


In [None]:

# Parameters
data_folder = 'flickr_prep/'  # folder with data files saved by create_input_files.py
data_name = 'flickr30k_5_cap_per_img_5_min_word_freq'  # base name shared by data files
checkpoint = 'BEST_checkpoint_flickr30k_5_cap_per_img_5_min_word_freq.pth.tar'  # model checkpoint
word_map_file = f'flickr_prep/WORDMAP_flickr30k_5_cap_per_img_5_min_word_freq.json'  # word map, ensure it's the same the data was encoded with and the model was trained with
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Load model
checkpoint = torch.load(checkpoint, map_location=torch.device('cpu'))
decoder = checkpoint['decoder']
#decoder = decoder.to(device)
decoder.eval()
encoder = checkpoint['encoder']
#encoder = encoder.to(device)
encoder.eval()

# Load word map (word2ix)
with open(word_map_file, 'r') as j:
    word_map = json.load(j)
rev_word_map = {v: k for k, v in word_map.items()}
vocab_size = len(word_map)

# Normalization transform
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])


def evaluate(beam_size):
    """
    Evaluation

    :param beam_size: beam size at which to generate captions for evaluation
    :return: BLEU-4 score
    """
    # DataLoader
    loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'TEST', transform=transforms.Compose([normalize])),
        batch_size=1, shuffle=True, num_workers=1, pin_memory=True)

    # TODO: Batched Beam Search
    # Therefore, do not use a batch_size greater than 1 - IMPORTANT!

    # Lists to store references (true captions), and hypothesis (prediction) for each image
    # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
    # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]
    references = list()
    hypotheses = list()

    # For each image
    with torch.no_grad():
        for i, (image, caps, caplens, allcaps) in enumerate( tqdm(loader, desc="EVALUATING AT BEAM SIZE " + str(beam_size))):
            k = beam_size
            # Move to GPU device, if available
            image = image.to(device)  # [1, 3, 256, 256]

            # Encode
            encoder_out = encoder(image)  # [1, enc_image_size=14, enc_image_size=14, encoder_dim=2048]
            enc_image_size = encoder_out.size(1)
            encoder_dim = encoder_out.size(-1)
            # We'll treat the problem as having a batch size of k, where k is beam_size
            encoder_out = encoder_out.expand(k, enc_image_size, enc_image_size, encoder_dim)  # [k, enc_image_size, enc_image_size, encoder_dim]
            # Tensor to store top k previous words at each step; now they're just <start>
            # Important: [1, 52] (eg: [[<start> <start> <start> ...]]) will not work, since it contains the position encoding
            k_prev_words = torch.LongTensor([[word_map['<start>']]*52] * k).to(device)  # (k, 52)
            # Tensor to store top k sequences; now they're just <start>
            seqs = torch.LongTensor([[word_map['<start>']]] * k).to(device)  # (k, 1)
            # Tensor to store top k sequences' scores; now they're just 0
            top_k_scores = torch.zeros(k, 1).to(device)
            # Lists to store completed sequences and scores
            complete_seqs = []
            complete_seqs_scores = []
            step = 1

            # Start decoding
            # s is a number less than or equal to k, because sequences are removed from this process once they hit <end>
            while True:
                # print("steps {} k_prev_words: {}".format(step, k_prev_words))
                # cap_len = torch.LongTensor([52]).repeat(k, 1).to(device) may cause different sorted results on GPU/CPU in transformer.py
                cap_len = torch.LongTensor([52]).repeat(k, 1)  # [s, 1]
                scores, _, _, _, _ = decoder(encoder_out, k_prev_words, cap_len)
                scores = scores[:, step-1, :].squeeze(1)  # [s, 1, vocab_size] -> [s, vocab_size]
                scores = F.log_softmax(scores, dim=1)
                # top_k_scores: [s, 1]
                scores = top_k_scores.expand_as(scores) + scores  # [s, vocab_size]
                # For the first step, all k points will have the same scores (since same k previous words, h, c)
                if step == 1:
                    top_k_scores, top_k_words = scores[0].topk(k, 0, True, True)  # (s)
                else:
                    # Unroll and find top scores, and their unrolled indices
                    top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True)  # (s)

                # Convert unrolled indices to actual indices of scores
                prev_word_inds = top_k_words // vocab_size  # (s)
                next_word_inds = top_k_words % vocab_size  # (s)

                # Add new words to sequences
                seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1)  # (s, step+1)
                # Which sequences are incomplete (didn't reach <end>)?
                incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if
                                   next_word != word_map['<end>']]
                complete_inds = list(set(range(len(next_word_inds))) - set(incomplete_inds))
                # Set aside complete sequences
            if len(complete_inds) > 0:
                complete_seqs.extend(seqs[complete_inds].tolist())
                complete_seqs_scores.extend(top_k_scores[complete_inds])
            k -= len(complete_inds)  # reduce beam length accordingly

            # Proceed with incomplete sequences
            if k == 0:
                break
            seqs = seqs[incomplete_inds]
            encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
            top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
            k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)
            k_prev_words[:, :step+1] = seqs   # [s, 52]
            # Break if things have been going on too long
            if step > 50:
                break
            step += 1

        i = complete_seqs_scores.index(max(complete_seqs_scores))
        seq = complete_seqs[i]

        # References
        img_caps = allcaps[0].tolist()
        img_captions = list(
            map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<end>'], word_map['<pad>']}],
                img_caps))  # remove <start> and pads
        references.append(img_captions)

        # Hypotheses
        hypotheses.append([w for w in seq if w not in {word_map['<start>'], word_map['<end>'], word_map['<pad>']}])

        assert len(references) == len(hypotheses)

    # Calculate BLEU-4 scores
    bleu4 = corpus_bleu(references, hypotheses)

    return bleu4


if __name__ == '__main__':
    beam_size = 1
    print("\nBLEU-4 score @ beam size of %d is %.4f." % (beam_size, evaluate(beam_size)))
