# CNN1D para Sleep Staging

Este notebook entrena un modelo **CNN1D con conexiones residuales** para clasificacion de estadios de sueno usando datos PSG **trimmed** (preprocesados y recortados del dataset Sleep-EDF).

**Optimizado para Kaggle con 2x Tesla T4 (16GB VRAM cada una)**

### Modos de ejecución:
- **Debug** (`EXECUTION_MODE = "debug"`): Carga datos en RAM, usa subset de sujetos, ~30 seg de setup
- **Full** (`EXECUTION_MODE = "full"`): Streaming TFRecord, todos los sujetos, ~30 min de setup

### Caracteristicas:
- CNN1D con bloques residuales
- Data augmentation con ruido gaussiano (opcional)
- Soporte multi-GPU con MirroredStrategy
- Optimizacion de hiperparametros con Optuna (opcional)
- **División train/val/test por SUJETOS (sin data leakage)** 
- Soporte para reanudar entrenamiento desde checkpoints
- Semillas fijas para reproducibilidad (SEED=42)

### Datos requeridos:
- Dataset `sleep-edf-trimmed-f32` en Kaggle (https://www.kaggle.com/datasets/ignaciolinari/sleep-edf-trimmed-f32)
  - `manifest_trimmed_spt.csv` (1 episodio por noche, 100 Hz)
  - `sleep_trimmed_spt/psg/*.fif` (PSG a 100 Hz, float32)
  - `sleep_trimmed_spt/hypnograms/*.csv` (anotaciones)
- Si usas la versión 200 Hz, apunta a `manifest_trimmed_resamp200.csv` + carpeta `sleep_trimmed_resamp200/`.
- Nota Kaggle: la versión 200 Hz puede consumir más VRAM; deja `sfreq=100` o usa el dataset 100 Hz para evitar OOM.
- Si usas otro slug, actualiza `DATA_PATH` abajo.

In [None]:
# ============================================================
# CONFIGURACIÓN INICIAL - Ejecutar primero
# ============================================================

import os
import warnings

warnings.filterwarnings("ignore")

# Detectar si estamos en Kaggle
IN_KAGGLE = os.path.exists("/kaggle/input")
print(f"Entorno: {'Kaggle' if IN_KAGGLE else 'Local'}")

# Paths según entorno
if IN_KAGGLE:
    # En Kaggle: ajustar al nombre de tu dataset
    DATA_PATH = "/kaggle/input/sleep-edf-trimmed-f32/sleep_trimmed_spt"  # <- Ajustar al slug/version que uses
    OUTPUT_PATH = "/kaggle/working"
else:
    # Local
    DATA_PATH = "../data/processed"
    OUTPUT_PATH = "../models"

# Asegurar que el directorio de salida exista antes de guardar artefactos
os.makedirs(OUTPUT_PATH, exist_ok=True)

print(f"Data path: {DATA_PATH}")
print(f"Output path: {OUTPUT_PATH}")

In [None]:
# ============================================================
# VERIFICAR GPUs DISPONIBLES
# ============================================================

import tensorflow as tf  # noqa: E402

# Semilla global para reproducibilidad
SEED = 42
tf.random.set_seed(SEED)

print("TensorFlow version:", tf.__version__)
print("\nGPUs disponibles:")

gpus = tf.config.list_physical_devices("GPU")
if gpus:
    for i, gpu in enumerate(gpus):
        print(f"  GPU {i}: {gpu}")

    # Habilitar memory growth para evitar OOM
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

    print(f"\n[OK] {len(gpus)} GPU(s) configuradas con memory growth")
else:
    print("[WARN] No se detectaron GPUs. El entrenamiento sera lento.")

# Estrategia de distribución para múltiples GPUs
if len(gpus) > 1:
    strategy = tf.distribute.MirroredStrategy()
    print(f"\nUsando MirroredStrategy con {strategy.num_replicas_in_sync} GPUs")
else:
    strategy = tf.distribute.get_strategy()
    if len(gpus) == 1:
        print("\nUsando estrategia por defecto (1 GPU)")
    else:
        print("\n[WARN] Usando CPU - entrenamiento será lento")
        print("       Para mejor rendimiento, ejecuta en entorno con GPU")


# Habilitar mixed precision para mejor rendimiento en GPUs modernas (T4, V100, A100)
# NOTA: Desactivado por defecto porque puede causar NaN en algunos modelos
# Se activa desde CONFIG['use_mixed_precision'] después de definir CONFIG
print("[INFO] Mixed precision se configurará después de definir CONFIG")

In [None]:
# ============================================================
# IMPORTS Y DEPENDENCIAS
# ============================================================

import gc  # noqa: E402
import json  # noqa: E402
import logging  # noqa: E402
import math  # noqa: E402
import pickle  # noqa: E402
import random  # noqa: E402
import re  # noqa: E402
import shutil  # noqa: E402
import time  # noqa: E402
import zipfile  # noqa: E402
from collections import Counter  # noqa: E402
from pathlib import Path  # noqa: E402

import matplotlib.pyplot as plt  # noqa: E402
import numpy as np  # noqa: E402
import pandas as pd  # noqa: E402
import seaborn as sns  # noqa: E402

# Semillas globales para reproducibilidad
np.random.seed(SEED)
random.seed(SEED)
tf.keras.utils.set_random_seed(SEED)

from datetime import datetime  # noqa: E402

from sklearn.metrics import (  # noqa: E402
    accuracy_score,
    classification_report,
    cohen_kappa_score,
    confusion_matrix,
    f1_score,
)
from sklearn.preprocessing import LabelEncoder  # noqa: E402
from sklearn.utils.class_weight import compute_class_weight  # noqa: E402
from tensorflow import keras  # noqa: E402
from tensorflow.keras import layers  # noqa: E402
from tensorflow.keras.callbacks import (  # noqa: E402
    Callback,
    EarlyStopping,
    ModelCheckpoint,
)

# Configurar logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

# Configurar estilo de plots
plt.style.use("seaborn-v0_8-whitegrid")
sns.set_palette("husl")

print("[OK] Imports completados")

In [None]:
# ============================================================
# INSTALAR DEPENDENCIAS (solo en Kaggle)
# ============================================================

try:
    import mne
except ImportError:
    if IN_KAGGLE:
        print("Instalando dependencias...")
        %pip install -q mne
        print("[OK] Dependencias instaladas")
        import mne
    else:
        raise

mne.set_log_level("ERROR")
print(f"MNE version: {mne.__version__}")

## Configuracion del Experimento

In [None]:
# ============================================================
# HIPERPARAMETROS - AJUSTAR SEGUN NECESIDAD
# ============================================================

# =====================================================
# MODO DE EJECUCIÓN: "debug" o "full"
# =====================================================
# - "debug": Carga datos en RAM (rápido, ~30 seg), usa subset de sujetos
# - "full":  Streaming TFRecord (lento, ~30 min), usa todos los sujetos
# =====================================================
EXECUTION_MODE = "debug"  # Cambiar a "full" para entrenamiento completo
# =====================================================

CONFIG = {
    # Modo de ejecución
    "execution_mode": EXECUTION_MODE,
    # Dataset (100 Hz)
    "manifest_path": (
        "/kaggle/input/sleep-edf-trimmed-f32/manifest_trimmed_spt.csv"
        if IN_KAGGLE
        else f"{DATA_PATH}/manifest_trimmed_spt.csv"
    ),
    "epoch_length": 30.0,  # segundos
    "sfreq": 100,  # Hz (re-muestreo a 100 Hz para reducir RAM)
    # Límite de sujetos según modo
    "debug_max_subjects": 30,  # Sujetos en modo debug
    "limit_sessions": None,  # None = todas las sesiones (obsoleto, usar debug_max_subjects)
    # Split por sujeto
    "test_size": 0.15,
    "val_size": 0.15,
    "random_state": 42,
    # Modelo CNN1D
    "n_filters": 32,  # Reducido para mayor estabilidad
    "kernel_size": 3,  # Kernel pequeño para evitar gradientes explosivos
    "dropout_rate": 0.3,  # Aumentado para regularización
    "use_residual": True,
    "use_augmentation": False,
    # Entrenamiento
    "learning_rate_initial": 3e-4,  # LR normal
    "learning_rate_min": 1e-6,  # LR mínimo
    "warmup_epochs": 3,  # Warmup epochs
    "batch_size": 64,  # Batch size base
    "epochs": 300
    if EXECUTION_MODE == "full"
    else 150,  # Más epochs para compensar steps limitados
    "use_class_weights": True,  # Activar para manejar desbalanceo
    "class_weight_clip": 1.5,  # Límite para class weights
    "std_epsilon": 1e-6,  # piso para std en normalizacion
    "clip_value": 5.0,  # Límite para valores normalizados (reducido de 8.0)
    "early_stopping_patience": 40
    if EXECUTION_MODE == "full"
    else 25,  # Más paciencia con epochs cortos
    # Mixed precision (puede causar NaN en algunos modelos)
    "use_mixed_precision": False,  # Desactivar para evitar NaN
    # Optimizacion (Optuna)
    "run_optimization": False,
    "n_optuna_trials": 30,
    # Streaming TFRecord (solo en modo full)
    "streaming": EXECUTION_MODE == "full",
    "tfrecord_dir": f"{OUTPUT_PATH}/tfrecords",
    "shuffle_buffer": 5000,
}

# Ajustar batch size para multi-GPU
CONFIG["effective_batch_size"] = CONFIG["batch_size"] * strategy.num_replicas_in_sync

# Configurar mixed precision según CONFIG (después de definir CONFIG)
if CONFIG.get("use_mixed_precision", False) and gpus:
    try:
        from tensorflow.keras import mixed_precision

        mixed_precision.set_global_policy("mixed_float16")
        print("[OK] Mixed precision (float16) habilitado")
    except Exception as e:
        print(f"[WARN] No se pudo habilitar mixed precision: {e}")
else:
    print("[INFO] Mixed precision desactivado (float32) para mayor estabilidad")

# Mostrar modo de ejecución de forma prominente
print("\n" + "=" * 60)
if EXECUTION_MODE == "debug":
    print(" MODO DEBUG: Carga rápida en RAM, subset de sujetos")
    print(f"   Max sujetos: {CONFIG['debug_max_subjects']}")
    print(f"   Epochs: {CONFIG['epochs']}")
    print("   Para entrenamiento completo, cambiar EXECUTION_MODE = 'full'")
else:
    print(" MODO FULL: Streaming TFRecord, todos los sujetos")
    print(f"   Epochs: {CONFIG['epochs']}")
    print("   Para debugging rápido, cambiar EXECUTION_MODE = 'debug'")
print("=" * 60)

print("\nConfiguracion del experimento:")
for key, value in CONFIG.items():
    print(f"   {key}: {value}")

## Carga de Datos

In [None]:
# ============================================================
# FUNCIONES DE CARGA DE DATOS
# ============================================================

# Estadios de sueño (mapeo AASM)
STAGE_CANONICAL = {
    "Sleep stage W": "W",
    "Sleep stage 1": "N1",
    "Sleep stage 2": "N2",
    "Sleep stage 3": "N3",
    "Sleep stage 4": "N3",  # AASM fusiona S3+S4 en N3
    "Sleep stage R": "REM",
}

STAGE_ORDER = ["W", "N1", "N2", "N3", "REM"]

# Canales REQUERIDOS (orden fijo para consistencia entre archivos)
REQUIRED_CHANNELS = ["EEG Fpz-Cz", "EEG Pz-Oz"]
OPTIONAL_CHANNELS = ["EOG horizontal", "EMG submental"]


def extract_subject_core(subject_id):
    """Agrupa noches del mismo sujeto para evitar data leakage.

    Sleep-EDF format: SC4XXNy where:
    - XX = subject number (00-82), 2 dígitos
    - N = night number (1 or 2)
    - y = suffix (E, F, G, etc.)

    Examples:
    - SC4011E -> sujeto 01 (noche 1) -> core: SC401
    - SC4012E -> sujeto 01 (noche 2) -> core: SC401
    - SC4001E -> sujeto 00 (noche 1) -> core: SC400
    - SC4002E -> sujeto 00 (noche 2) -> core: SC400

    IMPORTANTE: Agrupar por sujeto es crítico para evitar data leakage.
    Todas las noches del mismo sujeto deben ir al mismo split.

    Returns SC4XX (e.g., SC401 for subject 01).
    """
    sid = str(subject_id)
    # Extraer SC4 + 2 primeros dígitos (número de sujeto)
    match = re.match(r"(SC4\d{2})", sid)
    return match.group(1) if match else sid


def load_psg_data(psg_path, channels=None, target_sfreq=None):
    """Carga datos PSG desde archivo .fif.

    IMPORTANTE: Mantiene orden fijo de canales para consistencia.
    Si falta un canal requerido, lanza error.
    """
    raw = mne.io.read_raw_fif(str(psg_path), preload=True, verbose="ERROR")
    available = set(raw.ch_names)

    if channels is None:
        # Verificar canales requeridos
        missing_required = [ch for ch in REQUIRED_CHANNELS if ch not in available]
        if missing_required:
            raise ValueError(
                f"Canales requeridos faltantes en {psg_path}: {missing_required}. "
                f"Disponibles: {sorted(available)}"
            )

        # Usar canales requeridos + opcionales disponibles (orden fijo)
        channels = REQUIRED_CHANNELS.copy()
        for ch in OPTIONAL_CHANNELS:
            if ch in available:
                channels.append(ch)

    raw.pick(channels)  # pick_channels() está deprecado

    if target_sfreq and raw.info["sfreq"] != target_sfreq:
        raw.resample(target_sfreq)

    data = raw.get_data()
    return data, raw.info["sfreq"], raw.ch_names


def load_hypnogram(hyp_path):
    """Carga hipnograma desde CSV."""
    df = pd.read_csv(hyp_path)
    df["stage_canonical"] = df["description"].map(STAGE_CANONICAL)
    return df


def create_epochs(data, sfreq, epoch_length=30.0):
    """Divide datos en epochs de longitud fija."""
    samples_per_epoch = int(epoch_length * sfreq)
    n_channels, n_samples = data.shape
    n_epochs = n_samples // samples_per_epoch

    epochs = []
    epoch_times = []

    for i in range(n_epochs):
        start = i * samples_per_epoch
        end = start + samples_per_epoch
        epoch = data[:, start:end]
        epochs.append(epoch)
        epoch_times.append(i * epoch_length)

    return np.array(epochs), np.array(epoch_times)


def assign_stages(epoch_times, hypnogram, epoch_length=30.0):
    """Asigna estadios a cada epoch."""
    stages = []

    for t in epoch_times:
        epoch_center = t + epoch_length / 2
        mask = (hypnogram["onset"] <= epoch_center) & (
            hypnogram["onset"] + hypnogram["duration"] > epoch_center
        )
        matched = hypnogram[mask]

        if len(matched) > 0:
            stage = matched.iloc[0]["stage_canonical"]
            stages.append(stage)
        else:
            stages.append(None)

    return stages


print("[OK] Funciones de carga definidas")

In [None]:
# ============================================================
# PIPELINE DE DATOS (DEBUG: RAM / FULL: TFRecord)
# ============================================================


def resolve_paths(row, manifest_dir, dataset_dir_name):
    """Resuelve rutas de PSG e hipnograma respetando Kaggle/local."""

    psg_path_str = row.get("psg_trimmed_path", "")
    hyp_path_str = row.get("hypnogram_trimmed_path", "")

    # Manejar NaN
    if pd.isna(psg_path_str):
        psg_path_str = ""
    if pd.isna(hyp_path_str):
        hyp_path_str = ""

    base_data_root = manifest_dir.parent

    if psg_path_str and hyp_path_str:
        if IN_KAGGLE:
            psg_rel = Path(psg_path_str)
            hyp_rel = Path(hyp_path_str)
            psg_parts = psg_rel.parts
            hyp_parts = hyp_rel.parts

            psg_anchor_idx = next(
                (i for i, p in enumerate(psg_parts) if p.startswith("sleep_trimmed")),
                None,
            )
            hyp_anchor_idx = next(
                (i for i, p in enumerate(hyp_parts) if p.startswith("sleep_trimmed")),
                None,
            )

            if psg_anchor_idx is not None and hyp_anchor_idx is not None:
                psg_path = Path(DATA_PATH) / Path(*psg_parts[psg_anchor_idx:])
                hyp_path = Path(DATA_PATH) / Path(*hyp_parts[hyp_anchor_idx:])
            else:
                psg_path = Path(DATA_PATH) / dataset_dir_name / "psg" / psg_rel.name
                hyp_path = (
                    Path(DATA_PATH) / dataset_dir_name / "hypnograms" / hyp_rel.name
                )
        else:
            psg_rel = Path(psg_path_str)
            hyp_rel = Path(hyp_path_str)

            if not psg_rel.is_absolute():
                if psg_rel.parts and psg_rel.parts[0] == "data":
                    psg_path = base_data_root / psg_rel.relative_to("data")
                else:
                    psg_path = manifest_dir / psg_rel
            else:
                psg_path = psg_rel

            if not hyp_rel.is_absolute():
                if hyp_rel.parts and hyp_rel.parts[0] == "data":
                    hyp_path = base_data_root / hyp_rel.relative_to("data")
                else:
                    hyp_path = manifest_dir / hyp_rel
            else:
                hyp_path = hyp_rel
    else:
        subset = row.get("subset", "sleep-cassette")
        version = row.get("version", "1.0.0")
        dataset_dir = manifest_dir / dataset_dir_name
        psg_path = (
            dataset_dir
            / "psg"
            / f"{row['subject_id']}_{subset}_{version}_trimmed_raw.fif"
        )
        hyp_path = (
            dataset_dir
            / "hypnograms"
            / f"{row['subject_id']}_{subset}_{version}_trimmed_annotations.csv"
        )

    return psg_path, hyp_path


def get_subject_cores_for_mode(
    manifest_path, execution_mode, debug_max_subjects, random_state
):
    """Obtiene los cores de sujetos a usar según el modo de ejecución."""
    manifest = pd.read_csv(manifest_path)
    manifest_ok = manifest[manifest["status"] == "ok"].copy()

    all_cores = manifest_ok["subject_id"].apply(extract_subject_core).unique()
    rng = np.random.default_rng(random_state)
    rng.shuffle(all_cores)

    if execution_mode == "debug" and debug_max_subjects:
        selected_cores = set(all_cores[:debug_max_subjects])
        print(f"[DEBUG] Usando {len(selected_cores)}/{len(all_cores)} sujetos")
    else:
        selected_cores = set(all_cores)
        print(f"[FULL] Usando todos los {len(selected_cores)} sujetos")

    return selected_cores


def iter_sessions(manifest_path, epoch_length, sfreq, allowed_cores=None):
    """Itera sesiones entregando rutas resueltas (opcionalmente filtrando sujetos)."""

    manifest = pd.read_csv(manifest_path)
    manifest_ok = manifest[manifest["status"] == "ok"].copy()

    if allowed_cores is not None:
        manifest_ok = manifest_ok[
            manifest_ok["subject_id"].apply(extract_subject_core).isin(allowed_cores)
        ]

    manifest_dir = Path(manifest_path).parent
    dataset_dir_name = (
        "sleep_trimmed_resamp200"
        if (manifest_dir / "sleep_trimmed_resamp200").exists()
        else "sleep_trimmed_spt"
        if (manifest_dir / "sleep_trimmed_spt").exists()
        else "sleep_trimmed"
    )

    total_sessions = len(manifest_ok)
    print(f"\nProcesando {total_sessions} sesiones...")

    for i, (_, row) in enumerate(manifest_ok.iterrows(), start=1):
        subject_id = row["subject_id"]
        subject_core = extract_subject_core(subject_id)
        if allowed_cores is not None and subject_core not in allowed_cores:
            continue
        psg_path, hyp_path = resolve_paths(row, manifest_dir, dataset_dir_name)

        if not psg_path.exists() or not hyp_path.exists():
            continue

        yield i, total_sessions, subject_id, subject_core, psg_path, hyp_path


def update_running_stats(stats, epochs):
    """Acumula sumas y sumas cuadradas por canal."""

    if stats is None:
        stats = {"n": 0, "sum": None, "sumsq": None}

    if stats["sum"] is None:
        stats["sum"] = np.zeros(epochs.shape[1], dtype=np.float64)
        stats["sumsq"] = np.zeros(epochs.shape[1], dtype=np.float64)

    channel_sum = epochs.sum(axis=(0, 2))
    channel_sumsq = (epochs**2).sum(axis=(0, 2))
    stats["n"] += epochs.shape[0] * epochs.shape[2]
    stats["sum"] += channel_sum
    stats["sumsq"] += channel_sumsq
    return stats


def finalize_running_stats(stats, std_epsilon=1e-6):
    """Finaliza el cálculo de estadísticas running.

    Usa fórmula de varianza: Var(X) = E[X²] - E[X]²
    """
    mean = stats["sum"] / stats["n"]
    var = stats["sumsq"] / stats["n"] - mean**2
    # Asegurar varianza no negativa (puede ocurrir por errores numéricos)
    var = np.maximum(var, 0.0)
    std = np.sqrt(var + std_epsilon)  # Usar std_epsilon consistentemente
    return mean.astype(np.float32), std.astype(np.float32)


def load_all_data_to_ram(manifest_path, epoch_length, sfreq, allowed_cores):
    """Carga todos los datos en RAM (modo debug). Retorna arrays numpy y estadísticas."""

    print("\n[DEBUG] Cargando datos en RAM...")
    start_time = time.time()

    all_epochs = []
    all_stages = []
    all_subject_cores = []
    expected_channels = None
    nan_inf_count = 0

    for i, total, subject_id, subject_core, psg_path, hyp_path in iter_sessions(
        manifest_path, epoch_length, sfreq, allowed_cores=allowed_cores
    ):
        data, actual_sfreq, ch_names = load_psg_data(psg_path, target_sfreq=sfreq)

        if expected_channels is None:
            expected_channels = list(ch_names)
        elif list(ch_names) != expected_channels:
            raise ValueError(
                f"Canales inconsistentes: esperado {expected_channels}, encontrado {list(ch_names)} en {psg_path}"
            )

        hypnogram = load_hypnogram(hyp_path)
        epochs, epoch_times = create_epochs(data, actual_sfreq, epoch_length)
        stages = assign_stages(epoch_times, hypnogram, epoch_length)

        for epoch, stage in zip(epochs, stages):
            if stage not in STAGE_ORDER:
                continue
            if not np.all(np.isfinite(epoch)):
                nan_inf_count += 1
                continue

            all_epochs.append(epoch)
            all_stages.append(stage)
            all_subject_cores.append(subject_core)

        if i % 10 == 0 or i == total:
            print(f"   Cargando: {i}/{total} sesiones")

    if nan_inf_count > 0:
        print(f"[WARN] Se descartaron {nan_inf_count} epochs con NaN/Inf")

    # Convertir a arrays
    X = np.array(all_epochs, dtype=np.float32)
    y = np.array([STAGE_ORDER.index(s) for s in all_stages], dtype=np.int32)
    subject_cores = np.array(all_subject_cores)

    load_time = time.time() - start_time
    print(f"[OK] Datos cargados en {load_time:.1f} segundos")
    print(f"   Shape X: {X.shape}")
    print(f"   Shape y: {y.shape}")
    print(f"   Memoria: {X.nbytes / 1024**3:.2f} GB")

    return X, y, subject_cores, expected_channels


def split_data_by_subject(X, y, subject_cores, test_size, val_size, random_state):
    """Divide datos por sujeto (sin data leakage)."""

    unique_cores = np.unique(subject_cores)
    rng = np.random.default_rng(random_state)
    rng.shuffle(unique_cores)

    n_test = max(1, int(len(unique_cores) * test_size))
    n_val = max(1, int(len(unique_cores) * val_size))

    test_cores = set(unique_cores[:n_test])
    val_cores = set(unique_cores[n_test : n_test + n_val])
    train_cores = set(unique_cores[n_test + n_val :])

    train_mask = np.isin(subject_cores, list(train_cores))
    val_mask = np.isin(subject_cores, list(val_cores))
    test_mask = np.isin(subject_cores, list(test_cores))

    return (
        X[train_mask],
        y[train_mask],
        X[val_mask],
        y[val_mask],
        X[test_mask],
        y[test_mask],
        train_cores,
        val_cores,
        test_cores,
    )


def normalize_data(X_train, X_val, X_test, clip_value, std_epsilon):
    """Normaliza datos usando estadísticas de train (z-score por canal).

    Siguiendo estándares científicos:
    - Estadísticas calculadas SOLO en train para evitar data leakage
    - Z-score: (x - mean) / std
    - Clipping para manejar outliers

    NOTA: Usa sqrt(var + epsilon) para consistencia con modo streaming.
    """

    # Calcular mean/std por canal usando train
    mean_ch = X_train.mean(axis=(0, 2))
    var_ch = X_train.var(axis=(0, 2))
    # Usar sqrt(var + epsilon) para consistencia con finalize_running_stats
    std_ch = np.sqrt(var_ch + std_epsilon)

    # Normalizar (z-score)
    X_train_norm = (X_train - mean_ch[None, :, None]) / std_ch[None, :, None]
    X_val_norm = (X_val - mean_ch[None, :, None]) / std_ch[None, :, None]
    X_test_norm = (X_test - mean_ch[None, :, None]) / std_ch[None, :, None]

    # Clip
    X_train_norm = np.clip(X_train_norm, -clip_value, clip_value)
    X_val_norm = np.clip(X_val_norm, -clip_value, clip_value)
    X_test_norm = np.clip(X_test_norm, -clip_value, clip_value)

    return X_train_norm, X_val_norm, X_test_norm, mean_ch, std_ch


# ============================================================
# FUNCIONES PARA MODO STREAMING (TFRecord)
# ============================================================


def pass1_stats(manifest_path, epoch_length, sfreq, allowed_cores=None):
    """Primera pasada: mean/std por canal y conteo de clases (solo sujetos permitidos)."""

    stats = None
    class_counts = Counter()
    input_shape = None
    expected_channels = None
    nan_inf_count = 0

    for i, total, subject_id, subject_core, psg_path, hyp_path in iter_sessions(
        manifest_path, epoch_length, sfreq, allowed_cores=allowed_cores
    ):
        data, actual_sfreq, ch_names = load_psg_data(psg_path, target_sfreq=sfreq)

        if expected_channels is None:
            expected_channels = list(ch_names)
        elif list(ch_names) != expected_channels:
            raise ValueError(
                f"Canales inconsistentes: esperado {expected_channels}, encontrado {list(ch_names)} en {psg_path}"
            )

        hypnogram = load_hypnogram(hyp_path)
        epochs, epoch_times = create_epochs(data, actual_sfreq, epoch_length)
        stages = assign_stages(epoch_times, hypnogram, epoch_length)

        valid_mask = [s in STAGE_ORDER for s in stages]
        if not any(valid_mask):
            continue

        valid_epochs = epochs[valid_mask]
        valid_stages = [s for s in stages if s in STAGE_ORDER]

        finite_mask = np.all(np.isfinite(valid_epochs), axis=(1, 2))
        if not np.all(finite_mask):
            nan_inf_count += np.sum(~finite_mask)
            valid_epochs = valid_epochs[finite_mask]
            valid_stages = [s for s, m in zip(valid_stages, finite_mask) if m]

        if len(valid_epochs) == 0:
            continue

        if input_shape is None and len(valid_epochs) > 0:
            input_shape = valid_epochs.shape[1:]

        stats = update_running_stats(stats, valid_epochs)
        class_counts.update(valid_stages)

        if i % 20 == 0 or i == total:
            print(f"   Pasada 1: {i}/{total} sesiones")

    if nan_inf_count > 0:
        print(f"[WARN] Se descartaron {nan_inf_count} epochs con NaN/Inf en pasada 1")

    assert input_shape is not None, "No se encontraron epochs validos"
    mean, std = finalize_running_stats(
        stats, std_epsilon=CONFIG.get("std_epsilon", 1e-6)
    )
    return mean, std, class_counts, input_shape, expected_channels


def make_example(x, y):
    feature = {
        "x": tf.train.Feature(float_list=tf.train.FloatList(value=x.ravel())),
        "y": tf.train.Feature(int64_list=tf.train.Int64List(value=[y])),
    }
    return tf.train.Example(
        features=tf.train.Features(feature=feature)
    ).SerializeToString()


def assign_subject_splits(
    manifest_path, test_size, val_size, random_state, allowed_cores=None
):
    manifest = pd.read_csv(manifest_path)
    manifest_ok = manifest[manifest["status"] == "ok"].copy()

    subject_cores = manifest_ok["subject_id"].apply(extract_subject_core).unique()

    if allowed_cores is not None:
        subject_cores = np.array([c for c in subject_cores if c in allowed_cores])

    rng = np.random.default_rng(random_state)
    rng.shuffle(subject_cores)

    n_test = max(1, int(len(subject_cores) * test_size))
    n_val = max(1, int(len(subject_cores) * val_size))

    test_cores = set(subject_cores[:n_test])
    val_cores = set(subject_cores[n_test : n_test + n_val])
    train_cores = set(subject_cores[n_test + n_val :])

    split_map = dict.fromkeys(train_cores, "train")
    split_map.update(dict.fromkeys(val_cores, "val"))
    split_map.update(dict.fromkeys(test_cores, "test"))

    return split_map, {
        "train": len(train_cores),
        "val": len(val_cores),
        "test": len(test_cores),
    }


def build_tfrecord_splits(
    manifest_path,
    mean,
    std,
    split_map,
    epoch_length,
    sfreq,
    tfrecord_dir,
    expected_channels=None,
):
    tfrecord_dir = Path(tfrecord_dir)
    tfrecord_dir.mkdir(parents=True, exist_ok=True)

    writers = {
        "train": tf.io.TFRecordWriter(str(tfrecord_dir / "train.tfrecord")),
        "val": tf.io.TFRecordWriter(str(tfrecord_dir / "val.tfrecord")),
        "test": tf.io.TFRecordWriter(str(tfrecord_dir / "test.tfrecord")),
    }

    counts = Counter()
    session_counts = Counter()
    subject_sets = {k: set() for k in writers}
    skipped_nan_inf = 0

    clip_value = CONFIG.get("clip_value", 5.0)
    allowed_cores = set(split_map.keys())

    for i, total, subject_id, subject_core, psg_path, hyp_path in iter_sessions(
        manifest_path, epoch_length, sfreq, allowed_cores=allowed_cores
    ):
        split = split_map.get(subject_core)
        if split is None:
            continue
        session_counts[split] += 1
        subject_sets[split].add(subject_core)

        data, actual_sfreq, ch_names = load_psg_data(psg_path, target_sfreq=sfreq)

        if expected_channels is not None and list(ch_names) != expected_channels:
            raise ValueError(
                f"Canales inconsistentes en {psg_path}: esperado {expected_channels}, encontrado {list(ch_names)}"
            )

        hypnogram = load_hypnogram(hyp_path)
        epochs, epoch_times = create_epochs(data, actual_sfreq, epoch_length)
        stages = assign_stages(epoch_times, hypnogram, epoch_length)

        for epoch, stage in zip(epochs, stages):
            if stage not in STAGE_ORDER:
                continue

            if not np.all(np.isfinite(epoch)):
                skipped_nan_inf += 1
                continue

            y = STAGE_ORDER.index(stage)
            # std ya incluye epsilon desde finalize_running_stats (sqrt(var + epsilon))
            # No agregar epsilon adicional aquí para consistencia con modo RAM
            x = (epoch - mean[:, None]) / std[:, None]
            x = np.clip(x, -clip_value, clip_value)

            if not np.all(np.isfinite(x)):
                skipped_nan_inf += 1
                continue

            writers[split].write(make_example(x.astype(np.float32), y))
            counts[split] += 1

        if i % 20 == 0 or i == total:
            print(f"   Pasada 2 ({split}): {i}/{total} sesiones")

    for w in writers.values():
        w.close()

    if skipped_nan_inf > 0:
        print(
            f"[WARN] Se descartaron {skipped_nan_inf} epochs con NaN/Inf al escribir TFRecords"
        )

    tf_paths = {split: str(tfrecord_dir / f"{split}.tfrecord") for split in writers}
    subject_counts = {split: len(subject_sets[split]) for split in subject_sets}
    return tf_paths, counts, session_counts, subject_counts


def make_dataset(tfrecord_path, input_shape, batch_size, shuffle=False, repeat=False):
    feature_description = {
        "x": tf.io.FixedLenFeature([input_shape[0] * input_shape[1]], tf.float32),
        "y": tf.io.FixedLenFeature([], tf.int64),
    }

    clip_value = CONFIG.get("clip_value", 5.0)

    def _parse(example_proto):
        example = tf.io.parse_single_example(example_proto, feature_description)
        x = tf.reshape(example["x"], input_shape)
        x = tf.clip_by_value(x, -clip_value, clip_value)
        x = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x))
        y = tf.cast(example["y"], tf.int32)
        return x, y

    ds = tf.data.TFRecordDataset([tfrecord_path], num_parallel_reads=tf.data.AUTOTUNE)
    ds = ds.map(_parse, num_parallel_calls=tf.data.AUTOTUNE)
    if shuffle:
        ds = ds.shuffle(
            CONFIG["shuffle_buffer"],
            seed=CONFIG["random_state"],
            reshuffle_each_iteration=True,
        )
    if repeat:
        ds = ds.repeat()
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds


