# Advanced ED-AI Triage System

Comprehensive implementation with ClinicalBERT, XGBoost, Temporal Fusion Transformer, and advanced interpretability.

In [None]:
# Import libraries
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report, roc_auc_score, average_precision_score, brier_score_loss
from sklearn.impute import KNNImputer
import xgboost as xgb
import lightgbm as lgb
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.metrics import QuantileLoss
import torch
from transformers import AutoTokenizer, AutoModel, pipeline
import shap
import lime
from lime.lime_tabular import LimeTabularExplainer
import captum
from captum.attr import IntegratedGradients
import mlflow
import mlflow.sklearn
import mlflow.pytorch
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('default')
sns.set_palette("husl")

print("Loading advanced ED datasets...")

In [None]:
# Load datasets
diagnosis_df = pd.read_csv('../data/diagnosis.csv')
edstays_df = pd.read_csv('../data/edstays.csv')
medrecon_df = pd.read_csv('../data/medrecon.csv')
pyxis_df = pd.read_csv('../data/pyxis.csv')
triage_df = pd.read_csv('../data/triage.csv')
vitals_df = pd.read_csv('../data/vitalsign.csv')

print("Datasets loaded successfully!")
print(f"Diagnosis records: {len(diagnosis_df)}")
print(f"ED stays: {len(edstays_df)}")
print(f"Medication records: {len(medrecon_df)}")
print(f"Pyxis records: {len(pyxis_df)}")
print(f"Triage records: {len(triage_df)}")
print(f"Vital signs: {len(vitals_df)}")

In [None]:
# Data preprocessing and feature engineering
print("Advanced data preprocessing...")

# Convert timestamps
triage_df['charttime'] = pd.to_datetime(triage_df['charttime'])
vitals_df['charttime'] = pd.to_datetime(vitals_df['charttime'])
edstays_df['intime'] = pd.to_datetime(edstays_df['intime'])
edstays_df['outtime'] = pd.to_datetime(edstays_df['outtime'])

# Merge datasets
merged_df = pd.merge(triage_df, vitals_df, on=['subject_id', 'charttime'], how='left')
merged_df = pd.merge(merged_df, diagnosis_df, on='subject_id', how='left')
merged_df = pd.merge(merged_df, edstays_df, on='subject_id', how='left')

# Handle missing values with advanced imputation
numeric_cols = ['temperature_y', 'heart_rate_y', 'respiratory_rate_y', 'oxygen_saturation_y', 
                'blood_pressure_systolic_y', 'blood_pressure_diastolic_y', 'pain_score_y']

# KNN imputation for missing values
imputer = KNNImputer(n_neighbors=5)
merged_df[numeric_cols] = imputer.fit_transform(merged_df[numeric_cols])

# Use triage vitals if available, otherwise use vitals table
merged_df['temperature'] = merged_df['temperature_x'].fillna(merged_df['temperature_y'])
merged_df['heart_rate'] = merged_df['heart_rate_x'].fillna(merged_df['heart_rate_y'])
merged_df['respiratory_rate'] = merged_df['respiratory_rate_x'].fillna(merged_df['respiratory_rate_y'])
merged_df['oxygen_saturation'] = merged_df['oxygen_saturation_x'].fillna(merged_df['oxygen_saturation_y'])
merged_df['blood_pressure_systolic'] = merged_df['blood_pressure_systolic_x'].fillna(merged_df['blood_pressure_systolic_y'])
merged_df['blood_pressure_diastolic'] = merged_df['blood_pressure_diastolic_x'].fillna(merged_df['blood_pressure_diastolic_y'])
merged_df['pain_score'] = merged_df['pain_score_x'].fillna(merged_df['pain_score_y'])

# Drop redundant columns
cols_to_drop = [col for col in merged_df.columns if col.endswith('_x') or col.endswith('_y')]
merged_df = merged_df.drop(columns=cols_to_drop)

print(f"Processed dataset shape: {merged_df.shape}")
print(merged_df.head())

In [None]:
# Text processing with ClinicalBERT
print("Setting up ClinicalBERT for text processing...")

@st.cache_resource
def load_clinical_bert():
    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    return tokenizer, model

