# Visualización de Resultados de Modelos

Este notebook permite cargar y visualizar los resultados de entrenamiento de los modelos de clasificación de estadios de sueño.

Carga métricas desde archivos JSON generados por `src/models.py` y genera visualizaciones interactivas.

**Modelos soportados:**
- Random Forest
- XGBoost  
- CNN1D (Deep Learning)
- LSTM (Deep Learning)


## Configuración inicial


In [None]:
import json
from pathlib import Path
from statistics import NormalDist

import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.display import Image, display

sns.set_theme(style="whitegrid", palette="deep")

PROJECT_ROOT = Path("..").resolve()
MODELS_DIR = PROJECT_ROOT / "models"
ARTIFACTS_DIR = PROJECT_ROOT / "artifacts"

STAGE_ORDER = ["W", "N1", "N2", "N3", "REM"]
STAGE_COLORS = {
    "W": "#fdae61",
    "N1": "#fee090",
    "N2": "#abd9e9",
    "N3": "#2c7bb6",
    "REM": "#d7191c",
}

EPS = 1e-12


def load_metrics(path: Path) -> dict:
    """Carga métricas desde un archivo JSON."""
    with open(path) as f:
        return json.load(f)


def wilson_ci(
    successes: float, total: float, alpha: float = 0.05
) -> tuple[float, float]:
    """Intervalo de confianza de Wilson para proporciones (binomial)."""
    if total <= 0:
        return (np.nan, np.nan)
    z = NormalDist().inv_cdf(1 - alpha / 2)
    phat = successes / total
    denom = 1 + (z**2) / total
    center = (phat + (z**2) / (2 * total)) / denom
    margin = (
        z * np.sqrt((phat * (1 - phat) / total) + ((z**2) / (4 * total**2)))
    ) / denom
    return (max(0, center - margin), min(1, center + margin))


def compute_class_metrics(
    cm: np.ndarray, labels=STAGE_ORDER, alpha: float = 0.05
) -> pd.DataFrame:
    """Calcula métricas por clase a partir de la matriz de confusión."""
    cm = np.asarray(cm)
    metrics = []
    for idx, label in enumerate(labels):
        tp = cm[idx, idx]
        fn = cm[idx, :].sum() - tp
        fp = cm[:, idx].sum() - tp
        tn = cm.sum() - tp - fn - fp

        support = tp + fn
        predicted = tp + fp

        precision = tp / (predicted + EPS)
        recall = tp / (support + EPS)
        specificity = tn / (tn + fp + EPS)
        npv = tn / (tn + fn + EPS)
        f1 = 2 * precision * recall / (precision + recall + EPS)
        fpr = fp / (fp + tn + EPS)
        fnr = fn / (fn + tp + EPS)
        fdr = fp / (fp + tp + EPS)

        recall_ci_low, recall_ci_high = wilson_ci(tp, support, alpha=alpha)
        spec_ci_low, spec_ci_high = wilson_ci(tn, tn + fp, alpha=alpha)

        metrics.append(
            {
                "Stage": label,
                "support": support,
                "precision": precision,
                "recall": recall,
                "specificity": specificity,
                "npv": npv,
                "f1": f1,
                "fpr": fpr,
                "fnr": fnr,
                "fdr": fdr,
                "recall_ci_low": recall_ci_low,
                "recall_ci_high": recall_ci_high,
                "spec_ci_low": spec_ci_low,
                "spec_ci_high": spec_ci_high,
            }
        )

    return pd.DataFrame(metrics)


def summarize_class_metrics(df: pd.DataFrame) -> dict[str, float]:
    """Resumen macro con métricas robustas a desbalance (balanced accuracy)."""
    n_classes = len(df)
    chance = 1 / n_classes if n_classes else np.nan
    balanced_accuracy = df["recall"].mean()
    macro_specificity = df["specificity"].mean()
    return {
        "balanced_accuracy": balanced_accuracy,
        "macro_specificity": macro_specificity,
        "macro_precision": df["precision"].mean(),
        "macro_f1": df["f1"].mean(),
        "gmean_sens_spec": np.sqrt(balanced_accuracy * macro_specificity),
        "chance_accuracy": chance,
        "balanced_accuracy_lift": (balanced_accuracy / chance - 1)
        if chance
        else np.nan,
    }