# ============================================================
# EJECUTAR PIPELINE SEGÚN MODO
# ============================================================

print(f"\n[INFO] Modo de ejecución: {CONFIG['execution_mode'].upper()}")

# Obtener sujetos a usar según modo
selected_cores = get_subject_cores_for_mode(
    CONFIG["manifest_path"],
    CONFIG["execution_mode"],
    CONFIG["debug_max_subjects"],
    CONFIG["random_state"],
)

if CONFIG["streaming"]:
    # =========== MODO FULL: TFRecord ===========
    print("\n[FULL] Iniciando pipeline TFRecord (streaming)...")

    tfrecord_path = Path(CONFIG["tfrecord_dir"])
    if tfrecord_path.exists():
        print(f"[INFO] Limpiando TFRecords anteriores en {tfrecord_path}")
        shutil.rmtree(tfrecord_path)

    # Asignar splits
    split_map, split_subjects = assign_subject_splits(
        CONFIG["manifest_path"],
        CONFIG["test_size"],
        CONFIG["val_size"],
        CONFIG["random_state"],
        allowed_cores=selected_cores,
    )
    train_cores = {core for core, split in split_map.items() if split == "train"}
    val_cores = {core for core, split in split_map.items() if split == "val"}
    test_cores = {core for core, split in split_map.items() if split == "test"}

    # Calcular estadísticas SOLO en train (evita data leakage)
    mean_ch, std_ch, class_counts_train, input_shape, expected_channels = pass1_stats(
        CONFIG["manifest_path"],
        CONFIG["epoch_length"],
        CONFIG["sfreq"],
        allowed_cores=train_cores,
    )

    print("\n[OK] Estadísticas calculadas")
    print(f"   mean: {mean_ch}")
    print(f"   std:  {std_ch}")

    if np.any(std_ch < 1e-6):
        print("[WARN] Algunos canales tienen std muy bajo, ajustando...")
        std_ch = np.maximum(std_ch, 1e-6)

    # Escribir TFRecords
    tfrecord_paths, split_counts, split_session_counts, split_subject_counts = (
        build_tfrecord_splits(
            CONFIG["manifest_path"],
            mean_ch,
            std_ch,
            split_map,
            CONFIG["epoch_length"],
            CONFIG["sfreq"],
            CONFIG["tfrecord_dir"],
            expected_channels=expected_channels,
        )
    )

    print("\n[OK] TFRecords generados:")
    for split, path in tfrecord_paths.items():
        print(f"   {split}: {path} ({split_counts[split]:,} epochs)")

    INPUT_SHAPE = input_shape
    train_count = split_counts.get("train", 0)
    val_count = split_counts.get("val", 0)
    test_count = split_counts.get("test", 0)

    # Crear datasets
    train_ds = make_dataset(
        tfrecord_paths["train"],
        INPUT_SHAPE,
        CONFIG["effective_batch_size"],
        shuffle=True,
        repeat=True,
    )
    val_ds = make_dataset(
        tfrecord_paths["val"],
        INPUT_SHAPE,
        CONFIG["effective_batch_size"],
        shuffle=False,
        repeat=True,
    )
    test_ds = make_dataset(
        tfrecord_paths["test"],
        INPUT_SHAPE,
        CONFIG["effective_batch_size"],
        shuffle=False,
        repeat=False,
    )

    USE_STREAMING = True