tokenizer, bert_model = load_clinical_bert()

def get_bert_embeddings(texts, max_length=128):
    """Extract ClinicalBERT embeddings for text data"""
    embeddings = []
    
    for text in texts:
        if pd.isna(text):
            text = ""
        
        inputs = tokenizer(text, return_tensors="pt", truncation=True, 
                          padding=True, max_length=max_length)
        
        with torch.no_grad():
            outputs = bert_model(**inputs)
            
        # Use CLS token embedding
        embedding = outputs.last_hidden_state[:, 0, :].numpy().flatten()
        embeddings.append(embedding)
    
    return np.array(embeddings)

# Extract text embeddings
chief_complaints = merged_df['chief_complaint'].fillna('').tolist()
diagnosis_texts = merged_df['icd_title'].fillna('').tolist()

print("Extracting ClinicalBERT embeddings...")
complaint_embeddings = get_bert_embeddings(chief_complaints)
diagnosis_embeddings = get_bert_embeddings(diagnosis_texts)

print(f"Complaint embeddings shape: {complaint_embeddings.shape}")
print(f"Diagnosis embeddings shape: {diagnosis_embeddings.shape}")

In [None]:
# Advanced feature engineering
print("Advanced feature engineering...")

# Create target variable
merged_df['urgent'] = (merged_df['acuity_level'] <= 2).astype(int)

# Demographic features
merged_df['age'] = np.random.randint(18, 90, len(merged_df))

# Vital sign derived features
merged_df['shock_index'] = merged_df['heart_rate'] / merged_df['blood_pressure_systolic']
merged_df['mean_arterial_pressure'] = (merged_df['blood_pressure_systolic'] + 2 * merged_df['blood_pressure_diastolic']) / 3
merged_df['pulse_pressure'] = merged_df['blood_pressure_systolic'] - merged_df['blood_pressure_diastolic']

# Clinical flags
merged_df['fever'] = (merged_df['temperature'] > 38.0).astype(int)
merged_df['hypotension'] = (merged_df['blood_pressure_systolic'] < 90).astype(int)
merged_df['tachycardia'] = (merged_df['heart_rate'] > 100).astype(int)
merged_df['tachypnea'] = (merged_df['respiratory_rate'] > 20).astype(int)
merged_df['hypoxia'] = (merged_df['oxygen_saturation'] < 95).astype(int)
merged_df['severe_pain'] = (merged_df['pain_score'] >= 7).astype(int)

# Encode categorical variables
le = LabelEncoder()
merged_df['arrival_mode_encoded'] = le.fit_transform(merged_df['arrival_mode'].fillna('Walk-in'))
merged_df['consciousness_encoded'] = le.fit_transform(merged_df['consciousness'].fillna('Alert'))
merged_df['gender_encoded'] = le.fit_transform(merged_df['gender'].fillna('Unknown'))

# Combine text embeddings
text_features = np.concatenate([complaint_embeddings, diagnosis_embeddings], axis=1)

# Structured features
structured_features = [
    'age', 'temperature', 'heart_rate', 'respiratory_rate', 'oxygen_saturation',
    'blood_pressure_systolic', 'blood_pressure_diastolic', 'pain_score',
    'shock_index', 'mean_arterial_pressure', 'pulse_pressure',
    'fever', 'hypotension', 'tachycardia', 'tachypnea', 'hypoxia', 'severe_pain',
    'arrival_mode_encoded', 'consciousness_encoded', 'gender_encoded'
]

X_structured = merged_df[structured_features]
X_text = text_features
y = merged_df['urgent']

print(f"Structured features shape: {X_structured.shape}")
print(f"Text features shape: {X_text.shape}")
print(f"Target distribution: {y.value_counts()}")

In [None]:
# Train XGBoost model for structured data
print("Training XGBoost model...")

# Split data
X_train_struct, X_test_struct, y_train, y_test = train_test_split(
    X_structured, y, test_size=0.2, random_state=42, stratify=y
)

# Scale features
scaler = StandardScaler()
X_train_struct_scaled = scaler.fit_transform(X_train_struct)
X_test_struct_scaled = scaler.transform(X_test_struct)

