In [None]:
from collections import OrderedDict
import dotenv
import random
import numpy as np
import json
import pickle
import os
import enum
import numpy as np
import pandas as pd
import itertools
import openai
import schema
import anthropic
import instructor
from google import generativeai as gemini
from prompts.patient_prompt import prompt as pp
from prompts.doctor_prompt_structured import prompt as dp

class Model(enum.Enum):
    OPENAI = enum.auto()
    GEMINI = enum.auto()
    CLAUDE = enum.auto()  # NOTE: Super slow and throws "overloaded" errors

# Change me to test another model
model_type = Model.CLAUDE

env_file = '.env'
dotenv.load_dotenv(env_file, override=True)
gemini.configure(api_key=os.getenv("GEMINI_API_KEY"))

openai_api = openai.OpenAI(api_key=os.getenv("CORRELL_API_KEY"))
gemini_api = gemini.GenerativeModel(model_name='gemini-2.0-flash-lite')
claude_api = anthropic.Anthropic(api_key=os.getenv('CLAUDE_API_KEY'))

openai_client = instructor.from_openai(client=openai_api)
gemini_client = instructor.from_gemini(client=gemini_api, mode=instructor.Mode.GEMINI_JSON)
claude_client = instructor.from_anthropic(client=claude_api)

if model_type == Model.OPENAI:
    unstructured_client = openai_api
    client = openai_client
    model = 'gpt-4o-mini'
    call_doctor_kwargs = {
        'model': model,
        'temperature': 0.7,
    }
    call_patient_kwargs = {
        'model': model,
        'temperature': 0.7,
    }
elif model_type == Model.GEMINI:
    unstructured_client = gemini_api
    client = gemini_client
    model = 'gemini-2.0-flash-lite'
    call_doctor_kwargs = {}
    call_patient_kwargs = {}
elif model_type == Model.CLAUDE:
    unstructured_client = claude_api
    client = claude_client
    model = 'claude-3-5-haiku-20241022'
    call_doctor_kwargs = {
        'model': model,
        # 'temperature': 0.7,
        'max_tokens': 1024,
    }
    call_patient_kwargs = {
        'model': model,
        # 'temperature': 0.7,
        'max_tokens': 1024,
    }

patient_profiles = pickle.load(open('patient_profiles.pkl', 'rb'))
threshold = 0.8
steps = 8
num_profiles = 1 # how many interactions to run out of 240 profiles
patient_profiles: dict[int, schema.Profile] = OrderedDict(itertools.islice(patient_profiles.items(), num_profiles))

In [5]:
def call_doctor(messages: list[schema.Message]) -> schema.DoctorResponse:
    response = client.completions.create(
        messages=messages,
        response_model=schema.DoctorResponse,
        **call_doctor_kwargs
    )
    return response

def call_patient(messages: list[schema.Message]) -> str:
    response = client.completions.create(
        messages=messages,
        response_model=schema.PatientResponse,
        **call_patient_kwargs
    )
    return response.response

doctor_histories: dict[int, list[schema.Message]] = {}
patient_histories: dict[int, list[schema.Message]] = {}

# Iterate through patient profiles
for i, profile in patient_profiles.items():
    doctor_config = {
        "gender": profile["gender"],
        "ethnicity": profile["ethnicity"],
        "confidence_threshold": threshold,
        "interaction_steps": steps
    }

    # Initialize metadata
    metadata = profile['interaction_metadata']
    metadata.update({
        "diagnosis": None,
        "diagnosis_success": False,
        "interaction_duration": 0,
        "num_symptoms_recovered": 0,
        "confidence_history": [],
        'model': model,
    })
    
    # Format prompts
    pp_copy = pp.format(**profile)
    dp_copy = dp.format(**doctor_config)
    
    # Initialize conversation histories
    doctor_history: list[schema.Message] = [{"role": "system", "content": dp_copy}]
    patient_history: list[schema.Message] = [{"role": "system", "content": pp_copy}]

    doctor_reply = '<initial prompt>'

    next_response_is_last = False
    
    # Run interaction loop
    for step in range(steps):
        metadata["interaction_duration"] += 1
        
        # Update the patient history
        patient_history.append({"role": "user", "content": doctor_reply})

        # Patient response
        patient_reply = call_patient(patient_history)

        # Update patient conversation history
        patient_history.append({"role": "assistant", "content": patient_reply})

        # Update doctor conversation history
        doctor_history.append({"role": "user", "content": patient_reply})

        # Doctor response
        doctor_response = call_doctor(doctor_history)
        
        # Update doctor conversation history with its full reply
        doctor_history.append({"role": "assistant", "content": doctor_response.model_dump_json()})
        
        if len(doctor_response.diagnosis_rankings):
            diagnosis = max(doctor_response.diagnosis_rankings,  key=lambda x: x.confidence)
            metadata['diagnosis'] = diagnosis.diagnosis
            metadata['confidence_history'].append(diagnosis.confidence)
        # TODO: Counting could use improvement
        metadata['num_symptoms_recovered'] += 1 if any(symptom in patient_reply for symptom in doctor_response.symptoms_to_verify_or_refute) else 0
        
        if doctor_response.diagnosis_relayed_to_patient:
            metadata["diagnosis_success"] = True
            break

        doctor_reply = doctor_response.response_to_patient
    
    # Store updated metadata
    profile['interaction_metadata'] = metadata

    # Store conversation histories
    doctor_histories[i] = doctor_history
    patient_histories[i] = patient_history

InstructorRetryException: Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}

In [9]:
import save

for patient_id in {0}:
    filename = save.save_history(patient_profiles[patient_id], patient_histories[patient_id], doctor_histories[patient_id])
    print(f'Conversation saved to {filename}')

Conversation saved to conversations/0_6550.json
