In [1]:
import numpy as np
import pandas as pd
import torch

torch.set_float32_matmul_precision('high')

DEV_MODE=False
PATH_COLLECTION_DATA = 'data/subtask_4b/subtask4b_collection_data.pkl'

df_collection = pd.read_pickle(PATH_COLLECTION_DATA)
df_collection.info()

<class 'pandas.core.frame.DataFrame'>
Index: 7718 entries, 162 to 1056448
Data columns (total 17 columns):
 #   Column            Non-Null Count  Dtype         
---  ------            --------------  -----         
 0   cord_uid          7718 non-null   object        
 1   source_x          7718 non-null   object        
 2   title             7718 non-null   object        
 3   doi               7677 non-null   object        
 4   pmcid             4959 non-null   object        
 5   pubmed_id         6233 non-null   object        
 6   license           7718 non-null   object        
 7   abstract          7718 non-null   object        
 8   publish_time      7715 non-null   object        
 9   authors           7674 non-null   object        
 10  journal           6668 non-null   object        
 11  mag_id            0 non-null      float64       
 12  who_covidence_id  528 non-null    object        
 13  arxiv_id          20 non-null     object        
 14  label             7718 n

In [2]:
query_variant = 'dev' if DEV_MODE else 'test'

PATH_QUERY_DATA = f'data/subtask_4b/subtask4b_query_tweets_{query_variant}.tsv'
df_query = pd.read_csv(PATH_QUERY_DATA, sep = '\t')
df_query.head()

Unnamed: 0,post_id,tweet_text
0,1,A recent research study published yesterday cl...
1,2,"""We should track the long-term effects of thes..."
2,3,"the agony of ""long haul"" covid-19 symptoms."
3,4,Home and online monitoring and assessment of b...
4,5,"it may be a long one, folks! to avoid exceedin..."


In [3]:
PARTIAL_PREDICTION_FILE = "partial-predictions/classifier/predictions.tsv"
partial_predictions = pd.read_csv(PARTIAL_PREDICTION_FILE, sep = '\t')
partial_predictions.head()

Unnamed: 0,post_id,preds
0,1,"['qgwu9fsk', 'x4zuv4jo', 'bv7hvc1e', 'j0bu0upi..."
1,2,"['mm2aotem', '4vkkaqhz', '855atuue', 'h9nzxlaf..."
2,3,"['gtp5daep', '00ugdhvf', 'm3m2n3fw', 'trrg1mnw..."
3,4,"['wabd3b9z', 'z43p1dmf', 'cfi5zhgu', 'c0i45oix..."
4,5,"['nzat41wu', 'i88ccp9w', 'ky5env7t', 'trrg1mnw..."


In [4]:
from rankers.pairwise_ranker import PairwiseRanker
from tqdm import tqdm
from os import listdir

tqdm.pandas()

base_name = "models/pairwise-classifier-large"
dir_list = listdir(base_name)
dir_list.sort()

latest_checkpoint = dir_list[-1]

model_name = f"{base_name}/{latest_checkpoint}"

pairwise_ranker = PairwiseRanker(model_name)


def get_top_cord_uids(query):
    tweet_id = query["post_id"]
    selected_docs_uids = eval(partial_predictions[partial_predictions["post_id"] == tweet_id]["preds"].values[0])
    
    reduced_corpus = df_collection[df_collection['cord_uid'].isin(selected_docs_uids)]
    pair_topk = pairwise_ranker.rank_avg_prob(query["tweet_text"], reduced_corpus, use_cache=True)

    return pair_topk

# Retrieve topk candidates using the BM25 model
df_query['bi_cross_pair'] = df_query.progress_apply(lambda x: get_top_cord_uids(x), axis=1)

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /answerdotai/ModernBERT-base/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
  0%|          | 0/1446 [00:00<?, ?it/s]W0510 00:31:21.518000 3133 torch/_inductor/utils.py:1250] [0/0] Not enough SMs to use max_autotune_gemm mode
  1%|          | 9/1446 [02:17<6:05:53, 15.28s/it]


KeyboardInterrupt: 

In [None]:
from eval_scripts.eval import get_performance_mrr, get_avg_gold_in_pred, create_pred_file
if DEV_MODE:
    mrr_results = get_performance_mrr(df_query, 'cord_uid', 'bi_cross_pair')
    gold_results = get_avg_gold_in_pred(df_query, 'cord_uid', 'bi_cross_pair', list_k=[5, 10])
    # Printed MRR@k results in the following format: {k: MRR@k}
    print(">>>")
    print(mrr_results)
    print(gold_results)
    print("<<<")

create_pred_file(df_query, "bi_cross_pair", prediction_size=5, include_gold=DEV_MODE, base_folder="partial-predictions/pairwise")

>>>
{1: np.float64(0.8), 5: np.float64(0.9), 10: np.float64(0.9)}
{5: np.float64(1.0), 10: np.float64(1.0)}
<<<
