# CNN-LSTM Secuencial para Sleep Staging

Este notebook entrena un modelo híbrido **CNN + LSTM** para clasificación de estadios de sueño usando **secuencias de múltiples epochs consecutivos**, capturando dependencias temporales.

**Optimizado para Kaggle con 2x Tesla T4 (30GB RAM + 32GB VRAM)**

### Arquitectura:
- **Nivel 1 (CNN)**: Extrae features de cada epoch de 30s con TimeDistributed Conv1D
- **Nivel 2 (BiLSTM)**: Procesa la secuencia de features para capturar contexto temporal
- **Predicción Many-to-Many**: Predice el estadio de todos los epochs en la secuencia
- **Sample Weights**: Balance de clases minoritarias (N1) por epoch

### Modos de ejecución:
- **Debug** (`EXECUTION_MODE = "debug"`): Carga datos en RAM, subset de sujetos.
- **Full** (`EXECUTION_MODE = "full"`): **Streaming TFRecord**, todos los sujetos.

### Características:
- Secuencias de 11 epochs (5.5 min de contexto temporal)
- **Train/Val**: stride=2 (overlap alto, ~60k secuencias, ~32GB TFRecord)
- **Test**: stride=11 (sin overlap para métricas sin sesgo)
- División train/val/test **por SUJETOS** (sin data leakage)
- TFRecords cacheados para re-ejecutar sin reprocesar
- Semillas fijas para reproducibilidad (SEED=42)
- Mixed precision desactivado para mayor estabilidad numérica

### Recursos Kaggle:
- RAM: 30GB
- VRAM: 2×16GB = 32GB
- Disco temporal (`/kaggle/tmp`): 50GB → TFRecords (~32GB con stride=2)
- Disco output (`/kaggle/working`): 19.5GB 

### Datos requeridos:
- Dataset `sleep-edf-trimmed-f32` en Kaggle

In [None]:
# ============================================================
# CONFIGURACIÓN INICIAL - Ejecutar primero
# ============================================================

import os
import warnings

warnings.filterwarnings("ignore")

# Detectar si estamos en Kaggle
IN_KAGGLE = os.path.exists("/kaggle/input")
print(f"Entorno: {'Kaggle' if IN_KAGGLE else 'Local'}")

# Paths según entorno
if IN_KAGGLE:
    DATA_PATH = "/kaggle/input/sleep-edf-trimmed-f32/sleep_trimmed_spt"
    OUTPUT_PATH = "/kaggle/working"
else:
    DATA_PATH = "../data/processed"
    OUTPUT_PATH = "../models"

os.makedirs(OUTPUT_PATH, exist_ok=True)

print(f"Data path: {DATA_PATH}")
print(f"Output path: {OUTPUT_PATH}")

In [None]:
# ============================================================
# VERIFICAR GPUs DISPONIBLES
# ============================================================

import tensorflow as tf  # noqa: E402

SEED = 42
tf.random.set_seed(SEED)

print("TensorFlow version:", tf.__version__)
print("\nGPUs disponibles:")

gpus = tf.config.list_physical_devices("GPU")
if gpus:
    for i, gpu in enumerate(gpus):
        print(f"  GPU {i}: {gpu}")
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    print(f"\n[OK] {len(gpus)} GPU(s) configuradas con memory growth")
else:
    print("[WARN] No se detectaron GPUs. El entrenamiento sera lento.")

if len(gpus) > 1:
    strategy = tf.distribute.MirroredStrategy()
    print(f"\nUsando MirroredStrategy con {strategy.num_replicas_in_sync} GPUs")
else:
    strategy = tf.distribute.get_strategy()
    if len(gpus) == 1:
        print("\nUsando estrategia por defecto (1 GPU)")
    else:
        print("\n[WARN] Usando CPU - entrenamiento será lento")

print("[INFO] Mixed precision se configurará después de definir CONFIG")

In [None]:
# ============================================================
# IMPORTS Y DEPENDENCIAS
# ============================================================

import gc  # noqa: E402
import hashlib  # noqa: E402
import json  # noqa: E402
import logging  # noqa: E402
import math  # noqa: E402
import pickle  # noqa: E402
import random  # noqa: E402
import re  # noqa: E402
import shutil  # noqa: E402
import time  # noqa: E402
import zipfile  # noqa: E402
from collections import Counter  # noqa: E402
from datetime import datetime  # noqa: E402
from pathlib import Path  # noqa: E402

import matplotlib.pyplot as plt  # noqa: E402
import numpy as np  # noqa: E402
import pandas as pd  # noqa: E402
import seaborn as sns  # noqa: E402

np.random.seed(SEED)
random.seed(SEED)
tf.keras.utils.set_random_seed(SEED)

from sklearn.metrics import (  # noqa: E402
    accuracy_score,
    classification_report,
    cohen_kappa_score,
    confusion_matrix,
    f1_score,
)
from sklearn.preprocessing import LabelEncoder  # noqa: E402
from tensorflow import keras  # noqa: E402
from tensorflow.keras import layers  # noqa: E402
from tensorflow.keras.callbacks import (  # noqa: E402
    Callback,
    EarlyStopping,
    ModelCheckpoint,
)

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
plt.style.use("seaborn-v0_8-whitegrid")
sns.set_palette("husl")

print("[OK] Imports completados")

In [None]:
# ============================================================
# INSTALAR DEPENDENCIAS (solo en Kaggle)
# ============================================================

try:
    import mne
except ImportError:
    if IN_KAGGLE:
        print("Instalando dependencias...")
        %pip install -q mne
        print("[OK] Dependencias instaladas")
        import mne
    else:
        raise

mne.set_log_level("ERROR")
print(f"MNE version: {mne.__version__}")

## Configuracion del Experimento

In [None]:
# ============================================================
# HIPERPARAMETROS - AJUSTAR SEGUN NECESIDAD
# ============================================================

EXECUTION_MODE = "full"  # "debug" o "full"

# Recursos Kaggle:
# - RAM: 30GB
# - VRAM: 2x16GB = 32GB
# - Disco output: 19.5GB
# - Disco total: 57GB

CONFIG = {
    "execution_mode": EXECUTION_MODE,
    "manifest_path": (
        "/kaggle/input/sleep-edf-trimmed-f32/manifest_trimmed_spt.csv"
        if IN_KAGGLE
        else f"{DATA_PATH}/manifest_trimmed_spt.csv"
    ),
    "epoch_length": 30.0,
    "sfreq": 100,
    "debug_max_subjects": 16,
    "test_size": 0.15,
    "val_size": 0.15,
    "random_state": 42,
    # === SECUENCIAS ===
    "seq_length": 11,  # Epochs por secuencia (5.5 min contexto)
    # Stride options para /kaggle/tmp (50GB disponibles):
    # - stride=1: ~120k secuencias (~64GB) - NO CABE
    # - stride=2: ~60k secuencias (~32GB) - CABE, más datos (RECOMENDADO)
    # - stride=3: ~40k secuencias (~21GB) - CABE, fallback conservador
    "seq_stride_train": 2,  # Stride para train (más overlap = más datos)
    "seq_stride_val": 2,  # Stride para val
    "seq_stride_test": 11,  # Stride para test = seq_length (sin overlap)
    # === CNN Feature Extractor ===
    "cnn_filters": [32, 64, 128],
    "cnn_kernel_sizes": [5, 5, 3],
    "feature_dim": 128,
    # === LSTM Temporal ===
    "lstm_units": [64, 32],
    "bidirectional": True,
    "dropout_rate": 0.3,
    # === Training ===
    "learning_rate_initial": 3e-4,
    "learning_rate_min": 1e-6,
    "warmup_epochs": 5,
    "batch_size": 64,
    "epochs": 300 if EXECUTION_MODE == "full" else 50,
    "use_class_weights": True,
    "class_weight_clip": 1.5,
    "std_epsilon": 1e-6,
    "clip_value": 5.0,
    "early_stopping_patience": 40 if EXECUTION_MODE == "full" else 10,
    "use_mixed_precision": False,
    # === Control ===
    "run_optimization": False,
    "max_steps_per_epoch": 115,
    "shuffle_buffer": 3000,
    # === TFRecord Streaming (modo full) ===
    "streaming": EXECUTION_MODE == "full",
    "tfrecord_dir": "/kaggle/tmp/tfrecords_cnn_lstm_seq"
    if IN_KAGGLE
    else f"{OUTPUT_PATH}/tfrecords_cnn_lstm_seq",
}

CONFIG["effective_batch_size"] = CONFIG["batch_size"] * strategy.num_replicas_in_sync
CONFIG["samples_per_epoch"] = int(CONFIG["epoch_length"] * CONFIG["sfreq"])

if CONFIG.get("use_mixed_precision", False) and gpus:
    try:
        from tensorflow.keras import mixed_precision

        mixed_precision.set_global_policy("mixed_float16")
        print("[OK] Mixed precision (float16) habilitado")
    except Exception as e:
        print(f"[WARN] No se pudo habilitar mixed precision: {e}")
else:
    print("[INFO] Mixed precision desactivado (float32) para mayor estabilidad")

print("\n" + "=" * 60)
if EXECUTION_MODE == "debug":
    print(" MODO DEBUG: Carga en RAM, subset de sujetos")
    print(f"   Max sujetos: {CONFIG['debug_max_subjects']}")
else:
    print(" MODO FULL: Streaming TFRecord, todos los sujetos")
    print(f"   TFRecord dir: {CONFIG['tfrecord_dir']}")
print(f"   Epochs training: {CONFIG['epochs']}")
print(
    f"   Seq length: {CONFIG['seq_length']} epochs ({CONFIG['seq_length']*30/60:.1f} min contexto)"
)
print(f"   Seq stride train/val: {CONFIG['seq_stride_train']} (overlapping)")
print(f"   Seq stride test: {CONFIG['seq_stride_test']} (sin overlap)")
print(f"   Max steps/epoch: {CONFIG['max_steps_per_epoch']}")
print("=" * 60)

print("\nConfiguracion del experimento:")
for key, value in CONFIG.items():
    print(f"   {key}: {value}")

## Carga de Datos

In [None]:
# ============================================================
# FUNCIONES DE CARGA DE DATOS
# ============================================================

STAGE_CANONICAL = {
    "Sleep stage W": "W",
    "Sleep stage 1": "N1",
    "Sleep stage 2": "N2",
    "Sleep stage 3": "N3",
    "Sleep stage 4": "N3",
    "Sleep stage R": "REM",
}
STAGE_ORDER = ["W", "N1", "N2", "N3", "REM"]
REQUIRED_CHANNELS = ["EEG Fpz-Cz", "EEG Pz-Oz"]
OPTIONAL_CHANNELS = ["EOG horizontal", "EMG submental"]


def extract_subject_core(subject_id):
    """Agrupa noches del mismo sujeto. SC4XXNy -> SC4XX"""
    sid = str(subject_id)
    match = re.match(r"(SC4\d{2})", sid)
    return match.group(1) if match else sid


