Adapted from https://github.com/Tomjg14/Master_Thesis_MSMARCO_Passage_Reranking_BERT

In [None]:
!pip install torch==1.13.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html

# MSMARCO

In [None]:
from os.path import exists
import urllib.request
import tarfile

folder = 'data/msmarco_passage/'
tar_file = 'collectionandqueries.tar.gz'
tar_file_path = folder + tar_file

if not exists(tar_file_path):
    print('Downloading ' + tar_file + ' ...')
    url = 'https://msmarco.blob.core.windows.net/msmarcoranking/' + tar_file
    urllib.request.urlretrieve(url, tar_file_path)

    print('Extracting ' + tar_file + ' ...')
    with tarfile.open(tar_file_path) as tar:
        tar.extractall(folder)

# Anserini
Download/install C++ build tools: https://visualstudio.microsoft.com/visual-cpp-build-tools/

In [None]:
!pip install pyserini
!pip install faiss-cpu
!git clone https://github.com/castorini/anserini.git --recurse-submodules

## Convert to jsonl

In [None]:
!python anserini/tools/scripts/msmarco/convert_collection_to_jsonl.py --collection-path data/msmarco_passage/collection.tsv --output-folder data/msmarco_passage/collection_jsonl

## Generate index

In [None]:
!python -m pyserini.index.lucene -collection JsonCollection -generator DefaultLuceneDocumentGenerator -threads 9 -input data/msmarco_passage/collection_jsonl -index data/msmarco_passage/lucene-index-msmarco-passage -storePositions -storeDocvectors -storeRaw

## Filter queries

In [None]:
!python anserini/tools/scripts/msmarco/filter_queries.py --qrels data/msmarco_passage/qrels.dev.small.tsv --queries data/msmarco_passage/queries.dev.tsv --output data/msmarco_passage/queries.dev.small.tsv

## Retrieve top 1000s

In [None]:
!python anserini/tools/scripts/msmarco/retrieve.py --hits 1000 --threads 1 --index data/msmarco_passage/lucene-index-msmarco-passage --queries data/msmarco_passage/queries.dev.small.tsv --output data/output/run.anserini.dev.small.tsv

# BERT

In [None]:
!pip install pytorch_pretrained_bert
!pip install livelossplot
!pip install nvidia-ml-py3
!pip install unidecode

!pip install ipywidgets==7.* --user
!pip install widgetsnbextension jupyter_contrib_nbextensions --user
!jupyter contrib nbextension install --user
!jupyter nbextension enable --py widgetsnbextension

In [1]:
import pandas as pd
import numpy as np
import os
import json
import unidecode
import re
import torch

from tqdm.notebook import tqdm

from pytorch_pretrained_bert import BertTokenizer, BertModel
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME, BertForMultipleChoice
from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
                                                  BertTokenizer,
                                                  whitespace_tokenize)

## Convert Tensorflow model to PyTorch (only once)

In [None]:
!pip install tensorflow
!python bert_convert_tensorflow_to_pytorch.py --tf_checkpoint_path=./model/BERT_Base_trained_on_MSMARCO/model.ckpt-100000 --bert_config_file=./model/BERT_Base_trained_on_MSMARCO/bert_config.json --pytorch_dump_path=./model/BERT_Base_trained_on_MSMARCO/pytorch.bin

## Utilities

In [2]:
# function to get the IDs of the previous queries of a query in a session 
def get_lower_ids(session_df, query_id):
    session_id = int(query_id.split('_')[0])
    current_id = int(query_id.split('_')[1])
    all_ids = [int(x.split('_')[1]) for x in session_df['query_id'].tolist()]
    lower_ids = [x for x in all_ids if x < current_id]
    lower_ids = [str(session_id) + '_' + str(x) for x in lower_ids]
    return lower_ids

# function that strips all non-alphanumeric characters
def remove_non_alphanumeric(text):
    text = unidecode.unidecode(str(text))
    text = re.sub(r'[^a-zA-Z0-9]', ' ', text)
    return text

