## Import Libraries

In [26]:
import pandas as pd
import time
import os

from tqdm import tqdm

import boto3
from langchain.prompts import PromptTemplate
from langchain_aws import ChatBedrock

from IPython.display import display, HTML

## Data Display

In [27]:
def data_display(data):
    # Display the DataFrame with scroll and define the height and width for the scrollable area
    display(HTML(f'''
    <div style="height: 500px; overflow-y: scroll; overflow-x: scroll; border: 1px solid black; padding: 5px;">
        {data.to_html(max_rows=None, max_cols=None)}
    </div>
    '''))

## Load Data

In [None]:
## Load Data from postprocessing.ipynb
df = pd.read_csv("MIMIC-IV-Ext-Diagnosis-prediction.csv")

## Function: LLM evaluates the diagnosis predictions

In [None]:
## Function to evaluate the diagnosis predictions with retry logic
def get_evaluation_diagnosis(row, key, chain, max_retries=5, initial_wait=1):
    diagnosis = row["primary_diagnosis"]
    diag1 = row[key][0]
    diag2 = row[key][1]
    diag3 = row[key][2]

    attempt = 0
    while attempt < max_retries:
        try:
            # Invoke the chain with the diagnosis and icd_code
            evaluation= chain.invoke({"real_diag": diagnosis, "diag1": diag1, "diag2": diag2, "diag3": diag3}).content
            #print(evaluation)
            return evaluation  # Return on successful invocation

        except Exception as e:
            # Check if the error is a ThrottlingException or similar
            if "ThrottlingException" in str(e) or "Too many requests" in str(e):
                # Exponential backoff
                wait_time = initial_wait * (2 ** attempt)
                print(f"Throttling detected. Retrying after {wait_time} seconds...")
                time.sleep(wait_time)
                attempt += 1
            else:
                # Handle other types of exceptions
                return f"Error: {str(e)}"
    # If all retries fail, return an error
    return "Error: Max retries exceeded"

## LLM Evaluation

In [None]:
## Convert the diagnosis rows into lists - data in columns are stored as strings but actually represent lists
df['diagnosis_Claude3.5'] = df['diagnosis_Claude3.5'].apply(lambda x: eval(x))
df['diagnosis_Claude3'] = df['diagnosis_Claude3'].apply(lambda x: eval(x))
df['diagnosis_Haiku'] = df['diagnosis_Haiku'].apply(lambda x: eval(x))
df['diagnosis_Claude3.5_Clincal'] = df['diagnosis_Claude3.5_Clincal'].apply(lambda x: eval(x))
df['diagnosis_Claude3_Clinical'] = df['diagnosis_Claude3_Clincal'].apply(lambda x: eval(x))
df['diagnosis_Haiku_Clinical'] = df['diagnosis_Haiku_Clinical'].apply(lambda x: eval(x))

In [None]:
## Define the prompt template
prompt = """You are an experienced healthcare professional with expertise in medical and clinical domains. I will provide a list of real diagnoses for a patient and 3 predicted diagnoses. For each predicted diagnosis, determine if it has the same meaning as one of the real diagnoses or if the prediction falls under a broader category of one of the real diagnoses (e.g., a specific condition falling under a general diagnosis category). If it matches, return 'True'; otherwise, return 'False'. Return only 'True' or 'False' for each predicted diagnosis within <evaluation> tags and nothing else.
Real Diagnoses: {real_diag}, predicted diagnosis 1: {diag1}, predicted diagnosis 2: {diag2}, and predicted diagnosis 3: {diag3}."""


## set AWS credentials
os.environ["AWS_ACCESS_KEY_ID"]="Enter your AWS Access Key ID"
os.environ["AWS_SECRET_ACCESS_KEY"]="Enter your AWS Secret Access Key"

prompt_chain = PromptTemplate(template=prompt,input_variables=["real_diag", "diag1", "diag2", "diag3"])
client = boto3.client(service_name="bedrock-runtime", region_name=str("us-east-1"))


## Claude Sonnet 3.5
llm_claude35 = ChatBedrock(model_id="anthropic.claude-3-5-sonnet-20240620-v1:0", model_kwargs={"temperature": 0}, client=client)
chain_claude35 = prompt_chain | llm_claude35


tqdm.pandas()
keys = ["diagnosis_Claude3.5", "diagnosis_Claude3", 'diagnosis_Haiku', 'diagnosis_Claude3.5_Clinical', 'diagnosis_Claude3_Clinical','diagnosis_Haiku_Clinical']

for key in keys:
    df["eval_"+key] = df.progress_apply(lambda row: get_evaluation_diagnosis(row, key, chain_claude35), axis=1)
    df.to_csv('MIMIC-IV-Ext-Diagnosis-evaluation.csv', index=False)