In [32]:
import numpy as np
import pandas as pd

PATH_COLLECTION_DATA = 'data/subtask_4b/subtask4b_collection_data.pkl'

df_collection = pd.read_pickle(PATH_COLLECTION_DATA)
df_collection.info()

<class 'pandas.core.frame.DataFrame'>
Index: 7718 entries, 162 to 1056448
Data columns (total 17 columns):
 #   Column            Non-Null Count  Dtype         
---  ------            --------------  -----         
 0   cord_uid          7718 non-null   object        
 1   source_x          7718 non-null   object        
 2   title             7718 non-null   object        
 3   doi               7677 non-null   object        
 4   pmcid             4959 non-null   object        
 5   pubmed_id         6233 non-null   object        
 6   license           7718 non-null   object        
 7   abstract          7718 non-null   object        
 8   publish_time      7715 non-null   object        
 9   authors           7674 non-null   object        
 10  journal           6668 non-null   object        
 11  mag_id            0 non-null      float64       
 12  who_covidence_id  528 non-null    object        
 13  arxiv_id          20 non-null     object        
 14  label             7718 n

In [33]:
PATH_QUERY_DATA = 'data/dev-tweets/subtask4b_query_tweets_dev.tsv'
df_query = pd.read_csv(PATH_QUERY_DATA, sep = '\t')
df_query.head()

Unnamed: 0,post_id,tweet_text,cord_uid
0,16,covid recovery: this study from the usa reveal...,3qvh482o
1,69,"""Among 139 clients exposed to two symptomatic ...",r58aohnu
2,73,I recall early on reading that researchers who...,sts48u9i
3,93,You know you're credible when NIH website has ...,3sr2exq9
4,96,Resistance to antifungal medications is a grow...,ybwwmyqy


In [34]:
sub_df_query = df_query
sub_df_query.head()

Unnamed: 0,post_id,tweet_text,cord_uid
0,16,covid recovery: this study from the usa reveal...,3qvh482o
1,69,"""Among 139 clients exposed to two symptomatic ...",r58aohnu
2,73,I recall early on reading that researchers who...,sts48u9i
3,93,You know you're credible when NIH website has ...,3sr2exq9
4,96,Resistance to antifungal medications is a grow...,ybwwmyqy


In [35]:

from tqdm import tqdm
from rankers.bi_encoder_ranker import BiEncoderRanker
from sentence_transformers import SimilarityFunction

tqdm.pandas()

def flatten_corpus(entry):    
    title = entry["title"]
    authors = entry["authors"]
    abstract = entry["abstract"]
    journal = entry["journal"]

    paper_data = f"{title} [SEP] {authors} [SEP] {abstract} [SEP] {journal}"
    return paper_data

cord_uids = df_collection[:]['cord_uid'].tolist()
corpus = df_collection.apply(lambda x: flatten_corpus(x), axis = 1)
corpus = corpus.tolist()


In [36]:
model_name = "multi-qa-mpnet-base-cos-v1"
bi_enc_ranker = BiEncoderRanker(model_name, corpus, similarity_function=SimilarityFunction.DOT_PRODUCT)

INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cuda:0
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: multi-qa-mpnet-base-cos-v1


Batches:   0%|          | 0/242 [00:00<?, ?it/s]

In [37]:

def get_top_cord_uids(query):
    doc_scores = bi_enc_ranker.get_scores(query)
    indices = np.argsort(-doc_scores[0])[:1000]
    bi_enc_topk = [cord_uids[x] for x in indices]
    return bi_enc_topk

# Retrieve topk candidates using the BM25 model
sub_df_query['bi_enc'] = sub_df_query['tweet_text'].progress_apply(lambda x: get_top_cord_uids(x))

100%|██████████| 1400/1400 [00:21<00:00, 65.03it/s]


In [38]:
from eval_scripts.eval import get_performance_mrr, get_avg_gold_in_pred
mrr_results = get_performance_mrr(sub_df_query, 'cord_uid', 'bi_enc')
gold_results = get_avg_gold_in_pred(sub_df_query, 'cord_uid', 'bi_enc', list_k=[100, 125, 150, 200])
# Printed MRR@k results in the following format: {k: MRR@k}
print(">>>")
print(mrr_results)
print(gold_results)
print("<<<")

>>>
{1: np.float64(0.4664285714285714), 5: np.float64(0.5247857142857143), 10: np.float64(0.5326754535147393)}
{100: np.float64(0.8685714285714285), 125: np.float64(0.8807142857142857), 150: np.float64(0.8942857142857142), 200: np.float64(0.9078571428571428)}
<<<
