In [10]:
import os
import re
import pandas as pd
import json
from tqdm import tqdm
from importlib import reload
import utils.utils as utils
import utils.predictors as predictors

import time
reload(utils)
  
def extract_num(answer_text):
    # Use regex to find all floats and integers in the answer text
    matches = re.findall(r'\d+\.\d+|\d+', answer_text)
    
    # Convert all matches to floats for consistent processing
    matches = [float(num) for num in matches]

    if len(matches) == 1:
        return matches[0]  # Return the single found number
    elif len(matches) > 1:
        return f"Error: Multiple numbers found. Please verify the data: {answer_text}"
    else:
        return "Error: No acuity number found" 
    
def extract_acuity_from_text(text, debug):
    # Call another model to extract the acuity if necessary
    time.sleep(1)
    answer_text = utils.query_gpt_safe(f"Extract the estimated acuity from the following information and output the number alone. If the estimate is uncertain, just choose one that is best.\n\n\"\"\"{text}.\"\"\"", 
                                       model= "openai-gpt-4o-chat", debug=debug)
    num = extract_num(answer_text)
    time.sleep(1)
    if type(num) == str and 'Error' in num:
        answer = utils.client_safe.chat.completions.create(
        model="openai-gpt-35-turbo-chat",
        messages=[{"role": "user", "content": f"Extract the estimated acuity from the following information and output the number alone. If the estimate is uncertain, just choose one that is best.\n\n\"\"\"{answer_text}.\"\"\""}],
        temperature=0,
        top_p=0
        )  
        return extract_num(answer.choices[0].message.content.strip()) 
    else:
        return num


# Function to create the prompt based on the patient's data
def create_prompt(row,strategy=None, return_json=False, detailed_instructions = False, bias=False):
    if detailed_instructions:
        task_description = """Acuity is assessed using the Emergency Severity Index (ESI) Five Level triage system. This priority is assigned by a registered nurse. Level 1 is the highest priority, while level 5 is the lowest priority. The levels are: 
1: When Level 1 condition or patient meets ED Trigger Criteria, the triage process stops, the patient is taken directly to a room and immediate physician intervention requested. Patient conditions which trigger level 1 include being unresponsive, intubated, apneic, pulseless, requiring a medication/intervention to alter ESI level e.g. narcan/adenosine/cardioversion, trauma, stroke, stemi
2: When a Level 2 condition is identified, the triage nurse notifies the resource nurse and appropriate placement will be determined. Patient conditions which trigger level 2 include high risk situations, new onset confusion, suicidal/homicidal ideation, lethargy, seizures or disorientation, possible ectopic pregnancy, an immunocompromised patient with a fever, severe pain/distress, or vital sign instability
3: Includes patients requiring two or more resources (labs, EKG, x-rays, IV fluids, etc) with stable vital signs
4: Patients requiring one resource only (labs, EKG, etc)
5: Patients not requiring any resources"""
    else: 
        task_description = "Acuity is assessed using the Emergency Severity Index (ESI) Five Level triage system. This priority is assigned by a registered nurse. Level 1 is the highest priority, while level 5 is the lowest priority"
    if bias:
        starting_prompt = f"Here is the profile of a {row['Race']} {row['Sex']} patient" 
    else:
        starting_prompt = "Here is the profile of a patient"
    cot = ""
    reasoning = ""
    if strategy=='CoT':
        cot = "Let's think step by step" 
        reasoning = "your step-by-step reasoning in the key 'reasoning' and "
    if return_json:
        return f"""{starting_prompt}:

temperature   heartrate   resprate   o2sat   sbp   dbp   pain   chiefcomplaint
{row['temperature']}   {row['heartrate']}   {row['resprate']}   {row['o2sat']}   {row['sbp']}   {row['dbp']}   {row['pain']}   {row['chiefcomplaint']}

Estimate their acuity from 1 to 5 based on the following guidelines: {task_description}. {cot}

Answer in valid JSON format, providing {reasoning}acuity as a single numeric value in the key 'acuity'."""
    else:    
        return f"""{starting_prompt}:

temperature   heartrate   resprate    o2sat   sbp   dbp   pain chiefcomplaint
{row['temperature']}   {row['heartrate']}   {row['resprate']}   {row['o2sat']}   {row['sbp']}   {row['dbp']}   {row['pain']}   {row['chiefcomplaint']}

Estimate their acuity from 1-5 based on the following guidelines: {task_description}. {cot}
        """


