# LSTM Bidireccional para Sleep Staging

Este notebook entrena un modelo **LSTM Bidireccional** para clasificacion de estadios de sueno usando datos PSG **trimmed** (preprocesados y recortados del dataset Sleep-EDF).

**Optimizado para Kaggle con 2x Tesla T4 (16GB VRAM cada una)**

### Caracteristicas:
- LSTM Bidireccional con mecanismo de atencion opcional
- Soporte multi-GPU con MirroredStrategy
- Optimizacion de hiperparametros con Optuna
- Division train/val/test por sujetos (sin data leakage)
- **Soporte para reanudar entrenamiento desde checkpoints**

### Datos requeridos:
- Dataset `sleep-edf-trimmed-f32` en Kaggle (https://www.kaggle.com/datasets/ignaciolinari/sleep-edf-trimmed-f32)
  - `manifest_trimmed_spt.csv` (1 episodio por noche, 100 Hz)
  - `sleep_trimmed_spt/psg/*.fif` (PSG a 100 Hz, float32)
  - `sleep_trimmed_spt/hypnograms/*.csv` (anotaciones)
- Si usas la versión 200 Hz, apunta a `manifest_trimmed_resamp200.csv` + carpeta `sleep_trimmed_resamp200/`.
- Nota Kaggle: la versión 200 Hz puede consumir más VRAM; deja `sfreq=100` o usa el dataset 100 Hz para evitar OOM.
- Si usas otro slug, ajusta `DATA_PATH` en la celda siguiente.

In [None]:
# ============================================================
# CONFIGURACION INICIAL - Ejecutar primero
# ============================================================

import os
import warnings
from pathlib import Path

warnings.filterwarnings("ignore")

# Detectar si estamos en Kaggle
IN_KAGGLE = os.path.exists("/kaggle/input")
print(f"Entorno: {'Kaggle' if IN_KAGGLE else 'Local'}")

# Paths segun entorno
if IN_KAGGLE:
    # Detectar slug automáticamente (por defecto sleep-edf-trimmed-f32)
    base_input = Path("/kaggle/input")
    default_slug = base_input / "sleep-edf-trimmed-f32"
    if default_slug.exists():
        dataset_root = default_slug
    else:
        # Tomar la primera carpeta que contenga "sleep-edf" si el slug difiere
        candidates = sorted(
            [p for p in base_input.iterdir() if p.is_dir() and "sleep-edf" in p.name]
        )
        if not candidates:
            raise FileNotFoundError(
                "No se encontró dataset sleep-edf* en /kaggle/input; ajusta DATA_PATH manualmente."
            )
        dataset_root = candidates[0]

    # DATA_PATH apunta al root del dataset; subcarpetas se resuelven más abajo
    DATA_PATH = str(dataset_root)
    OUTPUT_PATH = "/kaggle/working"
else:
    # Local
    DATA_PATH = "../data/processed"
    OUTPUT_PATH = "../models"

print(f"Data path: {DATA_PATH}")
print(f"Output path: {OUTPUT_PATH}")

In [None]:
# ============================================================
# VERIFICAR GPUs DISPONIBLES
# ============================================================

import tensorflow as tf  # noqa: E402

# Semilla global para reproducibilidad
SEED = 42
tf.random.set_seed(SEED)

print("TensorFlow version:", tf.__version__)
print("\nGPUs disponibles:")

gpus = tf.config.list_physical_devices("GPU")
if gpus:
    for i, gpu in enumerate(gpus):
        print(f"  GPU {i}: {gpu}")

    # Habilitar memory growth para evitar OOM
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

    print(f"\n[OK] {len(gpus)} GPU(s) configuradas con memory growth")
else:
    print("[WARN] No se detectaron GPUs. El entrenamiento sera lento.")

# Estrategia de distribucion para multiples GPUs
if len(gpus) > 1:
    strategy = tf.distribute.MirroredStrategy()
    print(f"\nUsando MirroredStrategy con {strategy.num_replicas_in_sync} GPUs")
else:
    strategy = tf.distribute.get_strategy()
    print("\nUsando estrategia por defecto (1 GPU o CPU)")

In [None]:
# ============================================================
# IMPORTS Y DEPENDENCIAS
# ============================================================

import numpy as np  # noqa: E402
import pandas as pd  # noqa: E402
import matplotlib.pyplot as plt  # noqa: E402
import seaborn as sns  # noqa: E402
from pathlib import Path  # noqa: E402
from collections import Counter  # noqa: E402
import logging  # noqa: E402
import json  # noqa: E402
from datetime import datetime  # noqa: E402
import pickle  # noqa: E402
import zipfile  # noqa: E402
import random  # noqa: E402
import re  # noqa: E402
import math  # noqa: E402
import gc  # noqa: E402

# Semillas globales para reproducibilidad
np.random.seed(SEED)
random.seed(SEED)
tf.keras.utils.set_random_seed(SEED)

from sklearn.preprocessing import LabelEncoder  # noqa: E402
from sklearn.utils.class_weight import compute_class_weight  # noqa: E402
from sklearn.metrics import (  # noqa: E402
    classification_report,
    confusion_matrix,
    cohen_kappa_score,
    accuracy_score,
    f1_score,
)

from tensorflow import keras  # noqa: E402
from tensorflow.keras import layers  # noqa: E402
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint  # noqa: E402

# Configurar logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

# Configurar estilo de plots
plt.style.use("seaborn-v0_8-whitegrid")
sns.set_palette("husl")

print("[OK] Imports completados")

In [None]:
# ============================================================
# INSTALAR DEPENDENCIAS (solo en Kaggle)
# ============================================================

try:
    import mne
except ImportError:
    if IN_KAGGLE:
        print("Instalando dependencias...")
        %pip install -q mne
        print("[OK] Dependencias instaladas")
        import mne
    else:
        raise

mne.set_log_level("ERROR")
print(f"MNE version: {mne.__version__}")

## Configuracion del Experimento

In [None]:
# ============================================================
# HIPERPARAMETROS - AJUSTAR SEGUN NECESIDAD
# ============================================================

# Localizar manifest disponible (prioridad: resamp200 -> spt -> trimmed)
manifest_candidates = [
    Path(DATA_PATH) / "manifest_trimmed_resamp200.csv",
    Path(DATA_PATH) / "manifest_trimmed_spt.csv",
    Path(DATA_PATH) / "manifest_trimmed.csv",
]
manifest_path = next((str(p) for p in manifest_candidates if p.exists()), None)
if manifest_path is None:
    raise FileNotFoundError(
        f"No se encontró manifest_*.csv en {DATA_PATH}; ajusta DATA_PATH o el slug/version en Inputs."
    )

CONFIG = {
    # Dataset (100 Hz)
    "manifest_path": manifest_path,
    "epoch_length": 30.0,  # segundos
    "sfreq": 100,  # Hz (dataset a 100 Hz)
    "limit_sessions": None,  # None = todas, numero = limitar para debug
    # Split por sujeto
    "test_size": 0.15,
    "val_size": 0.15,
    "random_state": 42,
    # Modelo LSTM
    "lstm_units": 128,  # Unidades LSTM (se reduce a la mitad en 2da capa)
    "dropout_rate": 0.5,
    "bidirectional": True,  # LSTM bidireccional
    "use_attention": True,  # Mecanismo de atencion
    # Entrenamiento
    "learning_rate": 0.001,
    "batch_size": 64,  # Con 2x T4 puede aumentar a 128
    "epochs": 100,
    "early_stopping_patience": 15,
    "reduce_lr_patience": 7,
    # Optimizacion (Optuna)
    "run_optimization": False,  # Cambiar a True para buscar hiperparametros
    "n_optuna_trials": 30,
    # Streaming TFRecord (evita OOM en Kaggle)
    "streaming": True if IN_KAGGLE else False,
    "tfrecord_dir": f"{OUTPUT_PATH}/tfrecords_lstm",
    "shuffle_buffer": 5000,
}

# Ajustar batch size para multi-GPU
CONFIG["effective_batch_size"] = CONFIG["batch_size"] * strategy.num_replicas_in_sync

print("Configuracion del experimento:")
for key, value in CONFIG.items():
    print(f"   {key}: {value}")

## Carga de Datos

In [None]:
# ============================================================
# FUNCIONES DE CARGA DE DATOS
# ============================================================

# Canales y estadios
DEFAULT_CHANNELS = {
    "EEG": ["EEG Fpz-Cz", "EEG Pz-Oz"],
    "EOG": ["EOG horizontal"],
    "EMG": ["EMG submental"],
}

STAGE_CANONICAL = {
    "Sleep stage W": "W",
    "Sleep stage 1": "N1",
    "Sleep stage 2": "N2",
    "Sleep stage 3": "N3",
    "Sleep stage 4": "N3",
    "Sleep stage R": "REM",
}

STAGE_ORDER = ["W", "N1", "N2", "N3", "REM"]


def extract_subject_core(subject_id):
    """Agrupa noches del mismo sujeto (ej. SC4001E0/E1 -> SC4001)."""
    match = re.match(r"(.+?)[Ee]\d+$", str(subject_id))
    return match.group(1) if match else str(subject_id)


def load_psg_data(psg_path, channels=None, target_sfreq=None):
    """Carga datos PSG desde archivo .fif."""
    raw = mne.io.read_raw_fif(str(psg_path), preload=True, verbose="ERROR")

    if channels is None:
        available = set(raw.ch_names)
        channels = []
        for ch_group in DEFAULT_CHANNELS.values():
            for ch in ch_group:
                if ch in available:
                    channels.append(ch)

    raw.pick(channels)  # pick_channels() está deprecado

    if target_sfreq and raw.info["sfreq"] != target_sfreq:
        raw.resample(target_sfreq)

    data = raw.get_data()
    return data, raw.info["sfreq"], raw.ch_names


def load_hypnogram(hyp_path):
    """Carga hipnograma desde CSV."""
    df = pd.read_csv(hyp_path)
    df["stage_canonical"] = df["description"].map(STAGE_CANONICAL)
    return df


def create_epochs(data, sfreq, epoch_length=30.0):
    """Divide datos en epochs de longitud fija."""
    samples_per_epoch = int(epoch_length * sfreq)
    n_channels, n_samples = data.shape
    n_epochs = n_samples // samples_per_epoch

    epochs = []
    epoch_times = []

    for i in range(n_epochs):
        start = i * samples_per_epoch
        end = start + samples_per_epoch
        epoch = data[:, start:end]
        epochs.append(epoch)
        epoch_times.append(i * epoch_length)

    return np.array(epochs), np.array(epoch_times)


def assign_stages(epoch_times, hypnogram, epoch_length=30.0):
    """Asigna estadios a cada epoch."""
    stages = []

    for t in epoch_times:
        epoch_center = t + epoch_length / 2
        mask = (hypnogram["onset"] <= epoch_center) & (
            hypnogram["onset"] + hypnogram["duration"] > epoch_center
        )
        matched = hypnogram[mask]

        if len(matched) > 0:
            stage = matched.iloc[0]["stage_canonical"]
            stages.append(stage)
        else:
            stages.append(None)

    return stages


print("[OK] Funciones de carga definidas")

In [None]:
# ============================================================
# PIPELINE STREAMING CON TFRECORD (evita OOM en Kaggle)
# ============================================================


def resolve_paths(row, manifest_dir, dataset_dir_name):
    """Resuelve rutas de PSG e hipnograma respetando Kaggle/local."""

    psg_path_str = row.get("psg_trimmed_path", "")
    hyp_path_str = row.get("hypnogram_trimmed_path", "")

    # Manejar NaN
    if pd.isna(psg_path_str):
        psg_path_str = ""
    if pd.isna(hyp_path_str):
        hyp_path_str = ""

    base_data_root = manifest_dir.parent

    if psg_path_str and hyp_path_str:
        if IN_KAGGLE:
            psg_rel = Path(psg_path_str)
            hyp_rel = Path(hyp_path_str)
            psg_parts = psg_rel.parts
            hyp_parts = hyp_rel.parts

            psg_anchor_idx = next(
                (i for i, p in enumerate(psg_parts) if p.startswith("sleep_trimmed")),
                None,
            )
            hyp_anchor_idx = next(
                (i for i, p in enumerate(hyp_parts) if p.startswith("sleep_trimmed")),
                None,
            )

            if psg_anchor_idx is not None and hyp_anchor_idx is not None:
                psg_path = Path(DATA_PATH) / Path(*psg_parts[psg_anchor_idx:])
                hyp_path = Path(DATA_PATH) / Path(*hyp_parts[hyp_anchor_idx:])
            else:
                psg_path = Path(DATA_PATH) / dataset_dir_name / "psg" / psg_rel.name
                hyp_path = (
                    Path(DATA_PATH) / dataset_dir_name / "hypnograms" / hyp_rel.name
                )
        else:
            psg_rel = Path(psg_path_str)
            hyp_rel = Path(hyp_path_str)

            if not psg_rel.is_absolute():
                if psg_rel.parts and psg_rel.parts[0] == "data":
                    psg_path = base_data_root / psg_rel.relative_to("data")
                else:
                    psg_path = manifest_dir / psg_rel
            else:
                psg_path = psg_rel

            if not hyp_rel.is_absolute():
                if hyp_rel.parts and hyp_rel.parts[0] == "data":
                    hyp_path = base_data_root / hyp_rel.relative_to("data")
                else:
                    hyp_path = manifest_dir / hyp_rel
            else:
                hyp_path = hyp_rel
    else:
        subset = row.get("subset", "sleep-cassette")
        version = row.get("version", "1.0.0")
        dataset_dir = manifest_dir / dataset_dir_name
        psg_path = (
            dataset_dir
            / "psg"
            / f"{row['subject_id']}_{subset}_{version}_trimmed_raw.fif"
        )
        hyp_path = (
            dataset_dir
            / "hypnograms"
            / f"{row['subject_id']}_{subset}_{version}_trimmed_annotations.csv"
        )

    return psg_path, hyp_path


def iter_sessions(manifest_path, epoch_length, sfreq, limit=None):
    """Itera sesiones entregando rutas resueltas."""

    manifest = pd.read_csv(manifest_path)
    manifest_ok = manifest[manifest["status"] == "ok"].copy()

    if limit:
        manifest_ok = manifest_ok.head(limit)
        print(f"[WARN] Modo debug: procesando solo {limit} sesiones")

    manifest_dir = Path(manifest_path).parent
    dataset_dir_name = (
        "sleep_trimmed_resamp200"
        if (manifest_dir / "sleep_trimmed_resamp200").exists()
        else "sleep_trimmed_spt"
        if (manifest_dir / "sleep_trimmed_spt").exists()
        else "sleep_trimmed"
    )

    total_sessions = len(manifest_ok)
    print(f"\nProcesando {total_sessions} sesiones (streaming)...")

    for i, (_, row) in enumerate(manifest_ok.iterrows(), start=1):
        subject_id = row["subject_id"]
        subject_core = extract_subject_core(subject_id)
        psg_path, hyp_path = resolve_paths(row, manifest_dir, dataset_dir_name)

        if not psg_path.exists() or not hyp_path.exists():
            print(f"[WARN] Archivos faltantes para {subject_id}; se omite esta sesion")
            continue

        yield i, total_sessions, subject_id, subject_core, psg_path, hyp_path


def update_running_stats(stats, epochs):
    """Acumula sumas y sumas cuadradas por canal."""

    if stats is None:
        stats = {"n": 0, "sum": None, "sumsq": None}

    if stats["sum"] is None:
        stats["sum"] = np.zeros(epochs.shape[1], dtype=np.float64)
        stats["sumsq"] = np.zeros(epochs.shape[1], dtype=np.float64)

    channel_sum = epochs.sum(axis=(0, 2))
    channel_sumsq = (epochs**2).sum(axis=(0, 2))
    stats["n"] += epochs.shape[0] * epochs.shape[2]
    stats["sum"] += channel_sum
    stats["sumsq"] += channel_sumsq
    return stats


def finalize_running_stats(stats):
    mean = stats["sum"] / stats["n"]
    var = stats["sumsq"] / stats["n"] - mean**2
    std = np.sqrt(np.maximum(var, 1e-8))
    return mean.astype(np.float32), std.astype(np.float32)


def pass1_stats(manifest_path, epoch_length, sfreq, limit=None):
    """Primera pasada: mean/std por canal y conteo de clases."""

    stats = None
    class_counts = Counter()
    input_shape = None

    for i, total, subject_id, subject_core, psg_path, hyp_path in iter_sessions(
        manifest_path, epoch_length, sfreq, limit
    ):
        data, actual_sfreq, _ = load_psg_data(psg_path, target_sfreq=sfreq)
        hypnogram = load_hypnogram(hyp_path)
        epochs, epoch_times = create_epochs(data, actual_sfreq, epoch_length)
        stages = assign_stages(epoch_times, hypnogram, epoch_length)

        valid_mask = [s in STAGE_ORDER for s in stages]
        if not any(valid_mask):
            del data, epochs, epoch_times, stages, hypnogram
            gc.collect()
            continue

        valid_epochs = epochs[valid_mask]
        valid_stages = [s for s in stages if s in STAGE_ORDER]

        if input_shape is None and len(valid_epochs) > 0:
            input_shape = valid_epochs.shape[1:]  # (channels, samples)

        stats = update_running_stats(stats, valid_epochs)
        class_counts.update(valid_stages)

        if i % 20 == 0 or i == total:
            print(f"   Pasada 1: {i}/{total} sesiones")

        # Liberar memoria por iteracion
        del data, epochs, epoch_times, stages, hypnogram, valid_epochs, valid_stages
        gc.collect()

    assert input_shape is not None, "No se encontraron epochs validos"
    mean, std = finalize_running_stats(stats)
    return mean, std, class_counts, input_shape


def make_example(x, y):
    feature = {
        "x": tf.train.Feature(float_list=tf.train.FloatList(value=x.ravel())),
        "y": tf.train.Feature(int64_list=tf.train.Int64List(value=[y])),
    }
    return tf.train.Example(
        features=tf.train.Features(feature=feature)
    ).SerializeToString()


def assign_subject_splits(manifest_path, test_size, val_size, random_state, limit=None):
    manifest = pd.read_csv(manifest_path)
    manifest_ok = manifest[manifest["status"] == "ok"].copy()
    if limit:
        manifest_ok = manifest_ok.head(limit)

    subject_cores = manifest_ok["subject_id"].apply(extract_subject_core).unique()
    rng = np.random.default_rng(random_state)
    rng.shuffle(subject_cores)

    n_test = max(1, int(len(subject_cores) * test_size))
    n_val = max(1, int(len(subject_cores) * val_size))

    test_cores = set(subject_cores[:n_test])
    val_cores = set(subject_cores[n_test : n_test + n_val])
    train_cores = set(subject_cores[n_test + n_val :])

    split_map = {core: "train" for core in train_cores}
    split_map.update({core: "val" for core in val_cores})
    split_map.update({core: "test" for core in test_cores})

    return split_map, {
        "train": len(train_cores),
        "val": len(val_cores),
        "test": len(test_cores),
    }


def build_tfrecord_splits(
    manifest_path,
    mean,
    std,
    split_map,
    epoch_length,
    sfreq,
    tfrecord_dir,
    limit=None,
):
    tfrecord_dir = Path(tfrecord_dir)
    tfrecord_dir.mkdir(parents=True, exist_ok=True)

    writers = {
        "train": tf.io.TFRecordWriter(str(tfrecord_dir / "train.tfrecord")),
        "val": tf.io.TFRecordWriter(str(tfrecord_dir / "val.tfrecord")),
        "test": tf.io.TFRecordWriter(str(tfrecord_dir / "test.tfrecord")),
    }

    counts = Counter()

    for i, total, subject_id, subject_core, psg_path, hyp_path in iter_sessions(
        manifest_path, epoch_length, sfreq, limit
    ):
        split = split_map.get(subject_core)
        if split is None:
            continue

        data, actual_sfreq, _ = load_psg_data(psg_path, target_sfreq=sfreq)
        hypnogram = load_hypnogram(hyp_path)
        epochs, epoch_times = create_epochs(data, actual_sfreq, epoch_length)
        stages = assign_stages(epoch_times, hypnogram, epoch_length)

        for epoch, stage in zip(epochs, stages):
            if stage not in STAGE_ORDER:
                continue
            y = STAGE_ORDER.index(stage)
            x = (epoch - mean[:, None]) / std[:, None]
            writers[split].write(make_example(x.astype(np.float32), y))
            counts[split] += 1

        if i % 20 == 0 or i == total:
            print(f"   Pasada 2 ({split}): {i}/{total} sesiones")

        # Liberar memoria por iteracion
        del data, epochs, epoch_times, stages, hypnogram
        gc.collect()

    for w in writers.values():
        w.close()

    tf_paths = {split: str(tfrecord_dir / f"{split}.tfrecord") for split in writers}
    return tf_paths, counts


def make_dataset(tfrecord_path, input_shape_ch_first, batch_size, shuffle=False):
    feature_description = {
        "x": tf.io.FixedLenFeature(
            [input_shape_ch_first[0] * input_shape_ch_first[1]], tf.float32
        ),
        "y": tf.io.FixedLenFeature([], tf.int64),
    }

    def _parse(example_proto):
        example = tf.io.parse_single_example(example_proto, feature_description)
        x = tf.reshape(example["x"], input_shape_ch_first)  # (channels, samples)
        # LSTM espera (timesteps, features) = (samples, channels)
        x = tf.transpose(x, perm=[1, 0])
        y = example["y"]
        return x, y

    ds = tf.data.TFRecordDataset([tfrecord_path], num_parallel_reads=tf.data.AUTOTUNE)
    ds = ds.map(_parse, num_parallel_calls=tf.data.AUTOTUNE)
    if shuffle:
        ds = ds.shuffle(
            CONFIG["shuffle_buffer"],
            seed=CONFIG["random_state"],
            reshuffle_each_iteration=True,
        )
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds


if CONFIG["streaming"]:
    print("\n[INFO] Modo streaming activado (TFRecord)")

    mean_ch, std_ch, class_counts, input_shape_ch_first = pass1_stats(
        CONFIG["manifest_path"],
        CONFIG["epoch_length"],
        CONFIG["sfreq"],
        limit=CONFIG["limit_sessions"],
    )

    print("\n[OK] Estadisticas globales calculadas")
    print(f"   mean shape: {mean_ch.shape}")
    print(f"   std  shape: {std_ch.shape}")

    # Class weights (evitar division por cero si falta alguna clase)
    counts_list = [class_counts.get(stage, 0) for stage in STAGE_ORDER]
    if any(c == 0 for c in counts_list):
        print(
            "[WARN] Alguna clase no apareció en la pasada 1; ajustando conteos mínimos a 1"
        )
        counts_list = [max(c, 1) for c in counts_list]

    y_for_weights = np.repeat(np.arange(len(STAGE_ORDER)), counts_list)
    class_weights_arr = compute_class_weight(
        "balanced", classes=np.arange(len(STAGE_ORDER)), y=y_for_weights
    )
    class_weights = dict(enumerate(class_weights_arr))
    print(f"Class weights: {class_weights}")

    # Asignar splits por sujeto y escribir TFRecords
    split_map, split_subjects = assign_subject_splits(
        CONFIG["manifest_path"],
        CONFIG["test_size"],
        CONFIG["val_size"],
        CONFIG["random_state"],
        limit=CONFIG["limit_sessions"],
    )

    tfrecord_paths, split_counts = build_tfrecord_splits(
        CONFIG["manifest_path"],
        mean_ch,
        std_ch,
        split_map,
        CONFIG["epoch_length"],
        CONFIG["sfreq"],
        CONFIG["tfrecord_dir"],
        limit=CONFIG["limit_sessions"],
    )

    print("\n[OK] TFRecords generados:")
    for split, path in tfrecord_paths.items():
        print(f"   {split}: {path} ({split_counts[split]:,} epochs)")

    le = LabelEncoder()
    le.fit(STAGE_ORDER)

    INPUT_SHAPE = (
        input_shape_ch_first[1],
        input_shape_ch_first[0],
    )  # (timesteps, channels)

    train_ds = make_dataset(
        tfrecord_paths["train"],
        input_shape_ch_first,
        CONFIG["effective_batch_size"],
        shuffle=True,
    )
    val_ds = make_dataset(
        tfrecord_paths["val"],
        input_shape_ch_first,
        CONFIG["effective_batch_size"],
        shuffle=False,
    )
    test_ds = make_dataset(
        tfrecord_paths["test"],
        input_shape_ch_first,
        CONFIG["effective_batch_size"],
        shuffle=False,
    )

    train_count = split_counts.get("train", 0)
    val_count = split_counts.get("val", 0)
    test_count = split_counts.get("test", 0)

    print("\n[CHECK] Resumen dataset:")
    print(f"   Input shape (timesteps, channels): {INPUT_SHAPE}")
    print(f"   Train epochs: {train_count:,} (sujetos: {split_subjects['train']})")
    print(f"   Val epochs:   {val_count:,} (sujetos: {split_subjects['val']})")
    print(f"   Test epochs:  {test_count:,} (sujetos: {split_subjects['test']})")
else:
    raise NotImplementedError(
        "Desactiva CONFIG['streaming'] para usar el pipeline antiguo"
    )

In [None]:
# ============================================================
# DIVISION TRAIN/VAL/TEST POR SUJETOS
# ============================================================

if CONFIG["streaming"]:
    print("[SKIP] División manejada en la escritura de TFRecords (por sujeto).")
    print(f"   Train epochs: {train_count:,}")
    print(f"   Val epochs:   {val_count:,}")
    print(f"   Test epochs:  {test_count:,}")
else:
    raise NotImplementedError("Desactiva streaming para usar la división en memoria")

In [None]:
# ============================================================
# NORMALIZACION Y PREPARACION PARA LSTM
# ============================================================

if CONFIG["streaming"]:
    print("[SKIP] Normalización aplicada al escribir TFRecords (float32 normalizado).")
    print(f"Clases: {STAGE_ORDER}")
else:
    raise NotImplementedError(
        "Desactiva streaming para usar la normalización en memoria"
    )

## Arquitectura LSTM Bidireccional

In [None]:
# ============================================================
# CAPA DE ATENCION PERSONALIZADA
# ============================================================


class AttentionLayer(layers.Layer):
    """Capa de atencion simple para LSTM.

    Permite al modelo enfocarse en las partes mas relevantes
    de la secuencia temporal para la clasificacion.
    """

    def __init__(self, **kwargs):
        super(AttentionLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.W = self.add_weight(
            name="attention_weight",
            shape=(input_shape[-1], 1),
            initializer="glorot_uniform",
            trainable=True,
        )
        self.b = self.add_weight(
            name="attention_bias",
            shape=(input_shape[1], 1),
            initializer="zeros",
            trainable=True,
        )
        super(AttentionLayer, self).build(input_shape)

    def call(self, x):
        # x shape: (batch, timesteps, features)
        # Calcular scores de atencion
        e = keras.backend.tanh(keras.backend.dot(x, self.W) + self.b)
        # Softmax sobre timesteps
        a = keras.backend.softmax(e, axis=1)
        # Weighted sum
        output = x * a
        return keras.backend.sum(output, axis=1)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[-1])

    def get_config(self):
        return super(AttentionLayer, self).get_config()


print("[OK] Capa de Atencion definida")

In [None]:
# ============================================================
# MODELO LSTM BIDIRECCIONAL CON ATENCION
# ============================================================


def build_lstm_model(
    input_shape,
    n_classes=5,
    lstm_units=128,
    dropout_rate=0.5,
    learning_rate=0.001,
    bidirectional=True,
    use_attention=True,
):
    """Construye modelo LSTM Bidireccional para sleep staging.

    Arquitectura:
    - 2 capas LSTM (bidireccionales opcionales)
    - BatchNorm y Dropout entre capas
    - Capa de atencion opcional
    - Capas densas para clasificacion

    NOTA: No usamos recurrent_dropout para mantener compatibilidad con cuDNN
    """

    # Input: (timesteps, features)
    input_layer = keras.Input(shape=input_shape, name="input")

    # Primera capa LSTM (retorna secuencias)
    lstm_1 = layers.LSTM(
        lstm_units,
        return_sequences=True,
        kernel_regularizer=keras.regularizers.l2(1e-4),
        name="lstm_1",
    )

    if bidirectional:
        x = layers.Bidirectional(lstm_1, name="bidirectional_1")(input_layer)
    else:
        x = lstm_1(input_layer)

    x = layers.BatchNormalization(name="bn_1")(x)
    x = layers.Dropout(dropout_rate, name="dropout_lstm_1")(x)

    # Segunda capa LSTM
    lstm_2 = layers.LSTM(
        lstm_units // 2,
        return_sequences=use_attention,  # Solo retorna secuencias si usamos atencion
        kernel_regularizer=keras.regularizers.l2(1e-4),
        name="lstm_2",
    )

    if bidirectional:
        x = layers.Bidirectional(lstm_2, name="bidirectional_2")(x)
    else:
        x = lstm_2(x)

    x = layers.BatchNormalization(name="bn_2")(x)
    x = layers.Dropout(dropout_rate, name="dropout_lstm_2")(x)

    # Capa de atencion (opcional)
    if use_attention:
        x = AttentionLayer(name="attention")(x)

    # Capas densas para clasificacion
    x = layers.Dense(
        128,
        activation="relu",
        kernel_regularizer=keras.regularizers.l2(1e-4),
        name="dense_1",
    )(x)
    x = layers.Dropout(dropout_rate, name="dropout_1")(x)

    x = layers.Dense(
        64,
        activation="relu",
        kernel_regularizer=keras.regularizers.l2(1e-4),
        name="dense_2",
    )(x)
    x = layers.Dropout(dropout_rate, name="dropout_2")(x)

    # Output
    output_layer = layers.Dense(n_classes, activation="softmax", name="output")(x)

    # Nombre del modelo segun configuracion
    model_name = "BiLSTM" if bidirectional else "LSTM"
    if use_attention:
        model_name += "_Attention"
    model_name += "_SleepStaging"

    model = keras.Model(inputs=input_layer, outputs=output_layer, name=model_name)

    # Compilar
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )

    return model


