<a href="https://colab.research.google.com/github/juliatessler/1s2023-unicamp-dl-for-search-systems/blob/main/10-trade-offs/10_tradeoffs_splade.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SPLADE
by Júlia Ferreira Tessler

I have used several things shared in [this YouTube video](https://www.youtube.com/watch?v=0FQ2WmM0t3w&ab_channel=JamesBriggs) that I wasn't able to find the person who shared ):

In [1]:
!pip install transformers -q
!pip install datasets -q
!pip install pyserini -q
!pip install faiss-cpu -q
!pip install ftfy -q
!pip install evaluate -q

In [2]:
from google.colab import drive

drive.mount('/content/gdrive', force_remount = True)

Mounted at /content/gdrive


In [3]:
workdir = '/content/gdrive/MyDrive/Unicamp/DL_applied_to_IR/Notebooks'

In [4]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Wed May 10 19:36:24 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   63C    P8    10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [5]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
import pandas as pd
import numpy as np
import torch.nn.functional as F
import os
import json

from transformers import (
    DistilBertTokenizer, 
    DistilBertForMaskedLM, 
    BertTokenizer, 
    BertForMaskedLM,
    BatchEncoding
)

from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from collections import defaultdict
from evaluate import load

## Get data

In [7]:
trec_covid_queries = load_dataset("BeIR/trec-covid", 'queries')
trec_covid_corpus = load_dataset("BeIR/trec-covid", 'corpus')



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]

In [8]:
trec_covid_queries

DatasetDict({
    queries: Dataset({
        features: ['_id', 'title', 'text'],
        num_rows: 50
    })
})

In [9]:
trec_covid_queries['queries'][:10]

{'_id': ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10'],
 'title': ['', '', '', '', '', '', '', '', '', ''],
 'text': ['what is the origin of COVID-19',
  'how does the coronavirus respond to changes in the weather',
  'will SARS-CoV2 infected people develop immunity? Is cross protection possible?',
  'what causes death from Covid-19?',
  'what drugs have been active against SARS-CoV or SARS-CoV-2 in animal studies?',
  'what types of rapid testing for Covid-19 have been developed?',
  'are there serological tests that detect antibodies to coronavirus?',
  'how has lack of testing availability led to underreporting of true incidence of Covid-19?',
  'how has COVID-19 affected Canada',
  'has social distancing had an impact on slowing the spread of COVID-19?']}

In [10]:
trec_covid_corpus['corpus'][:10]

