Import some useful packages and set everything up.

(Remember to set the hardware accelerator to GPU in Notebook settings.)

In [1]:
%matplotlib inline
!pip install transformers

import copy
import math
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import unittest

from tqdm import tqdm

transformers.logging.set_verbosity_error()  # Disable needless warnings

def set_seed(seed):  # For reproducibility, fix random seeds.
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)



*Tip*: Google Colab is not really efficient in collecting memory from the GPU. If we create a model and put it on the GPU repeatedly it runs into out-of-memory problems. Use the function below to clear the model from the GPU memory whenever that happens.

In [2]:
import gc

def clear_from_gpu(thing):  # E.g., clear_from_gpu(model)
  with torch.no_grad():
    thing = None
    gc.collect()
    torch.cuda.empty_cache()

In [3]:
from google.colab import userdata
userdata.get('HF_TOKEN')

'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'

# Natural Questions (NQ) Dataset

Download [a preprocessed version of NQ](https://drive.google.com/file/d/1rHm8Nt3nIkmOeO2VU2YQPqdAoV0WQoMM/view?usp=share_link). We will assume that we have extracted the files (nq-train.json and nq-val.json) to the directory data/nq-toy/ in our Google Drive account. Let's load the data and stare at it.  

In [4]:
# Load the Drive helper and mount. You will have to authorize this operation.
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
import json

with open('/content/drive/My Drive/data/nq-toy/nq-train.json') as f:
  data_train_all = json.load(f)

with open('/content/drive/My Drive/data/nq-toy/nq-val.json') as f:
  data_val_all = json.load(f)

print('{:d} training and {:d} validation examples'.format(len(data_train_all), len(data_val_all)))
print('Each example has the following attributes: ' + str(list(data_train_all[0].keys())))

set_seed(42)
for example in random.sample(data_train_all, 5):
  print('\nQ: ' + example['question'])
  print('A: ' + str(example['answers']))
  print('C: ' + str(example['gold_context']))

58880 training and 6515 validation examples
Each example has the following attributes: ['question', 'answers', 'gold_context']

Q: who plays killer croc in the movie suicide squad
A: ['Adewale Akinnuoye - Agbaje']
C: {'title': 'Suicide Squad (film)', 'text': 'the film, and Scott Eastwood announced that he had been cast. Later that month, it was confirmed that Adewale Akinnuoye-Agbaje and Karen Fukuhara had been cast as Killer Croc and Katana, respectively. Adam Beach, Ike Barinholtz, and Jim Parrack were added to the cast in April 2015. In January 2016, Ben Affleck was confirmed to reprise his role as Batman from "". Filming began on April 13, 2015. On April 26 and 27, filming was to take place at the Hy\'s Steakhouse. A "snowstorm" scene was filmed on April 29 on the Adelaide St. and in Ching Lane. On May'}

Q: who is cate blanchett 's character in lord of the rings
A: ['Galadriel']
C: {'title': 'Galadriel', 'text': 'Annette Crosbie in Ralph Bakshi\'s 1978 animated film of "The Lord o

NQ consists of real-world Google queries (in question form), each manually labeled with a valid answer (possibly multiple answers, but mostly just one) and a piece of text from Wikipedia (~110 tokens), aka. a "context", from which the answer can be found and justified. This is a subset of NQ with short answers (prepared by [this paper](https://arxiv.org/pdf/1906.00300.pdf)).

Read the [original NQ paper](https://aclanthology.org/Q19-1026.pdf) about the excruciating care the authors took to ensure data quality (e.g., 5-way annotations for evaluation data).

## Reducing Data

Even though NQ is not a large dataset, for our purposes it will be too slow to train models if we use the whole dataset. Thus we will reduce the data size.

In [6]:
set_seed(42)
num_samples_train = 500
num_samples_val = 1000
data_train = random.sample(data_train_all, num_samples_train)
data_val = random.sample(data_val_all, num_samples_val)

## Knowledge Base (KB)

We consider open-domain question answering (ODQA) in which the KB consists of all passages (aka., contexts, text blocks) in Wikpedia, around 20 million. For obvious computational reasons we will use a small toy KB that consists only of the gold passages in the portion of NQ that we use.

In [7]:
KB = []
for example in data_train + data_val:
  KB.append({'title': example['gold_context']['title'], 'text': example['gold_context']['text']})
print(len(KB))

1500


# Dual Encoder

Our goal is to retrieve the right passage for a given question from the KB by training a dual encoder. The dual encoder has a trivial architecture:

1. Query encoder gets $Q \in \mathbb{R}^{B \times T}$ representing $B$ questions of length $T$ (assuming padding). It encodes it into $X \in \mathbb{R}^{B \times d}$ (i.e., $d$-dimensional question embeddings).

2. Passage encoder gets $P \in \mathbb{R}^{B' \times T'}$ representing $B'$ passages of length $T'$. It encodes it into $Y \in \mathbb{R}^{B' \times d'}$.

In [8]:
class DualEncoder(torch.nn.Module):

  def __init__(self, query_encoder, passage_encoder):
    super().__init__()
    self.query_encoder = query_encoder
    self.passage_encoder = passage_encoder

  def forward(self, Q=None, Q_mask=None, P=None, P_mask=None):
    X = self.query_encoder(Q, Q_mask) if Q is not None else None
    Y = self.passage_encoder(P, P_mask) if P is not None else None
    return X, Y

We are free to choose whatever encoders we like. We will use [DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert), which is a distilled version of [BERT](https://arxiv.org/pdf/1810.04805.pdf) with "only" 66 million parameters. We will take the embedding of the [CLS] token which is prepended to every input sequence as a single-vector representation of the sequence (if you don't know what [CLS] is, you need to review BERT).


In [9]:
from transformers import AutoTokenizer, DistilBertModel

In [10]:
class DistilBertEncoder(torch.nn.Module):

  def __init__(self):
    super().__init__()
    self.encoder = DistilBertModel.from_pretrained('distilbert-base-uncased')

  def forward(self, input_ids, attention_mask=None):
    # Passing the input_ids and attention_mask through the encoder to get the sequence of hidden states.
    outputs = self.encoder(input_ids, attention_mask=attention_mask)
    # Extracting the embeddings for the [CLS] token (i.e.the first token) from each sequence in the batch.
    cls_embeddings = outputs.last_hidden_state[:, 0, :]
    return cls_embeddings

In [11]:
class TestEncoder(unittest.TestCase):

  def setUp(self):
    self.input_ids = torch.LongTensor([[101, 1001, 1002, 1003, 1004], [101, 207, 350, 394, 999], [101, 10000, 10000, 10000, 10000]])
    self.places = 4

  def test_encoder(self):
    model = DistilBertEncoder()
    embeddings = model(self.input_ids)
    self.assertEqual(embeddings.dim(), 2)
    self.assertEqual(embeddings.size(0), 3)
    self.assertEqual(embeddings.size(1), 768)  # This is the hidden dimension of DistilBERT
    self.assertAlmostEqual(embeddings[0, 0].item(), 0.2982, places=self.places)
    self.assertAlmostEqual(embeddings[1, 0].item(), -0.0705, places=self.places)
    self.assertAlmostEqual(embeddings[2, 0].item(), 0.1151, places=self.places)

unittest.main(TestEncoder(), argv=[''], verbosity=2, exit=False)

test_encoder (__main__.TestEncoder) ... 

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

ok

----------------------------------------------------------------------
Ran 1 test in 8.934s

OK


<unittest.main.TestProgram at 0x7c2c9869f250>

We must use the same tokenizer used in DistilBERT. We will load it as a constant variable and use it throughout the assignment.

In [12]:
distilbert_tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [13]:
i1 = 0
i2 = 2
# Example (batch) encoding
texts = [data_train[i1]['gold_context']['title'], data_train[i2]['gold_context']['title']]
encoded = distilbert_tokenizer(texts, padding=True, truncation=True, max_length=10, return_tensors='pt')  # Try changing the max length
print(encoded.keys())  # input_ids, attention_mask
print(encoded['input_ids'].shape)
print(encoded['input_ids'])  # Special tokens [CLS] (index 101), [SEP] (index 102), [PAD] (index 0) are automatically inserted upon tokenization
print(encoded['attention_mask'])

# Example (batch) decoding
decoded = distilbert_tokenizer.batch_decode(encoded['input_ids'], skip_special_tokens=False)
print(decoded)

# BERT-style tokenizers also support appending a second text, as in "[CLS] text [SEP] another text [SEP]".
second_texts = [data_train[i1]['gold_context']['text'], data_train[i2]['gold_context']['text']]
encoded = distilbert_tokenizer(texts, second_texts, padding=True, truncation=True, max_length=15, return_tensors='pt')  # Try changing the max length
print('\n', encoded['input_ids'].shape)
print(encoded['input_ids'])
print(encoded['attention_mask'])
decoded = distilbert_tokenizer.batch_decode(encoded['input_ids'], skip_special_tokens=False)
print(decoded)

dict_keys(['input_ids', 'attention_mask'])
torch.Size([2, 7])
tensor([[  101,  5920,  4686,  1006,  2143,  1007,   102],
        [  101, 17266, 26692,  6632,  2278,  4101,   102]])
tensor([[1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1]])
['[CLS] suicide squad ( film ) [SEP]', '[CLS] sacroiliac joint [SEP]']

 torch.Size([2, 15])
tensor([[  101,  5920,  4686,  1006,  2143,  1007,   102,  1996,  2143,  1010,
          1998,  3660, 24201,  2623,   102],
        [  101, 17266, 26692,  6632,  2278,  4101,   102, 17266, 26692,  6632,
          2278,  4101,  1996, 17266,   102]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
['[CLS] suicide squad ( film ) [SEP] the film, and scott eastwood announced [SEP]', '[CLS] sacroiliac joint [SEP] sacroiliac joint the sac [SEP]']


We're going to cut off the input and throw away any text beyond length $T_{\max}$ for computational reasons. Remember that the runtime is quadratic in the input length for transformers.

(You don't want to do this in reality since crucial information might be at the end of a passage!)

In [14]:
questions_encoded = distilbert_tokenizer([example['question'] for example in data_train], padding=True, return_tensors='pt')
passages_encoded = distilbert_tokenizer([passage['title'] for passage in KB],
                                        [passage['text'] for passage in KB], padding=True, return_tensors='pt')
max_question_length = questions_encoded['input_ids'].size(1)
max_passage_length = passages_encoded['input_ids'].size(1)
print('max question length: {:d}'.format(max_question_length))
print('max passage length: {:d}'.format(max_passage_length))

MAX_INPUT_LENGTH=100
print('OUR max input length: {:d}'.format(MAX_INPUT_LENGTH))

max question length: 23
max passage length: 210
OUR max input length: 100


Putting together, here's a function that produces a DistilBERT-initialized dual encoder. We will have an option to tie the query and passage encoders. Encoder tying will clearly be more computationally efficient, but it may not always work as well as not tying.

In [15]:
def make_distilbert_dual_encoder(tie_encoders=False):
  query_encoder = DistilBertEncoder()
  passage_encoder = query_encoder  if tie_encoders else DistilBertEncoder()
  model = DualEncoder(query_encoder, passage_encoder)
  return model

# Evaluation

## Data Stuff

We first need a PyTorch dataset to feed our data to a DataLoader. We will reuse the same Dataset class for both the NQ data (which contains questions and passages) and the KB (which contains only passages) since both are just list of samples.


In [16]:
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):

  def __init__(self, data):
    self.samples = data

  def __len__(self):
    return len(self.samples)

  def __getitem__(self, index):
    return self.samples[index]

Let's write code to turn a batch of training/evaluation data into torch-friendly tensors. A sample has question, gold_context, and optionally hard_negative.

In [17]:
def tensorize_contrastive(samples, use_hard_negative=False):
  queries = []
  labels = []
  titles = []
  texts = []

  for sample in samples:
    queries.append(sample['question'])
    labels.append(len(titles))  # Index of correct context for the question

    titles.append(sample['gold_context']['title'])
    texts.append(sample['gold_context']['text'])

    if use_hard_negative and 'hard_negative' in sample:
      titles.append(sample['hard_negative']['title'])
      texts.append(sample['hard_negative']['text'])

  queries = distilbert_tokenizer(queries, padding=True, truncation=True, max_length=MAX_INPUT_LENGTH, return_tensors='pt')
  passages = distilbert_tokenizer(titles, texts, padding=True, truncation=True, max_length=MAX_INPUT_LENGTH, return_tensors='pt')

  Q, Q_mask = queries['input_ids'], queries['attention_mask']  # (batch_size, question_length)
  P, P_mask = passages['input_ids'], passages['attention_mask'] # (batch_size, passage_length) if no hard negatives else (2*batch_size, passage_length)
  labels = torch.LongTensor(labels)  # (batch_size,) elements in [0, batch_size) or [0, 2 * batch_size)

  return Q, Q_mask, P, P_mask, labels

Let's also write code to tensorize a batch of the KB, which only consists of passages.

In [18]:
def tensorize_KB_batch(passages):
  titles = [passage['title'] for passage in passages]
  texts = [passage['text'] for passage in passages]

  # We must make sure we encode passages the same way as we do in training.
  passages = distilbert_tokenizer(titles, texts, padding=True, truncation=True, max_length=MAX_INPUT_LENGTH, return_tensors='pt')
  P, P_mask = passages['input_ids'], passages['attention_mask']  # (batch_size, passage_length)
  return P, P_mask

## Recall

Assuming a trained dual encoder, we can embed all KB passages offline.

In [19]:
def embed_KB(KB, model, batch_size, device='cuda', disable_tqdm=True):
  model.eval()
  collate_fn = lambda samples: tensorize_KB_batch(samples)
  loader = DataLoader(MyDataset(KB), batch_size, shuffle=False, collate_fn=collate_fn)
  passage_embeddings = []
  with torch.no_grad():
    for P, P_mask in tqdm(loader, disable=disable_tqdm):
      _, Y = model(Q=None, Q_mask=None, P=P.to(device), P_mask=P_mask.to(device))
      passage_embeddings.append(Y.cpu())
  passage_embeddings = torch.cat(passage_embeddings, 0)  # (|KB|, dim)
  return passage_embeddings

In a real-world implementation the passage embedding computation must be done in parallel using many GPUs. Assuming that the passages in the KB are already embedded, at test time we only need to embed the questions and perform maximum inner product search (MIPS). Let's write code for embedding questions.

In [20]:
def embed_questions(data, model, batch_size, device='cuda', disable_tqdm=True):
  model.eval()
  collate_fn = lambda samples: tensorize_contrastive(samples, use_hard_negative=False)
  loader = DataLoader(MyDataset(data), batch_size, shuffle=False, collate_fn=collate_fn)
  question_embeddings = []
  with torch.no_grad():
    for Q, Q_mask, _, _, _ in tqdm(loader, disable=disable_tqdm):
      X, _= model(Q=Q.to(device), Q_mask=Q_mask.to(device), P=None, P_mask=None)
      question_embeddings.append(X.cpu())
  question_embeddings = torch.cat(question_embeddings, 0)  # (|val|, d)
  return question_embeddings

Write code to compute top-$k$ candidate passages per question.

In [21]:
def get_topk_candidates(KB, data, model, batch_size, num_candidates, device='cuda', disable_tqdm=True):
  passage_embeddings = embed_KB(KB, model, batch_size, device=device, disable_tqdm=disable_tqdm)
  question_embeddings = embed_questions(data, model, batch_size, device=device, disable_tqdm=disable_tqdm)
  # Performing MIPS between the questions and the passage embeddings
  scores = torch.matmul(question_embeddings, passage_embeddings.T)
  values, indices = torch.topk(scores, num_candidates)
  results = []
  for i in range(indices.size(0)):
    result = {'question': data[i]['question'],
              'answers': data[i]['answers'],
              'gold_title': data[i]['gold_context']['title'],
              'topk_titles': [KB[candidate_index]['title'] for candidate_index in indices[i]],
              'topk_indices': indices[i].tolist(),
              'topk_values': values[i].tolist(),
              }
    results.append(result)
  return results

Given results, we can compute Recall@$k$ for various values of $k$.

In [22]:
def topk_retrieval_accuracy(results, k_values=[1, 5, 10, 20, 100]):
  k_num_correct = {}
  for k in k_values:
    k_num_correct[k] = 0
  for result in results:
    rank_min = float('inf')
    for rank, title in enumerate(result['topk_titles']):
      if title == result['gold_title']:
        rank_min = rank
        break
    for k in k_values:
      if rank_min < k:
        k_num_correct[k] += 1

  num_queries = len(results)
  k_accuracy = {k: num_correct / num_queries * 100. for (k, num_correct) in k_num_correct.items()}
  k_accuracy['num_queries'] = num_queries
  return k_accuracy

Let's try evaluating the recall of an untrained dual encoder. Using max length 100 (thus failing to look at all of a passage) for computational reasons.

In [23]:
class TestEval(unittest.TestCase):

  def setUp(self):
    self.places = 4

  def test_eval(self):
    model = make_distilbert_dual_encoder(tie_encoders=False).to('cuda')
    results = get_topk_candidates(KB, data_val, model, batch_size=600, num_candidates=100, device='cuda', disable_tqdm=True)
    k_accuracy = topk_retrieval_accuracy(results, k_values=[1, 5, 10, 20, 100])
    self.assertAlmostEqual(k_accuracy[1], 6.20, places=self.places)
    self.assertAlmostEqual(k_accuracy[5], 11.30, places=self.places)
    self.assertAlmostEqual(k_accuracy[10], 14.50, places=self.places)
    self.assertAlmostEqual(k_accuracy[20], 19.20, places=self.places)
    self.assertAlmostEqual(k_accuracy[100], 37.20, places=self.places)
    clear_from_gpu(model)

unittest.main(TestEval(), argv=[''], verbosity=2, exit=False)

test_eval (__main__.TestEval) ... ok

----------------------------------------------------------------------
Ran 1 test in 8.697s

OK


<unittest.main.TestProgram at 0x7c2c90d0ef20>

Since the model is not finetuned, the recall is not good. Let's see if we can improve it by NCE training.

# Noise Contrastive Estimation (NCE)

The goal of a dual encoder is to retrieve the right passage (i.e., context) $p \in \mathcal{V}^+$ for a given question $q \in \mathcal{V}^+$. The dual encoder defines the score function

$$
\mathrm{score}_\theta(q,p) = \mathrm{QueryEncoder}_\theta(q) \cdot \mathrm{PassageEncoder}_\theta(p)
$$

(In our case, both encoders are the same DistilBERT.) Thus at test time, given any query $q$ the retriever outputs $p^\star$ with $\max \mathrm{score}_\theta(q,p)$. The usual way to train a model is by minimizing the empirical cross-entropy loss
$$
\hat{J}(\theta) = - \frac{1}{N} \sum_{i=1}^N \log \frac{\exp(\mathrm{score}_\theta(q_i,p_i))}{\sum_{p \in \textrm{KB}} \exp(\mathrm{score}_\theta(q_i,p))}
$$
The problem is that KB is generally too large to compute the normalization term. In NCE, we instead minimize
$$
\hat{J}_{\text{nce}}(\theta) = - \frac{1}{N} \sum_{i=1}^N \log \frac{\exp(\mathrm{score}_\theta(q_i,p_i))}{\sum_{p \in \{p_i\} \cup \mathcal{N}_i} \exp(\mathrm{score}_\theta(q_i,p))}
$$
where $\mathcal{N}_i$ is a much smaller set of **negative** (i.e., incorrect) passages for question $i$. The choice of negative examples is very important for both theoretical and empirical reasons. See [this note](http://karlstratos.com/notes/nce.pdf) for details.

A naive way to implement NCE is to prepare a set of negative passages for every question in each epoch. A more computationally efficient way to simulate this is **in-batch** negative sampling, which uses the fact that other *gold* passages in the same batch can be treated as negative examples for the given question. It's also possible to sneak in so-called "hard negatives". That is, each question, in addition to a gold passage, can be associated with an incorrect but correct-looking passages. If we include such hard negatives in the in-batch negative sampling scheme, the model will be trained to distinguish the gold passage from the hard negative. See the following illustration from  [Maillard et al., 2021](https://arxiv.org/pdf/2101.00117.pdf):

<p align="center">
<img src='https://drive.google.com/uc?id=1qK3t0NnSEqTjM17J6hYnmSNTtEwSQWb8' width="600">
</p>

Now write a function that computes the in-batch NCE loss.


In [24]:
def get_in_batch_contrastive_loss(model, batch, device='cuda'):
  Q, Q_mask, P, P_mask, labels  = [tensor.to(device) for tensor in batch]
  # Get embeddings from the model
  Q_embeddings = model.query_encoder(Q, attention_mask=Q_mask)
  P_embeddings = model.passage_encoder(P, attention_mask=P_mask)

  # Calculate the score matrix by taking the dot product of Q and P embeddings
  scores = torch.matmul(Q_embeddings, P_embeddings.T)

  # Calculate the NCE loss
  loss = F.cross_entropy(scores, labels)

  # Calculate the number of correct predictions
  num_correct = (scores.argmax(dim=1) == labels).sum()
  return loss, num_correct

In [25]:
class TestGetInBatchContrastiveLoss(unittest.TestCase):

  def setUp(self):
    set_seed(42)
    self.samples = [
        {'question': 'a b c d', 'gold_context': {'title': 'e f g', 'text': 'the dog saw the cat'}, 'hard_negative': {'title': 'h i', 'text': 'the cat saw the dog'}},
        {'question': '1 2 3 4', 'gold_context': {'title': '5 6 7', 'text': 'foo bar'}, 'hard_negative': {'title': '8 9', 'text': 'z y'}},
        {'question': 'A B C D', 'gold_context': {'title': 'E F G', 'text': 'the DOG'}, 'hard_negative': {'title': 'H I', 'text': 'GOD'}}
    ]
    self.places = 4
    self.device = 'cpu'

  def test_get_in_batch_contrastive_loss(self):
    model = make_distilbert_dual_encoder()
    batch = tensorize_contrastive(self.samples, use_hard_negative=True)
    loss, num_correct = get_in_batch_contrastive_loss(model, batch, device=self.device)

    self.assertAlmostEqual(loss.item(), 1.1979, places=self.places)
    self.assertEqual(num_correct.item(), 1)

unittest.main(TestGetInBatchContrastiveLoss(), argv=[''], verbosity=2, exit=False)

test_get_in_batch_contrastive_loss (__main__.TestGetInBatchContrastiveLoss) ... ok

----------------------------------------------------------------------
Ran 1 test in 2.579s

OK


<unittest.main.TestProgram at 0x7c2c986cce20>

# Training

We'll use the standard setting for finetuning BERT.

In [26]:
from transformers import get_linear_schedule_with_warmup

def configure_optimization(model, num_train_steps, num_warmup_steps, lr, weight_decay=0.01):
  # Copied from: https://huggingface.co/transformers/training.html
  no_decay = ['bias', 'LayerNorm.weight']
  optimizer_grouped_parameters = [{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                                   'weight_decay': weight_decay},
                                  {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                                   'weight_decay': 0.}]
  optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr)
  scheduler = get_linear_schedule_with_warmup(optimizer, num_training_steps=num_train_steps, num_warmup_steps=num_warmup_steps)
  return optimizer, scheduler

In [27]:
def train_nce(model, batch_size, epochs, lr, weight_decay=0.01, use_hard_negative=False, disable_tqdm=True):
  set_seed(42)
  collate_fn = lambda samples: tensorize_contrastive(samples, use_hard_negative=use_hard_negative)
  loader_train = DataLoader(MyDataset(data_train), batch_size, shuffle=True, collate_fn=collate_fn)
  loader_val = DataLoader(MyDataset(data_val), batch_size, shuffle=False, collate_fn=collate_fn)
  num_train_steps = len(loader_train) * epochs
  optimizer, scheduler = configure_optimization(model, num_train_steps, int(0.1 * num_train_steps), lr)

  best_r5 = 0.  # Let's use Recall@5 for model selection
  best_state_dict = None
  for epoch in range(epochs):
    model.train()
    loss_total = 0.

    for batch in tqdm(loader_train, disable=disable_tqdm):
      loss, num_correct = get_in_batch_contrastive_loss(model, batch, device='cuda')
      loss_total += loss.item()
      loss.backward()

      torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
      optimizer.step()
      scheduler.step()
      model.zero_grad()

    loss_avg = loss_total / len(loader_train)
    results = get_topk_candidates(KB, data_val, model, batch_size=600, num_candidates=100, device='cuda', disable_tqdm=True)
    k_accuracy = topk_retrieval_accuracy(results, k_values=[1, 5, 10, 20, 100])
    info = 'Epoch {:3d} | average loss {:8.4f} | R@1 {:5.2f} | R@5 {:5.2f} | R@10 {:5.2f} | R@20 {:5.2f} | R@100 {:5.2f}'.format(
        epoch, loss_avg, k_accuracy[1], k_accuracy[5], k_accuracy[10], k_accuracy[20], k_accuracy[100])
    if k_accuracy[5] > best_r5:
      best_r5 = k_accuracy[5]
      best_state_dict = copy.deepcopy(model.state_dict())
      info += ' <-------------- (saved)'
    print(info)

  if best_state_dict is not None:
    model.load_state_dict(best_state_dict)

In [28]:
model = make_distilbert_dual_encoder(tie_encoders=False).to('cuda')
train_nce(model, 128, 5, 1e-4, disable_tqdm=True)

Epoch   0 | average loss  10.1849 | R@1 26.10 | R@5 49.20 | R@10 56.70 | R@20 63.40 | R@100 81.50 <-------------- (saved)
Epoch   1 | average loss   2.2911 | R@1 40.20 | R@5 65.80 | R@10 71.80 | R@20 79.10 | R@100 92.40 <-------------- (saved)
Epoch   2 | average loss   1.1167 | R@1 53.40 | R@5 75.70 | R@10 82.40 | R@20 86.10 | R@100 95.10 <-------------- (saved)
Epoch   3 | average loss   0.6704 | R@1 60.40 | R@5 82.30 | R@10 86.30 | R@20 90.50 | R@100 96.90 <-------------- (saved)
Epoch   4 | average loss   0.4350 | R@1 62.60 | R@5 83.10 | R@10 87.60 | R@20 91.40 | R@100 97.70 <-------------- (saved)


So that's substantially better than an untrained dual encoder. Let's see if we can improve this even further by hard negative mining.

## Hard Negative Mining

The choice of hard negatives is flexible (e.g., you can use an off-the-shelf retriever like BM25 and take the top incorrect candidate). But they can be mined from the trained model itself.

In [29]:
results = get_topk_candidates(KB, data_train, model, batch_size=600, num_candidates=2, device='cuda', disable_tqdm=True)

In [30]:
set_seed(42)
for result in random.sample(results, 7):
  print('Gold:         ', result['gold_title'])
  print('1st candidate:', result['topk_titles'][0])
  print('2nd candidate:', result['topk_titles'][1])
  print()

Gold:          Classical conditioning
1st candidate: Classical conditioning
2nd candidate: Wilhelm Wundt

Gold:          New York Mets
1st candidate: New York Mets
2nd candidate: 2016 World Series

Gold:          Water distribution on Earth
1st candidate: Water distribution on Earth
2nd candidate: Earliest known life forms

Gold:          Kirby's Epic Yarn
1st candidate: Kirby's Epic Yarn
2nd candidate: The Incredibles

Gold:          Little Shop of Horrors (film)
1st candidate: Little Shop of Horrors (film)
2nd candidate: Hotel Transylvania

Gold:          I'll Be There (The Jackson 5 song)
1st candidate: Ten Years After
2nd candidate: Fooled Around and Fell in Love

Gold:          Chorion
1st candidate: Meninges
2nd candidate: Cytokinesis



We see that in many cases the correct passage is ranked at the top. We will take the 2nd passage as a hard negative in that case. If the top passage is not correct, then that's our hard negative for that example.

In [31]:
for i, example in enumerate(data_train):
  topk_result = results[i]

  # Checking if the top passage is the correct one
  if topk_result['topk_titles'][0] == example['gold_context']['title']:
      # If yes, take the 2nd passage as a hard negative
      i_hard = topk_result['topk_indices'][1]
  else:
      # If not, take the top passage as a hard negative
      i_hard = topk_result['topk_indices'][0]

  # Populating the hard_negative in the example
  example['hard_negative'] = KB[i_hard]  # Populating hard_negative in examples allows tensorize_contrastive to use hard negatives.

Now let's try training a dual encoder again, this time using $2B-1$ negative examples in NCE rather than $B-1$ including $B$ hard negatives. We can continue training the existing model, but we are limited in GPU memory here so we'll just start from the raw DistilBERT model again.

In [32]:
clear_from_gpu(model)
model = make_distilbert_dual_encoder(tie_encoders=False).to('cuda')
train_nce(model, 64, 5, 1e-4, use_hard_negative=True, disable_tqdm=True)  # Need a smaller batch size since each batch has dimension B * 2B now

Epoch   0 | average loss   8.0642 | R@1 31.60 | R@5 54.20 | R@10 62.60 | R@20 69.50 | R@100 85.80 <-------------- (saved)
Epoch   1 | average loss   1.8812 | R@1 51.50 | R@5 74.20 | R@10 80.20 | R@20 85.50 | R@100 94.90 <-------------- (saved)
Epoch   2 | average loss   0.7474 | R@1 62.40 | R@5 84.20 | R@10 88.50 | R@20 91.60 | R@100 98.00 <-------------- (saved)
Epoch   3 | average loss   0.3230 | R@1 65.50 | R@5 85.90 | R@10 90.20 | R@20 93.10 | R@100 98.30 <-------------- (saved)
Epoch   4 | average loss   0.1838 | R@1 65.70 | R@5 86.50 | R@10 90.20 | R@20 93.70 | R@100 98.30 <-------------- (saved)


So that seems to converge faster and to a better model. You can iterate (i.e., keep doing hard negative mining). In practice, you can come up with various training schemes using hard negatives. For example "refresh" hard negatives every epoch.

Hard negative mining was proposed for entity retrieval ([Gillick et al., 2019](https://aclanthology.org/K19-1049.pdf)). Also see [this paper](https://arxiv.org/pdf/2104.06245.pdf) for an anlysis of hard negative mining.

## Results

TODO: Fill in the following table with the *best* validation recalls you could get without and with hard negative mining. You are encouraged to try different hyperparameter values like doing sufficiently many epochs to converge, using different learning rates and batch sizes.

| Model   | Recall@1  |   Recall@5 | Recall@10 | Recall@20 |  Recall@100 |
| :---:   | :---:     |:---:       | :---:     | :---:     | :---:       |
| Untrained |  6.20 |  11.30     | 14.50     | 19.20     |  37.20  |
| NCE       |  62.60  | 83.10 | 87.60 | 91.40 | 97.70 |
| NCE+hard  |   65.70 | 86.50| 90.20 |93.70 | 98.30|

# Appendix

Learnable retrieval is generally useful for a variety of tasks. (E.g., finding $k$ documents that are most "similar" to a query document is of intrinsic interest.)

This assignment focused on a particular retrieval setting, namely retrieving evidence passages for QA. This is a so-called "retrieve-then-read" pipeline, where given a question $q$ the model retrieves top-$k$ passage candidates and then reads them to produce the final answer. The answer module can be any language model (encoder-decoder or decoder-only) that conditions on the question and passages and predicts an answer string. A popular model is [Fusion-in-Decoder](https://arxiv.org/pdf/2007.01282.pdf). It's simply a standard encoder-decoder model (e.g., you can use T5) except that you preprocess the input to perform full self-attention over $(q, p)$ pairs separately where $p$ is a single passage, but the decoder conditions on *all* of the encoded tokens. This can scale very well in the number of passages to condition on. If you naively concatenate $(q, p_1 \ldots p_k)$ and use it as the input to a transformer encoder, it will not scale well. With FiD, the model can condition on more than a hundard candidate passages.

<p align="center">
<img src='https://drive.google.com/uc?id=1MtZVfA2vzts29hK5DvjeuLNBGPlNUOLa' width="1000">
</p>

Also, in this assignment we assume that annotations for retrieval are available. I.e., somebody annotated the gold passage for each question. This is expensive and it's becoming more and more clear that it is not necessary. There are methods to learn to retrieve by doing a downstream task like language modeling and QA, without explicit retrieval annotations. They are based on marginalization (e.g., [REALM](https://arxiv.org/pdf/2002.08909.pdf), [EMDR2](https://arxiv.org/pdf/2106.05346.pdf)) or, more compellingly, distillation (e.g., [FiD-KD](https://arxiv.org/pdf/2012.04584.pdf)).

Lastly, instead of learning a retriever, we can use existing search engines like Google as an API for the model to use. This is exactly what you would want if your retrieval needs are consistent with what a search engine provides (e.g., [Sparrow](https://arxiv.org/pdf/2209.14375.pdf)). However, it's not what you want if you need your retriever to be specialized in some unique task not suitable for generic search engines (e.g., [entity linking](https://arxiv.org/pdf/2110.02369.pdf)).