# Interpretabilidad de modelos ML (SHAP)

Notebook para analizar importancias y explicaciones locales de modelos tradicionales (RandomForest / XGBoost) entrenados con features PSG.

> Ajusta las rutas de modelo y dataset según tus artefactos guardados.


In [None]:
import glob
import os
from collections import defaultdict
from datetime import datetime

import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap
from scipy.stats import kendalltau, spearmanr
from sklearn.inspection import permutation_importance

# Configuracion de paths
DATA_PATH = "../data/processed/features_resamp200.parquet"

# Modelos disponibles (cambiar MODEL_NAME para analizar otro modelo)
MODELS = {
    "xgboost_loso": "../models/xgb_loso_best/xgboost_model.pkl",
    "random_forest_bayes": "../models/rf_opt_bayes_best/random_forest_model.pkl",
}
MODEL_NAME = "xgboost_loso"  # Cambiar a 'random_forest_bayes' para RF
MODEL_PATH = MODELS[MODEL_NAME]
OUTPUT_DIR = f"../reports/shap_{MODEL_NAME}"

# Mapeo de etiquetas de sleep staging
STAGE_LABELS = {0: "W", 1: "N1", 2: "N2", 3: "N3", 4: "REM"}

os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Data: {DATA_PATH}\nModelo: {MODEL_NAME} -> {MODEL_PATH}\nOutput: {OUTPUT_DIR}")

In [None]:
# ------------------------------------------------------------
# Cargar datos y modelo
# ------------------------------------------------------------

# Cargar features (usa columnas de features, evita metadata no numerica)
df = (
    pd.read_parquet(DATA_PATH)
    if DATA_PATH.endswith(".parquet")
    else pd.read_feather(DATA_PATH)
)

# Ajusta si tu columna de etiqueta se llama distinto
target_col = "stage"
meta_cols = [
    c
    for c in [
        target_col,
        "subject_core",
        "subject_id",
        "epoch_time_start",
        "epoch_index",
    ]
    if c in df.columns
]
feature_cols = [c for c in df.columns if c not in meta_cols]

X = df[feature_cols]
y = df[target_col] if target_col in df else None

print(f"Features shape: {X.shape}")
print(f"Features: {feature_cols[:10]}{'...' if len(feature_cols) > 10 else ''}")
if y is not None:
    class_counts = y.value_counts().sort_index()
    print(f"\nDistribución de clases ({y.nunique()} clases):")
    for cls, count in class_counts.items():
        label = STAGE_LABELS.get(cls, str(cls))
        print(f"  {label} ({cls}): {count:,} ({100*count/len(y):.1f}%)")

# Cargar modelo
model = joblib.load(MODEL_PATH)
print(f"\nModelo: {type(model).__name__}")

# Mostrar importancia nativa si está disponible
if hasattr(model, "feature_importances_"):
    imp_native = pd.DataFrame(
        {"feature": feature_cols, "importance": model.feature_importances_}
    ).sort_values("importance", ascending=False)
    print("\nTop 10 features (importancia nativa):")
    print(imp_native.head(10).to_string(index=False))

# Opcional: submuestreo para acelerar SHAP si el dataset es grande
MAX_SAMPLES = 8000
if len(X) > MAX_SAMPLES:
    X_sample = X.sample(MAX_SAMPLES, random_state=42)
    y_sample = y.loc[X_sample.index] if y is not None else None
    print(f"\nUsando submuestra de {len(X_sample):,} para SHAP")
else:
    X_sample = X
    y_sample = y
    print(f"\nUsando dataset completo ({len(X_sample):,} muestras) para SHAP")

In [None]:
# ------------------------------------------------------------
# SHAP: explicaciones globales y locales
# ------------------------------------------------------------

# TreeExplainer es eficiente para RF / XGBoost; usar probas para interpretabilidad
try:
    explainer = shap.TreeExplainer(model, model_output="probability")
    print("Explainer: TreeExplainer con output=probability")
except Exception as e:
    print(
        f"TreeExplainer con probability falló ({e}); usando TreeExplainer por defecto"
    )
    explainer = shap.TreeExplainer(model)

shap_values = explainer.shap_values(X_sample)

