# 03 — Entrenamiento y Evaluación (SUPERVISED y CONTINUAL con EWC/NAIVE)

Este notebook entrena un modelo **SNN** para **regresión del ángulo de dirección (steering)** en dos protocolos:

- **Supervised** sobre `circuito1`.
- **Continual** con dos tareas secuenciales `circuito1 → circuito2` usando:
  - **EWC** (consolidación elástica de pesos), o
  - **NAIVE** (baseline sin penalización; equivalente a λ=0).

> **Requisitos previos**: Ejecuta `01_DATA_QC_PREP.ipynb` para generar `train/val/test.csv` y `tasks.json`.


In [1]:
# =============================================================================
# Imports y setup
# =============================================================================
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

from pathlib import Path
import sys, json, torch

# Detecta la raíz del repo (si estás dentro de notebooks/, sube un nivel)
ROOT = Path.cwd().parents[0] if (Path.cwd().name == "notebooks") else Path.cwd()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

# Utilidades y módulos del proyecto
from src.datasets import AugmentConfig
from src.utils import set_seeds, load_preset, make_loaders_from_csvs
from src.datasets import ImageTransform
from src.models import SNNVisionRegressor
from src.training import TrainConfig, train_supervised, _permute_if_needed
from src.methods.ewc import EWC, EWCConfig

# Dispositivo (CUDA si disponible)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ROOT, device

SEED = 42

# Transformación de imagen
# IMPORTANTE: usa argumentos **posicionales** (w, h, to_gray, crop_top)
# Evita keywords tipo target_w/target_h porque la clase no los define.
tfm = ImageTransform(160, 80, True, None)

def make_snn_model(tfm):
    # C = 1 si gris, 3 si color
    return SNNVisionRegressor(in_channels=(1 if tfm.to_gray else 3), lif_beta=0.95)

torch.set_num_threads(4)               # evita sobre-contención de CPU al decodificar
torch.backends.cudnn.benchmark = True  # selecciona la mejor impl. de convs para tamaño fijo

# Permite TF32 (barato y suele acelerar matmul/convs en Ampere+ sin tocar precisión de FP16)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Precisión alta para kernels FP32 cuando no uses AMP
torch.set_float32_matmul_precision("high")



In [2]:
GPU_ENCODE = True  # activa codificación (rate/latency/raw) en GPU
RUN_BENCH = False   # pon True para ejecutar

# --- SAFE MODE: desactiva todo lo pesado para estabilizar ---
SAFE_MODE = False  # o False cuando quieras rendimiento

NUM_WORKERS   = 12
PREFETCH      = 2
PIN_MEMORY    = True
PERSISTENT    = True

# AUG_CFG = None
from src.datasets import AugmentConfig
# --- Augment: perfiles ---
AUG_CFG_LIGHT = AugmentConfig(prob_hflip=0.5, brightness=None, gamma=None, noise_std=0.0)
AUG_CFG_FULL  = AugmentConfig(prob_hflip=0.5, brightness=(0.9, 1.1), gamma=(0.95, 1.05), noise_std=0.005)

AUG_CFG = AUG_CFG_LIGHT   # ← usa LIGHT ahora; luego prueba FULL

USE_OFFLINE_BALANCED  = True
USE_ONLINE_BALANCING  = False

if SAFE_MODE:
    NUM_WORKERS  = 0
    PREFETCH     = None   # ← IMPORTANTÍSIMO con num_workers=0
    PIN_MEMORY   = False
    PERSISTENT   = False  # ← también obligatorio con num_workers=0
    USE_OFFLINE_BALANCED = False
    USE_ONLINE_BALANCING = False
    AUG_CFG = None

print(f"[SAFE_MODE={SAFE_MODE}] workers={NUM_WORKERS} prefetch={PREFETCH} "
      f"pin={PIN_MEMORY} persistent={PERSISTENT}")



[SAFE_MODE=False] workers=12 prefetch=2 pin=True persistent=True


In [3]:
# =============================================================================
# Verificación de datos (normal y, si existe, balanceado offline)
# =============================================================================
from pathlib import Path

RAW  = ROOT/"data"/"raw"/"udacity"
PROC = ROOT/"data"/"processed"

RUNS = ["circuito1","circuito2"]  # ajusta si hace falta

missing = []
for run in RUNS:
    base = PROC / run

    # Comprobación obligatoria: splits normales
    for part in ["train","val","test"]:
        p = base / f"{part}.csv"
        if not p.exists():
            missing.append(str(p))

    # Comprobación opcional: train_balanced.csv (para modo OFFLINE balanceado)
    p_bal = base / "train_balanced.csv"
    if p_bal.exists():
        print(f"✓ {p_bal} OK")
    else:
        print(f"⚠️  Falta {p_bal}. Si más abajo pones USE_OFFLINE_BALANCED=True, "
              f"ejecuta 01A_PREP_BALANCED.ipynb o el script tools/make_splits_balanced.py")

