In [13]:
import os
import sys
import numpy as np 
import pandas as pd

from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.docstore.document import Document

sys.path.append('..')

In [14]:
import warnings

warnings.filterwarnings('ignore')

## Initiliaze Database

In [15]:
embedder = HuggingFaceEmbeddings(model_name="msmarco-distilbert-base-v4")

In [16]:
vector_store = Chroma(
    embedding_function=embedder,  
    persist_directory="./chroma"  
)

## Read Data

In [None]:
data_folder_path = os.path.join('..', 'data')
query_df = pd.read_csv(os.path.join(data_folder_path, "FinDER/queries.csv"), index_col=0)
documents_df = pd.read_csv(os.path.join(data_folder_path, "FinDER/corpus.csv"), index_col=0)
documents_df.dropna(subset=['text'], inplace=True)
documents_df

## Embed Documents and Save to DB

In [None]:
langchain_documents = [
    Document(page_content=row['text'], metadata={"id": index})
    for index, row in documents_df.iterrows()  
]

vector_store.add_documents(langchain_documents)

## Retrieve

In [None]:
retrieved_df = pd.DataFrame([[[] ] for _ in query_df.index], index=query_df.index, columns=["Documents"])
retrieved_df

In [35]:
for query_index, query_row in query_df.iterrows():
    query = query_row['text']
    retrieved_docs = vector_store.similarity_search(query, k=3)
    retrieved_doc_ids = [doc.metadata['id'] for doc in retrieved_docs]
    retrieved_df.loc[query_index, "Documents"] = retrieved_doc_ids

In [36]:
retrieved_df

Unnamed: 0,Documents
q00001,"[MSFT20230254, MSFT20230966, MSFT20230236]"
q00002,"[UNH20230813, NVDA20230304, JPM20231560]"
q00003,"[MSFT20230025, MSFT20230236, MSFT20230010]"
q00004,"[MSFT20230010, MSFT20230469, MSFT20230092]"
q00005,"[ADBE20231285, JPM20231233, ORCL20231547]"
...,...
q00214,"[BRK.A20230028, BRK.A20230009, BRK.A20230427]"
q00215,"[BRK.A20230404, BRK.A20230920, BRK.A20230035]"
q00216,"[BRK.A20232401, BRK.A20230726, BRK.A20231490]"
q00217,"[BRK.A20230062, BRK.A20230063, BRK.A20230594]"


## Evaluate

In [37]:
def evaluate_retrieval(actual_related_ids, retrieved_docs_ids, top_k=None):
    precisions = []
    recalls = []
    
    for actual_ids, retrieved_ids in zip(actual_related_ids, retrieved_docs_ids):
        if top_k:
            retrieved_ids = retrieved_ids[:top_k]
        
        actual_set = set(actual_ids)
        retrieved_set = set(retrieved_ids)
        
        true_positives = len(actual_set & retrieved_set)
        precision = true_positives / len(retrieved_set) if retrieved_set else 0
        recall = true_positives / len(actual_set) if actual_set else 0
        
        precisions.append(precision)
        recalls.append(recall)
    
    avg_precision = sum(precisions) / len(precisions) if precisions else 0
    avg_recall = sum(recalls) / len(recalls) if recalls else 0
    
    avg_f1 = 2 * (avg_precision * avg_recall) / (avg_precision + avg_recall) if (avg_precision + avg_recall) > 0 else 0
    
    return avg_precision, avg_recall, avg_f1

In [38]:
evaluate_retrieval(query_df["Related Documents"], retrieved_df["Documents"])

(0.0, 0.0, 0)

In [44]:
non_empty = query_df[query_df["Related Documents"].apply(lambda x: len(x) > 2)]
index = non_empty.index
non_empty.head(30)

Unnamed: 0,text,Related Documents
q00001,What are the service and product offerings fro...,"['MSFT20230014', 'MSFT20230015']"
q00007,How much revenue does Microsoft generate from ...,['MSFT20231529']
q00008,MSFT remaining performance obligation,['MSFT20231529']
q00010,ADBE share repurchase,"['ADBE20231571', 'ADBE20231572', 'ADBE20230728..."
q00019,When did Coupang`s Farfetch consolidation start,['CPNG20230732']
q00021,When did new FLC contract begin CPNG,['CPNG20230658']
q00022,CPNG free cash flow,['CPNG20230553']
q00027,asset divestitures Linde,['LIN20231133']
q00028,What is the total number of leases held by Linde,['LIN20231195']
q00030,the top 3 risks faced by Linde,"['LIN20230064', 'LIN20230065', 'LIN20230066', ..."


In [46]:
retrieved_df["Documents"].loc[index].head(30)

q00001    [MSFT20230254, MSFT20230966, MSFT20230236]
q00007    [MSFT20230482, MSFT20231767, MSFT20230574]
q00008       [ADBE20231120, MSFT20230508, V20231038]
q00010       [V20230586, AAPL20230871, MSFT20230455]
q00019    [CPNG20230539, CPNG20231413, CPNG20230812]
q00021      [ADBE20231613, JPM20231233, JNJ20231897]
q00022     [NFLX20230395, DAL20230783, AMZN20230469]
q00027       [LIN20230102, LIN20231143, LIN20231985]
q00028       [LIN20230179, LIN20230006, LIN20231195]
q00030       [LIN20230759, LIN20231636, LIN20230092]
q00034       [LIN20230769, LIN20230078, LIN20230122]
q00039      [DAL20231254, UNH20230470, TSLA20230555]
q00042    [ORCL20230439, ORCL20230014, ORCL20230147]
q00043       [DAL20230628, JPM20237313, JPM20237314]
q00044    [NVDA20230045, NVDA20230050, MSFT20230489]
q00048       [JPM20233794, UNH20230536, JPM20233800]
q00062       [BRK.A20230821, PG20230240, PG20230230]
q00067      [PG20230760, NFLX20230944, AAPL20230557]
q00070          [PG20230555, PG20231497, PG202