In [1]:
pip install tensorboardX

Note: you may need to restart the kernel to use updated packages.


In [2]:
from tensorboardX import SummaryWriter
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms
import argparse, json
import torch
import torch.nn as nn
import torch.optim as optim
from nltk.translate.bleu_score import corpus_bleu

import sys
sys.path.append("/kaggle/input/train-requirements")
from dataset import ImageCaptionDataset
from decoder import Decoder
from encoder import Encoder
from utils import AverageMeter, accuracy, calculate_caption_lengths, collate_fn

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
parser = argparse.ArgumentParser(description='Show, Attend and Tell')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=10, metavar='E',
                    help='number of epochs to train for (default: 20)')
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
                    help='learning rate of the decoder (default: 1e-3)')
parser.add_argument('--step-size', type=int, default=5,
                    help='step size for learning rate annealing (default: 5)')
parser.add_argument('--alpha-c', type=float, default=1, metavar='A',
                    help='regularization constant (default: 1)')
parser.add_argument('--log-interval', type=int, default=100, metavar='L',
                    help='number of batches to wait before logging training stats (default: 100)')
parser.add_argument('--data', type=str, default='/kaggle/input/image-captioning-dataset',
                    help='path to data images (default: /kaggle/input/flicker30k-dataset)')
parser.add_argument('--network', choices=['vgg19', 'resnet152', 'densenet161'], default='vgg19',
                    help='Network to use in the encoder (default: vgg19)')
parser.add_argument('--model', type=str, help='path to model')
parser.add_argument('--tf', action='store_true', default=False,
                    help='Use teacher forcing when training LSTM (default: False)')


_StoreTrueAction(option_strings=['--tf'], dest='tf', nargs=0, const=True, default=False, type=None, choices=None, required=False, help='Use teacher forcing when training LSTM (default: False)', metavar=None)

In [5]:
args, unknown = parser.parse_known_args()

print(args)

Namespace(batch_size=64, epochs=10, lr=0.001, step_size=5, alpha_c=1, log_interval=100, data='/kaggle/input/image-captioning-dataset', network='vgg19', model=None, tf=False)


In [6]:
tf = True # teacher forcing
network = "resnet152" # other options ["densenet161", "vgg19"]
arg_lr = 0.0001
step_size = 5
data = "/kaggle/input/image-captioning-dataset/data"
batch_size = 64

In [7]:
writer = SummaryWriter()

word_dict = json.load(open("/kaggle/input/image-captioning-dataset/word_dict.json", 'r'))
vocabulary_size = len(word_dict)

encoder = Encoder(args.network)
decoder = Decoder(vocabulary_size, encoder.dim, args.tf)

encoder.to(device)
decoder.to(device)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:02<00:00, 202MB/s] 