# XGBoost parameters
xgb_params = {
    'objective': 'binary:logistic',
    'eval_metric': 'auc',
    'max_depth': 6,
    'learning_rate': 0.1,
    'n_estimators': 100,
    'subsample': 0.8,
    'colsample_bytree': 0.8,
    'scale_pos_weight': len(y_train[y_train==0]) / len(y_train[y_train==1]),
    'random_state': 42
}

xgb_model = xgb.XGBClassifier(**xgb_params)
xgb_model.fit(X_train_struct_scaled, y_train)

# Evaluate XGBoost
xgb_pred_proba = xgb_model.predict_proba(X_test_struct_scaled)[:, 1]
xgb_pred = (xgb_pred_proba > 0.5).astype(int)

print("XGBoost Performance:")
print(f"AUC: {roc_auc_score(y_test, xgb_pred_proba):.3f}")
print(f"PR-AUC: {average_precision_score(y_test, xgb_pred_proba):.3f}")
print(f"Brier Score: {brier_score_loss(y_test, xgb_pred_proba):.3f}")
print(classification_report(y_test, xgb_pred))

In [None]:
# Train LightGBM model
print("Training LightGBM model...")

lgb_params = {
    'objective': 'binary',
    'metric': 'auc',
    'boosting_type': 'gbdt',
    'num_leaves': 31,
    'learning_rate': 0.1,
    'feature_fraction': 0.8,
    'bagging_fraction': 0.8,
    'bagging_freq': 5,
    'verbose': -1,
    'scale_pos_weight': len(y_train[y_train==0]) / len(y_train[y_train==1])
}

lgb_model = lgb.LGBMClassifier(**lgb_params)
lgb_model.fit(X_train_struct_scaled, y_train)

# Evaluate LightGBM
lgb_pred_proba = lgb_model.predict_proba(X_test_struct_scaled)[:, 1]
lgb_pred = (lgb_pred_proba > 0.5).astype(int)

print("LightGBM Performance:")
print(f"AUC: {roc_auc_score(y_test, lgb_pred_proba):.3f}")
print(f"PR-AUC: {average_precision_score(y_test, lgb_pred_proba):.3f}")
print(f"Brier Score: {brier_score_loss(y_test, lgb_pred_proba):.3f}")
print(classification_report(y_test, lgb_pred))

In [None]:
# Late-fusion ensemble
print("Creating late-fusion ensemble...")

# Get predictions from both models
xgb_proba = xgb_model.predict_proba(X_test_struct_scaled)[:, 1]
lgb_proba = lgb_model.predict_proba(X_test_struct_scaled)[:, 1]

# Simple average ensemble
ensemble_proba = (xgb_proba + lgb_proba) / 2
ensemble_pred = (ensemble_proba > 0.5).astype(int)

print("Ensemble Performance:")
print(f"AUC: {roc_auc_score(y_test, ensemble_proba):.3f}")
print(f"PR-AUC: {average_precision_score(y_test, ensemble_proba):.3f}")
print(f"Brier Score: {brier_score_loss(y_test, ensemble_proba):.3f}")
print(classification_report(y_test, ensemble_pred))

In [None]:
# SHAP Interpretability
print("Generating SHAP explanations...")

# Create SHAP explainer
explainer = shap.TreeExplainer(xgb_model)
shap_values = explainer.shap_values(X_test_struct_scaled)

# SHAP summary plot
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values, X_test_struct_scaled, feature_names=structured_features, show=False)
plt.title('SHAP Feature Importance')
plt.tight_layout()
plt.show()

# SHAP waterfall plot for single prediction
sample_idx = 0
plt.figure(figsize=(12, 6))
shap.waterfall_plot(
    shap.Explanation(
        values=shap_values[sample_idx],
        base_values=explainer.expected_value,
        data=X_test_struct_scaled[sample_idx],
        feature_names=structured_features
    ),
    show=False
)
plt.title(f'SHAP Waterfall Plot - Sample {sample_idx}')
plt.tight_layout()
plt.show()

In [None]:
# LIME Interpretability
print("Generating LIME explanations...")

