In [1]:
!pip install tensorboardX



In [2]:
from tensorboardX import SummaryWriter
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms
import argparse, json
import torch
import torch.nn as nn
import torch.optim as optim
from nltk.translate.bleu_score import corpus_bleu


import sys
sys.path.append("/kaggle/input/train-requirements")
from dataset import ImageCaptionDataset
import sys
sys.path.append("/kaggle/input/transformer")
from decoder import Decoder
from transformerDecoder import DecoderOnlyTransformer
from encoder import Encoder
from utils import AverageMeter, accuracy, calculate_caption_lengths, collate_fn

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
parser = argparse.ArgumentParser(description='Show, Attend and Tell')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
                    help='batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=10, metavar='E',
                    help='number of epochs to train for (default: 10)')
parser.add_argument('--lr', type=float, default=1e-4, metavar='LR',
                    help='learning rate of the decoder (default: 1e-4)')
parser.add_argument('--step-size', type=int, default=5,
                    help='step size for learning rate annealing (default: 5)')
parser.add_argument('--alpha-c', type=float, default=1, metavar='A',
                    help='regularization constant (default: 1)')
parser.add_argument('--log-interval', type=int, default=100, metavar='L',
                    help='number of batches to wait before logging training stats (default: 100)')
parser.add_argument('--data', type=str, default='/kaggle/input/image-captioning-dataset',
                    help='path to data images (default: /kaggle/input/image-captioning-dataset)')
parser.add_argument('--network', choices=['vgg19', 'resnet152', 'densenet161'], default='vgg19',
                    help='Network to use in the encoder (default: vgg19)')
parser.add_argument('--model', type=str, help='path to model')
parser.add_argument('--tf', action='store_true', default=False,
                    help='Use teacher forcing when training LSTM (default: False)')


_StoreTrueAction(option_strings=['--tf'], dest='tf', nargs=0, const=True, default=False, type=None, choices=None, required=False, help='Use teacher forcing when training LSTM (default: False)', metavar=None)

In [5]:
args, unknown = parser.parse_known_args()

print(args)

Namespace(batch_size=32, epochs=10, lr=0.0001, step_size=5, alpha_c=1, log_interval=100, data='/kaggle/input/image-captioning-dataset', network='vgg19', model=None, tf=False)


In [6]:
tf = True # teacher forcing
network = "resnet152" # other options ["densenet161", "vgg19"]
arg_lr = 0.0001
step_size = 5
data = "/kaggle/input/image-captioning-dataset/data"
batch_size = 32

In [7]:
writer = SummaryWriter()

word_dict = json.load(open("/kaggle/input/image-captioning-dataset/word_dict.json", 'r'))
vocabulary_size = len(word_dict)

encoder = Encoder(args.network)
# decoder = Decoder(vocabulary_size, encoder.dim, args.tf)
decoder = DecoderOnlyTransformer(vocabulary_size, d_model=512, max_len=44, num_heads=8, device = device, encoder_dim=encoder.dim)

encoder.to(device)
decoder.to(device)