# Detectar formato de shap_values (lista para multi-clase en SHAP antiguo, o array 3D en SHAP nuevo)
is_multiclass_list = isinstance(shap_values, list)
is_multiclass_3d = isinstance(shap_values, np.ndarray) and shap_values.ndim == 3

n_classes = (
    len(shap_values)
    if is_multiclass_list
    else (shap_values.shape[2] if is_multiclass_3d else 1)
)
print(
    f"SHAP values calculados. Multi-clase: {is_multiclass_list or is_multiclass_3d} ({n_classes} clases)"
)


# Clase principal para plots (modo de y_sample si existe; si no, predicción más probable)
def _choose_main_class():
    if y_sample is not None and len(y_sample) > 0:
        return int(y_sample.mode().iloc[0])
    if hasattr(model, "predict_proba"):
        return int(np.argmax(model.predict_proba(X_sample.iloc[[0]])[0]))
    return 0


main_class = _choose_main_class()
print(
    f"Clase principal para plots: {STAGE_LABELS.get(main_class, main_class)} ({main_class})"
)


# Para modelos multi-clase, calculamos media de |shap| para ranking global
def _mean_abs_shap(shap_vals):
    if isinstance(shap_vals, list):
        return np.mean([np.abs(sv) for sv in shap_vals], axis=0)
    elif isinstance(shap_vals, np.ndarray) and shap_vals.ndim == 3:
        return np.mean(np.abs(shap_vals), axis=2)
    return np.abs(shap_vals)


shap_abs = _mean_abs_shap(shap_values)
mean_abs = shap_abs.mean(axis=0)
top_idx = np.argsort(mean_abs)[::-1]

# Mostrar ranking de features
print("\nTop 15 features por SHAP (|mean|):")
for rank, idx in enumerate(top_idx[:15], 1):
    print(f"  {rank:2d}. {feature_cols[idx]}: {mean_abs[idx]:.4f}")

In [None]:
# ------------------------------------------------------------
# SHAP: Plots globales (bar + beeswarm)
# ------------------------------------------------------------

# Summary bar (importancias globales)
plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values, X_sample, plot_type="bar", show=False, max_display=20)
plt.title("SHAP Feature Importances (global)")
plt.tight_layout()
plt.savefig(
    os.path.join(OUTPUT_DIR, "shap_importances_bar.png"), dpi=150, bbox_inches="tight"
)
plt.show()

# Beeswarm - para multi-clase mostramos una clase representativa o usamos valores agregados
plt.figure(figsize=(10, 8))
if "main_class" not in locals():
    main_class = _choose_main_class() if "_choose_main_class" in locals() else 0

if is_multiclass_list:
    shap.summary_plot(shap_values[main_class], X_sample, show=False, max_display=20)
    plt.title(f"SHAP Beeswarm - Clase {STAGE_LABELS.get(main_class, main_class)}")
elif is_multiclass_3d:
    shap.summary_plot(
        shap_values[:, :, main_class], X_sample, show=False, max_display=20
    )
    plt.title(f"SHAP Beeswarm - Clase {STAGE_LABELS.get(main_class, main_class)}")
else:
    shap.summary_plot(shap_values, X_sample, show=False, max_display=20)
    plt.title("SHAP Beeswarm")
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "shap_beeswarm.png"), dpi=150, bbox_inches="tight")
plt.show()

print(f"Plots guardados en {OUTPUT_DIR}")

In [None]:
# ------------------------------------------------------------
# SHAP por clase: análisis específico para cada etapa del sueño
# ------------------------------------------------------------

# Crear beeswarm para cada clase (importante para entender qué features discriminan cada etapa)
n_classes_to_plot = min(n_classes, 5)  # limitar si hay muchas clases

fig, axes = plt.subplots(1, n_classes_to_plot, figsize=(5 * n_classes_to_plot, 6))
if n_classes_to_plot == 1:
    axes = [axes]

for class_i in range(n_classes_to_plot):
    plt.sca(axes[class_i])
    class_label = STAGE_LABELS.get(class_i, str(class_i))

    if is_multiclass_list:
        sv_class = shap_values[class_i]
    elif is_multiclass_3d:
        sv_class = shap_values[:, :, class_i]
    else:
        sv_class = shap_values

    shap.summary_plot(sv_class, X_sample, show=False, max_display=10, plot_size=None)
    axes[class_i].set_title(f"Clase: {class_label}")

