# LSTM Bidireccional para Sleep Staging

Este notebook entrena un modelo **LSTM Bidireccional** para clasificacion de estadios de sueno usando datos PSG **trimmed** (preprocesados y recortados del dataset Sleep-EDF).

**Optimizado para Kaggle con 2x Tesla T4 (16GB VRAM cada una)**

### Caracteristicas:
- LSTM Bidireccional con mecanismo de atencion opcional
- Soporte multi-GPU con MirroredStrategy
- Optimizacion de hiperparametros con Optuna
- Division train/val/test por sujetos (sin data leakage)
- **Soporte para reanudar entrenamiento desde checkpoints**

### Datos requeridos:
- Dataset `sleep-trimmed` en Kaggle (https://www.kaggle.com/datasets/ignaciolinari/sleep-trimmed)
  - `manifest_trimmed.csv`
  - `sleep_trimmed/psg/*.fif` (PSG a 100 Hz, float32)
  - `sleep_trimmed/hypnograms/*.csv` (anotaciones)
- Si usas otro slug, ajusta `DATA_PATH` en la celda siguiente.

In [None]:
# ============================================================
# CONFIGURACION INICIAL - Ejecutar primero
# ============================================================

import os
import warnings

warnings.filterwarnings("ignore")

# Detectar si estamos en Kaggle
IN_KAGGLE = os.path.exists("/kaggle/input")
print(f"Entorno: {'Kaggle' if IN_KAGGLE else 'Local'}")

# Paths segun entorno
if IN_KAGGLE:
    # En Kaggle: ajustar al nombre de tu dataset
    DATA_PATH = "/kaggle/input/sleep-trimmed"  # <- Ajustar al slug/version que uses
    OUTPUT_PATH = "/kaggle/working"
else:
    # Local
    DATA_PATH = "../data/processed"
    OUTPUT_PATH = "../models"

print(f"Data path: {DATA_PATH}")
print(f"Output path: {OUTPUT_PATH}")

In [None]:
# ============================================================
# VERIFICAR GPUs DISPONIBLES
# ============================================================

import tensorflow as tf  # noqa: E402

# Semilla global para reproducibilidad
SEED = 42
tf.random.set_seed(SEED)

print("TensorFlow version:", tf.__version__)
print("\nGPUs disponibles:")

gpus = tf.config.list_physical_devices("GPU")
if gpus:
    for i, gpu in enumerate(gpus):
        print(f"  GPU {i}: {gpu}")

    # Habilitar memory growth para evitar OOM
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

    print(f"\n[OK] {len(gpus)} GPU(s) configuradas con memory growth")
else:
    print("[WARN] No se detectaron GPUs. El entrenamiento sera lento.")

# Estrategia de distribucion para multiples GPUs
if len(gpus) > 1:
    strategy = tf.distribute.MirroredStrategy()
    print(f"\nUsando MirroredStrategy con {strategy.num_replicas_in_sync} GPUs")
else:
    strategy = tf.distribute.get_strategy()
    print("\nUsando estrategia por defecto (1 GPU o CPU)")

In [None]:
# ============================================================
# IMPORTS Y DEPENDENCIAS
# ============================================================

import numpy as np  # noqa: E402
import pandas as pd  # noqa: E402
import matplotlib.pyplot as plt  # noqa: E402
import seaborn as sns  # noqa: E402
from pathlib import Path  # noqa: E402
import logging  # noqa: E402
import json  # noqa: E402
from datetime import datetime  # noqa: E402
import pickle  # noqa: E402
import gc  # noqa: E402
import zipfile  # noqa: E402

from sklearn.preprocessing import LabelEncoder  # noqa: E402
from sklearn.utils.class_weight import compute_class_weight  # noqa: E402
from sklearn.metrics import (  # noqa: E402
    classification_report,
    confusion_matrix,
    cohen_kappa_score,
    accuracy_score,
    f1_score,
)

from tensorflow import keras  # noqa: E402
from tensorflow.keras import layers  # noqa: E402
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint  # noqa: E402

# Configurar logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

# Configurar estilo de plots
plt.style.use("seaborn-v0_8-whitegrid")
sns.set_palette("husl")

print("[OK] Imports completados")

In [None]:
# ============================================================
# INSTALAR DEPENDENCIAS (solo en Kaggle)
# ============================================================

if IN_KAGGLE:
    print("Instalando dependencias...")
    !pip install -q mne yasa
    print("[OK] Dependencias instaladas")

import mne  # noqa: E402

mne.set_log_level("ERROR")
print(f"MNE version: {mne.__version__}")

## Configuracion del Experimento

In [None]:
# ============================================================
# HIPERPARAMETROS - AJUSTAR SEGUN NECESIDAD
# ============================================================

CONFIG = {
    # Dataset (100 Hz)
    "manifest_path": f"{DATA_PATH}/manifest_trimmed.csv",
    "epoch_length": 30.0,  # segundos
    "sfreq": 100,  # Hz (dataset a 100 Hz)
    "limit_sessions": None,  # None = todas, numero = limitar para debug
    # Split
    "test_size": 0.15,
    "val_size": 0.15,
    "random_state": 42,
    # Modelo LSTM
    "lstm_units": 128,  # Unidades LSTM (se reduce a la mitad en 2da capa)
    "dropout_rate": 0.5,
    "bidirectional": True,  # LSTM bidireccional
    "use_attention": True,  # Mecanismo de atencion
    # Entrenamiento
    "learning_rate": 0.001,
    "batch_size": 64,  # Con 2x T4 puede aumentar a 128
    "epochs": 100,
    "early_stopping_patience": 15,
    "reduce_lr_patience": 7,
    # Optimizacion (Optuna)
    "run_optimization": False,  # Cambiar a True para buscar hiperparametros
    "n_optuna_trials": 30,
}

# Ajustar batch size para multi-GPU
CONFIG["effective_batch_size"] = CONFIG["batch_size"] * strategy.num_replicas_in_sync

print("Configuracion del experimento:")
for key, value in CONFIG.items():
    print(f"   {key}: {value}")

## Carga de Datos

In [None]:
# ============================================================
# FUNCIONES DE CARGA DE DATOS
# ============================================================

# Canales y estadios
DEFAULT_CHANNELS = {
    "EEG": ["EEG Fpz-Cz", "EEG Pz-Oz"],
    "EOG": ["EOG horizontal"],
    "EMG": ["EMG submental"],
}

STAGE_CANONICAL = {
    "Sleep stage W": "W",
    "Sleep stage 1": "N1",
    "Sleep stage 2": "N2",
    "Sleep stage 3": "N3",
    "Sleep stage 4": "N3",
    "Sleep stage R": "REM",
}

STAGE_ORDER = ["W", "N1", "N2", "N3", "REM"]


def load_psg_data(psg_path, channels=None, target_sfreq=None):
    """Carga datos PSG desde archivo .fif."""
    raw = mne.io.read_raw_fif(str(psg_path), preload=True, verbose="ERROR")

    if channels is None:
        available = set(raw.ch_names)
        channels = []
        for ch_group in DEFAULT_CHANNELS.values():
            for ch in ch_group:
                if ch in available:
                    channels.append(ch)

    raw.pick(channels)  # pick_channels() está deprecado

    if target_sfreq and raw.info["sfreq"] != target_sfreq:
        raw.resample(target_sfreq)

    data = raw.get_data()
    return data, raw.info["sfreq"], raw.ch_names


def load_hypnogram(hyp_path):
    """Carga hipnograma desde CSV."""
    df = pd.read_csv(hyp_path)
    df["stage_canonical"] = df["description"].map(STAGE_CANONICAL)
    return df


def create_epochs(data, sfreq, epoch_length=30.0):
    """Divide datos en epochs de longitud fija."""
    samples_per_epoch = int(epoch_length * sfreq)
    n_channels, n_samples = data.shape
    n_epochs = n_samples // samples_per_epoch

    epochs = []
    epoch_times = []

    for i in range(n_epochs):
        start = i * samples_per_epoch
        end = start + samples_per_epoch
        epoch = data[:, start:end]
        epochs.append(epoch)
        epoch_times.append(i * epoch_length)

    return np.array(epochs), np.array(epoch_times)


def assign_stages(epoch_times, hypnogram, epoch_length=30.0):
    """Asigna estadios a cada epoch."""
    stages = []

    for t in epoch_times:
        epoch_center = t + epoch_length / 2
        mask = (hypnogram["onset"] <= epoch_center) & (
            hypnogram["onset"] + hypnogram["duration"] > epoch_center
        )
        matched = hypnogram[mask]

        if len(matched) > 0:
            stage = matched.iloc[0]["stage_canonical"]
            stages.append(stage)
        else:
            stages.append(None)

    return stages


print("[OK] Funciones de carga definidas")

In [None]:
# ============================================================
# CARGAR DATASET COMPLETO
# ============================================================


def prepare_raw_epochs_dataset(manifest_path, limit=None, epoch_length=30.0, sfreq=100):
    """Prepara dataset de epochs raw para LSTM."""

    manifest = pd.read_csv(manifest_path)
    manifest_ok = manifest[manifest["status"] == "ok"].copy()

    if limit:
        manifest_ok = manifest_ok.head(limit)
        print(f"[WARN] Modo debug: procesando solo {limit} sesiones")

    all_epochs = []
    all_stages = []
    all_metadata = []
    skipped_files = 0
    error_count = 0

    manifest_dir = Path(manifest_path).parent
    dataset_dir_name = (
        "sleep_trimmed_resamp200"
        if (manifest_dir / "sleep_trimmed_resamp200").exists()
        else "sleep_trimmed"
    )
    total_sessions = len(manifest_ok)

    print(f"\nProcesando {total_sessions} sesiones...")

    for i, (idx, row) in enumerate(manifest_ok.iterrows()):
        subject_id = row["subject_id"]

        # Construir paths (manejar NaN)
        psg_path_str = row.get("psg_trimmed_path", "")
        hyp_path_str = row.get("hypnogram_trimmed_path", "")

        # Manejar valores NaN de pandas
        if pd.isna(psg_path_str):
            psg_path_str = ""
        if pd.isna(hyp_path_str):
            hyp_path_str = ""

        base_data_root = manifest_dir.parent

        if psg_path_str and hyp_path_str:
            # Paths relativos desde el manifest
            if IN_KAGGLE:
                # Extraer path relativo de forma segura
                psg_rel = Path(psg_path_str)
                hyp_rel = Path(hyp_path_str)
                psg_parts = psg_rel.parts
                hyp_parts = hyp_rel.parts

                psg_anchor_idx = next(
                    (
                        i
                        for i, p in enumerate(psg_parts)
                        if p.startswith("sleep_trimmed")
                    ),
                    None,
                )
                hyp_anchor_idx = next(
                    (
                        i
                        for i, p in enumerate(hyp_parts)
                        if p.startswith("sleep_trimmed")
                    ),
                    None,
                )

                if psg_anchor_idx is not None and hyp_anchor_idx is not None:
                    psg_path = Path(DATA_PATH) / Path(*psg_parts[psg_anchor_idx:])
                    hyp_path = Path(DATA_PATH) / Path(*hyp_parts[hyp_anchor_idx:])
                else:
                    # Fallback: usar solo el nombre del archivo
                    psg_path = Path(DATA_PATH) / dataset_dir_name / "psg" / psg_rel.name
                    hyp_path = (
                        Path(DATA_PATH) / dataset_dir_name / "hypnograms" / hyp_rel.name
                    )
            else:
                psg_rel = Path(psg_path_str)
                hyp_rel = Path(hyp_path_str)

                if not psg_rel.is_absolute():
                    if psg_rel.parts and psg_rel.parts[0] == "data":
                        psg_path = base_data_root / psg_rel.relative_to("data")
                    else:
                        psg_path = manifest_dir / psg_rel
                else:
                    psg_path = psg_rel

                if not hyp_rel.is_absolute():
                    if hyp_rel.parts and hyp_rel.parts[0] == "data":
                        hyp_path = base_data_root / hyp_rel.relative_to("data")
                    else:
                        hyp_path = manifest_dir / hyp_rel
                else:
                    hyp_path = hyp_rel
        else:
            # Construir paths manualmente
            subset = row.get("subset", "sleep-cassette")
            version = row.get("version", "1.0.0")
            dataset_dir = manifest_dir / dataset_dir_name
            psg_path = (
                dataset_dir / "psg" / f"{subject_id}_{subset}_{version}_trimmed_raw.fif"
            )
            hyp_path = (
                dataset_dir
                / "hypnograms"
                / f"{subject_id}_{subset}_{version}_trimmed_annotations.csv"
            )

        if not psg_path.exists() or not hyp_path.exists():
            skipped_files += 1
            continue

        try:
            # Cargar datos
            data, actual_sfreq, ch_names = load_psg_data(psg_path, target_sfreq=sfreq)
            hypnogram = load_hypnogram(hyp_path)

            # Crear epochs
            epochs, epoch_times = create_epochs(data, actual_sfreq, epoch_length)

            # Asignar estadios
            stages = assign_stages(epoch_times, hypnogram, epoch_length)

            # Filtrar epochs validos
            for epoch_idx, (epoch, stage, epoch_time) in enumerate(
                zip(epochs, stages, epoch_times)
            ):
                if stage is not None and stage in STAGE_ORDER:
                    all_epochs.append(epoch)
                    all_stages.append(stage)
                    all_metadata.append(
                        {
                            "subject_id": subject_id,
                            "subject_core": subject_id[:5],
                            "epoch_time_start": epoch_time,
                            "epoch_index": epoch_idx,
                        }
                    )

            if (i + 1) % 20 == 0:
                print(f"   Procesadas {i + 1}/{total_sessions} sesiones...")

        except Exception as e:
            error_count += 1
            print(f"   [WARN] Error en {subject_id}: {e}")
            continue

    # Resumen de carga
    if skipped_files > 0:
        print(f"   [INFO] Archivos no encontrados: {skipped_files}")
    if error_count > 0:
        print(f"   [WARN] Errores durante carga: {error_count}")

    # Validaciones
    assert len(all_epochs) > 0, "ERROR: No se cargaron epochs válidos"

    X_raw = np.array(all_epochs)
    y = np.array(all_stages)
    metadata_df = pd.DataFrame(all_metadata)

    n_classes_found = len(set(y))
    assert (
        n_classes_found >= 2
    ), f"ERROR: Se esperaban al menos 2 clases, se encontraron {n_classes_found}"

    print("\n[OK] Dataset cargado:")
    print(f"   X_raw shape: {X_raw.shape}")
    print(f"   Clases encontradas: {n_classes_found}")
    print("   Distribucion de clases:")
    for stage in STAGE_ORDER:
        count = (y == stage).sum()
        pct = count / len(y) * 100 if len(y) > 0 else 0
        print(f"      {stage}: {count:,} ({pct:.1f}%)")

    return X_raw, y, metadata_df


# Cargar datos
X_raw, y, metadata_df = prepare_raw_epochs_dataset(
    CONFIG["manifest_path"],
    limit=CONFIG["limit_sessions"],
    epoch_length=CONFIG["epoch_length"],
    sfreq=CONFIG["sfreq"],
)

# Validar shapes
print("\n[CHECK] Validacion de datos:")
print(f"   X_raw dtype: {X_raw.dtype}")
print(
    f"   Epochs por sujeto (promedio): {len(X_raw) / metadata_df['subject_core'].nunique():.1f}"
)

In [None]:
# ============================================================
# DIVISION TRAIN/VAL/TEST POR SUJETOS
# ============================================================


def split_by_subjects(
    X, y, metadata_df, test_size=0.15, val_size=0.15, random_state=42
):
    """Divide dataset respetando sujetos (sin data leakage)."""

    np.random.seed(random_state)

    subject_cores = metadata_df["subject_core"].unique()
    n_cores = len(subject_cores)
    shuffled_cores = np.random.permutation(subject_cores)

    n_test = max(1, int(n_cores * test_size))
    n_val = max(1, int(n_cores * val_size))

    test_cores = set(shuffled_cores[:n_test])
    val_cores = set(shuffled_cores[n_test : n_test + n_val])
    train_cores = set(shuffled_cores[n_test + n_val :])

    train_mask = metadata_df["subject_core"].isin(train_cores)
    val_mask = metadata_df["subject_core"].isin(val_cores)
    test_mask = metadata_df["subject_core"].isin(test_cores)

    X_train, y_train = X[train_mask], y[train_mask]
    X_val, y_val = X[val_mask], y[val_mask]
    X_test, y_test = X[test_mask], y[test_mask]

    print("\nDivision del dataset:")
    print(f"   Train: {len(X_train):,} epochs de {len(train_cores)} sujetos")
    print(f"   Val:   {len(X_val):,} epochs de {len(val_cores)} sujetos")
    print(f"   Test:  {len(X_test):,} epochs de {len(test_cores)} sujetos")

    return (X_train, y_train), (X_val, y_val), (X_test, y_test)


(X_train, y_train), (X_val, y_val), (X_test, y_test) = split_by_subjects(
    X_raw,
    y,
    metadata_df,
    test_size=CONFIG["test_size"],
    val_size=CONFIG["val_size"],
    random_state=CONFIG["random_state"],
)

# Liberar memoria
del X_raw
gc.collect()

In [None]:
# ============================================================
# NORMALIZACION Y PREPARACION PARA LSTM
# ============================================================


def normalize_for_lstm(X_train, X_val, X_test):
    """Normaliza y transpone datos para LSTM.

    LSTM espera: (samples, timesteps, features)
    Input viene como: (samples, channels, timesteps)
    Transponemos a: (samples, timesteps, channels)
    """

    n_channels = X_train.shape[1]
    channel_stats = []

    X_train_norm = np.zeros_like(X_train, dtype=np.float32)
    X_val_norm = np.zeros_like(X_val, dtype=np.float32)
    X_test_norm = np.zeros_like(X_test, dtype=np.float32)

    # Normalizar por canal
    for ch in range(n_channels):
        mean = X_train[:, ch, :].mean()
        std = X_train[:, ch, :].std()
        channel_stats.append({"mean": mean, "std": std})

        if std > 0:
            X_train_norm[:, ch, :] = (X_train[:, ch, :] - mean) / std
            X_val_norm[:, ch, :] = (X_val[:, ch, :] - mean) / std
            X_test_norm[:, ch, :] = (X_test[:, ch, :] - mean) / std
        else:
            X_train_norm[:, ch, :] = X_train[:, ch, :]
            X_val_norm[:, ch, :] = X_val[:, ch, :]
            X_test_norm[:, ch, :] = X_test[:, ch, :]

    # Transponer: (samples, channels, timesteps) -> (samples, timesteps, channels)
    X_train_lstm = np.transpose(X_train_norm, (0, 2, 1))
    X_val_lstm = np.transpose(X_val_norm, (0, 2, 1))
    X_test_lstm = np.transpose(X_test_norm, (0, 2, 1))

    print("[OK] Normalizacion y transposicion completada")
    print(f"   Shape para LSTM: {X_train_lstm.shape}")
    return X_train_lstm, X_val_lstm, X_test_lstm, channel_stats


# Normalizar y preparar para LSTM
X_train_lstm, X_val_lstm, X_test_lstm, channel_stats = normalize_for_lstm(
    X_train, X_val, X_test
)

# Codificar etiquetas
le = LabelEncoder()
le.fit(STAGE_ORDER)
y_train_enc = le.transform(y_train)
y_val_enc = le.transform(y_val)
y_test_enc = le.transform(y_test)

print(f"\nClases: {le.classes_}")

# Calcular class weights
class_weights_arr = compute_class_weight(
    "balanced", classes=np.unique(y_train_enc), y=y_train_enc
)
class_weights = dict(enumerate(class_weights_arr))
print(f"Class weights: {class_weights}")

# Liberar memoria de arrays no normalizados
del X_train, X_val, X_test
gc.collect()

## Arquitectura LSTM Bidireccional

In [None]:
# ============================================================
# CAPA DE ATENCION PERSONALIZADA
# ============================================================


class AttentionLayer(layers.Layer):
    """Capa de atencion simple para LSTM.

    Permite al modelo enfocarse en las partes mas relevantes
    de la secuencia temporal para la clasificacion.
    """

    def __init__(self, **kwargs):
        super(AttentionLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.W = self.add_weight(
            name="attention_weight",
            shape=(input_shape[-1], 1),
            initializer="glorot_uniform",
            trainable=True,
        )
        self.b = self.add_weight(
            name="attention_bias",
            shape=(input_shape[1], 1),
            initializer="zeros",
            trainable=True,
        )
        super(AttentionLayer, self).build(input_shape)

    def call(self, x):
        # x shape: (batch, timesteps, features)
        # Calcular scores de atencion
        e = keras.backend.tanh(keras.backend.dot(x, self.W) + self.b)
        # Softmax sobre timesteps
        a = keras.backend.softmax(e, axis=1)
        # Weighted sum
        output = x * a
        return keras.backend.sum(output, axis=1)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[-1])

    def get_config(self):
        return super(AttentionLayer, self).get_config()


print("[OK] Capa de Atencion definida")

In [None]:
# ============================================================
# MODELO LSTM BIDIRECCIONAL CON ATENCION
# ============================================================


def build_lstm_model(
    input_shape,
    n_classes=5,
    lstm_units=128,
    dropout_rate=0.5,
    learning_rate=0.001,
    bidirectional=True,
    use_attention=True,
):
    """Construye modelo LSTM Bidireccional para sleep staging.

    Arquitectura:
    - 2 capas LSTM (bidireccionales opcionales)
    - BatchNorm y Dropout entre capas
    - Capa de atencion opcional
    - Capas densas para clasificacion

    NOTA: No usamos recurrent_dropout para mantener compatibilidad con cuDNN
    """

    # Input: (timesteps, features)
    input_layer = keras.Input(shape=input_shape, name="input")

    # Primera capa LSTM (retorna secuencias)
    lstm_1 = layers.LSTM(
        lstm_units,
        return_sequences=True,
        kernel_regularizer=keras.regularizers.l2(1e-4),
        name="lstm_1",
    )

    if bidirectional:
        x = layers.Bidirectional(lstm_1, name="bidirectional_1")(input_layer)
    else:
        x = lstm_1(input_layer)

    x = layers.BatchNormalization(name="bn_1")(x)
    x = layers.Dropout(dropout_rate, name="dropout_lstm_1")(x)

    # Segunda capa LSTM
    lstm_2 = layers.LSTM(
        lstm_units // 2,
        return_sequences=use_attention,  # Solo retorna secuencias si usamos atencion
        kernel_regularizer=keras.regularizers.l2(1e-4),
        name="lstm_2",
    )

    if bidirectional:
        x = layers.Bidirectional(lstm_2, name="bidirectional_2")(x)
    else:
        x = lstm_2(x)

    x = layers.BatchNormalization(name="bn_2")(x)
    x = layers.Dropout(dropout_rate, name="dropout_lstm_2")(x)

    # Capa de atencion (opcional)
    if use_attention:
        x = AttentionLayer(name="attention")(x)

    # Capas densas para clasificacion
    x = layers.Dense(
        128,
        activation="relu",
        kernel_regularizer=keras.regularizers.l2(1e-4),
        name="dense_1",
    )(x)
    x = layers.Dropout(dropout_rate, name="dropout_1")(x)

    x = layers.Dense(
        64,
        activation="relu",
        kernel_regularizer=keras.regularizers.l2(1e-4),
        name="dense_2",
    )(x)
    x = layers.Dropout(dropout_rate, name="dropout_2")(x)

    # Output
    output_layer = layers.Dense(n_classes, activation="softmax", name="output")(x)

    # Nombre del modelo segun configuracion
    model_name = "BiLSTM" if bidirectional else "LSTM"
    if use_attention:
        model_name += "_Attention"
    model_name += "_SleepStaging"

    model = keras.Model(inputs=input_layer, outputs=output_layer, name=model_name)

    # Compilar
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )

    return model


