# Large Language Models for Chronic Condition Classification
## Zero-shot Learning for Clinical Phenotyping

This notebook documents the methodology used in the paper:

**Zero-shot learning for clinical phenotyping: Comparing LLMs and rule-based methods**  
*Computers in Biology and Medicine*  
https://doi.org/10.1016/j.compbiomed.2025.110181

---

### Experiment Overview

**Objective**: Compare the performance of different approaches in classifying 20 chronic conditions from patient clinical notes using zero-shot learning.

**Methods Evaluated**:
1. **Rule-based dictionary approach** (baseline)
2. **Large Language Models** via API calls

**Dataset**: 1,000 patients with aggregated clinical notes
- 57.5% female, 42.5% male
- Mean age: 51.3 years

**Chronic Conditions** (Top 20 by prevalence):
1. Hyperlipidemia
2. Autoimmune diseases
3. Hypertension
4. Benign prostate hypertrophy
5. Cancer
6. Heart Failure
7. Arrhythmias
8. Anxiety disorder
9. Disorders of thyroid gland
10. Ischemic heart disease
11. Osteoporosis
12. Disorder of vertebral column
13. Depression
14. Urinary incontinence
15. Osteoarthritis
16. Diabetes
17. Asthma
18. Anemia
19. Peripheral neuropathy
20. Epilepsy

---
## Part 1: Data Preparation

### 1.1 Ground Truth Generation

Ground truth labels were created using multiple sources:
- **Diagnosis codes** (ICD-10)
- **Medication prescriptions** (drug-condition mappings)
- **Laboratory results** (condition-specific tests)
- **Procedures** (condition-related procedures)

Patient dictionaries were created mapping each chronic condition to patient IDs who have evidence of that condition.

In [None]:
# Example: Loading ground truth dictionaries
import pickle

# Load dictionaries mapping conditions to patient IDs
# These were created from EHR data using rule-based classification

# patients_by_condition.pkl: Based on diagnosis codes only
with open('data/patients_by_condition.pkl', 'rb') as f:
    patients_by_condition = pickle.load(f)

# patients_by_drug_condition.pkl: Based on medication prescriptions
with open('data/patients_by_drug_condition.pkl', 'rb') as f:
    patients_by_drug_condition = pickle.load(f)

# patients_by_lab.pkl: Based on laboratory results
with open('data/patients_by_lab.pkl', 'rb') as f:
    patients_by_lab = pickle.load(f)

# patients_by_procedure.pkl: Based on procedures
with open('data/patients_by_procedure.pkl', 'rb') as f:
    patients_by_procedure = pickle.load(f)

# patients_any.pkl: Union of all sources (comprehensive ground truth)
with open('data/patients_any.pkl', 'rb') as f:
    patients_any = pickle.load(f)

print(f"Total chronic conditions tracked: {len(patients_any)}")
print(f"Available conditions: {list(patients_any.keys())}")

### 1.2 Patient Clinical Note Aggregation

For each patient, clinical data was aggregated into a structured JSON format containing:
- **Demographics**: Age, gender
- **Events by age**: Chronologically ordered clinical events
- **Diagnoses**: All recorded diagnosis codes
- **Medications**: Prescribed medications
- **Laboratory results**: Key lab measurements
- **Procedures**: Medical procedures performed

This structured data serves as the input for LLM evaluation.

In [None]:
import json

# Example of patient data structure
example_patient = {
    "sex": "female",
    "events_by_age": [
        {
            "age": 45,
            "diagnoses": ["E11.9 - Type 2 diabetes mellitus without complications"],
            "medications": ["Metformin 500mg"],
            "lab_results": ["HbA1c: 7.2%"],
            "procedures": []
        },
        {
            "age": 46,
            "diagnoses": ["E11.9 - Type 2 diabetes mellitus without complications"],
            "medications": ["Metformin 500mg", "Lisinopril 10mg"],
            "lab_results": ["HbA1c: 6.8%", "Blood pressure: 145/92"],
            "procedures": []
        }
    ]
}

