In [1]:
import numpy as np
import pandas as pd
import torch

DEV_MODE = True

PATH_COLLECTION_DATA = 'data/subtask_4b/subtask4b_collection_data.pkl'

df_collection = pd.read_pickle(PATH_COLLECTION_DATA)
df_collection.info()

<class 'pandas.core.frame.DataFrame'>
Index: 7718 entries, 162 to 1056448
Data columns (total 17 columns):
 #   Column            Non-Null Count  Dtype         
---  ------            --------------  -----         
 0   cord_uid          7718 non-null   object        
 1   source_x          7718 non-null   object        
 2   title             7718 non-null   object        
 3   doi               7677 non-null   object        
 4   pmcid             4959 non-null   object        
 5   pubmed_id         6233 non-null   object        
 6   license           7718 non-null   object        
 7   abstract          7718 non-null   object        
 8   publish_time      7715 non-null   object        
 9   authors           7674 non-null   object        
 10  journal           6668 non-null   object        
 11  mag_id            0 non-null      float64       
 12  who_covidence_id  528 non-null    object        
 13  arxiv_id          20 non-null     object        
 14  label             7718 n

In [2]:
PATH_QUERY_DATA = 'data/dev-tweets/subtask4b_query_tweets_dev.tsv'

if not DEV_MODE:
    PATH_QUERY_DATA = 'data/subtask_4b/subtask4b_query_tweets_test.tsv'
    
df_query = pd.read_csv(PATH_QUERY_DATA, sep = '\t')
df_query.head()

Unnamed: 0,post_id,tweet_text,cord_uid
0,16,covid recovery: this study from the usa reveal...,3qvh482o
1,69,"""Among 139 clients exposed to two symptomatic ...",r58aohnu
2,73,I recall early on reading that researchers who...,sts48u9i
3,93,You know you're credible when NIH website has ...,3sr2exq9
4,96,Resistance to antifungal medications is a grow...,ybwwmyqy


In [3]:
# Data prep for bi-encoder
from tqdm import tqdm
from rankers.bi_encoder_ranker import BiEncoderRanker

tqdm.pandas()

def flatten_corpus(entry):    
    title = entry["title"]
    authors = entry["authors"]
    abstract = entry["abstract"]
    journal = entry["journal"]

    paper_data = f"{title} [SEP] {authors} [SEP] {abstract} [SEP] {journal}"
    return paper_data

cord_uids = df_collection[:]['cord_uid'].tolist()
corpus = df_collection.progress_apply(lambda x: flatten_corpus(x), axis = 1)
corpus = corpus.tolist()

100%|██████████| 7718/7718 [00:00<00:00, 86704.71it/s]


In [4]:
bi_model_name = "multi-qa-mpnet-base-dot-v1"
bi_enc_ranker = BiEncoderRanker(bi_model_name, corpus)

INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cuda:0
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: multi-qa-mpnet-base-dot-v1


Batches:   0%|          | 0/242 [00:00<?, ?it/s]

In [5]:
from os import listdir
from rankers.pairwise_ranker import PairwiseRanker

base_name = "models/pairwise-classifier-large"
dir_list = listdir(base_name)
dir_list.sort()

latest_checkpoint = dir_list[-1]

model_name = f"{base_name}/{latest_checkpoint}"

pairwise_ranker = PairwiseRanker(model_name)

In [None]:
from rankers.cross_embedding_ranker import CrossRanker

dir_list = listdir("models/cross-embedding")
dir_list.sort()

latest_checkpoint = dir_list[-1]

cross_model_name = f"models/cross-embedding/{latest_checkpoint}"
cross_ranker = CrossRanker(cross_model_name)

def get_top_cord_uids_bi(query):
    doc_scores = bi_enc_ranker.get_scores(query)
    indices = np.argsort(-doc_scores[0])[:100]
    
    bi_topk = [cord_uids[x] for x in indices]
    reduced_corpus = df_collection[df_collection['cord_uid'].isin(bi_topk)]

    return reduced_corpus

def get_top_cord_uids_cross(query):
    reduced_corpus = get_top_cord_uids_bi(query)    
    
    doc_scores = cross_ranker.get_scores(query, reduced_corpus)    
    indices = np.argsort(-doc_scores)[:10]

    cross_topk = [reduced_corpus.iloc[x]["cord_uid"] for x in indices] 
    reduced_corpus = df_collection[df_collection['cord_uid'].isin(cross_topk)]

    return reduced_corpus

def get_pair_sorted_uids(query):
    reduced_corpus = get_top_cord_uids_cross(query)

    pair_topk = pairwise_ranker.rank_avg_prob(query, reduced_corpus, use_cache=True)

    return pair_topk

# Retrieve topk candidates using the BM25 model
df_query.loc[:,'bi_cross_pair'] = df_query.loc[:, 'tweet_text'].progress_apply(lambda x: get_pair_sorted_uids(x))
df_query.head()

  0%|          | 6/1400 [00:45<2:56:34,  7.60s/it]


KeyboardInterrupt: 

In [None]:
from eval_scripts.eval import get_performance_mrr, get_avg_gold_in_pred, create_pred_file

if DEV_MODE:
    mrr_results = get_performance_mrr(df_query, 'cord_uid', 'bi_cross_pair')
    gold_results = get_avg_gold_in_pred(df_query, 'cord_uid', 'bi_cross_pair', list_k=[1, 3, 5, 10])
    # Printed MRR@k results in the following format: {k: MRR@k}


    print(">>>")
    print(mrr_results)
    print(gold_results)
    print("<<<")
else:
    create_pred_file(df_query, 'bi_cross_pair')