In [1]:
!pip install --upgrade --force-reinstall git+https://github.com/castorini/pygaggle
!pip install faiss-cpu==1.7.2 --quiet
!pip install jsonlines==3.0.0 --quiet
!pip install beir==1.0.0 --quiet
!pip install protobuf==3.20.1 --quiet

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/castorini/pygaggle
  Cloning https://github.com/castorini/pygaggle to /tmp/pip-req-build-84mqw49m
  Running command git clone -q https://github.com/castorini/pygaggle /tmp/pip-req-build-84mqw49m
  Running command git submodule update --init --recursive -q
Collecting coloredlogs==14.0
  Downloading coloredlogs-14.0-py2.py3-none-any.whl (43 kB)
[K     |████████████████████████████████| 43 kB 1.5 MB/s 
[?25hCollecting numpy>=1.18
  Downloading numpy-1.21.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (15.7 MB)
[K     |████████████████████████████████| 15.7 MB 7.6 MB/s 
[?25hCollecting pydantic==1.7.4
  Downloading pydantic-1.7.4-cp37-cp37m-manylinux2014_x86_64.whl (9.1 MB)
[K     |████████████████████████████████| 9.1 MB 30.2 MB/s 
[?25hCollecting pyserini>=0.16.0
  Downloading pyserini-0.17.0-py3-none-any.whl (109.5 MB)
[K     |█████████

[K     |████████████████████████████████| 8.6 MB 3.9 MB/s 
[K     |████████████████████████████████| 64 kB 1.6 MB/s 
[K     |████████████████████████████████| 219 kB 9.1 MB/s 
[?25h  Building wheel for beir (setup.py) ... [?25l[?25hdone
  Building wheel for pytrec-eval (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 1.0 MB 4.3 MB/s 
[?25h

In [5]:
import torch

if torch.cuda.is_available(): 
    dev = "cuda:0"
    print(dev, torch.cuda.get_device_name(0))
    device = torch.device(dev)
else: 
    dev = "cpu"
    print(dev) 

from pyserini.search import SimpleSearcher
from pygaggle.rerank.base import hits_to_texts
from pygaggle.rerank.base import Query, Text
from pygaggle.rerank.transformer import MonoT5
from transformers import T5ForConditionalGeneration
from pygaggle.rerank.transformer import SentenceTransformersReranker
#import json
#import spacy
import jsonlines
import os
import math
from tqdm.notebook import tqdm
#import numpy as np
from beir import util
from beir.datasets.data_loader import GenericDataLoader
import pandas as pd
from IPython.display import clear_output

cuda:0 Tesla P100-PCIE-16GB


In [14]:
def download_dataset(dataset):
    """
    Download a BEIR dataset (test set only). Return the preprocessed corpus, queries and qrels

    Args:
      dataset: Dataset name (string)

    Returns:  
      Return the preprocessed corpus, queries and qrels
    """
    print('Downloading', dataset)
    url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
    data_path = util.download_and_unzip(url, "datasets")
    corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

    return corpus, queries, qrels


def prepare_qrels(dataset):
    """
    Convert qreld to TREC eval format

    Args:
      dataset: Dataset name (string)

    """
    df_qrel = pd.read_csv('/content/datasets/{}/qrels/test.tsv'.format(dataset), sep='\t')
    df_qrel['zero'] = '0'
    cols = ['query-id', 'zero',	'corpus-id', 'score']
    df_qrel = df_qrel[cols] 
    df_qrel.to_csv('qrel.tsv', sep='\t', header = None, index = False)


def index_corpus(corpus):
    """
    Index corpus to be retrieved by BM25

    Args:
      corpus: Corpus (dict)
      
    Returns:
      Searcher object to initialize BM25
    """

    !rm -r candidates
    !rm -r tmp_candidates
    !mkdir tmp_candidates

    for key, val in tqdm(corpus.items()):  
        indexed_dict = { "id": str(key), "contents": val['title'] + ' ' + val['text']}
        with jsonlines.open('/content/tmp_candidates/candidate.jsonl', mode='a') as writer:
            writer.write(indexed_dict)
         
    !python -m pyserini.index -collection JsonCollection -generator DefaultLuceneDocumentGenerator \
    -threads 1 -input /content/tmp_candidates \
    -index /content/candidates/indexes -storePositions -storeDocvectors -storeRaw

    searcher = SimpleSearcher('/content/candidates/indexes')

    return searcher


def run_retrieval(queries, searcher, model_name):
    """
    Run BM25 and reranker retrieval. Save the outputs as txt files

    Args:
      queries: Queries (dict)
      searcher: Pyserini object to perform retrieval
      model_name: model name (string)

    """
    run_bm25 = open("/content/run_BM25_{}.txt".format(dataset),'a')
    run_reranker = open("/content/run_{}_{}.txt".format(model_name, dataset),'a')

    list_t5 = []
    for id, query in tqdm(queries.items()):
      
        hits = searcher.search(query[0:1024], k=1000)
        texts = hits_to_texts(hits)
        query = Query(query)
        reranked = reranker.rerank(query, texts)
        reranked.sort(key=lambda x: x.score, reverse=True)

        for idx in range(len(hits)):

            run_bm25.write(str(id)+' Q0 '+ str(hits[idx].docid) + ' ' + str(idx+1) + ' ' + str(hits[idx].score) + ' BM25\n')
            run_reranker.write(str(id)+' Q0 '+ str(reranked[idx].metadata["docid"]) + ' ' + str(idx+1) + ' ' + str(math.exp(reranked[idx].score) * 100) + ' ' + model_name+'\n')

    run_bm25.close()
    run_reranker.close()


def evaluation(dataset, df_final, model_name):
    """
    Run evaluation and prepare the dataframe results

    Args:
      dataset: Dataset name (string)
      df_final: Dataframe containing final results (dataframe) 
      model_name: model name (string)   

    Returns:
      Returns the df_final dataframe containing run results
    """
    for model in ['BM25', model_name]: 
        list_results = []
        results = !python -m pyserini.eval.trec_eval -c -m all_trec /content/qrel.tsv /content/run_{model}_{dataset}.txt  
        list_results.append(model)
        list_results.append(dataset)      
        for result in results:
            line = result.split('\t')
            if len(line) == 3:
                metric_name, _, value = line
                metric_name = metric_name.strip()
                if metric_name in metrics_map:
                    list_results.append(value)

        df_res = pd.DataFrame([list_results], columns = ['Model', 'Dataset', 'mAP', 'MRR', 'nDCG@5', 'nDCG@10'])
        df_final = pd.concat([df_final, df_res])
    
    df_final.to_csv('BEIR_results.csv', index = False)
    return df_final

In [16]:
model_name = 'monot5-small' #@param ["monot5-small", "monot5-base", "monot5-3B", "MiniLM"]

In [17]:
## List of datasets to be evaluated 
lista_datasets = [ "trec-covid", "nfcorpus", "fiqa", "scifact", "webis-touche2020", "dbpedia-entity", "scidocs", "arguana", 
                  "climate-fever", "quora", "nq", "fever", "hotpotqa"]

## Download model
if model_name == "MiniLM":
    reranker = SentenceTransformersReranker(pretrained_model_name_or_path='cross-encoder/ms-marco-MiniLM-L-6-v2')
else:
    reranker = MonoT5(pretrained_model_name_or_path='castorini/{}-msmarco-10k'.format(model_name), token_false='▁false', token_true ='▁true')

## Desired metrics
metrics_map = {
    'recip_rank': 'MRR',
    'ndcg_cut_5': 'nDCG@5',
    'ndcg_cut_10': 'nDCG@10',
    'map': 'mAP',
}

####### START EXPERIMENTS ##########
if not os.path.exists('/content/BEIR_results.csv'):
    df_final = pd.DataFrame()
else:
    df_final = pd.read_csv('/content/BEIR_results.csv')

for dataset in tqdm(lista_datasets):
    
    # clear previous iteration files
    !rm /content/qrel.tsv
    clear_output(wait=True)
    # Dowload data
    corpus, queries, qrels = download_dataset(dataset)
    # Index corpus
    searcher = index_corpus(corpus)
    # Convert qrels to TREC format
    prepare_qrels(dataset)
    # Run retrieval models
    run_retrieval(queries, searcher, model_name)
    # Evaluate
    df_final = evaluation(dataset, df_final, model_name)

2022-06-09 20:00:55 [INFO] util: Downloading nfcorpus.zip ...


Downloading nfcorpus


datasets/nfcorpus.zip:   0%|          | 0.00/2.34M [00:00<?, ?iB/s]

2022-06-09 20:01:02 [INFO] util: Unzipping nfcorpus.zip ...
2022-06-09 20:01:02 [INFO] data_loader: Loading Corpus...


  0%|          | 0/3633 [00:00<?, ?it/s]

2022-06-09 20:01:02 [INFO] data_loader: Loaded 3633 TEST Documents.
2022-06-09 20:01:02 [INFO] data_loader: Doc Example: {'text': 'Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.

  0%|          | 0/3633 [00:00<?, ?it/s]

pyserini.index is deprecated, please use pyserini.index.lucene.
2022-06-09 20:01:05,857 INFO  [main] index.IndexCollection (IndexCollection.java:645) - Setting log level to INFO
2022-06-09 20:01:05,859 INFO  [main] index.IndexCollection (IndexCollection.java:648) - Starting indexer...
2022-06-09 20:01:05,859 INFO  [main] index.IndexCollection (IndexCollection.java:650) - DocumentCollection path: /content/tmp_candidates
2022-06-09 20:01:05,860 INFO  [main] index.IndexCollection (IndexCollection.java:651) - CollectionClass: JsonCollection
2022-06-09 20:01:05,860 INFO  [main] index.IndexCollection (IndexCollection.java:652) - Generator: DefaultLuceneDocumentGenerator
2022-06-09 20:01:05,861 INFO  [main] index.IndexCollection (IndexCollection.java:653) - Threads: 1
2022-06-09 20:01:05,861 INFO  [main] index.IndexCollection (IndexCollection.java:654) - Language: en
2022-06-09 20:01:05,861 INFO  [main] index.IndexCollection (IndexCollection.java:655) - Stemmer: porter
2022-06-09 20:01:05,862

  0%|          | 0/323 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [7]:
  for model in ['BM25', model_name]: 
        list_results = []
        results = !python -m pyserini.eval.trec_eval -c -m all_trec /content/qrel.tsv /content/run_{model}_{dataset}.txt  
        list_results.append(model)
        list_results.append(dataset)      
        for result in results:
            line = result.split('\t')
            if len(line) == 3:
                metric_name, _, value = line
                metric_name = metric_name.strip()
                if metric_name in metrics_map:
                    list_results.append(value)

TypeError: ignored