plt.tight_layout()
plt.savefig(
    os.path.join(OUTPUT_DIR, "shap_beeswarm_by_class.png"), dpi=150, bbox_inches="tight"
)
plt.show()

# Tabla: Top 5 features por clase
print("\nTop 5 features más importantes por clase:")
print("-" * 60)
for class_i in range(n_classes):
    class_label = STAGE_LABELS.get(class_i, str(class_i))
    if is_multiclass_list:
        sv_class = shap_values[class_i]
    elif is_multiclass_3d:
        sv_class = shap_values[:, :, class_i]
    else:
        sv_class = shap_values

    mean_abs_class = np.abs(sv_class).mean(axis=0)
    top5_idx = np.argsort(mean_abs_class)[::-1][:5]
    top5_features = [feature_cols[i] for i in top5_idx]
    print(f"{class_label}: {', '.join(top5_features)}")

In [None]:
# ------------------------------------------------------------
# SHAP: Dependence plots para top features
# ------------------------------------------------------------

N_TOP = 5

# Preparar shap_values para dependence plot (necesita 2D para una clase)
if "main_class" not in locals():
    main_class = _choose_main_class() if "_choose_main_class" in locals() else 0

if is_multiclass_list:
    shap_for_dep = shap_values[main_class]
elif is_multiclass_3d:
    shap_for_dep = shap_values[:, :, main_class]
else:
    shap_for_dep = shap_values

fig, axes = plt.subplots(1, N_TOP, figsize=(4 * N_TOP, 4))
for i, idx in enumerate(top_idx[:N_TOP]):
    fname = feature_cols[idx]
    plt.sca(axes[i])
    shap.dependence_plot(fname, shap_for_dep, X_sample, show=False, ax=axes[i])
    axes[i].set_title(fname)
plt.tight_layout()
plt.savefig(
    os.path.join(OUTPUT_DIR, "shap_dependence_top5.png"), dpi=150, bbox_inches="tight"
)
plt.show()

print(f"Dependence plots para top {N_TOP} features guardados")

In [None]:
# ------------------------------------------------------------
# SHAP local: explica casos individuales
# ------------------------------------------------------------

# Elige un indice para inspeccionar (ej.: primer ejemplo del test o un FN)
sample_idx = X_sample.index[0]
shap_value_single = explainer.shap_values(X_sample.loc[[sample_idx]])

# Info del ejemplo
true_label = y_sample.loc[sample_idx] if y_sample is not None else "?"
pred_label = model.predict(X_sample.loc[[sample_idx]])[0]
pred_proba = (
    model.predict_proba(X_sample.loc[[sample_idx]])[0]
    if hasattr(model, "predict_proba")
    else None
)

print(f"Ejemplo idx={sample_idx}")
print(f"  Etiqueta real: {STAGE_LABELS.get(true_label, true_label)}")
print(f"  Predicción:    {STAGE_LABELS.get(pred_label, pred_label)}")
if pred_proba is not None:
    print(
        f"  Probabilidades: {dict(zip([STAGE_LABELS.get(i, i) for i in range(len(pred_proba))], pred_proba.round(3)))}"
    )

# Para multi-clase, escoge la clase predicha
if isinstance(shap_value_single, list):
    class_idx = int(pred_label)
    values = shap_value_single[class_idx][0]
    base_value = explainer.expected_value[class_idx]
elif isinstance(shap_value_single, np.ndarray) and shap_value_single.ndim == 3:
    class_idx = int(pred_label)
    values = shap_value_single[0, :, class_idx]
    base_value = explainer.expected_value[class_idx]
else:
    values = shap_value_single[0]
    base_value = explainer.expected_value

plt.figure(figsize=(10, 8))
shap.plots.waterfall(
    shap.Explanation(
        values=values,
        base_values=base_value,
        data=X_sample.loc[sample_idx].values,
        feature_names=feature_cols,
    ),
    max_display=15,
    show=False,
)
plt.title(f"SHAP Waterfall - Predicción: {STAGE_LABELS.get(pred_label, pred_label)}")
plt.tight_layout()
plt.savefig(
    os.path.join(OUTPUT_DIR, f"shap_waterfall_idx{sample_idx}.png"),
    dpi=150,
    bbox_inches="tight",
)
plt.show()