print("[OK] Arquitectura LSTM definida")

In [None]:
# ============================================================
# CREAR MODELO (con estrategia multi-GPU si está disponible)
# ============================================================

# Input shape: (timesteps, features) = (X_train_lstm.shape[1], n_channels)
input_shape = (X_train_lstm.shape[1], X_train_lstm.shape[2])
print(f"Input shape: {input_shape}")

with strategy.scope():
    model = build_lstm_model(
        input_shape=input_shape,
        n_classes=len(STAGE_ORDER),
        lstm_units=CONFIG["lstm_units"],
        dropout_rate=CONFIG["dropout_rate"],
        learning_rate=CONFIG["learning_rate"],
        bidirectional=CONFIG["bidirectional"],
        use_attention=CONFIG["use_attention"],
    )

model.summary()

## Entrenamiento

In [None]:
# ============================================================
# CALLBACKS
# ============================================================

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_name = f"lstm_{timestamp}"

callbacks = [
    EarlyStopping(
        monitor="val_loss",
        patience=CONFIG["early_stopping_patience"],
        restore_best_weights=True,
        verbose=1,
    ),
    ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=CONFIG["reduce_lr_patience"],
        min_lr=1e-6,
        verbose=1,
    ),
    ModelCheckpoint(
        filepath=f"{OUTPUT_PATH}/{model_name}_best.keras",
        monitor="val_loss",
        save_best_only=True,
        verbose=1,
    ),
]