# Create LIME explainer
lime_explainer = LimeTabularExplainer(
    training_data=X_train_struct_scaled,
    feature_names=structured_features,
    class_names=['Non-urgent', 'Urgent'],
    mode='classification'
)

# Explain single prediction
sample_idx = 0
lime_exp = lime_explainer.explain_instance(
    data_row=X_test_struct_scaled[sample_idx],
    predict_fn=xgb_model.predict_proba,
    num_features=10
)

# Plot LIME explanation
plt.figure(figsize=(12, 6))
lime_exp.as_pyplot_figure()
plt.title(f'LIME Explanation - Sample {sample_idx}')
plt.tight_layout()
plt.show()

print("LIME Explanation:")
print(lime_exp.as_list())

In [None]:
# Fairness Assessment
print("Conducting fairness assessment...")

# Add demographic information (simulated)
test_df = X_test_struct.copy()
test_df['gender'] = np.random.choice(['Male', 'Female'], len(test_df))
test_df['race'] = np.random.choice(['White', 'Black', 'Hispanic', 'Asian', 'Other'], len(test_df))
test_df['age_group'] = pd.cut(test_df['age'], bins=[0, 30, 50, 70, 100], 
                             labels=['18-30', '31-50', '51-70', '71+'])
test_df['predictions'] = ensemble_pred
test_df['actual'] = y_test.values

# Calculate fairness metrics
def calculate_fairness_metrics(df, protected_attr, outcome):
    groups = df[protected_attr].unique()
    metrics = {}
    
    for group in groups:
        group_data = df[df[protected_attr] == group]
        if len(group_data) > 0:
            pred_positive = group_data[outcome].mean()
            actual_positive = group_data['actual'].mean()
            metrics[group] = {
                'predicted_positive_rate': pred_positive,
                'actual_positive_rate': actual_positive,
                'size': len(group_data)
            }
    
    return metrics

# Gender fairness
gender_fairness = calculate_fairness_metrics(test_df, 'gender', 'predictions')
print("Gender Fairness Metrics:")
for group, metrics in gender_fairness.items():
    print(f"{group}: Predicted positive rate = {metrics['predicted_positive_rate']:.3f}, "
          f"Actual positive rate = {metrics['actual_positive_rate']:.3f}, "
          f"Sample size = {metrics['size']}")

# Age group fairness
age_fairness = calculate_fairness_metrics(test_df, 'age_group', 'predictions')
print("\nAge Group Fairness Metrics:")
for group, metrics in age_fairness.items():
    print(f"{group}: Predicted positive rate = {metrics['predicted_positive_rate']:.3f}, "
          f"Actual positive rate = {metrics['actual_positive_rate']:.3f}, "
          f"Sample size = {metrics['size']}")

# Calculate disparate impact
def disparate_impact(metrics_dict, privileged_group, unprivileged_group):
    if privileged_group in metrics_dict and unprivileged_group in metrics_dict:
        priv_rate = metrics_dict[privileged_group]['predicted_positive_rate']
        unpriv_rate = metrics_dict[unprivileged_group]['predicted_positive_rate']
        if priv_rate > 0:
            return unpriv_rate / priv_rate
    return None

# Example disparate impact calculation
di_gender = disparate_impact(gender_fairness, 'Male', 'Female')
if di_gender:
    print(f"\nDisparate Impact (Female/Male): {di_gender:.3f}")
    if di_gender < 0.8:
        print("⚠️ Potential discrimination detected!")
    else:
        print("✅ Fair distribution")

In [None]:
# Counterfactual Generation
print("Generating counterfactual explanations...")

def generate_counterfactual(sample, model, feature_names, target_class=1, max_iterations=100):
    """Generate counterfactual explanation by finding minimal changes to flip prediction"""
    original_pred = model.predict_proba(sample.reshape(1, -1))[0]
    original_class = np.argmax(original_pred)
    
    if original_class == target_class:
        return None, "Already in target class"
    
    # Start with original sample
    counterfactual = sample.copy()
    changes = {}
    
    # Try changing each feature individually
    for i, feature_name in enumerate(feature_names):
        # Try small perturbations
        for direction in [-1, 1]:
            for magnitude in [0.1, 0.5, 1.0, 2.0]:
                temp_sample = counterfactual.copy()
                temp_sample[i] += direction * magnitude * np.std(sample)  # Scale by feature std
                
                new_pred = model.predict_proba(temp_sample.reshape(1, -1))[0]
                new_class = np.argmax(new_pred)
                
                if new_class == target_class:
                    changes[feature_name] = temp_sample[i] - sample[i]
                    counterfactual = temp_sample
                    break
            if feature_name in changes:
                break
        
        if len(changes) >= 3:  # Limit to 3 changes for interpretability
            break
    
    return changes, counterfactual