def most_confused_pairs(
    cm: np.ndarray, labels=STAGE_ORDER, top_n: int = 5
) -> pd.DataFrame:
    """Top pares de confusión (verdadero vs predicho) ordenados por frecuencia."""
    cm = np.asarray(cm)
    records = []
    for i, true_label in enumerate(labels):
        row_sum = cm[i, :].sum()
        for j, pred_label in enumerate(labels):
            if i == j:
                continue
            count = int(cm[i, j])
            if count > 0:
                records.append(
                    {
                        "Verdadero": true_label,
                        "Predicho": pred_label,
                        "Conteo": count,
                        "Proporción_sobre_clase": count / (row_sum + EPS),
                    }
                )

    if not records:
        return pd.DataFrame(
            columns=["Verdadero", "Predicho", "Conteo", "Proporción_sobre_clase"]
        )

    df = pd.DataFrame(records)
    return df.sort_values(by="Conteo", ascending=False).head(top_n)


MODEL_PATHS = {
    "random_forest": MODELS_DIR / "rf_opt_bayes_best" / "random_forest_metrics.json",
    "xgboost": MODELS_DIR / "xgb_loso_best" / "xgboost_metrics.json",
    "cnn1d": ARTIFACTS_DIR
    / "cnn1d_full_20251210_201502_artifacts"
    / "cnn1d_full_20251210_201502_results.json",
    "lstm_unidir": ARTIFACTS_DIR
    / "lstm_full_unidir_artifacts"
    / "lstm_full_20251210_193039_results.json",
    "lstm_bidir": ARTIFACTS_DIR
    / "lstm_full_bidir_artifacts"
    / "lstm_full_20251211_031820_results.json",
    "lstm_bidir_attention": ARTIFACTS_DIR
    / "lstm_full_bidir_attention_artifacts"
    / "lstm_full_20251211_145034_results.json",
}

AVAILABLE_MODELS = list(MODEL_PATHS.keys())

print("Modelos disponibles:", AVAILABLE_MODELS)

## Cargar métricas de un modelo


In [None]:
# Cambiar según el modelo que quieras visualizar
model_type = "xgboost"  # Opciones: "random_forest", "xgboost", "cnn1d", "lstm_unidir", "lstm_bidir", "lstm_bidir_attention"

metrics_path = MODEL_PATHS.get(model_type)

if metrics_path and metrics_path.exists():
    metrics = load_metrics(metrics_path)
    print(f"Métricas cargadas para {model_type}")
    print(f"Timestamp: {metrics.get('timestamp', 'N/A')}")
    print(
        f"Tipo de modelo: {metrics.get('model_type', metrics.get('model_name', 'N/A'))}"
    )
    print(
        f"Parámetros/Config: {metrics.get('model_params', metrics.get('config', {}))}"
    )
else:
    print(f"No se encontró archivo de métricas para {model_type}")
    print(f"Path buscado: {metrics_path}")
    print("\nModelos disponibles con archivos:")
    for name, path in MODEL_PATHS.items():
        status = "✓" if path.exists() else "✗"
        print(f"  {status} {name}: {path}")

## Métricas generales


In [None]:
if "metrics" in metrics:
    m = metrics["metrics"]
    print("=" * 60)
    print("MÉTRICAS GENERALES")
    print("=" * 60)
    print(f"Accuracy:        {m.get('accuracy', 0):.4f}")
    print(f"Cohen's Kappa:   {m.get('kappa', 0):.4f}")
    print(f"F1-score (macro): {m.get('f1_macro', 0):.4f}")
    print(f"F1-score (weighted): {m.get('f1_weighted', 0):.4f}")

    if "cv_results" in metrics:
        cv = metrics["cv_results"]
        print("\nCross-Validation:")
        print(
            f"  Mean F1-macro: {cv.get('mean_score', 0):.4f} ± {cv.get('std_score', 0):.4f}"
        )

## F1-score por estadio


