# 6th Pipeline: Pyserini BM25 + SPLADE Reranking + MonoT5 Reranking

Author: Monique Monteiro (moniquelouise@gmail.com)

Inspired by Leonardo Avila's notebook (https://colab.research.google.com/drive/1o-aMaptESHNLH9w9wUcO5Wz0N_VCChlG?usp=sharing#scrollTo=fYzK8SB9QG7l)

In [1]:
%%shell
pip install transformers -q
pip install datasets -q
pip install evaluate -q
pip install trectools -q
pip install torch -q
pip install faiss-cpu -q
pip install pyserini -q
pip install beir -q
pip install sentence-transformers -q
pip install git+https://github.com/naver/splade.git -q
pip install git+https://github.com/zetaalphavector/InPars.git -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m49.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m27.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m88.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m16.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m26.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m50.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━



In [2]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [3]:
import os
import json
import numpy as np
import pandas as pd
import random
import torch
from time import time
import collections
import evaluate
import datasets
import shutil
import pickle
import numba
import inpars

from collections import defaultdict, Counter
from datasets import load_dataset
from tqdm import tqdm
from operator import itemgetter

from torch import nn
from torch import optim
from transformers import BatchEncoding, get_linear_schedule_with_warmup

from pyserini.search.lucene import LuceneSearcher

In [4]:
from transformers import AutoModelForMaskedLM, AutoTokenizer
from splade.models.transformer_rep import Splade
from torch.utils.data import Dataset, DataLoader

In [5]:
from inpars import rerank, utils

In [6]:
model_name ='castorini/monot5-3b-med-msmarco'
fp16 = True
torch_compile = True
batch_size = 128
top_k = 30

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_t5 = rerank.Reranker.from_pretrained(
    model_name_or_path=model_name,
    batch_size=batch_size,
    fp16=fp16,
    device=device,
    torch_compile=torch_compile,
)

Downloading:   0%|          | 0.00/1.16k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/10.6G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.74k [00:00<?, ?B/s]

In [8]:
main_dir = "/content/gdrive/MyDrive/Unicamp-aula-10"

In [9]:
!pip install jsonlines

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting jsonlines
  Downloading jsonlines-3.1.0-py3-none-any.whl (8.6 kB)
Installing collected packages: jsonlines
Successfully installed jsonlines-3.1.0


## NDCG@10 Evaluation Code

In [10]:
import jsonlines

query_ids = []
query_texts = []

with jsonlines.open(f"{main_dir}/trec-covid/queries.jsonl") as reader:
  for item in reader:
    id = item["_id"]
    query_ids.append(id)
    text = item["text"]
    query_texts.append(text)
     

In [11]:

import pandas as pd

qrel = pd.read_csv(f"{main_dir}/trec-covid/test.tsv", sep="\t", header=None, 
                   skiprows=1, names=["query", "docid", "rel"])
qrel["q0"] = "q0"
qrel = qrel.to_dict(orient="list")

In [12]:

from evaluate import load

def eval_ndcg10(run):
  trec_eval = load("trec_eval")
  results = trec_eval.compute(predictions=[run], references=[qrel])
  return results['NDCG@10'] 

## Pyserini BM25 Search Code

In [13]:
from pyserini.search.lucene import LuceneSearcher

In [14]:
def search_with_bm25(query,k = 1000, index_name='beir-v1.0.0-trec-covid.flat'):
  if index_name == 'beir-v1.0.0-trec-covid.flat':   
    searcher = LuceneSearcher.from_prebuilt_index(index_name)
  else:
    searcher = LuceneSearcher(index_name)
  hits = searcher.search(query, k)
  return hits

## SPLADE

In [15]:
from transformers import AutoModelForMaskedLM, AutoTokenizer

In [16]:
from splade.models.transformer_rep import Splade

In [17]:
model_name_1 = 'naver/splade_v2_distil' 
model_name_2 = 'naver/splade-cocondenser-selfdistil'
model_name_3 = 'naver/splade-cocondenser-ensembledistil' 

tokenizer_3 = AutoTokenizer.from_pretrained(model_name_3)
model_3 = Splade(model_name_3, agg="max").to(device)
model_3.eval()

Downloading:   0%|          | 0.00/466 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/670 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/418M [00:00<?, ?B/s]

In [18]:
model_splade = model_3
tokenizer_splade = tokenizer_3

In [19]:
max_length=256

In [20]:
from torch.nn.functional import relu

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def vectorize_to_sparse(text, tokenizer=tokenizer_splade, model=model_splade, remove_special_tokens=False):
  # Kudos to Marcos Piau
  with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):
    tokenized_text = tokenizer(text, max_length=max_length, truncation=True, 
                              return_tensors='pt', 
                              #return_special_tokens_mask=True
                               ).to(device)
    
    model.to(device)
    model.eval()
    
    with torch.no_grad():
       output = model(q_kwargs=tokenized_text)["q_rep"]

  return torch.transpose(output, 0, 1).squeeze().to_sparse()

In [21]:
import jsonlines

