<a href="https://colab.research.google.com/github/eduseiti/ia368v_dd_class_03/blob/main/DL_reranking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Prepare the environment

In [None]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"

In [2]:
%%shell
pip install pyserini
pip install faiss-cpu
apt-get install maven -qq
git clone --recurse-submodules https://github.com/castorini/pyserini.git
cd pyserini
cd tools/eval && tar xvfz trec_eval.9.0.4.tar.gz && cd trec_eval.9.0.4 && make && cd ../../..
cd tools/eval/ndeval && make && cd ../../..

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyserini
  Downloading pyserini-0.20.0-py3-none-any.whl (137.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.1/137.1 MB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pyjnius>=1.4.0
  Downloading pyjnius-1.4.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m72.9 MB/s[0m eta [36m0:00:00[0m
Collecting onnxruntime>=1.8.1
  Downloading onnxruntime-1.14.1-cp39-cp39-manylinux_2_27_x86_64.whl (5.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m97.7 MB/s[0m eta [36m0:00:00[0m
Collecting sentencepiece>=0.1.95
  Downloading sentencepiece-0.1.97-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m69.4 MB/s[0



In [3]:
!pip install transformers -q

In [4]:
import pickle
from google.colab import drive

import pandas as pd
import numpy as np

import random
import torch
import torch.nn.functional as F

from scipy import stats

from statistics import mean, stdev

In [5]:
random.seed(0xDEADBEEF)
np.random.seed(0xDEADBEEF)
torch.manual_seed(0xDEADBEEF)

<torch._C.Generator at 0x7fbcf77713d0>

In [6]:
TRAINING_DATA="https://storage.googleapis.com/unicamp-dl/ia368dd_2023s1/msmarco/msmarco_triples.train.tiny.tsv"
WORKING_FOLDER="drive/MyDrive/unicamp/ia368v_dd/aula_03"
FIXED_TRAINING_DATA="msmarco_triples.train.tiny_fixed.tsv"

In [7]:
CHAR_FIXES={
    "â\x80\x99": "'",
    "â\x80\x98": "'",
    "â\x80²": "\'",
    " â\x80¦ ": "",
    "â\x80¦": "",
    " â\x80º": "",
    " â\x80¢ ": "",
    "â\x80º": "",
    "â\x80¢ ": "",
    "â\x84¢": "",
    "â\x80\x91": "-",
    "â\x80\x94": "-",
    "â\x80\x93": "-",
    "â": "-",
    "â\x80\x9c": "\"",
    "â\x80\x9d": "\"",
    "â\x80³": "\"",
    "Âº": "°"
}

## Set the Google Drive connection

In [8]:
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [9]:
os.chdir(WORKING_FOLDER)

## Initialize some model structures before doing anything

In [10]:
from torch import nn
from torch import optim
from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils import data
from transformers import BatchEncoding
from torch.utils import data

In [11]:
model_name = 'microsoft/MiniLM-L12-H384-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
device

device(type='cuda')

In [14]:
# This functions adds "pad" tokens to examples in the batch that are shorter than the largest one.
def collate_fn(batch):

    # print(len(batch[0]['input_ids']))

    r1 = tokenizer.pad(batch, return_tensors='pt')

    # print(len(r1['input_ids'][0]))

    return BatchEncoding(r1)


class Dataset(data.Dataset):
    def __init__(self, examples, targets):
        self.examples = examples
        self.targets = targets
    
    def __len__(self):
        return len(self.examples['input_ids'])
    
    def __getitem__(self, idx):
        return {
            'input_ids': self.examples['input_ids'][idx],
            'attention_mask': self.examples['attention_mask'][idx],
            'labels': int(self.targets[idx]),
        }

## Download the finetuning dataset and clean the encoding errors

### First, just check if has already a cleaned version of the dataset

Clean the dataset from the encoding errors.

In [15]:
if os.path.exists(FIXED_TRAINING_DATA):
    
    print("The data has already been cleaned...")

    df = pd.read_csv(FIXED_TRAINING_DATA)
else:
    !wget https://storage.googleapis.com/unicamp-dl/ia368dd_2023s1/msmarco/msmarco_triples.train.tiny.tsv

    os.path.basename(TRAINING_DATA)
    
    df = pd.read_csv(os.path.basename(TRAINING_DATA), sep='\t', header=None, names=['topic', 'positive', 'negative'])

    #
    # Fix some bad encodings...
    #

    for to_be_replaced, replacement in CHAR_FIXES.items():
        df['positive'] = df['positive'].str.replace(to_be_replaced, replacement)
        df['negative'] = df['negative'].str.replace(to_be_replaced, replacement)

    df.to_csv(FIXED_TRAINING_DATA, index=False)

The data has already been cleaned...


## Prepare the dataset

In [16]:
df.shape

(11000, 3)

In [17]:
df.head()

Unnamed: 0,topic,positive,negative
0,is a little caffeine ok during pregnancy,We don't know a lot about the effects of caffe...,It is generally safe for pregnant women to eat...
1,what fruit is native to australia,Passiflora herbertiana. A rare passion fruit n...,"The kola nut is the fruit of the kola tree, a ..."
2,how large is the canadian military,The Canadian Armed Forces. 1 The first large-...,The Canadian Physician Health Institute (CPHI)...
3,types of fruit trees,Cherry. Cherry trees are found throughout the ...,"The kola nut is the fruit of the kola tree, a ..."
4,how many calories a day are lost breastfeeding,"Not only is breastfeeding better for the baby,...","However, you still need some niacin each day; ..."


In [18]:
df['topic'].str.len().describe()

count    11000.000000
mean        34.225636
std         13.130216
min          6.000000
25%         26.000000
50%         32.000000
75%         40.000000
max        215.000000
Name: topic, dtype: float64

In [19]:
positive_examples = df[['topic', 'positive']].to_numpy()
negative_examples = df[['topic', 'negative']].to_numpy()

#### Concatenate the positive and negative examples, and create the corresponding label

In [20]:
all_examples = np.concatenate([df[['topic', 'positive']].to_numpy(), df[['topic', 'negative']].to_numpy()])

In [21]:
example_class = np.concatenate([np.ones(df.shape[0], dtype=bool), np.zeros(df.shape[0], dtype=bool)])

In [22]:
example_class.shape

(22000,)

#### Tokenize the topics and the related examples (positive and negative)

In [23]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model_name = 'microsoft/MiniLM-L12-H384-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [24]:
topics_tokens = tokenizer(list(all_examples[:,0]), return_length=True)

In [25]:
%%time

examples_tokens = tokenizer(list(all_examples[:,1]), return_length=True)

CPU times: user 9.22 s, sys: 266 ms, total: 9.48 s
Wall time: 1.34 s


In [26]:
stats.describe(topics_tokens['length'])

DescribeResult(nobs=22000, minmax=(4, 43), mean=9.103909090909092, variance=8.122390282202911, skewness=2.2074659701047348, kurtosis=14.43184511406837)

In [27]:
stats.describe(examples_tokens['length'])

DescribeResult(nobs=22000, minmax=(13, 280), mean=79.1235, variance=1016.1129801581889, skewness=1.1803703067189946, kurtosis=1.5748246785912885)

In [28]:
shuffled_examples_indexes = list(range(example_class.shape[0]))

np.random.shuffle(shuffled_examples_indexes)

In [29]:
shuffled_examples_indexes[:10]

[11069, 14593, 9415, 7511, 3623, 11244, 12033, 2374, 14850, 9175]

Split the data in train an validation sets

In [30]:
VALIDATION_SIZE=1000

In [31]:
topics_tokens.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'length'])

#### Merge topic + example in a single sequence to feed the model

When merging, remove the 'CLS' token from the example tokenized sequences

In [32]:
train_input_ids = []
train_token_type_ids = []
train_attention_mask = []

for i in range(all_examples.shape[0] - VALIDATION_SIZE):
    train_input_ids.append(topics_tokens['input_ids'][shuffled_examples_indexes[i]] + examples_tokens['input_ids'][shuffled_examples_indexes[i]][1:])
    train_token_type_ids.append(topics_tokens['token_type_ids'][shuffled_examples_indexes[i]] + examples_tokens['token_type_ids'][shuffled_examples_indexes[i]][1:])
    train_attention_mask.append(topics_tokens['attention_mask'][shuffled_examples_indexes[i]] + examples_tokens['attention_mask'][shuffled_examples_indexes[i]][1:])

In [33]:
x_train = {'input_ids': train_input_ids, 
           'token_type_ids': train_token_type_ids, 
           'attention_mask': train_attention_mask}

y_train = example_class[shuffled_examples_indexes[:(all_examples.shape[0] - VALIDATION_SIZE)]]

In [34]:
valid_input_ids = []
valid_token_type_ids = []
valid_attention_mask = []

for i in range(all_examples.shape[0] - VALIDATION_SIZE, all_examples.shape[0]):

    valid_input_ids.append(topics_tokens['input_ids'][shuffled_examples_indexes[i]] + examples_tokens['input_ids'][shuffled_examples_indexes[i]][1:])
    valid_token_type_ids.append(topics_tokens['token_type_ids'][shuffled_examples_indexes[i]] + examples_tokens['token_type_ids'][shuffled_examples_indexes[i]][1:])
    valid_attention_mask.append(topics_tokens['attention_mask'][shuffled_examples_indexes[i]] + examples_tokens['attention_mask'][shuffled_examples_indexes[i]][1:])

In [35]:
x_valid = {'input_ids': valid_input_ids, 
           'token_type_ids': valid_token_type_ids, 
           'attention_mask': valid_attention_mask}

y_valid = example_class[shuffled_examples_indexes[(all_examples.shape[0] - VALIDATION_SIZE):]]

In [36]:
dataset_train = Dataset(x_train, y_train)
dataset_valid = Dataset(x_valid, y_valid)

In [37]:
batch_size=32

# Convert examples to Pytorch's DataLoader.
dataloader_train = data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
dataloader_valid = data.DataLoader(dataset_valid, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

#### Define the evaluation function for the fine-tuning

In [38]:
def evaluate(model, dataloader, set_name):
    losses = []
    correct = 0
    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, mininterval=0.5, desc=set_name, disable=False):
            outputs = model(**batch.to(device))
            loss_val = outputs.loss
            losses.append(loss_val.cpu().item())
            preds = outputs.logits.argmax(dim=1)
            correct += (preds == batch['labels']).sum().item()

    print(f'{set_name} loss: {mean(losses):0.3f}; {set_name} accuracy: {correct / len(dataloader.dataset):0.3f}')

In [41]:
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
print('Parameters', model.num_parameters())

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/MiniLM-L12-H384-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Parameters 33360770


#### Fine-tune the model

In [42]:
epochs = 2
optimizer = optim.AdamW(model.parameters(), lr=5e-5)
num_training_steps = epochs * len(dataloader_train)
# Warm up is important to stabilize training.
num_warmup_steps = int(num_training_steps * 0.1)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

# First validation to check if evaluation code is working and accuracy is random as expected 
evaluate(model=model, dataloader=dataloader_valid, set_name='Valid')

# Training loop
for epoch in tqdm(range(epochs), desc='Epochs'):
    model.train()
    train_losses = []
    for batch in tqdm(dataloader_train, mininterval=0.5, desc='Train', disable=False):
        optimizer.zero_grad()
        outputs = model(**batch.to(device))
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_losses.append(loss.cpu().item())

    print(f'Epoch: {epoch + 1} Training loss: {mean(train_losses):0.2f}')
    evaluate(model=model, dataloader=dataloader_valid, set_name='Valid')

Valid:   0%|          | 0/32 [00:00<?, ?it/s]

Valid loss: 0.693; Valid accuracy: 0.500


Epochs:   0%|          | 0/2 [00:00<?, ?it/s]

Train:   0%|          | 0/657 [00:00<?, ?it/s]

Epoch: 1 Training loss: 0.35


Valid:   0%|          | 0/32 [00:00<?, ?it/s]

Valid loss: 0.214; Valid accuracy: 0.923


Train:   0%|          | 0/657 [00:00<?, ?it/s]

Epoch: 2 Training loss: 0.17


Valid:   0%|          | 0/32 [00:00<?, ?it/s]

Valid loss: 0.185; Valid accuracy: 0.931


In [43]:
from datetime import datetime

In [44]:
training_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

In [45]:
os.getcwd()

'/content/drive/MyDrive/unicamp/ia368v_dd/aula_03'

Save the fine-tuned model to allow later usage...

In [46]:
model.save_pretrained("pretrain_{}".format(training_timestamp))

# Now, test the reranking DL model over a Pyserini BM25 run

In [89]:
MSMARCO_DATASET_FOLDER="msmarco-passage"
TREC_DL_2020_TOPICS_FILENAME="msmarco-test2020-queries.tsv"
TREC_DL_2020_QRELS_FILENAME="2020qrels-pass.txt"

TOKENIZED_TREC_DL_2020_DATA="tokenized_trec-dl_2020_data.pkl"

PYSERINI_TEST_RUN_FILENAME="run.trec-dl_2020-passage.bm25tuned_msmarco-test2020-queries.tsv_20230308_185518.txt"

PYSERINI_TEST_RUN_RERANKED_FILENAME="run.trec-dl_2020-passage.bm25tuned_msmarco-test2020-queries.tsv_20230308_185518_reranked.txt"

PYSERINI_TEST_RUN_RERANKED_FILENAME_FORMAT="run.trec-dl_2020-passage.bm25tuned_msmarco-test2020-queries.tsv_20230308_185518_reranked_{}.txt"

## Load the tokenized test data if it is already available...

The **TREC-DL 2020 Passage ranking topics** and documents (**MS MARCO Passage Retrieval dataset**) could have already been fixed (wrong encodings) and tokenized.

In [48]:
if os.path.exists(TOKENIZED_TREC_DL_2020_DATA):
    with open(TOKENIZED_TREC_DL_2020_DATA, "rb") as inputFile:
        tokenized_data = pickle.load(inputFile)

    trec_topics_tokens = tokenized_data['trec_topics_tokens']
    trec_docs_tokens = tokenized_data['trec_docs_tokens']
    bm25_run_with_data_df = tokenized_data['bm25_run_with_data_df']

## ... Otherwise, prepare the dataset to be tokenized.

This part follows the Pyserini's [tutorial](https://github.com/castorini/pyserini/blob/master/docs/experiments-msmarco-passage.md) on the MS MARCO Passage Retrieval task.

### Load the fixed documents dataset, if available...

In [69]:
if os.path.exists(os.path.join(MSMARCO_DATASET_FOLDER, "fixed_collections.pkl")):
    with open(os.path.join(MSMARCO_DATASET_FOLDER, "fixed_collections.pkl"), 'rb') as inputFile:
        msmarco_passage_df = pickle.load(inputFile)

### ....Otherwise, load the original data and fix it

In [None]:
os.makedirs(MSMARCO_DATASET_FOLDER)

In [None]:
!wget https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz -P msmarco-passage

In [None]:
!tar xvfz msmarco-passage/collectionandqueries.tar.gz -C msmarco-passage

In [None]:
msmarco_passage_df = pd.read_csv(os.path.join(MSMARCO_DATASET_FOLDER, "collection.tsv"), sep='\t', header=None, names=['id', 'text'])

In [None]:
msmarco_passage_df.shape

In [None]:
msmarco_passage_df.head()

In [None]:
msmarco_passage_df.iloc[3]['text']

In [None]:
for to_be_replaced, replacement in CHAR_FIXES.items():
    msmarco_passage_df['text'] = msmarco_passage_df['text'].str.replace(to_be_replaced, replacement)

In [None]:
msmarco_passage_df.iloc[3]['text']

Save the fixed test data for later usage...

In [None]:
with open(os.path.join(MSMARCO_DATASET_FOLDER, "fixed_collections.pkl"), "wb") as outputFile:
    pickle.dump(msmarco_passage_df, outputFile, pickle.HIGHEST_PROTOCOL)

### Load the topics

In [49]:
trec_dl_2020_topics_df = pd.read_csv(TREC_DL_2020_TOPICS_FILENAME, sep='\t', header=None, names=['id', 'text'])

In [50]:
trec_dl_2020_topics_df

Unnamed: 0,id,text
0,1030303,who is aziz hashim
1,1037496,who is rep scalise?
2,1043135,who killed nicholas ii of russia
3,1045109,who owns barnhart crane
4,1049519,who said no one can make you feel inferior
...,...,...
195,985594,where is kampuchea
196,99005,convert sq meter to sq inch
197,997622,where is the show shameless filmed
198,999466,where is velbert


### Now, load the BM25 run

In [51]:
bm25_run_df = pd.read_csv(PYSERINI_TEST_RUN_FILENAME, sep=" ", header=None, names=['topic', '1', 'doc', 'order', 'score', 'comment'])

In [52]:
bm25_run_df.shape

(200000, 6)

In [53]:
bm25_run_df.head()

Unnamed: 0,topic,1,doc,order,score,comment
0,3505,Q0,4711746,1,14.2214,Anserini
1,3505,Q0,3859340,2,14.045,Anserini
2,3505,Q0,7207815,3,13.8724,Anserini
3,3505,Q0,6834658,4,13.5842,Anserini
4,3505,Q0,3829534,5,13.5347,Anserini


### Build the test data to be tokenized

#### First, filter the TREC-DL 2020 topics text and the MS-MARCO Passage texts using the corresponding IDs on the run

Using Pandas DataFrame merge method is blazing fast — at least compared to simple implementations...

In [54]:
%time

filtered_topics = trec_dl_2020_topics_df.merge(bm25_run_df[['topic', 'doc']], left_on='id', right_on='topic', how='inner')[['id', 'text', 'doc']]

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 5.96 µs


In [55]:
filtered_topics[filtered_topics['id'] == 3505]

Unnamed: 0,id,text,doc
119000,3505,how do they do open heart surgery,4711746
119001,3505,how do they do open heart surgery,3859340
119002,3505,how do they do open heart surgery,7207815
119003,3505,how do they do open heart surgery,6834658
119004,3505,how do they do open heart surgery,3829534
...,...,...,...
119995,3505,how do they do open heart surgery,6616085
119996,3505,how do they do open heart surgery,7097495
119997,3505,how do they do open heart surgery,8181171
119998,3505,how do they do open heart surgery,1971051


In [70]:
%time

bm25_run_with_data_df = filtered_topics.merge(msmarco_passage_df, left_on='doc', right_on='id', how='inner')

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 6.68 µs


In [71]:
bm25_run_with_data_df

Unnamed: 0,id_x,text_x,doc,id_y,text_y
0,1030303,who is aziz hashim,8726436,8726436,Share on LinkedInShare on FacebookShare on Twi...
1,1030303,who is aziz hashim,8726435,8726435,Mr. Aziz Hashim has been the President and Sec...
2,1030303,who is aziz hashim,8726429,8726429,"The crew at NRD Holdings, left to right: Karim..."
3,1030303,who is aziz hashim,8726437,8726437,Aziz Hashim is one of the world's leading expe...
4,1030303,who is aziz hashim,7156982,7156982,Rounding out the IFA leadership team is Aziz H...
...,...,...,...,...,...
199995,132622,definition of attempted arson,263255,263255,The Definition of Gambling Disorder. Gambling ...
199996,132622,definition of attempted arson,5825982,5825982,What parents need to know. The film attempts t...
199997,132622,definition of attempted arson,6045582,6045582,Subdivision 1.Flee; definition. For purposes o...
199998,132622,definition of attempted arson,6119686,6119686,Definition of conversion for English Language ...


In [72]:
bm25_run_with_data_df[bm25_run_with_data_df['id_x'] == 3505]

Unnamed: 0,id_x,text_x,doc,id_y,text_y
4254,3505,how do they do open heart surgery,4500749,4500749,(CNN) -- The pediatric surgeon who performed o...
8216,3505,how do they do open heart surgery,4165067,4165067,"1) Why do some people shiver after surgery, ev..."
9266,3505,how do they do open heart surgery,1628597,1628597,"Pete Carroll excited about Thomas Rawls, Chris..."
47560,3505,how do they do open heart surgery,6052850,6052850,There is no obligation to tell your employer a...
51995,3505,how do they do open heart surgery,8296198,8296198,This thread was archived. Please ask a new que...
...,...,...,...,...,...
123769,3505,how do they do open heart surgery,6616085,6616085,0 users have voted. LVAD implantation is an op...
123770,3505,how do they do open heart surgery,7097495,7097495,Daniel Hale Williams - Introduction: African A...
123771,3505,how do they do open heart surgery,8181171,8181171,Melody TPV Therapy does not replace open heart...
123772,3505,how do they do open heart surgery,1971051,1971051,I'd say it's very common in America to say how...


#### Now, tokenize both topics and returned texts

In [73]:
trec_topics_tokens = tokenizer(list(bm25_run_with_data_df['text_x'].to_numpy()), return_length=True)

In [74]:
%time

trec_docs_tokens = tokenizer(list(bm25_run_with_data_df['text_y'].to_numpy()), return_length=True)

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 6.44 µs


#### Check the tokenized data statistics

Check if tokenized topic + tokenized document tokens fits the model entry length (supposed to be 512).

In [75]:
stats.describe(trec_topics_tokens['length'])

DescribeResult(nobs=200000, minmax=(4, 21), mean=8.82, variance=8.267641338206689, skewness=1.007073241118117, kurtosis=1.5977178222274047)

In [76]:
stats.describe(trec_docs_tokens['length'])

DescribeResult(nobs=200000, minmax=(13, 423), mean=80.384575, variance=1096.505789598323, skewness=1.1522800830714979, kurtosis=1.6449787048830995)

Save the tokenized data

In [77]:
with open(TOKENIZED_TREC_DL_2020_DATA, "wb") as outputFile:
    pickle.dump({'trec_topics_tokens': trec_topics_tokens,
                 'trec_docs_tokens': trec_docs_tokens,
                 'bm25_run_with_data_df': bm25_run_with_data_df}, outputFile, pickle.HIGHEST_PROTOCOL)

### Build the concatenated topic + document to feed the model

Once again, remove the 'CLS' token from the documents token sequence.

In [78]:
test_input_ids = []
test_token_type_ids = []
test_attention_mask = []

for i in range(len(trec_topics_tokens['input_ids'])):
    test_input_ids.append(trec_topics_tokens['input_ids'][i] + trec_docs_tokens['input_ids'][i][1:])
    test_token_type_ids.append(trec_topics_tokens['token_type_ids'][i] + trec_docs_tokens['token_type_ids'][i][1:])
    test_attention_mask.append(trec_topics_tokens['attention_mask'][i] + trec_docs_tokens['attention_mask'][i][1:])

In [79]:
x_test = {'input_ids': test_input_ids, 
          'token_type_ids': test_token_type_ids, 
          'attention_mask': test_attention_mask}

Enter this fixed target data just as reference for the Dataset class.

In [80]:
y_test = np.ones(len(trec_topics_tokens['input_ids']), dtype=bool)

Create the dataset and the dataloader

In [81]:
dataset_test = Dataset(x_test, y_test)

Make sure the dataloader preserves the samples order (no shuffling!!!)

In [82]:
batch_size=32

dataloader_test = data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

Define a specific evaluation function for the reranking

In [83]:
def collect_reranking(model, dataloader, set_name):
    losses = []
    scores = []
    
    model.eval()

    with torch.no_grad():
        for batch in tqdm(dataloader, mininterval=0.5, desc=set_name, disable=False):
            outputs = model(**batch.to(device))
            loss_val = outputs.loss
            losses.append(loss_val.cpu().item())

            scores.append(outputs.logits.cpu())

    print(f"{set_name} loss: {mean(losses):0.3f}")

    return scores

Load the pretrained model

In [None]:
# PRETRAINED_MODEL="pretrain_20230313_030341"
PRETRAINED_MODEL="pretrain_{}".format(training_timestamp)

if 'model' not in vars():
    model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_MODEL).to(device)
    print('Parameters', model.num_parameters())

Rerank the BM25 retrieved texts

In [84]:
reranking_scores = collect_reranking(model=model, dataloader=dataloader_test, set_name='trec-dl_2020')

trec-dl_2020:   0%|          | 0/6250 [00:00<?, ?it/s]

trec-dl_2020 loss: 3.893


#### Consider the logit for class 1 (True) as the relevance score

In [85]:
matches_relevance_score = np.concatenate([batch_scores[:][:, 1].numpy() for batch_scores in reranking_scores])

In [86]:
matches_relevance_score.shape

(200000,)

#### Merge the results in the topics x docs dataframe

In [87]:
bm25_run_with_data_df['reranking_scores'] = matches_relevance_score

In [88]:
bm25_run_with_data_df

Unnamed: 0,id_x,text_x,doc,id_y,text_y,reranking_scores
0,1030303,who is aziz hashim,8726436,8726436,Share on LinkedInShare on FacebookShare on Twi...,1.069175
1,1030303,who is aziz hashim,8726435,8726435,Mr. Aziz Hashim has been the President and Sec...,2.074721
2,1030303,who is aziz hashim,8726429,8726429,"The crew at NRD Holdings, left to right: Karim...",-0.152663
3,1030303,who is aziz hashim,8726437,8726437,Aziz Hashim is one of the world's leading expe...,2.129169
4,1030303,who is aziz hashim,7156982,7156982,Rounding out the IFA leadership team is Aziz H...,1.355105
...,...,...,...,...,...,...
199995,132622,definition of attempted arson,263255,263255,The Definition of Gambling Disorder. Gambling ...,-2.483746
199996,132622,definition of attempted arson,5825982,5825982,What parents need to know. The film attempts t...,-2.477626
199997,132622,definition of attempted arson,6045582,6045582,Subdivision 1.Flee; definition. For purposes o...,-2.426031
199998,132622,definition of attempted arson,6119686,6119686,Definition of conversion for English Language ...,-2.488512


#### Save the result in the TREC format

In [90]:
TREC_RESULT_LINE_FORMAT="{}\tQ0\t{}\t{}\t{}\tminiLM_reranking\n"

In [91]:
test_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

In [92]:
with open(PYSERINI_TEST_RUN_RERANKED_FILENAME_FORMAT.format(test_timestamp), 'w') as outputFile:
    for group_name, group_df in bm25_run_with_data_df.groupby('id_x'):
        group_df = group_df.sort_values('reranking_scores', ascending=False).reset_index(drop=True)

        for i, row in group_df.iterrows():
            outputFile.write(TREC_RESULT_LINE_FORMAT.format(group_name, row['doc'], i + 1, row['reranking_scores']))

### Apply TREC metrics

In [96]:
TREC_EVAL_FULLPATH="/content/pyserini/tools/eval/trec_eval.9.0.4/trec_eval"

In [97]:
!{TREC_EVAL_FULLPATH} -c -mrecall.1000 -mmap -mndcg_cut.10 -mrecip_rank \
    2020qrels-pass.txt {PYSERINI_TEST_RUN_RERANKED_FILENAME_FORMAT.format(test_timestamp)}

map                   	all	0.4580
recip_rank            	all	0.9049
recall_1000           	all	0.7331
ndcg_cut_10           	all	0.6615