DecoderOnlyTransformer(
  (token_embedding): Embedding(5507, 512)
  (position_encoding): PositionEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (decoder_layers): ModuleList(
    (0-5): 6 x TransformerDecoderLayer(
      (multi_head_attention): MultiHeadAttention(
        (W_q): Linear(in_features=512, out_features=512, bias=False)
        (W_k): Linear(in_features=512, out_features=512, bias=False)
        (W_v): Linear(in_features=512, out_features=512, bias=False)
        (out_proj): Linear(in_features=512, out_features=512, bias=True)
      )
      (cross_attention): MultiHeadAttention(
        (W_q): Linear(in_features=512, out_features=512, bias=False)
        (W_k): Linear(in_features=512, out_features=512, bias=False)
        (W_v): Linear(in_features=512, out_features=512, bias=False)
        (out_proj): Linear(in_features=512, out_features=512, bias=True)
      )
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,),

In [8]:
optimizer = optim.Adam(decoder.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size)
cross_entropy_loss = nn.CrossEntropyLoss().to(device)

In [9]:
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [10]:
# print(decoder.state_dict())

In [11]:
train_loader = torch.utils.data.DataLoader(
    ImageCaptionDataset(data_transforms, args.data),
    batch_size=args.batch_size, shuffle=True, num_workers=1)

val_loader = torch.utils.data.DataLoader(
    ImageCaptionDataset(data_transforms, args.data, split_type='val'),
    batch_size=args.batch_size, shuffle=True, num_workers=1, collate_fn=collate_fn)

In [15]:
def train(epoch, encoder, decoder, optimizer, cross_entropy_loss, data_loader, word_dict, alpha_c, log_interval, writer):
    encoder.eval()
    decoder.train()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    for batch_idx, (imgs, captions, captions_len) in enumerate(data_loader):
        imgs, captions = imgs.to(device), captions.to(device)
        img_features = encoder(imgs)
        optimizer.zero_grad()
        preds, alphas = decoder(captions, img_features)
        targets = captions[:, 1:]

        caption_lengths = [ (cap != word_dict['<pad>']).sum().item() - 1 for cap in captions ]
        targets = pack_padded_sequence(targets, caption_lengths, batch_first=True, enforce_sorted=False)[0]
        preds = pack_padded_sequence(preds, caption_lengths, batch_first=True, enforce_sorted=False)[0]

        att_regularization = alpha_c * ((1 - alphas.sum(1))**2).mean()

        loss = cross_entropy_loss(preds, targets)
        loss += att_regularization
        loss.backward()
        optimizer.step()

        total_caption_length = calculate_caption_lengths(word_dict, captions)
        acc1 = accuracy(preds, targets, 1)
        acc5 = accuracy(preds, targets, 5)
        losses.update(loss.item(), total_caption_length)
        top1.update(acc1, total_caption_length)
        top5.update(acc5, total_caption_length)

        if batch_idx % log_interval == 0:
            print('Train Batch: [{0}/{1}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top 1 Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Top 5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(
                      batch_idx, len(data_loader), loss=losses, top1=top1, top5=top5))
    writer.add_scalar('train_loss', losses.avg, epoch)
    writer.add_scalar('train_top1_acc', top1.avg, epoch)
    writer.add_scalar('train_top5_acc', top5.avg, epoch)

In [13]:
def validate(epoch, encoder, decoder, cross_entropy_loss, data_loader, word_dict, alpha_c, log_interval, writer):
    encoder.eval()
    decoder.eval()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # used for calculating bleu scores
    references = []
    hypotheses = []
    with torch.no_grad():
        for batch_idx, (imgs, captions, captions_len, all_captions) in enumerate(data_loader):
            imgs, captions = imgs.to(device), captions.to(device)
            img_features = encoder(imgs)
            preds, alphas = decoder(captions, img_features)
            targets = captions[:, 1:]
            # print(f"cap={captions.size()}")
            # print((targets.size()))
            # print((preds.size()))
            targets = pack_padded_sequence(targets, captions_len, batch_first=True, enforce_sorted=False)[0]
            packed_preds = pack_padded_sequence(preds, captions_len, batch_first=True, enforce_sorted=False)[0]

            att_regularization = alpha_c * ((1 - alphas.sum(1))**2).mean()
            # print((targets.size()))
            # print((packed_preds.size()))
            loss = cross_entropy_loss(packed_preds, targets)
            loss += att_regularization

            total_caption_length = calculate_caption_lengths(word_dict, captions)
            acc1 = accuracy(packed_preds, targets, 1)
            acc5 = accuracy(packed_preds, targets, 5)
            losses.update(loss.item(), total_caption_length)
            top1.update(acc1, total_caption_length)
            top5.update(acc5, total_caption_length)

            for cap_set in all_captions:
                caps = []
                for caption in cap_set:
                    cap = [word_idx for word_idx in caption
                                    if word_idx != word_dict['<start>'] and word_idx != word_dict['<pad>']]
                    caps.append(cap)
                references.append(caps)

            word_idxs = torch.max(preds, dim=2)[1]
            for idxs in word_idxs.tolist():
                hypotheses.append([idx for idx in idxs
                                       if idx != word_dict['<start>'] and idx != word_dict['<pad>']])

            if batch_idx % log_interval == 0:
                print('Validation Batch: [{0}/{1}]\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top 1 Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Top 5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(
                          batch_idx, len(data_loader), loss=losses, top1=top1, top5=top5))
        writer.add_scalar('val_loss', losses.avg, epoch)
        writer.add_scalar('val_top1_acc', top1.avg, epoch)
        writer.add_scalar('val_top5_acc', top5.avg, epoch)

        bleu_1 = corpus_bleu(references, hypotheses, weights=(1, 0, 0, 0))
        bleu_2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0))
        bleu_3 = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0))
        bleu_4 = corpus_bleu(references, hypotheses)

        writer.add_scalar('val_bleu1', bleu_1, epoch)
        writer.add_scalar('val_bleu2', bleu_2, epoch)
        writer.add_scalar('val_bleu3', bleu_3, epoch)
        writer.add_scalar('val_bleu4', bleu_4, epoch)
        print('Validation Epoch: {}\t'
              'BLEU-1 ({})\t'
              'BLEU-2 ({})\t'
              'BLEU-3 ({})\t'
              'BLEU-4 ({})\t'.format(epoch, bleu_1, bleu_2, bleu_3, bleu_4))