if missing:
    raise FileNotFoundError(
        "Faltan CSV obligatorios (ejecuta 01A_PREP_BALANCED.ipynb o tu pipeline de prep):\n"
        + "\n".join(" - " + m for m in missing)
    )

print("OK: splits 'train/val/test' encontrados.")


✓ /home/cesar/proyectos/TFM_SNN/data/processed/circuito1/train_balanced.csv OK
✓ /home/cesar/proyectos/TFM_SNN/data/processed/circuito2/train_balanced.csv OK
OK: splits 'train/val/test' encontrados.


In [4]:
# ===================== Balanceo: helper =====================
print(
    "Modo balanceo:",
    "OFFLINE (tasks_balanced.json)" if USE_OFFLINE_BALANCED else "ORIGINAL (tasks.json)",
    "| Balanceo ONLINE:", USE_ONLINE_BALANCING
)

# Seguridad anti doble balanceo:
if USE_OFFLINE_BALANCED and USE_ONLINE_BALANCING:
    raise RuntimeError("Doble balanceo detectado: OFFLINE y ONLINE a la vez. "
                       "Pon USE_ONLINE_BALANCING=False cuando uses train_balanced.csv.")

from pathlib import Path  # (omite esta línea si ya importaste Path arriba)

def _balance_flag(train_csv_path: str | Path) -> bool:
    """
    Activa balanceo ONLINE solo si:
    - USE_ONLINE_BALANCING == True
    - Y el CSV de train NO es 'train_balanced.csv'
    """
    is_balanced_csv = Path(train_csv_path).name == "train_balanced.csv"
    return bool(USE_ONLINE_BALANCING and not is_balanced_csv)


Modo balanceo: OFFLINE (tasks_balanced.json) | Balanceo ONLINE: False


In [5]:
# =============================================================================
# Función para crear loaders de una tarea dada (respeta cfg del preset)
# =============================================================================
def make_loader_fn(task, batch_size, encoder, T, gain, tfm, seed,
                   num_workers=NUM_WORKERS, prefetch_factor=PREFETCH,
                   pin_memory=PIN_MEMORY, persistent_workers=PERSISTENT):
    from pathlib import Path
    name  = task["name"]
    paths = task["paths"]

    pw = persistent_workers and (num_workers > 0)
    pf = prefetch_factor if (num_workers > 0) else None

    # CLAVE: si vamos a codificar en GPU, el loader debe dar (B,C,H,W):
    encoder_for_loader = "image" if GPU_ENCODE else encoder

    return make_loaders_from_csvs(
        base_dir=RAW/name,
        train_csv=Path(paths["train"]),
        val_csv=Path(paths["val"]),
        test_csv=Path(paths["test"]),
        batch_size=batch_size,
        encoder=encoder_for_loader,
        T=T,
        gain=gain,
        tfm=tfm,
        seed=seed,
        # ---- Parche estabilidad WSL ----
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=pw,
        prefetch_factor=pf,
        aug_train=AUG_CFG,
        balance_train=_balance_flag(paths["train"]),
        balance_bins=21,
    )

print(f"[Loaders] workers={NUM_WORKERS} prefetch={PREFETCH} "
      f"pin={PIN_MEMORY} persistent={PERSISTENT} "
      f"offline_bal={USE_OFFLINE_BALANCED} online_bal={USE_ONLINE_BALANCING}")


[Loaders] workers=12 prefetch=2 pin=True persistent=True offline_bal=True online_bal=False


In [6]:
import src.training as training

# =============================================================================
# Helper de evaluación (permuta a (T,B,C,H,W) y usa copias no bloqueantes)
# =============================================================================
def eval_loader(loader, model, device):
    """Calcula MAE/MSE promediados sobre todo el loader.

    - El DataLoader produce (B, T, C, H, W)
    - El modelo espera      (T, B, C, H, W)
    """
    model.eval()  # modo evaluación: desactiva dropout/batchnorm, etc.
    mae_sum = 0.0
    mse_sum = 0.0
    n = 0

    # Un solo no_grad() fuera del bucle para minimizar overhead
    with torch.no_grad():
        for x, y in loader:
            # (B,T,C,H,W) -> (T,B,C,H,W), y luego a GPU con non_blocking=True
            x = training._permute_if_needed(x).to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            y_hat = model(x)

            # Acumula MAE/MSE ponderados por el tamaño real del batch
            mae_sum += torch.mean(torch.abs(y_hat - y)).item() * len(y)
            mse_sum += torch.mean((y_hat - y) ** 2).item() * len(y)
            n += len(y)

    return (mae_sum / max(n, 1)), (mse_sum / max(n, 1))