else:
    # =========== MODO DEBUG: RAM ===========
    print("\n[DEBUG] Iniciando pipeline en RAM...")

    # Cargar datos
    X, y, subject_cores_arr, expected_channels = load_all_data_to_ram(
        CONFIG["manifest_path"],
        CONFIG["epoch_length"],
        CONFIG["sfreq"],
        selected_cores,
    )

    # Split por sujeto
    (
        X_train,
        y_train,
        X_val,
        y_val,
        X_test,
        y_test,
        train_cores,
        val_cores,
        test_cores,
    ) = split_data_by_subject(
        X,
        y,
        subject_cores_arr,
        CONFIG["test_size"],
        CONFIG["val_size"],
        CONFIG["random_state"],
    )

    print("\n[OK] División por sujeto:")
    print(f"   Train: {len(X_train):,} epochs ({len(train_cores)} sujetos)")
    print(f"   Val:   {len(X_val):,} epochs ({len(val_cores)} sujetos)")
    print(f"   Test:  {len(X_test):,} epochs ({len(test_cores)} sujetos)")

    # Normalizar
    X_train_norm, X_val_norm, X_test_norm, mean_ch, std_ch = normalize_data(
        X_train,
        X_val,
        X_test,
        CONFIG["clip_value"],
        CONFIG["std_epsilon"],
    )

    print("\n[OK] Normalización aplicada")
    print(f"   mean: {mean_ch}")
    print(f"   std:  {std_ch}")

    # Liberar memoria de datos no normalizados
    del X, X_train, X_val, X_test
    gc.collect()

    INPUT_SHAPE = X_train_norm.shape[1:]
    train_count = len(X_train_norm)
    val_count = len(X_val_norm)
    test_count = len(X_test_norm)

    # Crear datasets tf.data (más eficiente que numpy arrays directos)
    train_ds = tf.data.Dataset.from_tensor_slices((X_train_norm, y_train))
    train_ds = train_ds.shuffle(
        min(10000, train_count),
        seed=CONFIG["random_state"],
        reshuffle_each_iteration=True,  # Importante: shuffle diferente cada epoch
    )
    train_ds = train_ds.batch(CONFIG["effective_batch_size"]).prefetch(tf.data.AUTOTUNE)

    val_ds = tf.data.Dataset.from_tensor_slices((X_val_norm, y_val))
    val_ds = val_ds.batch(CONFIG["effective_batch_size"]).prefetch(tf.data.AUTOTUNE)

    test_ds = tf.data.Dataset.from_tensor_slices((X_test_norm, y_test))
    test_ds = test_ds.batch(CONFIG["effective_batch_size"]).prefetch(tf.data.AUTOTUNE)

    # Conteo de clases para class weights
    class_counts_train = Counter([STAGE_ORDER[yi] for yi in y_train])

    # Variable para consistencia con modo streaming (no usado en RAM)
    tfrecord_paths = None

    USE_STREAMING = False

