In [25]:
#pip install sentence-transformers faiss-cpu langchain
#!pip install langchain-text-splitters
# pip install -qU langchain-community
# pip install langchain_openai
#pip install langchain_huggingface

In [1]:
import pandas as pd
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
#from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.documents import Document
from langchain_huggingface import HuggingFaceEmbeddings

from sentence_transformers import SentenceTransformer
from langchain_text_splitters import RecursiveCharacterTextSplitter
import pickle
import ast

In [6]:
test_df = pd.read_csv("test_data.csv")
test_df['patient_id'] = test_df['Patient Note'].astype('category').cat.codes + 1
test_df

Unnamed: 0,Row Number,Calculator ID,Calculator Name,Category,Output Type,Note ID,Note Type,Patient Note,Question,Relevant Entities,Ground Truth Answer,Lower Limit,Upper Limit,Ground Truth Explanation,patient_id
0,1,2,Creatinine Clearance (Cockcroft-Gault Equation),lab test,decimal,pmc-7671985-1,Extracted,An 87-year-old man was admitted to our hospita...,What is the patient's Creatinine Clearance usi...,"{'sex': 'Male', 'age': [87, 'years'], 'weight'...",25.2381,23.97619,26.50001,The formula for computing Cockcroft-Gault is g...,880
1,2,2,Creatinine Clearance (Cockcroft-Gault Equation),lab test,decimal,pmc-8605939-1,Extracted,An 83-year-old man with a past medical history...,What is the patient's Creatinine Clearance usi...,"{'sex': 'Male', 'age': [83, 'years'], 'weight'...",38.0,36.1,39.9,The formula for computing Cockcroft-Gault is g...,874
2,3,2,Creatinine Clearance (Cockcroft-Gault Equation),lab test,decimal,pmc-4359532-1,Extracted,A 62 year old Caucasian female patient (height...,What is the patient's Creatinine Clearance usi...,"{'sex': 'Female', 'age': [62, 'years'], 'weigh...",57.65181,54.76922,60.5344,The formula for computing Cockcroft-Gault is g...,467
3,4,2,Creatinine Clearance (Cockcroft-Gault Equation),lab test,decimal,usmle-1002,Extracted,A 42-year-old woman comes to the physician for...,What is the patient's Creatinine Clearance usi...,"{'sex': 'Female', 'age': [42, 'years'], 'weigh...",106.19185,100.88226,111.50144,The formula for computing Cockcroft-Gault is g...,188
4,5,2,Creatinine Clearance (Cockcroft-Gault Equation),lab test,decimal,pmc-4459668-1,Extracted,"A 45-year-old, 58 kg, 156 cm woman presented w...",What is the patient's Creatinine Clearance usi...,"{'sex': 'Female', 'age': [45, 'years'], 'weigh...",78.12231,74.21619,82.02843,The formula for computing Cockcroft-Gault is g...,212
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1095,1096,69,Estimated Gestational Age,date,date,16,Template,The patient's last menstrual period was on 01/...,"Based on the patient's last menstrual period, ...","{'Current Date': '07/31/2006', 'Last menstrual...","('26 weeks', '3 days')","('26 weeks', '3 days')","('26 weeks', '3 days')","To compute the estimated gestational age, we c...",991
1096,1097,69,Estimated Gestational Age,date,date,17,Template,The patient's last menstrual period was on 10/...,"Based on the patient's last menstrual period, ...","{'Current Date': '06/13/2023', 'Last menstrual...","('36 weeks', '3 days')","('36 weeks', '3 days')","('36 weeks', '3 days')","To compute the estimated gestational age, we c...",1034
1097,1098,69,Estimated Gestational Age,date,date,18,Template,The patient's last menstrual period was on 08/...,"Based on the patient's last menstrual period, ...","{'Current Date': '02/03/2001', 'Last menstrual...","('22 weeks', '2 days')","('22 weeks', '2 days')","('22 weeks', '2 days')","To compute the estimated gestational age, we c...",1030
1098,1099,69,Estimated Gestational Age,date,date,19,Template,The patient's last menstrual period was on 07/...,"Based on the patient's last menstrual period, ...","{'Current Date': '03/19/2014', 'Last menstrual...","('34 weeks', '6 days')","('34 weeks', '6 days')","('34 weeks', '6 days')","To compute the estimated gestational age, we c...",1025


# Store Documents

In [3]:
def chunk_note(raw_text, patient_id, chunk_size=1500, chunk_overlap=300):
    """
    Splits a clinical note into overlapping chunks
    """
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        separators=["\n\n", "\n", ".", " "]
    )

    chunks = splitter.split_text(raw_text)

    docs = []
    for i, chunk in enumerate(chunks):
        docs.append(
            Document(
                page_content=chunk,
                metadata={
                    "patient_id": patient_id,
                    "chunk_id": f"{patient_id}_{i}"
                }
            )
        )
    return docs

def store_documents(chunk_size, chunk_overlap):
    """
    Splits all patient notes into chunks
    and stores into vector store.
    """
    all_docs = []

    # chunk all notes first
    for _, row in test_df.iterrows():
        patient_docs = chunk_note(
            raw_text=row["Patient Note"],
            patient_id=row["patient_id"],
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap
        )
        all_docs.extend(patient_docs)

    # build index safely
    vector_store = FAISS.from_documents(all_docs, embeddings)

    print("stored", vector_store.index.ntotal, "chunks")
    return vector_store

embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5")

vector_store_300 = store_documents(chunk_size=300, chunk_overlap=75)

vector_store_300.save_local("faiss_300")

vector_store_300 = FAISS.load_local(
    "faiss_300",
    embeddings,
    allow_dangerous_deserialization=True
)

stored 14940 chunks


In [103]:
vector_store_1500 = store_documents(chunk_size=1500, chunk_overlap=300)
vector_store_1500.save_local("faiss_1500")

stored 3148 chunks


# Retrieve

### Load RAG

In [2]:
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5")
vector_store_1500 = FAISS.load_local(
    "faiss_1500", embeddings, allow_dangerous_deserialization=True
)

In [3]:
def display_query(query, patient_id, rag, k=5): 
    print(query)
    print("Patient ID: ", patient_id)
    print()
    results = rag.similarity_search_with_score(
        query, 
        k=k, 
        fetch_k=14940,
        filter={'patient_id':patient_id}
    )
    for res, score in results: 
        print(f"Similarity score: {score:3f}")
        print(res.page_content)
        # print(f"[SIM={score:3f}], [{res.metadata}]")
        # print({res.page_content})
        print()

display_query("What is the patient's Creatinine Clearance using the Cockroft-Gault Equation in terms of mL/min?", patient_id=880, rag=vector_store_1500)

What is the patient's Creatinine Clearance using the Cockroft-Gault Equation in terms of mL/min?
Patient ID:  880

Similarity score: 0.808021
An 87-year-old man was admitted to our hospital for anorexia for several days, high-grade fever from the previous day, and liver dysfunction. Of note, he had a history of hypertension, diabetes mellitus (DM), and angina. Physical examination findings included: clear consciousness; height, 163 cm; weight, 48 kg; blood pressure, 66/40 mmHg; heart rate, 75/min; respiratory rate, 22/min; oxygen saturation of peripheral artery, 96%; and body temperature, 38.1 °C. He had no surface lymphadenopathy. Laboratory findings included: white blood cell (WBC) count, 4.2 × 109/L; hemoglobin, 9.6 g/dL; platelet count, 106 × 109/L; lactate dehydrogenase (LDH), 1662 IU/L; aspartate aminotransferase (AST), 6562 IU/L; alanine aminotransferase (ALT), 1407 IU/L; alkaline phosphatase (ALP), 509 IU/L; γ-glutamyl transpeptidase (γ-GTP), 130 IU/L; total bilirubin, 2.7 mg/d

In [4]:
def query_rag(query, patient_id, rag, k=5): 
    results = rag.similarity_search_with_score(
        query, 
        k=k, 
        fetch_k=14940,
        filter={'patient_id':patient_id}
    )
    return [res.page_content for res, score in results]

query_rag("age", patient_id=518, rag=vector_store_1500, k=3)

['A 64-year-old female was brought to the intensive care unit after requiring an urgent procedure due to acute abdominal distress. This patient has a significant past medical background that renders the immune system less robust than average, raising concerns for increased susceptibility to postoperative and hospital-acquired complications. On arrival, the patient appeared visibly uncomfortable and mildly febrile: the temperature measured 100.4 degrees Fahrenheit. Vital signs revealed a systolic pressure of 100 mm Hg with a diastolic pressure of 60 mm Hg. The pulse was 110 beats per minute, and respiratory efforts were measured at 22 breaths per minute, reflecting some degree of physiologic stress. Arterial blood samples indicated a pH of 7.31, suggesting a mild acidemia at presentation. This was accompanied by concerns regarding oxygenation, as the partial pressure of oxygen was 68 mm Hg while the patient was receiving 40% supplemental oxygen. Laboratory results demonstrated plasma so

# Generate `.csv` for LLM testing

In [7]:
def process_chunks(chunks): 
    final_string = ""
    for i, val in enumerate(chunks): 
        final_string += f"{i + 1}.) {val}\n"

    return final_string

def retrieve_chunks(rag): 
    rows = []
    for i, row in test_df.iterrows(): 
        # for each patient, retrieve chunks for question and relevant entities
        row_number = row['Row Number']
        patient_id = row['patient_id']
        question = row['Question']
        entities = ast.literal_eval(row['Relevant Entities'])
    
        # get question chunks
        question_answer = row['Ground Truth Answer']
        question_chunks = query_rag(question, patient_id=patient_id, rag=rag, k=5)
    
        # add row
        rows.append({
            'Row Number': row_number, 
            'Question': question,
            'Answer': question_answer,
            'RAG Chunks': process_chunks(question_chunks), 
            'question_type': 'original'
        })
    
        # get chunks for each entity
        for ent, val in entities.items():
            ent_chunks = query_rag(ent, patient_id=patient_id, rag=rag, k=3)
            if type(val) == list: 
                ent_ans = val[0]
            else: 
                ent_ans = val
            
            rows.append({
                'Row Number': row_number, 
                'Question': "Find the value of the following attribute of the patient. " + ent,
                'Answer': ent_ans,
                'RAG Chunks': process_chunks(ent_chunks), 
                'question_type': 'entity'
            })

    return pd.DataFrame(rows)

retrieved_1500 = retrieve_chunks(rag=vector_store_1500)

In [8]:
retrieved_1500.to_csv("faiss_1500.csv")

In [9]:
retrieved_1500.shape

(6477, 5)