In [1]:
import numpy as np
import pandas as pd

PATH_COLLECTION_DATA = 'data/subtask_4b/subtask4b_collection_data.pkl'

df_collection = pd.read_pickle(PATH_COLLECTION_DATA)
df_collection.info()

<class 'pandas.core.frame.DataFrame'>
Index: 7718 entries, 162 to 1056448
Data columns (total 17 columns):
 #   Column            Non-Null Count  Dtype         
---  ------            --------------  -----         
 0   cord_uid          7718 non-null   object        
 1   source_x          7718 non-null   object        
 2   title             7718 non-null   object        
 3   doi               7677 non-null   object        
 4   pmcid             4959 non-null   object        
 5   pubmed_id         6233 non-null   object        
 6   license           7718 non-null   object        
 7   abstract          7718 non-null   object        
 8   publish_time      7715 non-null   object        
 9   authors           7674 non-null   object        
 10  journal           6668 non-null   object        
 11  mag_id            0 non-null      float64       
 12  who_covidence_id  528 non-null    object        
 13  arxiv_id          20 non-null     object        
 14  label             7718 n

In [2]:
PATH_QUERY_DATA = 'data/subtask_4b/subtask4b_query_tweets_train.tsv'

df_query = pd.read_csv(PATH_QUERY_DATA, sep = '\t')
df_query.head()

Unnamed: 0,post_id,tweet_text,cord_uid
0,0,Oral care in rehabilitation medicine: oral vul...,htlvpvz5
1,1,this study isn't receiving sufficient attentio...,4kfl29ul
2,2,"thanks, xi jinping. a reminder that this study...",jtwb17u8
3,3,Taiwan - a population of 23 million has had ju...,0w9k8iy1
4,4,Obtaining a diagnosis of autism in lower incom...,tiqksd69


In [3]:
# Data prep for bi-encoder
from tqdm import tqdm
from rankers.bi_encoder_ranker import BiEncoderRanker

tqdm.pandas()

def flatten_corpus(entry):    
    title = entry["title"]
    authors = entry["authors"]
    abstract = entry["abstract"]
    journal = entry["journal"]

    paper_data = f"{title} [SEP] {authors} [SEP] {abstract} [SEP] {journal}"
    return paper_data

cord_uids = df_collection[:]['cord_uid'].tolist()
corpus = df_collection.progress_apply(lambda x: flatten_corpus(x), axis = 1)
corpus = corpus.tolist()

100%|██████████| 7718/7718 [00:00<00:00, 88680.92it/s]


In [4]:
bi_model_name = "multi-qa-mpnet-base-cos-v1"
bi_enc_ranker = BiEncoderRanker(bi_model_name, corpus)

INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cuda:0
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: multi-qa-mpnet-base-cos-v1


Batches:   0%|          | 0/242 [00:00<?, ?it/s]

In [5]:
from os import listdir
from rankers.cross_embedding_ranker import CrossRanker

dir_list = listdir("models/cross-embedding")
dir_list.sort()

latest_checkpoint = dir_list[-1]

cross_model_name = f"models/cross-embedding/{latest_checkpoint}"
cross_ranker = CrossRanker(cross_model_name)

def get_top_cord_uids_bi(query):
    doc_scores = bi_enc_ranker.get_scores(query)
    indices = np.argsort(-doc_scores[0])[:50]
    bi_topk = [cord_uids[x] for x in indices]

    reduced_corpus = df_collection[df_collection['cord_uid'].isin(bi_topk)]

    return reduced_corpus

def get_top_cord_uids(query):
    reduced_corpus = get_top_cord_uids_bi(query)
    
    doc_scores = cross_ranker.get_scores(query, reduced_corpus)
    
    indices = np.argsort(-doc_scores)[:5]

    cross_topk = [reduced_corpus.iloc[x]["cord_uid"] for x in indices]

    return cross_topk

# Retrieve topk candidates using the BM25 model
df_query.loc[:,'bi_cross'] = df_query.loc[:, 'tweet_text'].progress_apply(lambda x: get_top_cord_uids(x))
df_query.head()


100%|██████████| 12853/12853 [5:14:54<00:00,  1.47s/it]  


Unnamed: 0,post_id,tweet_text,cord_uid,bi_cross
0,0,Oral care in rehabilitation medicine: oral vul...,htlvpvz5,"[htlvpvz5, fkwgq5mr, yec87cye, e9uou6rr, b52pn..."
1,1,this study isn't receiving sufficient attentio...,4kfl29ul,"[jveh2w09, 29z4q4fs, 7k8nlea3, maj8r6ti, bjvg2..."
2,2,"thanks, xi jinping. a reminder that this study...",jtwb17u8,"[jtwb17u8, iobpcfs5, veeavho5, 2tu707ng, f1ckv..."
3,3,Taiwan - a population of 23 million has had ju...,0w9k8iy1,"[0w9k8iy1, l4y7v729, q5wiqpcb, gy0kfhy6, zxe95..."
4,4,Obtaining a diagnosis of autism in lower incom...,tiqksd69,"[tiqksd69, aqbhxv1f, lr5lumdr, k7smwz6w, 0u330..."


In [6]:
from eval_scripts.eval import get_performance_mrr, get_avg_gold_in_pred, create_pred_file

mrr_results = get_performance_mrr(df_query, 'cord_uid', 'bi_cross')
gold_results = get_avg_gold_in_pred(df_query, 'cord_uid', 'bi_cross', list_k=[1, 5, 10, 100])
# Printed MRR@k results in the following format: {k: MRR@k}
print(">>>")
print(mrr_results)
print(gold_results)
print("<<<")

def create_gold_pred_file(query_set, prediction_columns):    
    query_set['preds'] = query_set[prediction_columns].apply(lambda x: x[:5])
    gold_query_set = query_set.loc[query_set.apply(lambda row: row["cord_uid"] in(row["preds"]), axis=1)]
    gold_query_set[['post_id', 'cord_uid', 'preds']].to_csv('data/pairwise-model-data/gold_pred.tsv', index=None, sep='\t')

create_gold_pred_file(df_query, "bi_cross")

>>>
{1: np.float64(0.6115303820119816), 5: np.float64(0.6661778054410125), 10: np.float64(0.6661778054410125)}
{1: np.float64(0.6115303820119816), 5: np.float64(0.7458180969423481), 10: np.float64(0.7458180969423481), 100: np.float64(0.7458180969423481)}
<<<