In [7]:
# =============================================================================
# Elegir split: normal (tasks.json) o balanceado offline (tasks_balanced.json)
# =============================================================================
with open(PROC / ("tasks_balanced.json" if USE_OFFLINE_BALANCED else "tasks.json"), "r", encoding="utf-8") as f:
    tasks_json = json.load(f)

task_list = [{"name": n, "paths": tasks_json["splits"][n]} for n in tasks_json["tasks_order"]]

# Vista rápida: muestra el CSV de train que se usará por cada tarea
print("Tareas y su TRAIN CSV:")
for t in task_list:
    print(f" - {t['name']}: {Path(t['paths']['train']).name}")

task_list[:2]  # vista rápida

# Guardarraíl extra: si has activado el OFFLINE balanceado,
# exige que el 'train' sea train_balanced.csv y que exista.
if USE_OFFLINE_BALANCED:
    for t in task_list:
        train_path = Path(t["paths"]["train"])
        if train_path.name != "train_balanced.csv":
            raise RuntimeError(
                f"[{t['name']}] Esperaba 'train_balanced.csv' pero encontré '{train_path.name}'. "
                "Repite 01A_PREP_BALANCED.ipynb o ajusta USE_OFFLINE_BALANCED=False."
            )
        if not train_path.exists():
            raise FileNotFoundError(
                f"[{t['name']}] No existe {train_path}. Genera los balanceados con 01A_PREP_BALANCED.ipynb."
            )
    print("✔ Verificación OFFLINE balanceado superada (train_balanced.csv por tarea).")


Tareas y su TRAIN CSV:
 - circuito1: train_balanced.csv
 - circuito2: train_balanced.csv
✔ Verificación OFFLINE balanceado superada (train_balanced.csv por tarea).


In [8]:
# === PRUEBA UNIVERSAL: loader -> (T,B,C,H,W) -> forward con fallback AMP ===
import torch, src.training as training

# --- 1) Loader pequeño con tu helper ---
tr, va, te = make_loader_fn(
    task=task_list[0],
    batch_size=8,
    encoder="rate",   # si tu pipeline ya devuelve 4D, lo detectamos abajo
    T=10,
    gain=0.5,
    tfm=tfm,
    seed=SEED,
)

xb, yb = next(iter(tr))
print("batch del loader:", xb.shape, yb.shape)

# --- 2) A (T,B,C,H,W) según formato de entrada ---
#    - Si el dataset ya codifica (5D): solo permutar.
#    - Si es 4D (imagen): activamos encode en GPU y usamos el helper runtime.
if xb.ndim == 5:  # (B,T,C,H,W)
    x5d = xb.permute(1,0,2,3,4).contiguous()
    used_runtime_encode = False
    print("dataset ya codificado; solo permuto a (T,B,C,H,W)")
