# 03_TRAIN_CONTINUAL — Entrenamiento Continual con *presets*

**Qué hace este notebook:**

Este notebook entrena y evalúa modelos en **aprendizaje continual** usando una configuración unificada desde `configs/presets.yaml`.  
Permite: (1) lanzar un *run* base, (2) comparar métodos continual con idéntica configuración de datos/modelo, y (3) generar un **resumen** y **agregados** de resultados.

---

## 🎯 Objetivos
- Centralizar la configuración (modelo, datos, optimizador, continual) vía `presets.yaml`.
- Entrenar con **H5 offline** o **CSV+runtime encode** (auto-detectado).
- Comparar métodos (`naive`, `ewc`, `rehearsal`, `rehearsal+ewc`) de forma reproducible.
- Exportar resúmenes a `outputs/summary/continual_summary_agg.csv`.

## ✅ Prerrequisitos
- `data/processed/tasks.json` o `data/processed/tasks_balanced.json` generados.
- Si usas **offline** (H5), haberlos creado con `tools/encode_tasks.py`.
- Revisar `configs/presets.yaml` (secciones: `model`, `data`, `optim`, `continual`).

## ⚠️ Notas importantes
- No combines `use_offline_spikes=True` con `encode_runtime=True`.
- La **semilla** del experimento viene de `CFG["data"]["seed"]` (reproducibilidad).
- El nombre de la carpeta de salida codifica preset, método y meta (ver `src/runner.py`).

<a id="toc"></a>

## 🧭 Índice