print(f"\nGuardado waterfall para idx={sample_idx}")

In [None]:
# ------------------------------------------------------------
# Análisis de errores: explicar casos mal clasificados
# ------------------------------------------------------------

if y_sample is not None:
    # Predicciones en la muestra
    y_pred = model.predict(X_sample)
    errors_mask = y_sample.values != y_pred
    n_errors = errors_mask.sum()
    accuracy = 1 - n_errors / len(y_sample)

    print(
        f"Accuracy en muestra: {accuracy:.3f} ({n_errors} errores de {len(y_sample)})"
    )

    if n_errors > 0:
        # Matriz de confusión de errores
        from collections import Counter

        error_pairs = [
            (y_sample.iloc[i], y_pred[i])
            for i in range(len(y_sample))
            if errors_mask[i]
        ]
        error_counts = Counter(error_pairs)

        print("\nErrores más frecuentes (real -> predicho):")
        for (true_cls, pred_cls), count in error_counts.most_common(10):
            true_label = STAGE_LABELS.get(true_cls, str(true_cls))
            pred_label = STAGE_LABELS.get(pred_cls, str(pred_cls))
            print(f"  {true_label} -> {pred_label}: {count}")

        # Analizar SHAP de ejemplos mal clasificados (muestra de errores)
        error_indices = X_sample.index[errors_mask]
        n_error_examples = min(3, len(error_indices))

        print(
            f"\n--- Análisis SHAP de {n_error_examples} ejemplos mal clasificados ---"
        )

        fig, axes = plt.subplots(1, n_error_examples, figsize=(6 * n_error_examples, 5))
        if n_error_examples == 1:
            axes = [axes]

        for i, err_idx in enumerate(error_indices[:n_error_examples]):
            true_cls = y_sample.loc[err_idx]
            pred_cls = y_pred[X_sample.index.get_loc(err_idx)]

            true_label = STAGE_LABELS.get(true_cls, str(true_cls))
            pred_label = STAGE_LABELS.get(pred_cls, str(pred_cls))

            # Obtener SHAP values para la clase predicha (explica por qué predijo eso)
            if is_multiclass_list:
                sv_err = shap_values[int(pred_cls)][X_sample.index.get_loc(err_idx)]
            elif is_multiclass_3d:
                sv_err = shap_values[X_sample.index.get_loc(err_idx), :, int(pred_cls)]
            else:
                sv_err = shap_values[X_sample.index.get_loc(err_idx)]

            # Top features que causaron el error
            top_err_idx = np.argsort(np.abs(sv_err))[::-1][:10]

            plt.sca(axes[i])
            colors = ["red" if v > 0 else "blue" for v in sv_err[top_err_idx]]
            plt.barh(range(len(top_err_idx)), sv_err[top_err_idx], color=colors)
            plt.yticks(range(len(top_err_idx)), [feature_cols[j] for j in top_err_idx])
            plt.xlabel("SHAP value")
            plt.title(f"Error: {true_label} -> {pred_label}")
            plt.gca().invert_yaxis()

        plt.tight_layout()
        plt.savefig(
            os.path.join(OUTPUT_DIR, "shap_error_analysis.png"),
            dpi=150,
            bbox_inches="tight",
        )
        plt.show()

        print("Rojo = empuja hacia predicción errónea, Azul = empuja contra")

In [None]:
# ------------------------------------------------------------
# Permutation importance (model-agnostic)
# ------------------------------------------------------------

if y_sample is not None:
    print("Calculando permutation importance (puede tardar)...")
    pi = permutation_importance(
        model,
        X_sample,
        y_sample,
        n_repeats=10,
        random_state=42,
        n_jobs=-1,
        scoring="accuracy",
    )
    pi_df = pd.DataFrame(
        {
            "feature": feature_cols,
            "importance_mean": pi.importances_mean,
            "importance_std": pi.importances_std,
        }
    ).sort_values("importance_mean", ascending=False)
    pi_df.to_csv(os.path.join(OUTPUT_DIR, "permutation_importance.csv"), index=False)

    # Visualizar top 20
    top_pi = pi_df.head(20)
    plt.figure(figsize=(10, 8))
    plt.barh(
        range(len(top_pi)),
        top_pi["importance_mean"],
        xerr=top_pi["importance_std"],
        color="steelblue",
    )
    plt.yticks(range(len(top_pi)), top_pi["feature"])
    plt.xlabel("Decrease in accuracy")
    plt.title("Permutation Feature Importance (Top 20)")
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.savefig(
        os.path.join(OUTPUT_DIR, "permutation_importance.png"),
        dpi=150,
        bbox_inches="tight",
    )
    plt.show()

    print("\nTop 10 features (permutation importance):")
    print(pi_df.head(10).to_string(index=False))