print("[OK] Arquitectura LSTM definida")

In [None]:
# ============================================================
# CREAR MODELO (con estrategia multi-GPU si está disponible)
# ============================================================

# Input shape: (timesteps, features)
input_shape = INPUT_SHAPE
print(f"Input shape: {input_shape}")

with strategy.scope():
    model = build_lstm_model(
        input_shape=input_shape,
        n_classes=len(STAGE_ORDER),
        lstm_units=CONFIG["lstm_units"],
        dropout_rate=CONFIG["dropout_rate"],
        learning_rate=CONFIG["learning_rate"],
        bidirectional=CONFIG["bidirectional"],
        use_attention=CONFIG["use_attention"],
    )

model.summary()

## Entrenamiento

In [None]:
# ============================================================
# CALLBACKS
# ============================================================

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_name = f"lstm_{timestamp}"

callbacks = [
    EarlyStopping(
        monitor="val_loss",
        patience=CONFIG["early_stopping_patience"],
        restore_best_weights=True,
        verbose=1,
    ),
    ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=CONFIG["reduce_lr_patience"],
        min_lr=1e-6,
        verbose=1,
    ),
    ModelCheckpoint(
        filepath=f"{OUTPUT_PATH}/{model_name}_best.keras",
        monitor="val_loss",
        save_best_only=True,
        verbose=1,
    ),
]

