In [1]:
import os
import openai
import langchain
from langchain.prompts import ChatPromptTemplate
from langchain.vectorstores import FAISS
from langchain.document_loaders import CSVLoader
from langchain.embeddings.base import Embeddings
from langchain.chains import RetrievalQA
from langchain.indexes import VectorstoreIndexCreator
import pandas as pd
import tiktoken
import json
import warnings
warnings.filterwarnings("ignore")

In [2]:
index_name = "ICD_index"
file_name = "./data/ICD.csv"

In [3]:
openai_api_key = os.environ['OPENAI_API_KEY']
client = openai.OpenAI(
    api_key=openai_api_key,
    base_url="https://cmu.litellm.ai",
)

In [4]:
class CustomOpenAIEmbeddings(Embeddings):
    def __init__(self, client):
        self.client = client

    def embed_documents(self, texts):
        embeddings = []
        counter = 0
        for text in texts:
            if(counter%500==0):
                print(counter)
            counter +=1
            response = self.client.embeddings.create(input=text, model="text-embedding-3-small")
            embedding = response.data[0].embedding
            embeddings.append(embedding)
        return embeddings

    def embed_query(self, text):
        response = self.client.embeddings.create(input=text, model="text-embedding-3-small")
        return response.data[0].embedding

embedding_model = CustomOpenAIEmbeddings(client)

In [5]:
loader = CSVLoader(file_path=file_name)
documents = loader.load()
print(len(documents))

73201


In [6]:
if not os.path.exists(index_name):
    document_texts = [doc.page_content for doc in documents]
    document_embeddings = embedding_model.embed_documents(document_texts)

    text_embedding_pairs = zip(document_texts, document_embeddings)
    vector_store = FAISS.from_embeddings(text_embedding_pairs, embedding_model.embed_query)

    vector_store.save_local(index_name)

0
500
1000
1500
2000
2500
3000
3500
4000
4500
5000
5500
6000
6500
7000
7500
8000
8500
9000
9500
10000
10500
11000
11500
12000
12500
13000
13500
14000
14500
15000
15500
16000
16500
17000
17500
18000
18500
19000
19500
20000
20500
21000
21500
22000
22500
23000
23500
24000
24500
25000
25500
26000
26500
27000
27500
28000
28500
29000
29500
30000
30500
31000
31500
32000
32500
33000
33500
34000
34500
35000
35500
36000
36500
37000
37500
38000
38500
39000
39500
40000
40500
41000
41500
42000
42500
43000
43500
44000
44500
45000
45500
46000
46500
47000
47500
48000
48500
49000
49500
50000
50500
51000
51500
52000
52500
53000
53500
54000
54500
55000
55500
56000
56500
57000
57500
58000
58500
59000
59500
60000
60500
61000
61500
62000
62500
63000
63500
64000
64500
65000
65500
66000
66500
67000
67500
68000
68500
69000
69500
70000
70500
71000
71500
72000
72500
73000


`embedding_function` is expected to be an Embeddings object, support for passing in a function will soon be removed.
