In [None]:
%%capture
!pip install -U sentence-transformers datasets

In [None]:
from os import path

from tqdm import trange
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision.io import read_image
from datasets import Dataset
from sentence_transformers import SentenceTransformer

import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F

import time
import torchvision
import torch
import pickle
import pathlib
import collections
import urllib
import zipfile

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
BATCH_SIZE = 128
N_CAPTIONS = 5
TARGET_DIM = 128

sentence_transformer_ckp = 'nreimers/MiniLM-L6-H384-uncased'

# Helper Functions

## Generic Helpers

In [None]:
def dict_to_device(d, device):
  return {k: v.to(device) for k, v in d.items()}

In [None]:
def repeat_list(my_list, n):
  return [x for x in my_list for _ in range(n)]

## Selection Helpers

In [None]:
def cosine_similarity_matrix(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
  """
  Source: https://stackoverflow.com/a/50426321
  """
  a = a / a.norm(dim=-1, keepdim=True)
  b = b / b.norm(dim=-1, keepdim=True)
  return a @ b.t()

## Metrics and Loss

In [None]:
def random_negative_criterion(vectors_a: torch.Tensor, vectors_b: torch.Tensor, labels: torch.Tensor, loss_fn: torch.nn.TripletMarginLoss) -> torch.Tensor:
  """
  Calculates triplet loss to map representations `vectors_a` and `vectors_b` to the same vectors.
  The hard negative sample is chosen based on the dot product between `vectors_a` and `vectors_b`.
  The function returns the calculated loss value.

  Parameters:
  vectors_a (torch.Tensor): A tensor of shape (batch_size, embedding_size) representing the embeddings for the first set of vectors.
  vectors_b (torch.Tensor): A tensor of shape (batch_size, embedding_size) representing the embeddings for the second set of vectors.
  labels (torch.Tensor): A tensor of shape (batch_size) representing the label / class for each sample in `vectors_a` and `vectors_b`.
  loss_fn ([torch.nn.TripletMarginLoss], optional): The triplet margin loss function. Defaults to `loss_fn` defined in the global scope.

  Returns:
  torch.Tensor: A tensor representing the calculated loss value.
  """
  non_positive_msk = (labels.unsqueeze(0) != labels.unsqueeze(1))

  # dot_product = torch.matmul(vectors_a, vectors_b.t())
  negative_msk = torch.max(torch.rand(non_positive_msk.shape, device=non_positive_msk.get_device()) * non_positive_msk, dim=1).indices
  loss = loss_fn(vectors_a, vectors_b, vectors_b[negative_msk])
  return loss

In [None]:
def hard_negative_criterion(vectors_a: torch.Tensor, vectors_b: torch.Tensor, labels: torch.Tensor, loss_fn: torch.nn.TripletMarginLoss) -> torch.Tensor:
  """
  Calculates triplet loss to map representations `vectors_a` and `vectors_b` to the same vectors.
  The hard negative sample is chosen based on the dot product between `vectors_a` and `vectors_b`.
  The function returns the calculated loss value.

  Parameters:
  vectors_a (torch.Tensor): A tensor of shape (batch_size, embedding_size) representing the embeddings for the first set of vectors.
  vectors_b (torch.Tensor): A tensor of shape (batch_size, embedding_size) representing the embeddings for the second set of vectors.
  labels (torch.Tensor): A tensor of shape (batch_size) representing the label / class for each sample in `vectors_a` and `vectors_b`.
  loss_fn ([torch.nn.TripletMarginLoss], optional): The triplet margin loss function. Defaults to `loss_fn` defined in the global scope.

  Returns:
  torch.Tensor: A tensor representing the calculated loss value.
  """
  positive_msk = (labels.unsqueeze(0) == labels.unsqueeze(1))

  # dot_product = torch.matmul(vectors_a, vectors_b.t())
  cos_sim = cosine_similarity_matrix(vectors_a, vectors_b)
  cos_sim = torch.where(positive_msk, float('-inf'), cos_sim)
  negative_msk = torch.max(cos_sim, dim=1).indices
  loss = loss_fn(vectors_a, vectors_b, vectors_b[negative_msk])
  return loss

In [None]:
def topk_accuracy(vectors_a: torch.Tensor, vectors_b: torch.Tensor, labels_a: torch.Tensor, labels_b: torch.Tensor, k :int=5) -> torch.Tensor:
  """
  Calculates the top-k accuracy of the predictions made using `vectors_a` and `vectors_b`.
  The predictions are made by comparing the cosine similarity between the vectors in `vectors_a` and `vectors_b`.
  The function returns the calculated accuracy.

  Parameters:
  vectors_a (torch.Tensor): A tensor of shape (batch_size, vector_size).
  vectors_b (torch.Tensor): A tensor of shape (batch_size, vector_size).
  labels_a (torch.Tensor): A tensor of shape (batch_size) representing the labels for each sample in `vectors_a`.
  labels_b (torch.Tensor): A tensor of shape (batch_size) representing the labels for each sample in `vectors_b`.
  k (int, optional): # of top predictions to consider. Defaults to 5.

  Returns:
  torch.Tensor: Accuracy.
  """
  pos_mask = (labels_a.unsqueeze(1) == labels_b.unsqueeze(0))
  sim_matrix = cosine_similarity_matrix(vectors_a, vectors_b)
  topk_mask = torch.topk(sim_matrix, k=k).indices
  topk_pos = pos_mask.gather(dim=1, index=topk_mask)
  true_pred = torch.any(topk_pos, dim=1)
  return torch.sum(true_pred) / true_pred.nelement()

# Train Utils

In [None]:
def train(model, iterator, optimizer, criterion, device):

    epoch_loss = 0

    model.train()

    # has to be assigned here as some iterators
    # change length after each loop
    len_iterator = len(iterator)

    for (a, b, labels) in iterator:

        a = a.to(device)
        b = dict_to_device(b, device)
        labels = labels.to(device)

        optimizer.zero_grad()

        a_t, b_t = model(a, b)
        
        loss = criterion(a_t, b_t, labels)

        loss.backward()

        optimizer.step()

        epoch_loss += loss.item()
    
    return epoch_loss / len_iterator

In [None]:
def evaluate(model, iterator, criterion, device):

    epoch_loss = 0

    model.eval()

    len_iterator = len(iterator)

    with torch.no_grad():

        for (a, b, labels) in iterator:

            a = a.to(device)
            b = dict_to_device(b, device)
            labels = labels.to(device)

            a_t, b_t = model(a, b)
            
            loss = criterion(a_t, b_t, labels)

            epoch_loss += loss.item()
    
    return epoch_loss / len_iterator

In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=4, verbose=False, delta=1e-3, path='checkpoint.pt', trace_func=print, save_checkpoint_file=True):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 4
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.val_acc_max = np.NINF
        self.time_at_stop = 0
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
        self.save_checkpoint_file = save_checkpoint_file

        # TEMP: DEBUG
        self.n_calls = 0
    def __call__(self, val_loss, val_acc, model, current_time=0):
        self.n_calls += 1

        if val_loss > self.val_loss_min - self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.save_checkpoint(val_loss, val_acc, model)
            self.counter = 0
            self.time_at_stop = current_time

    def save_checkpoint(self, val_loss, val_acc, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).')
        if self.save_checkpoint_file:
          torch.save(model.state_dict(), self.path)
          print(f'{self.n_calls:3d}. call: Saving model...')
        self.val_loss_min = val_loss
        self.val_acc_max = val_acc