def load_psg_data(psg_path, channels=None, target_sfreq=None):
    """Carga datos PSG desde archivo .fif.

    NOTA: Usa REQUIRED_CHANNELS + OPTIONAL_CHANNELS (EEG + EOG + EMG).
    EOG es crucial para detectar REM, EMG ayuda a distinguir estadios.
    IMPORTANTE: Valida que TODOS los canales estén presentes para consistencia.
    """
    raw = mne.io.read_raw_fif(str(psg_path), preload=True, verbose="ERROR")
    available = set(raw.ch_names)
    if channels is None:
        # FIXED: Usar todos los canales pero validar que estén presentes
        all_expected = REQUIRED_CHANNELS + OPTIONAL_CHANNELS
        missing = [ch for ch in all_expected if ch not in available]
        if missing:
            raise ValueError(f"Canales faltantes en {psg_path}: {missing}")
        channels = all_expected.copy()
    raw.pick(channels)
    if target_sfreq and raw.info["sfreq"] != target_sfreq:
        raw.resample(target_sfreq)
    return raw.get_data(), raw.info["sfreq"], raw.ch_names


def load_hypnogram(hyp_path):
    df = pd.read_csv(hyp_path)
    df["stage_canonical"] = df["description"].map(STAGE_CANONICAL)
    return df


def create_epochs(data, sfreq, epoch_length=30.0):
    samples_per_epoch = int(epoch_length * sfreq)
    n_channels, n_samples = data.shape
    n_epochs = n_samples // samples_per_epoch
    epochs = [
        data[:, i * samples_per_epoch : (i + 1) * samples_per_epoch]
        for i in range(n_epochs)
    ]
    epoch_times = [i * epoch_length for i in range(n_epochs)]
    return np.array(epochs), np.array(epoch_times)


def assign_stages(epoch_times, hypnogram, epoch_length=30.0):
    stages = []
    for t in epoch_times:
        epoch_center = t + epoch_length / 2
        mask = (hypnogram["onset"] <= epoch_center) & (
            hypnogram["onset"] + hypnogram["duration"] > epoch_center
        )
        matched = hypnogram[mask]
        stages.append(matched.iloc[0]["stage_canonical"] if len(matched) > 0 else None)
    return stages


print("[OK] Funciones de carga definidas")

In [None]:
# ============================================================
# FUNCIONES PARA GENERACIÓN DE SECUENCIAS
# ============================================================


def create_sequences_from_session(epochs, stages, seq_length, seq_stride=1):
    """
    Crea secuencias deslizantes de epochs consecutivos de una sesión.

    Args:
        epochs: (N, channels, samples) - epochs de una sesión
        stages: (N,) - labels por epoch
        seq_length: longitud de secuencia
        seq_stride: paso entre secuencias (1=overlap, seq_length=sin overlap)

    Returns:
        X_seq: (M, seq_length, channels, samples)
        y_seq: (M, seq_length) - labels para cada epoch
    """
    sequences_X, sequences_y = [], []
    n_epochs = len(epochs)

    if n_epochs < seq_length:
        return np.array([]), np.array([])

    for i in range(0, n_epochs - seq_length + 1, seq_stride):
        seq_X = epochs[i : i + seq_length]
        seq_y = stages[i : i + seq_length]

        # Solo incluir secuencias sin labels inválidos y sin NaN/Inf
        if all(s in STAGE_ORDER for s in seq_y) and np.all(np.isfinite(seq_X)):
            sequences_X.append(seq_X)
            sequences_y.append([STAGE_ORDER.index(s) for s in seq_y])

    if len(sequences_X) == 0:
        return np.array([]), np.array([])

    return np.array(sequences_X, dtype=np.float32), np.array(
        sequences_y, dtype=np.int32
    )


def load_sequences_for_split(
    manifest_path,
    epoch_length,
    sfreq,
    allowed_cores,
    seq_length,
    seq_stride,
    split_name="train",
):
    """Carga secuencias para un split específico."""
    print(f"\n[INFO] Cargando secuencias para {split_name} (stride={seq_stride})...")
    start_time = time.time()
    all_seq_X, all_seq_y, all_subject_cores = [], [], []
    expected_channels = None
    total_epochs, total_seqs = 0, 0

    for i, total, subject_id, subject_core, psg_path, hyp_path in iter_sessions(
        manifest_path, epoch_length, sfreq, allowed_cores
    ):
        data, actual_sfreq, ch_names = load_psg_data(psg_path, target_sfreq=sfreq)
        if expected_channels is None:
            expected_channels = list(ch_names)
        elif list(ch_names) != expected_channels:
            raise ValueError(f"Canales inconsistentes: {list(ch_names)}")

        hypnogram = load_hypnogram(hyp_path)
        epochs, epoch_times = create_epochs(data, actual_sfreq, epoch_length)
        stages = assign_stages(epoch_times, hypnogram, epoch_length)

        total_epochs += len(epochs)

        # Crear secuencias con stride específico para este split
        seq_X, seq_y = create_sequences_from_session(
            epochs, stages, seq_length, seq_stride
        )

        if len(seq_X) > 0:
            all_seq_X.extend(seq_X)
            all_seq_y.extend(seq_y)
            all_subject_cores.extend([subject_core] * len(seq_X))
            total_seqs += len(seq_X)

        if i % 20 == 0 or i == total:
            print(f"   {split_name}: {i}/{total} sesiones, {total_seqs} secuencias")

    if len(all_seq_X) == 0:
        return np.array([]), np.array([]), np.array([]), expected_channels

    X = np.array(all_seq_X, dtype=np.float32)
    y = np.array(all_seq_y, dtype=np.int32)
    subject_cores = np.array(all_subject_cores)

    # Transponer: (batch, seq_len, channels, samples) -> (batch, seq_len, samples, channels)
    X = np.transpose(X, (0, 1, 3, 2))

    print(
        f"[OK] {split_name}: {total_seqs} secuencias en {time.time() - start_time:.1f}s"
    )
    print(f"   Shape: {X.shape}, Epochs originales: {total_epochs}")

    return X, y, subject_cores, expected_channels


def get_split_cores(
    manifest_path, test_size, val_size, random_state, allowed_cores=None
):
    """Divide sujetos en train/val/test cores."""
    manifest = pd.read_csv(manifest_path)
    manifest_ok = manifest[manifest["status"] == "ok"].copy()
    subject_cores = manifest_ok["subject_id"].apply(extract_subject_core).unique()

    if allowed_cores is not None:
        subject_cores = np.array([c for c in subject_cores if c in allowed_cores])

    rng = np.random.default_rng(random_state)
    rng.shuffle(subject_cores)

    n_test = max(1, int(len(subject_cores) * test_size))
    n_val = max(1, int(len(subject_cores) * val_size))

    test_cores = set(subject_cores[:n_test])
    val_cores = set(subject_cores[n_test : n_test + n_val])
    train_cores = set(subject_cores[n_test + n_val :])

    return train_cores, val_cores, test_cores


def normalize_sequences(X_train, X_val, X_test, clip_value, std_epsilon):
    """Normaliza secuencias usando estadísticas de train (z-score por canal)."""
    # X shape: (batch, seq_len, samples, channels)
    mean_ch = X_train.mean(axis=(0, 1, 2))  # (channels,)
    var_ch = X_train.var(axis=(0, 1, 2))
    std_ch = np.sqrt(var_ch + std_epsilon)

    X_train_norm = np.clip((X_train - mean_ch) / std_ch, -clip_value, clip_value)
    X_val_norm = (
        np.clip((X_val - mean_ch) / std_ch, -clip_value, clip_value)
        if len(X_val) > 0
        else X_val
    )
    X_test_norm = (
        np.clip((X_test - mean_ch) / std_ch, -clip_value, clip_value)
        if len(X_test) > 0
        else X_test
    )

    return X_train_norm, X_val_norm, X_test_norm, mean_ch, std_ch


print("[OK] Funciones de secuencias definidas")

In [None]:
# ============================================================
# FUNCIONES TFRECORD PARA SECUENCIAS
# ============================================================


def make_seq_example(
    x_seq, y_seq, sample_weights, seq_length, samples_per_epoch, n_channels
):
    """
    Serializa una secuencia completa a TFRecord.
    x_seq: (seq_length, samples_per_epoch, n_channels) - ya transpuesto
    y_seq: (seq_length,) - labels
    sample_weights: (seq_length,) - peso por epoch según su clase
    """
    feature = {
        "x": tf.train.Feature(float_list=tf.train.FloatList(value=x_seq.ravel())),
        "y": tf.train.Feature(int64_list=tf.train.Int64List(value=y_seq.ravel())),
        "sw": tf.train.Feature(
            float_list=tf.train.FloatList(value=sample_weights.ravel())
        ),
    }
    return tf.train.Example(
        features=tf.train.Features(feature=feature)
    ).SerializeToString()


def compute_normalization_stats(manifest_path, epoch_length, sfreq, train_cores):
    """
    Primera pasada: calcula mean/std por canal usando solo datos de train.

    NOTA: Lleva running_n como vector por canal para manejar NaN/Inf
    desbalanceados entre canales.
    """
    print("\n[INFO] Calculando estadísticas de normalización (solo train)...")

    running_sum = None
    running_sumsq = None
    running_n = None  # Vector por canal
    n_channels = None

    for i, total, subject_id, subject_core, psg_path, hyp_path in iter_sessions(
        manifest_path, epoch_length, sfreq, train_cores
    ):
        data, actual_sfreq, ch_names = load_psg_data(psg_path, target_sfreq=sfreq)
        n_ch = data.shape[0]

        if n_channels is None:
            n_channels = n_ch
            running_sum = np.zeros(n_ch, dtype=np.float64)
            running_sumsq = np.zeros(n_ch, dtype=np.float64)
            running_n = np.zeros(n_ch, dtype=np.int64)  # Vector por canal

        # Acumular estadísticas por canal (cada canal con su propio count)
        for c in range(n_ch):
            channel_data = data[c, :]
            valid_mask = np.isfinite(channel_data)
            valid_data = channel_data[valid_mask]
            running_sum[c] += valid_data.sum()
            running_sumsq[c] += (valid_data**2).sum()
            running_n[c] += len(valid_data)  # Count por canal

        if i % 20 == 0 or i == total:
            print(f"   Stats: {i}/{total} sesiones")

    # Calcular mean/std finales (cada canal dividido por su propio count)
    # Protect against division by zero if a channel has no valid data
    if np.any(running_n == 0):
        zero_channels = np.where(running_n == 0)[0]
        print(
            f"[WARN] Canales sin datos válidos: {zero_channels}. Usando mean=0, std=1."
        )
        running_n = np.maximum(running_n, 1)  # Avoid division by zero

    mean_ch = running_sum / running_n
    var_ch = running_sumsq / running_n - mean_ch**2
    var_ch = np.maximum(var_ch, 0.0)
    std_ch = np.sqrt(var_ch + CONFIG.get("std_epsilon", 1e-6))

    # Set default stats for channels with no data
    mean_ch = np.where(running_n > 1, mean_ch, 0.0)
    std_ch = np.where(running_n > 1, std_ch, 1.0)

    print(f"[OK] Samples por canal: {running_n}")

    return mean_ch.astype(np.float32), std_ch.astype(np.float32), n_channels