print(f"Modelo se guardara en: {OUTPUT_PATH}/{model_name}_best.keras")

## Reanudacion desde Checkpoint (opcional)

Si Kaggle se desconecto durante el entrenamiento, puedes reanudar desde el ultimo checkpoint.
Solo ejecuta la siguiente celda si necesitas reanudar.

In [None]:
# ============================================================
# REANUDAR DESDE CHECKPOINT (ejecutar solo si es necesario)
# ============================================================
# Descomenta y ajusta el nombre del checkpoint si necesitas reanudar

RESUME_FROM_CHECKPOINT = False  # Cambiar a True para reanudar
CHECKPOINT_NAME = None  # Ejemplo: "lstm_20251125_143022_best.keras"

if RESUME_FROM_CHECKPOINT and CHECKPOINT_NAME:
    checkpoint_path = f"{OUTPUT_PATH}/{CHECKPOINT_NAME}"
    if os.path.exists(checkpoint_path):
        print(f"[INFO] Cargando modelo desde checkpoint: {checkpoint_path}")
        # Registrar la capa de atencion personalizada
        with strategy.scope():
            model = keras.models.load_model(
                checkpoint_path, custom_objects={"AttentionLayer": AttentionLayer}
            )
        print("[OK] Modelo cargado exitosamente")
        print("[INFO] Puedes continuar el entrenamiento ejecutando la celda de fit()")
        print("[INFO] El modelo continuara desde donde quedo")
    else:
        print(f"[ERROR] Checkpoint no encontrado: {checkpoint_path}")
        print(f"[INFO] Archivos disponibles en {OUTPUT_PATH}:")
        for f in os.listdir(OUTPUT_PATH):
            if f.endswith(".keras"):
                print(f"   - {f}")