In [None]:
if "metrics" in metrics and "f1_per_class" in metrics["metrics"]:
    f1_per_class = metrics["metrics"]["f1_per_class"]

    # Crear DataFrame para visualización
    f1_df = pd.DataFrame(
        [
            {"Stage": stage, "F1-score": f1_per_class.get(stage, 0.0)}
            for stage in STAGE_ORDER
        ]
    )

    # Gráfico de barras
    fig, ax = plt.subplots(figsize=(10, 6))
    colors = [STAGE_COLORS.get(stage, "gray") for stage in STAGE_ORDER]
    bars = ax.bar(
        f1_df["Stage"], f1_df["F1-score"], color=colors, alpha=0.7, edgecolor="black"
    )

    # Agregar valores en las barras
    for bar in bars:
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width() / 2.0,
            height,
            f"{height:.3f}",
            ha="center",
            va="bottom",
            fontweight="bold",
        )

    ax.set_ylabel("F1-score", fontsize=12)
    ax.set_xlabel("Estadio de Sueño", fontsize=12)
    ax.set_title(
        f"F1-score por Estadio - {model_type.upper()}", fontsize=14, fontweight="bold"
    )
    ax.set_ylim([0, 1.1])
    ax.grid(axis="y", alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Mostrar tabla
    display(
        f1_df.style.format({"F1-score": "{:.4f}"}).background_gradient(
            cmap="YlGnBu", subset=["F1-score"]
        )
    )

## Matriz de confusión


In [None]:
if "confusion_matrix" in metrics:
    cm = np.array(metrics["confusion_matrix"])

    # Normalizar por filas (recall por clase)
    cm_normalized = cm.astype("float") / (cm.sum(axis=1)[:, np.newaxis] + 1e-10)

    fig, axes = plt.subplots(1, 2, figsize=(16, 6))

    # Matriz absoluta
    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=STAGE_ORDER,
        yticklabels=STAGE_ORDER,
        ax=axes[0],
        cbar_kws={"label": "Count"},
    )
    axes[0].set_title("Matriz de Confusión (Absoluta)", fontsize=14, fontweight="bold")
    axes[0].set_ylabel("Verdadero", fontsize=12)
    axes[0].set_xlabel("Predicho", fontsize=12)

    # Matriz normalizada
    sns.heatmap(
        cm_normalized,
        annot=True,
        fmt=".3f",
        cmap="Blues",
        xticklabels=STAGE_ORDER,
        yticklabels=STAGE_ORDER,
        ax=axes[1],
        cbar_kws={"label": "Proporción"},
        vmin=0,
        vmax=1,
    )
    axes[1].set_title(
        "Matriz de Confusión (Normalizada)", fontsize=14, fontweight="bold"
    )
    axes[1].set_ylabel("Verdadero", fontsize=12)
    axes[1].set_xlabel("Predicho", fontsize=12)

    plt.tight_layout()
    plt.show()

## Análisis complementario

Se amplía el análisis para interpretar los resultados con mayor rigor: métricas derivadas de la matriz de confusión (recall, especificidad, precisión, NPV, tasas de error), intervalos de confianza binomiales (Wilson) por clase, balanceo de desempeño (balanced accuracy) y pares de confusión más frecuentes que indican errores sistemáticos.


