In [None]:
!pip install datasets
!pip install transformers
!pip install -U sentence-transformers
!pip install rank_bm25

from datasets import load_dataset
from datasets import get_dataset_config_names

import torch
import numpy as np
import time
from rank_bm25 import BM25Okapi

# For debugging torch
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

from google.colab import drive
drive.mount('/content/drive')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using {device}')

from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model_q = BertModel.from_pretrained('bert-base-uncased').to(device)
model_p = BertModel.from_pretrained('bert-base-uncased').to(device)

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

In [None]:
### Settings ###
#@title Dataset Setting 
dataset = 'COGS' #@param ["SCAN", "CFQ", "COGS"]

print(f'Using {dataset} dataset')

%cd /content/drive/My Drive/Colab Notebooks/UCL MSc Project/Data

data = load_dataset('csv', data_files={'train': f"./{dataset.lower()}_train.csv", 'test': f"./{dataset.lower()}_test.csv"})

train = data['train']
test = data['test']

if dataset == 'SCAN':
  input = 'commands'
  target = 'actions'

if dataset == 'COGS':
  input = 'source'
  target = 'target'

In [None]:
def MC_pair_loss(q, p_pos, p_neg:list, similarity_metric='dot'):
  # CLS hidden state of each for each of the representation
  # one query, one p positive, many p negatives

  if similarity_metric == 'dot':   
    # Normalise the vectors to prevent arithmetic overflow with the exponentials
    p_pos = torch.nn.functional.normalize(p_pos, dim=0)
    q = torch.nn.functional.normalize(q, dim=0)

    numerator = torch.exp(torch.dot(q, p_pos))

    denominator = 0
    for i in range(len(p_neg)):
      denominator += torch.exp(torch.dot(q, torch.nn.functional.normalize(p_neg[i][0], dim=0)))
    denominator += torch.exp(torch.dot(q, p_pos))

  loss = -torch.log(numerator/denominator)

  return loss


In [None]:
def format_pos_neg(LIST, B=50):

  import random
  positive_list = []
  negative_list = []
  for i in range(len(LIST)):
    negative = []
    # Sample a positive from top k of the instance
    positive_list.append(random.choice(LIST[i][0]))
    # Sample a negative from bottom k of the instance
    negative.append(random.choice(LIST[i][1]))

    for j in range(B-1):
      # Sample an instance that is not the current instance
      while True:
        pick = random.randint((i//B) * B, ((i//B + 1) * B) -1)
        if i != pick:
          break

      # Sample a negative from each of top and bottom k from the picked instance
      # negative.append(random.choice(LIST[pick][0]))
      negative.append(random.choice(LIST[pick][1]))
    
    negative_list.append(negative)
  
  return positive_list, negative_list


In [None]:
# Training loop
def train_mepr():
  import pandas as pd
  import ast
  # Load top-p and bottom-p data

  batch_size = 50
  a, b = format_pos_neg(pd.read_csv(f'/content/drive/My Drive/Colab Notebooks/UCL MSc Project/Top and Bottom Five/{dataset}_mepr', converters={0:ast.literal_eval, 1:ast.literal_eval}).to_numpy(),
                        B=batch_size)

  optimizer =  torch.optim.Adam(list(model_q.parameters()) + list(model_p.parameters()), lr=10**-4)
  scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor =1, end_factor=0.1, total_iters=10)


  batches = len(a) // batch_size

  epoch_loss_storage = []

  for epoch in range(30):
    # Per epoch
    epoch_loss = 0
    for batch in range(batches):
      # Per batch
      loss = 0
      torch.cuda.empty_cache()
      optimizer.zero_grad()
      for i in range(batch*batch_size, (batch+1)*batch_size):
        # Per data point
        q_rep = model_q(**tokenizer(test[i][input], return_tensors='pt').to(device)).last_hidden_state[0][0]
        p_rep_positive = model_p(**tokenizer(train[a[i]][input], return_tensors='pt').to(device)).last_hidden_state[0][0]
        p_rep_negative = model_p(**tokenizer(train.select(b[i])[input], return_tensors='pt', padding=True).to(device)).last_hidden_state
        loss += MC_pair_loss(q_rep, p_rep_positive, p_rep_negative, similarity_metric='dot') # loss calculation 
      loss /= batch_size
      loss.backward()
      optimizer.step()
      epoch_loss += loss.item()
      print(f'Loss per {batch}batch {loss.item()} at epoch {epoch}')

    
    epoch_loss_storage.append(epoch_loss/batches)
    print(f'epoch {epoch } loss is {epoch_loss/batches}')
    scheduler.step()
      
    model_q.save_pretrained(f'/content/drive/My Drive/Colab Notebooks/UCL MSc Project/Dense Retriever/{dataset}_E_q_mepr')
    model_p.save_pretrained(f'/content/drive/My Drive/Colab Notebooks/UCL MSc Project/Dense Retriever/{dataset}_E_p_mepr')   