elif xb.ndim == 4:  # (B,C,H,W)
    training.set_runtime_encode(mode="rate", T=10, gain=0.5,
                                device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    x5d = training._permute_if_needed(xb)  # aplica encode+permuta -> (T,B,C,H,W)
    used_runtime_encode = True
    print("dataset 4D; uso encode en GPU y permuto a (T,B,C,H,W)")
else:
    raise RuntimeError(f"Forma inesperada del batch: {xb.shape}")

print("x5d.device:", x5d.device, "| shape:", tuple(x5d.shape))

# --- 3) Modelo y forward con fallback automático AMP ---
model = make_snn_model(tfm).to(device).eval()

def forward_with_auto_amp(model, x5d, device):
    # Intento 1: AMP (solo si hay CUDA)
    if torch.cuda.is_available():
        try:
            x_amp = x5d.to(device, dtype=torch.float16, non_blocking=True)
            with torch.inference_mode(), torch.amp.autocast('cuda', enabled=True):
                y = model(x_amp)
            print("[forward] ejecutado con AMP (fp16)")
            return y
        except Exception as e:
            print("[forward] AMP falló, reintento en FP32. Motivo:", str(e))

    # Intento 2: FP32 (CPU o fallback)
    x_fp32 = x5d.to(device, dtype=torch.float32, non_blocking=True)
    with torch.inference_mode():
        y = model(x_fp32)
    print("[forward] ejecutado en FP32")
    return y

yhat = forward_with_auto_amp(model, x5d, device)
print("yhat:", tuple(yhat.shape))

# --- 4) Limpieza del runtime encode (si se usó) ---
if used_runtime_encode:
    training.set_runtime_encode(None)


ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/cesar/proyectos/TFM_SNN/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/cesar/proyectos/TFM_SNN/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/home/cesar/proyectos/TFM_SNN/src/datasets.py", line 365, in __getitem__
    X = self._encode(x_img)  # (T, C, H, W) con C=1 o 3
        ^^^^^^^^^^^^^^^^^^^
  File "/home/cesar/proyectos/TFM_SNN/src/datasets.py", line 297, in _encode
    raise ValueError(f"Encoder desconocido: {self.cfg.encoder}")
ValueError: Encoder desconocido: image


In [None]:
# ===================== BENCH: toggle y eco de configuración =====================
# Usa el RUN_BENCH que ya defines en la Celda 2
print(
    f"[Bench RUN_BENCH={RUN_BENCH}] workers={NUM_WORKERS} prefetch={PREFETCH} "
    f"pin={PIN_MEMORY} persistent={PERSISTENT} "
    f"| offline_bal={USE_OFFLINE_BALANCED} online_bal={USE_ONLINE_BALANCING}"
)


In [None]:
GPU_ENCODE = True  # activa codificación (rate/latency/raw) en GPU
RUN_BENCH = False   # pon True para ejecutar

# --- SAFE MODE: desactiva todo lo pesado para estabilizar ---
SAFE_MODE = False  # o False cuando quieras rendimiento

NUM_WORKERS   = 12
PREFETCH      = 2
PIN_MEMORY    = True
PERSISTENT    = True

# AUG_CFG = None
from src.datasets import AugmentConfig
# --- Augment: perfiles ---
AUG_CFG_LIGHT = AugmentConfig(prob_hflip=0.5, brightness=None, gamma=None, noise_std=0.0)
AUG_CFG_FULL  = AugmentConfig(prob_hflip=0.5, brightness=(0.9, 1.1), gamma=(0.95, 1.05), noise_std=0.005)

AUG_CFG = AUG_CFG_LIGHT   # ← usa LIGHT ahora; luego prueba FULL

USE_OFFLINE_BALANCED  = True
USE_ONLINE_BALANCING  = False

if SAFE_MODE:
    NUM_WORKERS  = 0
    PREFETCH     = None   # ← IMPORTANTÍSIMO con num_workers=0
    PIN_MEMORY   = False
    PERSISTENT   = False  # ← también obligatorio con num_workers=0
    USE_OFFLINE_BALANCED = False
    USE_ONLINE_BALANCING = False
    AUG_CFG = None

print(f"[SAFE_MODE={SAFE_MODE}] workers={NUM_WORKERS} prefetch={PREFETCH} "
      f"pin={PIN_MEMORY} persistent={PERSISTENT}")



In [None]:
# ====================== run_continual (versión conservadora) ======================
from pathlib import Path
import json, time, torch
from src.utils import load_preset, set_seeds
from src.training import TrainConfig, train_supervised
from src.models import SNNVisionRegressor
from src.methods.ewc import EWC, EWCConfig
import src.training as training  # para _permute_if_needed y runtime encode

def run_continual(
    preset: str,                 # "fast" | "std" | "accurate"
    method: str,                 # "ewc" | "naive"
    lam: float | None,           # λ si EWC; None si naive
    seed: int,
    encoder: str,                # "rate" | "latency" | "raw"
    tfm,                         # ImageTransform
    fisher_batches_by_preset: dict[str,int] | None = None,
    epochs_override: int | None = None,   # override opcional de epochs
    verbose: bool = False,                # traza opcional
):
    cfg    = load_preset(ROOT/"configs"/"presets.yaml", preset)
    T      = int(cfg["T"])
    gain   = float(cfg["gain"])
    lr     = float(cfg["lr"])
    epochs = int(epochs_override if epochs_override is not None else cfg["epochs"])
    bs     = int(cfg["batch_size"])
    use_amp= bool(cfg["amp"])

    fb = 100
    if fisher_batches_by_preset and preset in fisher_batches_by_preset:
        fb = int(fisher_batches_by_preset[preset])

    set_seeds(seed)

    model = make_snn_model(tfm)
    ewc = None
    if method == "ewc":
        assert lam is not None, "Para EWC debes pasar λ (lam)"
        ewc = EWC(model, EWCConfig(lambd=float(lam), fisher_batches=fb))

    out_tag = f"continual_{preset}_{method}" + (f"_lam_{lam:.0e}" if method=='ewc' else "") + f"_{encoder}_seed_{seed}"
    out_dir = ROOT/"outputs"/out_tag
    out_dir.mkdir(parents=True, exist_ok=True)

    device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    loss_fn = torch.nn.MSELoss()
    tcfg    = TrainConfig(epochs=epochs, batch_size=bs, lr=lr, amp=use_amp, seed=seed)

    results = {}
    seen = []

    for i, t in enumerate(task_list):
        name = t["name"]
        if verbose:
            print(f"\n--- Tarea {i+1}/{len(task_list)}: {name} | preset={preset} | method={method} "
                  f"| λ={lam if method=='ewc' else '-'} | B={bs} T={T} AMP={use_amp} | enc={encoder} ---")

        # Loaders: si GPU_ENCODE=True la celda 5 ya los crea con encoder="image" (batches 4D)
        tr, va, te = make_loader_fn(
            task=t, batch_size=bs, encoder=encoder, T=T, gain=gain, tfm=tfm, seed=seed,
        )

        # Miramos el primer batch para decidir (robusto si algún día cambias el loader)
        xb0, yb0 = next(iter(tr))
        if verbose:
            print("  loader batch shape:", tuple(xb0.shape), "| y:", tuple(yb0.shape))

        # === EXACTAMENTE EL MISMO COMPORTAMIENTO DE ANTES, PERO MÁS SEGURO ===
        # Solo activamos encode en GPU si:
        # - tú lo has pedido (GPU_ENCODE=True), y
        # - el batch viene 4D (B,C,H,W) y necesita codificación temporal.
        used_runtime = False
        if 'GPU_ENCODE' in globals() and GPU_ENCODE and xb0.ndim == 4:
            training.set_runtime_encode(mode=encoder, T=T, gain=gain, device=device)
            used_runtime = True
            if verbose: print("  runtime encode: ON (GPU)")

        # Entrenamiento normal
        _ = train_supervised(
            model, tr, va, loss_fn, tcfg,
            out_dir/f"task_{i+1}_{name}",
            method=ewc if method=="ewc" else None
        )

        # EWC: Fisher al final de la tarea
        if method=="ewc":
            print("Estimando Fisher…")
            ewc.estimate_fisher(va, loss_fn, device=device)

        # evaluación tarea actual
        te_mae, te_mse = eval_loader(te, model, device)
        results[name] = {"test_mae": te_mae, "test_mse": te_mse}
        seen.append((name, te))

        # reevaluación tareas previas (olvido)
        for pname, p_loader in seen[:-1]:
            p_mae, p_mse = eval_loader(p_loader, model, device)
            results[pname][f"after_{name}_mae"] = p_mae
            results[pname][f"after_{name}_mse"] = p_mse

        # Apaga el runtime cuando acabes la tarea (igual que antes)
        if used_runtime:
            training.set_runtime_encode(None)
            if verbose: print("  runtime encode: OFF")

    (out_dir/"continual_results.json").write_text(json.dumps(results, indent=2), encoding="utf-8")
    return out_dir, results



In [None]:
# === Activar métrica de it/s por época (parche temporal) ===
import time, json
from pathlib import Path
import torch
from torch import nn, optim
from torch.amp import autocast, GradScaler
import src.training as training
from src.utils import set_seeds  # ya lo tienes importado en el notebook

# Guarda la referencia al original para poder restaurar luego
orig_train_supervised = training.train_supervised

def train_supervised_ips(model: nn.Module, train_loader, val_loader, loss_fn: nn.Module,
                         cfg, out_dir: Path, method=None):
    out_dir = Path(out_dir); out_dir.mkdir(parents=True, exist_ok=True)
    if cfg.seed is not None:
        set_seeds(cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    opt = optim.Adam(model.parameters(), lr=cfg.lr)

    use_amp = bool(cfg.amp and torch.cuda.is_available())
    scaler = GradScaler(enabled=use_amp)

    history = {"train_loss": [], "val_loss": []}
    t0_total = time.perf_counter()

    for epoch in range(1, cfg.epochs + 1):
        model.train()
        running = 0.0
        nb = 0
        t_epoch0 = time.perf_counter()

        for x, y in train_loader:
            # encode/permutación runtime y subida a device (usa tu helper actual)
            x = training._permute_if_needed(x).to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            with autocast("cuda", enabled=use_amp):
                y_hat = model(x)
                loss = loss_fn(y_hat, y)
                if method is not None:
                    loss = loss + method.penalty()

            if use_amp:
                scaler.scale(loss).backward()
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(opt); scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                opt.step()

            running += loss.item()
            nb += 1

        epoch_time = time.perf_counter() - t_epoch0
        ips = nb / epoch_time if epoch_time > 0 else float("nan")
        print(f"[TRAIN it/s] epoch {epoch}/{cfg.epochs}: {ips:.1f} it/s  "
              f"({nb} iters en {epoch_time:.2f}s)")

        train_loss = running / max(1, nb)

        # --- validación ---
        model.eval()
        v_running = 0.0; nvb = 0
        with torch.no_grad():
            for x, y in val_loader:
                x = training._permute_if_needed(x).to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)
                with autocast("cuda", enabled=use_amp):
                    y_hat = model(x)
                    v_loss = loss_fn(y_hat, y)
                v_running += v_loss.item(); nvb += 1
        val_loss = v_running / max(1, nvb)
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)

    elapsed = time.perf_counter() - t0_total
    manifest = {
        "epochs": cfg.epochs, "batch_size": cfg.batch_size, "lr": cfg.lr,
        "amp": cfg.amp, "seed": cfg.seed, "elapsed_sec": elapsed,
        "device": str(device), "history": history,
    }
    (out_dir / "manifest.json").write_text(json.dumps(manifest, indent=2), encoding="utf-8")
    return history