In [None]:
alpha = 0.05  # IC 95%
if "confusion_matrix" in metrics:
    cm = np.array(metrics["confusion_matrix"])
    class_metrics = compute_class_metrics(cm, labels=STAGE_ORDER, alpha=alpha)
    summary = summarize_class_metrics(class_metrics)

    # Distribución de soportes por clase
    support_df = class_metrics[["Stage", "support"]].copy()
    support_df["proportion"] = support_df["support"] / support_df["support"].sum()
    print("Distribución de soportes por clase:")
    display(
        support_df.style.format(
            {"support": "{:,.0f}", "proportion": "{:.2%}"}
        ).background_gradient(cmap="Purples", subset=["proportion"])
    )

    # Métricas por clase con IC 95% (Wilson)
    print("\nMétricas por clase con IC 95% (Wilson):")
    display(
        class_metrics[
            [
                "Stage",
                "precision",
                "recall",
                "specificity",
                "npv",
                "f1",
                "fpr",
                "fnr",
                "fdr",
                "recall_ci_low",
                "recall_ci_high",
                "spec_ci_low",
                "spec_ci_high",
            ]
        ]
        .rename(columns={"npv": "npv (TN rate)", "fpr": "fpr", "fnr": "fnr"})
        .style.format(
            {
                "precision": "{:.3f}",
                "recall": "{:.3f}",
                "specificity": "{:.3f}",
                "npv (TN rate)": "{:.3f}",
                "f1": "{:.3f}",
                "fpr": "{:.3f}",
                "fnr": "{:.3f}",
                "fdr": "{:.3f}",
                "recall_ci_low": "{:.3f}",
                "recall_ci_high": "{:.3f}",
                "spec_ci_low": "{:.3f}",
                "spec_ci_high": "{:.3f}",
            }
        )
        .background_gradient(
            cmap="YlGnBu",
            subset=["precision", "recall", "specificity", "npv (TN rate)", "f1"],
        )
    )

    # Resumen macro robusto y lift sobre azar
    print("\nResumen macro robusto:")
    summary_labels = {
        "balanced_accuracy": "Balanced accuracy (sensibilidad macro)",
        "macro_specificity": "Especificidad macro",
        "macro_precision": "Precisión macro",
        "macro_f1": "F1 macro",
        "gmean_sens_spec": "G-mean (sensibilidad x especificidad)",
        "chance_accuracy": "Azar (1/n clases)",
        "balanced_accuracy_lift": "Lift sobre azar (balanced acc)",
    }
    for key, label in summary_labels.items():
        value = summary.get(key, np.nan)
        if pd.isna(value):
            print(f"{label}: N/A")
        else:
            print(f"{label}: {value:.3f}")

    # Gráfico de sensibilidad por clase con IC 95%
    colors = [STAGE_COLORS.get(stage, "gray") for stage in STAGE_ORDER]
    fig, ax = plt.subplots(figsize=(10, 6))
    yerr = np.vstack(
        [
            class_metrics["recall"] - class_metrics["recall_ci_low"],
            class_metrics["recall_ci_high"] - class_metrics["recall"],
        ]
    )
    ax.bar(
        class_metrics["Stage"],
        class_metrics["recall"],
        color=colors,
        edgecolor="black",
        alpha=0.8,
        yerr=yerr,
        capsize=6,
    )
    ax.set_ylim(0, 1.05)
    ax.set_ylabel("Sensibilidad / Recall")
    ax.set_title("Recall por clase con IC 95% (Wilson)", fontweight="bold")
    ax.grid(axis="y", alpha=0.3)
    plt.tight_layout()
    plt.show()

    # Pares de confusión más frecuentes (errores sistemáticos)
    conf_df = most_confused_pairs(cm, labels=STAGE_ORDER, top_n=5)
    print("\nPares de confusión más frecuentes:")
    if not conf_df.empty:
        display(
            conf_df.style.format(
                {"Conteo": "{:,.0f}", "Proporción_sobre_clase": "{:.2%}"}
            ).background_gradient(
                cmap="Reds", subset=["Conteo", "Proporción_sobre_clase"]
            )
        )
    else:
        print("No se registran confusiones relevantes (fuera de la diagonal).")
else:
    print("No hay matriz de confusión disponible en las métricas.")

## Comparar múltiples modelos


In [None]:
# Cargar métricas de todos los modelos disponibles
all_metrics = {}
for model_type, metrics_path in MODEL_PATHS.items():
    # Si es un directorio, buscar el archivo JSON dentro
    if metrics_path.is_dir():
        json_files = list(metrics_path.glob("*_results.json"))
        if json_files:
            metrics_path = json_files[0]
        else:
            continue

    if metrics_path.exists():
        all_metrics[model_type] = load_metrics(metrics_path)
        print(f"✓ Cargado: {model_type}")

print(f"\nTotal de modelos cargados: {len(all_metrics)}")

