# LSTM Bidireccional para Sleep Staging

Este notebook entrena un modelo **LSTM Bidireccional** para clasificacion de estadios de sueno usando datos PSG **trimmed** (preprocesados y recortados del dataset Sleep-EDF).

**Optimizado para Kaggle con 2x Tesla T4 (16GB VRAM cada una)**

### Modos de ejecución:
- **Debug** (`EXECUTION_MODE = "debug"`): Carga datos en RAM, usa subset de sujetos, ~5 min de setup; 20 min de corrida total
- **Full** (`EXECUTION_MODE = "full"`): Streaming TFRecord, todos los sujetos, ~30 min de setup; 2 horas de corrida total

### Caracteristicas:
- LSTM Bidireccional con mecanismo de atencion opcional
- Soporte multi-GPU con MirroredStrategy
- Optimizacion de hiperparametros con Optuna (opcional)
- **División train/val/test por SUJETOS (sin data leakage)** 
- Soporte para reanudar entrenamiento desde checkpoints
- Semillas fijas para reproducibilidad (SEED=42)

### Datos requeridos:
- Dataset `sleep-edf-trimmed-f32` en Kaggle (https://www.kaggle.com/datasets/ignaciolinari/sleep-edf-trimmed-f32)
  - `manifest_trimmed_spt.csv` (1 episodio por noche, 100 Hz)
  - `sleep_trimmed_spt/psg/*.fif` (PSG a 100 Hz, float32)
  - `sleep_trimmed_spt/hypnograms/*.csv` (anotaciones)
- Si usas la versión 200 Hz, apunta a `manifest_trimmed_resamp200.csv` + carpeta `sleep_trimmed_resamp200/`.
- Nota Kaggle: la versión 200 Hz puede consumir más VRAM; deja `sfreq=100` o usa el dataset 100 Hz para evitar OOM.
- Si usas otro slug, actualiza `DATA_PATH` abajo.

In [None]:
# ============================================================
# CONFIGURACIÓN INICIAL - Ejecutar primero
# ============================================================

import os
import warnings

warnings.filterwarnings("ignore")

# Detectar si estamos en Kaggle
IN_KAGGLE = os.path.exists("/kaggle/input")
print(f"Entorno: {'Kaggle' if IN_KAGGLE else 'Local'}")

# Paths según entorno
if IN_KAGGLE:
    DATA_PATH = "/kaggle/input/sleep-edf-trimmed-f32/sleep_trimmed_spt"
    OUTPUT_PATH = "/kaggle/working"
else:
    DATA_PATH = "../data/processed"
    OUTPUT_PATH = "../models"

os.makedirs(OUTPUT_PATH, exist_ok=True)

print(f"Data path: {DATA_PATH}")
print(f"Output path: {OUTPUT_PATH}")

In [None]:
# ============================================================
# VERIFICAR GPUs DISPONIBLES
# ============================================================

import tensorflow as tf  # noqa: E402

SEED = 42
tf.random.set_seed(SEED)

print("TensorFlow version:", tf.__version__)
print("\nGPUs disponibles:")

gpus = tf.config.list_physical_devices("GPU")
if gpus:
    for i, gpu in enumerate(gpus):
        print(f"  GPU {i}: {gpu}")
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    print(f"\n[OK] {len(gpus)} GPU(s) configuradas con memory growth")
else:
    print("[WARN] No se detectaron GPUs. El entrenamiento sera lento.")

if len(gpus) > 1:
    strategy = tf.distribute.MirroredStrategy()
    print(f"\nUsando MirroredStrategy con {strategy.num_replicas_in_sync} GPUs")
else:
    strategy = tf.distribute.get_strategy()
    if len(gpus) == 1:
        print("\nUsando estrategia por defecto (1 GPU)")
    else:
        print("\n[WARN] Usando CPU - entrenamiento será lento")

print("[INFO] Mixed precision se configurará después de definir CONFIG")

In [None]:
# ============================================================
# IMPORTS Y DEPENDENCIAS
# ============================================================

import gc  # noqa: E402
import hashlib  # noqa: E402
import json  # noqa: E402
import logging  # noqa: E402
import math  # noqa: E402
import pickle  # noqa: E402
import random  # noqa: E402
import re  # noqa: E402
import shutil  # noqa: E402
import time  # noqa: E402
import zipfile  # noqa: E402
from collections import Counter  # noqa: E402
from datetime import datetime  # noqa: E402
from pathlib import Path  # noqa: E402

import matplotlib.pyplot as plt  # noqa: E402
import numpy as np  # noqa: E402
import pandas as pd  # noqa: E402
import seaborn as sns  # noqa: E402

np.random.seed(SEED)
random.seed(SEED)
tf.keras.utils.set_random_seed(SEED)

from sklearn.metrics import (  # noqa: E402
    accuracy_score,
    classification_report,
    cohen_kappa_score,
    confusion_matrix,
    f1_score,
)
from sklearn.preprocessing import LabelEncoder  # noqa: E402
from sklearn.utils.class_weight import compute_class_weight  # noqa: E402
from tensorflow import keras  # noqa: E402
from tensorflow.keras import layers  # noqa: E402
from tensorflow.keras.callbacks import (  # noqa: E402
    Callback,
    EarlyStopping,
    ModelCheckpoint,
)

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
plt.style.use("seaborn-v0_8-whitegrid")
sns.set_palette("husl")

print("[OK] Imports completados")

In [None]:
# ============================================================
# INSTALAR DEPENDENCIAS (solo en Kaggle)
# ============================================================

try:
    import mne
except ImportError:
    if IN_KAGGLE:
        print("Instalando dependencias...")
        %pip install -q mne
        print("[OK] Dependencias instaladas")
        import mne
    else:
        raise

mne.set_log_level("ERROR")
print(f"MNE version: {mne.__version__}")

## Configuracion del Experimento

In [None]:
# ============================================================
# HIPERPARAMETROS - AJUSTAR SEGUN NECESIDAD
# ============================================================

EXECUTION_MODE = "debug"  # "debug" o "full"

CONFIG = {
    "execution_mode": EXECUTION_MODE,
    "manifest_path": (
        "/kaggle/input/sleep-edf-trimmed-f32/manifest_trimmed_spt.csv"
        if IN_KAGGLE
        else f"{DATA_PATH}/manifest_trimmed_spt.csv"
    ),
    "epoch_length": 30.0,
    "sfreq": 100,
    "debug_max_subjects": 30,
    "limit_sessions": None,
    "test_size": 0.15,
    "val_size": 0.15,
    "random_state": 42,
    "lstm_units": 128,
    "dropout_rate": 0.4,
    "bidirectional": True,  # En sleep staging offline bidireccional es mejor. Para simular real-time, usar unidireccional.
    "use_attention": True,  # La capa de atención agrega overhead pero suele mejorar rendimiento en N1.
    "learning_rate_initial": 3e-4,
    "learning_rate_min": 1e-6,
    "warmup_epochs": 3,
    "batch_size": 64,
    "epochs": 300 if EXECUTION_MODE == "full" else 150,
    "use_class_weights": True,
    "class_weight_clip": 1.5,
    "std_epsilon": 1e-6,
    "clip_value": 5.0,
    "early_stopping_patience": 40 if EXECUTION_MODE == "full" else 25,
    "use_mixed_precision": False,
    "run_optimization": False,
    "n_optuna_trials": 30,
    "streaming": EXECUTION_MODE == "full",
    "tfrecord_dir": f"{OUTPUT_PATH}/tfrecords_lstm",
    "shuffle_buffer": 5000,
}

CONFIG["effective_batch_size"] = CONFIG["batch_size"] * strategy.num_replicas_in_sync

if CONFIG.get("use_mixed_precision", False) and gpus:
    try:
        from tensorflow.keras import mixed_precision

        mixed_precision.set_global_policy("mixed_float16")
        print("[OK] Mixed precision (float16) habilitado")
    except Exception as e:
        print(f"[WARN] No se pudo habilitar mixed precision: {e}")
else:
    print("[INFO] Mixed precision desactivado (float32) para mayor estabilidad")

print("\n" + "=" * 60)
if EXECUTION_MODE == "debug":
    print(" MODO DEBUG: Carga rápida en RAM, subset de sujetos")
    print(f"   Max sujetos: {CONFIG['debug_max_subjects']}")
else:
    print(" MODO FULL: Streaming TFRecord, todos los sujetos")
print(f"   Epochs: {CONFIG['epochs']}")
print("=" * 60)

print("\nConfiguracion del experimento:")
for key, value in CONFIG.items():
    print(f"   {key}: {value}")

## Carga de Datos

In [None]:
# ============================================================
# FUNCIONES DE CARGA DE DATOS
# ============================================================

STAGE_CANONICAL = {
    "Sleep stage W": "W",
    "Sleep stage 1": "N1",
    "Sleep stage 2": "N2",
    "Sleep stage 3": "N3",
    "Sleep stage 4": "N3",
    "Sleep stage R": "REM",
}
STAGE_ORDER = ["W", "N1", "N2", "N3", "REM"]
REQUIRED_CHANNELS = ["EEG Fpz-Cz", "EEG Pz-Oz"]
OPTIONAL_CHANNELS = ["EOG horizontal", "EMG submental"]


def extract_subject_core(subject_id):
    """Agrupa noches del mismo sujeto. SC4XXNy -> SC4XX"""
    sid = str(subject_id)
    match = re.match(r"(SC4\d{2})", sid)
    return match.group(1) if match else sid


def load_psg_data(psg_path, channels=None, target_sfreq=None):
    """Carga datos PSG desde archivo .fif."""
    raw = mne.io.read_raw_fif(str(psg_path), preload=True, verbose="ERROR")
    available = set(raw.ch_names)
    if channels is None:
        missing_required = [ch for ch in REQUIRED_CHANNELS if ch not in available]
        if missing_required:
            raise ValueError(f"Canales faltantes: {missing_required}")
        channels = REQUIRED_CHANNELS.copy()
        for ch in OPTIONAL_CHANNELS:
            if ch in available:
                channels.append(ch)
    raw.pick(channels)
    if target_sfreq and raw.info["sfreq"] != target_sfreq:
        raw.resample(target_sfreq)
    return raw.get_data(), raw.info["sfreq"], raw.ch_names


def load_hypnogram(hyp_path):
    df = pd.read_csv(hyp_path)
    df["stage_canonical"] = df["description"].map(STAGE_CANONICAL)
    return df


def create_epochs(data, sfreq, epoch_length=30.0):
    samples_per_epoch = int(epoch_length * sfreq)
    n_channels, n_samples = data.shape
    n_epochs = n_samples // samples_per_epoch
    epochs = [
        data[:, i * samples_per_epoch : (i + 1) * samples_per_epoch]
        for i in range(n_epochs)
    ]
    epoch_times = [i * epoch_length for i in range(n_epochs)]
    return np.array(epochs), np.array(epoch_times)


def assign_stages(epoch_times, hypnogram, epoch_length=30.0):
    stages = []
    for t in epoch_times:
        epoch_center = t + epoch_length / 2
        mask = (hypnogram["onset"] <= epoch_center) & (
            hypnogram["onset"] + hypnogram["duration"] > epoch_center
        )
        matched = hypnogram[mask]
        stages.append(matched.iloc[0]["stage_canonical"] if len(matched) > 0 else None)
    return stages


print("[OK] Funciones de carga definidas")

In [None]:
# ============================================================
# PIPELINE DE DATOS (DEBUG: RAM / FULL: TFRecord)
# ============================================================


def resolve_paths(row, manifest_dir, dataset_dir_name):
    """Resuelve rutas de PSG e hipnograma respetando Kaggle/local."""
    psg_path_str = row.get("psg_trimmed_path", "")
    hyp_path_str = row.get("hypnogram_trimmed_path", "")
    if pd.isna(psg_path_str):
        psg_path_str = ""
    if pd.isna(hyp_path_str):
        hyp_path_str = ""
    base_data_root = manifest_dir.parent

    if psg_path_str and hyp_path_str:
        if IN_KAGGLE:
            psg_rel, hyp_rel = Path(psg_path_str), Path(hyp_path_str)
            psg_anchor = next(
                (
                    i
                    for i, p in enumerate(psg_rel.parts)
                    if p.startswith("sleep_trimmed")
                ),
                None,
            )
            hyp_anchor = next(
                (
                    i
                    for i, p in enumerate(hyp_rel.parts)
                    if p.startswith("sleep_trimmed")
                ),
                None,
            )
            if psg_anchor is not None and hyp_anchor is not None:
                psg_path = Path(DATA_PATH) / Path(*psg_rel.parts[psg_anchor:])
                hyp_path = Path(DATA_PATH) / Path(*hyp_rel.parts[hyp_anchor:])
            else:
                psg_path = Path(DATA_PATH) / dataset_dir_name / "psg" / psg_rel.name
                hyp_path = (
                    Path(DATA_PATH) / dataset_dir_name / "hypnograms" / hyp_rel.name
                )
        else:
            psg_rel, hyp_rel = Path(psg_path_str), Path(hyp_path_str)
            psg_path = (
                base_data_root / psg_rel.relative_to("data")
                if psg_rel.parts and psg_rel.parts[0] == "data"
                else manifest_dir / psg_rel
                if not psg_rel.is_absolute()
                else psg_rel
            )
            hyp_path = (
                base_data_root / hyp_rel.relative_to("data")
                if hyp_rel.parts and hyp_rel.parts[0] == "data"
                else manifest_dir / hyp_rel
                if not hyp_rel.is_absolute()
                else hyp_rel
            )
    else:
        subset = row.get("subset", "sleep-cassette")
        version = row.get("version", "1.0.0")
        dataset_dir = manifest_dir / dataset_dir_name
        psg_path = (
            dataset_dir
            / "psg"
            / f"{row['subject_id']}_{subset}_{version}_trimmed_raw.fif"
        )
        hyp_path = (
            dataset_dir
            / "hypnograms"
            / f"{row['subject_id']}_{subset}_{version}_trimmed_annotations.csv"
        )
    return psg_path, hyp_path


def get_subject_cores_for_mode(
    manifest_path, execution_mode, debug_max_subjects, random_state
):
    """Obtiene los cores de sujetos a usar según el modo de ejecución."""
    manifest = pd.read_csv(manifest_path)
    manifest_ok = manifest[manifest["status"] == "ok"].copy()
    all_cores = manifest_ok["subject_id"].apply(extract_subject_core).unique()
    rng = np.random.default_rng(random_state)
    rng.shuffle(all_cores)
    if execution_mode == "debug" and debug_max_subjects:
        selected_cores = set(all_cores[:debug_max_subjects])
        print(f"[DEBUG] Usando {len(selected_cores)}/{len(all_cores)} sujetos")
    else:
        selected_cores = set(all_cores)
        print(f"[FULL] Usando todos los {len(selected_cores)} sujetos")
    return selected_cores


def iter_sessions(manifest_path, epoch_length, sfreq, allowed_cores=None):
    """Itera sesiones entregando rutas resueltas."""
    manifest = pd.read_csv(manifest_path)
    manifest_ok = manifest[manifest["status"] == "ok"].copy()
    if allowed_cores is not None:
        manifest_ok = manifest_ok[
            manifest_ok["subject_id"].apply(extract_subject_core).isin(allowed_cores)
        ]
    manifest_dir = Path(manifest_path).parent
    dataset_dir_name = (
        "sleep_trimmed_resamp200"
        if (manifest_dir / "sleep_trimmed_resamp200").exists()
        else "sleep_trimmed_spt"
        if (manifest_dir / "sleep_trimmed_spt").exists()
        else "sleep_trimmed"
    )
    total_sessions = len(manifest_ok)
    print(f"\nProcesando {total_sessions} sesiones...")
    for i, (_, row) in enumerate(manifest_ok.iterrows(), start=1):
        subject_id = row["subject_id"]
        subject_core = extract_subject_core(subject_id)
        if allowed_cores is not None and subject_core not in allowed_cores:
            continue
        psg_path, hyp_path = resolve_paths(row, manifest_dir, dataset_dir_name)
        if not psg_path.exists() or not hyp_path.exists():
            continue
        yield i, total_sessions, subject_id, subject_core, psg_path, hyp_path


def update_running_stats(stats, epochs):
    if stats is None:
        stats = {"n": 0, "sum": None, "sumsq": None}
    if stats["sum"] is None:
        stats["sum"] = np.zeros(epochs.shape[1], dtype=np.float64)
        stats["sumsq"] = np.zeros(epochs.shape[1], dtype=np.float64)
    stats["sum"] += epochs.sum(axis=(0, 2))
    stats["sumsq"] += (epochs**2).sum(axis=(0, 2))
    stats["n"] += epochs.shape[0] * epochs.shape[2]
    return stats


def finalize_running_stats(stats, std_epsilon=1e-6):
    mean = stats["sum"] / stats["n"]
    var = stats["sumsq"] / stats["n"] - mean**2
    var = np.maximum(var, 0.0)
    std = np.sqrt(var + std_epsilon)
    return mean.astype(np.float32), std.astype(np.float32)


def load_all_data_to_ram(manifest_path, epoch_length, sfreq, allowed_cores):
    """Carga todos los datos en RAM (modo debug)."""
    print("\n[DEBUG] Cargando datos en RAM...")
    start_time = time.time()
    all_epochs, all_stages, all_subject_cores = [], [], []
    expected_channels = None
    nan_inf_count = 0

    for i, total, subject_id, subject_core, psg_path, hyp_path in iter_sessions(
        manifest_path, epoch_length, sfreq, allowed_cores
    ):
        data, actual_sfreq, ch_names = load_psg_data(psg_path, target_sfreq=sfreq)
        if expected_channels is None:
            expected_channels = list(ch_names)
        elif list(ch_names) != expected_channels:
            raise ValueError(f"Canales inconsistentes: {list(ch_names)}")
        hypnogram = load_hypnogram(hyp_path)
        epochs, epoch_times = create_epochs(data, actual_sfreq, epoch_length)
        stages = assign_stages(epoch_times, hypnogram, epoch_length)
        for epoch, stage in zip(epochs, stages):
            if stage not in STAGE_ORDER:
                continue
            if not np.all(np.isfinite(epoch)):
                nan_inf_count += 1
                continue
            all_epochs.append(epoch)
            all_stages.append(stage)
            all_subject_cores.append(subject_core)
        if i % 10 == 0 or i == total:
            print(f"   Cargando: {i}/{total} sesiones")

    if nan_inf_count > 0:
        print(f"[WARN] Se descartaron {nan_inf_count} epochs con NaN/Inf")
    X = np.array(all_epochs, dtype=np.float32)
    y = np.array([STAGE_ORDER.index(s) for s in all_stages], dtype=np.int32)
    subject_cores = np.array(all_subject_cores)
    print(
        f"[OK] Datos cargados en {time.time() - start_time:.1f}s. Shape: {X.shape}, Memoria: {X.nbytes / 1024**3:.2f} GB"
    )
    return X, y, subject_cores, expected_channels


def split_data_by_subject(X, y, subject_cores, test_size, val_size, random_state):
    """Divide datos por sujeto (sin data leakage)."""
    unique_cores = np.unique(subject_cores)
    rng = np.random.default_rng(random_state)
    rng.shuffle(unique_cores)
    n_test = max(1, int(len(unique_cores) * test_size))
    n_val = max(1, int(len(unique_cores) * val_size))
    test_cores = set(unique_cores[:n_test])
    val_cores = set(unique_cores[n_test : n_test + n_val])
    train_cores = set(unique_cores[n_test + n_val :])
    train_mask = np.isin(subject_cores, list(train_cores))
    val_mask = np.isin(subject_cores, list(val_cores))
    test_mask = np.isin(subject_cores, list(test_cores))
    return (
        X[train_mask],
        y[train_mask],
        X[val_mask],
        y[val_mask],
        X[test_mask],
        y[test_mask],
        train_cores,
        val_cores,
        test_cores,
    )


def normalize_data(X_train, X_val, X_test, clip_value, std_epsilon):
    """Normaliza datos usando estadísticas de train (z-score por canal)."""
    mean_ch = X_train.mean(axis=(0, 2))
    var_ch = X_train.var(axis=(0, 2))
    std_ch = np.sqrt(var_ch + std_epsilon)
    X_train_norm = np.clip(
        (X_train - mean_ch[None, :, None]) / std_ch[None, :, None],
        -clip_value,
        clip_value,
    )
    X_val_norm = np.clip(
        (X_val - mean_ch[None, :, None]) / std_ch[None, :, None],
        -clip_value,
        clip_value,
    )
    X_test_norm = np.clip(
        (X_test - mean_ch[None, :, None]) / std_ch[None, :, None],
        -clip_value,
        clip_value,
    )
    return X_train_norm, X_val_norm, X_test_norm, mean_ch, std_ch


print("[OK] Funciones de pipeline definidas")

In [None]:
# ============================================================
# FUNCIONES TFRECORD
# ============================================================


def pass1_stats(manifest_path, epoch_length, sfreq, allowed_cores=None):
    """Primera pasada: mean/std por canal y conteo de clases."""
    stats, class_counts, input_shape, expected_channels = None, Counter(), None, None
    nan_inf_count = 0

    for i, total, subject_id, subject_core, psg_path, hyp_path in iter_sessions(
        manifest_path, epoch_length, sfreq, allowed_cores
    ):
        data, actual_sfreq, ch_names = load_psg_data(psg_path, target_sfreq=sfreq)
        if expected_channels is None:
            expected_channels = list(ch_names)
        hypnogram = load_hypnogram(hyp_path)
        epochs, epoch_times = create_epochs(data, actual_sfreq, epoch_length)
        stages = assign_stages(epoch_times, hypnogram, epoch_length)

        valid_mask = [s in STAGE_ORDER for s in stages]
        if not any(valid_mask):
            continue
        valid_epochs = epochs[valid_mask]
        valid_stages = [s for s in stages if s in STAGE_ORDER]

        finite_mask = np.all(np.isfinite(valid_epochs), axis=(1, 2))
        if not np.all(finite_mask):
            nan_inf_count += np.sum(~finite_mask)
            valid_epochs = valid_epochs[finite_mask]
            valid_stages = [s for s, m in zip(valid_stages, finite_mask) if m]

        if len(valid_epochs) == 0:
            continue
        if input_shape is None:
            input_shape = valid_epochs.shape[1:]
        stats = update_running_stats(stats, valid_epochs)
        class_counts.update(valid_stages)

        if i % 20 == 0 or i == total:
            print(f"   Pasada 1: {i}/{total} sesiones")

    if nan_inf_count > 0:
        print(f"[WARN] Se descartaron {nan_inf_count} epochs con NaN/Inf")
    assert input_shape is not None, "No se encontraron epochs validos"
    mean, std = finalize_running_stats(
        stats, std_epsilon=CONFIG.get("std_epsilon", 1e-6)
    )
    return mean, std, class_counts, input_shape, expected_channels


def make_example(x, y):
    feature = {
        "x": tf.train.Feature(float_list=tf.train.FloatList(value=x.ravel())),
        "y": tf.train.Feature(int64_list=tf.train.Int64List(value=[y])),
    }
    return tf.train.Example(
        features=tf.train.Features(feature=feature)
    ).SerializeToString()


def assign_subject_splits(
    manifest_path, test_size, val_size, random_state, allowed_cores=None
):
    manifest = pd.read_csv(manifest_path)
    manifest_ok = manifest[manifest["status"] == "ok"].copy()
    subject_cores = manifest_ok["subject_id"].apply(extract_subject_core).unique()
    if allowed_cores is not None:
        subject_cores = np.array([c for c in subject_cores if c in allowed_cores])
    rng = np.random.default_rng(random_state)
    rng.shuffle(subject_cores)
    n_test = max(1, int(len(subject_cores) * test_size))
    n_val = max(1, int(len(subject_cores) * val_size))
    test_cores = set(subject_cores[:n_test])
    val_cores = set(subject_cores[n_test : n_test + n_val])
    train_cores = set(subject_cores[n_test + n_val :])
    split_map = dict.fromkeys(train_cores, "train")
    split_map.update(dict.fromkeys(val_cores, "val"))
    split_map.update(dict.fromkeys(test_cores, "test"))
    return split_map, {
        "train": len(train_cores),
        "val": len(val_cores),
        "test": len(test_cores),
    }


def build_tfrecord_splits(
    manifest_path,
    mean,
    std,
    split_map,
    epoch_length,
    sfreq,
    tfrecord_dir,
    expected_channels=None,
):
    tfrecord_dir = Path(tfrecord_dir)
    tfrecord_dir.mkdir(parents=True, exist_ok=True)
    writers = {
        k: tf.io.TFRecordWriter(str(tfrecord_dir / f"{k}.tfrecord"))
        for k in ["train", "val", "test"]
    }
    counts, session_counts, subject_sets = (
        Counter(),
        Counter(),
        {k: set() for k in writers},
    )
    skipped_nan_inf = 0
    clip_value = CONFIG.get("clip_value", 5.0)
    allowed_cores = set(split_map.keys())

    for i, total, subject_id, subject_core, psg_path, hyp_path in iter_sessions(
        manifest_path, epoch_length, sfreq, allowed_cores
    ):
        split = split_map.get(subject_core)
        if split is None:
            continue
        session_counts[split] += 1
        subject_sets[split].add(subject_core)
        data, _, ch_names = load_psg_data(psg_path, target_sfreq=sfreq)
        if expected_channels and list(ch_names) != expected_channels:
            raise ValueError(f"Canales inconsistentes en {psg_path}")
        hypnogram = load_hypnogram(hyp_path)
        epochs, epoch_times = create_epochs(data, sfreq, epoch_length)
        stages = assign_stages(epoch_times, hypnogram, epoch_length)

        for epoch, stage in zip(epochs, stages):
            if stage not in STAGE_ORDER:
                continue
            if not np.all(np.isfinite(epoch)):
                skipped_nan_inf += 1
                continue
            y = STAGE_ORDER.index(stage)
            x = np.clip((epoch - mean[:, None]) / std[:, None], -clip_value, clip_value)
            if not np.all(np.isfinite(x)):
                skipped_nan_inf += 1
                continue
            writers[split].write(make_example(x.astype(np.float32), y))
            counts[split] += 1

        if i % 20 == 0 or i == total:
            print(f"   Pasada 2 ({split}): {i}/{total} sesiones")

    for w in writers.values():
        w.close()
    if skipped_nan_inf > 0:
        print(f"[WARN] Se descartaron {skipped_nan_inf} epochs con NaN/Inf")
    return (
        {k: str(tfrecord_dir / f"{k}.tfrecord") for k in writers},
        counts,
        session_counts,
        {k: len(subject_sets[k]) for k in subject_sets},
    )


def make_dataset(
    tfrecord_path, input_shape, batch_size, shuffle=False, repeat=False, for_lstm=True
):
    """Crea dataset desde TFRecord. for_lstm=True transpone para LSTM: (samples, channels)."""
    feature_description = {
        "x": tf.io.FixedLenFeature([input_shape[0] * input_shape[1]], tf.float32),
        "y": tf.io.FixedLenFeature([], tf.int64),
    }
    clip_value = CONFIG.get("clip_value", 5.0)

    def _parse(example_proto):
        example = tf.io.parse_single_example(example_proto, feature_description)
        x = tf.reshape(example["x"], input_shape)
        x = tf.clip_by_value(x, -clip_value, clip_value)
        x = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x))
        if for_lstm:
            x = tf.transpose(
                x, perm=[1, 0]
            )  # (channels, samples) -> (samples, channels)
        y = tf.cast(example["y"], tf.int32)
        return x, y

    ds = tf.data.TFRecordDataset([tfrecord_path], num_parallel_reads=tf.data.AUTOTUNE)
    ds = ds.map(_parse, num_parallel_calls=tf.data.AUTOTUNE)
    if shuffle:
        ds = ds.shuffle(
            CONFIG["shuffle_buffer"],
            seed=CONFIG["random_state"],
            reshuffle_each_iteration=True,
        )
    if repeat:
        ds = ds.repeat()
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds


def get_tfrecord_cache_key(manifest_path, config):
    manifest_mtime = os.path.getmtime(manifest_path)
    key_data = (
        f"{manifest_path}_{manifest_mtime}_{config['sfreq']}_{config['epoch_length']}"
        f"_{config.get('clip_value', 'na')}_{config['random_state']}"
        f"_{config['test_size']}_{config['val_size']}_{config.get('debug_max_subjects', 'all')}"
    )
    return hashlib.md5(key_data.encode()).hexdigest()[:12]


print("[OK] Funciones TFRecord definidas")

In [None]:
# ============================================================
# EJECUTAR PIPELINE SEGÚN MODO
# ============================================================

print(f"\n[INFO] Modo de ejecución: {CONFIG['execution_mode'].upper()}")

selected_cores = get_subject_cores_for_mode(
    CONFIG["manifest_path"],
    CONFIG["execution_mode"],
    CONFIG["debug_max_subjects"],
    CONFIG["random_state"],
)

if CONFIG["streaming"]:
    print("\n[FULL] Iniciando pipeline TFRecord (streaming)...")
    tfrecord_path = Path(CONFIG["tfrecord_dir"])
    cache_key = get_tfrecord_cache_key(CONFIG["manifest_path"], CONFIG)
    cache_marker = tfrecord_path / f".cache_{cache_key}"
    cache_metadata_path = tfrecord_path / "cache_metadata.json"

    if tfrecord_path.exists() and cache_marker.exists():
        print(f"[INFO] Reutilizando TFRecords existentes (cache: {cache_key})")
        REUSE_TFRECORDS = True
    else:
        if tfrecord_path.exists():
            print("[INFO] Cache inválido, regenerando TFRecords...")
            shutil.rmtree(tfrecord_path)
        REUSE_TFRECORDS = False

    if REUSE_TFRECORDS:
        with open(cache_metadata_path) as f:
            cache_meta = json.load(f)
        mean_ch = np.array(cache_meta["mean_ch"], dtype=np.float32)
        std_ch = np.array(cache_meta["std_ch"], dtype=np.float32)
        class_counts_train = Counter(cache_meta["class_counts_train"])
        input_shape = tuple(cache_meta["input_shape"])
        split_counts = cache_meta["split_counts"]
        tfrecord_paths = {
            k: str(tfrecord_path / f"{k}.tfrecord") for k in ["train", "val", "test"]
        }
        split_map, _ = assign_subject_splits(
            CONFIG["manifest_path"],
            CONFIG["test_size"],
            CONFIG["val_size"],
            CONFIG["random_state"],
            selected_cores,
        )
        train_cores = {c for c, s in split_map.items() if s == "train"}
        val_cores = {c for c, s in split_map.items() if s == "val"}
        test_cores = {c for c, s in split_map.items() if s == "test"}
        print(f"   mean: {mean_ch}, std: {std_ch}")
        print(
            f"   train: {split_counts['train']:,}, val: {split_counts['val']:,}, test: {split_counts['test']:,} epochs"
        )
    else:
        split_map, split_subjects = assign_subject_splits(
            CONFIG["manifest_path"],
            CONFIG["test_size"],
            CONFIG["val_size"],
            CONFIG["random_state"],
            selected_cores,
        )
        train_cores = {c for c, s in split_map.items() if s == "train"}
        val_cores = {c for c, s in split_map.items() if s == "val"}
        test_cores = {c for c, s in split_map.items() if s == "test"}
        mean_ch, std_ch, class_counts_train, input_shape, expected_channels = (
            pass1_stats(
                CONFIG["manifest_path"],
                CONFIG["epoch_length"],
                CONFIG["sfreq"],
                train_cores,
            )
        )
        print(f"\n[OK] Estadísticas: mean={mean_ch}, std={std_ch}")
        if np.any(std_ch < 1e-6):
            std_ch = np.maximum(std_ch, 1e-6)
        tfrecord_paths, split_counts, _, _ = build_tfrecord_splits(
            CONFIG["manifest_path"],
            mean_ch,
            std_ch,
            split_map,
            CONFIG["epoch_length"],
            CONFIG["sfreq"],
            CONFIG["tfrecord_dir"],
            expected_channels,
        )
        print(f"\n[OK] TFRecords generados: {split_counts}")
        cache_meta = {
            "mean_ch": mean_ch.tolist(),
            "std_ch": std_ch.tolist(),
            "class_counts_train": dict(class_counts_train),
            "input_shape": list(input_shape),
            "split_counts": split_counts,
        }
        tfrecord_path.mkdir(parents=True, exist_ok=True)
        with open(cache_metadata_path, "w") as f:
            json.dump(cache_meta, f, indent=2)
        cache_marker.touch()
        print(f"[OK] Cache guardado (key: {cache_key})")

    INPUT_SHAPE = (input_shape[1], input_shape[0])  # (samples, channels) for LSTM
    train_count, val_count, test_count = (
        split_counts.get("train", 0),
        split_counts.get("val", 0),
        split_counts.get("test", 0),
    )
    train_ds = make_dataset(
        tfrecord_paths["train"],
        input_shape,
        CONFIG["effective_batch_size"],
        shuffle=True,
        repeat=True,
        for_lstm=True,
    )
    val_ds = make_dataset(
        tfrecord_paths["val"],
        input_shape,
        CONFIG["effective_batch_size"],
        shuffle=False,
        repeat=True,
        for_lstm=True,
    )
    test_ds = make_dataset(
        tfrecord_paths["test"],
        input_shape,
        CONFIG["effective_batch_size"],
        shuffle=False,
        repeat=False,
        for_lstm=True,
    )
    USE_STREAMING = True