else:
    print("[INFO] Modo normal: se usara el modelo recien creado")
    print("[INFO] Para reanudar desde checkpoint, cambia RESUME_FROM_CHECKPOINT=True")

In [None]:
# ============================================================
# ENTRENAR MODELO
# ============================================================

print("\nIniciando entrenamiento LSTM...")
print(f"   Batch size efectivo: {CONFIG['effective_batch_size']}")
print(f"   Epochs maximos: {CONFIG['epochs']}")

history = model.fit(
    X_train_lstm,
    y_train_enc,
    validation_data=(X_val_lstm, y_val_enc),
    batch_size=CONFIG["effective_batch_size"],
    epochs=CONFIG["epochs"],
    class_weight=class_weights,
    callbacks=callbacks,
    verbose=1,
)

print("\n[OK] Entrenamiento completado")

In [None]:
# ============================================================
# VISUALIZAR CURVAS DE APRENDIZAJE
# ============================================================

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(history.history["loss"], label="Train Loss", linewidth=2)
axes[0].plot(history.history["val_loss"], label="Val Loss", linewidth=2)
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Loss durante entrenamiento")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history.history["accuracy"], label="Train Acc", linewidth=2)
axes[1].plot(history.history["val_accuracy"], label="Val Acc", linewidth=2)
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Accuracy")
axes[1].set_title("Accuracy durante entrenamiento")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_PATH}/{model_name}_training_curves.png", dpi=150)
plt.show()

