<a id="top"></a>
# 03_TRAIN_CONTINUAL — Entrenamiento Continual con *presets*

**Qué hace este notebook**  
Entrena y evalúa modelos en **aprendizaje continual** (secuencia de tareas) usando una **configuración unificada** desde `configs/presets.yaml`. Permite:  
1) lanzar un *run* base con el método del preset,  
2) **comparar métodos** manteniendo fijos datos/modelo, y  
3) generar un **resumen agregado** de resultados en `outputs/summary/`.


---

## 🎯 Objetivos
- Centralizar la configuración de **modelo**, **datos/codificación temporal**, **optimizador** y **método continual** vía `presets.yaml`.
- Soportar **H5 offline** (si `use_offline_spikes: true`) o **CSV + codificación en runtime** (si `encode_runtime: true`), seleccionándolo de forma coherente con el preset.
- Comparar métodos (`naive`, `ewc`, `rehearsal`, `rehearsal+ewc`, y los bio-inspirados previstos: `as-snn`, `sa-snn`, `sca-snn`, `colanet`) con **idéntica preparación de datos**.
- Exportar un **CSV de agregados** con métricas clave (MAE/MSE por tarea, olvido absoluto/relativo, etc.).

## ✅ Prerrequisitos
- Haber generado `data/processed/tasks.json` (y opcionalmente `tasks_balanced.json`) con **01_DATA_QC_PREP** o **01A_PREP_BALANCED**.
- Si el preset usa **offline** (`use_offline_spikes: true`), haber creado los H5 compatibles con **02_ENCODE_OFFLINE** (mismo `encoder/T/gain/size/to_gray` que el preset).
- Revisar `configs/presets.yaml` (secciones `model`, `data`, `optim`, `continual`, `prep`).

## ⚠️ Notas importantes
- **No combines** `use_offline_spikes: true` y `encode_runtime: true`. El notebook lo detecta y lanza error.
- La **semilla** global se toma de `CFG["data"]["seed"]` para reproducibilidad.
- La carpeta de salida incluye en el nombre preset, método, *encoder*, modelo, *seed*, etc., para facilitar trazabilidad.

<a id="toc"></a>

## 🧭 Índice