else:
    print("No se encontró columna de etiqueta; se omite permutation importance")

In [None]:
# ------------------------------------------------------------
# Comparación y correlación entre métodos de importancia
# ------------------------------------------------------------

# Crear DataFrame comparativo
comparison = pd.DataFrame({"feature": feature_cols})

# SHAP importance
comparison["shap_importance"] = mean_abs
comparison["shap_rank"] = (
    comparison["shap_importance"].rank(ascending=False).astype(int)
)

# Native importance (si disponible)
if hasattr(model, "feature_importances_"):
    comparison["native_importance"] = model.feature_importances_
    comparison["native_rank"] = (
        comparison["native_importance"].rank(ascending=False).astype(int)
    )

# Permutation importance (si calculado)
if y_sample is not None and "pi" in dir():
    comparison["perm_importance"] = pi.importances_mean
    comparison["perm_rank"] = (
        comparison["perm_importance"].rank(ascending=False).astype(int)
    )

# Ordenar por SHAP
comparison = comparison.sort_values("shap_rank")

# Mostrar top 15
print("Comparación de rankings de importancia (Top 15):")
rank_cols = [c for c in comparison.columns if c.endswith("_rank")]
display_cols = ["feature"] + rank_cols
print(comparison[display_cols].head(15).to_string(index=False))

# --- Correlaciones entre métodos ---
print("\n" + "=" * 60)
print("CORRELACIÓN ENTRE MÉTODOS DE IMPORTANCIA")
print("=" * 60)

importance_cols = [c for c in comparison.columns if c.endswith("_importance")]
if len(importance_cols) >= 2:
    corr_results = []
    for i, col1 in enumerate(importance_cols):
        for col2 in importance_cols[i + 1 :]:
            spearman_corr, spearman_p = spearmanr(comparison[col1], comparison[col2])
            kendall_corr, kendall_p = kendalltau(comparison[col1], comparison[col2])
            corr_results.append(
                {
                    "Método 1": col1.replace("_importance", ""),
                    "Método 2": col2.replace("_importance", ""),
                    "Spearman ρ": f"{spearman_corr:.3f}",
                    "p-value": f"{spearman_p:.2e}",
                    "Kendall τ": f"{kendall_corr:.3f}",
                }
            )

    corr_df = pd.DataFrame(corr_results)
    print("\nCorrelaciones de ranking:")
    print(corr_df.to_string(index=False))

    # Interpretación
    print("\nInterpretación:")
    print("  ρ/τ > 0.7: Alta concordancia entre métodos (robusto)")
    print("  ρ/τ 0.4-0.7: Concordancia moderada")
    print("  ρ/τ < 0.4: Baja concordancia (investigar diferencias)")

# Scatter plot de importancias
if len(importance_cols) >= 2:
    n_pairs = len(importance_cols) * (len(importance_cols) - 1) // 2
    fig, axes = plt.subplots(1, n_pairs, figsize=(6 * n_pairs, 5))
    if n_pairs == 1:
        axes = [axes]

    pair_idx = 0
    for i, col1 in enumerate(importance_cols):
        for col2 in importance_cols[i + 1 :]:
            ax = axes[pair_idx]
            ax.scatter(comparison[col1], comparison[col2], alpha=0.5, s=20)
            ax.set_xlabel(col1.replace("_importance", "").upper())
            ax.set_ylabel(col2.replace("_importance", "").upper())

            # Línea de tendencia
            z = np.polyfit(comparison[col1], comparison[col2], 1)
            p = np.poly1d(z)
            x_line = np.linspace(comparison[col1].min(), comparison[col1].max(), 100)
            ax.plot(x_line, p(x_line), "r--", alpha=0.8, label="Tendencia")

            # Correlación en título
            r, _ = spearmanr(comparison[col1], comparison[col2])
            ax.set_title(f"Spearman ρ = {r:.3f}")
            pair_idx += 1

    plt.tight_layout()
    plt.savefig(
        os.path.join(OUTPUT_DIR, "importance_correlation_scatter.png"),
        dpi=150,
        bbox_inches="tight",
    )
    plt.show()