else:
    print("\n[DEBUG] Iniciando pipeline en RAM...")
    X, y, subject_cores_arr, expected_channels = load_all_data_to_ram(
        CONFIG["manifest_path"], CONFIG["epoch_length"], CONFIG["sfreq"], selected_cores
    )
    (
        X_train,
        y_train,
        X_val,
        y_val,
        X_test,
        y_test,
        train_cores,
        val_cores,
        test_cores,
    ) = split_data_by_subject(
        X,
        y,
        subject_cores_arr,
        CONFIG["test_size"],
        CONFIG["val_size"],
        CONFIG["random_state"],
    )
    print(
        f"\n[OK] División: Train={len(X_train):,}, Val={len(X_val):,}, Test={len(X_test):,}"
    )
    X_train_norm, X_val_norm, X_test_norm, mean_ch, std_ch = normalize_data(
        X_train, X_val, X_test, CONFIG["clip_value"], CONFIG["std_epsilon"]
    )
    print(f"[OK] Normalización: mean={mean_ch}, std={std_ch}")
    del X, X_train, X_val, X_test
    gc.collect()

    # Transponer para LSTM: (epochs, channels, samples) -> (epochs, samples, channels)
    X_train_lstm = np.transpose(X_train_norm, (0, 2, 1))
    X_val_lstm = np.transpose(X_val_norm, (0, 2, 1))
    X_test_lstm = np.transpose(X_test_norm, (0, 2, 1))

    INPUT_SHAPE = X_train_lstm.shape[1:]
    train_count, val_count, test_count = (
        len(X_train_lstm),
        len(X_val_lstm),
        len(X_test_lstm),
    )
    train_ds = (
        tf.data.Dataset.from_tensor_slices((X_train_lstm, y_train))
        .shuffle(
            min(10000, train_count),
            seed=CONFIG["random_state"],
            reshuffle_each_iteration=True,
        )
        .batch(CONFIG["effective_batch_size"])
        .prefetch(tf.data.AUTOTUNE)
    )
    val_ds = (
        tf.data.Dataset.from_tensor_slices((X_val_lstm, y_val))
        .batch(CONFIG["effective_batch_size"])
        .prefetch(tf.data.AUTOTUNE)
    )
    test_ds = (
        tf.data.Dataset.from_tensor_slices((X_test_lstm, y_test))
        .batch(CONFIG["effective_batch_size"])
        .prefetch(tf.data.AUTOTUNE)
    )
    class_counts_train = Counter([STAGE_ORDER[yi] for yi in y_train])
    tfrecord_paths = None
    USE_STREAMING = False