passage_ids = []
passage_texts = []
id_to_text = dict()

with jsonlines.open(f"{main_dir}/trec-covid/corpus.jsonl") as reader:
  for item in reader:
    id = item["_id"]
    passage_ids.append(id)
    text = item["title"] + ' ' + item["text"]
    passage_texts.append(text)
    id_to_text[id] = text

#Sorts the passages by length
passage_indices = sorted(range(len(passage_texts)), 
                         key=lambda k: len(passage_texts[k]))
passage_texts = sorted(passage_texts, key=lambda k: len(k))
passage_ids = sorted(passage_ids, key=lambda k: len(id_to_text[k]))

In [22]:
def search_by_query_vector_in_inverted_index(inverted_index, query_vec, k, ids=None):
  query_vec = query_vec.coalesce()
  doc_scores = defaultdict(int) # int (doc_id) -> int (score)
  doc_ids = []
  indices = query_vec.indices()[0]
  values = query_vec.values()

  for token_id, wj in zip(indices, values):
    token_id = token_id.item()
    wj = wj.item()
    
    if token_id in inverted_index:
      doc_ids = inverted_index[token_id]["docs"]
      wjs = inverted_index[token_id]["wj"]

      for idx, doc_wj in zip(doc_ids, wjs):
        if ids is not None and passage_ids[idx] in ids:
          doc_scores[passage_ids[idx]] += wj * doc_wj
        elif ids is None:
          doc_scores[passage_ids[idx]] += wj * doc_wj
        
  doc_scores = dict(sorted(doc_scores.items(), key=lambda x:x[1], 
                           reverse=True)[:k])
          
  return doc_scores

In [23]:
import array
import pandas as pd
from collections import defaultdict
from collections import Counter
import pickle
import os

def load_or_build_inverted_index(index_path = f"{main_dir}/index.pickle", 
                                 docs_matrix=None):
  if os.path.exists(index_path):
    with open(index_path, "rb") as f:
      print("Loading index...")
      index = pickle.load(f)
  else:
    print("Building inverted index...")
    inverted_index = dict()
    idx = 0

    def process(doc_id, idx):
      assert passage_ids[idx] == doc_id
      doc_vec = docs_matrix[idx]
      doc_vec = doc_vec.coalesce()
      indices = doc_vec.indices()[0]
      values = doc_vec.values()

      for token_id, wj in zip(indices, values):
        token_id = token_id.item()
        wj = wj.item()
        inverted_index.setdefault(
            token_id, {"docs":array.array("L", []), 
                       "wj":array.array("f", [])})["docs"].append(idx)
        inverted_index.setdefault(
            token_id, {"docs":array.array("L", []), 
                       "wj":array.array("f", [])})["wj"].append(wj)

    for i in tqdm(range(docs_matrix.shape[0])):
      process(passage_ids[i], i)
      
    index = {"inverted_index": inverted_index}

    with open(index_path, "wb") as f:
      pickle.dump(index, f)

  return index

In [24]:
start = time()
index = load_or_build_inverted_index()
end = time()
print("Time spent to build inverted index = ", end - start)

Loading index...
Time spent to build inverted index =  3.1887118816375732


In [25]:
inverted_index = index["inverted_index"]

## Pipeline

In [26]:
k=1000

In [27]:
import time

def search_bm25_splade_monot5(query_id, query, top_k=1000, 
                       index_name='beir-v1.0.0-trec-covid.flat'):
  #1st step: searches with BM25
  bm25_hits = search_with_bm25(query, k=top_k+1, index_name=index_name)
 
  if index_name == 'beir-v1.0.0-trec-covid.flat':
    ids = [json.loads(bm25_hits[i].raw)['_id'] for i in range(len(bm25_hits))]
  else:
    ids = [json.loads(bm25_hits[i].raw)['id'] for i in range(len(bm25_hits))]
  ids = set(ids)

  query_embedding = vectorize_to_sparse(query, tokenizer=tokenizer_splade, model=model_splade)
  
  doc_scores = search_by_query_vector_in_inverted_index(inverted_index, query_embedding, k+1, 
                                                        ids=ids)

  input_run = "run_bm25_splade.csv"
  rank = 1
  with open(input_run, 'w') as f:
    for doc_id, score in doc_scores.items():
        f.write(f'{query_id} Q0 {doc_id} {rank} {score} "bm25_splade"\n')
        rank += 1
  

  run = utils.TRECRun(input_run)
  topics = {query_id: query}

  #BM25(1000) + SPLADE(1000) + top_k = 1000 -> ndcg@10 = 0,7842, 21,1 segs p/ query
  #BM25(1000) + SPLADE(1000) + top_k = 100 -> ndcg@10 = 0,8098, 3,02 segs p/ query
  #BM25(1000) + SPLADE(1000) + top k = 50 -> ndcg@10 = 0.8128, 2,57 segs p/ query
  run.rerank(model_t5, topics, id_to_text, top_k=50)
  output_run = f"run.monot5.3b.msmarco.10k.txt"
  run.save(output_run)

  doc_scores = dict()
  with open(output_run, 'r') as f:
    for line in f:
      fields = line.strip().split()
      doc_id =fields[2]
      score = fields[4]
      doc_scores[doc_id] = score
  
  return doc_scores

