In [2]:
import os
import dotenv
import pickle
import itertools
import numpy as np
import pandas as pd
from collections import OrderedDict
from concurrent.futures import ProcessPoolExecutor
import openai
import schema
import save
from prompts.patient_prompt import prompt as pp
from prompts.doctor_prompt_structured import prompt as dp
from prompts.symptom_check_prompt import prompt as scp
from prompts.symptom_check_prompt import reply as scr
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from openai import OpenAIError, Timeout

# Load environment variables
env_file = ".env"
dotenv.load_dotenv(env_file, override=True)

# OpenAI client setup
client = openai.OpenAI(api_key=os.getenv("CORRELL_API_KEY"))

# Load patient profiles
patient_profiles = pickle.load(open("patient_profiles.pkl", "rb"))
threshold = 0.8
steps = 5

# Total number of profiles to process
num_profiles = len(patient_profiles)  # 240 profiles
batch_size = 10  # 10 profiles per process
num_workers = num_profiles // batch_size  # 24 parallel processes

# Convert to an ordered dictionary
patient_profiles: dict[int, schema.Profile] = OrderedDict(patient_profiles)

# Split into batches of 10 profiles
profile_batches = [
    dict(itertools.islice(patient_profiles.items(), i * batch_size, (i + 1) * batch_size))
    for i in range(num_workers)
]

In [3]:
# Retry logic for API calls
@retry(
    stop=stop_after_attempt(5),  # Retry up to 5 times
    wait=wait_exponential(min=1, max=10),  # Exponential backoff between retries (1s to 10s)
    retry=retry_if_exception_type((Timeout, OpenAIError))  # Retry on timeout or OpenAIError
)
def call_openai_doctor(messages: list[schema.Message]) -> schema.DoctorResponse:
    response = client.beta.chat.completions.parse(
        model="gpt-4o-mini",
        messages=messages,
        temperature=0.4,
        response_format=schema.DoctorResponse
    )
    return response.choices[0].message.parsed

@retry(
    stop=stop_after_attempt(5),
    wait=wait_exponential(min=1, max=10),
    retry=retry_if_exception_type((Timeout, OpenAIError))
)
def call_openai(messages: list[schema.Message]) -> str:
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=messages,
        temperature=0.7
    )
    return response.choices[0].message.content

@retry(
    stop=stop_after_attempt(5),
    wait=wait_exponential(min=1, max=10),
    retry=retry_if_exception_type((Timeout, OpenAIError))
)
def call_openai_symptom_check(messages: list[schema.Message]) -> schema.SymptomCheck:
    response = client.beta.chat.completions.parse(
        model="gpt-4o-mini",
        messages=messages,
        temperature=0.4,
        response_format=schema.SymptomCheck
    )
    return response.choices[0].message.parsed

# Diagnosis confidence formatting
def get_diagnosis_confidence(diagnosis_history: list[str, float]) -> dict[str, float]:
    all_diagnoses = {d.diagnosis for step in diagnosis_history for d in step}
    diagnosis_confidence = {d: [] for d in all_diagnoses}
    
    for step in diagnosis_history:
        step_conf_dict = {d.diagnosis: d.confidence for d in step}
        for d in all_diagnoses:
            diagnosis_confidence[d].append(step_conf_dict.get(d, np.nan))  # Use NaN if missing

    return diagnosis_confidence

In [None]:
# Function to process a batch of profiles
def process_profiles(profiles_batch: dict[int, schema.Profile], batch_id: int):
    print(f"Starting batch {batch_id} with {len(profiles_batch)} profiles...")

    for i, profile in profiles_batch.items():
        patient_data = {k: profile[k] for k in list(OrderedDict(profile))[1:-1]}
        doctor_config = {
            "gender": patient_data["gender"],
            "ethnicity": patient_data["ethnicity"],
            "confidence_threshold": threshold,
            "interaction_steps": steps
        }

        # Initialize metadata
        metadata = profile["interaction_metadata"]
        metadata.update({
            "diagnosis": None,
            "diagnosis_success": False,
            "interaction_duration": 0,
            "num_symptoms_recovered": 0,
            "confidence_history": []
        })

        # Format prompts
        pp_copy = pp.format(**patient_data)
        dp_copy = dp.format(**doctor_config)

        # Initialize conversation histories
        doctor_history = [{"role": "system", "content": dp_copy}]
        patient_history = [{"role": "system", "content": pp_copy}]
        patient_reply = ""

        next_response_is_last = False
        doctor_responses = []
        print(f"Processing profile {i + 1}/{num_profiles}...")

        # Run interaction loop
        for step in range(steps):
            if step == 0:
                greeting = "Hi, I'll be your doctor today. What brings you in?"
                patient_history.append({"role": "assistant", "content": greeting})
                patient_reply = call_openai(patient_history)
                patient_history.append({"role": "user", "content": patient_reply})

            metadata["interaction_duration"] += 1
            doctor_history.append({"role": "user", "content": patient_reply})

            # Doctor response
            doctor_response = call_openai_doctor(doctor_history)
            doctor_responses.append(doctor_response)

            # Update doctor conversation history
            doctor_history.append({"role": "assistant", "content": doctor_response.model_dump_json()})

            if doctor_response.diagnosis_rankings:
                diagnosis = max(doctor_response.diagnosis_rankings, key=lambda x: x.confidence)
                metadata["diagnosis"] = diagnosis.diagnosis
                metadata["confidence_history"].append(diagnosis.confidence)

            if next_response_is_last:
                break
            elif metadata["confidence_history"] and metadata["confidence_history"][-1] >= threshold:
                metadata["diagnosis_success"] = "melanoma" in metadata["diagnosis"].lower()
                next_response_is_last = True

            patient_history.append({"role": "user", "content": doctor_response.response_to_patient})
            patient_reply = call_openai(patient_history + [{"role": "user", "content": doctor_response.response_to_patient}])
            patient_history.append({"role": "assistant", "content": patient_reply})

        recovered_symptoms_history = [d.known_symptoms for d in doctor_responses]
        symptoms = {**patient_data["revealed_symptoms"], **patient_data["hidden_symptoms"]}
        scr_copy = scr.format(recovered_symptoms_history=recovered_symptoms_history, symptoms=symptoms)
        symptom_check = [{"role": "system", "content": scp}, {"role": "user", "content": scr_copy}]
        symptom_check_response = call_openai_symptom_check(symptom_check)

        metadata["num_symptoms_recovered"] = symptom_check_response.found_symptoms
        metadata["num_symptoms_recovered_history"] = symptom_check_response.found_symptoms_history

        diagnosis_history = [d.diagnosis_rankings for d in doctor_responses]
        metadata["diagnosis_confidence_history"] = get_diagnosis_confidence(diagnosis_history)

        profile["interaction_metadata"] = metadata
        save.save_history(profile, patient_history, doctor_history)

    print(f"Batch {batch_id} completed.")

In [None]:
# Run parallel execution
with ProcessPoolExecutor(max_workers=num_workers) as executor:
    for batch_id, batch in enumerate(profile_batches):
        executor.submit(process_profiles, batch, batch_id)
