In [None]:
#Packages needed for retrieval and reranking with T5

!pip install bm25_pt
!pip install rerankers
!pip install "rerankers[all]"

In [None]:
class Rerank:
    def __init__(self, device=None, model_name="t5"):
        self.model_name = model_name
        self.reranker = Reranker(self.model_name)

    def rerank_with_rerankers(self, row, k=10):
        query = row['tweet_text']
        docs = [
            f"{info['title']} {info['abstract']}"
            for info in row["title_abstract"].values()
        ]
        doc_ids = list(row["title_abstract"].keys())

        results = self.reranker.rank(query=query, docs=docs, doc_ids=doc_ids)

        #ranked_document_indices = [result['corpus_id'] for result in results]
        #ranked_paper_ids = [list(title_abstracts.keys())[index] for index in ranked_document_indices]
        #ranked_paper_ids = [doc_ids[index] for index in ranked_document_indices]
        ranked_paper_ids = [res.doc_id for res in results]
        return ranked_paper_ids

In [None]:
import argparse
import numpy as np
import pandas as pd
import time
import torch
import rerankers
from bm25_pytorch import BM25_Pytorch

#from rerank import Rerank
#from rerank_t5 import Rerank
from util import get_performance_mrr, retrieve_paper, output_file, preprocess
from rerankers import Reranker, Document

EXPERIMENT = "bm25pt-t5"
TEST_SET = True


def retrieve(df_collection, df_query, f, device=None, mrr_k = [1, 5, 10]):
  # Retrieval - Create the BM25 corpus (baseline)
  corpus = df_collection[:][['title', 'abstract']].apply(lambda x: f"{x['title']} {x['abstract']}", axis=1).tolist()
  cord_uids = df_collection[:]['cord_uid'].tolist()

  bm25 = BM25_Pytorch(corpus=corpus, cord_uids=cord_uids, device=device)

  # Retrieve topk candidates using the BM25 model
  df_query['bm25_topk'] = df_query['tweet_text'].apply(lambda x: bm25.get_top_cord_uids(x))

  '''
  if not TEST_SET:
    results = get_performance_mrr(df_query,
                                    col_gold='cord_uid',
                                    col_pred='bm25_topk',
                                    list_k = mrr_k)
    
    f.write("EXPERIMENT {} BASELINE (BM25) RESULTS:\n".format(EXPERIMENT))
    f.write(str(results))
    f.write("\n\n")
  '''

  return df_query


def rerank(df_collection, df_query, f, rerank_model, rerank_k, device=None, mrr_k = [1, 5, 10]):
  rerank = Rerank(model_name=rerank_model, device=device)

  df_query['title_abstract'] = df_query['bm25_topk'].apply(lambda row: retrieve_paper(df_collection=df_collection, paper_ids=row))
  df_query['bm25_reranker_topk'] = df_query.apply(lambda row: rerank.rerank_with_rerankers(row, k=rerank_k), axis=1)

  '''
  if not TEST_SET:
    # Check the result (this will contain the tweet and paper pairs)
    results = get_performance_mrr(df_query,
                                  col_gold='cord_uid',
                                  col_pred='bm25_cross_encoder_topk',
                                 list_k = mrr_k)

    f.write("EXPERIMENT {} RERANK RESULTS:\n".format(EXPERIMENT))
    f.write(str(results))
    f.write("\n")
  '''

  return df_query


def main(path_collection_data, path_query_data, output_dir, rerank_model, rerank_k, list_mrr_k):

  df_collection = pd.read_pickle(path_collection_data)
  df_collection = preprocess(df=df_collection)
  df_query = pd.read_csv(path_query_data, sep = '\t')
  df_query = preprocess(df=df_query)

  print("CUDA", torch.cuda.is_available())

  device = torch.device("cuda:0" if torch.cuda.is_available() else None)
  torch._dynamo.config.suppress_errors = True

  try:
    results_filename = "{}_{}".format(EXPERIMENT, rerank_model.replace("/", '-'))
    results_file = output_file(experiment_name=results_filename, output_dir=output_dir)
    results_file.write("EXPERIMENT: {}\n\n".format(EXPERIMENT))
    results_file.write("INPUT PARAMETERS:\n")
    results_file.write("path_cord_data: {}\n".format(path_collection_data))
    results_file.write("path_tweet_data: {}\n".format(path_query_data))
    results_file.write("output_dir: {}\n".format(output_dir))
    results_file.write("rerank_model: {}\n".format(rerank_model))
    results_file.write("rerank_k: {}\n".format(rerank_k))
    results_file.write("mrr_k: {}\n\n".format(list_mrr_k))

    ## Retrieve
    df_query = retrieve(df_collection=df_collection,
                        df_query=df_query,
                        f=results_file,
                        device=device,
                        mrr_k = list_mrr_k)

    ## Rerank
    df_query = rerank(df_collection=df_collection,
                      df_query=df_query,
                      f=results_file,
                      device=device,
                      rerank_model=rerank_model,
                      rerank_k=rerank_k,
                      mrr_k =list_mrr_k)

    df_query['preds'] = df_query['bm25_reranker_topk'].apply(lambda x: x[:5])
    df_query[['post_id', 'preds']].to_csv('predictions_{}_{}.tsv'.format(EXPERIMENT, rerank_model.replace("/", '-')), index=None, sep='\t')

  finally:
    results_file.close()


if __name__ == '__main__':

  '''
  parser = argparse.ArgumentParser(description="CLEF CheckThat! Task 4B")

  parser.add_argument("path_cord_data", help="Filepath for CORD dataset (.pkl file).")
  parser.add_argument("path_tweet_data", help="Filepath for Tweet dataset (.tsv file)")
  parser.add_argument("--output_dir", help="Output directory. Defaults to 'results.'")
  parser.add_argument("--rerank_model", help="Cross-encoder reranking model.")
  parser.add_argument("--mrr_k", help="List of MRR@K results to return. Defaults to [1, 5, 10].")
  parser.add_argument("--rerank_k", help="Number of items to pull back for re-ranking with Cross-Encoder. Defaults to 10.")
  '''
  '''
  main(path_collection_data=args.path_cord_data,
        path_query_data=args.path_tweet_data,
        output_dir=output_dir,
        rerank_model=rerank_model,
        rerank_k=rerank_k,
        list_mrr_k=mrr_k)
  '''


  path_collection_data = "file path to subtask4b_collection_data.pkl"
  path_query_data = "file path to subtask4b_query_tweets_test.tsv"
  output_dir = "file path to output directory"
  rerank_model = "t5"
  rerank_k = 10
  list_mrr_k = [1, 5, 10]

  main(path_collection_data=path_collection_data,
      path_query_data=path_query_data,
      output_dir=output_dir,
      rerank_model=rerank_model,
      rerank_k=rerank_k,
      list_mrr_k=list_mrr_k)