## Evaluacion en Test

In [None]:
# ============================================================
# EVALUACION EN TEST SET
# ============================================================

print("\nEvaluando en Test Set...")

# Predicciones
y_pred_proba = model.predict(
    X_test_lstm, batch_size=CONFIG["batch_size"] * 2, verbose=1
)
y_pred_enc = np.argmax(y_pred_proba, axis=1)
y_pred = le.inverse_transform(y_pred_enc)

# Metricas
accuracy = accuracy_score(y_test, y_pred)
kappa = cohen_kappa_score(y_test, y_pred)
f1_macro = f1_score(y_test, y_pred, average="macro")
f1_weighted = f1_score(y_test, y_pred, average="weighted")

print(f"\n{'='*50}")
print("RESULTADOS EN TEST SET")
print(f"{'='*50}")
print(f"   Accuracy:    {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"   Cohen Kappa: {kappa:.4f}")
print(f"   F1 Macro:    {f1_macro:.4f}")
print(f"   F1 Weighted: {f1_weighted:.4f}")
print(f"{'='*50}")

In [None]:
# ============================================================
# CLASSIFICATION REPORT
# ============================================================

print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=STAGE_ORDER, digits=4))

In [None]:
# ============================================================
# MATRIZ DE CONFUSIÓN
# ============================================================