print(f"Modelo se guardara en: {OUTPUT_PATH}/{model_name}_best.keras")

## Reanudacion desde Checkpoint (opcional)

Si Kaggle se desconecto durante el entrenamiento, puedes reanudar desde el ultimo checkpoint.
Solo ejecuta la siguiente celda si necesitas reanudar.

In [None]:
# ============================================================
# REANUDAR DESDE CHECKPOINT (ejecutar solo si es necesario)
# ============================================================
# Descomenta y ajusta el nombre del checkpoint si necesitas reanudar

RESUME_FROM_CHECKPOINT = False  # Cambiar a True para reanudar
CHECKPOINT_NAME = None  # Ejemplo: "lstm_20251125_143022_best.keras"

if RESUME_FROM_CHECKPOINT and CHECKPOINT_NAME:
    checkpoint_path = f"{OUTPUT_PATH}/{CHECKPOINT_NAME}"
    if os.path.exists(checkpoint_path):
        print(f"[INFO] Cargando modelo desde checkpoint: {checkpoint_path}")
        # Registrar la capa de atencion personalizada
        with strategy.scope():
            model = keras.models.load_model(
                checkpoint_path, custom_objects={"AttentionLayer": AttentionLayer}
            )
        print("[OK] Modelo cargado exitosamente")
        print("[INFO] Puedes continuar el entrenamiento ejecutando la celda de fit()")
        print("[INFO] El modelo continuara desde donde quedo")
    else:
        print(f"[ERROR] Checkpoint no encontrado: {checkpoint_path}")
        print(f"[INFO] Archivos disponibles en {OUTPUT_PATH}:")
        for f in os.listdir(OUTPUT_PATH):
            if f.endswith(".keras"):
                print(f"   - {f}")