def compute_class_counts(manifest_path, epoch_length, sfreq, train_cores):
    """
    Calcula distribución de clases recorriendo TODAS las sesiones de train.
    No usa break, cuenta todos los stages válidos.
    """
    print("\n[INFO] Calculando distribución de clases en train...")
    class_counts = Counter()
    total_epochs = 0

    for i, total, subject_id, subject_core, psg_path, hyp_path in iter_sessions(
        manifest_path, epoch_length, sfreq, train_cores
    ):
        data, actual_sfreq, ch_names = load_psg_data(psg_path, target_sfreq=sfreq)
        hypnogram = load_hypnogram(hyp_path)
        epochs, epoch_times = create_epochs(data, actual_sfreq, epoch_length)
        stages = assign_stages(epoch_times, hypnogram, epoch_length)

        for s in stages:
            if s in STAGE_ORDER:
                class_counts[s] += 1
                total_epochs += 1

        if i % 20 == 0 or i == total:
            print(f"   Class counts: {i}/{total} sesiones, {total_epochs} epochs")

    print(f"[OK] Distribución: {dict(class_counts)}")
    return class_counts


def build_tfrecord_sequences(
    manifest_path,
    mean_ch,
    std_ch,
    split_cores,
    split_name,
    epoch_length,
    sfreq,
    seq_length,
    seq_stride,
    tfrecord_dir,
    class_weights=None,
):
    """
    Construye TFRecords con secuencias para un split.
    Procesa sesión por sesión para minimizar uso de RAM.
    class_weights: dict {class_idx: weight} para generar sample_weights
    """
    tfrecord_dir = Path(tfrecord_dir)
    tfrecord_dir.mkdir(parents=True, exist_ok=True)
    tfrecord_path = tfrecord_dir / f"{split_name}.tfrecord"

    writer = tf.io.TFRecordWriter(str(tfrecord_path))
    total_seqs = 0
    total_epochs = 0
    samples_per_epoch = int(epoch_length * sfreq)
    clip_value = CONFIG.get("clip_value", 5.0)
    n_channels = len(mean_ch)

    print(f"\n[INFO] Generando TFRecord para {split_name} (stride={seq_stride})...")
    if class_weights:
        print(f"   [INFO] Usando class_weights para sample_weights: {class_weights}")

    for i, total, subject_id, subject_core, psg_path, hyp_path in iter_sessions(
        manifest_path, epoch_length, sfreq, split_cores
    ):
        # Cargar datos de esta sesión
        data, actual_sfreq, ch_names = load_psg_data(psg_path, target_sfreq=sfreq)
        hypnogram = load_hypnogram(hyp_path)
        epochs, epoch_times = create_epochs(data, actual_sfreq, epoch_length)
        stages = assign_stages(epoch_times, hypnogram, epoch_length)

        total_epochs += len(epochs)

        # Normalizar epochs
        epochs_norm = np.clip(
            (epochs - mean_ch[None, :, None]) / std_ch[None, :, None],
            -clip_value,
            clip_value,
        )

        # Transponer a (n_epochs, samples, channels) para modelo
        epochs_norm = np.transpose(epochs_norm, (0, 2, 1))

        # Crear secuencias de esta sesión
        n_epochs = len(epochs_norm)
        if n_epochs < seq_length:
            continue

        for start_idx in range(0, n_epochs - seq_length + 1, seq_stride):
            seq_stages = stages[start_idx : start_idx + seq_length]

            if not all(s in STAGE_ORDER for s in seq_stages):
                continue

            seq_X = epochs_norm[start_idx : start_idx + seq_length]

            if not np.all(np.isfinite(seq_X)):
                continue

            seq_y = np.array([STAGE_ORDER.index(s) for s in seq_stages], dtype=np.int64)

            # Generar sample_weights basados en clase de cada epoch
            if class_weights:
                seq_sw = np.array([class_weights[yi] for yi in seq_y], dtype=np.float32)
            else:
                seq_sw = np.ones(seq_length, dtype=np.float32)

            writer.write(
                make_seq_example(
                    seq_X.astype(np.float32),
                    seq_y,
                    seq_sw,
                    seq_length,
                    samples_per_epoch,
                    n_channels,
                )
            )
            total_seqs += 1

        if i % 20 == 0 or i == total:
            print(f"   {split_name}: {i}/{total} sesiones, {total_seqs} secuencias")

    writer.close()

    file_size_mb = tfrecord_path.stat().st_size / (1024 * 1024)
    print(f"[OK] {split_name}: {total_seqs} secuencias, {file_size_mb:.1f} MB")

    return str(tfrecord_path), total_seqs


def make_seq_dataset(
    tfrecord_path,
    seq_length,
    samples_per_epoch,
    n_channels,
    batch_size,
    shuffle=False,
    repeat=False,
):
    """Crea dataset de secuencias desde TFRecord con sample_weights."""

    feature_description = {
        "x": tf.io.FixedLenFeature(
            [seq_length * samples_per_epoch * n_channels], tf.float32
        ),
        "y": tf.io.FixedLenFeature([seq_length], tf.int64),
        "sw": tf.io.FixedLenFeature([seq_length], tf.float32),
    }

    def _parse(example_proto):
        example = tf.io.parse_single_example(example_proto, feature_description)
        x = tf.reshape(example["x"], (seq_length, samples_per_epoch, n_channels))
        y = tf.cast(example["y"], tf.int32)
        sw = example["sw"]  # (seq_length,)
        return x, y, sw

    ds = tf.data.TFRecordDataset([tfrecord_path], num_parallel_reads=tf.data.AUTOTUNE)
    ds = ds.map(_parse, num_parallel_calls=tf.data.AUTOTUNE)

    if shuffle:
        ds = ds.shuffle(
            CONFIG["shuffle_buffer"],
            seed=CONFIG["random_state"],
            reshuffle_each_iteration=True,
        )
    if repeat:
        ds = ds.repeat()

    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds


def get_tfrecord_seq_cache_key(manifest_path, config):
    """Genera cache key para detectar si los TFRecords son válidos."""
    manifest_mtime = os.path.getmtime(manifest_path)
    key_data = (
        f"{manifest_path}_{manifest_mtime}_{config['sfreq']}_{config['epoch_length']}"
        f"_{config['seq_length']}_{config['seq_stride_train']}_{config['seq_stride_val']}_{config['seq_stride_test']}"
        f"_{config['clip_value']}_{config.get('std_epsilon', 1e-6)}_{config['random_state']}"
        f"_{config['test_size']}_{config['val_size']}_{config.get('debug_max_subjects', 'all')}"
        f"_sw_v2"  # Version tag para invalidar cache cuando cambia formato
    )
    return hashlib.md5(key_data.encode()).hexdigest()[:12]


print("[OK] Funciones TFRecord para secuencias definidas (con sample_weights)")

In [None]:
# ============================================================
# PIPELINE DE DATOS (DEBUG: RAM / FULL: TFRecord)
# ============================================================


def resolve_paths(row, manifest_dir, dataset_dir_name):
    """Resuelve rutas de PSG e hipnograma respetando Kaggle/local."""
    psg_path_str = row.get("psg_trimmed_path", "")
    hyp_path_str = row.get("hypnogram_trimmed_path", "")
    if pd.isna(psg_path_str):
        psg_path_str = ""
    if pd.isna(hyp_path_str):
        hyp_path_str = ""
    base_data_root = manifest_dir.parent

    if psg_path_str and hyp_path_str:
        if IN_KAGGLE:
            psg_rel, hyp_rel = Path(psg_path_str), Path(hyp_path_str)
            psg_anchor = next(
                (
                    i
                    for i, p in enumerate(psg_rel.parts)
                    if p.startswith("sleep_trimmed")
                ),
                None,
            )
            hyp_anchor = next(
                (
                    i
                    for i, p in enumerate(hyp_rel.parts)
                    if p.startswith("sleep_trimmed")
                ),
                None,
            )
            if psg_anchor is not None and hyp_anchor is not None:
                psg_path = Path(DATA_PATH) / Path(*psg_rel.parts[psg_anchor:])
                hyp_path = Path(DATA_PATH) / Path(*hyp_rel.parts[hyp_anchor:])
            else:
                psg_path = Path(DATA_PATH) / dataset_dir_name / "psg" / psg_rel.name
                hyp_path = (
                    Path(DATA_PATH) / dataset_dir_name / "hypnograms" / hyp_rel.name
                )
        else:
            psg_rel, hyp_rel = Path(psg_path_str), Path(hyp_path_str)
            psg_path = (
                base_data_root / psg_rel.relative_to("data")
                if psg_rel.parts and psg_rel.parts[0] == "data"
                else manifest_dir / psg_rel
                if not psg_rel.is_absolute()
                else psg_rel
            )
            hyp_path = (
                base_data_root / hyp_rel.relative_to("data")
                if hyp_rel.parts and hyp_rel.parts[0] == "data"
                else manifest_dir / hyp_rel
                if not hyp_rel.is_absolute()
                else hyp_rel
            )
    else:
        subset = row.get("subset", "sleep-cassette")
        version = row.get("version", "1.0.0")
        dataset_dir = manifest_dir / dataset_dir_name
        psg_path = (
            dataset_dir
            / "psg"
            / f"{row['subject_id']}_{subset}_{version}_trimmed_raw.fif"
        )
        hyp_path = (
            dataset_dir
            / "hypnograms"
            / f"{row['subject_id']}_{subset}_{version}_trimmed_annotations.csv"
        )
    return psg_path, hyp_path


def get_subject_cores_for_mode(
    manifest_path, execution_mode, debug_max_subjects, random_state
):
    """Obtiene los cores de sujetos a usar según el modo de ejecución."""
    manifest = pd.read_csv(manifest_path)
    manifest_ok = manifest[manifest["status"] == "ok"].copy()
    all_cores = manifest_ok["subject_id"].apply(extract_subject_core).unique()
    rng = np.random.default_rng(random_state)
    rng.shuffle(all_cores)
    if execution_mode == "debug" and debug_max_subjects:
        selected_cores = set(all_cores[:debug_max_subjects])
        print(f"[DEBUG] Usando {len(selected_cores)}/{len(all_cores)} sujetos")
    else:
        selected_cores = set(all_cores)
        print(f"[FULL] Usando todos los {len(selected_cores)} sujetos")
    return selected_cores


def iter_sessions(manifest_path, epoch_length, sfreq, allowed_cores=None):
    """Itera sesiones entregando rutas resueltas.

    FIXED: Ahora reporta cuántas sesiones se saltaron por archivos faltantes.
    """
    manifest = pd.read_csv(manifest_path)
    manifest_ok = manifest[manifest["status"] == "ok"].copy()
    if allowed_cores is not None:
        manifest_ok = manifest_ok[
            manifest_ok["subject_id"].apply(extract_subject_core).isin(allowed_cores)
        ]
    manifest_dir = Path(manifest_path).parent
    dataset_dir_name = (
        "sleep_trimmed_resamp200"
        if (manifest_dir / "sleep_trimmed_resamp200").exists()
        else "sleep_trimmed_spt"
        if (manifest_dir / "sleep_trimmed_spt").exists()
        else "sleep_trimmed"
    )
    total_sessions = len(manifest_ok)
    skipped_count = 0  # FIXED: Contador de sesiones saltadas
    yielded_count = 0
    print(f"\nProcesando {total_sessions} sesiones...")
    for i, (_, row) in enumerate(manifest_ok.iterrows(), start=1):
        subject_id = row["subject_id"]
        subject_core = extract_subject_core(subject_id)
        if allowed_cores is not None and subject_core not in allowed_cores:
            continue
        psg_path, hyp_path = resolve_paths(row, manifest_dir, dataset_dir_name)
        if not psg_path.exists() or not hyp_path.exists():
            skipped_count += 1
            continue
        yielded_count += 1
        yield i, total_sessions, subject_id, subject_core, psg_path, hyp_path
    # FIXED: Loggear sesiones saltadas para auditoría
    if skipped_count > 0:
        print(f"[WARN] Se saltaron {skipped_count} sesiones por archivos faltantes")
    print(f"[INFO] Sesiones procesadas: {yielded_count}/{total_sessions}")