cm = confusion_matrix(y_test, y_pred, labels=STAGE_ORDER)
cm_normalized = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Absoluta
sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=STAGE_ORDER,
    yticklabels=STAGE_ORDER,
    ax=axes[0],
)
axes[0].set_xlabel("Predicho")
axes[0].set_ylabel("Real")
axes[0].set_title("Matriz de Confusión (Absoluta)")

# Normalizada
sns.heatmap(
    cm_normalized,
    annot=True,
    fmt=".2%",
    cmap="Blues",
    xticklabels=STAGE_ORDER,
    yticklabels=STAGE_ORDER,
    ax=axes[1],
)
axes[1].set_xlabel("Predicho")
axes[1].set_ylabel("Real")
axes[1].set_title("Matriz de Confusión (Normalizada)")

plt.tight_layout()
plt.savefig(f"{OUTPUT_PATH}/{model_name}_confusion_matrix.png", dpi=150)
plt.show()

## Optimizacion de Hiperparametros (Optuna)

In [None]:
# ============================================================
# OPTIMIZACION CON OPTUNA (opcional)
# ============================================================

if CONFIG["run_optimization"]:
    !pip install -q optuna
    import optuna
    from optuna.integration import TFKerasPruningCallback

    def objective(trial):
        """Funcion objetivo para Optuna."""

        # Hiperparametros a optimizar
        lstm_units = trial.suggest_categorical("lstm_units", [64, 128, 256])
        dropout_rate = trial.suggest_float("dropout_rate", 0.3, 0.6)
        learning_rate = trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True)
        batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
        bidirectional = trial.suggest_categorical("bidirectional", [True, False])
        use_attention = trial.suggest_categorical("use_attention", [True, False])

        # Crear modelo
        with strategy.scope():
            trial_model = build_lstm_model(
                input_shape=input_shape,
                n_classes=len(STAGE_ORDER),
                lstm_units=lstm_units,
                dropout_rate=dropout_rate,
                learning_rate=learning_rate,
                bidirectional=bidirectional,
                use_attention=use_attention,
            )

        # Callbacks
        trial_callbacks = [
            EarlyStopping(monitor="val_loss", patience=7, restore_best_weights=True),
            TFKerasPruningCallback(trial, "val_loss"),
        ]

        # Entrenar
        trial_model.fit(
            X_train_lstm,
            y_train_enc,
            validation_data=(X_val_lstm, y_val_enc),
            batch_size=batch_size * strategy.num_replicas_in_sync,
            epochs=30,
            class_weight=class_weights,
            callbacks=trial_callbacks,
            verbose=0,
        )

        # Evaluar en validacion
        y_val_pred = np.argmax(trial_model.predict(X_val_lstm, verbose=0), axis=1)
        kappa = cohen_kappa_score(y_val_enc, y_val_pred)

        # Limpiar
        keras.backend.clear_session()
        gc.collect()

        return kappa

    # Crear estudio
    study = optuna.create_study(
        direction="maximize",
        study_name="lstm_sleep_staging",
        pruner=optuna.pruners.MedianPruner(),
    )

    print(f"\nIniciando optimizacion con {CONFIG['n_optuna_trials']} trials...")
    study.optimize(
        objective, n_trials=CONFIG["n_optuna_trials"], show_progress_bar=True
    )

    print("\n[OK] Optimizacion completada")
    print(f"   Mejor Kappa: {study.best_value:.4f}")
    print("   Mejores hiperparametros:")
    for key, value in study.best_params.items():
        print(f"      {key}: {value}")

    # Guardar resultados
    with open(f"{OUTPUT_PATH}/optuna_lstm_results.json", "w") as f:
        json.dump(
            {
                "best_value": study.best_value,
                "best_params": study.best_params,
            },
            f,
            indent=2,
        )
