## Reranking using SciBERT

Goal is to take best run so far i.e. elasticsearch on abstract, use SciBERT to rerank first `k` abstracts and evaluate the results again using TREC metric as usual.

SciBERT will encode abstracts and queries into vectors of size : `[num_of_tokens, 786]`. If the sentence is longer than max. size of BERT, tokenizer will use slicing windows. We will use vectors of last hidden states of the BERT model as embedding to encode text. 

In [19]:
from collections import defaultdict
import warnings
import os
from pathlib import Path
warnings.simplefilter(action='ignore', category=FutureWarning)

# ----------------- Classics -------------------- #
import numpy as np
import pandas as pd

# ---------------- Pandas settings --------------- #
# Removes rows and columns truncation of '...'
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

import pickle

from tqdm import tqdm
tqdm.pandas()

# ------------------- NLP libs ---------------------- #
import torch
from transformers import *

## 1. Load data set

In [20]:
CORD19_PATH = Path('../data/input/trec_cord19_v1.csv')

def load_cord19(input_fpath: Path, dtype: str = 'csv', cols_to_keep: list = ['cord_uid', 'abstract'], index_col = 'cord_uid') -> pd.DataFrame:
    """Loads CORD19 data and returns it as pandas data frame
    """
    if dtype == 'csv':
        df = pd.read_csv(input_fpath, quotechar='"', index_col=index_col, usecols=cols_to_keep)
        # for each column
        for col in df.columns:
            # check if the columns contains string data
            if pd.api.types.is_string_dtype(df[col]):
                df[col] = df[col].str.strip() # removes front and end white spaces
                df[col] = df[col].str.replace('\s{2,}', ' ') # remove double or more white spaces
                df[col] = df[col].str.encode('ascii', 'ignore').str.decode('ascii')
    return df

cord19 = load_cord19(CORD19_PATH, cols_to_keep = ['cord_uid', 'abstract', 'title+abstract'], index_col='cord_uid')
cord19.head()

Unnamed: 0_level_0,abstract,title+abstract
cord_uid,Unnamed: 1_level_1,Unnamed: 2_level_1
ug7v899j,OBJECTIVE: This retrospective chart review des...,Clinical features of culture-proven Mycoplasma...
02tnwd4m,Inflammatory diseases of the respiratory tract...,Nitric oxide: a pro-inflammatory mediator in l...
ejv2xln0,Surfactant protein-D (SP-D) participates in th...,Surfactant protein-D and pulmonary host defens...
2b73a28n,Endothelin-1 (ET-1) is a 21 amino acid peptide...,Role of endothelin-1 in lung disease Endotheli...
9785vg6d,Respiratory syncytial virus (RSV) and pneumoni...,Gene expression in epithelial cells in respons...


## 2. Load queries

In [21]:
def load_queries(input_fpath: Path, dtype: str = 'csv', cols_to_keep=['topic-id', 'query', 'question'], index_col=['topic-id']) -> pd.DataFrame:
    """Loads queries file and returns it as pandas data frame
    """
    if dtype == 'csv':
        df = pd.read_csv(input_fpath, quotechar='"', index_col=index_col, usecols=cols_to_keep)
        # for each column
        for col in df.columns:
            # check if the columns contains string data
            if pd.api.types.is_string_dtype(df[col]):
                df[col] = df[col].str.strip() # removes front and end white spaces
                df[col] = df[col].str.replace('\s{2,}', ' ') # remove double or more white spaces
    return df

QUERY_FPATH = Path('../data/CORD-19/CORD-19/topics-rnd3.csv')
query_df = load_queries(QUERY_FPATH)
query_df['query+question'] = query_df['query'] + ' ' + query_df['question']
query_df.head()

Unnamed: 0_level_0,query,question,query+question
topic-id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,coronavirus origin,what is the origin of COVID-19,coronavirus origin what is the origin of COVID-19
2,coronavirus response to weather changes,how does the coronavirus respond to changes in...,coronavirus response to weather changes how do...
3,coronavirus immunity,will SARS-CoV2 infected people develop immunit...,coronavirus immunity will SARS-CoV2 infected p...
4,how do people die from the coronavirus,what causes death from Covid-19?,how do people die from the coronavirus what ca...
5,animal models of COVID-19,what drugs have been active against SARS-CoV o...,animal models of COVID-19 what drugs have been...


