In [None]:
import os
import pandas as pd
from pinecone import Pinecone, ServerlessSpec
import sys
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np 

sys.path.append('..')

from main import VectorDatabase, BiEncoder, SimpleSentenceChunker

## Initiliaze Database

In [None]:
API_KEY = "c4ac140e-932e-40c3-84e5-e407580eef2a"
pc = Pinecone(api_key=API_KEY)
indexes = pc.list_indexes()
print(indexes)

In [3]:
INDEX_NAME = 'test'
DIMENSION = 312  #
CLOUD = 'aws'
REGION = 'us-west-1'

In [4]:
vector_db = VectorDatabase(api_key=API_KEY)
handler = vector_db.start_db(index_name=INDEX_NAME, dimension=DIMENSION, cloud=CLOUD, region=REGION)

## Read Data

In [None]:
data_folder_path = os.path.join('..', 'data')
query_df = pd.read_csv(os.path.join(data_folder_path, "FinDER/queries.csv"), index_col=0)
document_df = pd.read_csv(os.path.join(data_folder_path, "FinDER/corpus.csv"), index_col=0)
document_df

## Embed Documents and Save to DB

In [6]:
tokenizer = AutoTokenizer.from_pretrained('huawei-noah/TinyBERT_General_4L_312D')
model = AutoModel.from_pretrained('huawei-noah/TinyBERT_General_4L_312D')

In [7]:
encoder = BiEncoder(tokenizer, model)

In [8]:
document_df["text"] = document_df["text"].fillna("") 
texts = document_df["text"].astype(str).tolist() 

In [9]:
try:
    handler.delete_all()
except:
    print("No index to delete")

In [None]:
batch_size = 10  

document_df = document_df[document_df.index.str.startswith('MSFT') | document_df.index.str.startswith('AAPL')]
document_df["text"] = document_df["text"].fillna("")
texts = document_df["text"].astype(str).tolist()

# Should be len(texts) in the final versipn
limit = len(texts)

def batch_upsert(texts, batch_size):
    # Iterate through batches of texts
    for i in range(0, limit, batch_size):

        try:
            batch_texts = texts[i:i+batch_size]
            batch_indexes = document_df.index[i:i+batch_size]

            # Encode the current batch
            encoded_documents = encoder.encode_batch(batch_texts)

            # Prepare the batch data for upsert
            batch_data = [(str(idx), embedding.tolist()) for idx, embedding in zip(batch_indexes, encoded_documents)]

            # Perform batch upsert
            handler.index.upsert(vectors=batch_data)
            print(f"Upserted batch {i//batch_size + 1}")

        except Exception as e:
            print(f"Error: {e}")

batch_upsert(texts, batch_size=batch_size)

## Retrieve

In [27]:
retrieved_df = pd.DataFrame([[[] ] for _ in query_df.index], index=query_df.index, columns=["Documents"])
retrieved_df

Unnamed: 0,Documents
q00001,[]
q00002,[]
q00003,[]
q00004,[]
q00005,[]
...,...
q00214,[]
q00215,[]
q00216,[]
q00217,[]


In [28]:
for idx, row in query_df.iterrows():
    query = row["text"]
    query = encoder.encode(query)
    query = np.array(query, dtype=np.float32)

    query_list = query.tolist()

    results = handler.query_vector(query_list, top_k=1)
    retrieved_df.at[idx, "Documents"] = [ result["id"] for result in results]

In [29]:
retrieved_df

Unnamed: 0,Documents
q00001,[AAPL20230011]
q00002,[MSFT20230741]
q00003,[AAPL20230021]
q00004,[AAPL20230011]
q00005,[AAPL20230011]
...,...
q00214,[AAPL20230045]
q00215,[AAPL20230908]
q00216,[MSFT20230429]
q00217,[AAPL20230021]


## Evaluate

In [30]:
def evaluate_retrieval(actual_related_ids, retrieved_docs_ids, top_k=None):
    precisions = []
    recalls = []
    
    for actual_ids, retrieved_ids in zip(actual_related_ids, retrieved_docs_ids):
        if top_k:
            retrieved_ids = retrieved_ids[:top_k]
        
        actual_set = set(actual_ids)
        retrieved_set = set(retrieved_ids)
        
        true_positives = len(actual_set & retrieved_set)
        precision = true_positives / len(retrieved_set) if retrieved_set else 0
        recall = true_positives / len(actual_set) if actual_set else 0
        
        precisions.append(precision)
        recalls.append(recall)
    
    avg_precision = sum(precisions) / len(precisions) if precisions else 0
    avg_recall = sum(recalls) / len(recalls) if recalls else 0
    
    avg_f1 = 2 * (avg_precision * avg_recall) / (avg_precision + avg_recall) if (avg_precision + avg_recall) > 0 else 0
    
    return avg_precision, avg_recall, avg_f1

In [31]:
evaluate_retrieval(query_df["Related Documents"], retrieved_df["Documents"])

(0.0, 0.0, 0)

In [40]:
not_empty = query_df[query_df["Related Documents"].apply(lambda x: len(x) > 2)]["Related Documents"]
index = not_empty.index
not_empty.values

array(["['MSFT20230014', 'MSFT20230015']", "['MSFT20231529']",
       "['MSFT20231529']",
       "['ADBE20231571', 'ADBE20231572', 'ADBE20230728', 'ADBE20231573']",
       "['CPNG20230732']", "['CPNG20230658']", "['CPNG20230553']",
       "['LIN20231133']", "['LIN20231195']",
       "['LIN20230064', 'LIN20230065', 'LIN20230066', 'LIN20230067']",
       "['LIN20230551']", "['ORCL20230738', 'ORCL20230739']",
       "['ORCL20230129', 'ORCL20230130', 'ORCL20230131']",
       "['ORCL20231527', 'ORCL20231529']", "['ORCL20231505']",
       "['NVDA20231260']", "['PG20230221', 'PG20230805']",
       "['PG20230429']", "['PG20230438']", "['PG20230440']",
       "['DAL20230459']", "['DAL20230513']", "['TSLA20230391']",
       "['TSLA20231453', 'TSLA20231454']",
       "['NFLX20230692', 'NFLX20230006']", "['NFLX20230387']",
       "['NFLX20230380']", "['HD20230012']",
       "['AAPL20230236', 'AAPL20230557']", "['AAPL20230251']",
       "['AAPL20230871', 'AAPL20230222']", "['AAPL20230966']",
      

In [42]:
retrieved_df.loc[index]["Documents"]

q00001    [AAPL20230011]
q00007    [AAPL20230694]
q00008    [AAPL20230694]
q00010    [AAPL20230011]
q00019    [AAPL20230694]
               ...      
q00197    [AAPL20230071]
q00200    [AAPL20230011]
q00204    [AAPL20230011]
q00210    [AAPL20230011]
q00212    [AAPL20230196]
Name: Documents, Length: 64, dtype: object

In [47]:
document_df.loc['AAPL20230694']["text"]

'As of September 24, 2022, the Company had one customer that represented 10% or more of total trade receivables, which accounted for 10%. The Company’s third-party cellular network carriers accounted for 41% and 44% of total trade receivables as of September 30, 2023 and September 24, 2022, respectively. The Company requires third-party credit support or collateral from certain customers to limit credit risk.'

In [49]:
query_df.iloc[7]["text"]

'MSFT remaining performance obligation'