# function that returns a list of segment ids based on indexed tokens (BERT)
def get_segment_ids_from_index_tokens(indexed_tokens):
    segment_ids = []
    sep = False
    for i, token in enumerate(indexed_tokens):
        if token == 102:
            sep = True
        if sep:
            segment_ids.append(1)
        else:
            segment_ids.append(0)
    return segment_ids

def run_bert(data):
    activations = []
    for i in tqdm(range(len(data))):
        # convert inputs to PyTorch tensors
        tokens = data.iloc[i]['indexed_tokens']
        segment_ids = data.iloc[i]['segment_ids']
        
        # make sure the input fits
        token_size_diff = len(tokens) - 512
        if token_size_diff > 0:
            tokens = [tokens[0]] + tokens[token_size_diff:]
            segment_ids = [segment_ids[0]] + segment_ids[token_size_diff:]

        tokens_tensor = torch.tensor([tokens])
        segments_tensors = torch.tensor([segment_ids])

        # set everything to run on GPU
        tokens_tensor = tokens_tensor.to('cuda')
        segments_tensors = segments_tensors.to('cuda')

        with torch.no_grad():
            prediction = bertmodel(tokens_tensor, segments_tensors) 
            activations.append(prediction.cpu())

    data['pooled_output'] = activations
    return data

# https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length
def split(a, n):
    k, m = divmod(len(a), n)
    return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))

## Config

In [28]:
queries_filename = 'queries.dev.small.tsv'
anserini_output_filename = 'run.anserini.dev.small.tsv'
output_filename = 'run.bert.dev.small.tsv'

models_dir = "model/"
msmarco_dir = "data/msmarco_passage/"
anserini_output_dir = "data/output/"
output_dir = "data/output/"

top_n = 100

n_chunks = 10

## Load data

In [29]:
# MSMARCO collection
msmarco_collection = pd.read_csv(msmarco_dir + 'collection.tsv',delimiter='\t',encoding='utf-8', header=None)
msmarco_collection.columns = ['passage_id', 'passage']

query_subset = pd.read_csv(msmarco_dir + queries_filename,delimiter='\t',encoding='utf-8', header=None)
query_subset.columns = ['query_id', 'query']

query_anserini_output = pd.read_csv(anserini_output_dir + anserini_output_filename,delimiter='\t',encoding='utf-8', header=None)
query_anserini_output.columns = ['query_id', 'passage_id', 'bm25_rank']

top1000_query_ids = pd.DataFrame(list(np.unique(query_anserini_output['query_id'].tolist())))
top1000_query_ids.columns = ['query_id']

## Preprocess

In [None]:
tqdm.pandas()

bert_df = top1000_query_ids.copy()
bert_df = bert_df.merge(query_anserini_output[query_anserini_output['bm25_rank'] <= top_n],how='left',on=['query_id'])
bert_df = bert_df.merge(query_subset,how='left',on=['query_id'])
bert_df = bert_df.merge(msmarco_collection,how='left',on=['passage_id'])

bert_df['query'] = bert_df['query'].progress_apply(lambda x: remove_non_alphanumeric(x.lower()))
tqdm.pandas()
bert_df['passage'] = bert_df['passage'].progress_apply(lambda x: remove_non_alphanumeric(x.lower()))
bert_df['input_text'] = "[CLS] " + bert_df['query'] +" [SEP] " + bert_df['passage'] + " [SEP]"

## Model

In [31]:
bertmodel = BertForSequenceClassification.from_pretrained('bert-base-uncased', 2)
bertmodel.load_state_dict(torch.load(models_dir + 'BERT_Base_trained_on_MSMARCO/pytorch.bin'))

bertmodel.eval()
bertmodel.to('cuda')

tqdm.pandas()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

## Run