# Guardar comparación completa
comparison.to_csv(os.path.join(OUTPUT_DIR, "importance_comparison.csv"), index=False)
print(f"\nComparación completa guardada en {OUTPUT_DIR}/importance_comparison.csv")

In [None]:
# ------------------------------------------------------------
# Importancia por grupos de features
# ------------------------------------------------------------


# Definir grupos de features basados en nombres típicos de PSG
# Ajusta estos patrones según tus nombres de features
def categorize_feature(fname):
    fname_lower = fname.lower()
    if any(
        band in fname_lower
        for band in ["delta", "theta", "alpha", "beta", "gamma", "sigma"]
    ):
        return "EEG_bands"
    elif "eog" in fname_lower or "eye" in fname_lower:
        return "EOG"
    elif "emg" in fname_lower or "chin" in fname_lower:
        return "EMG"
    elif any(
        stat in fname_lower for stat in ["mean", "std", "var", "skew", "kurt", "rms"]
    ):
        return "Statistical"
    elif any(temp in fname_lower for temp in ["hjorth", "mobility", "complexity"]):
        return "Hjorth"
    elif any(ent in fname_lower for ent in ["entropy", "perm", "sample", "approx"]):
        return "Entropy"
    elif any(
        conn in fname_lower for conn in ["coherence", "correlation", "plv", "pli"]
    ):
        return "Connectivity"
    else:
        return "Other"


# Asignar categorías
feature_groups = pd.DataFrame(
    {
        "feature": feature_cols,
        "group": [categorize_feature(f) for f in feature_cols],
        "shap_importance": mean_abs,
    }
)

# Importancia agregada por grupo
group_importance = feature_groups.groupby("group")["shap_importance"].agg(
    ["sum", "mean", "count"]
)
group_importance = group_importance.sort_values("sum", ascending=False)

print("Importancia SHAP por grupo de features:")
print(group_importance.round(4).to_string())

# Visualizar
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Importancia total por grupo
ax1 = axes[0]
colors = plt.cm.Set2(np.linspace(0, 1, len(group_importance)))
ax1.barh(range(len(group_importance)), group_importance["sum"], color=colors)
ax1.set_yticks(range(len(group_importance)))
ax1.set_yticklabels(group_importance.index)
ax1.set_xlabel("SHAP importance (sum)")
ax1.set_title("Importancia total por grupo")
ax1.invert_yaxis()

# Importancia media por grupo
ax2 = axes[1]
ax2.barh(range(len(group_importance)), group_importance["mean"], color=colors)
ax2.set_yticks(range(len(group_importance)))
ax2.set_yticklabels(group_importance.index)
ax2.set_xlabel("SHAP importance (mean)")
ax2.set_title("Importancia media por grupo")
ax2.invert_yaxis()

plt.tight_layout()
plt.savefig(
    os.path.join(OUTPUT_DIR, "shap_by_feature_group.png"), dpi=150, bbox_inches="tight"
)
plt.show()

# Distribución de grupos
print("\nDistribución de features por grupo:")
print(feature_groups["group"].value_counts().to_string())

## Analisis avanzado

Las siguientes celdas son computacionalmente mas costosas (interacciones entre features y analisis de estabilidad bootstrap).


In [None]:
# ------------------------------------------------------------
# SHAP Interaction Values
# ------------------------------------------------------------

MAX_FEATS_INTERACTION = 200
if len(feature_cols) > MAX_FEATS_INTERACTION:
    print(
        f"Saltando SHAP interaction values: {len(feature_cols)} features (> {MAX_FEATS_INTERACTION}). Ajusta MAX_FEATS_INTERACTION si realmente lo necesitas."
    )
