In [2]:
import pandas as pd
from sqlalchemy import create_engine
from langchain.docstore.document import Document
from sentence_transformers import SentenceTransformer
from langchain_core.embeddings import Embeddings
from langchain_community.vectorstores import Chroma
from tqdm import tqdm
from typing import List
from langchain.text_splitter import RecursiveCharacterTextSplitter
import numpy as np


In [3]:

# --- 0. CUSTOM EMBEDDING CLASS ---
# This custom class allows LangChain to use the local BGE-M3 model
class LocalHuggingFaceEmbeddings(Embeddings):
    def __init__(self, model_id):
        self.model = SentenceTransformer(model_id)

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        # Disable the model's internal progress bar to allow our main tqdm bar to control output
        return self.model.encode(texts, show_progress_bar=False, normalize_embeddings=True).tolist()

    def embed_query(self, text: str) -> List[float]:
        return self.model.encode(text, normalize_embeddings=True).tolist()


In [4]:

# --- 1. DATABASE CONFIGURATION ---
DB_USER = "admin"
DB_PASSWORD = "admin"
DB_NAME = "Spice_BD"  # MODIFIED: Changed to your database name
DB_URI = f"postgresql://{DB_USER}:{DB_PASSWORD}@localhost:5432/{DB_NAME}"

try:
    engine = create_engine(DB_URI)
    print("Successfully connected to the Spice_BD database.")
except Exception as e:
    print(f"Failed to connect to the database. Error: {e}")
    exit()

# --- 2. SQL QUERIES FOR SPICE_BD SCHEMA (MODIFIED) ---
# These queries are designed to fetch non-PII data for generating insights.
sql_queries = {
    "patient_vitals": """
        SELECT 
            pt.id AS patient_tracker_id,
            p.gender, 
            p.age,
            bp.avg_systolic, 
            bp.avg_diastolic, 
            gl.glucose_value, 
            gl.glucose_type
        FROM patient_tracker pt
        JOIN patient p ON pt.patient_id = p.id
        LEFT JOIN bp_log bp ON pt.id = bp.patient_track_id
        LEFT JOIN glucose_log gl ON pt.id = gl.patient_track_id
        WHERE p.gender IS NOT NULL AND p.age IS NOT NULL;
    """,
    "patient_conditions": """
        SELECT 
            pt.id AS patient_tracker_id,
            pd.is_htn_diagnosis, 
            pd.is_diabetes_diagnosis,
            pc.comorbidity_id,
            comp.complication_id
        FROM patient_tracker pt
        LEFT JOIN patient_diagnosis pd ON pt.id = pd.patient_track_id
        LEFT JOIN patient_comorbidity pc ON pt.id = pc.patient_track_id
        LEFT JOIN patient_complication comp ON pt.id = comp.patient_track_id;
    """,
    "user_roles_permissions": """
        SELECT
            u.id AS user_id,
            r.name AS role_name
        FROM "user" u
        JOIN user_role ur ON u.id = ur.user_id
        JOIN role r ON ur.role_id = r.id;
    """,
    "prescriptions": """
        SELECT
            p.patient_track_id,
            p.medication_name,
            p.prescribed_days
        FROM prescription p;
    """,
    "call_logs": """
        SELECT
            cr.id AS call_register_id,
            cr.call_type,
            crd.status AS call_status,
            crd.duration AS call_duration
        FROM call_register cr
        JOIN call_register_detail crd ON cr.id = crd.call_register_id;
    """
}


Successfully connected to the Spice_BD database.


In [5]:

# --- 3. LOAD DATA FROM ALL TABLES ---
dataframes = {}
for name, query in sql_queries.items():
    print(f"Loading data for '{name}'...")
    try:
        df = pd.read_sql(query, engine)
        dataframes[name] = df
        print(f"Successfully loaded {len(df)} records for '{name}'.")
    except Exception as e:
        print(f"--- QUERY FAILED for '{name}' ---")
        print(f"The actual database error is: {e}")

# --- 4. DOCUMENT CREATION FOR SPICE_BD (MODIFIED) ---
documents = []
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=512,
    chunk_overlap=100,
    length_function=len,
)

print("\nProcessing loaded data for document creation...")


Loading data for 'patient_vitals'...
Successfully loaded 9123 records for 'patient_vitals'.
Loading data for 'patient_conditions'...
Successfully loaded 8964 records for 'patient_conditions'.
Loading data for 'user_roles_permissions'...
Successfully loaded 1373 records for 'user_roles_permissions'.
Loading data for 'prescriptions'...
Successfully loaded 2533 records for 'prescriptions'.
Loading data for 'call_logs'...
Successfully loaded 1977 records for 'call_logs'.

Processing loaded data for document creation...


In [6]:

# --- Patient Vitals Documents ---
if 'patient_vitals' in dataframes:
    df = dataframes['patient_vitals']
    for _, row in tqdm(df.iterrows(), total=df.shape[0], desc="Creating Vitals Documents"):
        patient_id = row['patient_tracker_id']
        content = (
            f"Patient-ID {patient_id} (Age: {row.get('age', 'N/A')}, Gender: {row.get('gender', 'N/A')}) has a recorded "
            f"systolic pressure of {row.get('avg_systolic', 'N/A')}, "
            f"diastolic pressure of {row.get('avg_diastolic', 'N/A')}, "
            f"and a {row.get('glucose_type', 'N/A')} glucose value of {row.get('glucose_value', 'N/A')}."
        )
        doc = Document(page_content=content, metadata={"source_patient_tracker_id": patient_id, "source_table": "patient_vitals"})
        documents.append(doc)

# --- Patient Conditions Documents ---
if 'patient_conditions' in dataframes:
    df = dataframes['patient_conditions']
    for _, row in tqdm(df.iterrows(), total=df.shape[0], desc="Creating Conditions Documents"):
        patient_id = row['patient_tracker_id']
        content = (
            f"Patient-ID {patient_id} has the following conditions: "
            f"Hypertension Diagnosis: {row.get('is_htn_diagnosis', 'N/A')}, "
            f"Diabetes Diagnosis: {row.get('is_diabetes_diagnosis', 'N/A')}, "
            f"Comorbidity ID: {row.get('comorbidity_id', 'N/A')}, "
            f"Complication ID: {row.get('complication_id', 'N/A')}."
        )
        doc = Document(page_content=content, metadata={"source_patient_tracker_id": patient_id, "source_table": "patient_conditions"})
        documents.append(doc)

# --- User Roles Documents ---
if 'user_roles_permissions' in dataframes:
    df = dataframes['user_roles_permissions']
    for _, row in tqdm(df.iterrows(), total=df.shape[0], desc="Creating User Role Documents"):
        user_id = row['user_id']
        content = f"User-ID {user_id} has the role of '{row.get('role_name', 'N/A')}'."
        doc = Document(page_content=content, metadata={"source_user_id": user_id, "source_table": "user_roles"})
        documents.append(doc)

# --- Prescriptions Documents ---
if 'prescriptions' in dataframes:
    df = dataframes['prescriptions']
    for _, row in tqdm(df.iterrows(), total=df.shape[0], desc="Creating Prescription Documents"):
        text = f"Medication '{row.get('medication_name', 'N/A')}' was prescribed for {row.get('prescribed_days', 'N/A')} days for Patient-ID {row.get('patient_track_id', 'N/A')}."
        for chunk in text_splitter.split_text(text):
            doc = Document(page_content=chunk, metadata={"source_patient_tracker_id": row.get('patient_track_id', 'N/A'), "source_table": "prescription"})
            documents.append(doc)
            
# --- Call Logs Documents ---
if 'call_logs' in dataframes:
    df = dataframes['call_logs']
    for _, row in tqdm(df.iterrows(), total=df.shape[0], desc="Creating Call Log Documents"):
        content = f"Call Log ID {row.get('call_register_id', 'N/A')}: Type of call was '{row.get('call_type', 'N/A')}' with a final status of '{row.get('call_status', 'N/A')}' and duration {row.get('call_duration', 'N/A')}."
        doc = Document(page_content=content, metadata={"source_call_id": row.get('call_register_id', 'N/A'), "source_table": "call_logs"})
        documents.append(doc)

print(f"\nCreated a total of {len(documents)} documents for indexing.")


Creating Vitals Documents: 100%|██████████| 9123/9123 [00:00<00:00, 47097.66it/s]
Creating Conditions Documents: 100%|██████████| 8964/8964 [00:00<00:00, 54380.33it/s]
Creating User Role Documents: 100%|██████████| 1373/1373 [00:00<00:00, 65997.15it/s]
Creating Prescription Documents: 100%|██████████| 2533/2533 [00:00<00:00, 39427.06it/s]
Creating Call Log Documents: 100%|██████████| 1977/1977 [00:00<00:00, 46037.77it/s]


Created a total of 23970 documents for indexing.





In [7]:

# --- 5. EMBED AND STORE WITH BGE-M3 ---
if documents:
    print("Initializing BGE-M3 embedding model from local path...")
    
    # Ensure you have cloned the model into this directory
    model_path = "./models/bge-m3"  
    
    embedding_model = LocalHuggingFaceEmbeddings(model_id=model_path)
    print("Model initialized.")
    persist_directory = 'chroma_db_spice_bd' # MODIFIED: New directory for this DB
    
    print("Creating and persisting the vector store (this may take a long time)...")
    
    # Initialize Chroma with the first document to create the store
    vector_db = Chroma.from_documents(
        documents=[documents[0]],
        embedding=embedding_model,
        persist_directory=persist_directory
    )
    
    batch_size = 128 # Increased batch size for potentially faster processing
    
    # Loop through the rest of the documents in batches
    for i in tqdm(range(1, len(documents), batch_size), 
                  desc="Embedding and Storing Batches", 
                  unit="batch"):
        batch = documents[i:i + batch_size]
        if batch: # Ensure batch is not empty
            vector_db.add_documents(documents=batch)

    print("\n--- RAG Indexing Complete! ---")