In [22]:
run_name = "elasticsearch_baseline_abstract_query_question"
path_to_qrel_file = "../data/qrels/qrels-covid_d3_j0.5-3.txt"
path_to_result_file = f"../data/output/{run_name}.txt"
output_result_path = f"../data/results/{run_name}_trec_eval.txt"
os.system("trec_eval -c -m all_trec {} {} > {}".format(path_to_qrel_file, path_to_result_file, output_result_path))
with open(output_result_path, encoding='utf-8') as f:
    print(f.read())

runid                 	all	elasticsearch_baseline_abstract_query_question
num_q                 	all	40
num_ret               	all	40000
num_rel               	all	10001
num_rel_ret           	all	4113
map                   	all	0.1679
gm_map                	all	0.1180
Rprec                 	all	0.2522
bpref                 	all	0.3556
recip_rank            	all	0.7923
iprec_at_recall_0.00  	all	0.8408
iprec_at_recall_0.10  	all	0.4761
iprec_at_recall_0.20  	all	0.3490
iprec_at_recall_0.30  	all	0.2489
iprec_at_recall_0.40  	all	0.1536
iprec_at_recall_0.50  	all	0.0769
iprec_at_recall_0.60  	all	0.0339
iprec_at_recall_0.70  	all	0.0040
iprec_at_recall_0.80  	all	0.0000
iprec_at_recall_0.90  	all	0.0000
iprec_at_recall_1.00  	all	0.0000
P_5                   	all	0.6300
P_10                  	all	0.5975
P_15                  	all	0.5817
P_20                  	all	0.5538
P_30                  	all	0.5158
P_100                 	all	0.3552
P_200                 	all	0.2650
P_500           

In [23]:
with open(path_to_result_file, 'r') as f:
    es_results = f.readlines()

set_of_top_k_uid_all_topics = set()
for line in es_results:
    qid, _, uid, rank, _, _ = line.strip().split()
    if int(rank) <= 200:
        set_of_top_k_uid_all_topics.add(uid)

# So for each 40 topics we get 40 * 200 ~ 8000 doc ids
len(set_of_top_k_uid_all_topics)

6549

In [26]:
filtered_cord19 = cord19.loc[set_of_top_k_uid_all_topics]
filtered_cord19.info()