else:
    print(
        "[SKIP] Optimizacion deshabilitada. Cambiar CONFIG['run_optimization'] = True para ejecutar."
    )

## Guardar Modelo y Resultados

In [None]:
# ============================================================
# GUARDAR MODELO Y ARTEFACTOS
# ============================================================

# Guardar modelo completo
model.save(f"{OUTPUT_PATH}/{model_name}_final.keras")
print(f"[OK] Modelo guardado: {OUTPUT_PATH}/{model_name}_final.keras")

# Guardar historial
history_df = pd.DataFrame(history.history)
history_df.to_csv(f"{OUTPUT_PATH}/{model_name}_history.csv", index=False)

# Guardar resultados
results = {
    "model_name": model_name,
    "config": CONFIG,
    "metrics": {
        "accuracy": float(accuracy),
        "kappa": float(kappa),
        "f1_macro": float(f1_macro),
        "f1_weighted": float(f1_weighted),
    },
    "dataset": {
        "train_samples": len(X_train_lstm),
        "val_samples": len(X_val_lstm),
        "test_samples": len(X_test_lstm),
    },
    "channel_stats": channel_stats,
    "label_encoder_classes": list(le.classes_),
}

with open(f"{OUTPUT_PATH}/{model_name}_results.json", "w") as f:
    json.dump(results, f, indent=2, default=str)