def update_running_stats(stats, epochs):
    if stats is None:
        stats = {"n": 0, "sum": None, "sumsq": None}
    if stats["sum"] is None:
        stats["sum"] = np.zeros(epochs.shape[1], dtype=np.float64)
        stats["sumsq"] = np.zeros(epochs.shape[1], dtype=np.float64)
    stats["sum"] += epochs.sum(axis=(0, 2))
    stats["sumsq"] += (epochs**2).sum(axis=(0, 2))
    stats["n"] += epochs.shape[0] * epochs.shape[2]
    return stats


def finalize_running_stats(stats, std_epsilon=1e-6):
    mean = stats["sum"] / stats["n"]
    var = stats["sumsq"] / stats["n"] - mean**2
    var = np.maximum(var, 0.0)
    std = np.sqrt(var + std_epsilon)
    return mean.astype(np.float32), std.astype(np.float32)


def load_all_data_to_ram(manifest_path, epoch_length, sfreq, allowed_cores):
    """Carga todos los datos en RAM (modo debug)."""
    print("\n[DEBUG] Cargando datos en RAM...")
    start_time = time.time()
    all_epochs, all_stages, all_subject_cores = [], [], []
    expected_channels = None
    nan_inf_count = 0

    for i, total, subject_id, subject_core, psg_path, hyp_path in iter_sessions(
        manifest_path, epoch_length, sfreq, allowed_cores
    ):
        data, actual_sfreq, ch_names = load_psg_data(psg_path, target_sfreq=sfreq)
        if expected_channels is None:
            expected_channels = list(ch_names)
        elif list(ch_names) != expected_channels:
            raise ValueError(f"Canales inconsistentes: {list(ch_names)}")
        hypnogram = load_hypnogram(hyp_path)
        epochs, epoch_times = create_epochs(data, actual_sfreq, epoch_length)
        stages = assign_stages(epoch_times, hypnogram, epoch_length)
        for epoch, stage in zip(epochs, stages):
            if stage not in STAGE_ORDER:
                continue
            if not np.all(np.isfinite(epoch)):
                nan_inf_count += 1
                continue
            all_epochs.append(epoch)
            all_stages.append(stage)
            all_subject_cores.append(subject_core)
        if i % 10 == 0 or i == total:
            print(f"   Cargando: {i}/{total} sesiones")

    if nan_inf_count > 0:
        print(f"[WARN] Se descartaron {nan_inf_count} epochs con NaN/Inf")
    X = np.array(all_epochs, dtype=np.float32)
    y = np.array([STAGE_ORDER.index(s) for s in all_stages], dtype=np.int32)
    subject_cores = np.array(all_subject_cores)
    print(
        f"[OK] Datos cargados en {time.time() - start_time:.1f}s. Shape: {X.shape}, Memoria: {X.nbytes / 1024**3:.2f} GB"
    )
    return X, y, subject_cores, expected_channels


def split_data_by_subject(X, y, subject_cores, test_size, val_size, random_state):
    """Divide datos por sujeto (sin data leakage)."""
    unique_cores = np.unique(subject_cores)
    rng = np.random.default_rng(random_state)
    rng.shuffle(unique_cores)
    n_test = max(1, int(len(unique_cores) * test_size))
    n_val = max(1, int(len(unique_cores) * val_size))
    test_cores = set(unique_cores[:n_test])
    val_cores = set(unique_cores[n_test : n_test + n_val])
    train_cores = set(unique_cores[n_test + n_val :])
    train_mask = np.isin(subject_cores, list(train_cores))
    val_mask = np.isin(subject_cores, list(val_cores))
    test_mask = np.isin(subject_cores, list(test_cores))
    return (
        X[train_mask],
        y[train_mask],
        X[val_mask],
        y[val_mask],
        X[test_mask],
        y[test_mask],
        train_cores,
        val_cores,
        test_cores,
    )


def normalize_data(X_train, X_val, X_test, clip_value, std_epsilon):
    """Normaliza datos usando estadísticas de train (z-score por canal)."""
    mean_ch = X_train.mean(axis=(0, 2))
    var_ch = X_train.var(axis=(0, 2))
    std_ch = np.sqrt(var_ch + std_epsilon)
    X_train_norm = np.clip(
        (X_train - mean_ch[None, :, None]) / std_ch[None, :, None],
        -clip_value,
        clip_value,
    )
    X_val_norm = np.clip(
        (X_val - mean_ch[None, :, None]) / std_ch[None, :, None],
        -clip_value,
        clip_value,
    )
    X_test_norm = np.clip(
        (X_test - mean_ch[None, :, None]) / std_ch[None, :, None],
        -clip_value,
        clip_value,
    )
    return X_train_norm, X_val_norm, X_test_norm, mean_ch, std_ch


print("[OK] Funciones de pipeline definidas")

In [None]:
# ============================================================
# FUNCIONES TFRECORD
# ============================================================


def pass1_stats(manifest_path, epoch_length, sfreq, allowed_cores=None):
    """Primera pasada: mean/std por canal y conteo de clases."""
    stats, class_counts, input_shape, expected_channels = None, Counter(), None, None
    nan_inf_count = 0

    for i, total, subject_id, subject_core, psg_path, hyp_path in iter_sessions(
        manifest_path, epoch_length, sfreq, allowed_cores
    ):
        data, actual_sfreq, ch_names = load_psg_data(psg_path, target_sfreq=sfreq)
        if expected_channels is None:
            expected_channels = list(ch_names)
        hypnogram = load_hypnogram(hyp_path)
        epochs, epoch_times = create_epochs(data, actual_sfreq, epoch_length)
        stages = assign_stages(epoch_times, hypnogram, epoch_length)

        valid_mask = [s in STAGE_ORDER for s in stages]
        if not any(valid_mask):
            continue
        valid_epochs = epochs[valid_mask]
        valid_stages = [s for s in stages if s in STAGE_ORDER]

        finite_mask = np.all(np.isfinite(valid_epochs), axis=(1, 2))
        if not np.all(finite_mask):
            nan_inf_count += np.sum(~finite_mask)
            valid_epochs = valid_epochs[finite_mask]
            valid_stages = [s for s, m in zip(valid_stages, finite_mask) if m]

        if len(valid_epochs) == 0:
            continue
        if input_shape is None:
            input_shape = valid_epochs.shape[1:]
        stats = update_running_stats(stats, valid_epochs)
        class_counts.update(valid_stages)

        if i % 20 == 0 or i == total:
            print(f"   Pasada 1: {i}/{total} sesiones")

    if nan_inf_count > 0:
        print(f"[WARN] Se descartaron {nan_inf_count} epochs con NaN/Inf")
    assert input_shape is not None, "No se encontraron epochs validos"
    mean, std = finalize_running_stats(
        stats, std_epsilon=CONFIG.get("std_epsilon", 1e-6)
    )
    return mean, std, class_counts, input_shape, expected_channels


def make_example(x, y):
    feature = {
        "x": tf.train.Feature(float_list=tf.train.FloatList(value=x.ravel())),
        "y": tf.train.Feature(int64_list=tf.train.Int64List(value=[y])),
    }
    return tf.train.Example(
        features=tf.train.Features(feature=feature)
    ).SerializeToString()


def assign_subject_splits(
    manifest_path, test_size, val_size, random_state, allowed_cores=None
):
    manifest = pd.read_csv(manifest_path)
    manifest_ok = manifest[manifest["status"] == "ok"].copy()
    subject_cores = manifest_ok["subject_id"].apply(extract_subject_core).unique()
    if allowed_cores is not None:
        subject_cores = np.array([c for c in subject_cores if c in allowed_cores])
    rng = np.random.default_rng(random_state)
    rng.shuffle(subject_cores)
    n_test = max(1, int(len(subject_cores) * test_size))
    n_val = max(1, int(len(subject_cores) * val_size))
    test_cores = set(subject_cores[:n_test])
    val_cores = set(subject_cores[n_test : n_test + n_val])
    train_cores = set(subject_cores[n_test + n_val :])
    split_map = dict.fromkeys(train_cores, "train")
    split_map.update(dict.fromkeys(val_cores, "val"))
    split_map.update(dict.fromkeys(test_cores, "test"))
    return split_map, {
        "train": len(train_cores),
        "val": len(val_cores),
        "test": len(test_cores),
    }


def build_tfrecord_splits(
    manifest_path,
    mean,
    std,
    split_map,
    epoch_length,
    sfreq,
    tfrecord_dir,
    expected_channels=None,
):
    tfrecord_dir = Path(tfrecord_dir)
    tfrecord_dir.mkdir(parents=True, exist_ok=True)
    writers = {
        k: tf.io.TFRecordWriter(str(tfrecord_dir / f"{k}.tfrecord"))
        for k in ["train", "val", "test"]
    }
    counts, session_counts, subject_sets = (
        Counter(),
        Counter(),
        {k: set() for k in writers},
    )
    skipped_nan_inf = 0
    clip_value = CONFIG.get("clip_value", 5.0)
    allowed_cores = set(split_map.keys())

    for i, total, subject_id, subject_core, psg_path, hyp_path in iter_sessions(
        manifest_path, epoch_length, sfreq, allowed_cores
    ):
        split = split_map.get(subject_core)
        if split is None:
            continue
        session_counts[split] += 1
        subject_sets[split].add(subject_core)
        data, _, ch_names = load_psg_data(psg_path, target_sfreq=sfreq)
        if expected_channels and list(ch_names) != expected_channels:
            raise ValueError(f"Canales inconsistentes en {psg_path}")
        hypnogram = load_hypnogram(hyp_path)
        epochs, epoch_times = create_epochs(data, sfreq, epoch_length)
        stages = assign_stages(epoch_times, hypnogram, epoch_length)

        for epoch, stage in zip(epochs, stages):
            if stage not in STAGE_ORDER:
                continue
            if not np.all(np.isfinite(epoch)):
                skipped_nan_inf += 1
                continue
            y = STAGE_ORDER.index(stage)
            x = np.clip((epoch - mean[:, None]) / std[:, None], -clip_value, clip_value)
            if not np.all(np.isfinite(x)):
                skipped_nan_inf += 1
                continue
            writers[split].write(make_example(x.astype(np.float32), y))
            counts[split] += 1

        if i % 20 == 0 or i == total:
            print(f"   Pasada 2 ({split}): {i}/{total} sesiones")

    for w in writers.values():
        w.close()
    if skipped_nan_inf > 0:
        print(f"[WARN] Se descartaron {skipped_nan_inf} epochs con NaN/Inf")
    return (
        {k: str(tfrecord_dir / f"{k}.tfrecord") for k in writers},
        counts,
        session_counts,
        {k: len(subject_sets[k]) for k in subject_sets},
    )