# ============================================================
# COMÚN: Class weights y verificación
# ============================================================

le = LabelEncoder()
le.classes_ = np.array(STAGE_ORDER)  # fijar orden explícito para inverse_transform

# Class weights
counts_list = [class_counts_train.get(stage, 0) for stage in STAGE_ORDER]
if any(c == 0 for c in counts_list):
    print("[WARN] Alguna clase no apareció en train; ajustando conteos mínimos a 1")
    counts_list = [max(c, 1) for c in counts_list]

y_for_weights = np.repeat(np.arange(len(STAGE_ORDER)), counts_list)
class_weights_arr = compute_class_weight(
    "balanced", classes=np.arange(len(STAGE_ORDER)), y=y_for_weights
)

class_weight_clip = CONFIG.get("class_weight_clip", 1.5)
class_weights = dict(enumerate(class_weights_arr))
class_weights = {
    k: float(np.clip(v, 0.5, class_weight_clip)) for k, v in class_weights.items()
}

if not CONFIG.get("use_class_weights", True):
    class_weights = None
    print("Class weights desactivados")
else:
    print(f"Class weights (clip 0.5-{class_weight_clip}): {class_weights}")

# Verificación de integridad EXHAUSTIVA
print("\n[CHECK] Verificando integridad de datos...")