- [1) Setup del entorno y paths](#sec-01)  
- [2) Carga del preset unificado (`configs/presets.yaml`)](#sec-02)  
- [3) Verificación de datos y selección de `tasks.json`](#sec-03)  
- [4) Factories DataLoaders + Modelo (+ tasks)](#sec-04)  
- [5) Ejecución base con el preset (eco de config + run)](#sec-06)  
- [6) Comparativa de métodos (mismo preset/semilla/datos)](#sec-07)  
- [7) Barrido de combinaciones (opcional)](#sec-08)  
- [8) Resumen completo: inventario → parseo → agregados → tabla](#sec-09)



<a id="sec-01"></a>
## 1) Setup del entorno y paths

**Objetivo**  
Preparar el entorno: limitar hilos BLAS (evitar *oversubscription*), detectar `ROOT` (raíz del repo) y añadirlo a `sys.path`, importar utilidades del proyecto y seleccionar dispositivo (`cuda` si está disponible). Se activan optimizaciones de PyTorch en GPU (TF32/cuDNN) para acelerar.

> Aquí **no** se leen aún los presets; solo se configura el runtime global. 

[↑ Volver al índice](#toc)

In [None]:
# =============================================================================
# 1) Setup del entorno y paths
# =============================================================================
import os, torch

# Robustez multiproceso/WSL
# os.environ["PYTORCH_SHARING_STRATEGY"] = "file_system"
# os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"

# >>> Memoria CUDA: allocator estable en runs largos <<<
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:64,garbage_collection_threshold:0.6"

# Loggear it/s cada epoch desde training.py (opcional)
os.environ["TRAIN_LOG_ITPS"] = "1"

try:
    import torch.multiprocessing as mp
    mp.set_sharing_strategy("file_system")
except Exception:
    pass

# (solo para depurar caídas puntuales; descomenta si quieres localizar llamada)
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

from pathlib import Path
import sys, torch

ROOT = Path.cwd().parents[0] if (Path.cwd().name == "notebooks") else Path.cwd()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

from src.datasets import ImageTransform, AugmentConfig
from src.models   import build_model
from src.utils    import load_preset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_num_threads(4)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")

OUT = ROOT / "outputs"
OUT.mkdir(parents=True, exist_ok=True)

print("ROOT:", ROOT)
print("OUT :", OUT)
print("Device:", device)


<a id="sec-02"></a>

## 2) Carga del preset unificado (`configs/presets.yaml`)

**Objetivo**  
Cargar un **preset** (`fast` | `std` | `accurate`) y derivar toda la configuración operativa:

- **Modelo/transform**: tamaño de imagen, escala de grises, etc.
- **Datos/codificación**: `encoder` (`rate|latency|raw`), `T`, `gain`, `seed`.
- **DataLoader**: `num_workers`, `prefetch_factor`, `pin_memory`, `persistent_workers`.
- **Augment** opcional y **balanceo online** si procede.

Incluye un **guardarraíl**: si `use_offline_spikes: true` y `encode_runtime: true` están ambos activos, se aborta con un error claro (config inválida).  

[↑ Volver al índice](#toc)

In [None]:
# =============================================================================
# 2) Carga del preset (configs/presets.yaml)
# =============================================================================
from src.utils import load_preset

PRESET = "accurate"  # "fast" para pruebas rápidas, "accurate" para resultados finales
CFG = load_preset(ROOT / "configs" / "presets.yaml", PRESET)

# ---- Modelo / Transform ----
MODEL_NAME = CFG["model"]["name"]
tfm = ImageTransform(
    CFG["model"]["img_w"],
    CFG["model"]["img_h"],
    to_gray=bool(CFG["model"]["to_gray"]),
    crop_top=None,
)

# ---- Datos / codificación ----
ENCODER = CFG["data"]["encoder"]
T       = int(CFG["data"]["T"])
GAIN    = float(CFG["data"]["gain"])
SEED    = int(CFG["data"]["seed"])

USE_OFFLINE_SPIKES = bool(CFG["data"].get("use_offline_spikes", False))
RUNTIME_ENCODE     = bool(CFG["data"].get("encode_runtime", False))

# ---- Loader / augment / balanceo ----
NUM_WORKERS = int(CFG["data"].get("num_workers") or 0)
PREFETCH    = int(CFG["data"].get("prefetch_factor") or 2)
PIN_MEMORY  = bool(CFG["data"].get("pin_memory", True))
PERSISTENT  = bool(CFG["data"].get("persistent_workers", True))

AUG_CFG = AugmentConfig(**(CFG["data"].get("aug_train") or {})) \
          if CFG["data"].get("aug_train") else None

USE_ONLINE_BALANCING = bool(CFG["data"].get("balance_online", False))
BAL_BINS = int(CFG["data"].get("balance_bins") or 50)
BAL_EPS  = float(CFG["data"].get("balance_smooth_eps") or 1e-3)

# Guardarraíl
if USE_OFFLINE_SPIKES and RUNTIME_ENCODE:
    raise RuntimeError("Config inválida: use_offline_spikes=True y encode_runtime=True a la vez.")

print(f"[PRESET={PRESET}] model={MODEL_NAME} {tfm.w}x{tfm.h} gray={tfm.to_gray}")
print(f"[DATA] encoder={ENCODER} T={T} gain={GAIN} seed={SEED}")
print(f"[LOADER] workers={NUM_WORKERS} prefetch={PREFETCH} pin={PIN_MEMORY} persistent={PERSISTENT}")
print(f"[BALANCE] online={USE_ONLINE_BALANCING} bins={BAL_BINS} eps={BAL_EPS}")
print(f"[RUNTIME_ENCODE] {RUNTIME_ENCODE} | [OFFLINE_SPIKES] {USE_OFFLINE_SPIKES}")


<a id="sec-03"></a>

## 3) Verificación de datos y selección de `tasks.json`

**Objetivo**  
Construir `task_list` y verificar que existen los *splits* por tarea:

- Si el preset pide **balanced** (`prep.use_balanced_tasks: true`) y existe `tasks_balanced.json`, se usa; en caso contrario, se cae a `tasks.json` (se informa).
- Se valida que `train/val/test.csv` existen para cada *run*.  
- Si entrenas con **H5 offline**, se comprueba que están presentes los H5 con **nomenclatura compatible** con el preset (`encoder/T/gain/size/to_gray`).

> Si falta algún H5 requerido, genera primero con **02_ENCODE_OFFLINE**.  

[↑ Volver al índice](#toc)

In [None]:
# =============================================================================
# 3) Verificación de datos (splits y, si procede, H5)
# =============================================================================
import json
from pathlib import Path as _P

PROC = ROOT / "data" / "processed"

USE_BALANCED = bool(CFG.get("prep", {}).get("use_balanced_tasks", False))
tb_name = (CFG.get("prep", {}).get("tasks_balanced_file_name") or "tasks_balanced.json")
t_name  = (CFG.get("prep", {}).get("tasks_file_name") or "tasks.json")

cand_bal = PROC / tb_name
cand_std = PROC / t_name
TASKS_FILE = cand_bal if (USE_BALANCED and cand_bal.exists()) else cand_std

with open(TASKS_FILE, "r", encoding="utf-8") as f:
    tasks_json = json.load(f)

task_list = [{"name": n, "paths": tasks_json["splits"][n]} for n in tasks_json["tasks_order"]]
print("Usando:", TASKS_FILE.name)
for t in task_list:
    print(f" - {t['name']}: {_P(t['paths']['train']).name}")

if USE_BALANCED:
    for t in task_list:
        train_path = _P(tasks_json["splits"][t["name"]]["train"])
        if train_path.name != "train_balanced.csv":
            raise RuntimeError(
                f"[{t['name']}] Esperaba 'train_balanced.csv' en modo balanced, pero encontré '{train_path.name}'."
            )

if USE_OFFLINE_SPIKES:
    mw, mh = CFG["model"]["img_w"], CFG["model"]["img_h"]
    color = "gray" if CFG["model"]["to_gray"] else "rgb"
    gain_tag = (GAIN if ENCODER == "rate" else 0)
    missing = []
    for t in task_list:
        base = PROC / t["name"]
        for split in ("train", "val", "test"):
            expected = base / f"{split}_{ENCODER}_T{T}_gain{gain_tag}_{color}_{mw}x{mh}.h5"
            if not expected.exists():
                missing.append(str(expected))
    if missing:
        print("[WARN] Faltan H5 compatibles con el preset. Genera con 02_ENCODE_OFFLINE.")
print("OK: verificación de splits.")


<a id="sec-04"></a>
## 4) Factories unificados: DataLoaders + Modelo (+ tasks)

**Objetivo**  
Crear, en una sola llamada, los **componentes coherentes con el preset**:

- `build_components_for(CFG, ROOT)` → devuelve `tfm`, `make_loader_fn`, `make_model_fn`.
  - El **loader** respeta automáticamente el modo datos (H5 offline vs. CSV+encode runtime), *workers/prefetch/pin/persistent*, *augment*, y **balanceo online** si está activo.
  - El **modelo** se instancia según `model.name` y parámetros asociados.
- `build_task_list_for(CFG, ROOT)` → devuelve `task_list` y el *tasks file* efectivamente usado.

> Con esto evitas duplicar lógica entre cuadernos y garantizas que **bench, entrenamiento y comparativa** usen la **misma** configuración.  

[↑ Volver al índice](#toc)


In [None]:
# =============================================================================
# 4) Factories: DataLoaders + Modelo + task_list
# =============================================================================
from src.utils import build_task_list_for, build_components_for

tfm, make_loader_fn, make_model_fn = build_components_for(CFG, ROOT)
task_list, tasks_file = build_task_list_for(CFG, ROOT)

print("Tasks file:", tasks_file.name)
print("make_loader_fn listo (usa H5 si offline; si no, CSV + encode runtime).")


<a id="sec-05"></a>
## 5) Ejecución base con el preset (eco de config + run)

**Objetivo**  
Lanzar **un experimento** con el método y parámetros del preset (`CFG["continual"]`). La celda:

- Imprime un **resumen de configuración** (modelo, datos, loader, método).
- Ejecuta `run_continual(...)`.
- Guarda resultados en `outputs/continual_*` (incluye `continual_results.json` y `manifest.json` por tarea).

> Revisa la consola para confirmar dispositivo, *encoder/T/gain* y modo de datos (offline/ runtime).  

[↑ Volver al índice](#toc)

In [None]:
# =============================================================================
# 5) Ejecución base con el preset
# =============================================================================
from src.runner import run_continual

print(f"[RUN] preset={PRESET} | method={CFG['continual']['method']} "
      f"| seed={CFG['data']['seed']} | enc={CFG['data']['encoder']} "
      f"| kwargs={CFG['continual'].get('params', {})}")
print(f"[MODEL] {MODEL_NAME} {tfm.w}x{tfm.h} gray={tfm.to_gray}")
print(f"[DATA] T={CFG['data']['T']} gain={CFG['data']['gain']} "
      f"| offline_spikes={CFG['data']['use_offline_spikes']} "
      f"| runtime_encode={CFG['data']['encode_runtime']}")
print(f"[LOADER] workers={CFG['data']['num_workers']} "
      f"prefetch={CFG['data']['prefetch_factor']} pin={CFG['data']['pin_memory']} "
      f"persistent={CFG['data']['persistent_workers']} "
      f"| aug={bool(CFG['data']['aug_train'])} "
      f"| balance_online={CFG['data']['balance_online']}")

out_path, _ = run_continual(
    task_list=task_list,
    make_loader_fn=make_loader_fn,
    make_model_fn=make_model_fn,
    tfm=tfm,
    cfg=CFG,
    preset_name=PRESET,
    out_root=OUT,
    verbose=True,
)

print("OK:", out_path)


<a id="sec-06"></a>
## 6) Comparativa de métodos (mismo preset / misma semilla / mismos datos)

**Objetivo**  
Ejecutar una **batería de métodos** cambiando **solo** `continual.method` y sus `params`, manteniendo fijos: preset, semilla, *encoder/T/gain*, tamaño de imagen, *augment*, etc.

- Se clona `CFG` por método y se invoca `run_continual(...)` con las **factories** del propio `cfg_i`.
- El diccionario `METHODS` puede ampliarse con nombres registrados en `src/methods/`:
  - `naive`, `ewc`, `rehearsal`, `rehearsal+ewc`
  - (bio-inspirados previstos) `as-snn`, `sa-snn`, `sca-snn`, `colanet`

**Recomendaciones**
- Si usas **offline H5**, asegúrate de que existen para el preset (`02_ENCODE_OFFLINE`).
- Si activas *replay* (rehearsal), puedes **reducir** `persistent_workers` para evitar atascos de DataLoader en algunos entornos; la celda ya lo ajusta como precaución.

[↑ Volver al índice](#toc)

In [None]:

EXPERIMENTS = [
    # ("naive", {}),
    # ("ewc", {"lam": 1e9, "fisher_batches": 1000}),
    # ("rehearsal", {"buffer_size": 3000, "replay_ratio": 0.1}),
    # ("rehearsal+ewc", {"buffer_size": 3000, "replay_ratio": 0.1, "lam": 1e9, "fisher_batches": 1000}),
    # ("sa-snn", {"attach_to":"f6","k":8,"tau":28,"vt_scale":1.33,"p":2_000_000,
    #             "flatten_spatial":False,"assume_binary_spikes":False,"reset_counters_each_task":False}),
    # ("sa-snn", {"attach_to":"f6","k":8,"tau":32,"vt_scale":1.33,"p":5_000_000,
    #             "flatten_spatial":False,"assume_binary_spikes":False,"reset_counters_each_task":False}),
    # ("sa-snn", {"attach_to":"f6","k":9,"tau":28,"vt_scale":1.33,"p":5_000_000,
    #             "flatten_spatial":False,"assume_binary_spikes":False,"reset_counters_each_task":False}),
]

EXPERIMENTS = [
    # --- Clásicos ---
    ("naive", {}),  # baseline sin mitigación del olvido

    # Mejor EWC que ya viste en accurate
    ("ewc", {"lam": 7e8, "fisher_batches": 1000}),   # 500 rápido; si hay tiempo: 1000

    # Rehearsal estable en tus runs
    ("rehearsal", {"buffer_size": 3000, "replay_ratio": 0.1}),  # o 0.2 si puedes

    # Combo que ya salió en tu Pareto (accurate)
    ("rehearsal+ewc", {"buffer_size": 3000, "replay_ratio": 0.2, "lam": 1e9, "fisher_batches": 1500}),
]

EXPERIMENTS = [
    ("sa-snn", {"attach_to":"f6","k":8,"tau":28,"vt_scale":1.33,"p":2_000_000,
                "th_min":1.0,"th_max":2.0,"flatten_spatial":False,
                "assume_binary_spikes":False,"reset_counters_each_task":False}),
    ("sa-snn", {"attach_to":"f6","k":8,"tau":32,"vt_scale":1.33,"p":5_000_000,
                "th_min":1.0,"th_max":2.0,"flatten_spatial":False,
                "assume_binary_spikes":False,"reset_counters_each_task":False}),
    ("sa-snn", {"attach_to":"f6","k":9,"tau":28,"vt_scale":1.33,"p":5_000_000,
                "th_min":1.0,"th_max":2.0,"flatten_spatial":False,
                "assume_binary_spikes":False,"reset_counters_each_task":False}),
]

EXPERIMENTS = [
    # SA-SNN “intermedio” (entre lo que ya lanzaste):
    ("sa-snn", {
        "attach_to":"f6",
        "k": 8,
        "tau": 30,          # ← intermedio (tu preset usa 30)
        "vt_scale": 1.33,   # ← igual que las variantes que mejoraron olvido
        "p": 5_000_000,
        "th_min": 1.0, "th_max": 2.0,
        "flatten_spatial": False,
        "assume_binary_spikes": False,
        "reset_counters_each_task": False
    }),
]

EXPERIMENTS = [
    ("sca-snn", {
    "attach_to":"f6",
    "flatten_spatial": False,
    "num_bins": 50,
    "bin_lo": -1.0, "bin_hi": 1.0,
    "anchor_batches": 10,
    "max_per_bin": 512,
    "beta": 0.65,         # ↓
    "bias": 0.05,         # ↓
    "soft_mask_temp": 1.0,
    "habit_decay": 0.995,
    "verbose": True,
    "log_every": 50
}),
]


In [None]:
# =============================================================================
# 6) Barrido de métodos / variantes para la memoria
# =============================================================================
from copy import deepcopy
from src.runner import run_continual
from src.utils import build_task_list_for, build_components_for

# Define aquí SOLO la lista definitiva que quieres lanzar ahora
EXPS = [
    # baseline naive (sin mitigación del olvido)
    dict(
         method="naive",
         params={},
         tag="grid01"
    ),
    # SCA-SNN: mejor tuyo + dos variaciones de beta
    dict(
        method="sca-snn",
        params={
            "attach_to": "f6",
            "flatten_spatial": False,
            "num_bins": 50,
            "bin_lo": -1.0, "bin_hi": 1.0,
            "anchor_batches": 12,
            "max_per_bin": 512,
            "beta": 0.65,
            "bias": 0.0,
            "soft_mask_temp": 0.75,
            "habit_decay": 0.995,
            "verbose": True,
            "log_every": 8192
        },
        tag="grid06_sca_b065"
    ),
    dict(
        method="sca-snn",
        params={
            "attach_to": "f6",
            "flatten_spatial": False,
            "num_bins": 50,
            "bin_lo": -1.0, "bin_hi": 1.0,
            "anchor_batches": 12,
            "max_per_bin": 512,
            "beta": 0.60,
            "bias": 0.0,
            "soft_mask_temp": 0.75,
            "habit_decay": 0.995,
            "verbose": True,
            "log_every": 8192
        },
        tag="grid07_sca_b060"
    ),
    dict(
        method="sca-snn",
        params={
            "attach_to": "f6",
            "flatten_spatial": False,
            "num_bins": 50,
            "bin_lo": -1.0, "bin_hi": 1.0,
            "anchor_batches": 12,
            "max_per_bin": 512,
            "beta": 0.70,
            "bias": 0.0,
            "soft_mask_temp": 0.75,
            "habit_decay": 0.995,
            "verbose": True,
            "log_every": 8192
        },
        tag="grid08_sca_b070"
    ),

    # AS-SNN (mejores HPO que ya tienes)
    dict(
        method="as-snn",
        params={"gamma_ratio": 0.3, "lambda_a": 1.6, "ema": 0.9},
        tag="grid09_as"
    ),

    # SA-SNN (tus hiperparámetros buenos)
    dict(
        method="sa-snn",
        params={"k": 8, "tau": 28, "thresh_lo": 1.2, "period": 200000},
        tag="grid10_sa"
    ),

    # EWC (tu mejor lambda y fisher_batches)
    dict(
        method="ewc",
        params={"lam": 7e8, "fisher_batches": 1000},
        tag="grid11_ewc"
    ),
    # Rehearsal: barrido de replay_ratio (buffer_size fijo=3000)
    dict(
        method="rehearsal",
        params={"buffer_size": 3000, "replay_ratio": 0.10},
        tag="grid02_rr10"
    ),
    dict(
        method="rehearsal",
        params={"buffer_size": 3000, "replay_ratio": 0.15},
        tag="grid03_rr15"
    ),
    dict(
        method="rehearsal",
        params={"buffer_size": 3000, "replay_ratio": 0.20},
        tag="grid04_rr20"
    ),
    dict(
        method="rehearsal",
        params={"buffer_size": 3000, "replay_ratio": 0.25},
        tag="grid05_rr25"
    ),


]

runs_out = []

for exp in EXPS:
    method_name   = exp["method"]
    method_params = exp["params"]

    # clonar preset base para no pisar CFG
    cfg_i = deepcopy(CFG)
    cfg_i["continual"]["method"] = method_name
    cfg_i["continual"]["params"] = method_params

    # metadato opcional para que runner etiquete bien la carpeta
    # (runner ya mira cfg["naming"]["tag"] si existe)
    cfg_i.setdefault("naming", {})
    cfg_i["naming"]["tag"] = exp["tag"]

    # pequeño ajuste de robustez para rehearsal: baja workers persistentes
    if "rehearsal" in method_name:
        cfg_i["data"]["persistent_workers"] = False

    # construir componentes coherentes con ESTA cfg_i
    tfm_i, make_loader_fn_i, make_model_fn_i = build_components_for(cfg_i, ROOT)
    task_list_i, tasks_file_i = build_task_list_for(cfg_i, ROOT)

    print(
        f"\n=== RUN: preset={PRESET} | method={method_name} "
        f"| seed={cfg_i['data']['seed']} | enc={cfg_i['data']['encoder']} "
        f"| kwargs={method_params} | tag={exp['tag']} ==="
    )

    out_dir, _ = run_continual(
        task_list=task_list_i,
        make_loader_fn=make_loader_fn_i,
        make_model_fn=make_model_fn_i,
        tfm=tfm_i,
        cfg=cfg_i,
        preset_name=PRESET,
        out_root=OUT,
        verbose=True,
    )

    runs_out.append(out_dir)

print("\nHecho:", [str(p) for p in runs_out])


In [None]:
# =============================================================================
# 6) Barrido "mejor por método" (accurate, T=30, B=160)
# =============================================================================
from copy import deepcopy
from src.runner import run_continual
from src.utils import load_preset, build_task_list_for, build_components_for
from pathlib import Path
import sys

# --- Carga preset ---
ROOT = Path.cwd().parents[0] if (Path.cwd().name == "notebooks") else Path.cwd()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

PRESET = "accurate"
CFG = load_preset(ROOT / "configs" / "presets.yaml", PRESET)

# === Lista de mejores por método (según tu tabla) ===
# Nombres y parámetros alineados con las implementaciones actuales
EXPS = [
  # Rehearsal ganador: buffer=3000, rr=0.10
  # dict(
  #   method="rehearsal",
  #   params={
  #     "buffer_size": 3000,
  #     "replay_ratio": 0.10
  #   },
  #   tag="best_reh_buf3000_rr10"
  # ),

  # SA-SNN ganador: k=8, tau=28, p=2e6, th 1–2
  # dict(
  #   method="sa-snn",
  #   params={
  #     "attach_to": "f6",
  #     "k": 8,
  #     "tau": 28.0,
  #     "th_min": 1.0,
  #     "th_max": 2.0,
  #     "p": 2000000,
  #     "vt_scale": 1.0,
  #     "flatten_spatial": False,
  #     "assume_binary_spikes": False,
  #     "reset_counters_each_task": False
  #   },
  #   tag="best_sa_k8_tau28_p2m"
  # ),

  # SCA-SNN ganador: bins50, beta=0.60, bias=0.05, temp=0.5, ab=16, flat=0
  # dict(
  #     method="sca-snn",
  #     params={
  #         "attach_to": "f6",
  #         "flatten_spatial": False,
  #         "num_bins": 50,
  #         "anchor_batches": 16,
  #         "beta": 0.60,
  #         "bias": 0.05,
  #         "soft_mask_temp": 0.50,
  #         "verbose": False,
  #         "log_every": 65536
  #     },
  #     tag="best_sca_b060_bias005_t050_ab16"
  # ),

  # AS-SNN ganador: gamma_ratio=0.3, lambda≈1.59168
  dict(
    method="as-snn",
    params={
      "gamma_ratio": 0.3,
      "lambda_a": 1.59168,
      "ema": 0.9
    },
    tag="best_as_gr03_lam1p59168"
  ),

  # EWC ganador: lam=7e8, fisher=1000
  dict(
    method="ewc",
    params={
      "lam": 7e8,
      "fisher_batches": 1000
    },
    tag="best_ewc_lam7e8_f1000"
  ),

  # Baseline naive (sin mitigación)
  dict(
    method="naive",
    params={},
    tag="baseline_naive"
  ),
]

# --- Ejecutar ---
OUT = ROOT / "outputs"
OUT.mkdir(parents=True, exist_ok=True)

runs_out = []
for exp in EXPS:
    method_name   = exp["method"]
    method_params = exp["params"]

    cfg_i = deepcopy(CFG)
    cfg_i["continual"]["method"]  = method_name
    cfg_i["continual"]["params"]  = method_params
    cfg_i.setdefault("naming", {})
    cfg_i["naming"]["tag"] = exp["tag"]

    # Seguridad: rehearsal a veces se lleva mal con workers persistentes
    if "rehearsal" in method_name:
        cfg_i["data"]["persistent_workers"] = False

    tfm_i, make_loader_fn_i, make_model_fn_i = build_components_for(cfg_i, ROOT)
    task_list_i, tasks_file_i = build_task_list_for(cfg_i, ROOT)

    print(f"\n=== RUN accurate: method={method_name} | seed={cfg_i['data']['seed']} "
          f"| enc={cfg_i['data']['encoder']} | kwargs={method_params} | tag={exp['tag']} ===")

    out_dir, _ = run_continual(
        task_list=task_list_i,
        make_loader_fn=make_loader_fn_i,
        make_model_fn=make_model_fn_i,
        tfm=tfm_i,
        cfg=cfg_i,
        preset_name=PRESET,
        out_root=OUT,
        verbose=True,
    )
    runs_out.append(out_dir)

print("\nHecho:", [str(p) for p in runs_out])


<a id="sec-07"></a>
## 7) Resumen completo: inventario → parseo → agregados → tabla

**Objetivo**  
Crear un **resumen reproducible** de todos los *runs*:

- **Inventario** de carpetas `outputs/continual_*`.
- **Parseo** de nombres para extraer `preset`, `método`, `encoder`, `seed`, `modelo`, y parámetros relevantes.
- Cálculo de **olvido** (absoluto y relativo) y **agregados** por grupo (media, σ, n).
- Export a `outputs/summary/continual_summary_agg.csv` y **tabla formateada** para la memoria.

> Si no se detectan *runs*, verifica que exista `continual_results.json` dentro de cada carpeta.  

[↑ Volver al índice](#toc)


In [None]:
# =============================================================================
# 7) Resumen y gráficas
# =============================================================================
from pathlib import Path
import json
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from src.results_io import build_results_table
from src.plots import plot_across_runs

summary_dir = OUT / "summary"
summary_dir.mkdir(parents=True, exist_ok=True)

def canonical_method(s: str) -> str:
    if not isinstance(s, str):
        return "unknown"
    t = s.lower()
    if ("rehearsal" in t) and ("+ewc" in t or "_ewc" in t):
        return "rehearsal+ewc"
    if "sca-snn" in t:
        return "sca-snn"
    if re.search(r"\bsa[-_]snn\b", t):
        return "sa-snn"
    if re.search(r"\bas[-_]snn\b", t):
        return "as-snn"
    if "colanet" in t:
        return "colanet"
    if re.search(r"\bewc\b", t) or "ewc_lam" in t:
        return "ewc"
    if "rehearsal" in t:
        return "rehearsal"
    if "naive" in t or "finetune" in t or "fine-tune" in t:
        return "naive"
    return t.split("_")[0]

# --- 7.1 Tabla consolidada ---
df = build_results_table(OUT)
df["method_base"] = df["method"].astype(str).apply(canonical_method)
display(df)
df.to_csv(summary_dir / "results_table.csv", index=False)
print(f"[OK] Tabla guardada en {summary_dir/'results_table.csv'}")

# --- 7.2 Gráficas comparativas (final MAE, olvido, emisiones, trade-off) ---
plots_dir = plot_across_runs(df, summary_dir / "plots")
print("[OK] Gráficas comparativas en:", plots_dir)

# --- 7.3 Curvas de loss por tarea (para la memoria) ---
def plot_losses_for_run(run_dir: Path, outdir: Path):
    """Busca en run_dir/task_*/manifest.json y dibuja curvas de train/val loss por tarea."""
    outdir.mkdir(parents=True, exist_ok=True)
    task_dirs = sorted(run_dir.glob("task_*"))
    if not task_dirs:
        print(f"[WARN] No hay carpetas task_* en {run_dir}")
        return

    for td in task_dirs:
        # manifest.json o metrics.json
        man = None
        for cand in ("manifest.json", "metrics.json"):
            p = td / cand
            if p.exists():
                with open(p, "r", encoding="utf-8") as f:
                    man = json.load(f)
                break
        if man is None:
            print(f"[WARN] Sin manifest/metrics en {td.name}")
            continue

        hist = (man.get("history") or {})
        tr = hist.get("train_loss") or []
        va = hist.get("val_loss") or []

        if not tr and not va:
            print(f"[WARN] {td.name}: sin 'train_loss'/'val_loss' en history.")
            continue

        plt.figure(figsize=(7,4))
        if tr:
            plt.plot(range(1, len(tr)+1), tr, label="train_loss")
        if va:
            plt.plot(range(1, len(va)+1), va, label="val_loss")
        plt.title(f"{run_dir.name} — {td.name}")
        plt.xlabel("Epoch")
        plt.ylabel("MSE loss")
        plt.legend()
        plt.tight_layout()
        plt.savefig(outdir / f"{run_dir.name}__{td.name}_loss.png", dpi=160)
        plt.savefig(outdir / f"{run_dir.name}__{td.name}_loss.svg")
        plt.show()

# Elige el/los runs para curvas de loss (aquí, el último por fecha):
runs = sorted(OUT.glob("continual_*"), key=lambda p: p.stat().st_mtime, reverse=True)
if runs:
    loss_plots_dir = summary_dir / "loss_curves"
    print("Generando curvas de loss para:", runs[0].name)
    plot_losses_for_run(runs[0], loss_plots_dir)
    print("[OK] Curvas de loss en:", loss_plots_dir)
else:
    print("[INFO] No hay runs en outputs/ todavía.")


In [None]:
# === Selección automática de runs representativos ===
import re
import numpy as np
import pandas as pd

def norm01(x):
    x = x.astype(float)
    lo, hi = np.nanmin(x), np.nanmax(x)
    if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
        return np.ones_like(x) * 0.5
    return (x - lo) / (hi - lo)

# Usa la tabla ya construida en la celda anterior
d = df.copy()

# Normaliza método base (por si esta celda se ejecuta sola)
if "method_base" not in d.columns:
    d["method_base"] = d["method"].astype(str).apply(canonical_method)

# Filtra preset = accurate
d = d[d["preset"] == "accurate"].copy()
assert not d.empty, "No hay runs con preset='accurate'."

# --- Política de tarea principal: ÚLTIMA *_final_mae ---
# (si prefieres media de tareas, cambia la selección abajo)
task_cols = [c for c in d.columns if c.endswith("_final_mae")]
assert len(task_cols) > 0, "No encuentro columnas *_final_mae en la tabla."

def sort_key(col):
    name = col.replace("_final_mae", "")
    m = re.search(r"(\d+)$", name)
    base = re.sub(r"\d+$", "", name)
    idx = int(m.group(1)) if m else 0
    return (base, idx)

task_cols_sorted = sorted(task_cols, key=sort_key)
primary_mae  = task_cols_sorted[-1]    # p.ej. 'circuito2_final_mae'
primary_task = primary_mae.replace("_final_mae", "")
print("[INFO] Tareas detectadas:", task_cols_sorted)
print("[INFO] Usando como MAE principal:", primary_mae)

# Mantén columnas necesarias y normaliza
keep = ["run_dir","preset","method","method_base","encoder","model","seed",
        "elapsed_sec","emissions_kg","avg_forget_rel", primary_mae]
d = d[keep].copy()
for c in ["emissions_kg","avg_forget_rel", primary_mae]:
    d[c] = pd.to_numeric(d[c], errors="coerce")

# Relleno conservador
if d["emissions_kg"].isna().all():
    d["emissions_kg"] = 0.0
else:
    d["emissions_kg"] = d["emissions_kg"].fillna(d["emissions_kg"].median())
d["avg_forget_rel"] = d["avg_forget_rel"].fillna(d["avg_forget_rel"].max())

# --- Frente de Pareto (minimizar: MAE, olvido, emisiones) ---
M = d[[primary_mae,"avg_forget_rel","emissions_kg"]].values
is_dominated = np.zeros(len(d), dtype=bool)
for i in range(len(d)):
    ai = np.nan_to_num(M[i], nan=np.inf)
    for j in range(len(d)):
        if i == j: continue
        aj = np.nan_to_num(M[j], nan=np.inf)
        if np.all(aj <= ai) and np.any(aj < ai):
            is_dominated[i] = True
            break

pareto = d.loc[~is_dominated].sort_values([primary_mae, "avg_forget_rel", "emissions_kg"])
print(f"=== Frente de Pareto (no dominados) — MAE final ({primary_task}) ===")
display(pareto)

# --- Ranking por puntuación compuesta (opcional) ---
w_mae, w_forget, w_emiss = 0.5, 0.4, 0.1
d["_mae_n"]    = norm01(d[primary_mae].values)
d["_forget_n"] = norm01(d["avg_forget_rel"].values)
d["_emiss_n"]  = norm01(d["emissions_kg"].values)
d["score"]     = w_mae*d["_mae_n"] + w_forget*d["_forget_n"] + w_emiss*d["_emiss_n"]

topN = d.sort_values("score").head(6)
print("=== Top-6 por score compuesto (↓ mejor) ===")
display(topN[["run_dir","preset","method","method_base","seed",primary_mae,"avg_forget_rel","emissions_kg","score"]])

# ---- Si prefieres usar la MEDIA de tareas en lugar de la última ----
# mean_mae = d[task_cols].mean(axis=1, skipna=True)
# d_alt = d.copy()
# d_alt["mean_mae_final"] = mean_mae
# ... y sustituyes 'primary_mae' por 'mean_mae_final' en el Pareto/ranking.