else:
    print("[INFO] Modo normal: se usara el modelo recien creado")
    print("[INFO] Para reanudar desde checkpoint, cambia RESUME_FROM_CHECKPOINT=True")

In [None]:
# ============================================================
# ENTRENAR MODELO
# ============================================================

print("\nIniciando entrenamiento LSTM (streaming tf.data)...")
print(f"   Batch size efectivo: {CONFIG['effective_batch_size']}")
print(f"   Epochs maximos: {CONFIG['epochs']}")
print(f"   Steps train: {math.ceil(train_count / CONFIG['effective_batch_size'])}")
print(f"   Steps val:   {math.ceil(val_count / CONFIG['effective_batch_size'])}")

if train_count == 0 or val_count == 0:
    raise ValueError(
        f"Dataset vacio: train={train_count}, val={val_count}. Revisa manifest en {CONFIG['manifest_path']} y DATA_PATH={DATA_PATH}"
    )

history = model.fit(
    train_ds,
    validation_data=val_ds,
    steps_per_epoch=math.ceil(train_count / CONFIG["effective_batch_size"]),
    validation_steps=math.ceil(val_count / CONFIG["effective_batch_size"]),
    epochs=CONFIG["epochs"],
    class_weight=class_weights,
    callbacks=callbacks,
    verbose=1,
)