if len(all_metrics) > 1:
    # Comparar métricas generales
    comparison = []
    for name, m in all_metrics.items():
        if "metrics" in m:
            comparison.append(
                {
                    "Modelo": name,
                    "Accuracy": m["metrics"].get("accuracy", 0),
                    "Kappa": m["metrics"].get("kappa", 0),
                    "F1-macro": m["metrics"].get("f1_macro", 0),
                    "F1-weighted": m["metrics"].get("f1_weighted", 0),
                }
            )

    if comparison:
        df_comp = pd.DataFrame(comparison)
        display(
            df_comp.style.format(
                {
                    "Accuracy": "{:.4f}",
                    "Kappa": "{:.4f}",
                    "F1-macro": "{:.4f}",
                    "F1-weighted": "{:.4f}",
                }
            ).background_gradient(
                cmap="RdYlGn", subset=["Accuracy", "Kappa", "F1-macro", "F1-weighted"]
            )
        )

        # Gráfico comparativo
        fig, ax = plt.subplots(figsize=(12, 6))
        x = np.arange(len(df_comp))
        width = 0.2

        metrics_to_plot = ["Accuracy", "Kappa", "F1-macro", "F1-weighted"]
        for i, metric in enumerate(metrics_to_plot):
            ax.bar(x + i * width, df_comp[metric], width, label=metric, alpha=0.8)

        ax.set_ylabel("Score", fontsize=12)
        ax.set_title("Comparación de Modelos", fontsize=14, fontweight="bold")
        ax.set_xticks(x + width * 1.5)
        ax.set_xticklabels(df_comp["Modelo"], rotation=45, ha="right")
        ax.legend()
        ax.grid(axis="y", alpha=0.3)
        ax.set_ylim([0, 1])

        plt.tight_layout()
        plt.show()
else:
    print("Necesitas al menos 2 modelos para comparar")

In [None]:
# Mostrar matriz de confusión desde imagen (para modelos de Deep Learning)

# Mapeo de modelos a sus imágenes de confusion matrix
CONFUSION_MATRIX_IMAGES = {
    "cnn1d": ARTIFACTS_DIR
    / "cnn1d_full_20251210_201502_artifacts"
    / "cnn1d_full_20251210_201502_confusion_matrix.png",
    "lstm_unidir": ARTIFACTS_DIR
    / "lstm_full_unidir_artifacts"
    / "lstm_full_20251210_193039_confusion_matrix.png",
    "lstm_bidir": ARTIFACTS_DIR
    / "lstm_full_bidir_artifacts"
    / "lstm_full_20251211_031820_confusion_matrix.png",
    "lstm_bidir_attention": ARTIFACTS_DIR
    / "lstm_full_bidir_attention_artifacts"
    / "lstm_full_20251211_145034_confusion_matrix.png",
}

model_to_show = "cnn1d"  # Cambiar según el modelo

cm_path = CONFUSION_MATRIX_IMAGES.get(model_to_show)

if cm_path and cm_path.exists():
    print(f"Matriz de Confusión - {model_to_show.upper()}")
    display(Image(filename=str(cm_path), width=700))
else:
    print(f"No se encontró imagen de matriz de confusión para {model_to_show}")
    print(f"Path buscado: {cm_path}")

In [None]:
# Mostrar todas las matrices de confusión disponibles

fig, axes = plt.subplots(2, 2, figsize=(16, 14))
axes = axes.flatten()

cm_images = {
    "CNN1D": ARTIFACTS_DIR
    / "cnn1d_full_20251210_201502_artifacts"
    / "cnn1d_full_20251210_201502_confusion_matrix.png",
    "LSTM Unidir": ARTIFACTS_DIR
    / "lstm_full_unidir_artifacts"
    / "lstm_full_20251210_193039_confusion_matrix.png",
    "LSTM Bidir": ARTIFACTS_DIR
    / "lstm_full_bidir_artifacts"
    / "lstm_full_20251211_031820_confusion_matrix.png",
    "LSTM Bidir+Attn": ARTIFACTS_DIR
    / "lstm_full_bidir_attention_artifacts"
    / "lstm_full_20251211_145034_confusion_matrix.png",
}

for idx, (name, path) in enumerate(cm_images.items()):
    if idx >= len(axes):
        break
    if path.exists():
        img = mpimg.imread(str(path))
        axes[idx].imshow(img)
        axes[idx].set_title(name, fontsize=14, fontweight="bold")
        axes[idx].axis("off")
    else:
        axes[idx].text(0.5, 0.5, f"No encontrado:\n{name}", ha="center", va="center")
        axes[idx].axis("off")

plt.suptitle(
    "Matrices de Confusión - Modelos Deep Learning", fontsize=16, fontweight="bold"
)
plt.tight_layout()
plt.show()