# Verificar TODO el dataset de entrenamiento (no solo un batch)
if not USE_STREAMING:
    # Modo RAM: tenemos acceso directo a los arrays
    print("   Verificando datos normalizados completos...")
    train_min = np.min(X_train_norm)
    train_max = np.max(X_train_norm)
    train_mean = np.mean(X_train_norm)
    train_std = np.std(X_train_norm)
    has_nan = np.any(np.isnan(X_train_norm))
    has_inf = np.any(np.isinf(X_train_norm))

    # Estadísticas por canal
    for ch_idx in range(X_train_norm.shape[1]):
        ch_data = X_train_norm[:, ch_idx, :]
        ch_min, ch_max = np.min(ch_data), np.max(ch_data)
        ch_std = np.std(ch_data)
        print(
            f"   Canal {ch_idx}: rango=[{ch_min:.3f}, {ch_max:.3f}], std={ch_std:.3f}"
        )

    print(f"\n   [TRAIN] Shape: {X_train_norm.shape}")
    print(f"   [TRAIN] Rango global: [{train_min:.4f}, {train_max:.4f}]")
    print(f"   [TRAIN] Mean: {train_mean:.4f}, Std: {train_std:.4f}")
    print(f"   [TRAIN] Has NaN: {has_nan}")
    print(f"   [TRAIN] Has Inf: {has_inf}")

    # Verificar distribución de valores extremos
    extreme_count = np.sum(np.abs(X_train_norm) > 4.0)
    total_values = X_train_norm.size
    extreme_pct = extreme_count / total_values * 100
    print(f"   [TRAIN] Valores |x| > 4.0: {extreme_count:,} ({extreme_pct:.2f}%)")

    if has_nan or has_inf:
        print("\n   [ERROR] ¡Se detectaron NaN o Inf en los datos de entrenamiento!")
        print("   Esto causará problemas durante el entrenamiento.")
        raise ValueError("Datos de entrenamiento contienen NaN o Inf")
else:
    # Modo streaming: verificar sample batch
    sample_batch = next(iter(train_ds.take(1)))
    x_sample, y_sample = sample_batch
    print(f"   Sample batch shape: {x_sample.shape}")
    print(
        f"   Sample x range: [{tf.reduce_min(x_sample).numpy():.4f}, {tf.reduce_max(x_sample).numpy():.4f}]"
    )
    print(f"   Sample x mean: {tf.reduce_mean(x_sample).numpy():.4f}")
    print(f"   Has NaN: {tf.reduce_any(tf.math.is_nan(x_sample)).numpy()}")
    print(f"   Has Inf: {tf.reduce_any(tf.math.is_inf(x_sample)).numpy()}")

# Resumen
print("\n" + "=" * 60)
print(f"[OK] Pipeline {'STREAMING' if USE_STREAMING else 'RAM'} listo")
print(f"   Input shape: {INPUT_SHAPE}")
print(f"   Train: {train_count:,} epochs")
print(f"   Val:   {val_count:,} epochs")
print(f"   Test:  {test_count:,} epochs")
print("=" * 60)

# Distribución de clases
print("\n[INFO] Distribución de clases en Train:")
train_class_counts = [class_counts_train.get(stage, 0) for stage in STAGE_ORDER]
total_train = sum(train_class_counts)
print("   Clase     Epochs      %")
print("   " + "-" * 25)
for stage, count in zip(STAGE_ORDER, train_class_counts):
    pct = count / total_train * 100
    print(f"   {stage:5s}   {count:7,d}   {pct:5.1f}%")

# Gráfico
fig, ax = plt.subplots(figsize=(8, 4))
bars = ax.bar(STAGE_ORDER, train_class_counts, color=sns.color_palette("husl", 5))
ax.set_xlabel("Estadio de sueño")
ax.set_ylabel("Número de epochs")
ax.set_title(
    f"Distribución de clases en Train ({CONFIG['execution_mode'].upper()} mode)"
)
for bar, count in zip(bars, train_class_counts):
    pct = count / total_train * 100
    ax.annotate(
        f"{count:,}\n({pct:.1f}%)",
        xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()),
        ha="center",
        va="bottom",
        fontsize=9,
    )
plt.tight_layout()
plt.savefig(f"{OUTPUT_PATH}/class_distribution_train.png", dpi=150)
plt.show()

In [None]:
# ============================================================
# RESUMEN DE DIVISIÓN (info ya mostrada arriba)
# ============================================================

print(f"[INFO] Modo: {CONFIG['execution_mode'].upper()}")
print(f"   Train epochs: {train_count:,}")
print(f"   Val epochs:   {val_count:,}")
print(f"   Test epochs:  {test_count:,}")
print(f"   Streaming: {USE_STREAMING}")

In [None]:
# ============================================================
# NORMALIZACION Y PREPARACION (info ya mostrada arriba)
# ============================================================

print(f"[INFO] Normalización aplicada {'en TFRecords' if USE_STREAMING else 'en RAM'}")
print(f"   Mean por canal: {mean_ch}")
print(f"   Std por canal:  {std_ch}")
print(f"   Clases: {STAGE_ORDER}")

## Arquitectura CNN1D

In [None]:
# ============================================================
# MODELO CNN1D CON CONEXIONES RESIDUALES
# ============================================================


def build_cnn1d_model(
    input_shape,
    n_classes=5,
    n_filters=32,
    kernel_size=3,
    dropout_rate=0.3,
    lr_schedule=None,  # Acepta LearningRateSchedule o float
    use_residual=True,
    use_augmentation=False,
):
    """Construye modelo CNN1D optimizado para sleep staging.

    NOTA: Arquitectura simplificada para mayor estabilidad numérica.
    - Usa ReLU en lugar de GELU (más estable)
    - BatchNorm con momentum alto para estabilidad
    - Inicialización conservadora
    """

    # Input: (n_channels, n_samples) - float32 explícito
    input_layer = keras.Input(shape=input_shape, name="input", dtype="float32")

    # Transponer: (n_channels, n_samples) -> (n_samples, n_channels)
    x = layers.Permute((2, 1))(input_layer)

    # Data augmentation (solo durante training) - reducido para estabilidad
    if use_augmentation:
        x = layers.GaussianNoise(0.05)(x)  # Reducido de 0.1

    # Bloque 1
    x = layers.Conv1D(
        n_filters,
        kernel_size,
        padding="same",
        kernel_initializer="he_uniform",  # Más estable que he_normal
        kernel_regularizer=keras.regularizers.l2(1e-4),
    )(x)
    x = layers.BatchNormalization()(x)  # Default momentum=0.99
    x = layers.Activation("relu")(x)  # ReLU más estable que GELU
    x = layers.MaxPooling1D(2)(x)
    x = layers.Dropout(dropout_rate)(x)

    # Bloque 2 con residual
    conv2_input = x
    x = layers.Conv1D(
        n_filters * 2,
        kernel_size,
        padding="same",
        kernel_initializer="he_uniform",
        kernel_regularizer=keras.regularizers.l2(1e-4),
    )(x)
    x = layers.BatchNormalization(momentum=0.99)(x)

    if use_residual:
        conv2_input_adj = layers.Conv1D(
            n_filters * 2, 1, padding="same", kernel_initializer="he_uniform"
        )(conv2_input)
        conv2_input_adj = layers.BatchNormalization(momentum=0.99)(conv2_input_adj)
        x = layers.Add()([x, conv2_input_adj])
    x = layers.Activation("relu")(x)

    x = layers.MaxPooling1D(2)(x)
    x = layers.Dropout(dropout_rate)(x)

    # Bloque 3 con residual
    conv3_input = x
    x = layers.Conv1D(
        n_filters * 4,
        kernel_size,
        padding="same",
        kernel_initializer="he_uniform",
        kernel_regularizer=keras.regularizers.l2(1e-4),
    )(x)
    x = layers.BatchNormalization(momentum=0.99)(x)

    if use_residual:
        conv3_input_adj = layers.Conv1D(
            n_filters * 4, 1, padding="same", kernel_initializer="he_uniform"
        )(conv3_input)
        conv3_input_adj = layers.BatchNormalization(momentum=0.99)(conv3_input_adj)
        x = layers.Add()([x, conv3_input_adj])
    x = layers.Activation("relu")(x)

    # Global pooling
    x = layers.GlobalAveragePooling1D()(x)

    # Capas densas - simplificadas
    x = layers.Dense(
        64,  # Reducido de 128
        kernel_initializer="he_uniform",
        kernel_regularizer=keras.regularizers.l2(1e-4),
    )(x)
    x = layers.BatchNormalization(momentum=0.99)(x)
    x = layers.Activation("relu")(x)
    x = layers.Dropout(dropout_rate)(x)

    # Output - usar linear + float32 para estabilidad numérica
    # La softmax se computa dentro de la loss con from_logits=True
    output_layer = layers.Dense(
        n_classes,
        kernel_initializer="glorot_uniform",  # Xavier para capa final
        name="output",
        dtype="float32",
    )(x)

    model = keras.Model(
        inputs=input_layer, outputs=output_layer, name="CNN1D_SleepStaging"
    )

    # Compilar con from_logits=True para mejor estabilidad numérica
    model.compile(
        optimizer=keras.optimizers.Adam(
            learning_rate=lr_schedule if lr_schedule is not None else 1e-5,
            clipnorm=1.0,  # Gradient clipping por norma
        ),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )

    return model