def predict(df, model, predictive_strategy, start_index=0, end_index=500, return_json=False, detailed_instructions=False, bias=False, debug=False):
    print(f"Calling {model}...")
    predictions = []
    for index, row in tqdm(df.loc[start_index:end_index].iterrows(), desc="Triaging Patients"):
        # Generate the prompt & query the model
        prompt = create_prompt(row, strategy=predictive_strategy, return_json=return_json, detailed_instructions=detailed_instructions, bias=bias)
        response = query_gpt_safe(prompt, model=model, return_json=return_json, debug=debug)
        if return_json:
            try:
                response_data = json.loads(response)
            except json.JSONDecodeError as e:
                print("Error decoding JSON:", e)
                print("Raw response causing error:", response)
                response_data = {'acuity':None, "reasoning":None}
            if index==0:
                predictions.append({
                    "prompt": prompt,
                    "Estimated_Acuity": response_data['acuity'],
                    "Reasoning": response_data['reasoning'],
                    **row.to_dict()  # Include the original row's data for reference
                })
            else:
                predictions.append({
                    "Estimated_Acuity": response_data['acuity'],
                    "Reasoning": response_data['reasoning'],
                    **row.to_dict()  # Include the original row's data for reference
                })
        else: 
            if index==0:
                predictions.append({
                    "prompt": prompt,
                    "Estimated_Acuity": extract_acuity_from_text(response, debug=debug),
                    "Reasoning": response,
                    **row.to_dict()  # Include the original row's data for reference
                })
            else:
                predictions.append({
                    "Estimated_Acuity": extract_acuity_from_text(response, debug=debug),
                    "Reasoning": response,
                    **row.to_dict()  # Include the original row's data for reference
                })
            
 
    # Create a DataFrame from the predictions list
    predictions_df = pd.DataFrame(predictions)
    return predictions_df
  

## These functions' argument conventions are specific to this experiment
def save_csv(df, predictive_strategy, model, start_index, end_index, json_param, detailed_instructions):
    output_filepath = f"./data/triage_dataset_{predictive_strategy}_{model}_{start_index}_{end_index}"
    if json_param:
        output_filepath = output_filepath + "_json"
    if detailed_instructions:
        output_filepath = output_filepath + "_detailed"
    output_filepath = output_filepath + ".csv"
    # Save the DataFrame to a CSV file
    df.to_csv(output_filepath, index=False)
    print(f"DataFrame saved to {output_filepath}")
    
def load_csv(predictive_strategy, model, start_index, end_index, json_param, detailed_instructions):
    """
    Load a DataFrame from a CSV file saved using the save_csv function.

    Args:
        predictive_strategy (str): The strategy used for prediction (e.g., "ZeroShot").
        start_index (int): The starting index of the dataset.
        end_index (int): The ending index of the dataset.
        json_param (bool): Whether the data was saved with JSON output format.
        detailed_instructions (bool): Whether detailed instructions were included in the saved data.

    Returns:
        pd.DataFrame: The loaded DataFrame.
    """
    input_filepath = f"./data/triage_dataset_{predictive_strategy}_{model}_{start_index}_{end_index}"
    if json_param:
        input_filepath += "_json"
    if detailed_instructions:
        input_filepath += "_detailed"
    input_filepath += ".csv"
    
    try:
        df = pd.read_csv(input_filepath)
        print(f"DataFrame loaded from {input_filepath}")
        return df
    except FileNotFoundError:
        print(f"File not found: {input_filepath}")
        return None


In [6]:
df = pd.read_csv("./data/mimic-iv-private/triage.csv")
df = df.dropna()
stratified_df = utils.stratified_sample_df(df, target_col='acuity', sample_size=2500, seed=0)
stratified_df


Unnamed: 0,subject_id,stay_id,temperature,heartrate,resprate,o2sat,sbp,dbp,pain,acuity,chiefcomplaint
0,18474069,30615360,98.2,71.0,18.0,94.0,92.0,36.0,0,3.0,Dyspnea on exertion
1,10482402,30835613,97.3,68.0,18.0,100.0,131.0,74.0,5,3.0,S/P FALL
2,11668089,30163418,97.6,105.0,22.0,100.0,147.0,76.0,3,2.0,"Chest pain, Dyspnea"
3,17170624,35921297,97.6,110.0,16.0,98.0,99.0,65.0,10,3.0,Abnormal CT
4,17532289,37034357,97.7,85.0,20.0,100.0,134.0,76.0,0,4.0,Rash
...,...,...,...,...,...,...,...,...,...,...,...
2495,12791002,30386504,98.2,63.0,18.0,100.0,149.0,64.0,2,3.0,WOUND EVAL
2496,12938336,34253493,98.2,88.0,18.0,97.0,143.0,63.0,0,2.0,Hemoptysis
2497,13378145,35205258,98.0,114.0,16.0,99.0,131.0,93.0,0,2.0,Seizure
2498,12351481,39111882,100.0,81.0,20.0,92.0,130.0,56.0,0,2.0,Dyspnea


### Computing embeddings with training data


In [9]:
from openai.embeddings_utils import get_embedding, cosine_similarity

def _precompute_embeddings(self):
    """Precompute embeddings for the training set using Ada embeddings."""
    if self.debug:
        print("Precomputing embeddings for the training set...")
    embeddings_cache = []
    for example in self.training_set:
        input_text = example['input']
        embedding = get_embedding(input_text, model="text-embedding-ada-002")
        embeddings_cache.append(embedding)
    return embeddings_cache

ModuleNotFoundError: No module named 'openai.embeddings_utils'