print("\n[OK] Entrenamiento completado")

In [None]:
# ============================================================
# VISUALIZAR CURVAS DE APRENDIZAJE
# ============================================================

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(history.history["loss"], label="Train Loss", linewidth=2)
axes[0].plot(history.history["val_loss"], label="Val Loss", linewidth=2)
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Loss durante entrenamiento")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history.history["accuracy"], label="Train Acc", linewidth=2)
axes[1].plot(history.history["val_accuracy"], label="Val Acc", linewidth=2)
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Accuracy")
axes[1].set_title("Accuracy durante entrenamiento")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_PATH}/{model_name}_training_curves.png", dpi=150)
plt.show()

## Evaluacion en Test

In [None]:
# ============================================================
# EVALUACION EN TEST SET
# ============================================================

print("\nEvaluando en Test Set (tf.data)...")


def collect_labels(ds):
    return np.concatenate(
        list(ds.map(lambda _x, y: y).unbatch().batch(1024).as_numpy_iterator())
    )


# Etiquetas reales
y_test_enc = collect_labels(test_ds)
y_test = le.inverse_transform(y_test_enc.astype(int))

# Predicciones
y_pred_proba = model.predict(test_ds, verbose=1)
y_pred_enc = np.argmax(y_pred_proba, axis=1)
y_pred = le.inverse_transform(y_pred_enc)

