In [30]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from sklearn.metrics import (
    classification_report, confusion_matrix, roc_auc_score,
    brier_score_loss, log_loss
)
from sklearn.calibration import calibration_curve
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import os

# ----------------------------------
# Helper functions (define first!)
# ----------------------------------
def multiclass_brier_score(y_true_int, y_proba, classes):
    """
    Generalized Brier score for multi-class: mean over samples of sum_j (p_ij - y_ij)^2
    """
    Y = label_binarize(y_true_int, classes=classes)  # shape: (n_samples, n_classes)
    return float(np.mean(np.sum((y_proba - Y) ** 2, axis=1)))

def plot_reliability_curves(y_true_int, y_proba, class_names, model_name, safe_name, outdir="./output"):
    """
    One-vs-rest reliability (calibration) curves for each class (per-model figure).
    """
    classes = np.arange(len(class_names))
    fig, axes = plt.subplots(1, len(class_names), figsize=(5 * len(class_names), 4), sharey=True)
    if len(class_names) == 1:
        axes = [axes]

    for k, cls_name in enumerate(class_names):
        y_true_bin = (y_true_int == classes[k]).astype(int)
        frac_pos, mean_pred = calibration_curve(y_true_bin, y_proba[:, k], n_bins=10, strategy="quantile")

        ax = axes[k]
        ax.plot(mean_pred, frac_pos, marker="o")
        ax.plot([0, 1], [0, 1], linestyle="--", linewidth=1)
        ax.set_title(f"{model_name}: {cls_name}")
        ax.set_xlabel("Mean predicted probability")
        if k == 0:
            ax.set_ylabel("Fraction of positives")
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)

    plt.tight_layout()
    os.makedirs(outdir, exist_ok=True)
    plt.savefig(f"{outdir}/{safe_name}_reliability.png", dpi=200)
    plt.close()

def plot_reliability_combined(results, class_names, outpath="./output/reliability_combined.png"):
    """
    Combined reliability figure: 1 row x (#models) columns.
    Each panel shows three class curves for the corresponding model.
    """
    n_models = len(results)
    fig, axes = plt.subplots(1, n_models, figsize=(5 * n_models, 4), sharey=True)
    if n_models == 1:
        axes = [axes]

    for ax, res in zip(axes, results):
        y_true_int = res["y_true_int"]
        y_proba = res["y_proba"]
        model_name = res["name"]

        classes = np.arange(len(class_names))
        for k, cls_name in enumerate(class_names):
            y_true_bin = (y_true_int == classes[k]).astype(int)
            frac_pos, mean_pred = calibration_curve(y_true_bin, y_proba[:, k], n_bins=10, strategy="quantile")
            ax.plot(mean_pred, frac_pos, marker="o", label=cls_name)

        ax.plot([0, 1], [0, 1], linestyle="--", linewidth=1)
        ax.set_title(model_name)
        ax.set_xlabel("Mean predicted probability")
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)

    axes[0].set_ylabel("Fraction of positives")
    axes[-1].legend(title="Class", loc="lower right")
    plt.tight_layout()
    os.makedirs(os.path.dirname(outpath), exist_ok=True)
    plt.savefig(outpath, dpi=200)
    plt.close()

def plot_confusion_matrices_combined(results, class_names, outpath="./output/confusion_matrices_combined.png"):
    """
    Combined confusion-matrix figure: 1 row x (#models) columns.
    """
    n_models = len(results)
    fig, axes = plt.subplots(1, n_models, figsize=(5 * n_models, 4), sharey=True)
    if n_models == 1:
        axes = [axes]

    for ax, res in zip(axes, results):
        cm = res["cm"]
        model_name = res["name"]
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=class_names, yticklabels=class_names, ax=ax)
        ax.set_title(f"{model_name}")
        ax.set_xlabel("Predicted")
        ax.set_ylabel("Actual")

    plt.tight_layout()
    os.makedirs(os.path.dirname(outpath), exist_ok=True)
    plt.savefig(outpath, dpi=200)
    plt.close()

def print_reliability_tables(y_true_int, y_proba, class_names, n_bins=10):
    """
    Print quantile-binned reliability (calibration) tables per class:
    columns: bin, mean_pred (avg predicted prob), frac_pos (observed rate)
    """
    classes = np.arange(len(class_names))
    for k, cls_name in enumerate(class_names):
        y_true_bin = (y_true_int == classes[k]).astype(int)
        frac_pos, mean_pred = calibration_curve(
            y_true_bin, y_proba[:, k], n_bins=n_bins, strategy="quantile"
        )
        calib_df = pd.DataFrame({
            "bin": np.arange(1, len(frac_pos) + 1),
            "mean_pred": np.round(mean_pred, 3),
            "frac_pos": np.round(frac_pos, 3)
        })
        print(f"\nReliability (quantile bins) — Class: {cls_name}")
        print(calib_df.to_string(index=False))