In [28]:
from collections import defaultdict

run_bm25_splade_monot5 = defaultdict(list)

start = time.time()
k=1000

for i, query in tqdm(zip(query_ids, query_texts)):
  doc_scores = search_bm25_splade_monot5(i, query)
  n = len(doc_scores)
  run_bm25_splade_monot5["query"] += [i] * n
  run_bm25_splade_monot5["docid"] += doc_scores.keys()
  run_bm25_splade_monot5["score"] += doc_scores.values()
  run_bm25_splade_monot5["q0"] += ["q0"] * n
  run_bm25_splade_monot5["rank"] += list(range(1,n+1))
  run_bm25_splade_monot5["system"] += ["doc2query_bm25_splade_inpars"] * n

end = time.time()
print("Time spent = ", end - start)
print("Time spent by query = ", (end - start)/len(query_ids))

0it [00:00, ?it/s]

Downloading index at https://rgw.cs.uwaterloo.ca/pyserini/indexes/lucene-index.beir-v1.0.0-trec-covid.flat.20221116.505594.tar.gz...



lucene-index.beir-v1.0.0-trec-covid.flat.20221116.505594.tar.gz: 0.00B [00:00, ?B/s][A
lucene-index.beir-v1.0.0-trec-covid.flat.20221116.505594.tar.gz:   0%|          | 8.00k/216M [00:00<3:03:41, 20.5kB/s][A
lucene-index.beir-v1.0.0-trec-covid.flat.20221116.505594.tar.gz:   0%|          | 48.0k/216M [00:00<32:04, 118kB/s]   [A
lucene-index.beir-v1.0.0-trec-covid.flat.20221116.505594.tar.gz:   0%|          | 224k/216M [00:00<07:11, 524kB/s] [A
lucene-index.beir-v1.0.0-trec-covid.flat.20221116.505594.tar.gz:   0%|          | 992k/216M [00:06<01:39, 2.27MB/s][A
lucene-index.beir-v1.0.0-trec-covid.flat.20221116.505594.tar.gz:   0%|          | 0.98M/216M [00:06<32:04, 117kB/s][A
lucene-index.beir-v1.0.0-trec-covid.flat.20221116.505594.tar.gz:   2%|▏         | 5.17M/216M [00:06<03:06, 1.19MB/s][A
lucene-index.beir-v1.0.0-trec-covid.flat.20221116.505594.tar.gz:   2%|▏         | 5.25M/216M [00:06<03:09, 1.16MB/s][A
lucene-index.beir-v1.0.0-trec-covid.flat.20221116.505594.tar.gz:   3%|

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

1it [00:37, 37.86s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

2it [00:39, 16.74s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

3it [00:41,  9.96s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

4it [00:43,  6.86s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

5it [00:45,  4.98s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

6it [00:47,  4.07s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

7it [00:49,  3.42s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

8it [00:51,  2.84s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

9it [00:53,  2.68s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

10it [00:55,  2.49s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

11it [00:57,  2.21s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

12it [00:59,  2.29s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

13it [01:01,  2.14s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

14it [01:03,  2.15s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

15it [01:05,  2.11s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

16it [01:07,  2.07s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

17it [01:09,  2.07s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

18it [01:11,  2.04s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

19it [01:13,  1.97s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

20it [01:15,  2.05s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

21it [01:18,  2.30s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

22it [01:20,  2.18s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

23it [01:22,  2.17s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

24it [01:24,  2.12s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

25it [01:26,  1.96s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

26it [01:28,  1.99s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

27it [01:30,  1.85s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

28it [01:32,  2.12s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

29it [01:34,  2.13s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

30it [01:36,  2.08s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

31it [01:38,  2.08s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

32it [01:41,  2.33s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

33it [01:43,  2.13s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

34it [01:45,  2.14s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

35it [01:47,  2.12s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

36it [01:49,  2.11s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

37it [01:51,  1.92s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

38it [01:53,  2.05s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

39it [01:55,  2.10s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

40it [01:57,  2.05s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

41it [01:59,  2.04s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

42it [02:01,  2.03s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

43it [02:03,  1.89s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

44it [02:05,  2.05s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

45it [02:07,  2.04s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

46it [02:09,  2.04s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

47it [02:11,  1.97s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

48it [02:13,  1.97s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

49it [02:15,  2.04s/it]

Rescoring:   0%|          | 0/1 [00:00<?, ?it/s]

50it [02:17,  2.76s/it]

Time spent =  137.9179825782776
Time spent by query =  2.7583596515655517





In [29]:
eval_ndcg10(run_bm25_splade_monot5)

Downloading builder script:   0%|          | 0.00/5.51k [00:00<?, ?B/s]

0.8128078732353461

In [30]:
!nvidia-smi

Thu May 11 02:22:50 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   40C    P0    56W / 400W |  24623MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces