In [1]:
import os
import shutil
import pandas as pd
import logging
import json
import getpass
from dotenv import load_dotenv
from datetime import datetime
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.schema import Document
from langchain.prompts.chat import ChatPromptTemplate
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
load_dotenv()

True

In [3]:
if "GOOGLE_API_KEY" not in os.environ:
    os.environ["GOOGLE_API_KEY"] = getpass.getpass("Provide your Google API key here")

In [4]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

In [5]:
DATA_PATH = os.getenv("DATA_PATH", "hospital_data/data_tables")
CHROMA_PATH = os.getenv("CHROMA_PATH", "chroma_db")
DATA_DICTIONARY_PATH = os.getenv("DATA_DICTIONARY_PATH", "hospital_data/dictionary/data_dictionary.csv")

In [7]:
# Load CSV files into DataFrames with validation and logging
def load_csv_documents():
    csv_files = [f for f in os.listdir(DATA_PATH) if f.endswith('.csv')]
    dataframes = []
    for file in csv_files:
        file_path = os.path.join(DATA_PATH, file)
        try:
            df = pd.read_csv(file_path, dtype={'ZIP': str})  # Ensure ZIP codes are loaded as strings
            # Handle missing values by filling appropriately based on dtype
            for column in df.columns:
                if df[column].dtype == 'float64' or df[column].dtype == 'int64':
                    df[column] = df[column].fillna(0)  # Fill numeric columns with 0 or another default value
                else:
                    df[column] = df[column].fillna("N/A")  # Fill non-numeric columns with "N/A"
            dataframes.append(df)
            logging.info(f"Loaded {file} successfully with {len(df)} rows.")
        except Exception as e:
            logging.error(f"Error loading {file}: {e}")
    return dataframes

In [8]:
dataframes = load_csv_documents()

2024-11-07 02:31:34,984 - INFO - Loaded payers.csv successfully with 10 rows.
2024-11-07 02:31:35,178 - INFO - Loaded procedures.csv successfully with 47701 rows.
2024-11-07 02:31:35,184 - INFO - Loaded organizations.csv successfully with 1 rows.
2024-11-07 02:31:35,324 - INFO - Loaded encounters.csv successfully with 27891 rows.
2024-11-07 02:31:35,341 - INFO - Loaded patients.csv successfully with 974 rows.


In [10]:
dataframes[0]

Unnamed: 0,Id,NAME,ADDRESS,CITY,STATE_HEADQUARTERED,ZIP,PHONE
0,b3221cfc-24fb-339e-823d-bc4136cbc4ed,Dual Eligible,7500 Security Blvd,Baltimore,MD,21244.0,1-877-267-2323
1,7caa7254-5050-3b5e-9eae-bd5ea30e809c,Medicare,7500 Security Blvd,Baltimore,MD,21244.0,1-800-633-4227
2,7c4411ce-02f1-39b5-b9ec-dfbea9ad3c1a,Medicaid,7500 Security Blvd,Baltimore,MD,21244.0,1-877-267-2323
3,d47b3510-2895-3b70-9897-342d681c769d,Humana,500 West Main St,Louisville,KY,40018.0,1-844-330-7799
4,6e2f1a2d-27bd-3701-8d08-dae202c58632,Blue Cross Blue Shield,Michigan Plaza,Chicago,IL,60007.0,1-800-262-2583
5,5059a55e-5d6e-34d1-b6cb-d83d16e57bcf,UnitedHealthcare,9800 Healthcare Lane,Minnetonka,MN,55436.0,1-888-545-5205
6,4d71f845-a6a9-3c39-b242-14d25ef86a8d,Aetna,151 Farmington Ave,Hartford,CT,6156.0,1-800-872-3862
7,047f6ec3-6215-35eb-9608-f9dda363a44c,Cigna Health,900 Cottage Grove Rd,Bloomfield,CT,6002.0,1-800-997-1654
8,42c4fca7-f8a9-3cd1-982a-dd9751bf3e2a,Anthem,220 Virginia Ave,Indianapolis,IN,46204.0,1-800-331-1476
9,b1c428d6-4f07-31e0-90f0-68ffa6ff8c76,NO_INSURANCE,,,,,


In [11]:
# Load Data Dictionary
def load_data_dictionary():
    if os.path.exists(DATA_DICTIONARY_PATH):
        try:
            data_dict = pd.read_csv(DATA_DICTIONARY_PATH)
            logging.info("Data dictionary loaded successfully.")
            return data_dict
        except Exception as e:
            logging.error(f"Error loading hospital_data dictionary: {e}")
    return None

In [12]:
data_dictionary_df = load_data_dictionary()

