In [None]:
import time
import pathlib
import pickle
import torch

import numpy as np

import torch.nn as nn
import torch.nn.functional as F

from tqdm import trange
from torch.utils.data import TensorDataset, DataLoader

In [None]:
# Arguments
BATCH_SIZE = 128

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Helper Functions

In [None]:
def cosine_similarity_matrix(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
  """
  Source: https://stackoverflow.com/a/50426321
  """
  a = a / a.norm(dim=-1, keepdim=True)
  b = b / b.norm(dim=-1, keepdim=True)
  return a @ b.t()

In [None]:
def random_negative_criterion(vectors_a: torch.Tensor, vectors_b: torch.Tensor, labels: torch.Tensor, loss_fn: torch.nn.TripletMarginLoss) -> torch.Tensor:
  """
  Calculates triplet loss to map representations `vectors_a` and `vectors_b` to the same vectors.
  The hard negative sample is chosen based on the dot product between `vectors_a` and `vectors_b`.
  The function returns the calculated loss value.

  Parameters:
  vectors_a (torch.Tensor): A tensor of shape (batch_size, embedding_size) representing the embeddings for the first set of vectors.
  vectors_b (torch.Tensor): A tensor of shape (batch_size, embedding_size) representing the embeddings for the second set of vectors.
  labels (torch.Tensor): A tensor of shape (batch_size) representing the label / class for each sample in `vectors_a` and `vectors_b`.
  loss_fn ([torch.nn.TripletMarginLoss], optional): The triplet margin loss function. Defaults to `loss_fn` defined in the global scope.

  Returns:
  torch.Tensor: A tensor representing the calculated loss value.
  """
  non_positive_msk = (labels.unsqueeze(0) != labels.unsqueeze(1))

  # dot_product = torch.matmul(vectors_a, vectors_b.t())
  negative_msk = torch.max(torch.rand(non_positive_msk.shape) * non_positive_msk, dim=1).indices
  loss = loss_fn(vectors_a, vectors_b, vectors_b[negative_msk])
  return loss

In [None]:
def hard_negative_criterion(vectors_a: torch.Tensor, vectors_b: torch.Tensor, labels: torch.Tensor, loss_fn: torch.nn.TripletMarginLoss) -> torch.Tensor:
  """
  Calculates triplet loss to map representations `vectors_a` and `vectors_b` to the same vectors.
  The hard negative sample is chosen based on the dot product between `vectors_a` and `vectors_b`.
  The function returns the calculated loss value.

  Parameters:
  vectors_a (torch.Tensor): A tensor of shape (batch_size, embedding_size) representing the embeddings for the first set of vectors.
  vectors_b (torch.Tensor): A tensor of shape (batch_size, embedding_size) representing the embeddings for the second set of vectors.
  labels (torch.Tensor): A tensor of shape (batch_size) representing the label / class for each sample in `vectors_a` and `vectors_b`.
  loss_fn ([torch.nn.TripletMarginLoss], optional): The triplet margin loss function. Defaults to `loss_fn` defined in the global scope.

  Returns:
  torch.Tensor: A tensor representing the calculated loss value.
  """
  positive_msk = (labels.unsqueeze(0) == labels.unsqueeze(1))

  # dot_product = torch.matmul(vectors_a, vectors_b.t())
  cos_sim = cosine_similarity_matrix(vectors_a, vectors_b)
  cos_sim = torch.where(positive_msk, float('-inf'), cos_sim)
  negative_msk = torch.max(cos_sim, dim=1).indices
  loss = loss_fn(vectors_a, vectors_b, vectors_b[negative_msk])
  return loss

In [None]:
def topk_accuracy(vectors_a: torch.Tensor, vectors_b: torch.Tensor, labels_a: torch.Tensor, labels_b: torch.Tensor, k :int=5) -> torch.Tensor:
  """
  Calculates the top-k accuracy of the predictions made using `vectors_a` and `vectors_b`.
  The predictions are made by comparing the cosine similarity between the vectors in `vectors_a` and `vectors_b`.
  The function returns the calculated accuracy.

  Parameters:
  vectors_a (torch.Tensor): A tensor of shape (batch_size, vector_size).
  vectors_b (torch.Tensor): A tensor of shape (batch_size, vector_size).
  labels_a (torch.Tensor): A tensor of shape (batch_size) representing the labels for each sample in `vectors_a`.
  labels_b (torch.Tensor): A tensor of shape (batch_size) representing the labels for each sample in `vectors_b`.
  k (int, optional): # of top predictions to consider. Defaults to 5.

  Returns:
  torch.Tensor: Accuracy.
  """
  pos_mask = (labels_a.unsqueeze(1) == labels_b.unsqueeze(0))
  sim_matrix = cosine_similarity_matrix(vectors_a, vectors_b)
  topk_mask = torch.topk(sim_matrix, k=k).indices
  topk_pos = pos_mask.gather(dim=1, index=topk_mask)
  true_pred = torch.any(topk_pos, dim=1)
  return torch.sum(true_pred) / true_pred.nelement()

# Train Utils

In [None]:
def train(model, iterator, optimizer, criterion, device):

    epoch_loss = 0

    model.train()

    # has to be assigned here as some iterators
    # change length after each loop
    len_iterator = len(iterator)

    for (a, b, labels) in iterator:

        a = a.to(device)
        b = b.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        a_t, b_t = model(a, b)
        
        loss = criterion(a_t, b_t, labels)

        loss.backward()

        optimizer.step()

        epoch_loss += loss.item()
    
    return epoch_loss / len_iterator

In [None]:
def evaluate(model, iterator, criterion, device):

    epoch_loss = 0

    model.eval()

    len_iterator = len(iterator)

    with torch.no_grad():

        for (a, b, labels) in iterator:

            a = a.to(device)
            b = b.to(device)
            labels = labels.to(device)

            a_t, b_t = model(a, b)
            
            loss = criterion(a_t, b_t, labels)

            epoch_loss += loss.item()
    
    return epoch_loss / len_iterator

In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=4, verbose=False, delta=1e-3, path='checkpoint.pt', trace_func=print, save_checkpoint_file=True):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 4
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.val_acc_max = np.NINF
        self.time_at_stop = 0
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
        self.save_checkpoint_file = save_checkpoint_file

        # TEMP: DEBUG
        self.n_calls = 0
    def __call__(self, val_loss, val_acc, model, current_time=0):
        self.n_calls += 1

        if val_loss > self.val_loss_min - self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.save_checkpoint(val_loss, val_acc, model)
            self.counter = 0
            self.time_at_stop = current_time

    def save_checkpoint(self, val_loss, val_acc, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).')
        if self.save_checkpoint_file:
          torch.save(model.state_dict(), self.path)
          print(f'{self.n_calls:3d}. call: Saving model...')
        self.val_loss_min = val_loss
        self.val_acc_max = val_acc