# Metricas
accuracy = accuracy_score(y_test, y_pred)
kappa = cohen_kappa_score(y_test, y_pred)
f1_macro = f1_score(y_test, y_pred, average="macro")
f1_weighted = f1_score(y_test, y_pred, average="weighted")

print(f"\n{'='*50}")
print("RESULTADOS EN TEST SET")
print(f"{'='*50}")
print(f"   Accuracy:    {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"   Cohen Kappa: {kappa:.4f}")
print(f"   F1 Macro:    {f1_macro:.4f}")
print(f"   F1 Weighted: {f1_weighted:.4f}")
print(f"{'='*50}")

In [None]:
# ============================================================
# CLASSIFICATION REPORT
# ============================================================

print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=STAGE_ORDER, digits=4))

In [None]:
# ============================================================
# MATRIZ DE CONFUSIÓN
# ============================================================

cm = confusion_matrix(y_test, y_pred, labels=STAGE_ORDER)
cm_normalized = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Absoluta
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=STAGE_ORDER,
    yticklabels=STAGE_ORDER,
    ax=axes[0],
)
axes[0].set_xlabel("Predicho")
axes[0].set_ylabel("Real")
axes[0].set_title("Matriz de Confusión (Absoluta)")

# Normalizada
sns.heatmap(
    cm_normalized,
    annot=True,
    fmt=".2%",
    cmap="Blues",
    xticklabels=STAGE_ORDER,
    yticklabels=STAGE_ORDER,
    ax=axes[1],
)
axes[1].set_xlabel("Predicho")
axes[1].set_ylabel("Real")
axes[1].set_title("Matriz de Confusión (Normalizada)")

