# LLM-based Diabetes Risk Prediction

This notebook demonstrates the use of different LLM prompting techniques for diabetes risk prediction.

In [None]:
import os
import sys
import pandas as pd
import json
import openai
from typing import Dict, List, Any
from dotenv import load_dotenv
from tqdm.notebook import tqdm

# Add the parent directory to the path so we can import our modules
sys.path.append('..')
from src.data_processing import load_patient_data, create_patient_dataframe, prepare_patient_text
from src.prompt_engineering import (
    basic_prompt, in_context_learning_prompt, few_shot_learning_prompt,
    chain_of_thought_prompt, tree_of_thought_prompt, combined_approach_prompt,
    create_examples
)
from src.evaluation import evaluate_response, aggregate_evaluations
from src.visualization import plot_all_metrics, create_comparison_table

# Load environment variables from .env file if it exists
load_dotenv()

# Set OpenAI API key
openai.api_key = os.environ.get("OPENAI_API_KEY")

## 1. Load and Process Data

In [None]:
# Set data directory
data_dir = "../data/synthea"

# Load patient data
print("Loading patient data...")
patients = load_patient_data(data_dir)
print(f"Loaded {len(patients)} patients")

# Convert to DataFrame
df = create_patient_dataframe(patients)
print("
DataFrame created with shape:", df.shape)
print("
Risk category distribution:")
print(df['diabetes_risk_category'].value_counts())

## 2. Explore Patient Data

In [None]:
# Display a sample patient
sample_patient = df.iloc[0].to_dict()
print("Sample patient data:")
sample_patient_text = prepare_patient_text(sample_patient)
print(sample_patient_text)

## 3. Create Examples for In-Context Learning

In [None]:
# Create examples
examples = create_examples(df)
print(f"Created {len(examples)} examples")

# Display one example from each risk category
for example in examples:
    print(f"
Risk Level: {example['risk_level']}")
    print(f"Patient:
{example['patient_text']}")
    print(f"Reasoning: {example['reasoning']}")

## 4. Generate Prompts

In [None]:
# Select test patients (one from each risk category for demonstration)
test_patients = []
for risk_level in ['Low', 'Medium', 'High']:
    category_patients = df[df['diabetes_risk_category'] == risk_level]
    if not category_patients.empty:
        # Get the last patient from each category (not the one used for examples)
        test_patients.append(category_patients.iloc[-1])

# Convert to DataFrame
test_df = pd.DataFrame(test_patients)
print(f"Selected {len(test_df)} test patients")

# Generate prompts for the first test patient
if not test_df.empty:
    test_patient = test_df.iloc[0].to_dict()
    patient_text = prepare_patient_text(test_patient)
    
    # Generate and display each prompt type
    print("
Basic Prompt:")
    print(basic_prompt(patient_text))
    
    print("
In-Context Learning Prompt:")
    print(in_context_learning_prompt(patient_text, examples))
    
    print("
Few-Shot Learning Prompt:")
    print(few_shot_learning_prompt(patient_text, examples))
    
    print("
Chain-of-Thought Prompt:")
    print(chain_of_thought_prompt(patient_text))
    
    print("
Tree-of-Thought Prompt:")
    print(tree_of_thought_prompt(patient_text))
    
    print("
Combined Approach Prompt:")
    print(combined_approach_prompt(patient_text, examples))

## 5. Call OpenAI API

In [None]:
def call_llm(prompt: str, model: str = "gpt-4") -> str:
    """Call the LLM with a prompt and return the response."""
    response = openai.ChatCompletion.create(
        model=model,
        messages=[
            {"role": "system", "content": "You are a medical assistant skilled in diabetes risk assessment."},
            {"role": "user", "content": prompt}
        ],
        max_tokens=1000
    )
    return response.choices[0].message.content

# Test the API with a simple prompt to verify connectivity
try:
    test_response = call_llm("Hello, can you help with diabetes risk assessment?")
    print("API test successful. Response:")
    print(test_response)
except Exception as e:
    print(f"Error connecting to OpenAI API: {e}")
    print("Please check your API key in the .env file.")

## 6. Run Experiment

In [None]:
# Select a smaller number of test patients (adjust based on your budget)
# For a real experiment, use at least 10 patients per category
test_size = 3  # Use a small number for initial testing
test_patients = []
for risk_level in ['Low', 'Medium', 'High']:
    category_patients = df[df['diabetes_risk_category'] == risk_level]
    if not category_patients.empty:
        # Randomly sample patients from each category
        sample_size = min(test_size, len(category_patients))
        sampled = category_patients.sample(sample_size)
        test_patients.append(sampled)

# Combine into one DataFrame
test_df = pd.concat(test_patients)
print(f"Selected {len(test_df)} test patients for the experiment")
print(test_df['diabetes_risk_category'].value_counts())

# Define prompting methods
prompting_methods = {
    'Basic': lambda p: basic_prompt(p),
    'In-Context Learning': lambda p: in_context_learning_prompt(p, examples),
    'Few-Shot Learning': lambda p: few_shot_learning_prompt(p, examples),
    'Chain of Thought': lambda p: chain_of_thought_prompt(p),
    'Tree of Thought': lambda p: tree_of_thought_prompt(p),
    'Combined Approach': lambda p: combined_approach_prompt(p, examples)
}

# Run experiments
results = {}
responses = {}

for method_name, prompt_fn in prompting_methods.items():
    print(f"Testing {method_name} method...")
    method_evaluations = []
    method_responses = []
    
    for i, (_, patient) in enumerate(tqdm(test_df.iterrows(), total=len(test_df))):
        patient_text = prepare_patient_text(patient)
        prompt = prompt_fn(patient_text)
        
        # Call the API
        response = call_llm(prompt)
        
        # Evaluate the response
        evaluation = evaluate_response(response, patient['diabetes_risk_category'])
        method_evaluations.append(evaluation)
        
        # Store the response
        method_responses.append({
            'patient_id': patient['id'],
            'true_risk': patient['diabetes_risk_category'],
            'predicted_risk': evaluation['predicted_risk'],
            'prompt': prompt,
            'response': response
        })
    
    # Aggregate results for this method
    results[method_name] = aggregate_evaluations(method_evaluations)
    responses[method_name] = method_responses
    
    # Show interim results
    print(f"Accuracy: {results[method_name]['accuracy']:.2f}")
    print(f"Avg. Reasoning Depth: {results[method_name]['avg_reasoning_depth']:.2f}")
    print(f"Avg. Factors Mentioned: {results[method_name]['avg_factors_mentioned']:.2f}")
    print()

## 7. Save Results

In [None]:
# Create results directory if it doesn't exist
os.makedirs("../results", exist_ok=True)

# Save raw responses
with open("../results/responses.json", "w") as f:
    json.dump(responses, f, indent=2)

# Save aggregated metrics
metrics_df = create_comparison_table(results)
metrics_df.to_csv("../results/metrics.csv")
print("Results saved to ../results/")

# Display the comparison table
print("
Comparison of Methods:")
display(metrics_df)

## 8. Visualize Results

In [None]:
# Generate all plots
plot_all_metrics(results, save_dir="../results")
print("Visualizations saved to ../results/")

# Display a sample response from each method
for method_name, method_responses in responses.items():
    if method_responses:
        print(f"
=== Sample Response from {method_name} Method ===")
        sample = method_responses[0]
        print(f"Patient ID: {sample['patient_id']}")
        print(f"True Risk: {sample['true_risk']}")
        print(f"Predicted Risk: {sample['predicted_risk']}")
        print(f"Response:
{sample['response']}")
        print("="*50)

## 9. Analysis and Conclusions

In [None]:
# Compare accuracies
methods = list(results.keys())
best_method = max(methods, key=lambda m: results[m]['accuracy'])
best_accuracy = results[best_method]['accuracy']

print(f"Best performing method: {best_method} with accuracy {best_accuracy:.2f}")

# Compare reasoning depth
best_reasoning = max(methods, key=lambda m: results[m]['avg_reasoning_depth'])
best_depth = results[best_reasoning]['avg_reasoning_depth']

print(f"Method with deepest reasoning: {best_reasoning} with average depth {best_depth:.2f}")

# Compare factors mentioned
best_factors = max(methods, key=lambda m: results[m]['avg_factors_mentioned'])
most_factors = results[best_factors]['avg_factors_mentioned']

print(f"Method mentioning most factors: {best_factors} with average {most_factors:.2f} factors")

# Overall conclusions
print("
Conclusions:")
print("1. The Combined Approach method that leverages both in-context learning and ")
print("   chain-of-thought reasoning tends to perform best for diabetes risk prediction.")
print("2. Methods that encourage structured reasoning (Chain-of-Thought, Tree-of-Thought)")
print("   generally provide more comprehensive analyses with deeper reasoning.")
print("3. In-context learning with examples helps the model understand the task better,")
print("   especially for edge cases or patients with complex profiles.")
print("
Future improvements:")
print("1. Use larger and more diverse datasets with more balanced risk categories")
print("2. Fine-tune models specifically for medical risk assessment")
print("3. Incorporate medical guidelines more explicitly in prompts")
print("4. Explore hybrid approaches that combine LLM assessments with traditional")
print("   risk calculators for diabetes")