In [None]:
import pandas as pd
import json
import voyageai
import torch
import time
import logging
from tqdm import tqdm
import financerag.tasks as tasks_module

In [None]:
tabular_retrieval = ['TATQA', 'FinQA', 'ConvFinQA', 'MultiHiertt']

In [None]:
vo = voyageai.Client(api_key="")

In [None]:
## Rerank

#For Request Limits
BATCH_SIZE = 7
SLEEP_DURATION = 55

for task_class in tabular_retrieval:
    task_class_obj = getattr(tasks_module, task_class)
    finder_task = task_class_obj()

    retrieval_path = f"./hybrid_search/{task_class}_mm_best_cc.json"
    
    with open(retrieval_path, "r") as f:
        retrieval_result = json.load(f)
        query_ids = list(retrieval_result.keys())
    
    queries = pd.read_csv(f"./data/{task_class}_queries.csv")
    corpus = pd.read_csv(f"./data/{task_class}_corpus.csv")
    corpus = corpus.drop_duplicates(subset='_id').reset_index(drop=True)
    
    rerank_result = {}
    
    for i, query_id in enumerate(tqdm(query_ids, desc=f"Processing task {task_class}")):
        if i > 0 and i % BATCH_SIZE == 0:
            print(f"Processed {i} queries, resting for {SLEEP_DURATION} seconds...")
            time.sleep(SLEEP_DURATION)

        doc_scores = sorted(retrieval_result[query_id].items(), key=lambda x: x[1], reverse=True)[:100]
        doc_ids = [doc_id for doc_id, _ in doc_scores]

        documents = corpus.set_index('_id').loc[doc_ids]['text'].tolist()
        query = queries[queries['_id'] == query_id]['text'].values[0]
        rerank = vo.rerank(query, documents, model="rerank-2", top_k=10, truncation=False)
        rerank_doc_ids = [doc_ids[r.index] for r in rerank.results]
        rerank_result[query_id] = {doc_id: r.relevance_score for doc_id, r in zip(rerank_doc_ids, rerank.results)}

    # save rerank
    output_path = f"./rerank/{task_class}_rerank.json"
    with open(output_path, "w") as f:
        json.dump(rerank_result, f, indent=4)

    # Retrieval vs Rerank
    # Comparison and evaluation to determine the final outcome.
    df = pd.read_csv(f'./eval/{task_class}_qrels.tsv', sep='\t')
    qrels_dict = df.groupby('query_id').apply(lambda x: dict(zip(x['corpus_id'], x['score']))).to_dict()
    
    retrieval_score = finder_task.evaluate(qrels_dict, retrieval_result, [1, 5, 10])
    retrieval_ndcg = retrieval_score[0]['NDCG@10']

    rerank_score = finder_task.evaluate(qrels_dict, rerank_result, [1, 5, 10])
    rerank_ndcg = rerank_score[0]['NDCG@10']
    
    if retrieval_ndcg <= rerank_ndcg:
        with open(f'./final/{task_class}_final.json', "w") as json_file:
            json.dump(rerank_result, json_file, indent=4)
    else:
        with open(f'./final/{task_class}_final.json', "w") as json_file:
            json.dump(retrieval_result, json_file, indent=4)