plt.tight_layout()
plt.savefig(f"{OUTPUT_PATH}/{model_name}_confusion_matrix.png", dpi=150)
plt.show()

## Optimizacion de Hiperparametros (Optuna)

In [None]:
# ============================================================
# OPTIMIZACION CON OPTUNA (opcional)
# ============================================================

# Optuna no soportado en modo streaming. Si necesitas tuning, desactiva streaming.
if CONFIG["run_optimization"]:
    print("[WARN] Optuna no está implementado para el modo streaming TFRecord. Skip.")

    """
        # Hiperparametros a optimizar
        lstm_units = trial.suggest_categorical("lstm_units", [64, 128, 256])
        dropout_rate = trial.suggest_float("dropout_rate", 0.3, 0.6)
        learning_rate = trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True)
        batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
        bidirectional = trial.suggest_categorical("bidirectional", [True, False])
        use_attention = trial.suggest_categorical("use_attention", [True, False])

        # Crear modelo
        with strategy.scope():
            trial_model = build_lstm_model(
                input_shape=input_shape,
                n_classes=len(STAGE_ORDER),
                lstm_units=lstm_units,
                dropout_rate=dropout_rate,
                learning_rate=learning_rate,
                bidirectional=bidirectional,
                use_attention=use_attention,
            )

        # Callbacks
        trial_callbacks = [
            EarlyStopping(monitor="val_loss", patience=7, restore_best_weights=True),
            TFKerasPruningCallback(trial, "val_loss"),
        ]

        # Entrenar
        trial_model.fit(
            X_train_lstm,
            y_train_enc,
            validation_data=(X_val_lstm, y_val_enc),
            batch_size=batch_size * strategy.num_replicas_in_sync,
            epochs=30,
            class_weight=class_weights,
            callbacks=trial_callbacks,
            verbose=0,
        )

        # Evaluar en validacion
        y_val_pred = np.argmax(trial_model.predict(X_val_lstm, verbose=0), axis=1)
        kappa = cohen_kappa_score(y_val_enc, y_val_pred)

        # Limpiar
        keras.backend.clear_session()
        gc.collect()

        return kappa

    # Crear estudio
    study = optuna.create_study(
        direction="maximize",
        study_name="lstm_sleep_staging",
        pruner=optuna.pruners.MedianPruner(),
    )

    print(f"\nIniciando optimizacion con {CONFIG['n_optuna_trials']} trials...")
    study.optimize(
        objective, n_trials=CONFIG["n_optuna_trials"], show_progress_bar=True
    )

    print("\n[OK] Optimizacion completada")
    print(f"   Mejor Kappa: {study.best_value:.4f}")
    print("   Mejores hiperparametros:")
    for key, value in study.best_params.items():
        print(f"      {key}: {value}")

    # Guardar resultados
    with open(f"{OUTPUT_PATH}/optuna_lstm_results.json", "w") as f:
        json.dump(
            {
                "best_value": study.best_value,
                "best_params": study.best_params,
            },
            f,
            indent=2,
        )
    """