def create_lr_schedule(
    initial_lr, min_lr, warmup_epochs, total_epochs, steps_per_epoch
):
    """Crea un learning rate schedule con warmup lineal + cosine decay.

    - Warmup: LR sube linealmente de min_lr a initial_lr durante warmup_epochs
    - Cosine decay: LR baja suavemente de initial_lr a min_lr hasta el final
    """
    warmup_steps = warmup_epochs * steps_per_epoch
    total_steps = total_epochs * steps_per_epoch
    decay_steps = max(total_steps - warmup_steps, 1)

    class WarmupCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
        """LR schedule con warmup lineal + cosine decay a nivel de steps."""

        def __init__(self, initial_lr, min_lr, warmup_steps, decay_steps):
            super().__init__()
            self.initial_lr = tf.cast(initial_lr, tf.float32)
            self.min_lr = tf.cast(min_lr, tf.float32)
            self.warmup_steps = tf.cast(warmup_steps, tf.float32)
            self.decay_steps = tf.cast(decay_steps, tf.float32)

        def __call__(self, step):
            step = tf.cast(step, tf.float32)

            # Linear warmup
            warmup_progress = step / tf.maximum(self.warmup_steps, 1.0)
            warmup_lr = self.min_lr + (self.initial_lr - self.min_lr) * tf.minimum(
                warmup_progress, 1.0
            )

            # Cosine decay after warmup
            decay_progress = (step - self.warmup_steps) / self.decay_steps
            decay_progress = tf.minimum(tf.maximum(decay_progress, 0.0), 1.0)
            cosine_decay = 0.5 * (
                1.0 + tf.cos(tf.constant(np.pi, dtype=tf.float32) * decay_progress)
            )
            decay_lr = self.min_lr + (self.initial_lr - self.min_lr) * cosine_decay

            # Seleccionar fase
            return tf.cond(
                step < self.warmup_steps, lambda: warmup_lr, lambda: decay_lr
            )

        def get_config(self):
            return {
                "initial_lr": float(self.initial_lr.numpy())
                if hasattr(self.initial_lr, "numpy")
                else float(self.initial_lr),
                "min_lr": float(self.min_lr.numpy())
                if hasattr(self.min_lr, "numpy")
                else float(self.min_lr),
                "warmup_steps": int(self.warmup_steps.numpy())
                if hasattr(self.warmup_steps, "numpy")
                else int(self.warmup_steps),
                "decay_steps": int(self.decay_steps.numpy())
                if hasattr(self.decay_steps, "numpy")
                else int(self.decay_steps),
            }

    return WarmupCosineDecay(initial_lr, min_lr, warmup_steps, decay_steps)


print("[OK] Arquitectura CNN1D definida")
print("[OK] Learning rate schedule (warmup + cosine decay) definido")

In [None]:
# ============================================================
# CREAR MODELO (con estrategia multi-GPU si está disponible)
# ============================================================

input_shape = INPUT_SHAPE
print(f"Input shape: {input_shape}")

# Limitar steps por epoch para evitar NaN ~batch 119-121
MAX_STEPS_PER_EPOCH = 115  # aplicar a ambos modos

full_steps_train = math.ceil(train_count / CONFIG["effective_batch_size"])
full_steps_val = (
    math.ceil(val_count / CONFIG["effective_batch_size"]) if val_count else 0
)

if USE_STREAMING:
    steps_per_epoch_train = min(full_steps_train, MAX_STEPS_PER_EPOCH)
    steps_per_epoch_val = (
        min(full_steps_val, MAX_STEPS_PER_EPOCH) if full_steps_val else None
    )
else:
    steps_per_epoch_train = min(full_steps_train, MAX_STEPS_PER_EPOCH)
    steps_per_epoch_val = None

# Crear learning rate schedule usando los steps efectivos
lr_schedule = create_lr_schedule(
    initial_lr=CONFIG["learning_rate_initial"],
    min_lr=CONFIG["learning_rate_min"],
    warmup_epochs=CONFIG["warmup_epochs"],
    total_epochs=CONFIG["epochs"],
    steps_per_epoch=steps_per_epoch_train,
)

print("\n[INFO] Learning Rate Schedule:")
print(f"   Initial LR: {CONFIG['learning_rate_initial']}")
print(f"   Min LR: {CONFIG['learning_rate_min']}")
print(f"   Warmup epochs: {CONFIG['warmup_epochs']}")
print(f"   Total epochs: {CONFIG['epochs']}")
print(f"   Steps per epoch (train): {steps_per_epoch_train}")
print(f"   Warmup steps: {CONFIG['warmup_epochs'] * steps_per_epoch_train}")
if USE_STREAMING:
    print(f"   Full steps train: {full_steps_train} -> cap {MAX_STEPS_PER_EPOCH}")
    print(f"   Full steps val:   {full_steps_val} -> using {steps_per_epoch_val}")
else:
    print(f"   Full steps train: {full_steps_train} -> cap {MAX_STEPS_PER_EPOCH}")

with strategy.scope():
    model = build_cnn1d_model(
        input_shape=input_shape,
        n_classes=len(STAGE_ORDER),
        n_filters=CONFIG["n_filters"],
        kernel_size=CONFIG["kernel_size"],
        dropout_rate=CONFIG["dropout_rate"],
        lr_schedule=lr_schedule,
        use_residual=CONFIG["use_residual"],
        use_augmentation=CONFIG["use_augmentation"],
    )

model.summary()

## Entrenamiento

In [None]:
# ============================================================
# CALLBACKS
# ============================================================

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_name = f"cnn1d_{CONFIG['execution_mode']}_{timestamp}"


# Custom callback para métricas de sleep staging
class SleepMetricsCallback(Callback):
    """Calcula F1-Macro y Kappa al final de cada epoch. Guarda mejor modelo por F1-Macro."""

    def __init__(
        self, val_ds, val_steps, stage_order, use_streaming=False, eval_every=1
    ):
        super().__init__()
        self.val_ds = val_ds
        self.val_steps = val_steps
        self.stage_order = stage_order
        self.use_streaming = use_streaming
        self.eval_every = max(1, eval_every)
        self.best_f1_macro = -1.0
        self.best_kappa = -1.0
        self.best_weights = None

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}

        # Permite espaciar la evaluación para reducir tiempo (útil en Kaggle)
        if (epoch + 1) % self.eval_every != 0:
            return

        y_true_list = []
        y_pred_list = []

        if self.use_streaming:
            # Streaming: iterar con límite de steps (dataset tiene repeat)
            for batch_idx, (x_batch, y_batch) in enumerate(self.val_ds):
                if batch_idx >= self.val_steps:
                    break
                y_pred_batch = self.model.predict(x_batch, verbose=0)
                y_pred_list.append(np.argmax(y_pred_batch, axis=1))
                y_true_list.append(y_batch.numpy())
        else:
            # RAM: iterar todo el dataset (sin repeat, recrea cada epoch)
            for x_batch, y_batch in self.val_ds:
                y_pred_batch = self.model.predict(x_batch, verbose=0)
                y_pred_list.append(np.argmax(y_pred_batch, axis=1))
                y_true_list.append(y_batch.numpy())

        y_true = np.concatenate(y_true_list)
        y_pred = np.concatenate(y_pred_list)

        kappa = cohen_kappa_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)

        logs["val_kappa"] = kappa
        logs["val_f1_macro"] = f1

        if f1 > self.best_f1_macro:
            self.best_f1_macro = f1
            self.best_kappa = kappa
            self.best_weights = self.model.get_weights()
            print(f" - val_f1_macro mejoró a {f1:.4f} (kappa={kappa:.4f})")

        print(f" - val_f1_macro: {f1:.4f} - val_kappa: {kappa:.4f}")

    def restore_best_weights(self):
        if self.best_weights is not None:
            self.model.set_weights(self.best_weights)
            print(
                f"[OK] Restaurados pesos del mejor modelo "
                f"(F1-Macro={self.best_f1_macro:.4f}, Kappa={self.best_kappa:.4f})"
            )