# ----------------------------------
# Main
# ----------------------------------
os.makedirs("./output", exist_ok=True)

# Load dataset
df = pd.read_csv("./data/simulated_weekly_burnout.csv")

# Encode categorical target
label_encoder = LabelEncoder()
df["burnout_label"] = label_encoder.fit_transform(df["burnout"])

# Define features
features = ['avg_tired', 'avg_capable', 'avg_meaningful']
X = df[features]
y = df["burnout_label"]

# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, stratify=y, test_size=0.3, random_state=42
)

# Scale features (only for Logistic Regression)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Initialize models
models = {
    "Logistic Regression": LogisticRegression(
        multi_class='multinomial', solver='lbfgs', max_iter=1000, random_state=42
    ),
    "Random Forest": RandomForestClassifier(n_estimators=100, random_state=42),
    "XGBoost": XGBClassifier(
        objective='multi:softprob', num_class=3, eval_metric='mlogloss',
        random_state=42
    )
}

# Collect results for combined figures
combined_results = []

# Train, evaluate, and save each model
for name, model in models.items():
    print(f"\n=== {name} ===")
    safe_name = name.lower().replace(" ", "_")  # make available early

    # Fit
    if name == "Logistic Regression":
        model.fit(X_train_scaled, y_train)
        X_eval = X_test_scaled
    else:
        model.fit(X_train, y_train)
        X_eval = X_test

    # Predict labels and probabilities
    y_pred = model.predict(X_eval)
    y_proba = model.predict_proba(X_eval)

    # ---------- Calibration metrics ----------
    n_classes = len(label_encoder.classes_)
    classes_idx = np.arange(n_classes)
    Y_bin = label_binarize(y_test, classes=classes_idx)

    # (a) One-vs-rest Brier score per class
    brier_per_class = {}
    for k, cls in enumerate(label_encoder.classes_):
        brier_k = brier_score_loss(Y_bin[:, k], y_proba[:, k])
        brier_per_class[cls] = brier_k

    # (b) Multiclass Brier score (generalized)
    brier_multi = multiclass_brier_score(y_test, y_proba, classes_idx)

    # (c) Log loss (cross-entropy)
    ll = log_loss(y_test, y_proba, labels=classes_idx)

    print("Calibration:")
    print("  Brier (OVR) per class:", {k: f"{v:.4f}" for k, v in brier_per_class.items()})
    print(f"  Brier (multiclass): {brier_multi:.4f}")
    print(f"  Log loss: {ll:.4f}")

    # Print reliability (calibration) tables per class
    print_reliability_tables(y_true_int=y_test, y_proba=y_proba, class_names=label_encoder.classes_, n_bins=10)

    # (d) Reliability (calibration) curves (per-model; saved to ./output)
    plot_reliability_curves(
        y_true_int=y_test,
        y_proba=y_proba,
        class_names=label_encoder.classes_,
        model_name=name,
        safe_name=safe_name,
        outdir="./output"
    )

    # ---------- Interpretability summaries ----------
    # Logistic Regression: coefficients & odds ratios
    if name == "Logistic Regression":
        coefs_df = pd.DataFrame(model.coef_, columns=features, index=label_encoder.classes_)
        odds_df = np.exp(coefs_df)
        coefs_df.to_csv(f"./output/{safe_name}_coefficients.csv")
        odds_df.to_csv(f"./output/{safe_name}_odds_ratios.csv")

        plt.figure(figsize=(6, 3.5))
        sns.heatmap(coefs_df, annot=True, fmt=".2f", cmap="coolwarm", center=0)
        plt.title("Multinomial Logistic Regression Coefficients")
        plt.tight_layout()
        plt.savefig(f"./output/{safe_name}_coef_heatmap.png", dpi=200)
        plt.close()

        # Console summaries (rounded) for quick reading
        print("\nLogistic Regression — Coefficients (per class):")
        print(coefs_df.round(3).to_string())

        print("\nLogistic Regression — Odds ratios (per class):")
        print(odds_df.round(3).to_string())

        # Optional: quick “top driver” per class by absolute coefficient
        abs_coefs = coefs_df.abs()
        top_drivers = abs_coefs.idxmax(axis=1)
        print("\nTop driver feature per class (by |coef|):")
        for cls in coefs_df.index:
            feat = top_drivers.loc[cls]
            sign = "↑ (risk up)" if coefs_df.loc[cls, feat] > 0 else "↓ (risk down)"
            print(f"  {cls}: {feat} ({coefs_df.loc[cls, feat]:.3f}) {sign}")

    # Tree models: native feature importances
    if hasattr(model, "feature_importances_"):
        imp = pd.Series(model.feature_importances_, index=features).sort_values(ascending=False)
        imp.to_csv(f"./output/{safe_name}_feature_importances.csv")

        plt.figure(figsize=(5, 3.5))
        imp.plot(kind="bar")
        plt.ylabel("Importance")
        plt.title(f"{name} Feature Importances")
        plt.tight_layout()
        plt.savefig(f"./output/{safe_name}_feature_importances.png", dpi=200)
        plt.close()

        print(f"\n{name} — Feature importances:")
        print(imp.round(3).to_string())
        # Optional: one-line driver summary
        top_feat = imp.index[0]
        print(f"Top driver: {top_feat} ({imp.iloc[0]:.3f})")

    # Decode labels for display
    y_test_labels = label_encoder.inverse_transform(y_test)
    y_pred_labels = label_encoder.inverse_transform(y_pred)

    # 1) Classification Report
    print("Classification Report:")
    print(classification_report(y_test_labels, y_pred_labels))

    # 2) ROC-AUC (macro-averaged One-vs-Rest)
    roc_auc = roc_auc_score(
        y_true=y_test,
        y_score=y_proba,
        multi_class='ovr',
        average='macro'
    )
    print(f"ROC-AUC (macro, OVR): {roc_auc:.3f}\n")

    # 3) Confusion Matrix
    cm = confusion_matrix(
        y_test_labels, y_pred_labels, labels=label_encoder.classes_
    )
    print("Confusion Matrix:")
    print(pd.DataFrame(cm, index=label_encoder.classes_, columns=label_encoder.classes_))

    # Plot heatmap AND SAVE IMAGE (per-model)
    plt.figure(figsize=(5, 4))
    sns.heatmap(
        cm, annot=True, fmt='d', cmap='Blues',
        xticklabels=label_encoder.classes_,
        yticklabels=label_encoder.classes_
    )
    plt.title(f"{name} Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.tight_layout()
    plt.savefig(f"./output/{safe_name}_confusion_matrix.png", dpi=200)
    plt.close()

    # Save the model
    joblib.dump(model, f"./output/{safe_name}_model.pkl")

    # ---- Collect for combined figures ----
    combined_results.append({
        "name": name,
        "safe_name": safe_name,
        "y_true_int": y_test,
        "y_proba": y_proba,
        "cm": cm
    })

# Save preprocessing tools
joblib.dump(scaler, "./output/burnout_scaler.pkl")
joblib.dump(label_encoder, "./output/burnout_label_encoder.pkl")

# ---------- Create combined figures ----------
plot_reliability_combined(
    results=combined_results,
    class_names=label_encoder.classes_,
    outpath="./output/reliability_combined.png"
)

plot_confusion_matrices_combined(
    results=combined_results,
    class_names=label_encoder.classes_,
    outpath="./output/confusion_matrices_combined.png"
)



=== Logistic Regression ===
Calibration:
  Brier (OVR) per class: {'High': '0.0029', 'Low': '0.0021', 'Moderate': '0.0052'}
  Brier (multiclass): 0.0102
  Log loss: 0.0376
Classification Report:
              precision    recall  f1-score   support

        High       1.00      0.98      0.99        60
         Low       1.00      1.00      1.00       150
    Moderate       0.99      1.00      0.99        90

    accuracy                           1.00       300
   macro avg       1.00      0.99      1.00       300
weighted avg       1.00      1.00      1.00       300

ROC-AUC (macro, OVR): 1.000

Confusion Matrix:
          High  Low  Moderate
High        59    0         1
Low          0  150         0
Moderate     0    0        90

=== Random Forest ===
Calibration:
  Brier (OVR) per class: {'High': '0.0006', 'Low': '0.0006', 'Moderate': '0.0012'}
  Brier (multiclass): 0.0023
  Log loss: 0.0083
Classification Report:
              precision    recall  f1-score   support

        Hig