In [None]:
def train_planned(model, train_dataloader, test_loader, criterion, device='cpu', num_epochs=75, early_stop=True, patience=5, save_model=False, losses=None):
  pbar = trange(num_epochs, desc='Training', position=0, leave=True)

  early_stopping = EarlyStopping(verbose=True, patience=patience, delta=1e-3, trace_func=pbar.set_description, save_checkpoint_file=save_model)
  train_losses, test_losses = [], []

  info = {
      'train_losses': train_losses,
      'test_losses': test_losses,
      'loss': np.Inf,
      'epochs': num_epochs,
      'time': 0
  }

  for epoch in pbar:

    start = time.time()
    train_loss = train(model, train_dataloader, optimizer, criterion, device)
    stop = time.time()

    test_loss = evaluate(model, test_loader, criterion, device)

    info['loss'] = test_loss
    info['time'] = info['time'] + stop - start

    train_losses.append(train_loss)
    test_losses.append(test_loss)

    early_stopping(test_loss, 1, model, current_time=info['time'])

    pbar.set_description(f'Test / Train | Loss: {test_loss:.3f}/{train_loss:.3f}')

    if early_stop and early_stopping.early_stop:
      pbar.close()
      print(f'Early stopping. Completed {epoch}/{num_epochs} epochs.')

      info['loss'] = early_stopping.val_loss_min
      info['time'] = early_stopping.time_at_stop
      # Number of epochs the reported model were trained for
      info['epochs'] = epoch - early_stopping.patience + 1

      return info
  
  pbar.close()
  return info

# Model Definitions

## Text

