In [None]:
!pip install sentence-transformers

In [2]:

"""
This examples show how to train a Cross-Encoder for the SQuAD Dataset.

The query and the passage are passed simoultanously to a Transformer network. The network then returns
a score between 0 and 1 how relevant the passage is for a given query.

The resulting Cross-Encoder can then be used for passage re-ranking: You retrieve for example 100 passages
for a given query, for example with ElasticSearch, and pass the query+retrieved_passage to the CrossEncoder
for scoring. You sort the results then according to the output of the CrossEncoder.

This gives a significant boost compared to out-of-the-box ElasticSearch / BM25 ranking.

Running this script:
python train_cross-encoder.py
"""
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, LoggingHandler, util
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers import InputExample
from google.colab import data_table
from pathlib import Path
import torch
import logging
from datetime import datetime
import gzip
import pandas as pd
import os
import tarfile
from tqdm import tqdm
import json
import shutil
import csv

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
from google.colab import drive
drive.mount('/gdrive', force_remount=True)

root = '/gdrive/MyDrive/Project 2/retrieve-rerank'
result_dir = Path(root) / 'results'

Mounted at /gdrive


### loading corpus and questions

In [4]:
### Now we read the SQuAD dataset
data_folder = 'squad-data'
os.makedirs(data_folder, exist_ok=True)

def load_dataset(path, url):
  contexts = dict()
  questions = dict()
  golden_context_ids = dict()

  if not os.path.exists(path):
    logging.info("Download SQuAD dev-v1.1 ...")
    util.http_get(url, path)


  with open(path, 'r', encoding='utf8') as json_file:
    json_object = json.load(json_file)
    context_id = 0
    question_id = 0

    for i in range (len(json_object['data'])):
      data = json_object['data'][i]

      for j in range(len(data['paragraphs'])):
        paragraph = data['paragraphs'][j]
        contexts[str(context_id)] = paragraph['context']

        for k in range(len(paragraph['qas'])):
          questions[str(question_id)] = paragraph['qas'][k]['question']
          golden_context_ids[str(question_id)] = context_id

          question_id += 1

        context_id += 1

  return contexts, questions, golden_context_ids