# Activa el parche
training.train_supervised = train_supervised_ips
print("✅ it/s por época ACTIVADO. Para desactivarlo: training.train_supervised = orig_train_supervised")


In [None]:
# ===================== Demo smoke (naive + ewc, 2 épocas) =====================
RUN_DEMO = False
if RUN_DEMO:
    preset_demo  = "std"
    seed_demo    = 42
    enc_demo     = load_preset(ROOT / "configs" / "presets.yaml", preset_demo)["encoder"]

    print("\n>>> NAIVE (smoke)")
    out_path, res = run_continual(
        preset=preset_demo, method="naive", lam=None,
        seed=seed_demo, encoder=enc_demo, tfm=tfm,
        fisher_batches_by_preset={"std": 600},
        epochs_override=2,   # smoke rápido
        verbose=True
    )
    print("OK:", out_path)

    print("\n>>> EWC (smoke)")
    out_path, res = run_continual(
        preset=preset_demo, method="ewc", lam=1e9,
        seed=seed_demo, encoder=enc_demo, tfm=tfm,
        fisher_batches_by_preset={"std": 600},
        epochs_override=2,   # smoke rápido
        verbose=True
    )
    print("OK:", out_path)


In [None]:
# Mini-sweep de λ en preset std (rápido)
# lams = [3e8, 5e8, 7e8]
# out_runs = []
# for lam in lams:
#     print(f"\n>>> EWC SMOKE λ={lam:.0e}")
#     out_dir, _ = run_continual(
#         preset="std",
#         method="ewc",
#         lam=lam,
#         seed=42,
#         encoder="rate",
#         tfm=tfm,
#         fisher_batches_by_preset={"std": 600},
#         epochs_override=2,      # smoke rápido
#         verbose=True,
#     )
#     out_runs.append(out_dir)
# print("\nHecho:", out_runs)