# In practice, load from your data file
# with open('data/data_all.json', 'r') as f:
#     patient_data = json.load(f)

print("Example patient data structure:")
print(json.dumps(example_patient, indent=2))

---
## Part 2: LLM Evaluation Framework

### 2.1 Prompt Engineering

A carefully designed prompt guides the LLM to:
1. Act as an experienced doctor
2. Review the patient's clinical note
3. Assess presence of specific chronic conditions
4. Consider:
   - Direct mentions in diagnoses
   - Related medical terms and subtypes
   - Indirect evidence (medications, labs, procedures)
5. Return structured JSON output with:
   - Comorbidity name
   - Rationale for decision
   - Binary classification (true/false)
   - Confidence level (low/medium/high)

In [None]:
def create_prompt(patient_summary, chronic_condition):
    """
    Create a structured prompt for LLM chronic condition assessment.
    
    Args:
        patient_summary: JSON string of patient clinical data
        chronic_condition: Name of condition to assess
    
    Returns:
        Formatted prompt string
    """
    prompt = f"""
# Task
You are an experienced doctor tasked with determining if the patient has the chronic condition listed below based on their clinical note.

# Patient
Below is a clinical note describing the patient's aggregated health information:
{patient_summary}

# Chronic Condition
The chronic condition being assessed is:
{chronic_condition}

# Assessment Instructions
Evaluate only the chronic condition listed above. Do not consider or mention any other conditions.

Use the patient's clinical note to determine whether the patient has that chronic condition. Provide a detailed explanation.

First, consider any related medical terms, subtypes, or diagnoses related to the chronic condition. If any related terms or subtypes are found, confirm the presence of the condition.
Then, look at the clinical note sections such as procedures, measurements, and medications to infer if there could be a chronic condition based on the procedures, laboratory results, and medications.
Remember to consider subtypes of the chronic conditions when making your assessment.

Provide your response as a JSON dictionary with the following elements:
* comorbidity: str - The name of the comorbidity being assessed
* rationale: str - Your reasoning for the assessment
* is_met: bool - "true" if the patient has the comorbidity, otherwise "false"
* confidence: str - Your confidence level ("low", "medium", "high")

An example of how your JSON response should be formatted is shown below:
```json
{{
    "comorbidity": "{chronic_condition}",
    "rationale": "Reason for assessment",
    "is_met": true/false,
    "confidence": "low/medium/high"
}}
```
"""
    return prompt

# Example usage
example_condition = "Diabetes"
example_prompt = create_prompt(json.dumps(example_patient, indent=2), example_condition)
print("Example prompt (truncated):")
print(example_prompt[:500] + "...")

### 2.2 Generic LLM API Client

A generic client class that can interface with any LLM API endpoint (e.g., OpenAI, Azure OpenAI, Anthropic, local models, etc.).

In [None]:
import os
import logging
import requests
from dotenv import load_dotenv

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class GenericLLMClient:
    """
    Generic client for LLM API interactions.
    Supports any API that accepts chat-style requests.
    """
    
    def __init__(self, env_path='.env'):
        """
        Initialize LLM client.
        
        Args:
            env_path: Path to .env file containing API credentials
        """
        load_dotenv(dotenv_path=env_path)
        
        # Generic configuration - adapt based on your API
        self.api_key = os.getenv("LLM_API_KEY")
        self.api_endpoint = os.getenv("LLM_API_ENDPOINT")
        self.api_version = os.getenv("LLM_API_VERSION", "")
        
        self.validate_env_vars()
    
    def validate_env_vars(self):
        """Validate required environment variables."""
        if not all([self.api_key, self.api_endpoint]):
            raise ValueError("Missing required API credentials in .env file")
    
    def get_llm_response(self, model_name, prompt_text, temperature=0.0):
        """
        Get LLM response for a given prompt.
        
        This is a generic implementation. Adapt the request format
        based on your specific API requirements.
        
        Args:
            model_name: Model identifier
            prompt_text: The prompt to send
            temperature: Sampling temperature (0.0 for deterministic)
        
        Returns:
            Response text from the model
        """
        try:
            # Generic request structure - adapt for your API
            headers = {
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json"
            }
            
            payload = {
                "model": model_name,
                "messages": [
                    {
                        "role": "system",
                        "content": "You are an experienced doctor evaluating a patient's medical record."
                    },
                    {
                        "role": "user",
                        "content": prompt_text
                    }
                ],
                "temperature": temperature
            }
            
            response = requests.post(
                self.api_endpoint,
                headers=headers,
                json=payload,
                timeout=120
            )
            response.raise_for_status()
            
            # Parse response - adapt based on API response format
            result = response.json()
            return result['choices'][0]['message']['content']
            
        except Exception as e:
            logging.error(f"Error in getting LLM response: {e}")
            return None