train_contexts, train_questions, train_golden_context_ids = load_dataset(os.path.join(data_folder, 'train-set'), 'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json')
dev_contexts, dev_questions, dev_golden_context_ids = load_dataset(os.path.join(data_folder, 'dev-set'), 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json')

  0%|          | 0.00/8.12M [00:00<?, ?B/s]

  0%|          | 0.00/1.05M [00:00<?, ?B/s]

## Testing

In [5]:
cross_encoder_bm25_injection = torch.load(os.path.join(result_dir, 'cross-encoder-bm25-injection-v_1'))
cross_encoder_noinjection = torch.load(os.path.join(result_dir, 'cross-encoder-noinjection-v_2'))
# cross_encoder_raw = torch.load(os.path.join(result_dir, 'cross-encoder-bm25-injection-v_1'))
cross_encoder_raw = CrossEncoder('distilroberta-base')
# cross_encoder_raw = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

config.json:   0%|          | 0.00/480 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/331M [00:00<?, ?B/s]

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.out_proj.bias', 'classifier.out_proj.weight', 'classifier.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [6]:
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
bi_encoder.max_seq_length = 256     #Truncate long passages to 256 tokens
num_doc_retrieve = 32

dev_context_embeds = bi_encoder.encode(list(dev_contexts.values()), convert_to_tensor=True, show_progress_bar=True)

.gitattributes:   0%|          | 0.00/737 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/11.5k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

data_config.json:   0%|          | 0.00/25.5k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/383 [00:00<?, ?B/s]

train_script.py:   0%|          | 0.00/13.8k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

Batches:   0%|          | 0/65 [00:00<?, ?it/s]

In [7]:
def search(query, k, cross_encoder):

  ### Retrieve ###
  # Encode the query using the bi-encoder and find potentially relevant passages
  question_embed = bi_encoder.encode(query, convert_to_tensor=True).cuda()
  hits = util.semantic_search(question_embed, dev_context_embeds, top_k=num_doc_retrieve)
  hits = hits[0] # Get the hits for the first query

  ### Re-Ranking ###
  # Score all retrieved passages with the cross_encoder
  cross_input = [[query, dev_contexts[str(hit['corpus_id'])]] for hit in hits]
  cross_scores = cross_encoder.predict(cross_input)

  # Sort results by the cross-encoder scores
  for i in range(len(cross_scores)):
    hits[i]['cross-score'] = cross_scores[i]

  hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
  top_k_context_id = [hit['corpus_id'] for hit in hits[0:k]]

  return top_k_context_id

Just find top k accuracies for both k=5 and k=20.

In [None]:
def find_topk_acc(top_k, cross_encoder):
  correct = 0
  count = 0

  for id in tqdm(dev_questions.keys(), mininterval = 3, desc ="Evaluating..."):
    question = dev_questions[id]
    golden_context_id = dev_golden_context_ids[id]
    top_k_context_ids = search(question, top_k, cross_encoder)

    if golden_context_id in top_k_context_ids:
      correct += 1

    # # To limit the number of iterations
    # count += 1
    # if count == 100:
    #   break

  topk_acc = (correct/len(dev_questions)) * 100

  print(f"Successful retrievals: {correct}/{len(dev_questions.keys())}")

  return topk_acc

# print(f'Top-k retrieval accuracy with k=5: {find_topk_acc(5, cross_encoder_raw)}')
print(f'Top-k retrieval accuracy with k=20: {find_topk_acc(20, cross_encoder_raw)}')

Evaluating...: 100%|██████████| 10570/10570 [1:05:55<00:00,  2.67it/s]

Successful retrievals: 6314/10570
Top-k retrieval accuracy with k=20: 59.735099337748345





Also find the queries and the corresponding outputs that the model gets wrong.

In [15]:
def find_topk_acc(top_k, cross_encoder):
  correct = 0
  count = 0
  wrong_results = list()

  for id in tqdm(dev_questions.keys(), mininterval = 3, desc ="Evaluating..."):
    question = dev_questions[id]
    golden_context_id = dev_golden_context_ids[id]
    top_k_context_ids = search(question, top_k, cross_encoder)

    if golden_context_id in top_k_context_ids:
      correct += 1

    # To see where the model gets wrong
    else:
      wrong_results.append({'id':id,
                     'question':question,
                     'golden context':dev_contexts[str(golden_context_id)],
                     'retrived contexts':[dev_contexts[str(id)] for id in top_k_context_ids]})

    # To limit the number of iterations
    count += 1
    if count == 1000:
      break

  topk_acc = (correct/len(dev_questions)) * 100

  print(f"Successful retrievals: {correct}/{len(dev_questions.keys())}")

  return topk_acc, wrong_results

# bm25-injection
topk_acc_bm25_inject, wrong_results_bm25_inject = find_topk_acc(5, cross_encoder_bm25_injection)
print(f'Top-k retrieval accuracy with k=5: {topk_acc_bm25_inject}')

# no-injection
topk_acc_no_inject, wrong_results_no_inject = find_topk_acc(5, cross_encoder_noinjection)
print(f'Top-k retrieval accuracy with k=5: {topk_acc_no_inject}')

# no finetuning
# topk_acc_raw, wrong_results_raw = find_topk_acc(5, cross_encoder_raw)
# print(f'Top-k retrieval accuracy with k=5: {topk_acc_raw}')


bm25_inject_pd = pd.DataFrame(wrong_results_bm25_inject)
no_inject_pd = pd.DataFrame(wrong_results_no_inject)
# raw_pd = pd.DataFrame(wrong_results_raw)

Evaluating...:   9%|▉         | 999/10570 [02:57<28:17,  5.64it/s]


Successful retrievals: 969/10570
Top-k retrieval accuracy with k=5: 9.167455061494795


Evaluating...:   9%|▉         | 999/10570 [02:55<28:05,  5.68it/s]

Successful retrievals: 964/10570
Top-k retrieval accuracy with k=5: 9.120151371807001





In [19]:
def compare_wrong_results(dfA, dfB, diff=False):
  """
  If diff is False, find the ids that dfA gets right but dfB gets wrong.
  If diff is True, find the ids that both get wrong.
  """
  ids = list()
  if diff is False:
    for id in dfB.index:
      if dfB['id'][id] not in dfA['id'].values:
        ids.append(dfB['id'][id])
  else:
    for id in dfB.index:
      if dfB['id'][id] in dfA['id'].values:
        ids.append(dfB['id'][id])

  return dfB[dfB['id'].isin(ids)]

print(no_inject_pd['id'].values)
print(bm25_inject_pd['id'].values)

df = compare_wrong_results(bm25_inject_pd, no_inject_pd, False)
# df = compare_wrong_results(no_inject_pd, bm25_inject_pd, False)
# df = compare_wrong_results(bm25_inject_pd, no_inject_pd, True)
data_table.DataTable(df)

['57' '133' '163' '257' '312' '371' '385' '402' '407' '431' '471' '474'
 '477' '483' '498' '504' '523' '536' '543' '619' '625' '632' '670' '672'
 '680' '701' '707' '761' '785' '792' '794' '837' '848' '851' '949' '965']
['40' '50' '78' '133' '257' '312' '371' '385' '402' '407' '431' '471'
 '474' '498' '504' '523' '536' '543' '619' '625' '632' '672' '701' '707'
 '761' '792' '837' '848' '851' '949' '965']


Unnamed: 0,id,question,golden context,retrived contexts
0,57,Which team held the scoring lead throughout th...,The Broncos took an early lead in Super Bowl 5...,[Denver took the opening kickoff and started o...
2,163,"Prior to Super Bowl 50, what was the last Supe...","On May 21, 2013, NFL owners at their spring me...",[Super Bowl 50 was an American football game t...
12,477,How much money was spent on other festivities ...,"In addition, there are $2 million worth of oth...",[The annual NFL Experience was held at the Mos...
13,483,What was the cost of the other Super Bowl even...,"In addition, there are $2 million worth of oth...",[The annual NFL Experience was held at the Mos...
22,670,Who did the National Anthem at Super Bowl 50?,Six-time Grammy winner and Academy Award nomin...,[Super Bowl 50 was an American football game t...
24,680,Who lead the halftime show of Super Bowl 50?,"In late November 2015, reports surfaced statin...","[CBS broadcast Super Bowl 50 in the U.S., and ..."
28,785,How many picks did Cam Newton throw?,Manning finished the game 13 of 23 for 141 yar...,"[The Panthers offense, which led the NFL in sc..."
30,794,How many intercpetions did Newton have in Supe...,Manning finished the game 13 of 23 for 141 yar...,[Super Bowl 50 featured numerous records from ...


Run the following code to see how well our retrieve-rerank method does given a query and k value.

In [None]:
question_id = '766'
top_k = 5

question = dev_questions[question_id]
# golden_context_id = dev_contexts[555]
top_k_context_id = search(question, top_k)


print(f"Query: {question}")
# print(f"Golden context id: {golden_context_id}")
print(f"Top-k context id: {top_k_context_id}")
print(f"Top-k context:")

for id, context_id in enumerate(top_k_context_id):
  print(f'{id}. {dev_contexts[str(context_id)]}')

Query: On what yard line did Carolina begin with 4:51 left in the game?
Top-k context id: [51, 49, 47, 48, 46]
Top-k context:
0. With 4:51 left in regulation, Carolina got the ball on their own 24-yard line with a chance to mount a game-winning drive, and soon faced 3rd-and-9. On the next play, Miller stripped the ball away from Newton, and after several players dove for it, it took a long bounce backwards and was recovered by Ward, who returned it five yards to the Panthers 4-yard line. Although several players dove into the pile to attempt to recover it, Newton did not and his lack of aggression later earned him heavy criticism. Meanwhile, Denver's offense was kept out of the end zone for three plays, but a holding penalty on cornerback Josh Norman gave the Broncos a new set of downs. Then Anderson scored on a 2-yard touchdown run and Manning completed a pass to Bennie Fowler for a 2-point conversion, giving Denver a 24–10 lead with 3:08 left and essentially putting the game away. Ca