In [None]:
# ====================== CONFIG POR DEFECTO PARA LAS COMPARATIVAS ======================
# Recomendaciones acordadas:
# - fast: EWC λ=1e9 (estable). Extra: λ=1e8 (mejor T2, algo más de olvido)
# - std : EWC λ=1e7 (baseline estable). Extra: λ=3e7 (mejor T2, más olvido)

# EWC_DEFAULTS = {
#     "fast":     {"primary": [1e9, 3e8],      "extra": [3e9]},
#     "std": {"primary": [7e8, 1e9, 1.2e9, 1.5e9], "extra": []},
#     "accurate": {"primary": [1e6, 3e6, 1e7], "extra": []},  # ← mucho más bajo que 1e9
# }

EWC_DEFAULTS = {
    "fast": {"primary": [7e8, 1e9, 1.2e9], "extra": []},
    "std":  {"primary": [5e8, 7e8, 1e9],   "extra": []},
    "accurate": {"primary": [1e6, 3e6, 1e7], "extra": []}, 
    # de momento accurate solo NAIVE; EWC en accurate lo aparcamos
    # "accurate": {"primary": [], "extra": []},
}

INCLUDE_NAIVE    = True          # añade baseline sin EWC
INCLUDE_EXTRAS   = False          # activa los λ "extra" por preset
SEEDS            = [42]  # multisemillas para medias/σ
# SEEDS            = [42, 43, 44]  # multisemillas para medias/σ
ENCODERS         = ["rate"]      # luego podrás añadir "latency"
# FISHER_BY_PRESET = {"fast": 200, "std": 600, "accurate": 600}  # estabiliza el cálculo de Fisher
# FISHER_BY_PRESET = {"fast": 800, "std": 1000}  # estabiliza el cálculo de Fisher
# Endurecemos Fisher para estabilidad; accurate necesita más por ser más largo
# FISHER_BY_PRESET = {"fast": 800, "std": 1200, "accurate": 1500}
FISHER_BY_PRESET = {"fast": 800, "std": 1000, "accurate": 1500}
# Elige qué presets lanzar
PRESETS_TO_RUN = ["fast", "std", "accurate"]  # añade "accurate" si lo necesitas más adelante
# PRESETS_TO_RUN = ["accurate"]  # añade "accurate" si lo necesitas más adelante