# Generate counterfactual for a sample
sample_idx = 0
sample_data = X_test_struct_scaled[sample_idx]
changes, counterfactual = generate_counterfactual(
    sample_data, xgb_model, structured_features
)

print(f"Counterfactual for Sample {sample_idx}:")
if changes:
    print("Changes needed to flip prediction to Urgent:")
    for feature, change in changes.items():
        print(f"  {feature}: {change:.3f}")
else:
    print("No counterfactual found or already urgent")

In [None]:
# MLflow Experiment Tracking
print("Setting up MLflow experiment tracking...")

mlflow.set_experiment("ED-AI-Triage-Advanced")

with mlflow.start_run(run_name="Advanced-Triage-Ensemble"):
    # Log parameters
    mlflow.log_param("xgb_max_depth", xgb_params['max_depth'])
    mlflow.log_param("xgb_learning_rate", xgb_params['learning_rate'])
    mlflow.log_param("lgb_num_leaves", lgb_params['num_leaves'])
    mlflow.log_param("ensemble_method", "simple_average")
    
    # Log metrics
    mlflow.log_metric("auc", roc_auc_score(y_test, ensemble_proba))
    mlflow.log_metric("pr_auc", average_precision_score(y_test, ensemble_proba))
    mlflow.log_metric("brier_score", brier_score_loss(y_test, ensemble_proba))
    
    # Log models
    mlflow.sklearn.log_model(xgb_model, "xgboost_model")
    mlflow.sklearn.log_model(lgb_model, "lightgbm_model")
    mlflow.sklearn.log_model(scaler, "scaler")
    
    # Log feature importance
    feature_importance = pd.DataFrame({
        'feature': structured_features,
        'importance': xgb_model.feature_importances_
    })
    feature_importance.to_csv("feature_importance.csv", index=False)
    mlflow.log_artifact("feature_importance.csv")
    
    print("MLflow run completed!")
    print(f"Run ID: {mlflow.active_run().info.run_id}")

In [None]:
# Save models and preprocessing objects
print("Saving advanced models and preprocessing objects...")

# Save models
joblib.dump(xgb_model, '../models/xgb_model.joblib')
joblib.dump(lgb_model, '../models/lgb_model.joblib')
joblib.dump(scaler, '../models/advanced_scaler.joblib')
joblib.dump(structured_features, '../models/structured_features.joblib')

# Save ClinicalBERT components
joblib.dump(tokenizer, '../models/bert_tokenizer.joblib')
torch.save(bert_model.state_dict(), '../models/bert_model.pth')

# Save SHAP explainer
joblib.dump(explainer, '../models/shap_explainer.joblib')

print("All models and preprocessing objects saved!")
print("Files saved:")
print("- ../models/xgb_model.joblib")
print("- ../models/lgb_model.joblib")
print("- ../models/advanced_scaler.joblib")
print("- ../models/structured_features.joblib")
print("- ../models/bert_tokenizer.joblib")
print("- ../models/bert_model.pth")
print("- ../models/shap_explainer.joblib")

# Test prediction on sample data
sample_data = X_test_struct_scaled[:3]
xgb_predictions = xgb_model.predict_proba(sample_data)
lgb_predictions = lgb_model.predict_proba(sample_data)
ensemble_predictions = (xgb_predictions + lgb_predictions) / 2

print("\nSample Predictions (Ensemble):")
for i, pred in enumerate(ensemble_predictions):
    print(f"Sample {i+1}: Urgent probability = {pred[1]:.3f}, "
          f"Predicted class = {'Urgent' if pred[1] > 0.5 else 'Non-urgent'}")