<a href="https://colab.research.google.com/github/juliatessler/1s2023-unicamp-dl-for-search-systems/blob/main/2-cross-encoder/2_Cross_Encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Week 2 Assignment: Cross-Encoder
by Júlia Ferreira Tessler

In [3]:
!pip install transformers -q
!pip install pyserini -q
!pip install faiss-cpu -q

This notebook is based on [this one](https://colab.research.google.com/drive/10etP7Lb915EC-uEuf1IKC8DYkyg_om6-?usp=sharing#scrollTo=wHeZ9nAOEB0U) provided in the assignment description. When a cell has with "Uses the suggested notebook", it means that the code came from there.

In [4]:
import random
import torch
import torch.nn.functional as F
import numpy as np

from pyserini.search.lucene import LuceneSearcher
from pyserini.index.lucene import IndexReader
from pyserini.search import get_topics, get_qrels
from torch.utils import data

import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup, BatchEncoding

from torch import nn
from torch import optim
from tqdm.auto import tqdm
from statistics import mean

In [5]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f3ecc1650b0>

In [6]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Preparing the data

In [7]:
# Download training data
!wget https://storage.googleapis.com/unicamp-dl/ia368dd_2023s1/msmarco/msmarco_triples.train.tiny.tsv

--2023-03-16 00:41:22--  https://storage.googleapis.com/unicamp-dl/ia368dd_2023s1/msmarco/msmarco_triples.train.tiny.tsv
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.200.128, 74.125.68.128, 74.125.24.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.200.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 8076179 (7.7M) [text/tab-separated-values]
Saving to: ‘msmarco_triples.train.tiny.tsv.1’


2023-03-16 00:41:23 (149 MB/s) - ‘msmarco_triples.train.tiny.tsv.1’ saved [8076179/8076179]



According to ChatGPT:

The MSMARCO Triples Tiny dataset is a smaller subset of the full MSMARCO Triples dataset that contains 3-tuples of (query, positive document, negative document) pairs. Each triple in the MSMARCO Triples Tiny dataset consists of the following fields:

* Query: The search query that was issued by the user.
* Positive Document: A document that is relevant to the query and is considered as a positive example.
* Negative Document: A document that is irrelevant to the query and is considered as a negative example.

All three fields are represented as strings in the dataset. The MSMARCO Triples Tiny dataset is designed to be used for training and evaluating machine learning models for query-document ranking and relevance prediction tasks. The dataset is smaller than the full MSMARCO Triples dataset, containing only 31,000 triples, but can be used for quick experimentation and prototyping of machine learning models.

In [8]:
# Code from ChatGPT
df = pd.read_csv('msmarco_triples.train.tiny.tsv', 
                 delimiter = '\t', 
                 header = None, 
                 names = ['query', 'pos_doc', 'neg_doc'])

df.head()

Unnamed: 0,query,pos_doc,neg_doc
0,is a little caffeine ok during pregnancy,We donât know a lot about the effects of caf...,It is generally safe for pregnant women to eat...
1,what fruit is native to australia,Passiflora herbertiana. A rare passion fruit n...,"The kola nut is the fruit of the kola tree, a ..."
2,how large is the canadian military,The Canadian Armed Forces. 1 The first large-...,The Canadian Physician Health Institute (CPHI)...
3,types of fruit trees,Cherry. Cherry trees are found throughout the ...,"The kola nut is the fruit of the kola tree, a ..."
4,how many calories a day are lost breastfeeding,"Not only is breastfeeding better for the baby,...","However, you still need some niacin each day; ..."


In [9]:
distinct_queries = len(pd.unique(df['query']))
total_queries = len(df['query'])

print(f'The train dataset has {distinct_queries} distinct queries of {total_queries} in total. That is: there are {total_queries - distinct_queries} repeated queries.')

The train dataset has 10779 distinct queries of 11000 in total. That is: there are 221 repeated queries.


In order to use the dataset to train a binary classifier, we should change the format of the dataset to something that has:
- query
- document
- relevant or not

Thanks, Mirelle, for the idea to use `pandas.melt` to solve this problem!



In [10]:
# Because pandas.melt unpivots da dataframe, it is useful to rename the columns
# so that the corresponding values will be correct
df.rename(columns = {'pos_doc': 1, 'neg_doc': 0}, inplace = True)

train_df = pd.melt(
    df,
    id_vars = ['query'],
    value_vars = [1, 0],
    var_name = 'relevance',
    value_name = 'document'
)

train_df['relevance'] = train_df['relevance'].astype('bool')

In [11]:
train_df[train_df['query']=='is a little caffeine ok during pregnancy']

Unnamed: 0,query,relevance,document
0,is a little caffeine ok during pregnancy,True,We donât know a lot about the effects of caf...
11000,is a little caffeine ok during pregnancy,False,It is generally safe for pregnant women to eat...


In [12]:
df.head()

Unnamed: 0,query,1,0
0,is a little caffeine ok during pregnancy,We donât know a lot about the effects of caf...,It is generally safe for pregnant women to eat...
1,what fruit is native to australia,Passiflora herbertiana. A rare passion fruit n...,"The kola nut is the fruit of the kola tree, a ..."
2,how large is the canadian military,The Canadian Armed Forces. 1 The first large-...,The Canadian Physician Health Institute (CPHI)...
3,types of fruit trees,Cherry. Cherry trees are found throughout the ...,"The kola nut is the fruit of the kola tree, a ..."
4,how many calories a day are lost breastfeeding,"Not only is breastfeeding better for the baby,...","However, you still need some niacin each day; ..."


In [13]:
train_df.head()

Unnamed: 0,query,relevance,document
0,is a little caffeine ok during pregnancy,True,We donât know a lot about the effects of caf...
1,what fruit is native to australia,True,Passiflora herbertiana. A rare passion fruit n...
2,how large is the canadian military,True,The Canadian Armed Forces. 1 The first large-...
3,types of fruit trees,True,Cherry. Cherry trees are found throughout the ...
4,how many calories a day are lost breastfeeding,True,"Not only is breastfeeding better for the baby,..."


In [14]:
train_df.shape

(22000, 3)

In [15]:
# Saving the new df for posterior use
train_df.to_csv('format_msmarco_triples.train.tiny.tsv',
                sep = '\t',
                index = False)

I don't really see any problem with "leak" between train and test datasets. Here, I separate 5% for testing, 15% for validation and 80% for training.

In [16]:
df_train = train_df.sample(frac = 0.8, random_state = 42)                       # 80%
df_valid = train_df.drop(df_train.index).sample(frac = 0.75, random_state = 42) # 15%
df_test = train_df.drop(df_train.index).drop(df_valid.index)                    # 5%

In [17]:
df_train.shape, df_valid.shape, df_test.shape

((17600, 3), (3300, 3), (1100, 3))

In [18]:
df_train.head()

Unnamed: 0,query,relevance,document
13035,where is fitbit connect,False,Charging your tracker ...........................
3115,defense messaging system,True,Defense Message System. The Defense Message Sy...
8732,What are the four main ocean basins? which is ...,True,Although both the Northern and Southern Hemisp...
7591,"how many hours a day to work to get lunch in ,a",True,"In the United States, most states require a mi..."
221,can pizza boxes be recycled,True,Pizza box suitable for recycling. Pizza box no...


In [19]:
df_train.relevance.value_counts()

True     8847
False    8753
Name: relevance, dtype: int64

In [20]:
df_valid.relevance.value_counts()

False    1687
True     1613
Name: relevance, dtype: int64

In [21]:
df_test.relevance.value_counts()

False    560
True     540
Name: relevance, dtype: int64

In [22]:
# y_train = df_train.pop('relevance')
# X_train = df_train

# y_test = df_test.pop('relevance')
# X_test = df_test

In [23]:
# X_train.head()

### Creating Dataset and Dataloader artifacts


In [24]:
max_length = 356

In [25]:
class MSMARCODataset(data.Dataset):
    def __init__(self, tokenizer, query, documents, targets, max_lenght = 356):
        self.tokenizer = tokenizer
        self.query = query
        self.documents = documents
        self.targets = targets
        self.max_lenght = max_lenght
    
    def __len__(self):
        return len(self.query)

    def __getitem__(self, idx):
        query_doc_tuple = (self.query[idx]+ ', ' + self.documents[idx])
        query_doc_token = self.tokenizer(query_doc_tuple,
                                           max_length = self.max_lenght, 
                                           truncation = True,
                                           padding = "max_length", 
                                           return_tensors = 'pt')


        return {'input_ids': torch.squeeze(query_doc_token['input_ids']).long().to(device), \
               'attention_mask': torch.squeeze(query_doc_token['attention_mask']).long().to(device), \
               'labels': int(self.targets[idx])}

In [26]:
model_name = 'microsoft/MiniLM-L12-H384-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [27]:
device = (
    torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
print(device)

cuda


In [28]:
# This cell uses the suggested notebook
# This functions adds "pad" tokens to examples in the batch that are shorter than the largest one.
def collate_fn(batch):
    return BatchEncoding(tokenizer.pad(batch, return_tensors = 'pt'))

# Convert examples to Pytorch's Dataset.
train_dataset = MSMARCODataset(
    tokenizer, 
    df_train['query'].to_list(), 
    df_train['document'].to_list(), 
    df_train['relevance'].to_list()
)

valid_dataset = MSMARCODataset(
    tokenizer, 
    df_valid['query'].to_list(), 
    df_valid['document'].to_list(), 
    df_valid['relevance'].to_list()
)


test_dataset = MSMARCODataset(
    tokenizer, 
    df_test['query'].to_list(), 
    df_test['document'].to_list(), 
    df_test['relevance'].to_list()
)

# Convert examples to Pytorch's DataLoader.
dataloader_train = data.DataLoader(train_dataset, batch_size = 32, shuffle = True, collate_fn = collate_fn)
dataloader_valid = data.DataLoader(valid_dataset, batch_size = 32, shuffle = True, collate_fn = collate_fn)
dataloader_test = data.DataLoader(test_dataset, batch_size = 32, shuffle = False, collate_fn = collate_fn)

In [29]:
# Uses the suggested notebook
for batch in dataloader_train:
    assert batch['input_ids'].shape[0] <= dataloader_train.batch_size
    assert batch['input_ids'].shape[1] <= max_length
    break

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


### Training loop

In [30]:
# Uses the suggested notebook
# We first define the evaluation function to measure accuracy and loss

def evaluate(model, dataloader, set_name):
    losses = []
    correct = 0
    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, mininterval = 0.5, desc = set_name, disable = False):
            outputs = model(**batch.to(device))
            loss_val = outputs.loss
            losses.append(loss_val.cpu().item())
            preds = outputs.logits.argmax(dim = 1)
            correct += (preds == batch['labels']).sum().item()

    print(f'{set_name} loss: {mean(losses):0.3f}; {set_name} accuracy: {correct / len(dataloader.dataset):0.3f}')

In [31]:
def train_loop(model, epochs = 5, lr = 5e-5):
  model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
  print('Parameters', model.num_parameters())

  optimizer = optim.AdamW(model.parameters(), lr = lr)
  num_training_steps = epochs * len(dataloader_train)
  # Warm up is important to stabilize training.
  num_warmup_steps = int(num_training_steps * 0.1)
  scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

  # First validation to check if evaluation code is working and accuracy is random as expected 
  evaluate(model = model, dataloader = dataloader_valid, set_name='Valid')

  # Training loop
  for epoch in tqdm(range(epochs), desc='Epochs'):
      model.train()
      train_losses = []
      for batch in tqdm(dataloader_train, mininterval=0.5, desc='Train', disable=False):
          optimizer.zero_grad()
          outputs = model(**batch.to(device))
          loss = outputs.loss
          loss.backward()
          optimizer.step()
          scheduler.step()
          train_losses.append(loss.cpu().item())
      # Got those from Monique's notebook, thanks!
      print(f'Epoch: {epoch + 1} Training loss: {mean(train_losses):0.2f}')
      evaluate(model=model, dataloader=dataloader_valid, set_name='Valid')
  
  return model

In [32]:
model_path = '/content/drive/MyDrive/Unicamp/'
# You might need to mkdir for this to work if you're running this code

In [33]:
model = train_loop(model_name, 5)
model_name = model_name.replace('/','_')
model.save_pretrained(f'{model_path}/models_ranker_{model_name}')
tokenizer.save_pretrained(f'{model_path}/tokenizer_ranker')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/MiniLM-L12-H384-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Parameters 33360770


Valid:   0%|          | 0/104 [00:00<?, ?it/s]

Valid loss: 0.693; Valid accuracy: 0.511


Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Train:   0%|          | 0/550 [00:00<?, ?it/s]

Epoch: 1 Training loss: 0.44


Valid:   0%|          | 0/104 [00:00<?, ?it/s]

Valid loss: 0.277; Valid accuracy: 0.891


Train:   0%|          | 0/550 [00:00<?, ?it/s]

Epoch: 2 Training loss: 0.21


Valid:   0%|          | 0/104 [00:00<?, ?it/s]

Valid loss: 0.256; Valid accuracy: 0.906


Train:   0%|          | 0/550 [00:00<?, ?it/s]

Epoch: 3 Training loss: 0.12


Valid:   0%|          | 0/104 [00:00<?, ?it/s]

Valid loss: 0.250; Valid accuracy: 0.913


Train:   0%|          | 0/550 [00:00<?, ?it/s]

Epoch: 4 Training loss: 0.06


Valid:   0%|          | 0/104 [00:00<?, ?it/s]

Valid loss: 0.254; Valid accuracy: 0.920


Train:   0%|          | 0/550 [00:00<?, ?it/s]

Epoch: 5 Training loss: 0.03


Valid:   0%|          | 0/104 [00:00<?, ?it/s]

Valid loss: 0.296; Valid accuracy: 0.920


('/content/drive/MyDrive/Unicamp//tokenizer_ranker/tokenizer_config.json',
 '/content/drive/MyDrive/Unicamp//tokenizer_ranker/special_tokens_map.json',
 '/content/drive/MyDrive/Unicamp//tokenizer_ranker/vocab.txt',
 '/content/drive/MyDrive/Unicamp//tokenizer_ranker/added_tokens.json',
 '/content/drive/MyDrive/Unicamp//tokenizer_ranker/tokenizer.json')

## BM25 base search
For the index, we'll use one of [Pyserini's Prebuilt indexes](https://github.com/castorini/pyserini/blob/master/docs/prebuilt-indexes.md). Following the steps from the documentation:

In [34]:
searcher = LuceneSearcher.from_prebuilt_index('msmarco-v1-passage')
searcher.set_bm25(k1 = 0.9, b = 0.4)

In [35]:
index_reader = IndexReader.from_prebuilt_index('msmarco-v1-passage')
index_reader.stats()

{'total_terms': 352316036,
 'documents': 8841823,
 'non_empty_documents': 8841823,
 'unique_terms': 2660824}

In [36]:
# Code adapted from Pyserini Demo: https://colab.research.google.com/github/castorini/anserini-notebooks/blob/master/pyserini_msmarco_passage_demo.ipynb
topics = get_topics('dl20')
print(f'{len(topics)} queries total')

200 queries total


In [37]:
qrels = get_qrels('dl20-passage')

In [38]:
# Code adapted from Pyserini Demo: https://colab.research.google.com/github/castorini/anserini-notebooks/blob/master/pyserini_msmarco_passage_demo.ipynb

# Run all queries in topics, retrive top 1k for each query
def run_all_queries(file, topics, searcher):
    with open(file, 'w') as runfile:
        cnt = 0
        print('Running {} queries in total'.format(len(topics)))
        for id in tqdm(topics, desc='Running Queries'):
            query = topics[id]['title']
            hits = searcher.search(query, 1000)
            for i in range(0, len(hits)):
                _ = runfile.write('{} Q0 {} {} {:.6f} Pyserini\n'.format(id, hits[i].docid, i+1, hits[i].score))
            cnt += 1
            if cnt % 100 == 0:
                print(f'{cnt} queries completed')

run_all_queries('run-msmarco-passage-bm25.txt', topics, searcher)

Running 200 queries in total


Running Queries:   0%|          | 0/200 [00:00<?, ?it/s]

100 queries completed
200 queries completed


In [39]:
!python -m pyserini.eval.trec_eval -c -m ndcg_cut.10 -l 2 dl20-passage run-msmarco-passage-bm25.txt

Downloading https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar to /root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar...
/root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar already exists!
Skipping download.
Running command: ['java', '-jar', '/root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar', '-c', '-m', 'ndcg_cut.10', '-l', '2', '/root/.cache/pyserini/topics-and-qrels/qrels.dl20-passage.txt', 'run-msmarco-passage-bm25.txt']
Results:
ndcg_cut_10           	all	0.4796


BM25 NDCG@10 = 47.96%

In [40]:
# From the Pyserini Demo
import json

hits = searcher.search(topics[735922]['title'])
hits[1].raw

'{\n  "id" : "2766952",\n  "contents" : "Description: There has always been a debate as to what is the best form of crimping - hexagonal or indent. There is no straight answer as the crimping method depends on the type of cable, the type of lugs used ... View More. There has always been a debate as to what is the best form of crimping - hexagonal or indent. There is no straight answer as the crimping method depends on the type of cable, the type of lugs used and lot of other external attributes. For further questions: bpaul01@klauke.textron.com."\n}'

In [41]:
jsondoc = json.loads(hits[1].raw)
jsondoc

{'id': '2766952',
 'contents': 'Description: There has always been a debate as to what is the best form of crimping - hexagonal or indent. There is no straight answer as the crimping method depends on the type of cable, the type of lugs used ... View More. There has always been a debate as to what is the best form of crimping - hexagonal or indent. There is no straight answer as the crimping method depends on the type of cable, the type of lugs used and lot of other external attributes. For further questions: bpaul01@klauke.textron.com.'}

In [42]:
# Prints the first 10 hits
for i in range(0, 10):
    jsondoc = json.loads(hits[i].raw)
    print(f'{i+1:2} {hits[i].score:.5f} {jsondoc["contents"][:80]}...')

 1 9.60990 Definition of crimped in the Definitions.net dictionary. Meaning of crimped. Wha...
 2 9.39730 Description: There has always been a debate as to what is the best form of crimp...
 3 9.39330 What does crimped mean? Definitions for crimped Here are all the possible meanin...
 4 9.20850 Effect of Crimp Depth on Shotshells. Crimp depth of a finished shotshell reload ...
 5 8.95510 Crimp Oil is a 100% natural blend of essential oils and plant extracts that aids...
 6 8.66260 Directions. 1  Heat your oven to 225 degrees Fahrenheit. 2  On a large sheet of ...
 7 8.56560 1 Heat your oven to 225 degrees Fahrenheit. 2  On a large sheet of aluminum foil...
 8 8.53480 Metolius Crimp Oil. 1  A healing massage oil for climbers hands and muscles. 2  ...
 9 8.49770 crimped. simple past tense and past participle of crimp The sharp bend had crimp...
10 8.47610 Metolius Crimp Oil aids in quick and fast healing of the finger, hand, and wrist...


## Reranking
Copied from Arthur de Andrade Almeida!

In [43]:
# Get corpus dataset
!mkdir collections/msmarco-passage # type: ignore
!wget https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz -P collections/msmarco-passage # type: ignore
!tar xvfz collections/msmarco-passage/collectionandqueries.tar.gz -C collections/msmarco-passage # type: ignore

mkdir: cannot create directory ‘collections/msmarco-passage’: File exists
--2023-03-16 01:23:28--  https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz
Resolving msmarco.blob.core.windows.net (msmarco.blob.core.windows.net)... 20.150.34.4
Connecting to msmarco.blob.core.windows.net (msmarco.blob.core.windows.net)|20.150.34.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1057717952 (1009M) [application/gzip]
Saving to: ‘collections/msmarco-passage/collectionandqueries.tar.gz.2’


2023-03-16 01:26:23 (5.80 MB/s) - ‘collections/msmarco-passage/collectionandqueries.tar.gz.2’ saved [1057717952/1057717952]

collection.tsv
qrels.dev.small.tsv
qrels.train.tsv
queries.dev.small.tsv
queries.dev.tsv
queries.eval.small.tsv
queries.eval.tsv
queries.train.tsv


In [44]:
!wget https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2020-queries.tsv.gz
!gunzip -v /content/msmarco-test2020-queries.tsv.gz

--2023-03-16 01:26:58--  https://msmarco.blob.core.windows.net/msmarcoranking/msmarco-test2020-queries.tsv.gz
Resolving msmarco.blob.core.windows.net (msmarco.blob.core.windows.net)... 20.150.34.4
Connecting to msmarco.blob.core.windows.net (msmarco.blob.core.windows.net)|20.150.34.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4131 (4.0K) [application/x-gzip]
Saving to: ‘msmarco-test2020-queries.tsv.gz’


2023-03-16 01:26:59 (389 MB/s) - ‘msmarco-test2020-queries.tsv.gz’ saved [4131/4131]

gzip: /content/msmarco-test2020-queries.tsv already exists; do you wish to overwrite (y or n)? n
	not overwritten


In [45]:
ranked_df = pd.read_csv('run-msmarco-passage-bm25.txt', sep=' ', header=None)
ranked_df.columns = ['query_id', 'Q0', 'doc_id', 'ranking', 'score', 'run']
ranked_df = ranked_df[["query_id", "doc_id"]]
ranked_df.head()

Unnamed: 0,query_id,doc_id
0,735922,7307871
1,735922,2766952
2,735922,7307863
3,735922,8734268
4,735922,8626892


In [46]:
queries_df = pd.read_csv("/content/msmarco-test2020-queries.tsv", header=None, sep='\t')
queries_df.columns = ['query_id', 'query']
queries_df.head()

Unnamed: 0,query_id,query
0,1030303,who is aziz hashim
1,1037496,who is rep scalise?
2,1043135,who killed nicholas ii of russia
3,1045109,who owns barnhart crane
4,1049519,who said no one can make you feel inferior


In [47]:
docs_df = pd.read_csv('/content/collections/msmarco-passage/collection.tsv', header=None, sep='\t')
docs_df.columns = ["doc_id", "doc"]
docs_df.head()

Unnamed: 0,doc_id,doc
0,0,The presence of communication amid scientific ...
1,1,The Manhattan Project and its atomic bomb help...
2,2,Essay on The Manhattan Project - The Manhattan...
3,3,The Manhattan Project was the name for a proje...
4,4,versions of each volume as well as complementa...


In [48]:
merged_df = pd.merge(queries_df, ranked_df, on='query_id')
merged_df = pd.merge(merged_df, docs_df, on='doc_id')
merged_df.head()

Unnamed: 0,query_id,query,doc_id,doc
0,1030303,who is aziz hashim,8726436,Share on LinkedInShare on FacebookShare on Twi...
1,1030303,who is aziz hashim,8726435,Mr. Aziz Hashim has been the President and Sec...
2,1030303,who is aziz hashim,8726429,"The crew at NRD Holdings, left to right: Karim..."
3,1030303,who is aziz hashim,8726437,Aziz Hashim is one of the worldâs leading ex...
4,1030303,who is aziz hashim,7156982,Rounding out the IFA leadership team is Aziz H...


In [49]:
def get_test_dataloader(instances, targets, tokenizer, batch_size=32, num_workers=0):
    tokenized_instances = tokenizer(instances, padding="max_length", 
                                    max_length=max_length, truncation=True)
    dataset = MSMARCODataset(tokenizer, 
      instances['query'].to_list(), 
      instances['doc'].to_list(), 
      np.ones(len(instances['query'])))
    dataloader = data.DataLoader(dataset, shuffle=False, 
                            batch_size=batch_size, num_workers=num_workers)
    return dataloader

In [50]:
def get_probs(model, loader, set_name):
    """Method to obtain the predicted probabilities for the class 1 (relevant)"""
    scores = []
    model.eval()
    with torch.no_grad():
        for batch in tqdm(loader, desc=set_name, disable=False):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            probs = torch.sigmoid(outputs.logits)[:, 1]
            scores.extend(probs.tolist())
    return scores

In [None]:
test_data = list(zip(merged_df["query"].values.tolist(), merged_df["doc"].values.tolist()))
test_loader = get_test_dataloader(test_data, [True] * len(test_data), tokenizer)
scores = get_probs(model, test_loader, set_name="Test")

In [None]:
df_test.head()

In [None]:
test_data = list(zip(merged_df["query"].values.tolist(), merged_df["passage"].values.tolist()))
test_loader = get_test_dataloader(test_data, [True] * len(test_data), tokenizer)
merged_df["scores"] = scores
merged_df.head()

In [None]:
# Sort rows by score
df_test = df_test.sort_values(by=["query_id", "scores"], ascending=False)
df_test.head()

In [None]:
!python pyserini/pyserini/eval/trec_eval.py -c -mndcg_cut.10 -mmap \
   dl20-passage /content/run.dl20.microsoft.MiniLM-L12-H384-uncased.txt