else:
    # Usar submuestra mas pequena para interactions
    N_INTERACTION = min(500, len(X_sample))
    X_interaction = X_sample.sample(N_INTERACTION, random_state=42)

    print(f"Calculando SHAP interaction values para {N_INTERACTION} muestras...")
    print("(Esto puede tardar varios minutos)")

    if "main_class" not in locals():
        main_class = _choose_main_class() if "_choose_main_class" in locals() else 0

    shap_interaction = explainer.shap_interaction_values(X_interaction)

    # Para multi-clase, tomar una clase
    if isinstance(shap_interaction, list):
        shap_inter_class = shap_interaction[main_class]
    elif shap_interaction.ndim == 4:
        shap_inter_class = shap_interaction[:, :, :, main_class]
    else:
        shap_inter_class = shap_interaction

    # Matriz de interacciones promedio
    interaction_matrix = np.abs(shap_inter_class).mean(axis=0)

    # Encontrar top interacciones (excluyendo diagonal = main effects)
    n_feat = len(feature_cols)
    interaction_pairs = []
    for i in range(n_feat):
        for j in range(i + 1, n_feat):
            interaction_pairs.append(
                {
                    "feature_1": feature_cols[i],
                    "feature_2": feature_cols[j],
                    "interaction_strength": interaction_matrix[i, j],
                }
            )

    interaction_df = pd.DataFrame(interaction_pairs).sort_values(
        "interaction_strength", ascending=False
    )
    print("\nTop 10 interacciones entre features:")
    print(interaction_df.head(10).to_string(index=False))

    # Heatmap de top features
    N_TOP_HEAT = 15
    top_feat_idx = top_idx[:N_TOP_HEAT]
    interaction_sub = interaction_matrix[np.ix_(top_feat_idx, top_feat_idx)]
    top_feat_names = [feature_cols[i] for i in top_feat_idx]

    plt.figure(figsize=(10, 8))
    plt.imshow(interaction_sub, cmap="YlOrRd")
    plt.xticks(range(N_TOP_HEAT), top_feat_names, rotation=45, ha="right")
    plt.yticks(range(N_TOP_HEAT), top_feat_names)
    plt.colorbar(label="Mean |SHAP interaction|")
    plt.title("SHAP Interaction Matrix (Top 15 features)")
    plt.tight_layout()
    plt.savefig(
        os.path.join(OUTPUT_DIR, "shap_interaction_heatmap.png"),
        dpi=150,
        bbox_inches="tight",
    )
    plt.show()

    interaction_df.to_csv(
        os.path.join(OUTPUT_DIR, "shap_interactions.csv"), index=False
    )
    print(f"Guardado en {OUTPUT_DIR}/shap_interactions.csv")

In [None]:
# ------------------------------------------------------------
# Analisis de estabilidad de rankings (bootstrap)
# ------------------------------------------------------------

N_BOOTSTRAP = 5  # numero de repeticiones bootstrap
BOOTSTRAP_SIZE = min(2000, len(X_sample))
MAX_BOOTSTRAP_FEATS = 300
if len(feature_cols) > MAX_BOOTSTRAP_FEATS:
    print(
        f"Reduciendo N_BOOTSTRAP a 2 porque hay {len(feature_cols)} features (> {MAX_BOOTSTRAP_FEATS})"
    )
    N_BOOTSTRAP = 2

print(
    f"Analisis de estabilidad: {N_BOOTSTRAP} repeticiones bootstrap de {BOOTSTRAP_SIZE} muestras"
)

# Almacenar rankings por repeticion
bootstrap_ranks = defaultdict(list)

for b in range(N_BOOTSTRAP):
    # Muestra bootstrap
    idx_boot = np.random.choice(len(X_sample), size=BOOTSTRAP_SIZE, replace=True)
    X_boot = X_sample.iloc[idx_boot]

    # Calcular SHAP
    sv_boot = explainer.shap_values(X_boot)

    # Mean absolute SHAP
    if isinstance(sv_boot, list):
        abs_boot = np.mean([np.abs(sv) for sv in sv_boot], axis=0).mean(axis=0)
    elif sv_boot.ndim == 3:
        abs_boot = np.mean(np.abs(sv_boot), axis=2).mean(axis=0)
    else:
        abs_boot = np.abs(sv_boot).mean(axis=0)

    # Rankings
    ranks_boot = np.argsort(np.argsort(-abs_boot)) + 1  # 1-indexed ranks

    for i, fname in enumerate(feature_cols):
        bootstrap_ranks[fname].append(ranks_boot[i])

    print(f"  Bootstrap {b+1}/{N_BOOTSTRAP} completado")

