In [3]:
#imports 
import os
import pandas as pd
from collections import Counter
from sklearn.model_selection import train_test_split

import torch
import torch.optim as optim
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader,Dataset

from torchtext.datasets import translation, imdb, language_modeling, nli
from torchtext.datasets import sequence_tagging, unsupervised_learning, text_classification, sst
from torchtext.data import  Field, RawField,  BucketIterator

import spacy
import statistics
import torchtext
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

from PIL import Image
import matplotlib.pyplot as plt


import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.metrics import Accuracy

from typing import *

In [4]:
# # https://spacy.io/usage/spacy-101#annotations-token
# spacy_de = spacy.load('de_core_news_sm')
# spacy_en = spacy.load('en_core_web_sm')


# def tokenize_de(text: str) -> List[str]:
#     return [tok.text for tok in spacy_de.tokenizer(text)][::-1]

# def tokenize_en(text: str) -> List[str]:
#     return [tok.text for tok in spacy_en.tokenizer(text)]

# source = Field(tokenize=tokenize_de, init_token='<sos>', eos_token='<eos>', lower=True)
# target = Field(tokenize=tokenize_en, init_token='<sos>', eos_token='<eos>', lower=True)


# train_data, valid_data, test_data = translation.WMT14.splits(
#     exts=('.de', '.en'), fields=(source, target), root='../data/translation/'
# )
# print(f"Number of training examples: {len(train_data.examples)}")
# print(f"Number of validation examples: {len(valid_data.examples)}")
# print(f"Number of testing examples: {len(test_data.examples)}")

In [36]:
spacy_eng = spacy.load('en_core_web_sm')

class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1

                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]


class FlickrDataset(Dataset):
    def __init__(self, root_dir, caption_file, transform=None, freq_threshold=5,
                 train=True, split_val=0.2):
        self.root_dir = root_dir
        self.caption_file = caption_file
        self.df = pd.read_csv(caption_file)
        self.transform = transform
        
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.df['caption'].tolist())
        
        self.train = train 
        self.split_val = split_val
        self._do_split_train_valid()
        
#         # Get img, caption columns
#         self.imgs = self.df["image"]
#         self.captions = self.df["caption"]

        # Initialize vocabulary and build vocab

        
    def _do_split_train_valid(self):
        imgs_train, imgs_valid, caps_train, caps_valid = train_test_split(
            self.df["image"], self.df["caption"], 
            test_size=self.split_val, random_state=16
        )
        
        if self.train:
            self.imgs = imgs_train
            self.captions = caps_train
        else:
            self.imgs = imgs_valid
            self.captions = caps_valid
            
        self.imgs = self.imgs.tolist()
        self.captions = self.captions.tolist()
        

    def __len__(self):
        return len(self.imgs)
    
    def _numericalized_caption(self, caption):
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])
        
        return numericalized_caption

    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        ncaption = self._numericalized_caption(caption)

        return img, torch.tensor(ncaption)


class CaptionCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx)

        return imgs, targets
    

In [37]:
def flickr8k_dataloader(root_folder, caption_file, transform, train=True,
                        batch_size=32, num_workers=8, shuffle=True, pin_memory=True):
    
    dataset = FlickrDataset(root_folder, caption_file, transform=transform, train=train)
    pad_idx = dataset.vocab.stoi["<PAD>"]
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers,
                            shuffle=shuffle, pin_memory=pin_memory, 
                            collate_fn=CaptionCollate(pad_idx=pad_idx))
    
    return dataloader, dataset