Decoder(
  (init_h): Linear(in_features=512, out_features=512, bias=True)
  (init_c): Linear(in_features=512, out_features=512, bias=True)
  (tanh): Tanh()
  (f_beta): Linear(in_features=512, out_features=512, bias=True)
  (sigmoid): Sigmoid()
  (deep_output): Linear(in_features=512, out_features=5507, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (attention): Attention(
    (U): Linear(in_features=512, out_features=512, bias=True)
    (W): Linear(in_features=512, out_features=512, bias=True)
    (v): Linear(in_features=512, out_features=1, bias=True)
    (tanh): Tanh()
    (softmax): Softmax(dim=1)
  )
  (embedding): Embedding(5507, 512, padding_idx=0)
  (lstm): LSTMCell(1024, 512)
)

In [8]:
optimizer = optim.Adam(decoder.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size)
cross_entropy_loss = nn.CrossEntropyLoss().to(device)
# cross_entropy_loss = nn.CrossEntropyLoss(ignore_index=0).to(device)

In [9]:
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [10]:
# print(decoder.state_dict())

In [11]:
train_loader = torch.utils.data.DataLoader(
    ImageCaptionDataset(data_transforms, args.data),
    batch_size=args.batch_size, shuffle=True, num_workers=1)

val_loader = torch.utils.data.DataLoader(
    ImageCaptionDataset(data_transforms, args.data, split_type='val'),
    batch_size=args.batch_size, shuffle=True, num_workers=1, collate_fn=collate_fn)

In [12]:
def train(epoch, encoder, decoder, optimizer, cross_entropy_loss, data_loader, word_dict, alpha_c, log_interval, writer):
    encoder.eval()
    decoder.train()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    for batch_idx, (imgs, captions, captions_len) in enumerate(data_loader):
        imgs, captions = imgs.to(device), captions.to(device)
        img_features = encoder(imgs)
        optimizer.zero_grad()
        preds, alphas = decoder(img_features, captions)
        targets = captions[:, 1:]    # as the first word is <start>
        # print(f"t_bef{targets}")
        # print(f"p_bef{preds[0]}")
        targets = pack_padded_sequence(targets, captions_len, batch_first=True, enforce_sorted=False)[0]
        preds = pack_padded_sequence(preds, captions_len, batch_first=True, enforce_sorted=False)[0]
        # print(f"t_after{targets}")
        # print(f"p_after{preds}")
        # break
        att_regularization = alpha_c * ((1 - alphas.sum(1))**2).mean()
        
        loss = cross_entropy_loss(preds, targets)
        loss += att_regularization
        loss.backward()    # computer gradients
        optimizer.step()   # update weights

        total_caption_length = calculate_caption_lengths(word_dict, captions)
        acc1 = accuracy(preds, targets, 1)  # was the most probable word = target?
        acc5 = accuracy(preds, targets, 5)  # was the target in the top 5 most probable words?
        losses.update(loss.item(), total_caption_length)
        top1.update(acc1, total_caption_length)
        top5.update(acc5, total_caption_length)

        if batch_idx % log_interval == 0:
            print('Train Batch: [{0}/{1}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top 1 Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Top 5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(
                      batch_idx, len(data_loader), loss=losses, top1=top1, top5=top5))
    writer.add_scalar('train_loss', losses.avg, epoch)
    writer.add_scalar('train_top1_acc', top1.avg, epoch)
    writer.add_scalar('train_top5_acc', top5.avg, epoch)

In [13]:
def validate(epoch, encoder, decoder, cross_entropy_loss, data_loader, word_dict, alpha_c, log_interval, writer):
    encoder.eval()
    decoder.eval()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    references = []
    hypotheses = []

    with torch.no_grad():
        for batch_idx, (imgs, captions, captions_len, all_captions) in enumerate(data_loader):
            imgs, captions = imgs.to(device), captions.to(device)
            img_features = encoder(imgs)
            preds, alphas = decoder(img_features, captions)
            targets = captions[:, 1:]

            # Unpad sequences for loss calculation
            targets = pack_padded_sequence(targets, captions_len, batch_first=True, enforce_sorted=False)[0]
            packed_preds = pack_padded_sequence(preds, captions_len, batch_first=True, enforce_sorted=False)[0]

            att_regularization = alpha_c * ((1 - alphas.sum(1))**2).mean()
            loss = cross_entropy_loss(packed_preds, targets)
            loss += att_regularization

            total_caption_length = calculate_caption_lengths(word_dict, captions)
            acc1 = accuracy(packed_preds, targets, 1)
            acc5 = accuracy(packed_preds, targets, 5)
            losses.update(loss.item(), total_caption_length)
            top1.update(acc1, total_caption_length)
            top5.update(acc5, total_caption_length)

            # References (GT captions)
            for cap_set in all_captions:
                caps = []
                for caption in cap_set:
                    cap = [word_idx for word_idx in caption
                                    if word_idx != word_dict['<start>'] and word_idx != word_dict['<pad>']]
                    caps.append(cap)
                references.append(caps)

            # Hypotheses (Predicted captions with BEAM SEARCH)
            for i in range(imgs.size(0)):
                sentence, alpha = decoder.caption(img_features[i].unsqueeze(0), beam_size=3)  
                hypothesis = [idx for idx in sentence if idx not in (word_dict['<start>'], word_dict['<pad>'])]
                hypotheses.append(hypothesis)

            if batch_idx % log_interval == 0:
                print('Validation Batch: [{0}/{1}]\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top 1 Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Top 5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(
                          batch_idx, len(data_loader), loss=losses, top1=top1, top5=top5))

        # Write scalars
        writer.add_scalar('val_loss', losses.avg, epoch)
        writer.add_scalar('val_top1_acc', top1.avg, epoch)
        writer.add_scalar('val_top5_acc', top5.avg, epoch)

        # BLEU scores
        bleu_1 = corpus_bleu(references, hypotheses, weights=(1, 0, 0, 0))
        bleu_2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0))
        bleu_3 = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0))
        bleu_4 = corpus_bleu(references, hypotheses)

        writer.add_scalar('val_bleu1', bleu_1, epoch)
        writer.add_scalar('val_bleu2', bleu_2, epoch)
        writer.add_scalar('val_bleu3', bleu_3, epoch)
        writer.add_scalar('val_bleu4', bleu_4, epoch)

        print('Validation Epoch: {}\t'
              'BLEU-1 ({})\t'
              'BLEU-2 ({})\t'
              'BLEU-3 ({})\t'
              'BLEU-4 ({})\t'.format(epoch, bleu_1, bleu_2, bleu_3, bleu_4))


