In [1]:
from collections import OrderedDict
import dotenv
import random
import numpy as np
import json
import pickle
import os
import numpy as np
import pandas as pd
import itertools
import openai
import schema
env_file = '.env'
dotenv.load_dotenv(env_file, override=True)
client = openai.OpenAI(api_key=os.getenv("CORRELL_API_KEY"))
from patient_prompt import prompt as pp
from doctor_prompt_structured import prompt as dp

patient_profiles = pickle.load(open('patient_profiles.pkl', 'rb'))
threshold = 0.8
steps = 8
num_profiles = 1 # how many interactions to run out of 240 profiles
patient_profiles: dict[int, schema.Profile] = OrderedDict(itertools.islice(patient_profiles.items(), num_profiles))

In [2]:
# Function to call OpenAI API
def call_openai_doctor(messages: list[schema.Message]) -> schema.DoctorResponse:
    response = client.beta.chat.completions.parse(
        model="gpt-4o-mini",
        messages=messages,
        temperature=0.4,
        response_format=schema.DoctorResponse
    )
    return response.choices[0].message.parsed

def call_openai(messages: list[schema.Message]) -> str:
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=messages,
        temperature=0.7
    )
    return response.choices[0].message.content

doctor_histories: dict[int, list[schema.Message]] = {}
patient_histories: dict[int, list[schema.Message]] = {}

# Iterate through patient profiles
for i, profile in patient_profiles.items():
    doctor_config = {
        "gender": profile["gender"],
        "ethnicity": profile["ethnicity"],
        "confidence_threshold": threshold,
        "interaction_steps": steps
    }

    # Initialize metadata
    metadata = profile['interaction_metadata']
    metadata.update({
        "diagnosis": None,
        "diagnosis_success": False,
        "interaction_duration": 0,
        "num_symptoms_recovered": 0,
        "confidence_history": []
    })
    
    # Format prompts
    pp_copy = pp.format(**profile)
    dp_copy = dp.format(**doctor_config)
    
    # Initialize conversation histories
    doctor_history: list[schema.Message] = [{"role": "system", "content": dp_copy}]
    patient_history: list[schema.Message] = [{"role": "system", "content": pp_copy}]

    doctor_reply = ''

    next_response_is_last = False
    
    # Run interaction loop
    for step in range(steps):
        metadata["interaction_duration"] += 1
        
        # Update the patient history
        patient_history.append({"role": "user", "content": doctor_reply})

        # Patient response
        patient_reply = call_openai(patient_history)

        # Update patient conversation history
        patient_history.append({"role": "assistant", "content": patient_reply})

        # Update doctor conversation history
        doctor_history.append({"role": "user", "content": patient_reply})

        # Doctor response
        doctor_response = call_openai_doctor(doctor_history)
        
        # Update doctor conversation history with its full reply
        doctor_history.append({"role": "assistant", "content": doctor_response.model_dump_json()})
        
        if len(doctor_response.diagnosis_rankings):
            diagnosis = max(doctor_response.diagnosis_rankings,  key=lambda x: x.confidence)
            metadata['diagnosis'] = diagnosis.diagnosis
            metadata['confidence_history'].append(diagnosis.confidence)
        # TODO: Counting could use improvement
        metadata['num_symptoms_recovered'] += 1 if any(symptom in patient_reply for symptom in doctor_response.symptoms_to_verify_or_refute) else 0
        
        if doctor_response.diagnosis_relayed_to_patient:
            metadata["diagnosis_success"] = True
            break

        doctor_reply = doctor_response.response_to_patient
    
    # Store updated metadata
    profile['interaction_metadata'] = metadata

    # Store conversation histories
    doctor_histories[i] = doctor_history
    patient_histories[i] = patient_history

In [3]:
import save

for patient_id in {0}:
    filename = save.save_history(patient_profiles[patient_id], patient_histories[patient_id], doctor_histories[patient_id])
    print(f'Conversation saved to {filename}')

Conversation saved to conversations/0_1690.json