class NaNDebugCallback(Callback):
    """Detecta NaN en loss. Hace diagnóstico pero NO detiene el entrenamiento."""

    def __init__(self):
        super().__init__()
        self.nan_count = 0
        self.last_valid_loss = None
        self.first_nan_batch = None
        self.diagnosed = False

    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        loss = logs.get("loss", 0)

        if not (np.isnan(loss) or np.isinf(loss)):
            self.last_valid_loss = loss
            return

        # NaN detectado
        self.nan_count += 1

        if self.first_nan_batch is None:
            self.first_nan_batch = batch
            print(
                f"\n[WARN] Primer NaN en batch {batch} (loss={loss}, último válido={self.last_valid_loss})"
            )

            # Diagnóstico detallado solo la primera vez
            if not self.diagnosed:
                self.diagnosed = True
                try:
                    max_weight = 0
                    problem_layers = []
                    for layer in self.model.layers:
                        weights = layer.get_weights()
                        for w in weights:
                            w_max = np.max(np.abs(w))
                            max_weight = max(max_weight, w_max)
                            if np.any(np.isnan(w)) or np.any(np.isinf(w)):
                                problem_layers.append(layer.name)
                    print(f"   Max |peso|: {max_weight:.2e}")
                    if problem_layers:
                        print(f"   Capas con NaN/Inf: {problem_layers}")
                except Exception as e:
                    print(f"   Error en diagnóstico: {e}")

    def on_epoch_end(self, epoch, logs=None):
        if self.nan_count > 0:
            print(
                f"   [INFO] Epoch {epoch+1}: {self.nan_count} batches con NaN (primero en batch {self.first_nan_batch})"
            )
        self.nan_count = 0
        self.first_nan_batch = None


# Calcular val_steps para el callback
if USE_STREAMING:
    val_steps_for_callback = math.ceil(val_count / CONFIG["effective_batch_size"])
else:
    val_steps_for_callback = None  # No se usa en modo RAM

metrics_callback = SleepMetricsCallback(
    val_ds=val_ds,
    val_steps=val_steps_for_callback,
    stage_order=STAGE_ORDER,
    use_streaming=USE_STREAMING,
)

# NOTA: El orden de callbacks importa. metrics_callback debe estar antes de
# EarlyStopping y ModelCheckpoint porque agrega val_f1_macro a los logs.
callbacks = [
    NaNDebugCallback(),  # Solo log, NUNCA detiene el entrenamiento
    # NO usamos TerminateOnNaN - dejamos que el modelo se recupere
    # NO usamos ReduceLROnPlateau - incompatible con LearningRateSchedule
    metrics_callback,  # Agrega val_f1_macro y val_kappa a logs
    EarlyStopping(
        monitor="val_f1_macro",
        mode="max",
        patience=CONFIG["early_stopping_patience"],
        restore_best_weights=False,  # Usamos metrics_callback para esto
        verbose=1,
    ),
    ModelCheckpoint(
        filepath=f"{OUTPUT_PATH}/{model_name}_best.keras",
        monitor="val_f1_macro",
        mode="max",
        save_best_only=True,
        verbose=1,
    ),
]

print(f"Modelo: {model_name}")
print(f"Checkpoint: {OUTPUT_PATH}/{model_name}_best.keras")
print("[INFO] F1-Macro (selección) y Kappa (reporte) se calculan cada epoch")

## Reanudacion desde Checkpoint (opcional)

Si Kaggle se desconecto durante el entrenamiento, puedes reanudar desde el ultimo checkpoint.
Solo ejecuta la siguiente celda si necesitas reanudar.

In [None]:
# ============================================================
# REANUDAR DESDE CHECKPOINT (ejecutar solo si es necesario)
# ============================================================
# Descomenta y ajusta el nombre del checkpoint si necesitas reanudar

RESUME_FROM_CHECKPOINT = False  # Cambiar a True para reanudar
CHECKPOINT_NAME = None  # Ejemplo: "cnn1d_20251125_143022_best.keras"

if RESUME_FROM_CHECKPOINT and CHECKPOINT_NAME:
    checkpoint_path = f"{OUTPUT_PATH}/{CHECKPOINT_NAME}"
    if os.path.exists(checkpoint_path):
        print(f"[INFO] Cargando modelo desde checkpoint: {checkpoint_path}")
        with strategy.scope():
            model = keras.models.load_model(checkpoint_path)
        print("[OK] Modelo cargado exitosamente")
        print("[INFO] Puedes continuar el entrenamiento ejecutando la celda de fit()")
        print("[INFO] El modelo continuara desde donde quedo")
    else:
        print(f"[ERROR] Checkpoint no encontrado: {checkpoint_path}")
        print(f"[INFO] Archivos disponibles en {OUTPUT_PATH}:")
        for f in os.listdir(OUTPUT_PATH):
            if f.endswith(".keras"):
                print(f"   - {f}")
else:
    print("[INFO] Modo normal: se usara el modelo recien creado")
    print("[INFO] Para reanudar desde checkpoint, cambia RESUME_FROM_CHECKPOINT=True")

In [None]:
# ============================================================
# ENTRENAR MODELO
# ============================================================

print(f"\nIniciando entrenamiento ({CONFIG['execution_mode'].upper()} mode)...")
print(f"   Batch size efectivo: {CONFIG['effective_batch_size']}")
print(f"   Epochs maximos: {CONFIG['epochs']}")

# Calcular steps según modo (ya calculados en celda de modelo)
print(f"   Steps train (efectivos): {steps_per_epoch_train}")
if USE_STREAMING:
    print(f"   Steps val (efectivos): {steps_per_epoch_val}")
else:
    print("   Steps val: automático")

training_start_time = time.time()

# Verificar class_weights antes de entrenar
if CONFIG.get("use_class_weights", True) and class_weights:
    print(f"   Class weights: {class_weights}")
else:
    print("   Class weights: DESACTIVADOS")

# Ajustar datasets para streaming (necesitan repeat)
if USE_STREAMING:
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        steps_per_epoch=steps_per_epoch_train,
        validation_steps=steps_per_epoch_val,
        epochs=CONFIG["epochs"],
        class_weight=class_weights if CONFIG.get("use_class_weights", True) else None,
        callbacks=callbacks,
        verbose=1,
    )
else:
    # Modo RAM: usar steps_per_epoch para evitar NaN en batch ~120
    # Necesitamos repeat() en train_ds para que no se agote
    train_ds_repeat = train_ds.repeat()
    history = model.fit(
        train_ds_repeat,
        validation_data=val_ds,
        steps_per_epoch=steps_per_epoch_train,
        epochs=CONFIG["epochs"],
        class_weight=class_weights if CONFIG.get("use_class_weights", True) else None,
        callbacks=callbacks,
        verbose=1,
    )

training_time = time.time() - training_start_time
print(
    f"\n[OK] Entrenamiento completado en {training_time / 60:.2f} minutos ({training_time / 3600:.2f} horas)"
)

gc.collect()

# Restaurar mejores pesos según F1-Macro
metrics_callback.restore_best_weights()

In [None]:
# ============================================================
# VISUALIZAR CURVAS DE APRENDIZAJE
# ============================================================

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss
axes[0, 0].plot(history.history["loss"], label="Train Loss", linewidth=2)
axes[0, 0].plot(history.history["val_loss"], label="Val Loss", linewidth=2)
axes[0, 0].set_xlabel("Epoch")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].set_title("Loss durante entrenamiento")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy
axes[0, 1].plot(history.history["accuracy"], label="Train Acc", linewidth=2)
axes[0, 1].plot(history.history["val_accuracy"], label="Val Acc", linewidth=2)
axes[0, 1].set_xlabel("Epoch")
axes[0, 1].set_ylabel("Accuracy")
axes[0, 1].set_title("Accuracy durante entrenamiento")
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Cohen's Kappa (para comparación con literatura)
if "val_kappa" in history.history:
    axes[1, 0].plot(
        history.history["val_kappa"], label="Val Kappa", linewidth=2, color="green"
    )
    axes[1, 0].axhline(
        y=metrics_callback.best_kappa,
        color="red",
        linestyle="--",
        label=f"Best Kappa={metrics_callback.best_kappa:.4f}",
    )
    axes[1, 0].set_xlabel("Epoch")
    axes[1, 0].set_ylabel("Cohen's Kappa")
    axes[1, 0].set_title("Cohen's Kappa (para comparación con literatura)")
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
else:
    axes[1, 0].text(
        0.5, 0.5, "Kappa no disponible", ha="center", va="center", fontsize=12
    )
    axes[1, 0].set_title("Cohen's Kappa")

# F1 Macro (métrica de selección de modelo)
if "val_f1_macro" in history.history:
    axes[1, 1].plot(
        history.history["val_f1_macro"],
        label="Val F1-Macro",
        linewidth=2,
        color="purple",
    )
    axes[1, 1].axhline(
        y=metrics_callback.best_f1_macro,
        color="red",
        linestyle="--",
        label=f"Best F1={metrics_callback.best_f1_macro:.4f}",
    )
    axes[1, 1].set_xlabel("Epoch")
    axes[1, 1].set_ylabel("F1-Macro")
    axes[1, 1].set_title("F1-Macro (métrica de selección)")
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
else:
    axes[1, 1].text(
        0.5, 0.5, "F1-Macro no disponible", ha="center", va="center", fontsize=12
    )
    axes[1, 1].set_title("F1-Macro")

plt.tight_layout()
plt.savefig(f"{OUTPUT_PATH}/{model_name}_training_curves.png", dpi=150)
plt.show()

## Evaluacion en Test

In [None]:
# ============================================================
# EVALUACION EN TEST SET
# ============================================================

print("\nEvaluando en Test Set (tf.data)...")


def collect_labels(ds):
    return np.concatenate(
        list(ds.map(lambda _x, y: y).unbatch().batch(1024).as_numpy_iterator())
    )


