In [None]:
from sentence_transformers import CrossEncoder
import logging
from financerag.retrieval import DenseRetrieval, SentenceTransformerEncoder
import financerag.tasks as tasks_module

from transformers import AutoConfig

import importlib
import inspect
import os
import json
import pandas as pd
  
logging.basicConfig(level=logging.INFO)

In [None]:
passage_retrieval = ['FinDER', 'FinQABench', 'FinanceBench']

In [None]:
ENCODER = 'bennegeek/stella_en_1.5B_v5'
#ENCODER = 'rbhatia46/financial-rag-matryoshka'
encoder_name = ENCODER.split('/')[-1]

encoder_model = SentenceTransformerEncoder(
    model_name_or_path = ENCODER,
    query_prompt= "Instruct: Given a financial question, relevant passages that best answer the question. \nQuery: ",
    doc_prompt="Passage: ",
    trust_remote_code=True,
)

retrieval_model = DenseRetrieval(
    model=encoder_model,
    batch_size=8,
)

In [None]:
output_dir = f'./{encoder_name}'
os.makedirs(output_dir, exist_ok=True)

for task_class in passage_retrieval:
    task_class_obj = getattr(tasks_module, task_class)
    finder_task = task_class_obj()
    
    retrieval_result = finder_task.retrieve(
        retriever=retrieval_model,
        top_k=500,
    )

    
    file_name = f"{output_dir}/{task_class}.json"
                
    with open(file_name, "w") as json_file:
        json.dump(retrieval_result, json_file, indent=4)
    
    #save eval result
    df = pd.read_csv(f'./eval/{task_class}_qrels.tsv', sep='\t')
    
    # Convert the TSV data into a dictionary format for evaluation
    qrels_dict = df.groupby('query_id').apply(lambda x: dict(zip(x['corpus_id'], x['score']))).to_dict()
    
    eval_result = finder_task.evaluate(qrels_dict, retrieval_result, [1, 5, 10])
    combined_result = {**eval_result[0], **eval_result[1], **eval_result[2], **eval_result[3]}
    df_eval = pd.DataFrame([combined_result])
    df_eval.to_csv(f'{output_dir}/{task_class}_eval.csv', index=False)