# Calcular estabilidad
stability_df = pd.DataFrame(
    {
        "feature": feature_cols,
        "rank_mean": [np.mean(bootstrap_ranks[f]) for f in feature_cols],
        "rank_std": [np.std(bootstrap_ranks[f]) for f in feature_cols],
        "rank_min": [np.min(bootstrap_ranks[f]) for f in feature_cols],
        "rank_max": [np.max(bootstrap_ranks[f]) for f in feature_cols],
    }
)
stability_df["rank_range"] = stability_df["rank_max"] - stability_df["rank_min"]
stability_df = stability_df.sort_values("rank_mean")

print("\nEstabilidad de rankings (Top 20 features):")
print(stability_df.head(20).to_string(index=False))

# Identificar features con rankings inestables
unstable = stability_df[stability_df["rank_std"] > 5].head(10)
if len(unstable) > 0:
    print("\nATENCION: Features con ranking inestable (std > 5):")
    print(
        unstable[["feature", "rank_mean", "rank_std", "rank_range"]].to_string(
            index=False
        )
    )

# Visualizar estabilidad de top 20
top20_stability = stability_df.head(20)
plt.figure(figsize=(12, 6))
plt.errorbar(
    range(len(top20_stability)),
    top20_stability["rank_mean"],
    yerr=top20_stability["rank_std"],
    fmt="o",
    capsize=3,
    capthick=1,
)
plt.xticks(
    range(len(top20_stability)), top20_stability["feature"], rotation=45, ha="right"
)
plt.ylabel("Rank (mean +/- std)")
plt.xlabel("Feature")
plt.title(f"Estabilidad de rankings SHAP ({N_BOOTSTRAP} bootstrap)")
plt.gca().invert_yaxis()  # Rank 1 arriba
plt.tight_layout()
plt.savefig(
    os.path.join(OUTPUT_DIR, "shap_stability_bootstrap.png"),
    dpi=150,
    bbox_inches="tight",
)
plt.show()

stability_df.to_csv(os.path.join(OUTPUT_DIR, "shap_stability.csv"), index=False)
print(f"\nGuardado en {OUTPUT_DIR}/shap_stability.csv")

In [None]:
# ------------------------------------------------------------
# Resumen del análisis y metadatos
# ------------------------------------------------------------

print("=" * 70)
print("RESUMEN DEL ANÁLISIS DE INTERPRETABILIDAD")
print("=" * 70)

# Metadatos
print(f"\nFecha: {datetime.now().strftime('%Y-%m-%d %H:%M')}")
print(f"Modelo: {type(model).__name__}")
print(f"Dataset: {DATA_PATH}")
print(f"N muestras para SHAP: {len(X_sample):,}")
print(f"N features: {len(feature_cols)}")
print(f"N clases: {n_classes}")

# Top features global
print("\n--- TOP 10 FEATURES (SHAP global) ---")
for rank, idx in enumerate(top_idx[:10], 1):
    print(f"  {rank:2d}. {feature_cols[idx]}")

# Features más importantes por clase
print("\n--- FEATURES CLAVE POR ETAPA DE SUEÑO ---")
for class_i in range(n_classes):
    class_label = STAGE_LABELS.get(class_i, str(class_i))
    if is_multiclass_list:
        sv_class = shap_values[class_i]
    elif is_multiclass_3d:
        sv_class = shap_values[:, :, class_i]
    else:
        sv_class = shap_values

    mean_abs_class = np.abs(sv_class).mean(axis=0)
    top3_idx = np.argsort(mean_abs_class)[::-1][:3]
    top3_features = [feature_cols[i] for i in top3_idx]
    print(f"  {class_label}: {', '.join(top3_features)}")

# Archivos generados
print("\n--- ARCHIVOS GENERADOS ---")

output_files = glob.glob(os.path.join(OUTPUT_DIR, "*"))
for f in sorted(output_files):
    print(f"  {os.path.basename(f)}")

print("\n" + "=" * 70)
print("Análisis completado. Revisa los plots y CSVs en:", OUTPUT_DIR)
print("=" * 70)