In [None]:
sentence_model = SentenceTransformer(sentence_transformer_ckp)
tokenized = sentence_model.tokenize(['dummy_text'])
text_feature_dim = sentence_model(tokenized)['sentence_embedding'].nelement()

In [None]:
del sentence_model, tokenized

## Image

In [None]:
resnet_weights = torchvision.models.ResNet18_Weights.DEFAULT

transforms = resnet_weights.transforms()

In [None]:
def init_base_img_model(weights):
  image_model = torchvision.models.resnet18(weights)
  image_feature_layers = list(image_model.children())[:-2]
  return nn.Sequential(*image_feature_layers)

In [None]:
image_model = init_base_img_model(resnet_weights)
dummy_input = torch.randn(1, 3, 224, 224)
image_feature_shape = image_model(dummy_input).shape

In [None]:
del image_model, dummy_input

## Merged

In [None]:
class ImageFCModel(nn.Module):
  def __init__(self, img_shape, transformed_dim, maxpool_kernel_size=2):
    super().__init__()
    self.maxpool_kernel_size = maxpool_kernel_size
    self.img_base_model = init_base_img_model(resnet_weights)

    # calc img dim after feature extraction
    img_dim = F.max_pool2d(torch.randn(img_shape), kernel_size=self.maxpool_kernel_size).view(-1).nelement()

    self.fc1 = nn.Linear(img_dim, 1024)
    self.fc2 = nn.Linear(1024, 256)
    self.fc3 = nn.Linear(256, transformed_dim)
    self.l2_norm = nn.utils.weight_norm(self.fc3)

  def forward(self, x):
    x = self.img_base_model(x)
    x = F.max_pool2d(x, self.maxpool_kernel_size)
    x = torch.flatten(x, start_dim=1) # flatten
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    # x = self.l2_norm(x)
    x = self.fc3(x)
    return x

In [None]:
class TextFCModel(nn.Module):
  def __init__(self, text_dim, transformed_dim, sentence_transformer_ckp='nreimers/MiniLM-L6-H384-uncased'):
    super().__init__()
    self.text_transformer = SentenceTransformer(sentence_transformer_ckp, device=device)

    self.fc1 = nn.Linear(text_dim, 512)
    self.fc2 = nn.Linear(512, 256)
    self.fc3 = nn.Linear(256, transformed_dim)
    self.l2_norm = nn.utils.weight_norm(self.fc3)

  def forward(self, x):
    with torch.no_grad():
      x = self.text_transformer(x)['sentence_embedding']
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    # x = self.l2_norm(x)
    x = self.fc3(x)
    return x

In [None]:
class MergedModel(nn.Module):
  def __init__(self, img_shape, text_dim, transformed_dim, sentence_transformer_ckp='nreimers/MiniLM-L6-H384-uncased', maxpool_kernel_size=2):
    super().__init__()

    self.img_model = ImageFCModel(img_shape, transformed_dim, maxpool_kernel_size=maxpool_kernel_size)
    self.text_model = TextFCModel(text_dim, transformed_dim, sentence_transformer_ckp=sentence_transformer_ckp)

  def forward(self, img, tokenized_text):
    img = self.img_model(img)
    text_x = self.text_model(tokenized_text)

    return img, text_x

# Dataset

## Download The Dataset

In [None]:
dataset_img_path = pathlib.Path('flickr8k') / 'Flicker8k_Dataset'

In [None]:
# Reference: https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/text/image_captioning.ipynb#scrollTo=kaNy_l7tGuAZ&line=1&uniqifier=1