{'_id': ['ug7v899j',
  '02tnwd4m',
  'ejv2xln0',
  '2b73a28n',
  '9785vg6d',
  'zjufx4fo',
  '5yhe786e',
  '8zchiykl',
  '8qnrcgnk',
  'jg13scgo'],
 'title': ['Clinical features of culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi Arabia',
  'Nitric oxide: a pro-inflammatory mediator in lung disease?',
  'Surfactant protein-D and pulmonary host defense',
  'Role of endothelin-1 in lung disease',
  'Gene expression in epithelial cells in response to pneumovirus infection',
  'Sequence requirements for RNA strand transfer during nidovirus discontinuous subgenomic RNA synthesis',
  'Debate: Transfusing to normal haemoglobin levels will not improve outcome',
  'The 21st International Symposium on Intensive Care and Emergency Medicine, Brussels, Belgium, 20-23 March 2001',
  'Heme oxygenase-1 and carbon monoxide in pulmonary medicine',
  'Technical Description of RODS: A Real-time Public Health Surveillance System'],
 'text': ['OBJECTIVE: T

## Get model

In [11]:
model_name = 'naver/splade_v2_distil'

tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForMaskedLM.from_pretrained(model_name).to(device)

## Prepare dataset

In [12]:
class SPLADEDataset(Dataset):
  def __init__(self, text, tokenizer, max_length):
    self.text = text
    self.tokenizer = tokenizer
    self.max_length = max_length
  
  def __len__(self):
    return len(self.text)
  
  def __getitem__(self, idx):
    tokenized_text = self.tokenizer(self.text[idx],
                                    padding = True,
                                    truncation = True,
                                    # Autores não removem CLS e SEP, como bem notado
                                    # por Leandro Carísio
                                    # https://github.com/naver/splade/blob/main/inference_splade.ipynb
                                    return_special_tokens_mask = True,
                                    max_length = self.max_length)
    return tokenized_text

In [13]:
max_length = 256
batch_size = 32

In [14]:
def collate_fn(batch):
  return BatchEncoding(tokenizer.pad(batch, return_tensors = 'pt'))

In [15]:
trec_covid_dataset = SPLADEDataset(trec_covid_corpus['corpus']['text'], 
                                   tokenizer, max_length)

trec_covid_dataloader = DataLoader(trec_covid_dataset, 
                                   batch_size = batch_size, 
                                   collate_fn = collate_fn)

## Sparse representation

In [16]:
def get_sparse_representation(model, tokenizer, text, agg='sum'):

    # tokenize and get token_ids
    tokenizer_output = tokenizer(
        text, 
        return_special_tokens_mask = True,
        truncation                 = True, 
        return_tensors             = 'pt',
        max_length                 = max_length
    )

    # propagates and get logits
    with torch.autocast(device_type=str(device), dtype=torch.float16, enabled=True):
        with torch.no_grad():
            model_output = model(
                input_ids      = tokenizer_output['input_ids'].to(device), 
                attention_mask = tokenizer_output['attention_mask'].to(device)
            )
    logits = model_output.logits[0, :]

    # Mask
    mask_valid_tokens = 1 - tokenizer_output['special_tokens_mask'].to(device)
    mask = mask_valid_tokens.squeeze().unsqueeze(-1).expand(logits.size())
    
    # Applying equation from paper to get sparse representation
    if agg == 'sum':
        wj = torch.sum(torch.log(1 + torch.relu(logits*mask)), dim=0)
    else:
        wj, _ = torch.max(torch.log(1 + torch.relu(logits*mask)), dim=0)

    return wj.to_sparse()

def get_tokens_and_values(wj):

    # Getting batck tokens from wj
    tokens = tokenizer.convert_ids_to_tokens(wj.indices()[0])
    return list(zip(tokens, wj.values()))

### Testing

In [17]:
text = "My name is Júlia. I really like dinosaurs."

In [18]:
wj = get_sparse_representation(model, tokenizer, text)
wj

tensor(indices=tensor([[ 1011,  1045,  2003,  2014,  2016,  2017,  2026,  2032,
                         2033,  2040,  2066,  2079,  2171,  2228,  2256,  2293,
                         2360,  2428,  2450,  2643,  2653,  2767,  2941,  3566,
                         3814,  4489,  5203,  5223,  6127,  6423,  7059,  7777,
                        12120, 15799, 18148, 19958]]),
       values=tensor([4.1577e-01, 7.4902e-01, 4.5679e-01, 6.4883e+00,
                      1.2036e-01, 9.1602e-01, 1.9004e+00, 3.9246e-02,
                      4.9878e-01, 3.5820e+00, 1.5938e+00, 2.8076e-01,
                      3.9316e+00, 3.2666e-01, 4.0192e-02, 1.4717e+00,
                      4.0161e-01, 1.2832e+00, 3.2349e-01, 1.1504e+00,
                      1.9512e-03, 6.2469e-02, 3.4375e-01, 7.0947e-01,
                      5.9131e-01, 1.2646e-01, 1.1523e+00, 6.2549e-01,
                      9.1162e-01, 2.8477e+00, 1.7324e+00, 2.6074e-01,
                      7.4658e-01, 2.2598e+00, 2.3691e+00, 5.7178e

In [19]:
tokens_and_values = get_tokens_and_values(wj)
tokens_and_values

[('-', tensor(0.4158, device='cuda:0', dtype=torch.float16)),
 ('i', tensor(0.7490, device='cuda:0', dtype=torch.float16)),
 ('is', tensor(0.4568, device='cuda:0', dtype=torch.float16)),
 ('her', tensor(6.4883, device='cuda:0', dtype=torch.float16)),
 ('she', tensor(0.1204, device='cuda:0', dtype=torch.float16)),
 ('you', tensor(0.9160, device='cuda:0', dtype=torch.float16)),
 ('my', tensor(1.9004, device='cuda:0', dtype=torch.float16)),
 ('him', tensor(0.0392, device='cuda:0', dtype=torch.float16)),
 ('me', tensor(0.4988, device='cuda:0', dtype=torch.float16)),
 ('who', tensor(3.5820, device='cuda:0', dtype=torch.float16)),
 ('like', tensor(1.5938, device='cuda:0', dtype=torch.float16)),
 ('do', tensor(0.2808, device='cuda:0', dtype=torch.float16)),
 ('name', tensor(3.9316, device='cuda:0', dtype=torch.float16)),
 ('think', tensor(0.3267, device='cuda:0', dtype=torch.float16)),
 ('our', tensor(0.0402, device='cuda:0', dtype=torch.float16)),
 ('love', tensor(1.4717, device='cuda:0', dt

### Batch sparse representation

In [20]:
def get_sparse_representation_batch(model, tokenizer, dataloader):

    wjs = None
    for batch in tqdm(dataloader):  
      with torch.autocast(device_type = str(device), dtype = torch.float16, enabled = True): 
        with torch.no_grad():
            model_output = model(
                input_ids      = batch['input_ids'].to(device), 
                attention_mask = batch['attention_mask'].to(device)
            )

      # Getting logits
      logits = model_output.logits

      # Getting mask fthat will be used to generate the sparse vector
      mask_valid_tokens = batch['attention_mask']
      mask = mask_valid_tokens.unsqueeze(-1).expand(logits.size()).to(device)

      # Getting wjs based on mask and logits
      wj, _ = torch.max(torch.log(1 + F.relu(logits*mask)), dim=1)
      wj_sparse = wj.to_sparse()

      # Concatenating into wjs to save
      wjs = wj_sparse if wjs is None else torch.cat([wjs, wj_sparse], dim=0)
      
    return wjs

In [21]:
if not os.path.exists(f'{workdir}/trec_covid_wjs_with_special_tokens.pt'):
    trec_covid_wjs = get_sparse_representation_batch(model, tokenizer, trec_covid_dataloader)
    torch.save(trec_covid_wjs, f'{workdir}/trec_covid_wjs_with_special_tokens.pt')
else:
    trec_covid_wjs = torch.load(f'{workdir}/trec_covid_wjs_with_special_tokens.pt').to(device)

In [22]:
trec_covid_wjs.size()

torch.Size([171332, 30522])

## Search

In [23]:
row_sums = torch.sparse.sum(trec_covid_wjs, dim=1).to_dense()
row_sums

tensor([222.8063, 179.4055, 235.0530,  ..., 191.2339, 223.8661,  11.9450],
       device='cuda:0')

In [24]:
ids_docs = trec_covid_corpus['corpus']['_id']

In [25]:
def get_scores_sparse_vectors(query_vec, 
                              wj_matrix_doc, 
                              normalize_doc = False, 
                              normalize_query = True,
                              hits = 1000):

    scores = torch.sparse.mm(wj_matrix_doc.to(torch.half), query_vec.unsqueeze(-1).to(torch.half)).to_dense()
    scores = scores.squeeze(1) 

    # Normalizations
    if normalize_doc:
        scores = torch.div(scores, row_sums)
    if normalize_query:
        scores = scores / query_vec.sum().item()

    # Getting ids and scores
    docs_score = list(zip(ids_docs, scores))
    docs_score = sorted(docs_score, key=lambda x: x[1].item(), reverse=True)

    return docs_score[:hits]

In [26]:
trec_queries_df = pd.DataFrame()
trec_queries_df['_id'] = trec_covid_queries['queries']['_id']
trec_queries_df['text'] = trec_covid_queries['queries']['text']

In [27]:
run = {
    'query_id':   [],
    'passage_id': [],
    'score':      [],
    'rank':       [],
    'Q0':         [],
    'run':        []
}

# não consegui descobrir como iterar corretamente em json
for i, row in trec_queries_df.iterrows():

    # Vectorizing query
    vec_query = get_sparse_representation(model, tokenizer, row['text'], agg = 'max')

    # getting score
    docs_score = get_scores_sparse_vectors(vec_query, trec_covid_wjs)

    passage_ids = [x[0]        for x in docs_score]
    scores      = [x[1].item() for x in docs_score]

    # appending in run
    run['query_id'].extend([row['_id']] * 1000)
    run['passage_id'].extend(passage_ids)
    run['score'].extend(scores)
    run['rank'].extend(list(range(1,1001)))
    run['Q0'].extend(['Q0'] * 1000)
    run['run'].extend(['vec_query_vs_wj_passages'] * 1000)

df_run = pd.DataFrame(run)
df_run = df_run[['query_id', 'Q0', 'passage_id', 'rank', 'score', 'run']]

print(df_run.shape)
df_run.head()

(50000, 6)


Unnamed: 0,query_id,Q0,passage_id,rank,score,run
0,1,Q0,1mjaycee,1,1.297852,vec_query_vs_wj_passages
1,1,Q0,zrycqlvs,2,1.290039,vec_query_vs_wj_passages
2,1,Q0,cgvj10r2,3,1.287109,vec_query_vs_wj_passages
3,1,Q0,uv4gbhbb,4,1.287109,vec_query_vs_wj_passages
4,1,Q0,0wm6u10a,5,1.286133,vec_query_vs_wj_passages


In [28]:
df_run.to_csv(f'{workdir}/run.sparse_rank.txt', sep='\t', header=None, index=None)

In [29]:
# Qrels with all relevances
!wget https://huggingface.co/datasets/BeIR/trec-covid-qrels/raw/main/test.tsv

--2023-05-10 19:39:39--  https://huggingface.co/datasets/BeIR/trec-covid-qrels/raw/main/test.tsv
Resolving huggingface.co (huggingface.co)... 18.160.249.78, 18.160.249.31, 18.160.249.70, ...
Connecting to huggingface.co (huggingface.co)|18.160.249.78|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 980831 (958K) [text/plain]
Saving to: ‘test.tsv.1’


2023-05-10 19:39:39 (17.1 MB/s) - ‘test.tsv.1’ saved [980831/980831]



In [30]:
with open('test.tsv', 'r') as fin:
  data = fin.read().splitlines(True)
with open('qrels_format.tsv', 'w') as fout:
  for linha in data[1:]:
    campos = linha.split()
    fout.write(f'{campos[0]}\t0\t{campos[1]}\t{campos[2]}\n')


In [31]:
!python -m pyserini.eval.trec_eval -c -m ndcg_cut.10 qrels_format.tsv {workdir}/run.sparse_rank.txt #type: ign

Downloading https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar to /root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar...
/root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar already exists!
Skipping download.
Running command: ['java', '-jar', '/root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar', '-c', '-m', 'ndcg_cut.10', 'qrels_format.tsv', '/content/gdrive/MyDrive/Unicamp/DL_applied_to_IR/Notebooks/run.sparse_rank.txt']
Results:
ndcg_cut_10           	all	0.6082


---
# New code

## InPars to generate dataset for trec-covid

## Getting dataset from all students

This code is a courtesy from [Marcos Piau](https://huggingface.co/datasets/unicamp-dl/trec-covid-experiment/blob/main/sugestao_uso_dataset.ipynb) 

In [94]:
import pandas as pd
import numpy as np
# import pytorch_lightning as pl
import json
import time
import ftfy

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    BatchEncoding,
    AdamW,
    get_linear_schedule_with_warmup
)

from datasets import load_dataset
# from langchain import (
#     HuggingFacePipeline,
#     HuggingFaceHub,
#     PromptTemplate, 
#     LLMChain
# )
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM
)
# from langchain.llms import OpenAI
# from langchain.chat_models import ChatOpenAI
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from collections import defaultdict
from pyserini.search.lucene import LuceneSearcher
from sklearn.model_selection import train_test_split
from statistics import mean

In [67]:
!pip install huggingface_hub -q

In [72]:
from getpass import getpass

HUGGINGFACEHUB_API_TOKEN = getpass()

··········


In [73]:
os.environ["HUGGINGFACEHUB_API_TOKEN"] = HUGGINGFACEHUB_API_TOKEN

In [76]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [77]:
ds = load_dataset('unicamp-dl/trec-covid-experiment')

Downloading builder script:   0%|          | 0.00/2.71k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/21 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/309 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/346 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/185k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/255k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/152k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/311k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/627k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/280k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/307k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/238k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/237k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/356k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/266k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/249k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/241k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/224k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/11.2M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/188k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/244k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/234k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.98M [00:00<?, ?B/s]

Generating example split: 0 examples [00:00, ? examples/s]

Generating example2 split: 0 examples [00:00, ? examples/s]

Generating eduseiti_100_queries_expansion_20230501_01 split: 0 examples [00:00, ? examples/s]

Generating leandro_carisio_01 split: 0 examples [00:00, ? examples/s]

Generating thales_1k_generated_queries_20230429 split: 0 examples [00:00, ? examples/s]

Generating manoel_1k_generated_queries_20230430 split: 0 examples [00:00, ? examples/s]

Generating manoel_2k_generated_queries_20230501 split: 0 examples [00:00, ? examples/s]

Generating thiago_laitz_1k_queries split: 0 examples [00:00, ? examples/s]

Generating mirelle_1k_generated_queries_20230501 split: 0 examples [00:00, ? examples/s]

Generating hugo_padovani_query_generation split: 0 examples [00:00, ? examples/s]

Generating marcus_borela_1k_gptj6b_20230501 split: 0 examples [00:00, ? examples/s]

Generating juliatessler_1000_queries split: 0 examples [00:00, ? examples/s]

Generating pedro_holanda_1k_generated_queries_20230502 split: 0 examples [00:00, ? examples/s]

Generating leonardo_avila_queries_v1 split: 0 examples [00:00, ? examples/s]

Generating marcus_borela_1k_gptj6b_20230501_v2 split: 0 examples [00:00, ? examples/s]

Generating gustavo_1k_cohere split: 0 examples [00:00, ? examples/s]

Generating marcospiau_1k_v1 split: 0 examples [00:00, ? examples/s]

Generating pedrogengo_queries_inparsv1 split: 0 examples [00:00, ? examples/s]

Generating ricardo_primi_1k split: 0 examples [00:00, ? examples/s]

Generating thiago_vieira_1k_queries split: 0 examples [00:00, ? examples/s]

Generating eduseiti_1000_queries_expansion_20230502_02 split: 0 examples [00:00, ? examples/s]

  0%|          | 0/21 [00:00<?, ?it/s]

In [78]:
df = pd.concat((v.to_pandas().assign(origin=k) for k,v in ds.items()),
               ignore_index=True)
df.head()

Unnamed: 0,query,positive_doc_id,negative_doc_ids,origin
0,This is a example query 1,doc1,"[xxx, yyy, zzz]",example
1,This is another example query,doc2,"[aaa, bbb, ccc]",example
2,Example of query with no negative doc_ids,doc2,[],example
3,This is a example query 1 (file 2),doc12222,"[xxx, yyy, zzz]",example2
4,This is another example query (file 2),doc12345,"[aaa, bbb, ccc]",example2


In [79]:
df = df.drop(df[(df.origin == 'example') | (df.origin == 'example2')].index)
df.head()

Unnamed: 0,query,positive_doc_id,negative_doc_ids,origin
6,How can chatbots be designed to effectively sh...,70hskj1o,"[mt00852w, x7ol32mz, b54dymlu, h5vh6px7, bza9a...",eduseiti_100_queries_expansion_20230501_01
7,What strategies can be used to encourage desir...,70hskj1o,"[et84j0qi, xsfolppr, 5t2o287y, kj2tnw8q, j68x0...",eduseiti_100_queries_expansion_20230501_01
8,What are the risks associated with amplifying ...,70hskj1o,"[2c1m04je, rd93y7hu, vlmvi0tf, dbq3z982, 848fs...",eduseiti_100_queries_expansion_20230501_01
9,What research has been conducted on the effect...,70hskj1o,"[49zlztqu, amjqr9hr, hpx4723v, e790rxq9, 95bso...",eduseiti_100_queries_expansion_20230501_01
10,How can collaborations between healthcare work...,70hskj1o,"[eg2lj9zc, prmf9yob, ara8bsws, zjmshwl3, apvc5...",eduseiti_100_queries_expansion_20230501_01


In [80]:
df.shape

(23585, 4)

In [81]:
def compute_len(negative_docs_list_size):
  return len(negative_docs_list_size)

df['negative_docs_list_size'] = df['negative_doc_ids'].map(compute_len)

In [82]:
df = df[df['negative_docs_list_size'] > 0]
df.shape

(22585, 5)

In [83]:
# This ready functions came from Mirelle

def search_in_corpus(doc_id, corpus):
  found = corpus[corpus['_id'] == doc_id]
  title = found['title'].to_list()[0]
  doc = found['text'].to_list()[0]
  return title + ' ' + doc

# df format = query label hypotesis
def format_data(df_all, corpus):
  data = {
      'query': [],
      'label': [],
      'passage': []
  }

  for idx, row in tqdm(df_all.iterrows(), total=len(df_all)):
    #row pos
    data['query'].append(row['query'])
    data['passage'].append(search_in_corpus(row['positive_doc_id'],
                                             corpus))
    data['label'].append(True)

    #row neg
    data['query'].append(row['query'])
    data['passage'].append(search_in_corpus(row['negative_doc_ids'][0],
                                             corpus))
    data['label'].append(False)
  return data

In [84]:
df_trec_covid_corpus = pd.DataFrame(trec_covid_corpus['corpus'])
df_trec_covid_corpus.head()

Unnamed: 0,_id,title,text
0,ug7v899j,Clinical features of culture-proven Mycoplasma...,OBJECTIVE: This retrospective chart review des...
1,02tnwd4m,Nitric oxide: a pro-inflammatory mediator in l...,Inflammatory diseases of the respiratory tract...
2,ejv2xln0,Surfactant protein-D and pulmonary host defense,Surfactant protein-D (SP-D) participates in th...
3,2b73a28n,Role of endothelin-1 in lung disease,Endothelin-1 (ET-1) is a 21 amino acid peptide...
4,9785vg6d,Gene expression in epithelial cells in respons...,Respiratory syncytial virus (RSV) and pneumoni...


In [85]:
df_all = pd.DataFrame(format_data(df, df_trec_covid_corpus))
df_all

  0%|          | 0/22585 [00:00<?, ?it/s]

Unnamed: 0,query,label,passage
0,How can chatbots be designed to effectively sh...,True,Chatbots in the fight against the COVID-19 pan...
1,How can chatbots be designed to effectively sh...,False,You Need a Plan: A Stepwise Protocol for Opera...
2,What strategies can be used to encourage desir...,True,Chatbots in the fight against the COVID-19 pan...
3,What strategies can be used to encourage desir...,False,Using Thinkalouds to Understand Rule Learning ...
4,What are the risks associated with amplifying ...,True,Chatbots in the fight against the COVID-19 pan...
...,...,...,...
45165,Does the N-terminus domain of GBF1 have any ro...,False,A Self-stabilizing One-To-Many Node Disjoint P...
45166,How does the absence of p115 and Rab1b influen...,True,Poliovirus Replication Requires the N-terminus...
45167,How does the absence of p115 and Rab1b influen...,False,Time-varying human mobility patterns with meta...
45168,Does the Sec7 domain of GBF1 contribute to pol...,True,Poliovirus Replication Requires the N-terminus...


## Dataset Preparation

### Train/test split

In [87]:
df_train, df_test = train_test_split(df_all, 
                                     train_size = 0.9)

df_train.shape, df_test.shape

((40653, 3), (4517, 3))

### PyTorch Dataset, DataLoader & Trainer classes

In [88]:
max_seq_length = 512
batch_size = 16        # T4: 16, V100: 32, A100: 64
lr = 5e-5
epochs = 10

In [95]:
model_name = 'cross-encoder/ms-marco-MiniLM-L-6-v2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)

optimizer = AdamW(model.parameters(), lr = lr)




In [96]:
# Got it from Carísio
class Dataset(Dataset):
  # Recebe um dataframe do pandas. Precisa ter as colunas query, passage e label (0/1)
  def __init__(self, tokenizer, df, max_seq_length):
    self.max_seq_length = max_seq_length
    self.tokenizer = tokenizer

    # Já concatenas as query com as passagens e guarda em uma lista
    query_passage = df['query'] + ' [SEP] ' + df['passage']
    self.query_passage = query_passage.tolist()
    # Converte os labels para inteiros e guarda em uma lista
    self.labels = df.label.tolist()
    self.labels = [float(x) for x in self.labels]

    # Cria um cache vazio. Como tem treino em algumas épocas, guarda o encode no cache
    self.cache = {}

  def __len__(self):
    return len(self.query_passage)
  
  def get_token_type_ids(self, input_ids):
    idx_sep = input_ids.index(102)+1
    tam_seq = len(input_ids)
    token_type_ids = [0]*idx_sep + [1]*(tam_seq - idx_sep)

    # Apesar do tokenizer fazer isso, não precisa pois o attention_mask já zera.
    # for i in range(len(token_type_ids)):
    #   token_type_ids[i] = token_type_ids[i] if input_ids[i] != 0 else 0

    return token_type_ids

  def get_token_type_ids_from_slice(self, idx, matriz_input_ids):
    if isinstance(idx, slice):
      token_types = []
      for i in range(idx.start or 0, idx.stop or len(matriz_input_ids), idx.step or 1):
        token_types.append(self.get_token_type_ids(matriz_input_ids[i]))
      return token_types
    else:
      return self.get_token_type_ids(matriz_input_ids)

  def get_input_ids_e_labels(self, idx):
    input_ids_e_labels = self.tokenizer(self.query_passage[idx],
                                padding=True,
                                truncation=True,
                                max_length=self.max_seq_length)
    input_ids_e_labels['labels'] = self.labels[idx]

    input_ids_e_labels['token_type_ids'] = self.get_token_type_ids_from_slice(idx, input_ids_e_labels['input_ids'])

    return input_ids_e_labels

  def __getitem__(self, idx):
    # Guarda os itens tokenizados num dict e apenas recupera de lá, pra não ter que ficar tokenizando a cada época
    # Como estamos guardando no dict e idx é um slice, é necessário converter ele pra algo mapeável
    self.cache[str(idx)] = self.cache.get(str(idx), self.get_input_ids_e_labels(idx))
    return self.cache[str(idx)]
    

In [97]:
dataset_train = Dataset(tokenizer, df_train, max_seq_length)
dataset_val = Dataset(tokenizer, df_test, max_seq_length)

collate_fn = lambda batch: BatchEncoding(tokenizer.pad(batch, return_tensors = 'pt'))
dataloader_train = DataLoader(dataset_train, 
                              batch_size = batch_size, 
                              shuffle = False, 
                              collate_fn = collate_fn)
dataloader_val = DataLoader(dataset_val, 
                            batch_size = batch_size, 
                            shuffle = False, 
                            collate_fn = collate_fn)

## Train model

In [98]:
def evaluate(model, dataloader, set_name):
  losses = []
  correct = 0
  model.eval()
  with torch.no_grad():
    for batch in tqdm(dataloader, mininterval=0.5, desc=set_name, disable=False):
      outputs = model(**batch.to(device))
      loss_val = outputs.loss
      losses.append(loss_val.cpu().item())
      # Só tem uma classe. Joga pra sigmoide e arredonda pro inteiro mais próximo
      preds = torch.round(torch.sigmoid(outputs.logits))
      correct += (preds.squeeze() == batch['labels']).sum().item()

  print(f'{set_name} loss: {mean(losses):0.3f}; {set_name} accuracy: {correct / len(dataloader.dataset):0.3f}')

def automodel_train(model, optimizer, dataloader_train, dataloader_val, epoch_inicial, epochs):
  num_training_steps = epochs * len(dataloader_train)
  # Warm up is important to stabilize training.
  num_warmup_steps = int(num_training_steps * 0.1)
  scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

  evaluate(model=model, dataloader=dataloader_val, set_name='Validation')

  # Training loop
  for epoch in tqdm(range(epoch_inicial, epochs), desc='Epochs'):
    model.train()
    train_losses = []
    for batch in tqdm(dataloader_train, mininterval=0.5, desc='Train', disable=False):
      optimizer.zero_grad()
      outputs = model(**batch.to(device))
      loss = outputs.loss
      loss.backward()
      optimizer.step()
      scheduler.step()
      train_losses.append(loss.cpu().item())

    print(f'Epoch: {epoch + 1} Training loss: {mean(train_losses):0.2f}')
    model.save_pretrained(f'{workdir}/inpars-model/{epoch+1}/')
    evaluate(model=model, dataloader=dataloader_val, set_name='Validation')
    print('---------------------------------------------------------------------')

In [None]:
%%time
automodel_train(model, optimizer, dataloader_train, dataloader_val, 0, epochs = 20)

Validation:   0%|          | 0/283 [00:00<?, ?it/s]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epochs:   0%|          | 0/20 [00:00<?, ?it/s]

Train:   0%|          | 0/2541 [00:00<?, ?it/s]

Validation:   0%|          | 0/283 [00:00<?, ?it/s]

Train:   0%|          | 0/2541 [00:00<?, ?it/s]

Validation:   0%|          | 0/283 [00:00<?, ?it/s]

Train:   0%|          | 0/2541 [00:00<?, ?it/s]

Validation:   0%|          | 0/283 [00:00<?, ?it/s]

Train:   0%|          | 0/2541 [00:00<?, ?it/s]

Validation:   0%|          | 0/283 [00:00<?, ?it/s]

Train:   0%|          | 0/2541 [00:00<?, ?it/s]

Validation:   0%|          | 0/283 [00:00<?, ?it/s]

Train:   0%|          | 0/2541 [00:00<?, ?it/s]

Validation:   0%|          | 0/283 [00:00<?, ?it/s]

Train:   0%|          | 0/2541 [00:00<?, ?it/s]

Validation:   0%|          | 0/283 [00:00<?, ?it/s]

Train:   0%|          | 0/2541 [00:00<?, ?it/s]

Validation:   0%|          | 0/283 [00:00<?, ?it/s]

Train:   0%|          | 0/2541 [00:00<?, ?it/s]

Validation:   0%|          | 0/283 [00:00<?, ?it/s]

Train:   0%|          | 0/2541 [00:00<?, ?it/s]

Validation:   0%|          | 0/283 [00:00<?, ?it/s]

Train:   0%|          | 0/2541 [00:00<?, ?it/s]

Validation:   0%|          | 0/283 [00:00<?, ?it/s]

Train:   0%|          | 0/2541 [00:00<?, ?it/s]

Validation:   0%|          | 0/283 [00:00<?, ?it/s]

Train:   0%|          | 0/2541 [00:00<?, ?it/s]

Validation:   0%|          | 0/283 [00:00<?, ?it/s]

Train:   0%|          | 0/2541 [00:00<?, ?it/s]

Validation:   0%|          | 0/283 [00:00<?, ?it/s]

Train:   0%|          | 0/2541 [00:00<?, ?it/s]

Validation:   0%|          | 0/283 [00:00<?, ?it/s]

Train:   0%|          | 0/2541 [00:00<?, ?it/s]

## BM25 + Rerank
Very much copied from Mirelle

In [None]:
%%shell
cd /content/ &&  git clone --recurse-submodules https://github.com/castorini/pyserini.git
cd pyserini
cd tools/eval && tar xvfz trec_eval.9.0.4.tar.gz && cd trec_eval.9.0.4 && make && cd ../../..
cd tools/eval/ndeval && make && cd ../../..

In [None]:
!pip install pyserini faiss intel-openmp nltk --quiet
!apt install libomp-dev
%cd /content
!rm -rf pygaggle && pip uninstall -y pygaggle
!git clone  --recursive https://github.com/castorini/pygaggle.git
%cd pygaggle
! pip install --editable . --quiet
! pip install gensim==4.2.0 jsonlines --quiet
! pip install faiss-cpu --no-cache --quiet

In [None]:
!cd /content/ && ls

gdrive	pygaggle  pyserini  sample_data


In [None]:
data_out = []
for idx, row in trec_covid_corpus['corpus'].to_pandas().iterrows():
  segment = row['title'] + ' '+ row['text']
  data_out.append({'id':row['_id'], 'contents':segment})
with open('/content/trec-covid-corpus/corpus.jsonl', 'w') as fout:
    json.dump(data_out, fout)

In [None]:
! python -m pyserini.index.lucene \
  --collection JsonCollection \
  --input /content/trec-covid-corpus \
  --index /content/Index_BM25 \
  --generator DefaultLuceneDocumentGenerator \
  --threads 9 \
  --storePositions --storeDocvectors --storeRaw

2023-05-04 01:25:44,149 INFO  [main] index.IndexCollection (IndexCollection.java:380) - Setting log level to INFO
2023-05-04 01:25:44,151 INFO  [main] index.IndexCollection (IndexCollection.java:383) - Starting indexer...
2023-05-04 01:25:44,151 INFO  [main] index.IndexCollection (IndexCollection.java:385) - DocumentCollection path: /content/trec-covid-corpus
2023-05-04 01:25:44,151 INFO  [main] index.IndexCollection (IndexCollection.java:386) - CollectionClass: JsonCollection
2023-05-04 01:25:44,152 INFO  [main] index.IndexCollection (IndexCollection.java:387) - Generator: DefaultLuceneDocumentGenerator
2023-05-04 01:25:44,152 INFO  [main] index.IndexCollection (IndexCollection.java:388) - Threads: 9
2023-05-04 01:25:44,153 INFO  [main] index.IndexCollection (IndexCollection.java:389) - Language: en
2023-05-04 01:25:44,153 INFO  [main] index.IndexCollection (IndexCollection.java:390) - Stemmer: porter
2023-05-04 01:25:44,154 INFO  [main] index.IndexCollection (IndexCollection.java:391

In [None]:
from pyserini.search.lucene import LuceneSearcher
from pygaggle.rerank.base import hits_to_texts


def get_results(path_out_bm25='/content/output_bm25', qrys=None,k=100):
  searcher = LuceneSearcher(path_out_bm25)

  results = {}
  for key, value in qrys.items():
    results_found = [found.docid for found in searcher.search(value, k)]
    scores = [found.score for found in searcher.search(value, k)]
    results[str(key)] = {'query':value, 'founds':results_found, 'scores':scores}
  return results


In [None]:
df_queries = trec_covid_queries['queries'].to_pandas()
df_queries.drop('title', axis = 1, inplace = True)
df_queries.head()

Unnamed: 0,_id,text
0,1,what is the origin of COVID-19
1,2,how does the coronavirus respond to changes in...
2,3,will SARS-CoV2 infected people develop immunit...
3,4,what causes death from Covid-19?
4,5,what drugs have been active against SARS-CoV o...


In [None]:
df_queries.shape

(50, 2)

In [None]:
results = get_results(path_out_bm25 = '/content/Index_BM25',  
                      qrys = dict(df_queries.values),
                      k = 1000)

In [None]:
print('Total qrels: ', len(results), ' --Total docs ids founds: ', len(results['1']['founds']))


Total qrels:  50  --Total docs ids founds:  1000


In [None]:
# qrels = load_dataset('beir/trec-covid-qrels')
!wget https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/trec-covid.zip
!unzip trec-covid.zip


--2023-05-04 02:07:32--  https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/trec-covid.zip
Resolving public.ukp.informatik.tu-darmstadt.de (public.ukp.informatik.tu-darmstadt.de)... 130.83.167.186
Connecting to public.ukp.informatik.tu-darmstadt.de (public.ukp.informatik.tu-darmstadt.de)|130.83.167.186|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 73876720 (70M) [application/zip]
Saving to: ‘trec-covid.zip.1’


2023-05-04 02:07:33 (45.9 MB/s) - ‘trec-covid.zip.1’ saved [73876720/73876720]

Archive:  trec-covid.zip
replace trec-covid/qrels/test.tsv? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: trec-covid/qrels/test.tsv  
replace trec-covid/corpus.jsonl? [y]es, [n]o, [A]ll, [N]one, [r]ename: A
  inflating: trec-covid/corpus.jsonl  
  inflating: trec-covid/queries.jsonl  


In [None]:
def load_file(path):
  qrls = {}
  qrels_file = pd.read_csv(path, sep='\t')

  for i, row in qrels_file.iterrows():
    qid = str(row['query-id'])
    if qid in qrls:
      qrls[qid]['doc_ids'].append(row['corpus-id'])
      qrls[qid]['rating'].append(row['score'])
    else:
      qrls[qid] = {'doc_ids':[row['corpus-id']], 'rating':[row['score']]}
  return qrls

In [None]:
qrels = load_file('trec-covid/qrels/test.tsv')

In [None]:
from rank_eval import Qrels, Run, evaluate

#Configs vars
qrels_ = Qrels()
qrels_.add_multi(q_ids=results.keys(),
                doc_ids=[qrels[k]['doc_ids'] for k in results.keys()],
                scores=[qrels[k]['rating'] for k in results.keys()])
run = Run()
run.add_multi(
    q_ids=results.keys(),
    doc_ids=[results[k]["founds"] for k in results.keys() ],
    scores=[results[k]["scores"] for k in results.keys()],
)


In [None]:
evaluate(qrels_, run, ["mrr","ndcg@10"]) # base bm25

{'mrr': 0.8528571428571429, 'ndcg@10': 0.5946917010118077}

## Reranking with BERT

In [32]:
import random
import torch
import torch.nn.functional as F
import numpy as np

from pyserini.search.lucene import LuceneSearcher
from pyserini.index.lucene import IndexReader
from pyserini.search import get_topics, get_qrels
from torch.utils import data

import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup, BatchEncoding

from torch import nn
from torch import optim
from tqdm.auto import tqdm
from statistics import mean

In [33]:
max_length = 356

class MSMARCODataset(Dataset):
    def __init__(self, tokenizer, query, documents, targets, max_lenght = 356):
        self.tokenizer = tokenizer
        self.query = query
        self.documents = documents
        self.targets = targets
        self.max_lenght = max_lenght
    
    def __len__(self):
        return len(self.query)

    def __getitem__(self, idx):
        query_doc_tuple = (self.query[idx]+ ', ' + self.documents[idx])
        query_doc_token = self.tokenizer(query_doc_tuple,
                                           max_length = self.max_lenght, 
                                           truncation = True,
                                           padding = "max_length", 
                                           return_tensors = 'pt')


        return {'input_ids': torch.squeeze(query_doc_token['input_ids']).long().to(device), \
               'attention_mask': torch.squeeze(query_doc_token['attention_mask']).long().to(device), \
               'labels': int(self.targets[idx])}

In [34]:
model_name = 'microsoft/MiniLM-L12-H384-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

### Test dataset

In [35]:
!wget https://huggingface.co/datasets/BeIR/trec-covid-qrels/raw/main/test.tsv

--2023-05-10 19:39:48--  https://huggingface.co/datasets/BeIR/trec-covid-qrels/raw/main/test.tsv
Resolving huggingface.co (huggingface.co)... 13.249.85.16, 13.249.85.92, 13.249.85.127, ...
Connecting to huggingface.co (huggingface.co)|13.249.85.16|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 980831 (958K) [text/plain]
Saving to: ‘test.tsv.2’


2023-05-10 19:39:49 (15.0 MB/s) - ‘test.tsv.2’ saved [980831/980831]



In [36]:
!mv test.tsv {workdir}/trec-covid

In [37]:
qrel = pd.read_csv(f"{workdir}/trec-covid/test.tsv", 
                   sep = "\t",
                   header=None, 
                   skiprows = 1, 
                   names = ["query", "docid", "rel"])
qrel["q0"] = "q0"
qrel = qrel.to_dict(orient = "list")

### Queries

In [38]:
!wget https://huggingface.co/datasets/BeIR/trec-covid/resolve/main/queries.jsonl.gz

--2023-05-10 19:39:49--  https://huggingface.co/datasets/BeIR/trec-covid/resolve/main/queries.jsonl.gz
Resolving huggingface.co (huggingface.co)... 13.249.85.16, 13.249.85.92, 13.249.85.127, ...
Connecting to huggingface.co (huggingface.co)|13.249.85.16|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/a8/10/a810e88b0e7b233be82b89c1fa6ec2d75efc6d55784c2ada9dcac8434a634f3a/9eadcc2cdf140addc9dae83648bb2c6611f5e4b66eaed7475fa5a0ca48eda371?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27queries.jsonl.gz%3B+filename%3D%22queries.jsonl.gz%22%3B&response-content-type=application%2Fgzip&Expires=1684006789&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2E4LzEwL2E4MTBlODhiMGU3YjIzM2JlODJiODljMWZhNmVjMmQ3NWVmYzZkNTU3ODRjMmFkYTlkY2FjODQzNGE2MzRmM2EvOWVhZGNjMmNkZjE0MGFkZGM5ZGFlODM2NDhiYjJjNjYxMWY1ZTRiNjZlYWVkNzQ3NWZhNWEwY2E0OGVkYTM3MT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9u

In [39]:
!mv queries.jsonl.gz {workdir}/trec-covid
!gunzip {workdir}/trec-covid/queries.jsonl.gz

gzip: /content/gdrive/MyDrive/Unicamp/DL_applied_to_IR/Notebooks/trec-covid/queries.jsonl already exists; do you wish to overwrite (y or n)? y


In [63]:
queries = pd.read_json(f"{workdir}/trec-covid/queries.jsonl", lines = True)
queries.head()

Unnamed: 0,_id,text,metadata
0,1,what is the origin of COVID-19,"{'query': 'coronavirus origin', 'narrative': '..."
1,2,how does the coronavirus respond to changes in...,{'query': 'coronavirus response to weather cha...
2,3,will SARS-CoV2 infected people develop immunit...,"{'query': 'coronavirus immunity', 'narrative':..."
3,4,what causes death from Covid-19?,{'query': 'how do people die from the coronavi...
4,5,what drugs have been active against SARS-CoV o...,"{'query': 'animal models of COVID-19', 'narrat..."


### Corpus

In [44]:
!wget https://huggingface.co/datasets/BeIR/trec-covid/resolve/main/corpus.jsonl.gz

--2023-05-10 19:44:23--  https://huggingface.co/datasets/BeIR/trec-covid/resolve/main/corpus.jsonl.gz
Resolving huggingface.co (huggingface.co)... 18.160.249.31, 18.160.249.9, 18.160.249.70, ...
Connecting to huggingface.co (huggingface.co)|18.160.249.31|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/a8/10/a810e88b0e7b233be82b89c1fa6ec2d75efc6d55784c2ada9dcac8434a634f3a/e9e97686e3138eaff989f67c04cd32e8f8f4c0d4857187e3f180275b23e24e85?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27corpus.jsonl.gz%3B+filename%3D%22corpus.jsonl.gz%22%3B&response-content-type=application%2Fgzip&Expires=1684007063&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2E4LzEwL2E4MTBlODhiMGU3YjIzM2JlODJiODljMWZhNmVjMmQ3NWVmYzZkNTU3ODRjMmFkYTlkY2FjODQzNGE2MzRmM2EvZTllOTc2ODZlMzEzOGVhZmY5ODlmNjdjMDRjZDMyZThmOGY0YzBkNDg1NzE4N2UzZjE4MDI3NWIyM2UyNGU4NT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uP

In [45]:
!mv corpus.jsonl.gz {workdir}/trec-covid
!gunzip {workdir}/trec-covid/corpus.jsonl.gz

gzip: /content/gdrive/MyDrive/Unicamp/DL_applied_to_IR/Notebooks/trec-covid/corpus.jsonl already exists; do you wish to overwrite (y or n)? y


In [64]:
corpus = pd.read_json(f"{workdir}/trec-covid/corpus.jsonl", lines = True)
corpus.head()

Unnamed: 0,_id,title,text,metadata
0,ug7v899j,Clinical features of culture-proven Mycoplasma...,OBJECTIVE: This retrospective chart review des...,{'url': 'https://www.ncbi.nlm.nih.gov/pmc/arti...
1,02tnwd4m,Nitric oxide: a pro-inflammatory mediator in l...,Inflammatory diseases of the respiratory tract...,{'url': 'https://www.ncbi.nlm.nih.gov/pmc/arti...
2,ejv2xln0,Surfactant protein-D and pulmonary host defense,Surfactant protein-D (SP-D) participates in th...,{'url': 'https://www.ncbi.nlm.nih.gov/pmc/arti...
3,2b73a28n,Role of endothelin-1 in lung disease,Endothelin-1 (ET-1) is a 21 amino acid peptide...,{'url': 'https://www.ncbi.nlm.nih.gov/pmc/arti...
4,9785vg6d,Gene expression in epithelial cells in respons...,Respiratory syncytial virus (RSV) and pneumoni...,{'url': 'https://www.ncbi.nlm.nih.gov/pmc/arti...


In [None]:
def collate_fn(batch):
    return BatchEncoding(tokenizer.pad(batch, return_tensors = 'pt'))

# Convert examples to Pytorch's Dataset.
train_dataset = MSMARCODataset(
    tokenizer, 
    df_train['query'].to_list(), 
    df_train['document'].to_list(), 
    df_train['relevance'].to_list()
)

valid_dataset = MSMARCODataset(
    tokenizer, 
    df_valid['query'].to_list(), 
    df_valid['document'].to_list(), 
    df_valid['relevance'].to_list()
)


test_dataset = MSMARCODataset(
    tokenizer, 
    df_test['query'].to_list(), 
    df_test['document'].to_list(), 
    df_test['relevance'].to_list()
)

# Convert examples to Pytorch's DataLoader.
dataloader_train = DataLoader(train_dataset, batch_size = 32, shuffle = True, collate_fn = collate_fn)
dataloader_valid = DataLoader(valid_dataset, batch_size = 32, shuffle = True, collate_fn = collate_fn)
dataloader_test = DataLoader(test_dataset, batch_size = 32, shuffle = False, collate_fn = collate_fn)