# Etiquetas reales (enteros, alineados con STAGE_ORDER)
y_test_enc = collect_labels(test_ds).astype(int)

# Predicciones (modelo devuelve logits, aplicar softmax para probabilidades)
y_pred_logits = model.predict(test_ds, verbose=1)
y_pred_proba = tf.nn.softmax(y_pred_logits, axis=1).numpy()
y_pred_enc = np.argmax(y_pred_proba, axis=1)

# Metricas usando índices y orden explícito
labels_idx = np.arange(len(STAGE_ORDER))
accuracy = accuracy_score(y_test_enc, y_pred_enc)
kappa = cohen_kappa_score(y_test_enc, y_pred_enc)
f1_macro = f1_score(y_test_enc, y_pred_enc, average="macro", zero_division=0)
f1_weighted = f1_score(y_test_enc, y_pred_enc, average="weighted", zero_division=0)

print(f"\n{'=' * 50}")
print("RESULTADOS EN TEST SET")
print(f"{'=' * 50}")
print(f"   Accuracy:    {accuracy:.4f} ({accuracy * 100:.2f}%)")
print(f"   Cohen Kappa: {kappa:.4f}  <- Métrica principal en literatura")
print(f"   F1 Macro:    {f1_macro:.4f}  <- Para comparar con datasets desbalanceados")
print(f"   F1 Weighted: {f1_weighted:.4f}")
print(f"{'=' * 50}")
print(
    "\nNOTA: Según literatura (AASM), acuerdo inter-scorer humano tiene Kappa ~0.75-0.85"
)

In [None]:
# ============================================================
# CLASSIFICATION REPORT
# ============================================================

print("\nClassification Report:")
print(
    classification_report(
        y_test_enc,
        y_pred_enc,
        labels=np.arange(len(STAGE_ORDER)),
        target_names=STAGE_ORDER,
        digits=4,
        zero_division=0,
    )
)

In [None]:
# ============================================================
# MATRIZ DE CONFUSIÓN
# ============================================================

cm = confusion_matrix(y_test_enc, y_pred_enc, labels=np.arange(len(STAGE_ORDER)))
cm_normalized = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Absoluta
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=STAGE_ORDER,
    yticklabels=STAGE_ORDER,
    ax=axes[0],
)
axes[0].set_xlabel("Predicho")
axes[0].set_ylabel("Real")
axes[0].set_title("Matriz de Confusión (Absoluta)")

# Normalizada
sns.heatmap(
    cm_normalized,
    annot=True,
    fmt=".2%",
    cmap="Blues",
    xticklabels=STAGE_ORDER,
    yticklabels=STAGE_ORDER,
    ax=axes[1],
)
axes[1].set_xlabel("Predicho")
axes[1].set_ylabel("Real")
axes[1].set_title("Matriz de Confusión (Normalizada)")

plt.tight_layout()
plt.savefig(f"{OUTPUT_PATH}/{model_name}_confusion_matrix.png", dpi=150)
plt.show()

## Optimizacion de Hiperparametros (Optuna)

In [None]:
# ============================================================
# OPTIMIZACION CON OPTUNA (opcional)
# ============================================================

# Optuna no soportado en modo streaming. Si necesitas tuning, desactiva streaming.
if CONFIG["run_optimization"]:
    print("[WARN] Optuna no está implementado para el modo streaming TFRecord. Skip.")
    """
        # Hiperparametros a optimizar
        n_filters = trial.suggest_categorical("n_filters", [32, 64, 128])
        kernel_size = trial.suggest_categorical("kernel_size", [3, 5, 7, 9])
        dropout_rate = trial.suggest_float("dropout_rate", 0.2, 0.6)
        learning_rate = trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True)
        batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])

        # Crear modelo
        with strategy.scope():
            model = build_cnn1d_model(
                input_shape=input_shape,
                n_classes=len(STAGE_ORDER),
                n_filters=n_filters,
                kernel_size=kernel_size,
                dropout_rate=dropout_rate,
                learning_rate=learning_rate,
            )

        # Callbacks
        callbacks = [
            EarlyStopping(monitor="val_loss", patience=7, restore_best_weights=True),
            TFKerasPruningCallback(trial, "val_loss"),
        ]

        # Entrenar
        model.fit(
            X_train_norm,
            y_train_enc,
            validation_data=(X_val_norm, y_val_enc),
            batch_size=batch_size * strategy.num_replicas_in_sync,
            epochs=30,
            class_weight=class_weights,
            callbacks=callbacks,
            verbose=0,
        )

        # Evaluar en validacion
        y_val_pred = np.argmax(model.predict(X_val_norm, verbose=0), axis=1)
        kappa = cohen_kappa_score(y_val_enc, y_val_pred)

        # Limpiar
        keras.backend.clear_session()
        gc.collect()

        return kappa

    # Crear estudio
    study = optuna.create_study(
        direction="maximize",
        study_name="cnn1d_sleep_staging",
        pruner=optuna.pruners.MedianPruner(),
    )

    print(f"\nIniciando optimizacion con {CONFIG['n_optuna_trials']} trials...")
    study.optimize(
        objective, n_trials=CONFIG["n_optuna_trials"], show_progress_bar=True
    )

    print("\n[OK] Optimizacion completada")
    print(f"   Mejor Kappa: {study.best_value:.4f}")
    print("   Mejores hiperparametros:")
    for key, value in study.best_params.items():
        print(f"      {key}: {value}")

    # Guardar resultados
    with open(f"{OUTPUT_PATH}/optuna_cnn1d_results.json", "w") as f:
        json.dump(
            {
                "best_value": study.best_value,
                "best_params": study.best_params,
            },
            f,
            indent=2,
        )
"""
else:
    print(
        "[SKIP] Optimizacion deshabilitada. Cambiar CONFIG['run_optimization'] = True para ejecutar."
    )

## Guardar Modelo y Resultados

In [None]:
# ============================================================
# GUARDAR MODELO Y ARTEFACTOS
# ============================================================

# Guardar modelo completo
model.save(f"{OUTPUT_PATH}/{model_name}_final.keras")
print(f"[OK] Modelo guardado: {OUTPUT_PATH}/{model_name}_final.keras")

# Guardar historial
history_df = pd.DataFrame(history.history)
history_df.to_csv(f"{OUTPUT_PATH}/{model_name}_history.csv", index=False)

# Guardar resultados
results = {
    "model_name": model_name,
    "config": CONFIG,
    "metrics": {
        "accuracy": float(accuracy),
        "kappa": float(kappa),
        "f1_macro": float(f1_macro),
        "f1_weighted": float(f1_weighted),
    },
    "training": {
        "training_time_seconds": float(training_time),
        "training_time_minutes": float(training_time / 60),
        "epochs_trained": len(history.history["loss"]),
        "best_val_f1_macro": float(metrics_callback.best_f1_macro),
        "best_val_kappa": float(metrics_callback.best_kappa),
    },
    "dataset": {
        "train_samples": int(train_count),
        "val_samples": int(val_count),
        "test_samples": int(test_count),
        "tfrecord_paths": tfrecord_paths if USE_STREAMING else None,
        "execution_mode": CONFIG["execution_mode"],
    },
    "channel_stats": {"mean": mean_ch.tolist(), "std": std_ch.tolist()},
    "label_encoder_classes": list(le.classes_),
}

with open(f"{OUTPUT_PATH}/{model_name}_results.json", "w") as f:
    json.dump(results, f, indent=2, default=str)

print(f"[OK] Resultados guardados: {OUTPUT_PATH}/{model_name}_results.json")

# Guardar LabelEncoder
with open(f"{OUTPUT_PATH}/{model_name}_label_encoder.pkl", "wb") as f:
    pickle.dump(le, f)

print("\nArchivos generados:")
print(f"   - {model_name}_final.keras (modelo)")
print(f"   - {model_name}_best.keras (mejor checkpoint)")
print(f"   - {model_name}_history.csv (historial)")
print(f"   - {model_name}_results.json (metricas y config)")
print(f"   - {model_name}_label_encoder.pkl (encoder)")
print(f"   - {model_name}_training_curves.png")
print(f"   - {model_name}_confusion_matrix.png")

In [None]:
# ============================================================
# RESUMEN FINAL
# ============================================================

print("\n" + "=" * 60)
print("ENTRENAMIENTO CNN1D COMPLETADO")
print("=" * 60)
print("\nResultados finales en Test Set:")
print(f"   Accuracy:    {accuracy:.4f} ({accuracy * 100:.2f}%)")
print(f"   Cohen Kappa: {kappa:.4f}")
print(f"   F1 Macro:    {f1_macro:.4f}")
print(f"   F1 Weighted: {f1_weighted:.4f}")
print(f"\nModelo guardado en: {OUTPUT_PATH}")
print("=" * 60)

In [None]:
# ============================================================
# COMPRIMIR ARTEFACTOS PARA DESCARGA
# ============================================================

zip_path = f"{OUTPUT_PATH}/{model_name}_artifacts.zip"
exclude_extensions = (".tfrecord", ".zip")  # Excluir TFRecords y otros zips
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
    for fname in os.listdir(OUTPUT_PATH):
        fpath = os.path.join(OUTPUT_PATH, fname)
        # Solo incluir archivos (no directorios) que empiecen con model_name
        if (
            os.path.isfile(fpath)
            and fname.startswith(model_name)
            and not fname.endswith(exclude_extensions)
        ):
            zf.write(fpath, arcname=fname)

print(f"[OK] Artefactos comprimidos en: {zip_path} (excluyendo TFRecords)")
print("Archivos incluidos:")
with zipfile.ZipFile(zip_path, "r") as zf:
    for info in zf.infolist():
        print(f" - {info.filename} ({info.file_size / 1024:.1f} KB)")