In [None]:
import sys
import pandas as pd
import torch 
import os 

sys.path.append("../..")
# For retrieval

from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.docstore.document import Document
from sentence_transformers import CrossEncoder
from langchain.text_splitter import RecursiveCharacterTextSplitter

from financerag.tasks import FinQABench

In [2]:
import warnings

warnings.filterwarnings('ignore')

## Read Data

In [None]:
task = FinQABench()

In [4]:
queries = task.queries
query_df = pd.DataFrame(queries.values(), index=queries.keys(), columns=["query"])

In [5]:
documents = task.corpus
documents_df = pd.DataFrame(documents.values(), index=documents.keys(), columns=["title", "text"])

In [None]:
import matplotlib.pyplot as plt

# Calculate the length of each text
documents_df['text_length'] = documents_df['text'].apply(len)

# Plot the histogram
plt.figure(figsize=(10, 6))
plt.hist(documents_df['text_length'], bins=30, edgecolor='k', alpha=0.7)
plt.title('Histogram of Document Lengths')
plt.xlabel('Text Length')
plt.ylabel('Frequency')
plt.show()

## Initiliaze Database

In [None]:
embedder = HuggingFaceEmbeddings(model_name="msmarco-distilbert-base-v4")

persist_directory = ".chroma"

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=800,  # Define chunk size
    chunk_overlap=100  # Define overlap to maintain context between chunks
)

docs = []
for id, text in documents_df.text.items():
    # Split the document into chunks
    chunks = text_splitter.split_text(text)
    
    # Create a Document object for each chunk but store the parent document ID in the metadata
    for i, chunk in enumerate(chunks):
        doc = Document(page_content=chunk, metadata={"id": str(id), "chunk_index": i})
        docs.append(doc)

if os.path.exists(persist_directory):
    # Load the existing ChromaDB
    chroma_db = Chroma(persist_directory=persist_directory, embedding_function=embedder)
    print("Loaded existing ChromaDB from .chroma")
else:
    # Create ChromaDB and store the documents
    chroma_db = Chroma.from_documents(
        documents=docs,
        embedding=embedder,
        persist_directory=persist_directory,  
    )
    print("Created new ChromaDB and saved to .chroma")


## Retrieve

In [8]:
retriever = chroma_db.as_retriever(search_kwargs={"k": 10})

In [9]:
retrieved_df = pd.DataFrame([[{} ] for _ in query_df.index], index=query_df.index, columns=["Documents"])

In [10]:
for idx, query in query_df["query"].items():

    retrieved = retriever.invoke(query)

    retrieved = {
        str(doc.metadata["id"]):  1
        for doc in retrieved
    }
    retrieved_df.loc[idx]["Documents"] = retrieved

retrieved_results = retrieved_df["Documents"].to_dict()

## Re-Rank

In [11]:
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

def sigmoid(x):
    return 1 / (1 + torch.exp(-torch.tensor(x)))

for idx, query in query_df["query"].items():
    for doc_id in retrieved_results[idx]:

        raw_score = cross_encoder.predict((query, documents_df.loc[doc_id].text))
        normalized_score = sigmoid(raw_score).item()

        retrieved_results[idx][doc_id] = normalized_score
    
    retrieved_results[idx] = dict(sorted(retrieved_results[idx].items(), key=lambda item: item[1], reverse=True))

## Evaluate

In [12]:
qrels = pd.read_csv('../../data/resources/finq_bench_qrels.tsv', sep='\t')

In [13]:
qrels_dict = {}
for index, row in qrels.iterrows():
    key = row['query_id']
    if key not in qrels_dict:
        qrels_dict[key] = {}
    qrels_dict[key][row['corpus_id']] = row['score']

In [14]:
k_values = [1, 3, 10]
results = task.evaluate(qrels=qrels_dict, results=retrieved_results, k_values=k_values)

In [None]:
metrics_df = pd.DataFrame(index=k_values, columns=["MAP", "NDCG", "P@K", "R@K"])

metrics_df["MAP"] = [results[1][f"MAP@{k}"] for k in k_values]
metrics_df["NDCG"] = [results[0][f"NDCG@{k}"] for k in k_values]
metrics_df["P@K"] = [results[3][f"P@{k}"] for k in k_values]
metrics_df["R@K"] = [results[2][f"Recall@{k}"] for k in k_values]

metrics_df