def make_dataset(
    tfrecord_path, input_shape, batch_size, shuffle=False, repeat=False, for_lstm=True
):
    """Crea dataset desde TFRecord. for_lstm=True transpone para LSTM: (samples, channels)."""
    feature_description = {
        "x": tf.io.FixedLenFeature([input_shape[0] * input_shape[1]], tf.float32),
        "y": tf.io.FixedLenFeature([], tf.int64),
    }
    clip_value = CONFIG.get("clip_value", 5.0)

    def _parse(example_proto):
        example = tf.io.parse_single_example(example_proto, feature_description)
        x = tf.reshape(example["x"], input_shape)
        x = tf.clip_by_value(x, -clip_value, clip_value)
        x = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x))
        if for_lstm:
            x = tf.transpose(
                x, perm=[1, 0]
            )  # (channels, samples) -> (samples, channels)
        y = tf.cast(example["y"], tf.int32)
        return x, y

    ds = tf.data.TFRecordDataset([tfrecord_path], num_parallel_reads=tf.data.AUTOTUNE)
    ds = ds.map(_parse, num_parallel_calls=tf.data.AUTOTUNE)
    options = tf.data.Options()
    options.experimental_deterministic = False
    ds = ds.with_options(options)
    if shuffle:
        ds = ds.shuffle(
            CONFIG["shuffle_buffer"],
            seed=CONFIG["random_state"],
            reshuffle_each_iteration=True,
        )
    if repeat:
        ds = ds.repeat()
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds


def get_tfrecord_cache_key(manifest_path, config):
    manifest_mtime = os.path.getmtime(manifest_path)
    key_data = (
        f"{manifest_path}_{manifest_mtime}_{config['sfreq']}_{config['epoch_length']}"
        f"_{config.get('clip_value', 'na')}_{config['random_state']}"
        f"_{config['test_size']}_{config['val_size']}_{config.get('debug_max_subjects', 'all')}"
    )
    return hashlib.md5(key_data.encode()).hexdigest()[:12]


print("[OK] Funciones TFRecord definidas")

In [None]:
# ============================================================
# EJECUTAR PIPELINE SEGÚN MODO
# ============================================================

print(f"\n[INFO] Modo de ejecución: {CONFIG['execution_mode'].upper()}")
print(f"[INFO] Longitud de secuencia: {CONFIG['seq_length']} epochs")

selected_cores = get_subject_cores_for_mode(
    CONFIG["manifest_path"],
    CONFIG["execution_mode"],
    CONFIG["debug_max_subjects"],
    CONFIG["random_state"],
)

# Dividir sujetos en train/val/test
train_cores, val_cores, test_cores = get_split_cores(
    CONFIG["manifest_path"],
    CONFIG["test_size"],
    CONFIG["val_size"],
    CONFIG["random_state"],
    selected_cores,
)
print(
    f"\n[INFO] Split por sujeto: train={len(train_cores)}, val={len(val_cores)}, test={len(test_cores)}"
)

USE_STREAMING = CONFIG.get("streaming", False)

if USE_STREAMING:
    # ========== MODO FULL: TFRecord Streaming ==========
    print("\n[FULL] Usando TFRecord streaming...")

    tfrecord_path = Path(CONFIG["tfrecord_dir"])
    cache_key = get_tfrecord_seq_cache_key(CONFIG["manifest_path"], CONFIG)
    cache_marker = tfrecord_path / f".cache_{cache_key}"
    cache_metadata_path = tfrecord_path / "cache_metadata.json"

    if tfrecord_path.exists() and cache_marker.exists():
        print(f"[INFO] Reutilizando TFRecords existentes (cache: {cache_key})")
        with open(cache_metadata_path) as f:
            cache_meta = json.load(f)
        mean_ch = np.array(cache_meta["mean_ch"], dtype=np.float32)
        std_ch = np.array(cache_meta["std_ch"], dtype=np.float32)
        n_channels = cache_meta["n_channels"]
        train_count = cache_meta["train_count"]
        val_count = cache_meta["val_count"]
        test_count = cache_meta["test_count"]
        class_counts_train = Counter(cache_meta["class_counts_train"])
        class_weights = cache_meta.get("class_weights", {})
        # Convert string keys back to int
        class_weights = {int(k): v for k, v in class_weights.items()}
        tfrecord_paths = cache_meta["tfrecord_paths"]
    else:
        # Regenerar TFRecords
        if tfrecord_path.exists():
            print("[INFO] Cache inválido, regenerando TFRecords...")
            shutil.rmtree(tfrecord_path)

        # Paso 1: Calcular estadísticas de normalización (solo train)
        mean_ch, std_ch, n_channels = compute_normalization_stats(
            CONFIG["manifest_path"],
            CONFIG["epoch_length"],
            CONFIG["sfreq"],
            train_cores,
        )
        print(f"[OK] Estadísticas: mean={mean_ch}, std={std_ch}")

        # Paso 2: Calcular distribución de clases en train (recorriendo TODAS las sesiones)
        class_counts_train = compute_class_counts(
            CONFIG["manifest_path"],
            CONFIG["epoch_length"],
            CONFIG["sfreq"],
            train_cores,
        )

        # Paso 3: Calcular class_weights para sample_weights
        # FIXED: Cálculo directo sin np.repeat (evita explotar memoria con datasets grandes)
        # Fórmula balanced: w_c = N / (K * n_c)
        counts_list = [class_counts_train.get(stage, 1) for stage in STAGE_ORDER]
        total_samples = sum(counts_list)
        n_classes = len(STAGE_ORDER)
        class_weight_clip = CONFIG.get("class_weight_clip", 1.5)
        class_weights = {}
        for k, count in enumerate(counts_list):
            weight = total_samples / (n_classes * max(count, 1))
            class_weights[k] = float(np.clip(weight, 0.5, class_weight_clip))
        print(f"[OK] Class weights para sample_weights: {class_weights}")

        # Paso 4: Generar TFRecords para cada split (con sample_weights)
        tfrecord_paths = {}

        train_path, train_count = build_tfrecord_sequences(
            CONFIG["manifest_path"],
            mean_ch,
            std_ch,
            train_cores,
            "train",
            CONFIG["epoch_length"],
            CONFIG["sfreq"],
            CONFIG["seq_length"],
            CONFIG["seq_stride_train"],
            CONFIG["tfrecord_dir"],
            class_weights=class_weights,
        )
        tfrecord_paths["train"] = train_path

        # Val/test use uniform weights (1.0) - class_weights only for training loss
        val_path, val_count = build_tfrecord_sequences(
            CONFIG["manifest_path"],
            mean_ch,
            std_ch,
            val_cores,
            "val",
            CONFIG["epoch_length"],
            CONFIG["sfreq"],
            CONFIG["seq_length"],
            CONFIG["seq_stride_val"],
            CONFIG["tfrecord_dir"],
            class_weights=None,  # Val: uniform weights
        )
        tfrecord_paths["val"] = val_path

        test_path, test_count = build_tfrecord_sequences(
            CONFIG["manifest_path"],
            mean_ch,
            std_ch,
            test_cores,
            "test",
            CONFIG["epoch_length"],
            CONFIG["sfreq"],
            CONFIG["seq_length"],
            CONFIG["seq_stride_test"],
            CONFIG["tfrecord_dir"],
            class_weights=None,  # Test: uniform weights
        )
        tfrecord_paths["test"] = test_path

        # Guardar metadata (incluyendo class_weights)
        cache_meta = {
            "mean_ch": mean_ch.tolist(),
            "std_ch": std_ch.tolist(),
            "n_channels": n_channels,
            "train_count": train_count,
            "val_count": val_count,
            "test_count": test_count,
            "class_counts_train": dict(class_counts_train),
            "class_weights": class_weights,
            "tfrecord_paths": tfrecord_paths,
        }
        tfrecord_path.mkdir(parents=True, exist_ok=True)
        with open(cache_metadata_path, "w") as f:
            json.dump(cache_meta, f, indent=2)
        cache_marker.touch()
        print(f"[OK] Cache guardado (key: {cache_key})")

    # Crear datasets desde TFRecords
    INPUT_SHAPE = (CONFIG["seq_length"], CONFIG["samples_per_epoch"], n_channels)

    train_ds = make_seq_dataset(
        tfrecord_paths["train"],
        CONFIG["seq_length"],
        CONFIG["samples_per_epoch"],
        n_channels,
        CONFIG["effective_batch_size"],
        shuffle=True,
        repeat=True,
    )
    val_ds = (
        make_seq_dataset(
            tfrecord_paths["val"],
            CONFIG["seq_length"],
            CONFIG["samples_per_epoch"],
            n_channels,
            CONFIG["effective_batch_size"],
            shuffle=False,
            repeat=False,
        )
        if val_count > 0
        else None
    )
    test_ds = (
        make_seq_dataset(
            tfrecord_paths["test"],
            CONFIG["seq_length"],
            CONFIG["samples_per_epoch"],
            n_channels,
            CONFIG["effective_batch_size"],
            shuffle=False,
            repeat=False,
        )
        if test_count > 0
        else None
    )

