In [1]:
import numpy as np
import pandas as pd
import torch

torch.set_float32_matmul_precision('high')

DEV_MODE=True
PATH_COLLECTION_DATA = 'data/subtask_4b/subtask4b_collection_data.pkl'

df_collection = pd.read_pickle(PATH_COLLECTION_DATA)
df_collection.info()

<class 'pandas.core.frame.DataFrame'>
Index: 7718 entries, 162 to 1056448
Data columns (total 17 columns):
 #   Column            Non-Null Count  Dtype         
---  ------            --------------  -----         
 0   cord_uid          7718 non-null   object        
 1   source_x          7718 non-null   object        
 2   title             7718 non-null   object        
 3   doi               7677 non-null   object        
 4   pmcid             4959 non-null   object        
 5   pubmed_id         6233 non-null   object        
 6   license           7718 non-null   object        
 7   abstract          7718 non-null   object        
 8   publish_time      7715 non-null   object        
 9   authors           7674 non-null   object        
 10  journal           6668 non-null   object        
 11  mag_id            0 non-null      float64       
 12  who_covidence_id  528 non-null    object        
 13  arxiv_id          20 non-null     object        
 14  label             7718 n

In [2]:
query_variant = 'dev' if DEV_MODE else 'test'

PATH_QUERY_DATA = f'data/subtask_4b/subtask4b_query_tweets_{query_variant}.tsv'
df_query = pd.read_csv(PATH_QUERY_DATA, sep = '\t')
df_query.head()

Unnamed: 0,post_id,tweet_text,cord_uid
0,16,covid recovery: this study from the usa reveal...,3qvh482o
1,69,"""Among 139 clients exposed to two symptomatic ...",r58aohnu
2,73,I recall early on reading that researchers who...,sts48u9i
3,93,You know you're credible when NIH website has ...,3sr2exq9
4,96,Resistance to antifungal medications is a grow...,ybwwmyqy


In [3]:
PARTIAL_PREDICTION_FILE = "partial-predictions/sbert/predictions.tsv"
partial_predictions = pd.read_csv(PARTIAL_PREDICTION_FILE, sep = '\t')
partial_predictions.head()

Unnamed: 0,post_id,preds,cord_uid
0,16,"['3qvh482o', 'jrqlhjsm', 'hg3xpej0', 'styavbvi...",3qvh482o
1,69,"['r58aohnu', 'cpbu3fv3', '5u63aqo5', 'u1q6wl45...",r58aohnu
2,73,"['sts48u9i', 'u5nxm9tu', 'rytzyf1j', 'a7frertc...",sts48u9i
3,93,"['3sr2exq9', 'xl0zb9zj', 'k0f4cwig', 'e8mk04uf...",3sr2exq9
4,96,"['ybwwmyqy', 'ierqfgo5', 'sxx3yid9', 'qh6rif48...",ybwwmyqy


In [4]:
from tqdm import tqdm
from rankers.cross_embedding_ranker import CrossRanker
from os import listdir

tqdm.pandas()

dir_list = listdir("models/cross-embedding")
dir_list.sort()

latest_checkpoint = dir_list[-1]

cross_model_name = f"models/cross-embedding/{latest_checkpoint}"
cross_ranker = CrossRanker(cross_model_name)


def get_top_cord_uids(query):
    tweet_id = query["post_id"]
    selected_docs_uids = eval(partial_predictions[partial_predictions["post_id"] == tweet_id]["preds"].values[0])[:200]

    reduced_corpus = df_collection[df_collection['cord_uid'].isin(selected_docs_uids)]
    cord_uids = reduced_corpus[:]['cord_uid'].tolist()

    doc_scores = cross_ranker.get_scores(query["tweet_text"], reduced_corpus)
    indices = np.argsort(-doc_scores)[:200]
    cross_topk = [cord_uids[x] for x in indices]

    return cross_topk

# Retrieve topk candidates using the BM25 model
df_query['cross'] = df_query.progress_apply(lambda x: get_top_cord_uids(x), axis=1)

  0%|          | 0/1400 [00:00<?, ?it/s]W0514 16:27:02.508000 1711 torch/_inductor/utils.py:1250] [0/0] Not enough SMs to use max_autotune_gemm mode
100%|██████████| 1400/1400 [21:15<00:00,  1.10it/s]


In [5]:
from eval_scripts.eval import get_performance_mrr, get_avg_gold_in_pred, create_pred_file
if DEV_MODE:
    mrr_results = get_performance_mrr(df_query, 'cord_uid', 'cross')
    gold_results = get_avg_gold_in_pred(df_query, 'cord_uid', 'cross', list_k=[5, 10, 15])
    # Printed MRR@k results in the following format: {k: MRR@k}
    print(">>>")
    print(mrr_results)
    print(gold_results)
    print("<<<")

create_pred_file(df_query, "cross", prediction_size=10, include_gold=DEV_MODE, base_folder="partial-predictions/classifier")

>>>
{1: np.float64(0.6128571428571429), 5: np.float64(0.6760833333333334), 10: np.float64(0.6831757369614513)}
{5: np.float64(0.775), 10: np.float64(0.8278571428571428), 15: np.float64(0.8478571428571429)}
<<<