# ---- Construcción del plan de ejecuciones ----
runs_plan = []
for preset_i in PRESETS_TO_RUN:
    # EWC primary
    for lam in EWC_DEFAULTS[preset_i]["primary"]:
        runs_plan.append((preset_i, "ewc", lam))
    # EWC extras (opcionales)
    if INCLUDE_EXTRAS:
        for lam in EWC_DEFAULTS[preset_i]["extra"]:
            runs_plan.append((preset_i, "ewc", lam))
    # Baseline sin EWC
    if INCLUDE_NAIVE:
        runs_plan.append((preset_i, "naive", None))

print("Plan de runs (preset, método, λ):")
for preset_i, method_i, lam_i in runs_plan:
    print(f"  {preset_i:>7}  {method_i:>5}  λ={lam_i}")
print("Semillas:", SEEDS, " | Encoders:", ENCODERS)
print("Fisher batches por preset:", FISHER_BY_PRESET)


In [None]:
# ====================== DRIVER MULTISEMILLAS (usa la CONFIG de arriba) ======================
for enc in ENCODERS:
    for seed in SEEDS:
        for preset_i, method_i, lam_i in runs_plan:
            print(f"\n=== RUN: preset={preset_i} | method={method_i} | λ={lam_i} | seed={seed} | encoder={enc} ===")
            out_path, _ = run_continual(
                preset=preset_i,
                method=method_i,
                lam=(lam_i if method_i == "ewc" else None),
                seed=seed,
                encoder=enc,
                tfm=tfm,  # definido en tu celda de setup
                fisher_batches_by_preset=FISHER_BY_PRESET,
            )
            print("OK:", out_path)
print("\n✅ Listo. Ejecuta las celdas de resumen.")


In [None]:
# =============================================================================
# Resumen comparativo de todos los continual_* en outputs/
# =============================================================================
import re, json
from pathlib import Path
import pandas as pd

def parse_exp_name(name: str):
    """
    Extrae preset, método, lambda, encoder y seed del nombre de carpeta:

      continual_<preset>_<method>[_lam_<lambda>]_<_encoder>[_seed_<seed>]

    Ejemplos:
      continual_fast_naive_rate_seed_42
      continual_fast_ewc_lam_1e+08_rate_seed_42
      continual_std_ewc_lam_3e+07_latency_seed_43
    """
    m = re.match(
        r"continual_(?P<preset>\w+)_(?P<method>ewc|naive)"
        r"(?:_lam_(?P<lambda>[^_]+))?_(?P<enc>[^_]+)"
        r"(?:_seed_(?P<seed>\d+))?$",
        name
    )
    meta = {"preset": None, "method": None, "lambda": None, "encoder": None, "seed": None}
    if m:
        d = m.groupdict()
        meta.update({
            "preset": d["preset"],
            "method": d["method"],
            "lambda": d.get("lambda"),
            "encoder": d.get("enc"),
            "seed": d.get("seed"),
        })
    return meta

rows = []
root_out = ROOT / "outputs"

for exp_dir in sorted(root_out.glob("continual_*")):
    name = exp_dir.name
    meta = parse_exp_name(name)

    # Saltar nombres no reconocidos (runs muy antiguos)
    if meta["preset"] is None:
        continue

    results_path = exp_dir / "continual_results.json"
    if not results_path.exists():
        continue

    with open(results_path, "r", encoding="utf-8") as f:
        res = json.load(f)

    # Detectar tareas: la "última" es la que NO tiene claves 'after_*'
    task_names = list(res.keys())
    if len(task_names) < 2:
        continue

    def is_last(d):  # no tiene after_*
        return not any(k.startswith("after_") for k in d.keys())

    last_task = None
    first_task = None
    for tn in task_names:
        if is_last(res[tn]):
            last_task = tn
        else:
            first_task = tn

    # Fallback por si no se identifica bien
    if first_task is None or last_task is None:
        task_names_sorted = sorted(task_names)
        first_task = task_names_sorted[0]
        last_task  = task_names_sorted[-1]

    c1, c2 = first_task, last_task

    c1_test_mae = float(res[c1].get("test_mae", float("nan")))
    c2_test_mae = float(res[c2].get("test_mae", float("nan")))
    after_key_mae = f"after_{c2}_mae"
    c1_after_c2_mae = float(res[c1].get(after_key_mae, float("nan")))

    forgetting_abs = c1_after_c2_mae - c1_test_mae
    forgetting_rel = (forgetting_abs / c1_test_mae * 100.0) if c1_test_mae == c1_test_mae else float("nan")

    rows.append({
        "exp": name,
        "preset": meta["preset"],
        "method": meta["method"],
        "lambda": meta["lambda"] if meta["method"] == "ewc" else None,
        "encoder": meta["encoder"],
        "seed": int(meta["seed"]) if meta["seed"] is not None else None,
        "c1_name": c1,
        "c2_name": c2,
        "c1_mae": c1_test_mae,
        "c1_after_c2_mae": c1_after_c2_mae,
        "c1_forgetting_mae_abs": forgetting_abs,
        "c1_forgetting_mae_rel_%": forgetting_rel,
        "c2_mae": c2_test_mae,
    })

