In [None]:
import logging
from financerag.retrieval import BM25Retriever
import financerag.tasks as tasks_module

import importlib
import inspect
import os
import json
import pandas as pd

from nltk.tokenize import word_tokenize, TweetTokenizer
from rank_bm25 import BM25Okapi
import nltk

tweet_tokenizer = TweetTokenizer()
logging.basicConfig(level=logging.INFO)

In [None]:
retrieval = ['FinDER', 'FinQABench', 'FinanceBench', 'TATQA', 'FinQA', 'ConvFinQA', 'MultiHiertt']
tabular_retrieval = ['TATQA', 'FinQA', 'ConvFinQA', 'MultiHiertt']

In [None]:
output_dir = './BM25'
os.makedirs(output_dir, exist_ok=True)

def process_task(task_class, corpus_documents, output_dir, finder_task):
    # Tokenize corpus
    tokenized_corpus = [tweet_tokenizer.tokenize(doc) for doc in corpus_documents]
    
    # Initialize BM25 and retrieval model
    bm25_model = BM25Okapi(tokenized_corpus)
    retrieval_model = BM25Retriever(bm25_model, tweet_tokenizer.tokenize)
    
    # Retrieve documents
    retrieval_result = finder_task.retrieve(
        retriever=retrieval_model,
        top_k=500
    )
    
    # Save retrieval result
    file_name = f"{output_dir}/{task_class}.json"
    with open(file_name, "w") as json_file:
        json.dump(retrieval_result, json_file, indent=4)
    
    # Save evaluation result
    df = pd.read_csv(f'./eval/{task_class.split("_")[0]}_qrels.tsv', sep='\t')
    qrels_dict = df.groupby('query_id').apply(lambda x: dict(zip(x['corpus_id'], x['score']))).to_dict()
    
    eval_result = finder_task.evaluate(qrels_dict, retrieval_result, [1, 5, 10])
    combined_result = {**eval_result[0], **eval_result[1], **eval_result[2], **eval_result[3]}
    df_eval = pd.DataFrame([combined_result])
    df_eval.to_csv(f'{output_dir}/{task_class.split("_")[0]}_eval.csv', index=False)


for task_class in retrieval:
    if task_class in tabular_retrieval:
    task_class_obj = getattr(tasks_module, task_class)
    finder_task = task_class_obj()

    corpus = pd.read_csv(f"./data/{task_class}_corpus_convert.csv")
    
    for version, column_name in [("original", "text"), ("convert", "convert_text")]:
        corpus_documents = corpus[column_name].str.lower().tolist()
        
        process_task(f"{task_class}_{version}", corpus_documents, output_dir, finder_task)
    else:
        corpus_documents = [doc['text'].lower() for doc in finder_task.corpus.values()]
        process_task(task_class, corpus_documents, output_dir, finder_task)