In [2]:
import os
import shutil
import pandas as pd
import logging
import json
import getpass
from dotenv import load_dotenv
from datetime import datetime
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.schema import Document
from langchain.prompts.chat import ChatPromptTemplate
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
load_dotenv()

True

In [4]:
if "GOOGLE_API_KEY" not in os.environ:
    os.environ["GOOGLE_API_KEY"] = getpass.getpass("Provide your Google API key here: ")

In [5]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

In [6]:
MERGED_DATA_PATH = os.getenv("MERGED_DATA_PATH", "merged_hospital_data.csv")
CHROMA_PATH = os.getenv("CHROMA_PATH", "chroma_db")
DATA_DICTIONARY_PATH = os.getenv("DATA_DICTIONARY_PATH", "data/dictionary/data_dictionary.csv")

In [7]:
# Load the merged CSV file into a DataFrame with validation and logging
def load_merged_document():
    try:
        df = pd.read_csv(MERGED_DATA_PATH, dtype={'ZIP': str})  # Ensure ZIP codes are loaded as strings
        # Handle missing values by filling appropriately based on dtype
        for column in df.columns:
            if df[column].dtype == 'float64' or df[column].dtype == 'int64':
                df[column] = df[column].fillna(0)  # Fill numeric columns with 0 or another default value
            else:
                df[column] = df[column].fillna("N/A")  # Fill non-numeric columns with "N/A"
        logging.info(f"Loaded merged CSV successfully with {len(df)} rows.")
        return df
    except Exception as e:
        logging.error(f"Error loading merged CSV: {e}")
        return None

In [8]:
merged_data_df = load_merged_document()

2024-11-06 19:42:55,341 - INFO - Loaded merged CSV successfully with 60922 rows.


In [9]:
merged_data_df.head()

Unnamed: 0,Id_encounter,START,STOP,PATIENT,ORGANIZATION,PAYER,ENCOUNTERCLASS,CODE,DESCRIPTION,BASE_ENCOUNTER_COST,...,ZIP_payer,PHONE,START_procedure,STOP_procedure,ENCOUNTER,CODE_procedure,DESCRIPTION_procedure,BASE_COST,REASONCODE_procedure,REASONDESCRIPTION_procedure
0,32c84703-2481-49cd-d571-3899d5820253,2011-01-02T09:26:36Z,2011-01-02T12:58:36Z,3de74169-7f67-9304-91d4-757e0f3a14d2,d78e84ec-30aa-3bba-a33a-f29a3a454662,b1c428d6-4f07-31e0-90f0-68ffa6ff8c76,ambulatory,185347001,Encounter for problem (procedure),85.55,...,0.0,,2011-01-02T09:26:36Z,2011-01-02T12:58:36Z,32c84703-2481-49cd-d571-3899d5820253,265764009.0,Renal dialysis (procedure),903.0,0.0,
1,c98059da-320a-c0a6-fced-c8815f3e3f39,2011-01-03T05:44:39Z,2011-01-03T06:01:42Z,d9ec2e44-32e9-9148-179a-1653348cc4e2,d78e84ec-30aa-3bba-a33a-f29a3a454662,b1c428d6-4f07-31e0-90f0-68ffa6ff8c76,outpatient,308335008,Patient encounter procedure,142.58,...,0.0,,2011-01-03T05:44:39Z,2011-01-03T06:01:42Z,c98059da-320a-c0a6-fced-c8815f3e3f39,76601001.0,Intramuscular injection,2477.0,0.0,
2,4ad28a3a-2479-782b-f29c-d5b3f41a001e,2011-01-03T14:32:11Z,2011-01-03T14:47:11Z,73babadf-5b2b-fee7-189e-6f41ff213e01,d78e84ec-30aa-3bba-a33a-f29a3a454662,7caa7254-5050-3b5e-9eae-bd5ea30e809c,outpatient,185349003,Encounter for check up (procedure),85.55,...,21244.0,1-800-633-4227,,,,0.0,,0.0,0.0,
3,c3f4da61-e4b4-21d5-587a-fbc89943bc19,2011-01-03T16:24:45Z,2011-01-03T16:39:45Z,3b46a0b7-0f34-9b9a-c319-ace4a1f58c0b,d78e84ec-30aa-3bba-a33a-f29a3a454662,b1c428d6-4f07-31e0-90f0-68ffa6ff8c76,wellness,162673000,General examination of patient (procedure),136.8,...,0.0,,,,,0.0,,0.0,0.0,
4,a9183b4f-2572-72ea-54c2-b3cd038b4be7,2011-01-03T17:36:53Z,2011-01-03T17:51:53Z,fa006887-d93c-d302-8b89-f3c25f88c0e1,d78e84ec-30aa-3bba-a33a-f29a3a454662,42c4fca7-f8a9-3cd1-982a-dd9751bf3e2a,ambulatory,390906007,Follow-up encounter,85.55,...,46204.0,1-800-331-1476,,,,0.0,,0.0,0.0,