In [None]:
print('Starting training with {}'.format(args))
for epoch in range(1, args.epochs + 1):
    scheduler.step()
    train(epoch, encoder, decoder, optimizer, cross_entropy_loss,
          train_loader, word_dict, args.alpha_c, args.log_interval, writer)
    validate(epoch, encoder, decoder, cross_entropy_loss, val_loader,
             word_dict, args.alpha_c, args.log_interval, writer)
    model_file = '/kaggle/working/' + args.network + '_' + str(epoch) + '.pth'
    torch.save(decoder.state_dict(), model_file)
    print('Saved model to ' + model_file)
writer.close()

Starting training with Namespace(batch_size=64, epochs=10, lr=0.001, step_size=5, alpha_c=1, log_interval=100, data='/kaggle/input/image-captioning-dataset', network='vgg19', model=None, tf=False)




Train Batch: [0/569]	Loss 9.2370 (9.2370)	Top 1 Accuracy 0.000 (0.000)	Top 5 Accuracy 0.125 (0.125)
Train Batch: [100/569]	Loss 5.4839 (5.8510)	Top 1 Accuracy 20.528 (16.160)	Top 5 Accuracy 40.216 (37.807)
Train Batch: [200/569]	Loss 5.4329 (5.6300)	Top 1 Accuracy 17.626 (17.088)	Top 5 Accuracy 41.127 (39.616)
Train Batch: [300/569]	Loss 5.3154 (5.5243)	Top 1 Accuracy 21.366 (17.592)	Top 5 Accuracy 43.478 (40.533)
Train Batch: [400/569]	Loss 5.1999 (5.4508)	Top 1 Accuracy 18.442 (17.882)	Top 5 Accuracy 43.986 (41.185)
Train Batch: [500/569]	Loss 5.1847 (5.3929)	Top 1 Accuracy 19.332 (18.056)	Top 5 Accuracy 41.766 (41.667)
Validation Batch: [0/64]	Loss 4.9511 (4.9511)	Top 1 Accuracy 20.449 (20.449)	Top 5 Accuracy 45.262 (45.262)
Validation Epoch: 1	BLEU-1 (0.3856281620576739)	BLEU-2 (0.22910143031053684)	BLEU-3 (0.0966502387022368)	BLEU-4 (0.0443747190254995)	
Saved model to /kaggle/working/vgg19_1.pth
Train Batch: [0/569]	Loss 5.0303 (5.0303)	Top 1 Accuracy 20.256 (20.256)	Top 5 Accura