In [None]:
print('Starting training with {}'.format(args))
for epoch in range(1, args.epochs + 1):
    scheduler.step()
    train(epoch, encoder, decoder, optimizer, cross_entropy_loss,
          train_loader, word_dict, args.alpha_c, args.log_interval, writer)
    validate(epoch, encoder, decoder, cross_entropy_loss, val_loader,
             word_dict, args.alpha_c, args.log_interval, writer)
    model_file = '/kaggle/working/' + args.network + '_' + str(epoch) + '.pth'
    torch.save(decoder.state_dict(), model_file)
    print('Saved model to ' + model_file)
writer.close()

Starting training with Namespace(batch_size=32, epochs=10, lr=0.0001, step_size=5, alpha_c=1, log_interval=100, data='/kaggle/input/image-captioning-dataset', network='vgg19', model=None, tf=False)
Train Batch: [0/1138]	Loss 9.3726 (9.3726)	Top 1 Accuracy 0.000 (0.000)	Top 5 Accuracy 0.255 (0.255)
Train Batch: [100/1138]	Loss 5.9067 (6.6313)	Top 1 Accuracy 15.309 (10.909)	Top 5 Accuracy 36.790 (32.038)
Train Batch: [200/1138]	Loss 5.6650 (6.1209)	Top 1 Accuracy 16.704 (14.353)	Top 5 Accuracy 40.312 (36.191)
Train Batch: [300/1138]	Loss 5.3680 (5.8905)	Top 1 Accuracy 19.213 (15.879)	Top 5 Accuracy 40.972 (38.323)
Train Batch: [400/1138]	Loss 5.1035 (5.7173)	Top 1 Accuracy 25.648 (18.032)	Top 5 Accuracy 44.301 (39.888)
Train Batch: [500/1138]	Loss 5.0057 (5.5886)	Top 1 Accuracy 27.340 (19.580)	Top 5 Accuracy 46.798 (41.053)
Train Batch: [600/1138]	Loss 4.6455 (5.4839)	Top 1 Accuracy 29.155 (20.694)	Top 5 Accuracy 48.688 (42.017)
Train Batch: [700/1138]	Loss 4.7755 (5.3965)	Top 1 Accuracy