def flickr8k(path='flickr8k'):
  path = pathlib.Path(path)
  path = pathlib.Path(path)
  dataset_path = path / 'Flicker8k_Dataset'

  if not dataset_path.exists():
    url = 'https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip'
    file_path, _ = urllib.request.urlretrieve(url)
    zip_ref = zipfile.ZipFile(file_path, 'r')
    zip_ref.extractall(path)
    zip_ref.close()

    url = 'https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip'
    file_path, _ = urllib.request.urlretrieve(url)
    zip_ref = zipfile.ZipFile(file_path, 'r')
    zip_ref.extractall(path)
    zip_ref.close()
    
  captions = (path/"Flickr8k.token.txt").read_text().splitlines()
  captions = (line.split('\t') for line in captions)
  captions = ((fname.split('#')[0], caption) for (fname, caption) in captions)

  cap_dict = collections.defaultdict(list)
  for fname, cap in captions:
    cap_dict[fname].append(cap)

  train_files = (path/'Flickr_8k.trainImages.txt').read_text().splitlines()
  train_captions = [(str(path/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in train_files]

  dev_files = (path/'Flickr_8k.devImages.txt').read_text().splitlines()
  dev_captions = [(str(path/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in dev_files]

  test_files = (path/'Flickr_8k.testImages.txt').read_text().splitlines()
  test_captions = [(str(path/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in test_files]

  return train_captions, dev_captions, test_captions

In [None]:
train_raw, dev_raw, test_raw = flickr8k()

## Preprocess The Dataset

In [None]:
class CustomImageDataset(Dataset):
  def __init__(self, image_paths, transform=None, target_transform=None):
    self.image_paths = image_paths
    self.transform = transform
    self.target_transform = target_transform

  def __len__(self):
    return len(self.image_paths)

  def __getitem__(self, idx):
    img_path = self.image_paths[idx]
    image = read_image(str(img_path))
    if self.transform:
      image = self.transform(image)
    return image

In [None]:
class MergedDataset(Dataset):
  def __init__(self, datasets):
    self.datasets = datasets

  def __len__(self):
    return min(len(ds) for ds in self.datasets)

  def __getitem__(self, idx):
    return tuple(ds[idx] for ds in self.datasets)

## Create Text Dataset

In [None]:
train_captions = [cap for _, captions in train_raw for cap in captions]
dev_captions = [cap for _, captions in dev_raw for cap in captions]
test_captions = [cap for _, captions in test_raw for cap in captions]

In [None]:
sentence_transformer = SentenceTransformer(sentence_transformer_ckp, device='cpu')

tokenized_train_text = sentence_transformer.tokenize(train_captions)
tokenized_dev_text = sentence_transformer.tokenize(dev_captions)
tokenized_test_text = sentence_transformer.tokenize(test_captions)

del sentence_transformer

In [None]:
text_train_ds = Dataset.from_dict(tokenized_train_text)
text_dev_ds = Dataset.from_dict(tokenized_dev_text)
text_test_ds = Dataset.from_dict(tokenized_test_text)

In [None]:
text_train_ds.set_format('torch')
text_dev_ds.set_format('torch')
text_test_ds.set_format('torch')

## Create Image Dataset

In [None]:
img_train_paths = [path for path, _ in train_raw]
img_dev_paths = [path for path, _ in dev_raw]
img_test_paths = [path for path, _ in test_raw]

In [None]:
start = 0
end = len(img_train_paths)
label_train_tensor = torch.arange(start=start, end=end)

start = end
end += len(img_dev_paths)
label_dev_tensor = torch.arange(start=start, end=end)

start = end
end += len(img_test_paths)
label_test_tensor = torch.arange(start=start, end=end)

# repeat the label for the number of captions. each image is repeated 5 times, 
# so the other captions have the same label as the image
label_train_tensor = torch.repeat_interleave(label_train_tensor, repeats=N_CAPTIONS)
label_dev_tensor = torch.repeat_interleave(label_dev_tensor, repeats=N_CAPTIONS)
label_test_tensor = torch.repeat_interleave(label_test_tensor, repeats=N_CAPTIONS)

label_train_ds = label_train_tensor
label_dev_ds = label_dev_tensor
label_test_ds = label_test_tensor

In [None]:
# More efficient implementation can be considered
# (i.e. the pipeline can be changed to avoid repititon in this step)
img_train_paths = repeat_list(img_train_paths, N_CAPTIONS)
img_dev_paths = repeat_list(img_dev_paths, N_CAPTIONS)
img_test_paths = repeat_list(img_test_paths, N_CAPTIONS)

In [None]:
img_train_ds = CustomImageDataset(img_train_paths, transform=transforms)
img_dev_ds = CustomImageDataset(img_dev_paths, transform=transforms)
img_test_ds = CustomImageDataset(img_test_paths, transform=transforms)

## Create Merged Dataset

In [None]:
train_ds = MergedDataset([img_train_ds, text_train_ds, label_train_ds])
dev_ds = MergedDataset([img_dev_ds, text_dev_ds, label_dev_ds])
test_ds = MergedDataset([img_test_ds, text_test_ds, label_test_ds])

In [None]:
train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
dev_dataloader = DataLoader(dev_ds, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

# Train

In [None]:
model = MergedModel(image_feature_shape, text_feature_dim, TARGET_DIM, sentence_transformer_ckp)
loss_fn = torch.nn.TripletMarginLoss(margin=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

In [None]:
model = model.to(device)

In [None]:
info = train_planned(model, train_dataloader, dev_dataloader, lambda a, b, labels: hard_negative_criterion(a, b, labels, loss_fn), device=device, num_epochs=13, early_stop=False, patience=5, save_model=True)

In [None]:
torch.save(model.state_dict(), 'model.pt')