# Class weights
le = LabelEncoder()
le.classes_ = np.array(STAGE_ORDER)
counts_list = [class_counts_train.get(stage, 0) for stage in STAGE_ORDER]
if any(c == 0 for c in counts_list):
    counts_list = [max(c, 1) for c in counts_list]
y_for_weights = np.repeat(np.arange(len(STAGE_ORDER)), counts_list)
class_weights_arr = compute_class_weight(
    "balanced", classes=np.arange(len(STAGE_ORDER)), y=y_for_weights
)
class_weight_clip = CONFIG.get("class_weight_clip", 1.5)
class_weights = (
    {
        k: float(np.clip(v, 0.5, class_weight_clip))
        for k, v in enumerate(class_weights_arr)
    }
    if CONFIG.get("use_class_weights", True)
    else None
)
print(
    f"Class weights: {class_weights}" if class_weights else "Class weights desactivados"
)

print(f"\n{'=' * 60}")
print(f"[OK] Pipeline {'STREAMING' if USE_STREAMING else 'RAM'} listo")
print(f"   Input shape (samples, channels): {INPUT_SHAPE}")
print(f"   Train: {train_count:,}, Val: {val_count:,}, Test: {test_count:,} epochs")
print("=" * 60)

## Arquitectura LSTM Bidireccional