print("Generic LLM client class defined")

### 2.3 Response Processing

LLM responses need to be parsed and validated to extract structured predictions.

In [None]:
def clean_json_response(response):
    """
    Clean JSON response from LLM output.
    
    Removes markdown code fences and other formatting.
    
    Args:
        response: Raw response string from LLM
    
    Returns:
        Cleaned JSON string
    """
    # Remove markdown JSON code blocks
    if response.startswith('```json'):
        response = response[7:]
    if response.endswith('```'):
        response = response[:-3]
    return response.strip()

def evaluate_condition(client, model_name, patient_id, patient_summary, chronic_condition):
    """
    Evaluate a single chronic condition for a patient.
    
    Args:
        client: LLM client (Azure or Ollama)
        model_name: Model identifier
        patient_id: Patient identifier
        patient_summary: JSON string of patient data
        chronic_condition: Condition name to evaluate
    
    Returns:
        Dictionary with assessment results or None if error
    """
    # Create the prompt
    prompt = create_prompt(patient_summary, chronic_condition)
    logging.info(f"Evaluating patient {patient_id} for {chronic_condition}")
    
    # Get the response from the model
    response = client.get_llm_response(model_name, prompt)
    
    # Parse and validate the response
    if response:
        try:
            cleaned_response = clean_json_response(response)
            assessment = json.loads(cleaned_response)
            
            # Validate response structure
            required_fields = ['comorbidity', 'rationale', 'is_met', 'confidence']
            if not all(field in assessment for field in required_fields):
                raise ValueError("Missing required fields in response")
            
            # Validate condition matches
            if assessment["comorbidity"] != chronic_condition:
                raise ValueError("Response comorbidity does not match requested condition")
            
            return assessment
        except (json.JSONDecodeError, ValueError) as e:
            logging.error(f"Error parsing response for patient {patient_id}, "
                         f"condition {chronic_condition}: {e}")
            return None
    else:
        return None

print("Response processing functions defined")

---
## Part 3: Running the Experiment

### 3.1 Main Evaluation Loop

The evaluation process:
1. Loads patient data
2. For each patient:
   - For each chronic condition:
     - Generate prompt
     - Get LLM prediction
     - Parse and store result
3. Save results incrementally (to handle interruptions)
4. Resume from last checkpoint if interrupted

In [None]:
def save_results(results, output_file):
    """Save evaluation results to JSON file."""
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=4)
    logging.info(f"Results saved to {output_file}")