else:
    # ========== MODO DEBUG: Carga en RAM ==========
    print("\n[DEBUG] Cargando secuencias en RAM...")

    # Cargar secuencias para cada split
    X_train, y_train, _, expected_channels = load_sequences_for_split(
        CONFIG["manifest_path"],
        CONFIG["epoch_length"],
        CONFIG["sfreq"],
        train_cores,
        CONFIG["seq_length"],
        CONFIG["seq_stride_train"],
        "train",
    )
    X_val, y_val, _, _ = load_sequences_for_split(
        CONFIG["manifest_path"],
        CONFIG["epoch_length"],
        CONFIG["sfreq"],
        val_cores,
        CONFIG["seq_length"],
        CONFIG["seq_stride_train"],
        "val",
    )
    X_test, y_test, _, _ = load_sequences_for_split(
        CONFIG["manifest_path"],
        CONFIG["epoch_length"],
        CONFIG["sfreq"],
        test_cores,
        CONFIG["seq_length"],
        CONFIG["seq_stride_test"],
        "test",
    )

    n_channels = X_train.shape[-1] if len(X_train) > 0 else 4

    # Normalizar
    X_train_norm, X_val_norm, X_test_norm, mean_ch, std_ch = normalize_sequences(
        X_train, X_val, X_test, CONFIG["clip_value"], CONFIG["std_epsilon"]
    )
    print(f"[OK] Normalización: mean={mean_ch}, std={std_ch}")

    del X_train, X_val, X_test
    gc.collect()

    INPUT_SHAPE = X_train_norm.shape[1:]
    train_count, val_count, test_count = (
        len(X_train_norm),
        len(X_val_norm),
        len(X_test_norm),
    )

    # Calcular class_weights para sample_weights (modo debug)
    # FIXED: Cálculo directo sin np.repeat (evita explotar memoria)
    y_flat = y_train.flatten()
    class_counts_train = Counter([STAGE_ORDER[yi] for yi in y_flat])
    counts_list = [class_counts_train.get(stage, 1) for stage in STAGE_ORDER]
    total_samples = sum(counts_list)
    n_classes = len(STAGE_ORDER)
    class_weight_clip = CONFIG.get("class_weight_clip", 1.5)
    class_weights = {}
    for k, count in enumerate(counts_list):
        weight = total_samples / (n_classes * max(count, 1))
        class_weights[k] = float(np.clip(weight, 0.5, class_weight_clip))
    print(f"[OK] Class weights para sample_weights: {class_weights}")

    # Generar sample_weights para cada secuencia
    def generate_sample_weights(y, class_weights):
        """Genera sample_weights (seq_len,) para cada secuencia."""
        sw = np.array(
            [[class_weights[yi] for yi in seq] for seq in y], dtype=np.float32
        )
        return sw

    sw_train = generate_sample_weights(y_train, class_weights)
    sw_val = (
        generate_sample_weights(y_val, class_weights)
        if len(y_val) > 0
        else np.array([])
    )
    sw_test = (
        generate_sample_weights(y_test, class_weights)
        if len(y_test) > 0
        else np.array([])
    )

    # Crear datasets (tupla de 3 elementos: x, y, sample_weight)
    train_ds = (
        tf.data.Dataset.from_tensor_slices((X_train_norm, y_train, sw_train))
        .shuffle(
            min(5000, train_count),
            seed=CONFIG["random_state"],
            reshuffle_each_iteration=True,
        )
        .batch(CONFIG["effective_batch_size"])
        .prefetch(tf.data.AUTOTUNE)
    )
    val_ds = (
        (
            tf.data.Dataset.from_tensor_slices((X_val_norm, y_val, sw_val))
            .batch(CONFIG["effective_batch_size"])
            .prefetch(tf.data.AUTOTUNE)
        )
        if val_count > 0
        else None
    )
    test_ds = (
        (
            tf.data.Dataset.from_tensor_slices((X_test_norm, y_test, sw_test))
            .batch(CONFIG["effective_batch_size"])
            .prefetch(tf.data.AUTOTUNE)
        )
        if test_count > 0
        else None
    )

# Label encoder para compatibilidad
le = LabelEncoder()
le.classes_ = np.array(STAGE_ORDER)

print(f"\n{'=' * 60}")
print(f"[OK] Pipeline {'STREAMING' if USE_STREAMING else 'RAM'} listo")
print(f"   Input shape: {INPUT_SHAPE} (seq_len, samples, channels)")
print(f"   Train: {train_count:,} secuencias")
print(f"   Val: {val_count:,} secuencias")
print(
    f"   Test: {test_count:,} secuencias (stride={CONFIG['seq_stride_test']}, sin overlap)"
)
print(f"   Class distribution (train): {dict(class_counts_train)}")
print(f"   Class weights (sample_weights): {class_weights}")
print("   [OK] sample_weight por epoch incluido en dataset")
print("=" * 60)

## Arquitectura CNN-LSTM Secuencial

In [None]:
# ============================================================
# LEARNING RATE SCHEDULE
# ============================================================


def create_lr_schedule(
    initial_lr, min_lr, warmup_epochs, total_epochs, steps_per_epoch
):
    """Warmup lineal + cosine decay."""
    warmup_steps = warmup_epochs * steps_per_epoch
    total_steps = total_epochs * steps_per_epoch
    decay_steps = max(total_steps - warmup_steps, 1)

    class WarmupCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
        def __init__(self, initial_lr, min_lr, warmup_steps, decay_steps):
            super().__init__()
            self.initial_lr = tf.cast(initial_lr, tf.float32)
            self.min_lr = tf.cast(min_lr, tf.float32)
            self.warmup_steps = tf.cast(warmup_steps, tf.float32)
            self.decay_steps = tf.cast(decay_steps, tf.float32)

        def __call__(self, step):
            step = tf.cast(step, tf.float32)
            warmup_progress = step / tf.maximum(self.warmup_steps, 1.0)
            warmup_lr = self.min_lr + (self.initial_lr - self.min_lr) * tf.minimum(
                warmup_progress, 1.0
            )
            decay_progress = tf.minimum(
                tf.maximum((step - self.warmup_steps) / self.decay_steps, 0.0), 1.0
            )
            cosine_decay = 0.5 * (
                1.0 + tf.cos(tf.constant(np.pi, dtype=tf.float32) * decay_progress)
            )
            decay_lr = self.min_lr + (self.initial_lr - self.min_lr) * cosine_decay
            return tf.cond(
                step < self.warmup_steps, lambda: warmup_lr, lambda: decay_lr
            )

        def get_config(self):
            # Use numpy() for proper tensor-to-Python conversion
            return {
                "initial_lr": float(
                    self.initial_lr.numpy()
                    if hasattr(self.initial_lr, "numpy")
                    else self.initial_lr
                ),
                "min_lr": float(
                    self.min_lr.numpy()
                    if hasattr(self.min_lr, "numpy")
                    else self.min_lr
                ),
                "warmup_steps": int(
                    self.warmup_steps.numpy()
                    if hasattr(self.warmup_steps, "numpy")
                    else self.warmup_steps
                ),
                "decay_steps": int(
                    self.decay_steps.numpy()
                    if hasattr(self.decay_steps, "numpy")
                    else self.decay_steps
                ),
            }

    return WarmupCosineDecay(initial_lr, min_lr, warmup_steps, decay_steps)


print("[OK] Learning rate schedule definido")

In [None]:
# ============================================================
# ARQUITECTURA CNN-LSTM SECUENCIAL
# ============================================================


def build_cnn_feature_extractor(
    samples_per_epoch, n_channels, cnn_filters, cnn_kernel_sizes, feature_dim
):
    """CNN para extraer features de un epoch individual."""
    inputs = keras.Input(shape=(samples_per_epoch, n_channels), name="epoch_input")
    x = inputs

    for i, (filters, kernel) in enumerate(zip(cnn_filters, cnn_kernel_sizes)):
        x = layers.Conv1D(
            filters,
            kernel,
            padding="same",
            kernel_regularizer=keras.regularizers.l2(1e-4),
            kernel_initializer="he_uniform",
            name=f"conv_{i+1}",
        )(x)
        x = layers.BatchNormalization(name=f"bn_conv_{i+1}")(x)
        x = layers.Activation("gelu")(x)
        if i < len(cnn_filters) - 1:
            x = layers.MaxPooling1D(2, name=f"pool_{i+1}")(x)

    x = layers.GlobalAveragePooling1D(name="global_pool")(x)
    x = layers.Dense(
        feature_dim,
        kernel_regularizer=keras.regularizers.l2(1e-4),
        kernel_initializer="he_uniform",
        name="feature_dense",
    )(x)
    x = layers.BatchNormalization(name="bn_feature")(x)
    outputs = layers.Activation("gelu")(x)

    return keras.Model(inputs, outputs, name="cnn_feature_extractor")


def build_cnn_lstm_model(
    seq_length,
    samples_per_epoch,
    n_channels,
    n_classes=5,
    cnn_filters=[32, 64, 128],
    cnn_kernel_sizes=[5, 5, 3],
    feature_dim=128,
    lstm_units=[64, 32],
    dropout_rate=0.3,
    bidirectional=True,
    lr_schedule=None,
):
    """
    Modelo híbrido CNN-LSTM para sleep staging con secuencias.

    Input: (batch, seq_length, samples_per_epoch, n_channels)
    Output: (batch, seq_length, n_classes) - many-to-many

    NOTA: sample_weight por epoch se pasa como tercer elemento del dataset.
    Keras acepta (x, y, sample_weight) para balancear clases minoritarias.
    """
    inputs = keras.Input(
        shape=(seq_length, samples_per_epoch, n_channels),
        name="sequence_input",
        dtype="float32",
    )

    # Nivel 1: CNN Feature Extractor (TimeDistributed)
    cnn_model = build_cnn_feature_extractor(
        samples_per_epoch, n_channels, cnn_filters, cnn_kernel_sizes, feature_dim
    )
    features = layers.TimeDistributed(cnn_model, name="td_cnn")(inputs)
    features = layers.Dropout(dropout_rate, name="dropout_features")(features)
    # features shape: (batch, seq_length, feature_dim)

    # Nivel 2: LSTM Temporal
    x = features
    for i, units in enumerate(lstm_units):
        lstm = layers.LSTM(
            units,
            return_sequences=True,  # Many-to-many
            kernel_regularizer=keras.regularizers.l2(1e-4),
            name=f"lstm_{i+1}",
        )
        if bidirectional:
            x = layers.Bidirectional(lstm, name=f"bilstm_{i+1}")(x)
        else:
            x = lstm(x)
        # LayerNorm en lugar de BatchNorm para secuencias (más estable)
        x = layers.LayerNormalization(name=f"ln_lstm_{i+1}")(x)
        x = layers.Dropout(dropout_rate, name=f"dropout_lstm_{i+1}")(x)

    # Output: predicción para cada epoch de la secuencia
    outputs = layers.TimeDistributed(
        layers.Dense(n_classes, kernel_initializer="glorot_uniform"), name="output"
    )(x)
    # outputs shape: (batch, seq_length, n_classes) - logits

    model = keras.Model(inputs, outputs, name="CNN_LSTM_SeqSleepStaging")

    model.compile(
        optimizer=keras.optimizers.Adam(
            learning_rate=lr_schedule if lr_schedule else 1e-4, clipnorm=1.0
        ),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )

    return model


print("[OK] Arquitectura CNN-LSTM Secuencial definida")
print("   [OK] sample_weight por epoch soportado para balancear clases")

In [None]:
# ============================================================
# CREAR MODELO CNN-LSTM
# ============================================================

seq_length = CONFIG["seq_length"]
samples_per_epoch = CONFIG["samples_per_epoch"]
n_channels = INPUT_SHAPE[-1]

print(f"Input shape: {INPUT_SHAPE}")
print(f"   seq_length: {seq_length}")
print(f"   samples_per_epoch: {samples_per_epoch}")
print(f"   n_channels: {n_channels}")

# Calcular steps per epoch
MAX_STEPS_PER_EPOCH = CONFIG["max_steps_per_epoch"]
full_steps_train = math.ceil(train_count / CONFIG["effective_batch_size"])
full_steps_val = (
    math.ceil(val_count / CONFIG["effective_batch_size"]) if val_count else 0
)

steps_per_epoch_train = min(full_steps_train, MAX_STEPS_PER_EPOCH)
steps_per_epoch_val = full_steps_val if USE_STREAMING else None

print("\n[INFO] Steps/epoch:")
print(f"   Full steps: {full_steps_train}")
print(f"   Actual steps (cap {MAX_STEPS_PER_EPOCH}): {steps_per_epoch_train}")
if full_steps_train > MAX_STEPS_PER_EPOCH:
    pct = 100 * steps_per_epoch_train / full_steps_train
    print(f"   [NOTE] Cada 'epoch' procesa {pct:.1f}% de train")

lr_schedule = create_lr_schedule(
    initial_lr=CONFIG["learning_rate_initial"],
    min_lr=CONFIG["learning_rate_min"],
    warmup_epochs=CONFIG["warmup_epochs"],
    total_epochs=CONFIG["epochs"],
    steps_per_epoch=steps_per_epoch_train,
)