df = pd.DataFrame(rows)

# Asegura columnas numéricas auxiliares
if "lambda_num" not in df.columns:
    df["lambda_num"] = pd.to_numeric(df["lambda"], errors="coerce")  # '1e+08' -> 1e+08 ; NAIVE -> NaN

# Deja 'seed' como entero y elimina 'seed_num' si existe
df["seed"] = pd.to_numeric(df["seed"], errors="coerce").astype("Int64")
if "seed_num" in df.columns:
    df = df.drop(columns=["seed_num"])

# Ordenar: preset, method, encoder, lambda_num (NaN al final), seed
df = df.sort_values(
    by=["preset", "method", "encoder", "lambda_num", "seed"],
    na_position="last",
    ignore_index=True,
)

df


In [None]:
# ====================== Vista agregada (media±std por preset/method/λ/encoder) ======================
import pandas as pd

# Métricas a agregar
cols_metrics = ["c1_mae", "c1_after_c2_mae", "c1_forgetting_mae_abs", "c1_forgetting_mae_rel_%", "c2_mae"]

# Copia y asegura columna numérica auxiliar para ordenar por λ
gdf = df.copy()
if "lambda_num" not in gdf.columns:
    gdf["lambda_num"] = pd.to_numeric(gdf["lambda"], errors="coerce")  # NA para NAIVE

# Agregación: media, std y número de corridas (semillas) por combinación
agg = (
    gdf
    .groupby(["preset", "method", "encoder", "lambda", "lambda_num"], dropna=False)[cols_metrics]
    .agg(["mean", "std", "count"])
    .reset_index()
)

# Aplanar nombres de columnas (de MultiIndex a una sola capa)
agg.columns = [
    "_".join(filter(None, map(str, col))).rstrip("_")
    for col in agg.columns.to_flat_index()
]

# Ordena por preset/method/encoder/λ_num (NaN al final ⇒ NAIVE al final de su grupo)
agg = agg.sort_values(
    by=["preset", "method", "encoder", "lambda_num"],
    na_position="last",
    ignore_index=True,
)

# (Opcional) guardar a CSV
summary_dir = ROOT / "outputs" / "summary"
summary_dir.mkdir(parents=True, exist_ok=True)
agg.to_csv(summary_dir / "continual_summary_agg.csv", index=False)
print("Guardado:", summary_dir / "continual_summary_agg.csv")

agg


In [None]:
# ====================== Formateo para la memoria (tabla compacta) ======================

def fmt(x, prec=4):
    # Redondea y gestiona NaN de forma amigable
    import pandas as pd
    return "" if pd.isna(x) else f"{x:.{prec}f}"

show = agg.copy()

# 1) Crea 'count' a partir de cualquiera de las columnas *_count
count_cols = [c for c in show.columns if c.endswith("_count")]
if count_cols:
    show["count"] = show[count_cols[0]].astype("Int64")  # todas deberían coincidir
    # (opcional) elimina las columnas *_count individuales
    show = show.drop(columns=count_cols)

# 2) Redondea columnas de medias/desviaciones
for c in [c for c in show.columns if c.endswith("_mean") or c.endswith("_std")]:
    show[c] = show[c].map(lambda v: fmt(v, 4))

# 3) Selección de columnas clave (ajusta el orden a tu gusto)
cols = [
    "preset", "method", "encoder", "lambda",
    "c1_mae_mean", "c1_forgetting_mae_rel_%_mean", "c2_mae_mean",
    "c1_mae_std",  "c1_forgetting_mae_rel_%_std",  "c2_mae_std",
    "count"
]

# Si alguna columna no existiera (según tus métricas), la ignoramos con aviso
missing = [c for c in cols if c not in show.columns]
if missing:
    print("Aviso: faltan columnas en 'show':", missing)
    cols = [c for c in cols if c in show.columns]

show = show[cols].rename(columns={
    "preset": "preset",
    "method": "método",
    "encoder": "codificador",
    "lambda": "λ",
    "c1_mae_mean": "MAE Tarea1 (media)",
    "c1_forgetting_mae_rel_%_mean": "Olvido T1 (%) (media)",
    "c2_mae_mean": "MAE Tarea2 (media)",
    "c1_mae_std": "MAE Tarea1 (σ)",
    "c1_forgetting_mae_rel_%_std": "Olvido T1 (%) (σ)",
    "c2_mae_std": "MAE Tarea2 (σ)",
    "count": "n (semillas)"
})

show