def run_evaluation(client, model_name, patient_data_file, chronic_conditions_list, 
                   output_file, evaluation_type='single'):
    """
    Run chronic condition evaluation for all patients.
    
    Args:
        client: LLM client instance
        model_name: Model identifier
        patient_data_file: Path to patient data JSON
        chronic_conditions_list: List of conditions to evaluate
        output_file: Path to save results
        evaluation_type: 'single' (one condition at a time) or 'batch' (all conditions)
    """
    results = {}
    
    # Load existing results if available (for resuming)
    if os.path.exists(output_file):
        with open(output_file, 'r') as f:
            results = json.load(f)
        logging.info(f"Loaded existing results from {output_file}")
    
    # Load patient data
    with open(patient_data_file, 'r') as f:
        patient_data = json.load(f)
    
    # Evaluate each patient
    for patient_id, patient_info in patient_data.items():
        patient_summary = json.dumps(patient_info)
        logging.info(f"Evaluating patient ID: {patient_id}")
        
        # Initialize patient results if not exists
        if patient_id not in results:
            results[patient_id] = {
                "patient_summary": patient_summary,
                "assessments": []
            }
        
        # Determine which conditions still need evaluation
        existing_conditions = {
            assessment['comorbidity'] 
            for assessment in results[patient_id]['assessments']
        }
        remaining_conditions = [
            condition 
            for condition in chronic_conditions_list 
            if condition not in existing_conditions
        ]
        
        # Skip if patient fully evaluated
        if not remaining_conditions:
            logging.info(f"Patient {patient_id} already fully evaluated")
            continue
        
        # Evaluate remaining conditions
        for condition in remaining_conditions:
            assessment = evaluate_condition(
                client, model_name, patient_id, patient_summary, condition
            )
            
            if assessment:
                results[patient_id]['assessments'].append(assessment)
                # Save progress after each condition
                save_results(results, output_file)
    
    logging.info(f"Evaluation complete. Results saved to {output_file}")
    return results

print("Evaluation loop defined")

### 3.2 Example: Running Evaluation

Example of running the evaluation with an LLM API:

In [None]:
# Configuration
CHRONIC_CONDITIONS = [
    'Hyperlipidemia', 'Autoimmune diseases', 'Hypertension', 
    'Benign prostate hypertrophy', 'Cancer', 'Heart Failure', 
    'Arrhythmias', 'Anxiety disorder', 'Disorders of thyroid gland', 
    'Ischemic heart disease', 'Osteoporosis', 'Disorder of vertebral column', 
    'Depression', 'Urinary incontinence', 'Osteoarthritis', 'Diabetes', 
    'Asthma', 'Anemia', 'Peripheral neuropathy', 'Epilepsy'
]

# Example: Run evaluation (requires actual data and API credentials)
"""
# Initialize LLM client
client = GenericLLMClient(env_path='.env')

# Run evaluation
results = run_evaluation(
    client=client,
    model_name='your-model-name',  # Specify your model
    patient_data_file='data/patient_data.json',
    chronic_conditions_list=CHRONIC_CONDITIONS,
    output_file='results/model_results.json',
    evaluation_type='single'
)
"""

print("Example evaluation code provided (commented out)")

---
## Part 4: Evaluation Metrics

### 4.1 Computing Performance Metrics

We compare LLM predictions against ground truth labels using standard classification metrics:
- **Accuracy**: Overall correctness
- **Precision**: Of predicted positives, how many are correct
- **Recall (Sensitivity)**: Of actual positives, how many are detected
- **F1 Score**: Harmonic mean of precision and recall
- **Specificity**: Of actual negatives, how many are correctly identified

In [None]:
import pandas as pd
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, 
    f1_score, confusion_matrix
)

def calculate_metrics(ground_truth, predictions):
    """
    Calculate classification metrics.
    
    Args:
        ground_truth: List of true labels (0/1)
        predictions: List of predicted labels (0/1)
    
    Returns:
        Dictionary of metrics
    """
    # Calculate confusion matrix
    tn, fp, fn, tp = confusion_matrix(ground_truth, predictions).ravel()
    
    # Calculate metrics
    metrics = {
        'accuracy': accuracy_score(ground_truth, predictions),
        'precision': precision_score(ground_truth, predictions, zero_division=0),
        'recall': recall_score(ground_truth, predictions, zero_division=0),
        'f1_score': f1_score(ground_truth, predictions, zero_division=0),
        'specificity': tn / (tn + fp) if (tn + fp) > 0 else 0,
        'true_positives': tp,
        'false_positives': fp,
        'true_negatives': tn,
        'false_negatives': fn
    }
    
    return metrics