In [None]:
def train_planned(model, train_dataloader, test_loader, criterion, device='cpu', num_epochs=75, early_stop=True, patience=5, save_model=False, losses=None):
  pbar = trange(num_epochs, desc='Training', position=0, leave=True)

  early_stopping = EarlyStopping(verbose=True, patience=patience, delta=1e-5, trace_func=pbar.set_description, save_checkpoint_file=save_model)
  train_losses, test_losses = [], []

  info = {
      'train_losses': train_losses,
      'test_losses': test_losses,
      'loss': np.Inf,
      'epochs': num_epochs,
      'time': 0
  }

  for epoch in pbar:

    start = time.time()
    train_loss = train(model, train_dataloader, optimizer, criterion, device)
    stop = time.time()

    test_loss = evaluate(model, test_loader, criterion, device)

    info['loss'] = test_loss
    info['time'] = info['time'] + stop - start

    train_losses.append(train_loss)
    test_losses.append(test_loss)

    early_stopping(test_loss, 1, model, current_time=info['time'])

    pbar.set_description(f'Test / Train | Loss: {test_loss:.3f}/{train_loss:.3f}')

    if early_stop and early_stopping.early_stop:
      pbar.close()
      print(f'Early stopping. Completed {epoch}/{num_epochs} epochs.')

      info['loss'] = early_stopping.val_loss_min
      info['time'] = early_stopping.time_at_stop
      # Number of epochs the reported model were trained for
      info['epochs'] = epoch - early_stopping.patience + 1

      return info
  
  pbar.close()
  return info

# Loading Data

In [None]:
def load_data(filepath):
    with open(filepath, 'rb') as f:
        data = pickle.load(f)
    return data

def print_shapes(arrays):
  max_name_length = max(len(name) for name in arrays)
  for name, array in arrays.items():
    print(f'{name:{max_name_length}}: {array.shape}')

In [None]:
parent_folder = pathlib.Path('/content/drive/MyDrive/collective_learning/px-multimodal-repr/binaries/flickr8k')

train_img_vectors_path = parent_folder / 'train_img_vectors.pkl'
train_text_vectors_path = parent_folder / 'train_text_vectors.pkl'
dev_img_vectors_path = parent_folder / 'dev_img_vectors.pkl'
dev_text_vectors_path = parent_folder / 'dev_text_vectors.pkl'
test_img_vectors_path = parent_folder / 'test_img_vectors.pkl'
test_text_vectors_path = parent_folder / 'test_text_vectors.pkl'

train_img_vectors = load_data(train_img_vectors_path)
train_text_vectors = load_data(train_text_vectors_path)
dev_img_vectors = load_data(dev_img_vectors_path)
dev_text_vectors = load_data(dev_text_vectors_path)
test_img_vectors = load_data(test_img_vectors_path)
test_text_vectors = load_data(test_text_vectors_path)