print(
    f"\n[INFO] LR Schedule: initial={CONFIG['learning_rate_initial']}, min={CONFIG['learning_rate_min']}"
)
print(f"   Warmup: {CONFIG['warmup_epochs']} epochs")

with strategy.scope():
    model = build_cnn_lstm_model(
        seq_length=seq_length,
        samples_per_epoch=samples_per_epoch,
        n_channels=n_channels,
        n_classes=len(STAGE_ORDER),
        cnn_filters=CONFIG["cnn_filters"],
        cnn_kernel_sizes=CONFIG["cnn_kernel_sizes"],
        feature_dim=CONFIG["feature_dim"],
        lstm_units=CONFIG["lstm_units"],
        dropout_rate=CONFIG["dropout_rate"],
        bidirectional=CONFIG["bidirectional"],
        lr_schedule=lr_schedule,
    )

model.summary()

## Entrenamiento

In [None]:
# ============================================================
# CALLBACKS PARA SECUENCIAS
# ============================================================

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_name = f"cnn_lstm_seq_{CONFIG['execution_mode']}_{timestamp}"


class SeqSleepMetricsCallback(Callback):
    """
    Calcula F1-Macro, Kappa, val_loss y val_accuracy para secuencias (many-to-many).

    IMPORTANTE: Persiste métricas en epochs intermedios para EarlyStopping.
    Calcula val_loss desde logits sin pasadas extra por los datos.
    """

    def __init__(self, val_ds, stage_order, eval_every=1):
        super().__init__()
        self.val_ds = val_ds
        self.stage_order = stage_order
        self.eval_every = max(1, eval_every)
        self.best_f1_macro, self.best_kappa, self.best_weights = -1.0, -1.0, None
        # Almacenar últimos valores para epochs intermedios
        self.last_f1_macro = 0.0
        self.last_kappa = 0.0
        self.last_val_loss = 0.0
        self.last_val_accuracy = 0.0

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}

        # Siempre escribir los últimos valores conocidos para que EarlyStopping funcione
        if (epoch + 1) % self.eval_every != 0:
            logs["val_f1_macro"] = self.last_f1_macro
            logs["val_kappa"] = self.last_kappa
            logs["val_loss"] = self.last_val_loss
            logs["val_accuracy"] = self.last_val_accuracy
            return

        # Manejar caso de val_ds vacío o None
        if self.val_ds is None:
            self.last_f1_macro, self.last_kappa = 0.0, 0.0
            self.last_val_loss, self.last_val_accuracy = 0.0, 0.0
            logs["val_f1_macro"], logs["val_kappa"] = 0.0, 0.0
            logs["val_loss"], logs["val_accuracy"] = 0.0, 0.0
            return

        y_true_list, y_pred_list, logits_list = [], [], []
        try:
            for batch in self.val_ds:
                # Handle both 2-element (x, y) and 3-element (x, y, sw) datasets
                x_batch = batch[0]
                y_batch = batch[1]
                y_pred_logits = self.model.predict(x_batch, verbose=0)
                y_pred = np.argmax(y_pred_logits, axis=-1)
                y_pred_list.append(y_pred.flatten())
                y_true_list.append(y_batch.numpy().flatten())
                logits_list.append(y_pred_logits.reshape(-1, y_pred_logits.shape[-1]))
        except Exception as e:
            print(f"[WARN] Error en evaluación: {e}")
            logs["val_f1_macro"], logs["val_kappa"] = (
                self.last_f1_macro,
                self.last_kappa,
            )
            logs["val_loss"], logs["val_accuracy"] = (
                self.last_val_loss,
                self.last_val_accuracy,
            )
            return

        if len(y_true_list) == 0:
            logs["val_f1_macro"], logs["val_kappa"] = (
                self.last_f1_macro,
                self.last_kappa,
            )
            logs["val_loss"], logs["val_accuracy"] = (
                self.last_val_loss,
                self.last_val_accuracy,
            )
            return

        y_true = np.concatenate(y_true_list)
        y_pred = np.concatenate(y_pred_list)
        logits = np.concatenate(logits_list)

        if len(y_true) == 0:
            logs["val_f1_macro"], logs["val_kappa"] = (
                self.last_f1_macro,
                self.last_kappa,
            )
            logs["val_loss"], logs["val_accuracy"] = (
                self.last_val_loss,
                self.last_val_accuracy,
            )
            return

        # Calcular val_loss desde logits (sin pasada extra)
        val_loss = (
            tf.keras.losses.sparse_categorical_crossentropy(
                y_true, logits, from_logits=True
            )
            .numpy()
            .mean()
        )
        val_accuracy = accuracy_score(y_true, y_pred)

        kappa = cohen_kappa_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)

        # Actualizar últimos valores conocidos
        self.last_f1_macro, self.last_kappa = f1, kappa
        self.last_val_loss, self.last_val_accuracy = val_loss, val_accuracy
        logs["val_f1_macro"], logs["val_kappa"] = f1, kappa
        logs["val_loss"], logs["val_accuracy"] = val_loss, val_accuracy

        if f1 > self.best_f1_macro:
            self.best_f1_macro, self.best_kappa = f1, kappa
            self.best_weights = self.model.get_weights()
            print(f" - val_f1_macro mejoró a {f1:.4f} (kappa={kappa:.4f})")
        print(
            f" - val_loss: {val_loss:.4f} - val_acc: {val_accuracy:.4f} - val_f1: {f1:.4f} - val_kappa: {kappa:.4f}"
        )

    def restore_best_weights(self):
        if self.best_weights:
            self.model.set_weights(self.best_weights)
            print(
                f"[OK] Restaurados pesos (F1={self.best_f1_macro:.4f}, Kappa={self.best_kappa:.4f})"
            )


class NaNDebugCallback(Callback):
    """Detecta NaN en loss."""

    def __init__(self):
        super().__init__()
        self.nan_count, self.last_valid_loss = 0, None

    def on_batch_end(self, batch, logs=None):
        loss = (logs or {}).get("loss", 0)
        if not (np.isnan(loss) or np.isinf(loss)):
            self.last_valid_loss = loss
            return
        self.nan_count += 1
        if self.nan_count == 1:
            print(
                f"\n[WARN] NaN en batch {batch} (último válido={self.last_valid_loss})"
            )

    def on_epoch_end(self, epoch, logs=None):
        if self.nan_count > 0:
            print(f"   [INFO] Epoch {epoch+1}: {self.nan_count} batches con NaN")
        self.nan_count = 0


# Evalúa métricas completas cada 3 epochs, pero persiste valores para EarlyStopping
metrics_callback = SeqSleepMetricsCallback(val_ds, STAGE_ORDER, eval_every=3)

callbacks = [
    NaNDebugCallback(),
    metrics_callback,
    EarlyStopping(
        monitor="val_f1_macro",
        mode="max",
        patience=CONFIG["early_stopping_patience"],
        restore_best_weights=False,
        verbose=1,
    ),
    ModelCheckpoint(
        filepath=f"{OUTPUT_PATH}/{model_name}_best.keras",
        monitor="val_f1_macro",
        mode="max",
        save_best_only=True,
        verbose=1,
    ),
]

print(f"Modelo: {model_name}")
print(f"Checkpoint: {OUTPUT_PATH}/{model_name}_best.keras")
print("[INFO] val_f1_macro y val_loss se persisten entre epochs para EarlyStopping")

## Reanudacion desde Checkpoint (opcional)

Si Kaggle se desconecto durante el entrenamiento, puedes reanudar desde el ultimo checkpoint.
Solo ejecuta la siguiente celda si necesitas reanudar.

In [None]:
# ============================================================
# REANUDAR DESDE CHECKPOINT (ejecutar solo si es necesario)
# ============================================================

RESUME_FROM_CHECKPOINT = False  # Cambiar a True para reanudar
CHECKPOINT_NAME = None  # Ejemplo: "lstm_full_20251125_143022_best.keras"

if RESUME_FROM_CHECKPOINT and CHECKPOINT_NAME:
    checkpoint_path = f"{OUTPUT_PATH}/{CHECKPOINT_NAME}"
    if os.path.exists(checkpoint_path):
        print(f"[INFO] Cargando modelo desde checkpoint: {checkpoint_path}")
        with strategy.scope():
            model = keras.models.load_model(checkpoint_path, custom_objects={})
        print("[OK] Modelo cargado exitosamente")
    else:
        print(f"[ERROR] Checkpoint no encontrado: {checkpoint_path}")
        for f in os.listdir(OUTPUT_PATH):
            if f.endswith(".keras"):
                print(f"   - {f}")
else:
    print("[INFO] Modo normal: se usará el modelo recién creado")

In [None]:
# ============================================================
# ENTRENAR MODELO CNN-LSTM
# ============================================================

print(f"\nIniciando entrenamiento ({CONFIG['execution_mode'].upper()} mode)...")
print("   Arquitectura: CNN-LSTM Secuencial (many-to-many)")
print(f"   Modo: {'Streaming TFRecord' if USE_STREAMING else 'RAM'}")
print(f"   Seq length: {CONFIG['seq_length']} epochs")
print(f"   Batch size efectivo: {CONFIG['effective_batch_size']}")
print(f"   Epochs maximos: {CONFIG['epochs']}")
print(f"   Steps/epoch: {steps_per_epoch_train}")

# No usamos class_weight con many-to-many TimeDistributed
print("   [INFO] sample_weight por epoch activo (corrige class imbalance)")

VALIDATION_FREQ = 3
training_start_time = time.time()

if USE_STREAMING:
    # Streaming mode: train_ds ya tiene repeat
    # FIXED: Removido validation_data para evitar doble validación
    # SeqSleepMetricsCallback ya hace la evaluación completa en val_ds
    history = model.fit(
        train_ds,
        steps_per_epoch=steps_per_epoch_train,
        epochs=CONFIG["epochs"],
        callbacks=callbacks,
        verbose=1,
    )
else:
    # RAM mode: necesitamos repeat
    train_ds_repeat = train_ds.repeat()
    # FIXED: Removido validation_data para evitar doble validación
    history = model.fit(
        train_ds_repeat,
        steps_per_epoch=steps_per_epoch_train,
        epochs=CONFIG["epochs"],
        callbacks=callbacks,
        verbose=1,
    )

training_time = time.time() - training_start_time
print(
    f"\n[OK] Entrenamiento completado en {training_time / 60:.2f} min ({training_time / 3600:.2f} h)"
)

gc.collect()
metrics_callback.restore_best_weights()

In [None]:
# ============================================================
# VISUALIZAR CURVAS DE APRENDIZAJE
# ============================================================

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# FIXED: Manejar caso val_count==0
has_val_metrics = "val_loss" in history.history and len(history.history["val_loss"]) > 0
if has_val_metrics:
    val_epochs = np.arange(len(history.history["val_loss"])) * VALIDATION_FREQ + (
        VALIDATION_FREQ - 1
    )

axes[0, 0].plot(history.history["loss"], label="Train Loss", linewidth=2)
if has_val_metrics:
    axes[0, 0].plot(
        val_epochs, history.history["val_loss"], label="Val Loss", linewidth=2
    )