def evaluate_model_performance(results_file, ground_truth_dict, conditions_list):
    """
    Evaluate model performance against ground truth.
    
    Args:
        results_file: Path to model results JSON
        ground_truth_dict: Dictionary mapping conditions to patient IDs
        conditions_list: List of conditions to evaluate
    
    Returns:
        DataFrame with metrics per condition
    """
    # Load results
    with open(results_file, 'r') as f:
        results = json.load(f)
    
    # Calculate metrics for each condition
    condition_metrics = []
    
    for condition in conditions_list:
        ground_truth = []
        predictions = []
        
        # Collect ground truth and predictions for this condition
        for patient_id, patient_results in results.items():
            # Ground truth: is patient in ground truth set?
            gt_label = 1 if int(patient_id) in ground_truth_dict.get(condition, []) else 0
            
            # Find prediction for this condition
            pred_label = 0
            for assessment in patient_results['assessments']:
                if assessment['comorbidity'] == condition:
                    pred_label = 1 if assessment['is_met'] else 0
                    break
            
            ground_truth.append(gt_label)
            predictions.append(pred_label)
        
        # Calculate metrics
        metrics = calculate_metrics(ground_truth, predictions)
        metrics['condition'] = condition
        metrics['prevalence'] = sum(ground_truth) / len(ground_truth)
        condition_metrics.append(metrics)
    
    # Create DataFrame
    df_metrics = pd.DataFrame(condition_metrics)
    
    # Reorder columns
    column_order = [
        'condition', 'prevalence', 'accuracy', 'precision', 
        'recall', 'f1_score', 'specificity',
        'true_positives', 'false_positives', 
        'true_negatives', 'false_negatives'
    ]
    df_metrics = df_metrics[column_order]
    
    return df_metrics

print("Metrics calculation functions defined")

### 4.2 Example: Computing Metrics

Example of computing metrics for a model:

In [None]:
# Example: Compute metrics (requires actual results)
"""
# Load ground truth
with open('data/patients_any.pkl', 'rb') as f:
    ground_truth = pickle.load(f)

# Evaluate model performance
model_metrics = evaluate_model_performance(
    results_file='results/model_results.json',
    ground_truth_dict=ground_truth,
    conditions_list=CHRONIC_CONDITIONS
)

# Display results
print("Model Performance Metrics:")
print(model_metrics.round(3))

# Calculate overall metrics
overall_metrics = {
    'Mean Accuracy': model_metrics['accuracy'].mean(),
    'Mean Precision': model_metrics['precision'].mean(),
    'Mean Recall': model_metrics['recall'].mean(),
    'Mean F1 Score': model_metrics['f1_score'].mean(),
    'Mean Specificity': model_metrics['specificity'].mean()
}

print("\nOverall Performance:")
for metric, value in overall_metrics.items():
    print(f"{metric}: {value:.3f}")
"""

print("Example metrics computation provided (commented out)")

### 4.3 Comparing Multiple Models

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def compare_models(results_files, model_names, ground_truth_dict, conditions_list):
    """
    Compare performance across multiple models.
    
    Args:
        results_files: List of paths to model results
        model_names: List of model names
        ground_truth_dict: Ground truth dictionary
        conditions_list: List of conditions
    
    Returns:
        DataFrame with comparison
    """
    all_metrics = []
    
    for results_file, model_name in zip(results_files, model_names):
        metrics = evaluate_model_performance(
            results_file, ground_truth_dict, conditions_list
        )
        metrics['model'] = model_name
        all_metrics.append(metrics)
    
    # Combine all metrics
    df_comparison = pd.concat(all_metrics, ignore_index=True)
    
    return df_comparison

def plot_model_comparison(df_comparison, metric='f1_score'):
    """
    Create visualization comparing models.
    
    Args:
        df_comparison: DataFrame from compare_models
        metric: Which metric to plot
    """
    plt.figure(figsize=(14, 6))
    
    # Boxplot comparing models
    sns.boxplot(data=df_comparison, x='model', y=metric)
    plt.title(f'Model Comparison: {metric.replace("_", " ").title()}')
    plt.ylabel(metric.replace("_", " ").title())
    plt.xlabel('Model')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    print(f"\n{metric.upper()} Summary by Model:")
    summary = df_comparison.groupby('model')[metric].agg(['mean', 'std', 'min', 'max'])
    print(summary.round(3))