In [None]:
arrays = {
    'train_img_vectors': train_img_vectors,
    'train_text_vectors': train_text_vectors,
    'dev_img_vectors': dev_img_vectors,
    'dev_text_vectors': dev_text_vectors,
    'test_img_vectors': test_img_vectors,
    'test_text_vectors': test_text_vectors,
}
print_shapes(arrays)

In [None]:
# Each image has 5 captions
n_captions = 5

all_img_vectors = np.concatenate((train_img_vectors, dev_img_vectors, test_img_vectors))
num_images = all_img_vectors.shape[0]
all_labels = np.arange(num_images)

num_train_images = train_img_vectors.shape[0]
train_labels = np.repeat(all_labels[:num_train_images], n_captions)

num_dev_images = dev_img_vectors.shape[0]
dev_labels = np.repeat(all_labels[num_train_images:num_train_images + num_dev_images], n_captions)

num_test_images = test_img_vectors.shape[0]
test_labels = np.repeat(all_labels[num_train_images + num_dev_images:], n_captions)

train_img_vectors = np.repeat(train_img_vectors, n_captions, axis=0)
dev_img_vectors = np.repeat(dev_img_vectors, n_captions, axis=0)
test_img_vectors = np.repeat(test_img_vectors, n_captions, axis=0)

In [None]:
train_img_vectors, train_text_vectors = torch.Tensor(train_img_vectors), torch.Tensor(train_text_vectors)
dev_img_vectors, dev_text_vectors = torch.Tensor(dev_img_vectors), torch.Tensor(dev_text_vectors)
test_img_vectors, test_text_vectors = torch.Tensor(test_img_vectors), torch.Tensor(test_text_vectors)

train_labels = torch.Tensor(train_labels)
dev_labels = torch.Tensor(dev_labels)
test_labels = torch.Tensor(test_labels)

In [None]:
arrays = {
    'train_img_vectors': train_img_vectors,
    'train_text_vectors': train_text_vectors,
    'dev_img_vectors': dev_img_vectors,
    'dev_text_vectors': dev_text_vectors,
    'test_img_vectors': test_img_vectors,
    'test_text_vectors': test_text_vectors,
}
print_shapes(arrays)

In [None]:
train_dataset = TensorDataset(train_img_vectors, train_text_vectors, train_labels)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

dev_dataset = TensorDataset(dev_img_vectors, dev_text_vectors, dev_labels)
dev_dataloader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=True)

test_dataset = TensorDataset(test_img_vectors, test_text_vectors, test_labels)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Model Definitions

In [None]:
class ImageModel(nn.Module):
    def __init__(self, img_dim, transformed_dim):
        super().__init__()
        self.fc1 = nn.Linear(img_dim, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, transformed_dim)
        self.l2_norm = nn.utils.weight_norm(self.fc3)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        # x = self.l2_norm(x)
        x = self.fc3(x)
        return x

In [None]:
class TextModel(nn.Module):
    def __init__(self, text_dim, transformed_dim):
        super().__init__()
        self.fc1 = nn.Linear(text_dim, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, transformed_dim)
        self.l2_norm = nn.utils.weight_norm(self.fc3)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        # x = self.l2_norm(x)
        x = self.fc3(x)
        return x

In [None]:
class MergedModel(nn.Module):
    def __init__(self, img_dim, text_dim, transformed_dim):
        super().__init__()
        self.img_model = ImageModel(img_dim, transformed_dim)
        self.text_model = TextModel(text_dim, transformed_dim)

    def forward(self, img_vectors, text_vectors):
        img_output = self.img_model(img_vectors)
        text_output = self.text_model(text_vectors)
        return img_output, text_output

In [None]:
model = MergedModel(img_dim=train_img_vectors.shape[1], text_dim=train_text_vectors.shape[1], transformed_dim=128)
loss_fn = torch.nn.TripletMarginLoss(margin=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

In [None]:
model = model.to(device)

In [None]:
dev_img_vectors = load_data(dev_img_vectors_path)
dev_img_transformed = model.img_model(torch.Tensor(dev_img_vectors))
dev_text_transformed = model.text_model(torch.Tensor(dev_text_vectors))

In [None]:
topk_accuracy(dev_img_transformed, dev_text_transformed, torch.arange(num_dev_images), torch.arange(num_dev_images).repeat_interleave(n_captions))

In [None]:
info = train_planned(model, train_dataloader, dev_dataloader, lambda a, b, labels: hard_negative_criterion(a, b, labels, loss_fn), device=device, num_epochs=13, early_stop=True, patience=5)

In [None]:
test_img_vectors = load_data(test_img_vectors_path)
test_img_transformed = model.img_model(torch.Tensor(test_img_vectors))
test_text_transformed = model.text_model(torch.Tensor(test_text_vectors))
topk_accuracy(test_img_transformed, test_text_transformed, torch.arange(num_test_images), torch.arange(num_test_images).repeat_interleave(n_captions), k=5)