In [None]:
# ============================================================
# CAPA DE ATENCION PERSONALIZADA
# ============================================================


class AttentionLayer(layers.Layer):
    """Capa de atencion simple para LSTM."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        self.W = self.add_weight(
            name="attention_weight",
            shape=(input_shape[-1], 1),
            initializer="glorot_uniform",
            trainable=True,
        )
        self.b = self.add_weight(
            name="attention_bias",
            shape=(input_shape[1], 1),
            initializer="zeros",
            trainable=True,
        )
        super().build(input_shape)

    def call(self, x):
        e = keras.backend.tanh(keras.backend.dot(x, self.W) + self.b)
        a = keras.backend.softmax(e, axis=1)
        return keras.backend.sum(x * a, axis=1)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[-1])

    def get_config(self):
        return super().get_config()


print("[OK] Capa de Atencion definida")

In [None]:
# ============================================================
# LEARNING RATE SCHEDULE Y MODELO LSTM
# ============================================================


def create_lr_schedule(
    initial_lr, min_lr, warmup_epochs, total_epochs, steps_per_epoch
):
    """Warmup lineal + cosine decay."""
    warmup_steps = warmup_epochs * steps_per_epoch
    total_steps = total_epochs * steps_per_epoch
    decay_steps = max(total_steps - warmup_steps, 1)

    class WarmupCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
        def __init__(self, initial_lr, min_lr, warmup_steps, decay_steps):
            super().__init__()
            self.initial_lr = tf.cast(initial_lr, tf.float32)
            self.min_lr = tf.cast(min_lr, tf.float32)
            self.warmup_steps = tf.cast(warmup_steps, tf.float32)
            self.decay_steps = tf.cast(decay_steps, tf.float32)

        def __call__(self, step):
            step = tf.cast(step, tf.float32)
            warmup_progress = step / tf.maximum(self.warmup_steps, 1.0)
            warmup_lr = self.min_lr + (self.initial_lr - self.min_lr) * tf.minimum(
                warmup_progress, 1.0
            )
            decay_progress = tf.minimum(
                tf.maximum((step - self.warmup_steps) / self.decay_steps, 0.0), 1.0
            )
            cosine_decay = 0.5 * (
                1.0 + tf.cos(tf.constant(np.pi, dtype=tf.float32) * decay_progress)
            )
            decay_lr = self.min_lr + (self.initial_lr - self.min_lr) * cosine_decay
            return tf.cond(
                step < self.warmup_steps, lambda: warmup_lr, lambda: decay_lr
            )

        def get_config(self):
            return {
                "initial_lr": float(self.initial_lr),
                "min_lr": float(self.min_lr),
                "warmup_steps": int(self.warmup_steps),
                "decay_steps": int(self.decay_steps),
            }

    return WarmupCosineDecay(initial_lr, min_lr, warmup_steps, decay_steps)


def build_lstm_model(
    input_shape,
    n_classes=5,
    lstm_units=128,
    dropout_rate=0.4,
    lr_schedule=None,
    bidirectional=True,
    use_attention=True,
):
    """Construye modelo LSTM Bidireccional para sleep staging."""
    input_layer = keras.Input(shape=input_shape, name="input", dtype="float32")

    lstm_1 = layers.LSTM(
        lstm_units,
        return_sequences=True,
        kernel_regularizer=keras.regularizers.l2(1e-4),
        name="lstm_1",
    )
    x = (
        layers.Bidirectional(lstm_1, name="bidirectional_1")(input_layer)
        if bidirectional
        else lstm_1(input_layer)
    )
    x = layers.BatchNormalization(name="bn_1")(x)
    x = layers.Dropout(dropout_rate, name="dropout_lstm_1")(x)

    lstm_2 = layers.LSTM(
        lstm_units // 2,
        return_sequences=use_attention,
        kernel_regularizer=keras.regularizers.l2(1e-4),
        name="lstm_2",
    )
    x = (
        layers.Bidirectional(lstm_2, name="bidirectional_2")(x)
        if bidirectional
        else lstm_2(x)
    )
    x = layers.BatchNormalization(name="bn_2")(x)
    x = layers.Dropout(dropout_rate, name="dropout_lstm_2")(x)

    if use_attention:
        x = AttentionLayer(name="attention")(x)

    x = layers.Dense(
        128,
        kernel_regularizer=keras.regularizers.l2(1e-4),
        kernel_initializer="he_uniform",
        name="dense_1",
    )(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.Dropout(dropout_rate, name="dropout_1")(x)

    x = layers.Dense(
        64,
        kernel_regularizer=keras.regularizers.l2(1e-4),
        kernel_initializer="he_uniform",
        name="dense_2",
    )(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.Dropout(dropout_rate, name="dropout_2")(x)

    output_layer = layers.Dense(
        n_classes, kernel_initializer="glorot_uniform", name="output", dtype="float32"
    )(x)  # logits

    model_name = (
        ("BiLSTM" if bidirectional else "LSTM")
        + ("_Attention" if use_attention else "")
        + "_SleepStaging"
    )
    model = keras.Model(inputs=input_layer, outputs=output_layer, name=model_name)
    model.compile(
        optimizer=keras.optimizers.Adam(
            learning_rate=lr_schedule if lr_schedule else 1e-4, clipnorm=1.0
        ),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )
    return model


print("[OK] Arquitectura LSTM definida")
print("[OK] Learning rate schedule (warmup + cosine decay) definido")

In [None]:
# ============================================================
# CREAR MODELO
# ============================================================

input_shape = INPUT_SHAPE
print(f"Input shape: {input_shape}")

MAX_STEPS_PER_EPOCH = 115
full_steps_train = math.ceil(train_count / CONFIG["effective_batch_size"])
full_steps_val = (
    math.ceil(val_count / CONFIG["effective_batch_size"]) if val_count else 0
)

if USE_STREAMING:
    steps_per_epoch_train = min(full_steps_train, MAX_STEPS_PER_EPOCH)
    steps_per_epoch_val = (
        min(full_steps_val, MAX_STEPS_PER_EPOCH) if full_steps_val else None
    )
else:
    steps_per_epoch_train = min(full_steps_train, MAX_STEPS_PER_EPOCH)
    steps_per_epoch_val = None

lr_schedule = create_lr_schedule(
    initial_lr=CONFIG["learning_rate_initial"],
    min_lr=CONFIG["learning_rate_min"],
    warmup_epochs=CONFIG["warmup_epochs"],
    total_epochs=CONFIG["epochs"],
    steps_per_epoch=steps_per_epoch_train,
)

print(
    f"\n[INFO] LR Schedule: initial={CONFIG['learning_rate_initial']}, min={CONFIG['learning_rate_min']}, warmup={CONFIG['warmup_epochs']} epochs"
)
print(f"   Steps per epoch (train): {steps_per_epoch_train}")

with strategy.scope():
    model = build_lstm_model(
        input_shape=input_shape,
        n_classes=len(STAGE_ORDER),
        lstm_units=CONFIG["lstm_units"],
        dropout_rate=CONFIG["dropout_rate"],
        lr_schedule=lr_schedule,
        bidirectional=CONFIG["bidirectional"],
        use_attention=CONFIG["use_attention"],
    )

model.summary()

## Entrenamiento

In [None]:
# ============================================================
# CALLBACKS
# ============================================================

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_name = f"lstm_{CONFIG['execution_mode']}_{timestamp}"


class SleepMetricsCallback(Callback):
    """Calcula F1-Macro y Kappa al final de cada epoch."""

    def __init__(
        self, val_ds, val_steps, stage_order, use_streaming=False, eval_every=1
    ):
        super().__init__()
        self.val_ds, self.val_steps, self.stage_order = val_ds, val_steps, stage_order
        self.use_streaming, self.eval_every = use_streaming, max(1, eval_every)
        self.best_f1_macro, self.best_kappa, self.best_weights = -1.0, -1.0, None

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        if (epoch + 1) % self.eval_every != 0:
            return
        y_true_list, y_pred_list = [], []
        if self.use_streaming:
            for batch_idx, (x_batch, y_batch) in enumerate(self.val_ds):
                if batch_idx >= self.val_steps:
                    break
                y_pred_list.append(
                    np.argmax(self.model.predict(x_batch, verbose=0), axis=1)
                )
                y_true_list.append(y_batch.numpy())
        else:
            for x_batch, y_batch in self.val_ds:
                y_pred_list.append(
                    np.argmax(self.model.predict(x_batch, verbose=0), axis=1)
                )
                y_true_list.append(y_batch.numpy())
        y_true, y_pred = np.concatenate(y_true_list), np.concatenate(y_pred_list)
        kappa = cohen_kappa_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
        logs["val_kappa"], logs["val_f1_macro"] = kappa, f1
        if f1 > self.best_f1_macro:
            self.best_f1_macro, self.best_kappa, self.best_weights = (
                f1,
                kappa,
                self.model.get_weights(),
            )
            print(f" - val_f1_macro mejoró a {f1:.4f} (kappa={kappa:.4f})")
        print(f" - val_f1_macro: {f1:.4f} - val_kappa: {kappa:.4f}")

    def restore_best_weights(self):
        if self.best_weights:
            self.model.set_weights(self.best_weights)
            print(
                f"[OK] Restaurados pesos (F1={self.best_f1_macro:.4f}, Kappa={self.best_kappa:.4f})"
            )


class NaNDebugCallback(Callback):
    """Detecta NaN en loss."""

    def __init__(self):
        super().__init__()
        self.nan_count, self.last_valid_loss, self.first_nan_batch, self.diagnosed = (
            0,
            None,
            None,
            False,
        )

    def on_batch_end(self, batch, logs=None):
        loss = (logs or {}).get("loss", 0)
        if not (np.isnan(loss) or np.isinf(loss)):
            self.last_valid_loss = loss
            return
        self.nan_count += 1
        if self.first_nan_batch is None:
            self.first_nan_batch = batch
            print(
                f"\n[WARN] NaN en batch {batch} (último válido={self.last_valid_loss})"
            )

    def on_epoch_end(self, epoch, logs=None):
        if self.nan_count > 0:
            print(f"   [INFO] Epoch {epoch+1}: {self.nan_count} batches con NaN")
        self.nan_count, self.first_nan_batch = 0, None


val_steps_for_callback = steps_per_epoch_val if USE_STREAMING else None
metrics_callback = SleepMetricsCallback(
    val_ds, val_steps_for_callback, STAGE_ORDER, USE_STREAMING, eval_every=1
)

callbacks = [
    NaNDebugCallback(),
    metrics_callback,
    EarlyStopping(
        monitor="val_f1_macro",
        mode="max",
        patience=CONFIG["early_stopping_patience"],
        restore_best_weights=False,
        verbose=1,
    ),
    ModelCheckpoint(
        filepath=f"{OUTPUT_PATH}/{model_name}_best.keras",
        monitor="val_f1_macro",
        mode="max",
        save_best_only=True,
        verbose=1,
    ),
]

print(f"Modelo: {model_name}")
print(f"Checkpoint: {OUTPUT_PATH}/{model_name}_best.keras")

## Reanudacion desde Checkpoint (opcional)

Si Kaggle se desconecto durante el entrenamiento, puedes reanudar desde el ultimo checkpoint.
Solo ejecuta la siguiente celda si necesitas reanudar.

In [None]:
# ============================================================
# REANUDAR DESDE CHECKPOINT (ejecutar solo si es necesario)
# ============================================================

RESUME_FROM_CHECKPOINT = False  # Cambiar a True para reanudar
CHECKPOINT_NAME = None  # Ejemplo: "lstm_full_20251125_143022_best.keras"

if RESUME_FROM_CHECKPOINT and CHECKPOINT_NAME:
    checkpoint_path = f"{OUTPUT_PATH}/{CHECKPOINT_NAME}"
    if os.path.exists(checkpoint_path):
        print(f"[INFO] Cargando modelo desde checkpoint: {checkpoint_path}")
        with strategy.scope():
            model = keras.models.load_model(
                checkpoint_path, custom_objects={"AttentionLayer": AttentionLayer}
            )
        print("[OK] Modelo cargado exitosamente")
    else:
        print(f"[ERROR] Checkpoint no encontrado: {checkpoint_path}")
        for f in os.listdir(OUTPUT_PATH):
            if f.endswith(".keras"):
                print(f"   - {f}")
else:
    print("[INFO] Modo normal: se usará el modelo recién creado")

In [None]:
# ============================================================
# ENTRENAR MODELO
# ============================================================

print(f"\nIniciando entrenamiento ({CONFIG['execution_mode'].upper()} mode)...")
print(f"   Batch size efectivo: {CONFIG['effective_batch_size']}")
print(f"   Epochs maximos: {CONFIG['epochs']}")
print(f"   Steps train: {steps_per_epoch_train}")
if USE_STREAMING:
    print(f"   Steps val: {steps_per_epoch_val}")

training_start_time = time.time()

if CONFIG.get("use_class_weights", True) and class_weights:
    print(f"   Class weights: {class_weights}")
else:
    print("   Class weights: DESACTIVADOS")

if USE_STREAMING:
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        steps_per_epoch=steps_per_epoch_train,
        validation_steps=steps_per_epoch_val,
        epochs=CONFIG["epochs"],
        class_weight=class_weights if CONFIG.get("use_class_weights", True) else None,
        callbacks=callbacks,
        verbose=1,
    )
else:
    train_ds_repeat = train_ds.repeat()
    history = model.fit(
        train_ds_repeat,
        validation_data=val_ds,
        steps_per_epoch=steps_per_epoch_train,
        epochs=CONFIG["epochs"],
        class_weight=class_weights if CONFIG.get("use_class_weights", True) else None,
        callbacks=callbacks,
        verbose=1,
    )

training_time = time.time() - training_start_time
print(
    f"\n[OK] Entrenamiento completado en {training_time / 60:.2f} min ({training_time / 3600:.2f} h)"
)

gc.collect()
metrics_callback.restore_best_weights()

In [None]:
# ============================================================
# VISUALIZAR CURVAS DE APRENDIZAJE
# ============================================================

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

axes[0, 0].plot(history.history["loss"], label="Train Loss", linewidth=2)
axes[0, 0].plot(history.history["val_loss"], label="Val Loss", linewidth=2)
axes[0, 0].set_xlabel("Epoch")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].set_title("Loss durante entrenamiento")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].plot(history.history["accuracy"], label="Train Acc", linewidth=2)
axes[0, 1].plot(history.history["val_accuracy"], label="Val Acc", linewidth=2)
axes[0, 1].set_xlabel("Epoch")
axes[0, 1].set_ylabel("Accuracy")
axes[0, 1].set_title("Accuracy durante entrenamiento")
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

if "val_kappa" in history.history:
    axes[1, 0].plot(
        history.history["val_kappa"], label="Val Kappa", linewidth=2, color="green"
    )
    axes[1, 0].axhline(
        y=metrics_callback.best_kappa,
        color="red",
        linestyle="--",
        label=f"Best={metrics_callback.best_kappa:.4f}",
    )
    axes[1, 0].set_xlabel("Epoch")
    axes[1, 0].set_ylabel("Cohen's Kappa")
    axes[1, 0].set_title("Cohen's Kappa")
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
else:
    axes[1, 0].text(0.5, 0.5, "Kappa no disponible", ha="center", va="center")

if "val_f1_macro" in history.history:
    axes[1, 1].plot(
        history.history["val_f1_macro"],
        label="Val F1-Macro",
        linewidth=2,
        color="purple",
    )
    axes[1, 1].axhline(
        y=metrics_callback.best_f1_macro,
        color="red",
        linestyle="--",
        label=f"Best={metrics_callback.best_f1_macro:.4f}",
    )
    axes[1, 1].set_xlabel("Epoch")
    axes[1, 1].set_ylabel("F1-Macro")
    axes[1, 1].set_title("F1-Macro (métrica de selección)")
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
else:
    axes[1, 1].text(0.5, 0.5, "F1-Macro no disponible", ha="center", va="center")

plt.tight_layout()
plt.savefig(f"{OUTPUT_PATH}/{model_name}_training_curves.png", dpi=150)
plt.show()

## Evaluacion en Test

In [None]:
# ============================================================
# EVALUACION EN TEST SET
# ============================================================

print("\nEvaluando en Test Set...")


def collect_labels(ds):
    return np.concatenate(
        list(ds.map(lambda _x, y: y).unbatch().batch(1024).as_numpy_iterator())
    )


y_test_enc = collect_labels(test_ds).astype(int)
y_pred_logits = model.predict(test_ds, verbose=1)
y_pred_proba = tf.nn.softmax(y_pred_logits, axis=1).numpy()
y_pred_enc = np.argmax(y_pred_proba, axis=1)

accuracy = accuracy_score(y_test_enc, y_pred_enc)
kappa = cohen_kappa_score(y_test_enc, y_pred_enc)
f1_macro = f1_score(y_test_enc, y_pred_enc, average="macro", zero_division=0)
f1_weighted = f1_score(y_test_enc, y_pred_enc, average="weighted", zero_division=0)

print(f"\n{'=' * 50}")
print("RESULTADOS EN TEST SET")
print(f"{'=' * 50}")
print(f"   Accuracy:    {accuracy:.4f} ({accuracy * 100:.2f}%)")
print(f"   Cohen Kappa: {kappa:.4f}  <- Métrica principal en literatura")
print(f"   F1 Macro:    {f1_macro:.4f}")
print(f"   F1 Weighted: {f1_weighted:.4f}")
print(f"{'=' * 50}")

In [None]:
# ============================================================
# CLASSIFICATION REPORT
# ============================================================

print("\nClassification Report:")
print(
    classification_report(
        y_test_enc,
        y_pred_enc,
        labels=np.arange(len(STAGE_ORDER)),
        target_names=STAGE_ORDER,
        digits=4,
        zero_division=0,
    )
)

In [None]:
# ============================================================
# MATRIZ DE CONFUSIÓN
# ============================================================

cm = confusion_matrix(y_test_enc, y_pred_enc, labels=np.arange(len(STAGE_ORDER)))
cm_normalized = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=STAGE_ORDER,
    yticklabels=STAGE_ORDER,
    ax=axes[0],
)
axes[0].set_xlabel("Predicho")
axes[0].set_ylabel("Real")
axes[0].set_title("Matriz de Confusión (Absoluta)")

sns.heatmap(
    cm_normalized,
    annot=True,
    fmt=".2%",
    cmap="Blues",
    xticklabels=STAGE_ORDER,
    yticklabels=STAGE_ORDER,
    ax=axes[1],
)
axes[1].set_xlabel("Predicho")
axes[1].set_ylabel("Real")
axes[1].set_title("Matriz de Confusión (Normalizada)")

plt.tight_layout()
plt.savefig(f"{OUTPUT_PATH}/{model_name}_confusion_matrix.png", dpi=150)
plt.show()

## Optimizacion de Hiperparametros (Optuna)

In [None]:
# ============================================================
# OPTIMIZACION CON OPTUNA (opcional)
# ============================================================

if CONFIG["run_optimization"]:
    print("[WARN] Optuna no implementado para modo streaming. Skip.")
else:
    print(
        "[SKIP] Optimización deshabilitada. Cambiar CONFIG['run_optimization'] = True para ejecutar."
    )

## Guardar Modelo y Resultados

In [None]:
# ============================================================
# GUARDAR MODELO Y ARTEFACTOS
# ============================================================

model.save(f"{OUTPUT_PATH}/{model_name}_final.keras")
print(f"[OK] Modelo guardado: {OUTPUT_PATH}/{model_name}_final.keras")

history_data = {k: pd.Series(v) for k, v in history.history.items()}
history_df = pd.DataFrame(history_data)
history_df.to_csv(f"{OUTPUT_PATH}/{model_name}_history.csv", index=False)

results = {
    "model_name": model_name,
    "config": CONFIG,
    "metrics": {
        "accuracy": float(accuracy),
        "kappa": float(kappa),
        "f1_macro": float(f1_macro),
        "f1_weighted": float(f1_weighted),
    },
    "training": {
        "training_time_seconds": float(training_time),
        "training_time_minutes": float(training_time / 60),
        "epochs_trained": len(history.history["loss"]),
        "best_val_f1_macro": float(metrics_callback.best_f1_macro),
        "best_val_kappa": float(metrics_callback.best_kappa),
    },
    "dataset": {
        "train_samples": int(train_count),
        "val_samples": int(val_count),
        "test_samples": int(test_count),
        "tfrecord_paths": tfrecord_paths if USE_STREAMING else None,
        "execution_mode": CONFIG["execution_mode"],
    },
    "channel_stats": {"mean": mean_ch.tolist(), "std": std_ch.tolist()},
    "label_encoder_classes": list(le.classes_),
}

with open(f"{OUTPUT_PATH}/{model_name}_results.json", "w") as f:
    json.dump(results, f, indent=2, default=str)

with open(f"{OUTPUT_PATH}/{model_name}_label_encoder.pkl", "wb") as f:
    pickle.dump(le, f)

print("\nArchivos generados:")
print(f"   - {model_name}_final.keras (modelo)")
print(f"   - {model_name}_best.keras (mejor checkpoint)")
print(f"   - {model_name}_history.csv (historial)")
print(f"   - {model_name}_results.json (metricas y config)")

In [None]:
# ============================================================
# RESUMEN FINAL
# ============================================================

print("\n" + "=" * 60)
print("ENTRENAMIENTO LSTM COMPLETADO")
print("=" * 60)
print("\nResultados finales en Test Set:")
print(f"   Accuracy:    {accuracy:.4f} ({accuracy * 100:.2f}%)")
print(f"   Cohen Kappa: {kappa:.4f}")
print(f"   F1 Macro:    {f1_macro:.4f}")
print(f"   F1 Weighted: {f1_weighted:.4f}")
print(f"\nModelo guardado en: {OUTPUT_PATH}")
print("=" * 60)

In [None]:
# ============================================================
# COMPRIMIR ARTEFACTOS PARA DESCARGA
# ============================================================

zip_path = f"{OUTPUT_PATH}/{model_name}_artifacts.zip"
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
    for fname in os.listdir(OUTPUT_PATH):
        # Excluir TFRecords (muy grandes) y directorios
        if fname.startswith(model_name) and not fname.endswith(".tfrecord"):
            fpath = os.path.join(OUTPUT_PATH, fname)
            if os.path.isfile(fpath):
                zf.write(fpath, arcname=fname)

print(f"[OK] Artefactos comprimidos en: {zip_path}")
print("[INFO] TFRecords excluidos del zip (muy grandes)")
with zipfile.ZipFile(zip_path, "r") as zf:
    for info in zf.infolist():
        print(f" - {info.filename} ({info.file_size/1024:.1f} KB)")