In [48]:
import os
import dotenv
from langchain_community.utilities import SQLDatabase
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.chat_models import init_chat_model
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
from langchain_community.docstore import InMemoryDocstore 
from langchain.chains import RetrievalQA
from langchain.prompts import (
    PromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
    ChatPromptTemplate,
)
import faiss
import tqdm
import numpy as np

In [49]:
dotenv.load_dotenv()
mysql_url = "mysql+mysqlconnector://root:175003@localhost:3306/hospital"
db = SQLDatabase.from_uri(mysql_url)
llm = init_chat_model("llama3-70b-8192", model_provider="groq")

In [50]:
# Initialize embeddings
embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    model_kwargs={'device': 'cpu'},
    encode_kwargs={
        'normalize_embeddings': True,  
        'batch_size': 128,             
    }
)

# Vector store path
VECTOR_STORE_PATH = "./vector_store.faiss"

In [60]:
query = """
    SELECT comment
    FROM reviews
    WHERE comment IS NOT NULL AND comment != ''
    ORDER BY created_at DESC
    LIMIT 10
    """
execute_query_tool = QuerySQLDatabaseTool(db=db)
results = execute_query_tool.invoke(query)

In [52]:
def process_results(results_string: str) -> list:
    """Process the results string into a list of comments"""
    clean_str = results_string.strip('[]')
    rows = clean_str.split('), (')
    
    comments = []
    for row in rows:
        clean_row = row.strip("(')")
        if clean_row:
            comments.append(clean_row)
    
    return comments

In [61]:
comments = process_results(results)

In [54]:
if not os.path.exists(VECTOR_STORE_PATH):
    # Create embeddings for the comments
    embedded_texts = embeddings.embed_documents(comments)
    embeddings_array = np.array(embedded_texts, dtype=np.float32)
    
    # Get dimension from embeddings
    dimension = embeddings_array.shape[1]
    
    # Create FAISS index
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings_array)
    
    # Create vector store
    vector_store = FAISS(
        embedding_function=embeddings,
        index=index,
        docstore=InMemoryDocstore({i: doc for i, doc in enumerate(comments)}),
        index_to_docstore_id={i: str(i) for i in range(len(comments))}
    )
    
    # Save the vector store
    vector_store.save_local(VECTOR_STORE_PATH)
else:
    # Load existing vector store with allow_dangerous_deserialization=True
    vector_store = FAISS.load_local(
        VECTOR_STORE_PATH, 
        embeddings, 
        allow_dangerous_deserialization=True
    )

In [55]:
review_template = """Answer the question about doctor reviews based on the following information. Be brief and accurate.
If you don't know an answer, say you don't know.
{context}
"""

review_system_prompt = SystemMessagePromptTemplate(
    prompt=PromptTemplate(input_variables=["context"], template=review_template)
)

review_human_prompt = HumanMessagePromptTemplate(
    prompt=PromptTemplate(input_variables=["question"], template="{question}")
)

review_prompt = ChatPromptTemplate(
    input_variables=["context", "question"],
    messages=[review_system_prompt, review_human_prompt]
)

In [56]:
from langchain.schema import Document
from langchain_community.retrievers import TFIDFRetriever 

documents = [Document(page_content=comment) for comment in comments]
retriever = TFIDFRetriever.from_documents(documents)

reviews_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=retriever,
    chain_type_kwargs={"prompt": review_prompt}
)

In [59]:
question = "How many are there doctors in hospital ?"
vector_response = reviews_chain.invoke(question)
print(vector_response["result"])

I don't know. The text only mentions one doctor, Dr. John Doe (also referred to as Dr. A). It does not provide information about the total number of doctors in the hospital.