In [None]:
query_ids = list(query_subset['query_id'])
query_id_chunks = list(split(query_ids, n_chunks))

for i, query_id_chunk in enumerate(query_id_chunks):
    tqdm.write('chunk {}/{}'.format(i + 1, n_chunks))

    bert_df_chunk = bert_df[bert_df['query_id'].isin(query_id_chunk)].copy()

    # tokenize
    tqdm.write('tokenize')

    bert_df_chunk['indexed_tokens'] = bert_df_chunk.progress_apply(lambda row: tokenizer.convert_tokens_to_ids(tokenizer.tokenize(row['input_text'])), axis=1)
    bert_df_chunk['segment_ids'] = bert_df_chunk.progress_apply(lambda row: get_segment_ids_from_index_tokens(row['indexed_tokens']), axis=1)

    # run
    tqdm.write('run')
    output_df_chunk = run_bert(bert_df_chunk)

    # score
    output_df_chunk['score_bert'] = output_df_chunk.progress_apply(lambda row: row['pooled_output'].data[0][1].item(), axis=1)
    output_df_chunk = output_df_chunk.drop(columns=['input_text', 'indexed_tokens', 'segment_ids', 'pooled_output'])
    output_df_chunk["bert_rank"] = output_df_chunk.groupby("query_id")["score_bert"].rank(ascending=0,method='dense')
    output_df_chunk["bert_rank"] = output_df_chunk['bert_rank'].astype(int)

    # save
    output_df_chunk[['query_id', 'passage_id', 'bm25_rank', 'score_bert', 'bert_rank']].to_csv(output_dir + output_filename + '-{}-of-{}'.format(i + 1, n_chunks),sep="\t", header=False,index=False)
    
    

## Evaluate

In [None]:
relevance_df = pd.read_csv(msmarco_dir + 'qrels.dev.small.tsv',delimiter='\t',encoding='utf-8',header=None)
relevance_df.columns = ['query_id','label1','passage_id','label2']
relevance_df = relevance_df.drop(columns=['label1','label2'])

bert_filenames = [output_filename + '-{}-of-{}'.format(i + 1, n_chunks) for i in range(n_chunks)]

bert_dfs = []
for bert_filename in tqdm(bert_filenames):
    bert_df = pd.read_csv(output_dir + bert_filename,delimiter='\t',encoding='utf-8', header=None)
    bert_df.columns = ['query_id', 'passage_id', 'bm25_rank', 'bert_score', 'bert_rank']
    bert_dfs.append(bert_df)
 
bert_rankings = pd.concat(bert_dfs, ignore_index=True)

In [34]:
def compute_mrr(gt, pred, column):
   score = 0.0
   best_rank = 11
   for index, row in pred.iterrows():
       current_rank = row[column]
       if row['passage_id'] in gt:
           if current_rank < best_rank:
               score = 1.0 / (row[column])
               best_rank = current_rank
   return score

In [None]:
bm25_mrr = 0.0
bert_mrr = 0.0

query_ids = list(np.unique(bert_rankings['query_id'].tolist()))
relevance_df = relevance_df[relevance_df['query_id'].isin(query_ids)]
for query_id in tqdm(query_ids):
    gt = relevance_df[relevance_df['query_id'] == query_id]['passage_id'].values.tolist()

    query_preds_df = bert_rankings[(bert_rankings['query_id'] == query_id) & (bert_rankings['bert_rank'] < 11)]
    bert_mrr += compute_mrr(gt, query_preds_df, 'bert_rank')

    query_preds_df = bert_rankings[(bert_rankings['query_id'] == query_id) & (bert_rankings['bm25_rank'] < 11)]
    bm25_mrr += compute_mrr(gt, query_preds_df, 'bm25_rank')

tqdm.write('BM25: MRR@10: {}'.format(round((bm25_mrr/len(query_ids))*100,1)))
tqdm.write('BERT: MRR@10: {}'.format(round((bert_mrr/len(query_ids))*100,1)))