<class 'pandas.core.frame.DataFrame'>
Index: 6549 entries, keaxietu to 3twud97m
Data columns (total 2 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   abstract        6549 non-null   object
 1   title+abstract  6549 non-null   object
dtypes: object(2)
memory usage: 153.5+ KB


In [27]:
abstracts_dict = filtered_cord19['title+abstract'].to_dict()
query_dict = query_df['query+question'].to_dict()
len(abstracts_dict), len(query_dict)

(6549, 40)

In [28]:
# load SciBert (alternative: monologg/biobert_v1.1_pubmed)
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', do_lower_case=False)
model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')

In [37]:
CHUNK_SIZE = 510

def extract_scibert(text, tokenizer, model):
    """
    Encode text to vectors
        text -  string to be encoded
        tokenizer - BertTokenizer object
        model - BertModel object
        return - tensor of size [num_tokens, 768] (last hidden state of BERT)
    """
    
    text_ids = torch.tensor([tokenizer.encode(text, add_special_tokens=True)])
    text_words = tokenizer.convert_ids_to_tokens(text_ids[0])[1:-1]

    n_chunks = int(np.ceil(float(text_ids.size(1))/CHUNK_SIZE))
    states = []
    
    for ci in range(n_chunks):
        text_ids_ = text_ids[0, 1+ci*CHUNK_SIZE:1+(ci+1)*CHUNK_SIZE]  
        text_ids_ = torch.cat([text_ids[0, 0].unsqueeze(0), text_ids_])
        if text_ids_[-1] != text_ids[0, -1]:
            text_ids_ = torch.cat([text_ids_, text_ids[0,-1].unsqueeze(0)])
        
        with torch.no_grad():
            state = model(text_ids_.unsqueeze(0))[0] # last hidden states
            state = state[:, 1:-1, :]
        states.append(state)

    state = torch.cat(states, axis=1)
    return state[0]

In [38]:
def encode_abstract_query_narrative_and_save(query_dict, abstracts_dict, extract_scibert, tokenizer, model, fname, **kwargs):
    """
    Encode topics and abstracts using given encoding function, save it to path_to_output
        query_dict
        abstracts_dict
        extract_scibert - BERT encoding function, 
            input: text, tokenizer, model **kwargs, 
            output  tensor of size [num, 768]
        tokenizer - to pass to extract_scibert
        model - to pass to extract_scibert
    """
    
    # encode abstract
    encoded_abstract = dict()

    for uid, text in tqdm(abstracts_dict.items()):
        encoded_abstract[uid] = extract_scibert(text, tokenizer, model, **kwargs)

    # encode queries
    encoded_queries = dict()

    for qid, query in topics.items():
        encoded_queries[qid] = extract_scibert(query, tokenizer, model)
        
    # save for future use
    bert_vectors = {
            "abstract": encoded_abstract, 
            "query": encoded_queries
    }

    with open("../data/embeddings/" + fname, "wb") as f:
        pickle.dump(bert_vectors, f)

In [None]:
encode_abstract_query_narrative_and_save(
    query_dict,
    abstracts_dict,
    extract_scibert,
    tokenizer,
    model,
    "scibert.pkl"
)

 12%|█▏        | 809/6549 [19:28<1:28:01,  1.09it/s]

In [None]:
# load
with open("../data/embeddings/scibert.pkl", "rb") as f:
    bert_vectors = pickle.load(f)

In [None]:
def cross_match(state1, state2):
    state1 = state1 / torch.sqrt((state1 ** 2).sum(1, keepdims=True))
    state2 = state2 / torch.sqrt((state2 ** 2).sum(1, keepdims=True))
    sim = (state1.unsqueeze(1) * state2.unsqueeze(0)).sum(-1)
    return sim

def rerank(topics, search_run, abstracts_dict, top_k, run_name, topics_field, bert_vectors):
    """
    Rerank the original run and save the reranked run in path_to_reranked_run
        query_dict -  dict where key=qid and value=query+question
        abstracts_dict - dict where key=uid and value=abstract
        search_run - python list of previous search runs
        top_k - to rerank, the rest will remain same
        run_name - reranked run name
        bert_vectors dict - key = qid, uid and value = vectors
    """
    rerank = defaultdict(list)  # first k hits
    keeprank = defaultdict(list) # k+1 to 1000 hits
   
    encoded_queries = bert_vectors["query"] 
    encoded_abstract = bert_vectors["abstract"]

    # calculate similarity
    for line in search_run:
        qid, _, uid, j, score, _ = line.strip().split()
        if len(rerank[qid]) < top_k:
            if not abstracts_dict[uid]:
                continue # Some uid don't have abstract. But why they show up in Anserini run?

            _, _, enc_abs = encoded_abstract[abstracts_dict[uid]]
            _, _, enc_query = encoded_queries[query_dict[qid]]
            sim = cross_match(enc_query, enc_abs)

            rel_score = torch.max(sim).item()
            rerank[qid].append([uid, rel_score])

        elif len(rerank[qid]) >= top_k and len(keeprank[qid]) < 1000 - top_k: 
            keeprank[qid].append([uid, score, j])


    # create reranked run and save to path_to_reranked_run
    template = "{} Q0 {} {} {} {}"
    run = list()

    for qid in rerank:
        rank = 1
        for uid, score in sorted(rerank[qid], key=lambda x:-x[1]):
            run.append(template.format(qid, uid, rank, score + 10, run_name))
            rank += 1

        for uid, score, j in keeprank[qid]:
            run.append(template.format(qid, uid, rank, score, run_name))
            rank += 1
            
        assert rank == 1001 # if no bugs, each topic will have at most 1000 uid (can be less if original run has less)

    with open("../data/output/{run_name}.txt", "w", encoding='utf-8') as f:
        f.write("\n".join(run))