In [None]:
from FlagEmbedding import LightWeightFlagLLMReranker
from financerag.rerank import CrossEncoderReranker
import financerag.tasks as tasks_module

import importlib
import inspect
import os
import json
import pandas as pd
import logging

logging.basicConfig(level=logging.INFO)

In [None]:
passage_retrieval = ['FinDER', 'FinQABench', 'FinanceBench']

In [None]:
RERANKER = 'BAAI/bge-reranker-v2.5-gemma2-lightweight'

reranker_name = RERANKER.split('/')[-1]

reranker = LightWeightFlagLLMReranker(RERANKER, use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
reranker = CrossEncoderReranker(model = reranker)

In [None]:
for task_class in passage_retrieval:
    retrieval_dir = f'./hybrid'    
    task_class_obj = getattr(tasks_module, task_class)
    finder_task = task_class_obj()

    file_name = f"{retrieval_dir}/{task_class}_mm_best_cc.json"
    
    with open(file_name, "r") as json_file:
        retrieval_result = json.load(json_file)
    
    rerank_result = finder_task.rerank(
        reranker=reranker,
        results=retrieval_result,
        top_k=100,  
        batch_size=4
    )
    
    with open(f'./rerank/{task_class}_rerank.json', "w") as json_file:
        json.dump(rerank_result, json_file, indent=4)
        
    # Retrieval vs Rerank
    # Comparison and evaluation to determine the final outcome.
    df = pd.read_csv(f'./eval/{task_class}_qrels.tsv', sep='\t')
    qrels_dict = df.groupby('query_id').apply(lambda x: dict(zip(x['corpus_id'], x['score']))).to_dict()
    
    retrieval_score = finder_task.evaluate(qrels_dict, retrieval_result, [1, 5, 10])
    retrieval_ndcg = retrieval_score[0]['NDCG@10']

    rerank_score = finder_task.evaluate(qrels_dict, rerank_result, [1, 5, 10])
    rerank_ndcg = rerank_score[0]['NDCG@10']
    
    if retrieval_ndcg <= rerank_ndcg:
        with open(f'./final/{task_class}_final.json', "w") as json_file:
            json.dump(rerank_result, json_file, indent=4)
    else:
        with open(f'./final/{task_class}_final.json', "w") as json_file:
            json.dump(retrieval_result, json_file, indent=4)