- [1) Setup del entorno y paths](#sec-01)
- [2) Carga del preset unificado (`configs/presets.yaml`)](#sec-02)
- [3) Verificación de datos y selección de `tasks.json`](#sec-03)
- [4) Factory de DataLoaders (H5 offline o CSV + runtime encode)](#sec-04)
- [5) Factory del modelo](#sec-05)
- [6) (Opcional) Parche: imprimir *it/s* por época](#sec-06)
- [7) Ejecución base con el preset (eco de config + run)](#sec-07)
- [8) Comparativa de métodos (mismo preset/semilla/datos)](#sec-08)
- [9) Barrido de combinaciones (opcional)](#sec-09)
- [10) Resumen completo: inventario → parseo → agregados → tabla](#sec-10)



<a id="sec-01"></a>

## 1) Setup del entorno y paths

**Objetivo:** preparar el entorno de ejecución con granularidad de hilos, selección de dispositivo y rutas del proyecto.

- Fija variables de entorno para limitar hilos BLAS (reproducibilidad y evitar oversubscription).
- Detecta `ROOT` (raíz del repo) y lo añade a `sys.path`.
- Importa utilidades del proyecto (datasets, modelos, presets).
- Selecciona dispositivo (`cuda` si está disponible).
- Activa optimizaciones de PyTorch (TF32/cuDNN) para acelerar entrenamiento en GPU.

> **Nota:** No se leen presets aquí todavía; únicamente se configura el runtime global.

[↑ Volver al índice](#toc)

In [None]:
# =============================================================================
# Imports y setup de entorno (threads, paths, dispositivo)
# =============================================================================
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

from pathlib import Path
import sys, json, torch

# Raíz del repo y sys.path
ROOT = Path.cwd().parents[0] if (Path.cwd().name == "notebooks") else Path.cwd()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

# Librerías del proyecto
from src.datasets import ImageTransform, AugmentConfig
from src.models import build_model, default_tfm_for_model
from src.utils import load_preset

# Dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Ajustes de rendimiento (opcional)
torch.set_num_threads(4)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")

print("Device:", device)

<a id="sec-02"></a>

## 2) Carga del preset unificado (`configs/presets.yaml`)

**Objetivo:** cargar un *preset* y derivar toda la configuración de trabajo.

Contenido:
- Lee el preset (`PRESET = "fast" | "std" | "accurate"`).
- Construye el `ImageTransform` según el modelo del preset.
- Extrae parámetros de **datos/codificación temporal** (`ENCODER`, `T`, `GAIN`, `SEED`).
- Extrae configuración del **DataLoader** (workers, prefetch, pin/persistent).
- Prepara *augment* opcional (`AUG_CFG`) y **balanceo online** (si procede).
- Guardarraíl: prohíbe usar a la vez `use_offline_spikes=True` y `encode_runtime=True`.

> **Consejo:** cambia el valor de `PRESET` aquí para barrer configuraciones sin tocar código en más sitios.

[↑ Volver al índice](#toc)

In [None]:
# =============================================================================
# Config global: presets.yaml
# =============================================================================
from pathlib import Path
from src.datasets import ImageTransform, AugmentConfig
from src.utils import load_preset

PRESET = "fast"  # cambia aquí el preset a usar: fast | std | accurate

CFG = load_preset(ROOT / "configs" / "presets.yaml", PRESET)

# ---- Modelo / Transform ------------------------------------------------------
MODEL_NAME = CFG["model"]["name"]
tfm = ImageTransform(
    CFG["model"]["img_w"], CFG["model"]["img_h"],
    to_gray=bool(CFG["model"]["to_gray"]),
    crop_top=None
)

# ---- Datos / codificación temporal ------------------------------------------
ENCODER  = CFG["data"]["encoder"]
T        = int(CFG["data"]["T"])
GAIN     = float(CFG["data"]["gain"])
SEED     = int(CFG["data"]["seed"])

USE_OFFLINE_SPIKES   = bool(CFG["data"]["use_offline_spikes"])
USE_OFFLINE_BALANCED = bool(CFG["data"]["use_offline_balanced"])
RUNTIME_ENCODE       = bool(CFG["data"]["encode_runtime"])   # == runtime_encode

# ---- DataLoader / augment / balanceo ----------------------------------------
NUM_WORKERS = int(CFG["data"]["num_workers"])
PREFETCH    = CFG["data"]["prefetch_factor"]
PIN_MEMORY  = bool(CFG["data"]["pin_memory"])
PERSISTENT  = bool(CFG["data"]["persistent_workers"])

AUG_CFG = AugmentConfig(**CFG["data"]["aug_train"]) if CFG["data"]["aug_train"] else None

USE_ONLINE_BALANCING = bool(CFG["data"]["balance_online"])
BAL_BINS             = int(CFG["data"]["balance_bins"])
BAL_EPS              = float(CFG["data"]["balance_smooth_eps"])

# Guardarraíles
if USE_OFFLINE_SPIKES and RUNTIME_ENCODE:
    raise RuntimeError("Config inválida: use_offline_spikes=True y encode_runtime=True a la vez.")

print(f"[PRESET={PRESET}] model={MODEL_NAME} {tfm.w}x{tfm.h} gray={tfm.to_gray}")
print(f"[DATA] encoder={ENCODER} T={T} gain={GAIN} seed={SEED}")
print(f"[LOADER] workers={NUM_WORKERS} prefetch={PREFETCH} pin={PIN_MEMORY} persistent={PERSISTENT}")
print(f"[BALANCE] offline={USE_OFFLINE_BALANCED} online={USE_ONLINE_BALANCING} bins={BAL_BINS}")
print(f"[RUNTIME_ENCODE] {RUNTIME_ENCODE} | [OFFLINE_SPIKES] {USE_OFFLINE_SPIKES}")

<a id="sec-03"></a>

## 3) Verificación de datasets y selección de `tasks.json`

Comprueba que existen los CSV de `train/val/test` por tarea, y (si corresponde)
que `train_balanced.csv` está disponible para el modo **offline balanceado**.

- Lee `tasks_balanced.json` si `USE_OFFLINE_BALANCED=True`; si faltan, cae a `tasks.json`.
- Construye `task_list` con rutas por split.

> *Salida esperada:* listado de tareas y su CSV de `train`. Mensaje de OK/aviso.

[↑ Volver al índice](#toc)

In [None]:
# =============================================================================
# Verificación de datos (splits y, si procede, H5)
# =============================================================================
PROC = ROOT/"data"/"processed"
TASKS_FILE = PROC / ("tasks_balanced.json" if USE_OFFLINE_BALANCED else "tasks.json")
with open(TASKS_FILE, "r", encoding="utf-8") as f:
    tasks_json = json.load(f)

task_list = [{"name": n, "paths": tasks_json["splits"][n]} for n in tasks_json["tasks_order"]]

print("Tareas y TRAIN CSV a usar:")
for t in task_list:
    print(f" - {t['name']}: {Path(t['paths']['train']).name}")

if USE_OFFLINE_BALANCED:
    missing = []
    for t in task_list:
        p = Path(t["paths"]["train"])
        if p.name != "train_balanced.csv" or not p.exists():
            missing.append(str(p))
    if missing:
        print("[WARN] Faltan balanceados:", missing, " → usando tasks.json (no balanceado).")
        USE_OFFLINE_BALANCED = False
        with open(PROC/"tasks.json", "r", encoding="utf-8") as f:
            tasks_json = json.load(f)
        task_list = [{"name": n, "paths": tasks_json["splits"][n]} for n in tasks_json["tasks_order"]]

print("OK: verificación de splits.")

print(f"Preset en uso: {PRESET}")

<a id="sec-04"></a>

## 4) Factory de DataLoaders (offline H5 o CSV + encode en runtime)

Construye un **builder** unificado de loaders:

- `_raw_make_loader_fn` decide entre H5 offline o CSV según flags.
- `make_loader_fn(...)` es un **wrapper pass-through**: recibe parámetros del runner y
  **propaga** kwargs de DataLoader (workers, prefetch, pin_memory, persistent, augment, balanceo).

> Usa este `make_loader_fn` en el runner para no duplicar lógica en el notebook.

[↑ Volver al índice](#toc)

In [None]:
# =============================================================================
# Factory de loaders (elige H5 offline o CSV + runtime encode) + pass-through de kwargs
# =============================================================================
from src.utils import build_make_loader_fn

_raw_make_loader_fn = build_make_loader_fn(
    root=ROOT,
    use_offline_spikes=USE_OFFLINE_SPIKES,
    runtime_encode=RUNTIME_ENCODE,
)

def make_loader_fn(task, batch_size, encoder, T, gain, tfm, seed, **dl_kwargs):
    """Wrapper que solo pasa los kwargs al factory real (runner le añade dl_kwargs)."""
    return _raw_make_loader_fn(
        task=task, batch_size=batch_size, encoder=encoder, T=T, gain=gain, tfm=tfm, seed=seed,
        **dl_kwargs
    )

print("make_loader_fn listo (pass-through de kwargs del runner).")

<a id="sec-05"></a>

## 5) Factory del modelo

Devuelve el modelo según el nombre del preset (`MODEL_NAME`).

- Si es `pilotnet_snn`, aplica hiperparámetros de neurona (`beta`, `threshold`).
- Para otros modelos, estos kwargs se ignoran.

> *Salida esperada:* impresión del nombre del modelo elegido.

[↑ Volver al índice](#toc)


In [None]:
# =============================================================================
# Construcción del modelo (factory)
# =============================================================================
def make_model_fn(tfm):
    """
    Devuelve el modelo con los hyperparámetros de neuronas (beta/threshold).
    Para 'pilotnet_snn' estos kwargs aplican; para otros modelos se ignoran.
    """
    return build_model(MODEL_NAME, tfm, beta=0.9, threshold=0.5)

print("Modelo:", MODEL_NAME)

<a id="sec-06"></a>

## 6) (Opcional) Parche: imprimir iteraciones/segundo por época

Sobrescribe temporalmente `training.train_supervised` para:

- Medir **it/s** por época (útil para benchmarks de rendimiento).
- Mantener el resto del entrenamiento sin cambios funcionales.

> Para restaurar el comportamiento original: `training.train_supervised = orig_train_supervised`.

[↑ Volver al índice](#toc)


In [None]:
# =============================================================================
# (Opcional) Parche: it/s + Early Stopping (controlado por preset)
# =============================================================================
import time, json
from pathlib import Path
import torch
from torch import nn, optim
from torch.amp import autocast, GradScaler

import src.training as training
from src.utils import set_seeds

orig_train_supervised = training.train_supervised  # backup

def train_supervised_ips_es(model: nn.Module, train_loader, val_loader, loss_fn: nn.Module,
                            cfg, out_dir: Path, method=None):
    """
    it/s + Early Stopping:
      - Activo si cfg.es_patience y cfg.es_min_delta no son None.
      - Criterio: min val_loss con tolerancia es_min_delta.
    Escribe manifest.json con 'history' y 'early_stop_epoch' (si aplica).
    """
    out_dir = Path(out_dir); out_dir.mkdir(parents=True, exist_ok=True)
    if cfg.seed is not None:
        set_seeds(cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    opt = optim.Adam(model.parameters(), lr=cfg.lr)
    use_amp = bool(cfg.amp and torch.cuda.is_available())
    scaler = GradScaler(enabled=use_amp)

    # ES params (leídos del preset)
    patience = getattr(cfg, "es_patience", None)
    min_delta = getattr(cfg, "es_min_delta", None)
    use_es = (patience is not None) and (min_delta is not None)

    best_val = float("inf")
    wait = 0
    early_stop_epoch = None

    history = {"train_loss": [], "val_loss": []}
    for epoch in range(1, cfg.epochs + 1):
        # -------- train --------
        model.train()
        running = 0.0; nb = 0
        t0 = time.perf_counter()

        for x, y in train_loader:
            x = training._permute_if_needed(x).to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            with autocast("cuda", enabled=use_amp):
                y_hat = model(x)
                loss = loss_fn(y_hat, y)
                if method is not None:
                    loss = loss + method.penalty()

            if use_amp:
                scaler.scale(loss).backward()
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(opt); scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                opt.step()

            running += loss.item(); nb += 1

        dt = time.perf_counter() - t0
        ips = nb / dt if dt > 0 else float("nan")
        print(f"[TRAIN it/s] epoch {epoch}/{cfg.epochs}: {ips:.1f} it/s  ({nb} iters en {dt:.2f}s)")
        train_loss = running / max(1, nb)

        # -------- val --------
        model.eval()
        v_running = 0.0; nvb = 0
        with torch.no_grad():
            for x, y in val_loader:
                x = training._permute_if_needed(x).to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)
                with autocast("cuda", enabled=use_amp):
                    y_hat = model(x)
                    v_loss = loss_fn(y_hat, y)
                v_running += v_loss.item(); nvb += 1
        val_loss = v_running / max(1, nvb)

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)

        # -------- Early Stopping check --------
        if use_es:
            improved = (best_val - val_loss) > float(min_delta)
            if improved:
                best_val = val_loss
                wait = 0
            else:
                wait += 1
                if wait >= int(patience):
                    early_stop_epoch = epoch
                    print(f"[EarlyStopping] Stop en epoch={epoch} (best_val={best_val:.6f})")
                    break

    manifest = {
        "epochs": cfg.epochs, "batch_size": cfg.batch_size, "lr": cfg.lr,
        "amp": cfg.amp, "seed": cfg.seed, "history": history,
        "early_stop_epoch": early_stop_epoch,
    }
    (out_dir / "manifest.json").write_text(json.dumps(manifest, indent=2), encoding="utf-8")
    return history

training.train_supervised = train_supervised_ips_es
print("Parche it/s + EarlyStopping ACTIVADO. Para desactivarlo: training.train_supervised = orig_train_supervised")


<a id="sec-07"></a>

## 7) Ejecución base con el preset (eco de config + run)

Lanza **un experimento** con el método y parámetros definidos en el preset (`CFG["continual"]`).  
Se imprimen los campos más relevantes (modelo, datos, loader) y se guarda la salida en `outputs/continual_*`.

[↑ Volver al índice](#toc)

In [None]:
# =============================================================================
# Ejecución base con el preset (eco de config + run)
# =============================================================================
from src.runner import run_continual

# Echo de configuración “resumido” (lo esencial para el run)
print(f"[RUN] preset={PRESET} | method={CFG['continual']['method']} "
      f"| seed={CFG['data']['seed']} | enc={CFG['data']['encoder']} "
      f"| kwargs={CFG['continual'].get('params', {})}")
print(f"[MODEL] {MODEL_NAME} {tfm.w}x{tfm.h} gray={tfm.to_gray}")
print(f"[DATA] T={CFG['data']['T']} gain={CFG['data']['gain']} "
      f"| offline_spikes={CFG['data']['use_offline_spikes']} "
      f"| runtime_encode={CFG['data']['encode_runtime']}")
print(f"[LOADER] workers={CFG['data']['num_workers']} "
      f"prefetch={CFG['data']['prefetch_factor']} pin={CFG['data']['pin_memory']} "
      f"persistent={CFG['data']['persistent_workers']} "
      f"| aug={bool(CFG['data']['aug_train'])} "
      f"| balance_online={CFG['data']['balance_online']}")

out_path, _ = run_continual(
    task_list=task_list,
    make_loader_fn=make_loader_fn,   # wrapper (Celda 4)
    make_model_fn=make_model_fn,     # factory (Celda 5)
    tfm=tfm,
    cfg=CFG,                         # preset completo
    preset_name=PRESET,              # solo naming
    out_root=ROOT / "outputs",
    verbose=True,
)
print("OK:", out_path)

<a id="sec-08"></a>

## 8) Comparativa de métodos (mismo preset / misma semilla / mismos datos)

Esta celda ejecuta varias corridas cambiando **únicamente** el método de aprendizaje
continual, manteniendo fijos el `preset` cargado en la Celda 2 (modelo, loader, AMP,
LR, epochs, T, gain, tamaño de imagen, etc.).

- Usa `CFG` tal cual (ya trae `continual.method`/`params`, `data.encoder`, `data.seed`, etc.).
- Para cada método, se clona `CFG` y se sobreescriben **solo** `continual.method` y `continual.params`.
- Los resultados se escriben en `outputs/continual_<preset>_<tag>_<encoder>_model-..._seed_<seed>/continual_results.json`.
- Después, usa las Celdas 10–13 para generar el resumen y el CSV de agregados.

**Requisitos**:
- Si usas *offline spikes*, asegúrate de que existen H5 compatibles con el preset:
  encoder/T/gain/size/to_gray. Si no, el loader emitirá un `FileNotFoundError`.

[↑ Volver al índice](#toc)

In [None]:
# === COMPARATIVA DE MÉTODOS: mismo preset, misma semilla, mismos datos ===
from copy import deepcopy
from src.runner import run_continual

# 1) Base de configuración: la CFG ya cargada en Celda 2
CFG_BASE = deepcopy(CFG)  # opcionalmente, fija aquí la semilla común
# CFG_BASE["data"]["seed"] = 42

# 2) Define los métodos a comparar (ajusta hiperparámetros a tu gusto)
METHODS = {
    "naive": {},
    "ewc": {"lam": 7e8, "fisher_batches": 800},
    "rehearsal": {"buffer_size": 5000, "replay_ratio": 0.2},
    "rehearsal+ewc": {"buffer_size": 5000, "replay_ratio": 0.2, "lam": 7e8, "fisher_batches": 800},
}

runs_out = []
for method_name, method_params in METHODS.items():
    cfg_i = deepcopy(CFG_BASE)
    cfg_i["continual"]["method"] = method_name
    cfg_i["continual"]["params"] = method_params

    print(
        f"\n=== RUN: preset={PRESET} | method={method_name} | "
        f"seed={cfg_i['data']['seed']} | enc={cfg_i['data']['encoder']} | kwargs={method_params} ==="
    )

    out_dir, _ = run_continual(
        task_list=task_list,
        make_loader_fn=make_loader_fn,  # wrapper pass-through definido en Celda 4
        make_model_fn=make_model_fn,
        tfm=tfm,
        cfg=cfg_i,                      # configuración completa del preset con el método cambiado
        preset_name=PRESET,             # solo para naming de la carpeta de salida
        out_root=ROOT / "outputs",
        verbose=True,
    )
    runs_out.append(out_dir)

print("\nHecho:", [str(p) for p in runs_out])
print("Ahora ejecuta las Celdas 10–13 para ver el resumen y comparativas.")

<a id="sec-09"></a>

## 9) Barrido de combinaciones (opcional)

Driver genérico para explorar:

- `presets × seeds × encoders × métodos`.

Útil para estudios más amplios (coste alto).  
Asegura H5 compatibles si usas modo offline; controla carga (workers/prefetch) si la GPU va justa.

[↑ Volver al índice](#toc)

In [None]:
# =============================================================================
# Driver de ejecución: barrido de combinaciones (opcional)
# =============================================================================
from copy import deepcopy
from src.runner import run_continual

PRESETS   = [PRESET]  # puedes añadir "std", "accurate", etc.
SEEDS     = [CFG["data"]["seed"], 43]
ENCODERS  = [CFG["data"]["encoder"]]  # ej. ["rate", "latency"]
METHODS   = [
    ("naive", {}),
    ("ewc", {"lam": 1e9, "fisher_batches": 600}),
    # ("rehearsal", {"buffer_size": 5000, "replay_ratio": 0.2}),
    # ("rehearsal+ewc", {"buffer_size": 5000, "replay_ratio": 0.2, "lam": 7e8, "fisher_batches": 600}),
]

for preset_i in PRESETS:
    for seed_i in SEEDS:
        for enc_i in ENCODERS:
            for method_name, method_params in METHODS:
                cfg_i = deepcopy(CFG)
                cfg_i["data"]["seed"] = seed_i
                cfg_i["data"]["encoder"] = enc_i
                cfg_i["continual"]["method"] = method_name
                cfg_i["continual"]["params"] = method_params

                print(
                    f"\n=== RUN: preset={preset_i} | method={method_name} "
                    f"| seed={seed_i} | enc={enc_i} | kwargs={method_params} ==="
                )
                out_path, _ = run_continual(
                    task_list=task_list,
                    make_loader_fn=make_loader_fn,   # mismo factory si no cambias offline/runtime
                    make_model_fn=make_model_fn,
                    tfm=tfm,
                    cfg=cfg_i,
                    preset_name=preset_i,
                    out_root=ROOT / "outputs",
                    verbose=True,
                )
                print("OK:", out_path)

<a id="sec-10"></a>

## 10) Resumen completo: inventario → parseo → agregados → tabla

- **Inventario** de runs en `outputs/continual_*`.  
- **Parseo** del nombre de las carpetas para extraer `preset`, `método`, `λ`, `encoder`, `seed`, `modelo`.  
- **Cálculo de olvido**: diferencia absoluta y relativa de T1 tras T2.  
- **Agregados** por grupo (media/σ/n) y export a `outputs/summary/continual_summary_agg.csv`.  
- **Tabla formateada** con métricas clave y desviaciones.

> Si no aparece nada, revisa que existan `continual_results.json` en las carpetas de salida.

[↑ Volver al índice](#toc)


In [None]:
## === 10) Resumen completo: inventario → parseo → agregados → tabla ===
import re, json, pandas as pd
from IPython.display import display

ALLOWED_ENC = r"(rate|latency|raw|image)"

def parse_exp_name(name: str):
    """continual_<preset>_<tag>_<encoder>[_model-<model>]?_seed_<seed>?"""
    pat = re.compile(
        rf"^continual_(?P<preset>[^_]+)_(?P<tag>.+)_(?P<enc>{ALLOWED_ENC})"
        rf"(?:_model\-(?P<model>.+?))?(?:_seed_(?P<seed>\d+))?$"
    )
    m = pat.match(name)
    meta = {"preset": None, "method": None, "lambda": None, "encoder": None, "seed": None, "model": None}
    if not m:
        return meta
    preset = m.group("preset"); tag = m.group("tag"); enc = m.group("enc")
    seed = m.group("seed"); model = m.group("model")
    lam = None; mlam = re.search(r"_lam_([^_]+)", tag)
    if mlam:
        lam = mlam.group(1)
        method = tag.replace(f"_lam_{lam}", "")
    else:
        method = tag
    meta.update({"preset": preset, "method": method, "lambda": lam, "encoder": enc,
                 "seed": int(seed) if seed is not None else None, "model": model})
    return meta

def build_runs_df(outputs_root: Path) -> pd.DataFrame:
    rows = []
    for exp_dir in sorted((outputs_root).glob("continual_*")):
        name = exp_dir.name
        meta = parse_exp_name(name)
        if meta["preset"] is None:
            continue
        results_path = exp_dir / "continual_results.json"
        if not results_path.exists():
            continue
        with open(results_path, "r", encoding="utf-8") as f:
            res = json.load(f)

        task_names = list(res.keys())
        if len(task_names) < 2:
            continue  # con 1 tarea no hay after_*

        def is_last(d: dict) -> bool:
            return not any(k.startswith("after_") for k in d.keys())

        last_task = None; first_task = None
        for tn in task_names:
            if is_last(res[tn]): last_task = tn
            else: first_task = tn
        if first_task is None or last_task is None:
            task_names_sorted = sorted(task_names)
            first_task = task_names_sorted[0]; last_task = task_names_sorted[-1]

        c1, c2 = first_task, last_task
        c1_test_mae = float(res[c1].get("test_mae", float("nan")))
        c2_test_mae = float(res[c2].get("test_mae", float("nan")))
        after_key_mae = f"after_{c2}_mae"
        c1_after_c2_mae = float(res[c1].get(after_key_mae, float("nan")))
        forgetting_abs = c1_after_c2_mae - c1_test_mae
        forgetting_rel = (forgetting_abs / c1_test_mae * 100.0) if c1_test_mae == c1_test_mae else float("nan")

        rows.append({
            "exp": name, "preset": meta["preset"], "method": meta["method"], "lambda": meta["lambda"],
            "encoder": meta["encoder"], "model": meta["model"], "seed": meta["seed"],
            "c1_name": c1, "c2_name": c2,
            "c1_mae": c1_test_mae, "c1_after_c2_mae": c1_after_c2_mae,
            "c1_forgetting_mae_abs": forgetting_abs, "c1_forgetting_mae_rel_%": forgetting_rel,
            "c2_mae": c2_test_mae,
        })
    df = pd.DataFrame(rows)
    if df.empty:
        return df
    df["lambda_num"] = pd.to_numeric(df["lambda"], errors="coerce")
    df["seed"] = pd.to_numeric(df["seed"], errors="coerce").astype("Int64")
    df = df.sort_values(by=["preset", "method", "encoder", "model", "lambda_num", "seed"],
                        na_position="last", ignore_index=True)
    return df

def aggregate_and_show(df: pd.DataFrame, outputs_root: Path):
    if df.empty:
        print("No hay filas (¿no existen JSONs o solo hubo 1 tarea por run?).")
        return
    # Vista detalle
    display(df[[
        "exp","preset","method","lambda","encoder","model","seed",
        "c1_name","c2_name","c1_mae","c1_after_c2_mae",
        "c1_forgetting_mae_abs","c1_forgetting_mae_rel_%","c2_mae","lambda_num"
    ]])

    # Agregados
    cols_metrics = ["c1_mae", "c1_after_c2_mae", "c1_forgetting_mae_abs", "c1_forgetting_mae_rel_%", "c2_mae"]
    gdf = df.copy()
    if "lambda_num" not in gdf.columns:
        gdf["lambda_num"] = pd.to_numeric(gdf["lambda"], errors="coerce")
    agg = (gdf
           .groupby(["preset", "method", "encoder", "lambda", "lambda_num"], dropna=False)[cols_metrics]
           .agg(["mean", "std", "count"])
           .reset_index())
    agg.columns = ["_".join(filter(None, map(str, col))).rstrip("_") for col in agg.columns.to_flat_index()]
    agg = agg.sort_values(by=["preset", "method", "encoder", "lambda_num"], na_position="last", ignore_index=True)

    # Persistencia y vista bonita
    summary_dir = outputs_root / "summary"
    summary_dir.mkdir(parents=True, exist_ok=True)
    out_csv = summary_dir / "continual_summary_agg.csv"
    agg.to_csv(out_csv, index=False)
    print("Guardado:", out_csv)

    def fmt(x, prec=4):
        return "" if pd.isna(x) else f"{x:.{prec}f}"

    show = agg.copy()
    count_cols = [c for c in show.columns if c.endswith("_count")]
    if count_cols:
        show["count"] = show[count_cols[0]].astype("Int64")
        show = show.drop(columns=count_cols)
    for c in [c for c in show.columns if c.endswith("_mean") or c.endswith("_std")]:
        show[c] = show[c].map(lambda v: fmt(v, 4))
    cols = [
        "preset", "method", "encoder", "lambda",
        "c1_mae_mean", "c1_forgetting_mae_rel_%_mean", "c2_mae_mean",
        "c1_mae_std",  "c1_forgetting_mae_rel_%_std",  "c2_mae_std",
        "count"
    ]
    cols = [c for c in cols if c in show.columns]
    show = show[cols].rename(columns={
        "preset": "preset", "method": "método", "encoder": "codificador", "lambda": "λ",
        "c1_mae_mean": "MAE Tarea1 (media)",
        "c1_forgetting_mae_rel_%_mean": "Olvido T1 (%) (media)",
        "c2_mae_mean": "MAE Tarea2 (media)",
        "c1_mae_std": "MAE Tarea1 (σ)",
        "c1_forgetting_mae_rel_%_std": "Olvido T1 (%) (σ)",
        "c2_mae_std": "MAE Tarea2 (σ)",
        "count": "n (semillas)"
    })
    display(show)

# Ejecuta el resumen
outputs_root = ROOT / "outputs"
print("Inventario de runs en:", outputs_root)
for p in sorted(outputs_root.glob("continual_*")):
    print(" -", p.name, "| results.json:", (p / "continual_results.json").exists())

df = build_runs_df(outputs_root)
print(f"runs en resumen: {len(df)}")
aggregate_and_show(df, outputs_root)