2024-11-07 02:32:09,115 - INFO - Data dictionary loaded successfully.


In [13]:
data_dictionary_df

Unnamed: 0,Table,Field,Description
0,encounters,,Patient encounter data
1,encounters,Id,Primary Key. Unique Identifier of the encounter.
2,encounters,Start,The date and time (iso8601 UTC Date (yyyy-MM-d...
3,encounters,Stop,The date and time (iso8601 UTC Date (yyyy-MM-d...
4,encounters,Patient,Foreign key to the Patient.
...,...,...,...
60,procedures,Code,Procedure code from SNOMED-CT
61,procedures,Description,Description of the procedure.
62,procedures,Base_Cost,The line item cost of the procedure.
63,procedures,ReasonCode,Diagnosis code from SNOMED-CT specifying why t...


In [14]:
# Convert DataFrames into Documents with enriched metadata
def convert_to_documents(dataframes: list[pd.DataFrame], data_dictionary: pd.DataFrame = None):
    documents = []
    for df in dataframes:
        for _, row in df.iterrows():
            text = "\n".join([f"{col}: {row[col]}" for col in df.columns])  # Use newline for better readability
            metadata = row.to_dict()
            metadata.update({
                "document_type": "organization_info" if "NAME" in df.columns else "unknown",
                "load_timestamp": datetime.now().isoformat(),
                "data_source": "hospital_data"
            })
            # Add field descriptions if hospital_data dictionary is available
            if data_dictionary is not None:
                field_descriptions = {}
                for col in df.columns:
                    desc_row = data_dictionary[data_dictionary['Field'] == col]
                    if not desc_row.empty:
                        field_descriptions[col] = desc_row['Description'].values[0]
                # Serialize field_descriptions to a JSON string
                metadata["field_descriptions"] = json.dumps(field_descriptions)

            document = Document(page_content=text, metadata=metadata)
            documents.append(document)

    if data_dictionary is not None:
        for _, row in data_dictionary.iterrows():
            text = "\n".join([f"{col}: {row[col]}" for col in data_dictionary.columns])
            document = Document(page_content=text, metadata={
                "source": "data_dictionary",
                "document_type": "data_dictionary",
                "load_timestamp": datetime.now().isoformat(),
                "data_source": "data_dictionary"
            })
            documents.append(document)

    return documents

In [15]:
data = convert_to_documents(dataframes, data_dictionary_df)

In [16]:
data

[Document(metadata={'Id': 'b3221cfc-24fb-339e-823d-bc4136cbc4ed', 'NAME': 'Dual Eligible', 'ADDRESS': '7500 Security Blvd', 'CITY': 'Baltimore', 'STATE_HEADQUARTERED': 'MD', 'ZIP': '21244', 'PHONE': '1-877-267-2323', 'document_type': 'organization_info', 'load_timestamp': '2024-11-07T02:33:52.011337', 'data_source': 'hospital_data', 'field_descriptions': '{"Id": "Primary Key. Unique Identifier of the encounter."}'}, page_content='Id: b3221cfc-24fb-339e-823d-bc4136cbc4ed\nNAME: Dual Eligible\nADDRESS: 7500 Security Blvd\nCITY: Baltimore\nSTATE_HEADQUARTERED: MD\nZIP: 21244\nPHONE: 1-877-267-2323'),
 Document(metadata={'Id': '7caa7254-5050-3b5e-9eae-bd5ea30e809c', 'NAME': 'Medicare', 'ADDRESS': '7500 Security Blvd', 'CITY': 'Baltimore', 'STATE_HEADQUARTERED': 'MD', 'ZIP': '21244', 'PHONE': '1-800-633-4227', 'document_type': 'organization_info', 'load_timestamp': '2024-11-07T02:33:52.024453', 'data_source': 'hospital_data', 'field_descriptions': '{"Id": "Primary Key. Unique Identifier of 

In [23]:
# Split documents into chunks with dynamic chunk sizing
def split_text(documents: list):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=600,
        chunk_overlap=50,
        length_function=len,
        add_start_index=True,
    )
    chunks = text_splitter.split_documents(documents)
    logging.info(f"Split {len(documents)} documents into {len(chunks)} chunks.")

    if len(chunks) > 10:
        document = chunks[10]
        logging.info(f"Sample chunk content: {document.page_content}")
        logging.info(f"Sample chunk metadata: {document.metadata}")

    return chunks

In [24]:
chunks = split_text(data)

2024-11-07 02:48:37,839 - INFO - Split 76642 documents into 76642 chunks.
2024-11-07 02:48:37,840 - INFO - Sample chunk content: START: 2011-01-02T09:26:36Z
STOP: 2011-01-02T12:58:36Z
PATIENT: 3de74169-7f67-9304-91d4-757e0f3a14d2
ENCOUNTER: 32c84703-2481-49cd-d571-3899d5820253
CODE: 265764009
DESCRIPTION: Renal dialysis (procedure)
BASE_COST: 903
REASONCODE: 0.0
REASONDESCRIPTION: N/A
2024-11-07 02:48:37,840 - INFO - Sample chunk metadata: {'START': '2011-01-02T09:26:36Z', 'STOP': '2011-01-02T12:58:36Z', 'PATIENT': '3de74169-7f67-9304-91d4-757e0f3a14d2', 'ENCOUNTER': '32c84703-2481-49cd-d571-3899d5820253', 'CODE': 265764009, 'DESCRIPTION': 'Renal dialysis (procedure)', 'BASE_COST': 903, 'REASONCODE': 0.0, 'REASONDESCRIPTION': 'N/A', 'document_type': 'unknown', 'load_timestamp': '2024-11-07T02:33:52.098330', 'data_source': 'hospital_data', 'field_descriptions': '{}', 'start_index': 0}


In [25]:
chunks

[Document(metadata={'Id': 'b3221cfc-24fb-339e-823d-bc4136cbc4ed', 'NAME': 'Dual Eligible', 'ADDRESS': '7500 Security Blvd', 'CITY': 'Baltimore', 'STATE_HEADQUARTERED': 'MD', 'ZIP': '21244', 'PHONE': '1-877-267-2323', 'document_type': 'organization_info', 'load_timestamp': '2024-11-07T02:33:52.011337', 'data_source': 'hospital_data', 'field_descriptions': '{"Id": "Primary Key. Unique Identifier of the encounter."}', 'start_index': 0}, page_content='Id: b3221cfc-24fb-339e-823d-bc4136cbc4ed\nNAME: Dual Eligible\nADDRESS: 7500 Security Blvd\nCITY: Baltimore\nSTATE_HEADQUARTERED: MD\nZIP: 21244\nPHONE: 1-877-267-2323'),
 Document(metadata={'Id': '7caa7254-5050-3b5e-9eae-bd5ea30e809c', 'NAME': 'Medicare', 'ADDRESS': '7500 Security Blvd', 'CITY': 'Baltimore', 'STATE_HEADQUARTERED': 'MD', 'ZIP': '21244', 'PHONE': '1-800-633-4227', 'document_type': 'organization_info', 'load_timestamp': '2024-11-07T02:33:52.024453', 'data_source': 'hospital_data', 'field_descriptions': '{"Id": "Primary Key. Uni

In [26]:
# Save chunks to Chroma vector store with batch saving and logging
def save_to_chroma(chunks: list):
    # Clear out the database first.
    if os.path.exists(CHROMA_PATH):
        shutil.rmtree(CHROMA_PATH)
        logging.info(f"Cleared existing database at {CHROMA_PATH}.")

    # Use Google Generative AI Embeddings with API Key
    embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")

    # Create a new DB from the documents with batch processing
    batch_size = 1000  # Save in batches to handle large datasets
    for i in range(0, len(chunks), batch_size):
        batch = chunks[i:i + batch_size]
        db = Chroma.from_documents(
            batch, embeddings, persist_directory=CHROMA_PATH
        )
        logging.info(f"Saved batch {i // batch_size + 1} to {CHROMA_PATH}.")

    logging.info(f"Saved all chunks to {CHROMA_PATH}.")

In [27]:
save_to_chroma(chunks)

2024-11-07 02:52:48,672 - INFO - Cleared existing database at chroma_db.
2024-11-07 02:52:49,957 - INFO - Anonymized telemetry enabled. See                     https://docs.trychroma.com/telemetry for more information.
2024-11-07 02:52:59,978 - INFO - Saved batch 1 to chroma_db.
2024-11-07 02:53:11,559 - INFO - Saved batch 2 to chroma_db.
2024-11-07 02:53:21,003 - INFO - Saved batch 3 to chroma_db.
2024-11-07 02:53:31,027 - INFO - Saved batch 4 to chroma_db.
2024-11-07 02:53:40,577 - INFO - Saved batch 5 to chroma_db.
2024-11-07 02:53:51,009 - INFO - Saved batch 6 to chroma_db.
2024-11-07 02:54:00,761 - INFO - Saved batch 7 to chroma_db.
2024-11-07 02:54:09,943 - INFO - Saved batch 8 to chroma_db.
2024-11-07 02:54:19,061 - INFO - Saved batch 9 to chroma_db.
2024-11-07 02:54:28,304 - INFO - Saved batch 10 to chroma_db.
2024-11-07 02:54:37,324 - INFO - Saved batch 11 to chroma_db.
2024-11-07 02:54:45,787 - INFO - Saved batch 12 to chroma_db.
2024-11-07 02:54:54,970 - INFO - Saved batch 1

In [32]:
def chat(query_text):
    # Prepare the DB
    embedding_function = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
    db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)

    # Search the DB with similarity filtering
    results = db.similarity_search_with_relevance_scores(query_text, k=5)
    print(f"Retrieved results: {results}")

    # Adjusted Similarity Threshold
    if len(results) == 0:
        print("Unable to find matching results.")
        return

    # Check if any results exceed a set threshold or if the best available should be used
    threshold = 0.5  # Lower the threshold for partial matches
    relevant_results = [result for result in results if result[1] >= threshold]

    if len(relevant_results) == 0:
        print("No results exceed the similarity threshold. Returning the best available match.")
        relevant_results = [results[0]]  # Return the best available match if no results exceed threshold

    # Prepare context from the results
    context_texts = []
    document_types = []
    data_sources = []
    field_descriptions_list = []
    patient_ids = []

    for doc, _score in relevant_results:
        context_texts.append(doc.page_content)
        document_types.append(doc.metadata.get("document_type", "unknown"))
        data_sources.append(doc.metadata.get("data_source", "unknown"))
        if "field_descriptions" in doc.metadata:
            field_descriptions_list.append(json.loads(doc.metadata["field_descriptions"]))
        if "PATIENT" in doc.metadata:
            patient_ids.append(doc.metadata["PATIENT"])

    # Compile context with separators
    context_text = "\n\n---\n\n".join(context_texts)
    document_type_context = ", ".join(set(document_types))
    data_source_context = ", ".join(set(data_sources))
    field_descriptions_context = json.dumps(field_descriptions_list, indent=2) if field_descriptions_list else "None"
    patient_ids_context = ", ".join(set(patient_ids)) if patient_ids else "None"

    # Update the prompt template with new metadata
    PROMPT_TEMPLATE = """
        Answer the question based only on the following context:
        
        {context}
        
        ---
        
        Metadata Information:
        Document Type: {document_type}
        Data Source: {data_source}
        Field Descriptions: {field_descriptions}
        Patient Identifiers: {patient_ids}
        
        ---
        
        Answer the question based on the above context: {question}
        """

    prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
    prompt = prompt_template.format(
        context=context_text,
        question=query_text,
        document_type=document_type_context,
        data_source=data_source_context,
        field_descriptions=field_descriptions_context,
        patient_ids=patient_ids_context
    )

    # Query the model
    model = ChatGoogleGenerativeAI(model="gemini-1.5-flash")
    response_text = model.predict(prompt)

    # Compile and print response with sources
    sources = [doc.metadata.get("source", "Unknown") for doc, _score in relevant_results if doc.metadata]
    formatted_response = f"Response: {response_text}\n\n\nSources: {sources}"
    print(prompt)
    print(formatted_response)


In [37]:
query_text = "show me sample hospital_data"
chat(query_text)

Retrieved results: [(Document(metadata={'BASE_COST': 1780, 'CODE': 104091002, 'DESCRIPTION': 'Hemoglobin / Hematocrit / Platelet count', 'ENCOUNTER': '8a864187-1568-057f-cb9c-e98e37e38e70', 'PATIENT': '20f86dda-a492-2acc-63fb-2c361b0367ed', 'REASONCODE': 72892002.0, 'REASONDESCRIPTION': 'Normal pregnancy', 'START': '2013-06-01T13:00:16Z', 'STOP': '2013-06-01T13:15:16Z', 'data_source': 'hospital_data', 'document_type': 'unknown', 'field_descriptions': '{}', 'load_timestamp': '2024-11-07T02:34:05.309315', 'start_index': 0}, page_content='START: 2013-06-01T13:00:16Z\nSTOP: 2013-06-01T13:15:16Z\nPATIENT: 20f86dda-a492-2acc-63fb-2c361b0367ed\nENCOUNTER: 8a864187-1568-057f-cb9c-e98e37e38e70\nCODE: 104091002\nDESCRIPTION: Hemoglobin / Hematocrit / Platelet count\nBASE_COST: 1780\nREASONCODE: 72892002.0\nREASONDESCRIPTION: Normal pregnancy'), 0.5438163358843946), (Document(metadata={'BASE_COST': 1224, 'CODE': 104091002, 'DESCRIPTION': 'Hemoglobin / Hematocrit / Platelet count', 'ENCOUNTER': '9