print(f"[OK] Resultados guardados: {OUTPUT_PATH}/{model_name}_results.json")

# Guardar LabelEncoder
with open(f"{OUTPUT_PATH}/{model_name}_label_encoder.pkl", "wb") as f:
    pickle.dump(le, f)

print("\nArchivos generados:")
print(f"   - {model_name}_final.keras (modelo)")
print(f"   - {model_name}_best.keras (mejor checkpoint)")
print(f"   - {model_name}_history.csv (historial)")
print(f"   - {model_name}_results.json (metricas y config)")
print(f"   - {model_name}_label_encoder.pkl (encoder)")
print(f"   - {model_name}_training_curves.png")
print(f"   - {model_name}_confusion_matrix.png")

In [None]:
# ============================================================
# RESUMEN FINAL
# ============================================================

print("\n" + "=" * 60)
print("ENTRENAMIENTO LSTM COMPLETADO")
print("=" * 60)
print("\nResultados finales en Test Set:")
print(f"   Accuracy:    {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"   Cohen Kappa: {kappa:.4f}")
print(f"   F1 Macro:    {f1_macro:.4f}")
print(f"   F1 Weighted: {f1_weighted:.4f}")
print(f"\nModelo guardado en: {OUTPUT_PATH}")
print("=" * 60)

In [None]:
# ============================================================
# COMPRIMIR ARTEFACTOS PARA DESCARGA
# ============================================================

zip_path = f"{OUTPUT_PATH}/{model_name}_artifacts.zip"
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
    for fname in os.listdir(OUTPUT_PATH):
        if fname.startswith(model_name):
            zf.write(os.path.join(OUTPUT_PATH, fname), arcname=fname)

print(f"[OK] Artefactos comprimidos en: {zip_path}")
print("Archivos incluidos:")
with zipfile.ZipFile(zip_path, "r") as zf:
    for info in zf.infolist():
        print(f" - {info.filename} ({info.file_size/1024:.1f} KB)")