In [10]:
# Load Data Dictionary
def load_data_dictionary():
    if os.path.exists(DATA_DICTIONARY_PATH):
        try:
            data_dict = pd.read_csv(DATA_DICTIONARY_PATH)
            logging.info("Data dictionary loaded successfully.")
            return data_dict
        except Exception as e:
            logging.error(f"Error loading data dictionary: {e}")
    return None

In [11]:
data_dictionary_df = load_data_dictionary()

2024-11-06 19:43:47,837 - INFO - Data dictionary loaded successfully.


In [12]:
data_dictionary_df

Unnamed: 0,Table,Field,Description
0,encounters,,Patient encounter data
1,encounters,Id,Primary Key. Unique Identifier of the encounter.
2,encounters,Start,The date and time (iso8601 UTC Date (yyyy-MM-d...
3,encounters,Stop,The date and time (iso8601 UTC Date (yyyy-MM-d...
4,encounters,Patient,Foreign key to the Patient.
...,...,...,...
60,procedures,Code,Procedure code from SNOMED-CT
61,procedures,Description,Description of the procedure.
62,procedures,Base_Cost,The line item cost of the procedure.
63,procedures,ReasonCode,Diagnosis code from SNOMED-CT specifying why t...


In [13]:
# Convert DataFrame into Documents with enriched metadata
def convert_to_documents(df: pd.DataFrame, data_dictionary: pd.DataFrame = None):
    documents = []
    for _, row in df.iterrows():
        text = "\n".join([f"{col}: {row[col]}" for col in df.columns])  # Use newline for better readability
        metadata = row.to_dict()
        metadata.update({
            "document_type": "organization_info" if "NAME" in df.columns else "unknown",
            "load_timestamp": datetime.now().isoformat(),
            "data_source": "hospital_data"
        })
        # Add field descriptions if data dictionary is available
        if data_dictionary is not None:
            field_descriptions = {}
            for col in df.columns:
                desc_row = data_dictionary[data_dictionary['Field'] == col]
                if not desc_row.empty:
                    field_descriptions[col] = desc_row['Description'].values[0]
            # Serialize field_descriptions to a JSON string
            metadata["field_descriptions"] = json.dumps(field_descriptions)

        document = Document(page_content=text, metadata=metadata)
        documents.append(document)

    if data_dictionary is not None:
        for _, row in data_dictionary.iterrows():
            text = "\n".join([f"{col}: {row[col]}" for col in data_dictionary.columns])
            document = Document(page_content=text, metadata={
                "source": "data_dictionary",
                "document_type": "data_dictionary",
                "load_timestamp": datetime.now().isoformat(),
                "data_source": "data_dictionary"
            })
            documents.append(document)

    return documents

In [14]:
data = convert_to_documents(merged_data_df, data_dictionary_df)

In [15]:
# Split documents into chunks with dynamic chunk sizing
def split_text(documents: list):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=600,
        chunk_overlap=50,
        length_function=len,
        add_start_index=True,
    )
    chunks = text_splitter.split_documents(documents)
    logging.info(f"Split {len(documents)} documents into {len(chunks)} chunks.")

    if len(chunks) > 10:
        document = chunks[10]
        logging.info(f"Sample chunk content: {document.page_content}")
        logging.info(f"Sample chunk metadata: {document.metadata}")

    return chunks

In [16]:
chunks = split_text(data)

2024-11-06 19:56:58,614 - INFO - Split 60987 documents into 182930 chunks.
2024-11-06 19:56:58,615 - INFO - Sample chunk content: PREFIX: Mr.
FIRST: Efrain317
LAST: Dibbert990
SUFFIX: N/A
MAIDEN: N/A
MARITAL: M
RACE: white
ETHNICITY: nonhispanic
GENDER: M
BIRTHPLACE: Lowell  Massachusetts  US
ADDRESS: 475 Wunsch Overpass
CITY: Boston
STATE: Massachusetts
COUNTY: Suffolk County
ZIP: 2121.0
LAT: 42.36292149492703
LON: -71.01306739
Id: d78e84ec-30aa-3bba-a33a-f29a3a454662
NAME: MASSACHUSETTS GENERAL HOSPITAL
ADDRESS_organization: 55 FRUIT STREET
CITY_organization: BOSTON
STATE_organization: MA
ZIP_organization: 2114
LAT_organization: 42.362813
LON_organization: -71.069187
Id_payer: b1c428d6-4f07-31e0-90f0-68ffa6ff8c76
2024-11-06 19:56:58,616 - INFO - Sample chunk metadata: {'Id_encounter': 'c3f4da61-e4b4-21d5-587a-fbc89943bc19', 'START': '2011-01-03T16:24:45Z', 'STOP': '2011-01-03T16:39:45Z', 'PATIENT': '3b46a0b7-0f34-9b9a-c319-ace4a1f58c0b', 'ORGANIZATION': 'd78e84ec-30aa-3bba-a33a-f29a3

In [17]:
# Save chunks to Chroma vector store with batch saving and logging
def save_to_chroma(chunks: list):
    # Clear out the database first.
    if os.path.exists(CHROMA_PATH):
        shutil.rmtree(CHROMA_PATH)
        logging.info(f"Cleared existing database at {CHROMA_PATH}.")

    # Use Google Generative AI Embeddings with API Key
    embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")

    # Create a new DB from the documents with batch processing
    batch_size = 1000  # Save in batches to handle large datasets
    for i in range(0, len(chunks), batch_size):
        batch = chunks[i:i + batch_size]
        db = Chroma.from_documents(
            batch, embeddings, persist_directory=CHROMA_PATH
        )
        logging.info(f"Saved batch {i // batch_size + 1} to {CHROMA_PATH}.")

    logging.info(f"Saved all chunks to {CHROMA_PATH}.")

In [19]:
save_to_chroma(chunks)

2024-11-06 19:59:23,734 - INFO - Anonymized telemetry enabled. See                     https://docs.trychroma.com/telemetry for more information.
2024-11-06 19:59:42,814 - INFO - Saved batch 1 to chroma_db.
2024-11-06 19:59:55,794 - INFO - Saved batch 2 to chroma_db.
2024-11-06 20:00:08,450 - INFO - Saved batch 3 to chroma_db.
2024-11-06 20:00:20,631 - INFO - Saved batch 4 to chroma_db.
2024-11-06 20:00:32,847 - INFO - Saved batch 5 to chroma_db.
2024-11-06 20:00:45,990 - INFO - Saved batch 6 to chroma_db.
2024-11-06 20:00:58,627 - INFO - Saved batch 7 to chroma_db.
2024-11-06 20:01:10,928 - INFO - Saved batch 8 to chroma_db.
2024-11-06 20:01:23,556 - INFO - Saved batch 9 to chroma_db.
2024-11-06 20:01:36,793 - INFO - Saved batch 10 to chroma_db.
2024-11-06 20:01:50,364 - INFO - Saved batch 11 to chroma_db.
2024-11-06 20:02:02,905 - INFO - Saved batch 12 to chroma_db.
2024-11-06 20:02:15,443 - INFO - Saved batch 13 to chroma_db.
2024-11-06 20:02:27,453 - INFO - Saved batch 14 to chroma

In [35]:
def chat(query_text):
    # Prepare the DB
    embedding_function = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
    db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)

    # Search the DB with similarity filtering
    results = db.similarity_search_with_relevance_scores(query_text, k=5)
    print(f"Retrieved results: {results}")

    # Adjusted Similarity Threshold
    if len(results) == 0:
        print("Unable to find matching results.")
        return

    # Check if any results exceed a set threshold or if the best available should be used
    threshold = 0.55  # change as necessary 
    relevant_results = [result for result in results if result[1] >= threshold]

    if len(relevant_results) == 0:
        print("No results exceed the similarity threshold. Returning the best available match.")
        relevant_results = [results[0]]  # Return the best available match if no results exceed threshold

    # Prepare context from the results
    context_texts = []
    document_types = []
    data_sources = []
    field_descriptions_list = []
    patient_ids = []

    for doc, _score in relevant_results:
        context_texts.append(doc.page_content)
        document_types.append(doc.metadata.get("document_type", "unknown"))
        data_sources.append(doc.metadata.get("data_source", "unknown"))
        if "field_descriptions" in doc.metadata:
            field_descriptions_list.append(json.loads(doc.metadata["field_descriptions"]))
        if "PATIENT" in doc.metadata:
            patient_ids.append(doc.metadata["PATIENT"])

    # Compile context with separators
    context_text = "\n\n---\n\n".join(context_texts)
    document_type_context = ", ".join(set(document_types))
    data_source_context = ", ".join(set(data_sources))
    field_descriptions_context = json.dumps(field_descriptions_list, indent=2) if field_descriptions_list else "None"
    patient_ids_context = ", ".join(set(patient_ids)) if patient_ids else "None"

    # Update the prompt template with new metadata
    PROMPT_TEMPLATE = """
        Answer the question based only on the following context:
        
        {context}
        
        ---
        
        Metadata Information:
        Document Type: {document_type}
        Data Source: {data_source}
        Field Descriptions: {field_descriptions}
        Patient Identifiers: {patient_ids}
        
        ---
        
        Answer the question based on the above context: {question}
        """

    prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
    prompt = prompt_template.format(
        context=context_text,
        question=query_text,
        document_type=document_type_context,
        data_source=data_source_context,
        field_descriptions=field_descriptions_context,
        patient_ids=patient_ids_context
    )

    # Query the model
    model = ChatGoogleGenerativeAI(model="gemini-1.5-flash")
    response_text = model.predict(prompt)

    # Compile and print response with sources
    sources = [doc.metadata.get("source", "Unknown") for doc, _score in relevant_results if doc.metadata]
    formatted_response = f"Response: {response_text}\n\n\nSources: {sources}"
    print(prompt)
    print(formatted_response)


In [34]:
query_text = "what types of data are there?"
chat(query_text)

Retrieved results: [(Document(metadata={'data_source': 'data_dictionary', 'document_type': 'data_dictionary', 'load_timestamp': '2024-11-06T19:56:03.231338', 'source': 'data_dictionary', 'start_index': 0}, page_content='Table: procedures\nField: nan\nDescription: Patient procedure data including surgeries.'), 0.4486088054416878), (Document(metadata={'data_source': 'data_dictionary', 'document_type': 'data_dictionary', 'load_timestamp': '2024-11-06T19:56:03.227057', 'source': 'data_dictionary', 'start_index': 0}, page_content='Table: patients\nField: Id\nDescription: Primary Key. Unique Identifier of the patient.'), 0.434435148754982), (Document(metadata={'data_source': 'data_dictionary', 'document_type': 'data_dictionary', 'load_timestamp': '2024-11-06T19:56:03.230684', 'source': 'data_dictionary', 'start_index': 0}, page_content='Table: payers\nField: nan\nDescription: Insurance payer data.'), 0.4294189075013217), (Document(metadata={'data_source': 'data_dictionary', 'document_type': 