print("Model comparison functions defined")

---
## Part 5: Analysis and Insights

### 5.1 Condition-Specific Performance

Analyze which conditions are easier or harder to detect:

In [None]:
def analyze_condition_difficulty(df_metrics, top_n=10):
    """
    Identify easiest and hardest conditions to detect.
    
    Args:
        df_metrics: DataFrame with condition metrics
        top_n: Number of conditions to show
    """
    # Sort by F1 score
    df_sorted = df_metrics.sort_values('f1_score', ascending=False)
    
    print(f"Top {top_n} Easiest Conditions to Detect:")
    print(df_sorted[['condition', 'f1_score', 'precision', 'recall']].head(top_n))
    
    print(f"\nTop {top_n} Hardest Conditions to Detect:")
    print(df_sorted[['condition', 'f1_score', 'precision', 'recall']].tail(top_n))
    
    # Visualize
    plt.figure(figsize=(12, 8))
    
    # Plot F1 scores
    plt.barh(range(len(df_sorted)), df_sorted['f1_score'])
    plt.yticks(range(len(df_sorted)), df_sorted['condition'])
    plt.xlabel('F1 Score')
    plt.title('F1 Score by Chronic Condition')
    plt.tight_layout()
    plt.show()

print("Condition difficulty analysis function defined")

### 5.2 Error Analysis

Examine false positives and false negatives:

In [None]:
def analyze_errors(results_file, ground_truth_dict, condition, error_type='fp'):
    """
    Analyze false positives or false negatives for a condition.
    
    Args:
        results_file: Path to model results
        ground_truth_dict: Ground truth dictionary
        condition: Condition to analyze
        error_type: 'fp' for false positives, 'fn' for false negatives
    
    Returns:
        List of error cases with details
    """
    # Load results
    with open(results_file, 'r') as f:
        results = json.load(f)
    
    errors = []
    ground_truth_set = set(ground_truth_dict.get(condition, []))
    
    for patient_id, patient_results in results.items():
        patient_id_int = int(patient_id)
        has_condition = patient_id_int in ground_truth_set
        
        # Find prediction
        predicted = False
        rationale = ""
        confidence = ""
        
        for assessment in patient_results['assessments']:
            if assessment['comorbidity'] == condition:
                predicted = assessment['is_met']
                rationale = assessment.get('rationale', '')
                confidence = assessment.get('confidence', '')
                break
        
        # Check for errors
        if error_type == 'fp' and predicted and not has_condition:
            errors.append({
                'patient_id': patient_id,
                'error_type': 'False Positive',
                'rationale': rationale,
                'confidence': confidence
            })
        elif error_type == 'fn' and not predicted and has_condition:
            errors.append({
                'patient_id': patient_id,
                'error_type': 'False Negative',
                'rationale': rationale,
                'confidence': confidence
            })
    
    return errors

print("Error analysis function defined")

---
## Appendix: Dependencies and Setup

### Required Python Packages

```bash
pip install requests python-dotenv pandas numpy scikit-learn matplotlib seaborn
```

### Environment Variables (.env file)

Configure based on your LLM API provider:

```
# Generic LLM API Configuration
LLM_API_KEY=your_api_key_here
LLM_API_ENDPOINT=https://api.your-provider.com/v1/chat/completions
LLM_API_VERSION=2024-01-01
```

### Data Structure

```
project/
├── data/
│   ├── patient_data.json           # Patient clinical notes
│   ├── patients_any.pkl            # Ground truth dictionary
│   ├── patients_by_condition.pkl   # Diagnosis-based labels
│   ├── patients_by_drug_condition.pkl
│   ├── patients_by_lab.pkl
│   └── patients_by_procedure.pkl
├── results/
│   └── model_results.json          # Model evaluation results
└── .env                            # API credentials
```