axes[0, 0].set_xlabel("Epoch")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].set_title("Loss durante entrenamiento")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].plot(history.history["accuracy"], label="Train Acc", linewidth=2)
if has_val_metrics:
    axes[0, 1].plot(
        val_epochs, history.history["val_accuracy"], label="Val Acc", linewidth=2
    )
axes[0, 1].set_xlabel("Epoch")
axes[0, 1].set_ylabel("Accuracy")
axes[0, 1].set_title("Accuracy durante entrenamiento")
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

if "val_kappa" in history.history:
    val_epochs_kappa = val_epochs[: len(history.history["val_kappa"])]
    axes[1, 0].plot(
        val_epochs_kappa,
        history.history["val_kappa"],
        label="Val Kappa",
        linewidth=2,
        color="green",
    )
    axes[1, 0].axhline(
        y=metrics_callback.best_kappa,
        color="red",
        linestyle="--",
        label=f"Best={metrics_callback.best_kappa:.4f}",
    )
    axes[1, 0].set_xlabel("Epoch")
    axes[1, 0].set_ylabel("Cohen's Kappa")
    axes[1, 0].set_title("Cohen's Kappa")
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
else:
    axes[1, 0].text(0.5, 0.5, "Kappa no disponible", ha="center", va="center")

if "val_f1_macro" in history.history:
    val_epochs_f1 = val_epochs[: len(history.history["val_f1_macro"])]
    axes[1, 1].plot(
        val_epochs_f1,
        history.history["val_f1_macro"],
        label="Val F1-Macro",
        linewidth=2,
        color="purple",
    )
    axes[1, 1].axhline(
        y=metrics_callback.best_f1_macro,
        color="red",
        linestyle="--",
        label=f"Best={metrics_callback.best_f1_macro:.4f}",
    )
    axes[1, 1].set_xlabel("Epoch")
    axes[1, 1].set_ylabel("F1-Macro")
    axes[1, 1].set_title("F1-Macro (métrica de selección)")
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
else:
    axes[1, 1].text(0.5, 0.5, "F1-Macro no disponible", ha="center", va="center")

plt.tight_layout()
plt.savefig(f"{OUTPUT_PATH}/{model_name}_training_curves.png", dpi=150)
plt.show()

## Evaluacion en Test

In [None]:
# ============================================================
# EVALUACION EN TEST SET (SECUENCIAS SIN OVERLAP)
# ============================================================

print("\nEvaluando en Test Set...")
print(f"[INFO] Test usa stride={CONFIG['seq_stride_test']} (sin overlap)")
print("[INFO] Cada epoch aparece exactamente UNA vez -> métricas sin sesgo")

# Initialize metrics to None so subsequent cells can check
accuracy, kappa, f1_macro, f1_weighted = None, None, None, None
y_test_enc, y_pred_enc, y_pred_proba = None, None, None

if test_ds is None or test_count == 0:
    print("[WARN] No hay datos de test disponibles")
else:
    # Recolectar predicciones y labels
    y_true_list, y_pred_list, y_pred_proba_list = [], [], []

    for batch in test_ds:
        # Handle both 2-element (x, y) and 3-element (x, y, sw) datasets
        x_batch = batch[0]
        y_batch = batch[1]
        y_pred_logits = model.predict(x_batch, verbose=0)
        y_pred_proba = tf.nn.softmax(y_pred_logits, axis=-1).numpy()
        y_pred = np.argmax(y_pred_logits, axis=-1)

        # Flatten (batch, seq_len) -> (batch * seq_len,)
        y_true_list.append(y_batch.numpy().flatten())
        y_pred_list.append(y_pred.flatten())
        y_pred_proba_list.append(y_pred_proba.reshape(-1, y_pred_proba.shape[-1]))

    y_test_enc = np.concatenate(y_true_list)
    y_pred_enc = np.concatenate(y_pred_list)
    y_pred_proba = np.concatenate(y_pred_proba_list)

    accuracy = accuracy_score(y_test_enc, y_pred_enc)
    kappa = cohen_kappa_score(y_test_enc, y_pred_enc)
    f1_macro = f1_score(y_test_enc, y_pred_enc, average="macro", zero_division=0)
    f1_weighted = f1_score(y_test_enc, y_pred_enc, average="weighted", zero_division=0)

    print(f"\n{'=' * 50}")
    print("RESULTADOS EN TEST SET")
    print(f"{'=' * 50}")
    print(f"   Accuracy:    {accuracy:.4f} ({accuracy * 100:.2f}%)")
    print(f"   Cohen Kappa: {kappa:.4f}  <- Métrica principal en literatura")
    print(f"   F1 Macro:    {f1_macro:.4f}")
    print(f"   F1 Weighted: {f1_weighted:.4f}")
    print(f"   Epochs únicos evaluados: {len(y_test_enc):,}")
    print(
        f"   (stride={CONFIG['seq_stride_test']}, sin overlap -> cada epoch cuenta 1 vez)"
    )
    print(f"{'=' * 50}")

In [None]:
# ============================================================
# CLASSIFICATION REPORT
# ============================================================

if y_test_enc is not None and y_pred_enc is not None:
    print("\nClassification Report:")
    print(
        classification_report(
            y_test_enc,
            y_pred_enc,
            labels=np.arange(len(STAGE_ORDER)),
            target_names=STAGE_ORDER,
            digits=4,
            zero_division=0,
        )
    )
else:
    print("[SKIP] Classification report: no hay datos de test")

In [None]:
# ============================================================
# MATRIZ DE CONFUSIÓN
# ============================================================

if y_test_enc is None or y_pred_enc is None:
    print("[SKIP] Confusion matrix: no hay datos de test")
else:
    cm = confusion_matrix(y_test_enc, y_pred_enc, labels=np.arange(len(STAGE_ORDER)))
    # Protect against division by zero for classes absent from test set
    row_sums = cm.sum(axis=1)[:, np.newaxis]
    row_sums = np.where(row_sums == 0, 1, row_sums)  # Avoid NaN
    cm_normalized = cm.astype("float") / row_sums

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=STAGE_ORDER,
        yticklabels=STAGE_ORDER,
        ax=axes[0],
    )
    axes[0].set_xlabel("Predicho")
    axes[0].set_ylabel("Real")
    axes[0].set_title("Matriz de Confusión (Absoluta)")

    sns.heatmap(
        cm_normalized,
        annot=True,
        fmt=".2%",
        cmap="Blues",
        xticklabels=STAGE_ORDER,
        yticklabels=STAGE_ORDER,
        ax=axes[1],
    )
    axes[1].set_xlabel("Predicho")
    axes[1].set_ylabel("Real")
    axes[1].set_title("Matriz de Confusión (Normalizada)")

    plt.tight_layout()
    plt.savefig(f"{OUTPUT_PATH}/{model_name}_confusion_matrix.png", dpi=150)
    plt.show()

## Optimizacion de Hiperparametros (Optuna)

In [None]:
# ============================================================
# OPTIMIZACION CON OPTUNA (opcional)
# ============================================================

if CONFIG["run_optimization"]:
    print("[WARN] Optuna no implementado para modo streaming. Skip.")
else:
    print(
        "[SKIP] Optimización deshabilitada. Cambiar CONFIG['run_optimization'] = True para ejecutar."
    )

## Guardar Modelo y Resultados

In [None]:
# ============================================================
# GUARDAR MODELO Y ARTEFACTOS
# ============================================================

model.save(f"{OUTPUT_PATH}/{model_name}_final.keras")
print(f"[OK] Modelo guardado: {OUTPUT_PATH}/{model_name}_final.keras")

history_data = {k: pd.Series(v) for k, v in history.history.items()}
history_df = pd.DataFrame(history_data)
history_df.to_csv(f"{OUTPUT_PATH}/{model_name}_history.csv", index=False)

# Handle case where test evaluation was not run
_accuracy = float(accuracy) if accuracy is not None else None
_kappa = float(kappa) if kappa is not None else None
_f1_macro = float(f1_macro) if f1_macro is not None else None
_f1_weighted = float(f1_weighted) if f1_weighted is not None else None

results = {
    "model_name": model_name,
    "config": CONFIG,
    "metrics": {
        "accuracy": _accuracy,
        "kappa": _kappa,
        "f1_macro": _f1_macro,
        "f1_weighted": _f1_weighted,
    },
    "training": {
        "training_time_seconds": float(training_time),
        "training_time_minutes": float(training_time / 60),
        "epochs_trained": len(history.history["loss"]),
        "best_val_f1_macro": float(metrics_callback.best_f1_macro),
        "best_val_kappa": float(metrics_callback.best_kappa),
    },
    "dataset": {
        "train_samples": int(train_count),
        "val_samples": int(val_count),
        "test_samples": int(test_count),
        "seq_length": CONFIG["seq_length"],
        "model_type": "CNN-LSTM-Sequential",
        "execution_mode": CONFIG["execution_mode"],
    },
    "channel_stats": {"mean": mean_ch.tolist(), "std": std_ch.tolist()},
    "label_encoder_classes": list(le.classes_),
}

with open(f"{OUTPUT_PATH}/{model_name}_results.json", "w") as f:
    json.dump(results, f, indent=2, default=str)

with open(f"{OUTPUT_PATH}/{model_name}_label_encoder.pkl", "wb") as f:
    pickle.dump(le, f)

print("\nArchivos generados:")
print(f"   - {model_name}_final.keras (modelo)")
print(f"   - {model_name}_best.keras (mejor checkpoint)")
print(f"   - {model_name}_history.csv (historial)")
print(f"   - {model_name}_results.json (metricas y config)")

In [None]:
# ============================================================
# RESUMEN FINAL
# ============================================================

print("\n" + "=" * 60)
print("ENTRENAMIENTO CNN-LSTM SECUENCIAL COMPLETADO")
print("=" * 60)
print("\nResultados finales en Test Set:")
if accuracy is not None:
    print(f"   Accuracy:    {accuracy:.4f} ({accuracy * 100:.2f}%)")
    print(f"   Cohen Kappa: {kappa:.4f}")
    print(f"   F1 Macro:    {f1_macro:.4f}")
    print(f"   F1 Weighted: {f1_weighted:.4f}")
else:
    print("   [WARN] No se ejecutó evaluación en test")
print(f"\nModelo guardado en: {OUTPUT_PATH}")
print("=" * 60)

In [None]:
# ============================================================
# COMPRIMIR ARTEFACTOS PARA DESCARGA
# ============================================================

zip_path = f"{OUTPUT_PATH}/{model_name}_artifacts.zip"
exclude_extensions = (".tfrecord", ".zip")  # Excluir TFRecords y otros zips
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
    for fname in os.listdir(OUTPUT_PATH):
        fpath = os.path.join(OUTPUT_PATH, fname)
        # Solo incluir archivos (no directorios) que empiecen con model_name
        if (
            os.path.isfile(fpath)
            and fname.startswith(model_name)
            and not fname.endswith(exclude_extensions)
        ):
            zf.write(fpath, arcname=fname)

print(f"[OK] Artefactos comprimidos en: {zip_path} (excluyendo TFRecords)")
print("Archivos incluidos:")
with zipfile.ZipFile(zip_path, "r") as zf:
    for info in zf.infolist():
        print(f" - {info.filename} ({info.file_size / 1024:.1f} KB)")