In [1]:
import numpy as np
import pandas as pd

PATH_COLLECTION_DATA = 'data/subtask_4b/subtask4b_collection_data.pkl'

df_collection = pd.read_pickle(PATH_COLLECTION_DATA)
df_collection.info()

<class 'pandas.core.frame.DataFrame'>
Index: 7718 entries, 162 to 1056448
Data columns (total 17 columns):
 #   Column            Non-Null Count  Dtype         
---  ------            --------------  -----         
 0   cord_uid          7718 non-null   object        
 1   source_x          7718 non-null   object        
 2   title             7718 non-null   object        
 3   doi               7677 non-null   object        
 4   pmcid             4959 non-null   object        
 5   pubmed_id         6233 non-null   object        
 6   license           7718 non-null   object        
 7   abstract          7718 non-null   object        
 8   publish_time      7715 non-null   object        
 9   authors           7674 non-null   object        
 10  journal           6668 non-null   object        
 11  mag_id            0 non-null      float64       
 12  who_covidence_id  528 non-null    object        
 13  arxiv_id          20 non-null     object        
 14  label             7718 n

In [2]:
PATH_QUERY_DATA = 'data/dev-tweets/subtask4b_query_tweets_dev.tsv'
df_query = pd.read_csv(PATH_QUERY_DATA, sep = '\t')
df_query.head()

Unnamed: 0,post_id,tweet_text,cord_uid
0,16,covid recovery: this study from the usa reveal...,3qvh482o
1,69,"""Among 139 clients exposed to two symptomatic ...",r58aohnu
2,73,I recall early on reading that researchers who...,sts48u9i
3,93,You know you're credible when NIH website has ...,3sr2exq9
4,96,Resistance to antifungal medications is a grow...,ybwwmyqy


In [3]:
import os

dir_list = os.listdir("models/cross-embedding")
dir_list.sort()

latest_checkpoint = dir_list[-1]

model_name = f"models/cross-embedding/{latest_checkpoint}"

print(model_name)

models/cross-embedding/checkpoint-4276


In [4]:
from tqdm import tqdm
from rankers.cross_embedding_ranker import CrossRanker
from rank_bm25 import BM25Okapi

tqdm.pandas()

corpus = df_collection[:][['title', 'abstract']].apply(lambda x: f"{x['title']} {x['abstract']}", axis=1).tolist()

tokenized_corpus = [doc.split(' ') for doc in corpus]
bm25 = BM25Okapi(tokenized_corpus)

cord_uids = df_collection[:]['cord_uid'].tolist()
cross_ranker = CrossRanker(model_name)

def get_top_cord_uids_bm25(query):
    tokenized_query = query.split(' ')
    doc_scores = bm25.get_scores(tokenized_query)
    indices = np.argsort(-doc_scores)[:100]
    bm25_topk = [cord_uids[x] for x in indices]

    reduced_corpus = df_collection[df_collection['cord_uid'].isin(bm25_topk)]

    return reduced_corpus


def get_top_cord_uids(query):
    reduced_corpus = get_top_cord_uids_bm25(query)
    
    doc_scores = cross_ranker.get_scores(query, reduced_corpus)
    
    indices = np.argsort(-doc_scores)[:10]

    cross_topk = [reduced_corpus.iloc[x]["cord_uid"] for x in indices]

    return cross_topk

# Retrieve topk candidates using the BM25 model
df_query.loc[:,'bm25-cross'] = df_query.loc[:, 'tweet_text'].progress_apply(lambda x: get_top_cord_uids(x))
df_query.head()


DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /allenai/scibert_scivocab_cased/resolve/main/tokenizer_config.json HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /allenai/scibert_scivocab_cased/resolve/main/config.json HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /allenai/scibert_scivocab_cased/resolve/main/tokenizer_config.json HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /allenai/scibert_scivocab_cased/resolve/main/vocab.txt HTTP/1.1" 200 0
100%|██████████| 1400/1400 [1:12:59<00:00,  3.13s/it]


Unnamed: 0,post_id,tweet_text,cord_uid,bm25-cross
0,16,covid recovery: this study from the usa reveal...,3qvh482o,"[hg3xpej0, styavbvi, trrg1mnw, jrqlhjsm, is9a7..."
1,69,"""Among 139 clients exposed to two symptomatic ...",r58aohnu,"[r58aohnu, icgsbelo, kiq6xb6k, yrowv62k, d06np..."
2,73,I recall early on reading that researchers who...,sts48u9i,"[sts48u9i, gruir7aw, lpbb4rga, sgo76prc, 4aps0..."
3,93,You know you're credible when NIH website has ...,3sr2exq9,"[3sr2exq9, k0f4cwig, 8j3bb6zx, sv48gjkk, kca5r..."
4,96,Resistance to antifungal medications is a grow...,ybwwmyqy,"[ybwwmyqy, rs3umc1x, ouvq2wpq, fiicxnty, vabb2..."


In [5]:
from eval_scripts.eval import get_performance_mrr, get_avg_gold_in_pred

mrr_results = get_performance_mrr(df_query, 'cord_uid', 'bm25-cross')
gold_results = get_avg_gold_in_pred(df_query, 'cord_uid', 'bm25-cross', list_k=[1, 5, 10, 100])
# Printed MRR@k results in the following format: {k: MRR@k}
print(">>>")
print(mrr_results)
print(gold_results)
print("<<<")

>>>
{1: np.float64(0.5935714285714285), 5: np.float64(0.6460595238095238), 10: np.float64(0.6490107709750567)}
{1: np.float64(0.5935714285714285), 5: np.float64(0.7235714285714285), 10: np.float64(0.7457142857142857), 100: np.float64(0.7457142857142857)}
<<<