else:
    print("\nNo data was loaded or no documents were created. RAG indexing was skipped.")

Initializing BGE-M3 embedding model from local path...
Model initialized.
Creating and persisting the vector store (this may take a long time)...


Embedding and Storing Batches: 100%|██████████| 188/188 [02:09<00:00,  1.45batch/s]


--- RAG Indexing Complete! ---





In [None]:
from langchain_community.vectorstores import Chroma
from sentence_transformers import SentenceTransformer
from langchain_core.embeddings import Embeddings
from typing import List

# --- Make sure the custom embedding class is defined in your notebook ---
class LocalHuggingFaceEmbeddings(Embeddings):
    def __init__(self, model_id):
        self.model = SentenceTransformer(model_id)

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        return self.model.encode(texts, show_progress_bar=False).tolist()

    def embed_query(self, text: str) -> List[float]:
        return self.model.encode(text).tolist()

# --- 1. SETTINGS ---
model_path = "./models/bge-m3"
persist_directory = 'chroma_db_spice_bd'  # Ensure this matches the directory used during indexing

# --- 2. LOAD THE EMBEDDING MODEL AND VECTOR STORE ---
print("Loading embedding model and vector store...")
embedding_model = LocalHuggingFaceEmbeddings(model_id=model_path)
vector_db = Chroma(persist_directory=persist_directory, embedding_function=embedding_model)
print("Vector store loaded successfully.")

# --- 3. PERFORM A TEST QUERY ---
query = "any patients born after 2016??"
print(f"\nPerforming similarity search for: '{query}'\n")

# Retrieve the 3 most relevant documents
retrieved_docs = vector_db.similarity_search(query, k=3)

# --- 4. DISPLAY THE RESULTS ---
if retrieved_docs:
    print("--- Top 3 Retrieved Documents ---")
    for i, doc in enumerate(retrieved_docs):
        print(f"\n--- Document {i+1} ---")
        print(f"Content: {doc.page_content}")
        print(f"Metadata: {doc.metadata}")
else:
    print("No relevant documents were found.")
    
# --- 5. VERIFY SOURCE DATA ---
print("Checking for 'diabetes' in the source condition data...")

# Access the 'condition' dataframe from the dictionary we created earlier
condition_df = dataframes.get('condition')

if condition_df is not None:
    # Search for rows where 'condition_text' contains 'Hypertension and Diabetics' (case-insensitive)
    diabetes_records = condition_df[condition_df['condition_text'].str.contains('Hypertension and Diabetics', case=False, na=False)]


    hypertension_records = condition_df[condition_df['condition_text'].str.contains('Hypertension', case=False, na=False)]
    
    if not diabetes_records.empty:
        print(f"\nFound {len(diabetes_records)} records related to Hypertension and Diabetics in the source data.")
        print("Here are the first 3:")
        print(diabetes_records.head())
    else:
        print("This is likely why the similarity search did not return relevant results.")
    
    if not hypertension_records.empty:
        print(f"\nFound {len(hypertension_records)} records related to Hypertension in the source data.")
        print("Here are the first 3")
        print(hypertension_records.head())
    else:
        print("No records related to Hypertension were found.")
else:
    print("Could not find the 'condition' dataframe.")

Loading embedding model and vector store...
Vector store loaded successfully.

Performing similarity search for: 'any patients born after 2016??'

--- Top 3 Retrieved Documents ---

--- Document 1 ---
Content: Patient-ID 20049 (Age: 27, Gender: Male) has a recorded systolic pressure of 150.0, diastolic pressure of 85.0, and a fbs glucose value of 8.0.
Metadata: {'source_table': 'patient_vitals', 'source_patient_tracker_id': 20049}

--- Document 2 ---
Content: Patient-ID 21585 (Age: 45, Gender: Male) has a recorded systolic pressure of 150.0, diastolic pressure of 90.0, and a fbs glucose value of 9.0.
Metadata: {'source_table': 'patient_vitals', 'source_patient_tracker_id': 21585}

--- Document 3 ---
Content: Patient-ID 20156 (Age: 45, Gender: Male) has a recorded systolic pressure of 150.0, diastolic pressure of 95.0, and a fbs glucose value of 9.0.
Metadata: {'source_table': 'patient_vitals', 'source_patient_tracker_id': 20156}
Checking for 'diabetes' in the source condition data...
C