In [38]:
train_transform = transforms.Compose([
            transforms.Resize((356, 356)),
            transforms.RandomCrop((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

valid_transform = transforms.Compose([
                    transforms.Resize((224, 224)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])


root_dir = "/data/flickr/flickr8k/images/"
caption_file = "/data/flickr/flickr8k/captions.txt"

train_loader, trainset = flickr8k_dataloader(root_dir, caption_file, transform=train_transform, train=True)
valid_loader, validset = flickr8k_dataloader(root_dir, caption_file, transform=valid_transform, train=False)

In [39]:
img, cap = trainset[200]
print([trainset.vocab.itos[token] for token in cap.tolist()])

len(trainset), len(validset)
type(trainset.vocab.stoi["<PAD>"])

['<SOS>', 'the', 'clowns', 'are', 'striking', 'a', 'pose', 'for', 'the', 'camera', '.', '<EOS>']


int

In [97]:
imgs, caps = next(iter(valid_loader))
print([trainset.vocab.itos[token] for token in caps[0, :-1].tolist()])

caps.shape, caps[0, :-1].shape

# caps[:, :-1]

['<SOS>', 'some', 'african', 'kids', 'playing', 'with', 'a', 'ball', '.', '<EOS>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']


(torch.Size([32, 25]), torch.Size([24]))

In [112]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        self.train_CNN = train_CNN
        self.resnet = models.resnet34(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_size)
        self.relu = nn.ReLU()
        self.times = []
        self.dropout = nn.Dropout(0.5)

    def forward(self, images):
        features = self.resnet(images)
        return self.dropout(self.relu(features))


class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, dropout_p=0.5):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, dropout=dropout_p, 
                            batch_first=True, bidirectional=True)
        self.linear = nn.Linear(hidden_size * 2, vocab_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, features, captions):
        captions = captions[:, :-1]
        embeddings = self.dropout(self.embed(captions))
        embeddings = torch.cat((features.unsqueeze(1), embeddings), dim=1)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs


class ImageCaptionNet(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(ImageCaptionNet, self).__init__()
        self.encoderCNN = EncoderCNN(embed_size)
        self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)

    def forward(self, images, captions):
        features = self.encoderCNN(images)
        outputs = self.decoderRNN(features, captions)
        return outputs

    def caption_image(self, image, vocabulary, max_length=50):
        result_caption = []

        with torch.no_grad():
            x = self.encoderCNN(image).unsqueeze(0)
            states = None

            for _ in range(max_length):
                hiddens, states = self.decoderRNN.lstm(x, states)
                output = self.decoderRNN.linear(hiddens.squeeze(0))
                predicted = output.argmax(1)
                result_caption.append(predicted.item())
                x = self.decoderRNN.embed(predicted).unsqueeze(0)

                if vocabulary.itos[predicted.item()] == "<EOS>":
                    break

        return [vocabulary.itos[idx] for idx in result_caption]

    
    

In [113]:
class ImageCaptionTask(pl.LightningModule):
    def __init__(self, model, optimizers, criterion, scheduler=None):
        super().__init__()
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.scheduler = scheduler
        self.metric = Accuracy()
        
    def forward(self, imgs, captions):
        outputs = self.model(imgs, captions)
        return outputs
        
    
    def shared_step(self, batch, batch_idx):
        imgs, captions = batch
        outputs = self.model(imgs, captions)
        
        outputs_preprocess = outputs.reshape(-1, outputs.shape[2])
        captions_preprocess = captions.reshape(-1)
        loss = criterion(outputs_preprocess, captions_preprocess)
#         acc = (output.argmax(1) == cls).sum().item()
        acc = self.metric(outputs_preprocess.argmax(1), captions_preprocess)
        return loss, acc

    
    def training_step(self, batch, batch_idx):
        loss, acc = self.shared_step(batch, batch_idx)
        result = pl.TrainResult(loss)
#         result.log_dict({'trn_loss': loss})
        result.log_dict({'trn_loss': loss, 'trn_acc':acc})
        
        return result
    
    def validation_step(self, batch, batch_idx):
        loss, acc = self.shared_step(batch, batch_idx)
        result = pl.EvalResult(checkpoint_on=loss)
#         result.log_dict({'val_loss': loss})
        result.log_dict({'val_loss': loss, 'val_acc': acc})
        
        return result
    
    def configure_optimizers(self):
        if self.scheduler:
            return [self.optimizer], [self.scheduler]
        return self.optimizer
    
  

In [114]:
# from utils import save_checkpoint, load_checkpoint, print_examples
# from get_loader import get_loader
# from model import CNNtoRNN

torch.backends.cudnn.benchmark = True
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
embed_size = 256
hidden_size = 256
vocab_size = len(trainset.vocab)
num_layers = 1
learning_rate = 2e-4
num_epochs = 100

In [115]:
vocab_size

2994

In [134]:
# initialize model, loss etc
model = ImageCaptionNet(embed_size, hidden_size, vocab_size, num_layers)

In [135]:
pad_index = trainset.vocab.stoi["<PAD>"]
criterion = nn.CrossEntropyLoss(ignore_index=pad_index)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


In [136]:
checkpoint_path = '../saved_model'
# DEFAULTS used by the Trainer
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_top_k=1,
    verbose=True,
    monitor='checkpoint_on',
    mode='min',
    prefix='image_caption_net_'
)

In [None]:
tb_logger = pl_loggers.TensorBoardLogger('logs/flickr8k')
task = ImageCaptionTask(model, optimizer, criterion)
trainer = pl.Trainer(gpus=1, logger=tb_logger, checkpoint_callback=checkpoint_callback)
trainer.fit(task, train_loader, valid_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | ImageCaptionNet  | 24 M  
1 | criterion | CrossEntropyLoss | 0     
2 | metric    | Accuracy         | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00000: val_checkpoint_on reached 1.61182 (best 1.61182), saving model to ../saved_model/image_caption_net_epoch=0.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00001: val_checkpoint_on reached 0.94836 (best 0.94836), saving model to ../saved_model/image_caption_net_epoch=1.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00002: val_checkpoint_on reached 0.63462 (best 0.63462), saving model to ../saved_model/image_caption_net_epoch=2.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00003: val_checkpoint_on reached 0.45017 (best 0.45017), saving model to ../saved_model/image_caption_net_epoch=3.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00004: val_checkpoint_on reached 0.33191 (best 0.33191), saving model to ../saved_model/image_caption_net_epoch=4.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00005: val_checkpoint_on reached 0.25035 (best 0.25035), saving model to ../saved_model/image_caption_net_epoch=5.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00006: val_checkpoint_on reached 0.19440 (best 0.19440), saving model to ../saved_model/image_caption_net_epoch=6.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00007: val_checkpoint_on reached 0.15382 (best 0.15382), saving model to ../saved_model/image_caption_net_epoch=7.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00008: val_checkpoint_on reached 0.12294 (best 0.12294), saving model to ../saved_model/image_caption_net_epoch=8.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00009: val_checkpoint_on reached 0.10149 (best 0.10149), saving model to ../saved_model/image_caption_net_epoch=9.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00010: val_checkpoint_on reached 0.08382 (best 0.08382), saving model to ../saved_model/image_caption_net_epoch=10.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Epoch 00011: val_checkpoint_on reached 0.07000 (best 0.07000), saving model to ../saved_model/image_caption_net_epoch=11.ckpt as top 1


In [44]:
# models.resnet34(pretrained=True)
# models.resnet50(pretrained=True)
# models.resnet101(pretrained=True)
# models.resnet152(pretrained=True)

In [None]:
import torchtext