else:
    print(
        "[SKIP] Optimizacion deshabilitada. Cambiar CONFIG['run_optimization'] = True para ejecutar."
    )

## Guardar Modelo y Resultados

In [None]:
# ============================================================
# GUARDAR MODELO Y ARTEFACTOS
# ============================================================

# Guardar modelo completo
model.save(f"{OUTPUT_PATH}/{model_name}_final.keras")
print(f"[OK] Modelo guardado: {OUTPUT_PATH}/{model_name}_final.keras")

# Guardar historial
history_df = pd.DataFrame(history.history)
history_df.to_csv(f"{OUTPUT_PATH}/{model_name}_history.csv", index=False)

# Guardar resultados
results = {
    "model_name": model_name,
    "config": CONFIG,
    "metrics": {
        "accuracy": float(accuracy),
        "kappa": float(kappa),
        "f1_macro": float(f1_macro),
        "f1_weighted": float(f1_weighted),
    },
    "dataset": {
        "train_samples": int(train_count),
        "val_samples": int(val_count),
        "test_samples": int(test_count),
        "tfrecord_paths": tfrecord_paths,
    },
    "channel_stats": {"mean": mean_ch.tolist(), "std": std_ch.tolist()},
    "label_encoder_classes": list(le.classes_),
}

with open(f"{OUTPUT_PATH}/{model_name}_results.json", "w") as f:
    json.dump(results, f, indent=2, default=str)

print(f"[OK] Resultados guardados: {OUTPUT_PATH}/{model_name}_results.json")

# Guardar LabelEncoder
with open(f"{OUTPUT_PATH}/{model_name}_label_encoder.pkl", "wb") as f:
    pickle.dump(le, f)

print("\nArchivos generados:")
print(f"   - {model_name}_final.keras (modelo)")
print(f"   - {model_name}_best.keras (mejor checkpoint)")
print(f"   - {model_name}_history.csv (historial)")
print(f"   - {model_name}_results.json (metricas y config)")
print(f"   - {model_name}_label_encoder.pkl (encoder)")
print(f"   - {model_name}_training_curves.png")
print(f"   - {model_name}_confusion_matrix.png")

In [None]:
# ============================================================
# RESUMEN FINAL
# ============================================================

print("\n" + "=" * 60)
print("ENTRENAMIENTO LSTM COMPLETADO")
print("=" * 60)
print("\nResultados finales en Test Set:")
print(f"   Accuracy:    {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"   Cohen Kappa: {kappa:.4f}")
print(f"   F1 Macro:    {f1_macro:.4f}")
print(f"   F1 Weighted: {f1_weighted:.4f}")
print(f"\nModelo guardado en: {OUTPUT_PATH}")
print("=" * 60)

In [None]:
# ============================================================
# COMPRIMIR ARTEFACTOS PARA DESCARGA
# ============================================================

zip_path = f"{OUTPUT_PATH}/{model_name}_artifacts.zip"
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
    for fname in os.listdir(OUTPUT_PATH):
        if fname.startswith(model_name):
            zf.write(os.path.join(OUTPUT_PATH, fname), arcname=fname)

print(f"[OK] Artefactos comprimidos en: {zip_path}")
print("Archivos incluidos:")
with zipfile.ZipFile(zip_path, "r") as zf:
    for info in zf.infolist():
        print(f" - {info.filename} ({info.file_size/1024:.1f} KB)")