<a id="top"></a>
# 03B ‚Äî B√∫squeda de hiperpar√°metros con Optuna (Continual Learning)

**Qu√© hace este notebook**  
Automatiza la **optimizaci√≥n de hiperpar√°metros (HPO)** para m√©todos de aprendizaje continuo manteniendo **coherencia total** con `configs/presets.yaml`. Permite afinar, entre otros:

- **naive** (l√≠nea base),
- **ewc** (Elastic Weight Consolidation),
- **rehearsal** (rejuego con buffer),
- **rehearsal+ewc** (combinaci√≥n),
- **as-snn** (bio-inspirado; espacio de b√∫squeda incluido).

Se apoya en:
- `configs/presets.yaml` (misma configuraci√≥n que el resto de cuadernos),
- `build_make_loader_fn` (elige **H5 offline** o **CSV + codificaci√≥n en runtime**),
- `run_continual` (entrena, eval√∫a y guarda m√©tricas en `outputs/`).

---

### M√©trica objetivo (minimizar)

\[
\textbf{Objetivo}= \mathrm{MAE}_{\text{tarea final}} + \alpha \cdot \max\!\bigl(0,\; \text{OlvidoRelativo}\,\%\bigr)
\]

- **MAE de la tarea final**: rendimiento en la **√∫ltima** tarea de la secuencia.  
- **Olvido relativo (%)**: degradaci√≥n de la **primera** tarea tras aprender la/s siguiente/s.  
- **Œ±**: peso del olvido (por defecto **0.5**). S√∫belo si quieres penalizar m√°s el olvido.

> La extracci√≥n de m√©tricas se realiza desde `continual_results.json`. Si falta informaci√≥n, ese *trial* se considera peor (valor infinito) para no sesgar el estudio.

---

## ‚úÖ Prerrequisitos
- `pip install optuna`  
- **Datos** preparados (`tasks.json` o `tasks_balanced.json` desde 01/01A).  
- Si el preset usa **offline** (`use_offline_spikes: true`), tener los **H5** generados con **02_ENCODE_OFFLINE** (mismo `encoder/T/gain/size/to_gray`).

> **Consejo**: empieza con el preset `fast` y pocos *trials*; si todo es estable, sube `N_TRIALS` y/o `epochs`.


<a id="toc"></a>

## üß≠ √çndice
- [1) Imports y setup de entorno](#sec-01)  
- [2) Carga de preset y construcci√≥n de modelo/transform](#sec-02)  
- [3) Tareas y factory de loaders](#sec-03)  
- [4) M√©tricas y objetivo para Optuna](#sec-04)  
- [5) Espacios de b√∫squeda (por m√©todo)](#sec-05)  
- [6) Estudio Optuna ‚Äî un m√©todo concreto](#sec-06)  
- [7) Estudio Optuna conjunto (elige m√©todo + HPs)](#sec-07)  
- [8) Re-entrena con los mejores hiperpar√°metros](#sec-08)  
- [9) Resumen r√°pido de *runs* (tabla)](#sec-09)



<a id="sec-01"></a>
## 1) Imports y setup de entorno

**Objetivo**  
Configurar el entorno de HPO de forma reproducible y eficiente:

- Limitar hilos BLAS (`OMP/MKL/OPENBLAS`) para evitar sobrecarga de CPU.  
- Detectar `ROOT` (ra√≠z del repo) y a√±adirlo a `sys.path`.  
- Importar utilidades del proyecto (`load_preset`, `build_make_loader_fn`, `run_continual`, etc.).  
- Seleccionar dispositivo (`cuda` si est√° disponible) y activar optimizaciones de PyTorch (TF32/cuDNN).

> Si tu CPU va justa, baja `torch.set_num_threads(4)` a `2`.  

[‚Üë Volver al √≠ndice](#toc)

In [None]:
# Limitar threads BLAS (opcional)
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TRAIN_LOG_ITPS"] = "1"   # logs it/s

from pathlib import Path
import sys, json, copy, time
import torch
import optuna

# Ra√≠z del repo y sys.path
ROOT = Path.cwd().parents[0] if (Path.cwd().name == "notebooks") else Path.cwd()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

# Proyecto
from src.config import load_preset, build_make_loader_fn
from src.datasets import ImageTransform, AugmentConfig
from src.models import build_model
from src.runner import run_continual

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_num_threads(4)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")

print("Device:", device)


<a id="sec-02"></a>
## 2) Carga de preset y construcci√≥n de modelo/transform

**Objetivo**  
Tomar la **fuente de la verdad** del experimento desde `configs/presets.yaml` y derivar:

- **Modelo/transform** (`ImageTransform`) seg√∫n `img_w`, `img_h`, `to_gray`.  
- **Codificaci√≥n temporal** (`encoder` ‚àà `{rate, latency, raw}`, `T`, `gain`, `seed`).  
- **Flags de datos**: `use_offline_spikes` (H5 offline) y/o `encode_runtime` (codificaci√≥n en GPU).  
- **DataLoader**: `num_workers`, `prefetch_factor`, `pin_memory`, `persistent_workers`.  
- **Augment** (`aug_train`) y **balanceo online** si procede.

Se define `make_model_fn(tfm)` para instanciar el modelo con los par√°metros adecuados.  

[‚Üë Volver al √≠ndice](#toc)

In [None]:
PRESET = "accurate"  # fast | std | accurate
CFG = load_preset(ROOT / "configs" / "presets.yaml", PRESET)

# Modelo / tfm
MODEL_NAME = CFG["model"]["name"]
tfm = ImageTransform(
    CFG["model"]["img_w"],
    CFG["model"]["img_h"],
    to_gray=bool(CFG["model"]["to_gray"]),
    crop_top=None
)

# Datos / codificaci√≥n temporal
ENCODER = CFG["data"]["encoder"]
T       = int(CFG["data"]["T"])
GAIN    = float(CFG["data"]["gain"])
SEED    = int(CFG["data"]["seed"])

# Flags & loader
USE_OFFLINE_SPIKES = bool(CFG["data"].get("use_offline_spikes", False))
RUNTIME_ENCODE     = bool(CFG["data"].get("encode_runtime", False))

NUM_WORKERS = int(CFG["data"].get("num_workers") or 0)
PREFETCH    = int(CFG["data"].get("prefetch_factor") or 2)
PIN_MEMORY  = bool(CFG["data"].get("pin_memory", True))
PERSISTENT  = bool(CFG["data"].get("persistent_workers", True))

AUG_CFG = AugmentConfig(**(CFG["data"].get("aug_train") or {})) \
          if CFG["data"].get("aug_train") else None

USE_ONLINE_BAL = bool(CFG["data"].get("balance_online", False))
BAL_BINS = int(CFG["data"].get("balance_bins") or 50)
BAL_EPS  = float(CFG["data"].get("balance_smooth_eps") or 1e-3)

print(f"[PRESET={PRESET}] model={MODEL_NAME} {tfm.w}x{tfm.h} gray={tfm.to_gray}")
print(f"[DATA] encoder={ENCODER} T={T} gain={GAIN} seed={SEED}")
print(f"[LOADER] workers={NUM_WORKERS} prefetch={PREFETCH} pin={PIN_MEMORY} persistent={PERSISTENT}")
print(f"[BALANCE] online={USE_ONLINE_BAL} bins={BAL_BINS}")
print(f"[RUNTIME_ENCODE] {RUNTIME_ENCODE} | [OFFLINE_SPIKES] {USE_OFFLINE_SPIKES}")

def make_model_fn(tfm):
    # kwargs espec√≠ficos de pilotnet_snn; ignora para otros
    return build_model(MODEL_NAME, tfm, beta=0.9, threshold=0.5)


<a id="sec-03"></a>
## 3) Tareas y factory de loaders

**Objetivo**  
Construir la lista de tareas y un **factory de DataLoaders** coherente con el preset:

- Se elige `tasks_balanced.json` si `prep.use_balanced_tasks: true` y existe; si no, `tasks.json`.  
- Si `use_offline_spikes: true`, se verifican los **H5** esperados (nomenclatura fija con `encoder/T/gain/size/color`).  
- `build_make_loader_fn(...)` selecciona autom√°ticamente **H5** (offline) o **CSV + runtime encode** en GPU.  
- El *wrapper* `make_loader_fn(...)` solo **propaga kwargs** (augment, balanceo online, *workers*, etc.) para que el *runner* no cambie.

> Si usas *tasks* balanceadas, el **train** debe ser `train_balanced.csv` (o su H5 derivado). El notebook lo comprueba.  

[‚Üë Volver al √≠ndice](#toc)

In [None]:
from pathlib import Path as _P
import json

PROC = ROOT / "data" / "processed"

USE_BALANCED = bool(CFG.get("prep", {}).get("use_balanced_tasks", False))
tb_name = (CFG.get("prep", {}).get("tasks_balanced_file_name") or "tasks_balanced.json")
t_name  = (CFG.get("prep", {}).get("tasks_file_name")            or "tasks.json")

cand_bal = PROC / tb_name
cand_std = PROC / t_name
TASKS_FILE = cand_bal if (USE_BALANCED and cand_bal.exists()) else cand_std

with open(TASKS_FILE, "r", encoding="utf-8") as f:
    tasks_json = json.load(f)
task_list = [{"name": n, "paths": tasks_json["splits"][n]} for n in tasks_json["tasks_order"]]

print("Usando:", TASKS_FILE.name)
print("Tareas y TRAIN CSV/H5 a usar:")
for t in task_list:
    print(f" - {t['name']}: {_P(t['paths']['train']).name}")

# Guardarra√≠l: si balanced, exigir train_balanced.csv
if USE_BALANCED:
    for t in task_list:
        train_path = _P(tasks_json["splits"][t["name"]]["train"])
        if train_path.name != "train_balanced.csv":
            raise RuntimeError(
                f"[{t['name']}] Esperaba 'train_balanced.csv' en modo balanced, pero encontr√© '{train_path.name}'."
            )

# Si usas H5 offline, chequear que existan
if USE_OFFLINE_SPIKES:
    mw, mh = CFG["model"]["img_w"], CFG["model"]["img_h"]
    color = "gray" if CFG["model"]["to_gray"] else "rgb"
    gain_tag = (GAIN if ENCODER == "rate" else 0)
    missing = []
    for t in task_list:
        base = PROC / t["name"]
        for split in ("train", "val", "test"):
            p = base / f"{split}_{ENCODER}_T{T}_gain{gain_tag}_{color}_{mw}x{mh}.h5"
            if not p.exists():
                missing.append(str(p))
    if missing:
        print("[WARN] Faltan H5. Genera primero con 02_ENCODE_OFFLINE.ipynb (o tools/encode_tasks.py).")

# Factory de loaders con kwargs del preset
from src.utils import build_make_loader_fn

_raw_make_loader_fn = build_make_loader_fn(
    root=ROOT, use_offline_spikes=USE_OFFLINE_SPIKES, encode_runtime=RUNTIME_ENCODE,
)
_DL_KW = dict(
    num_workers=NUM_WORKERS,
    prefetch_factor=PREFETCH,
    pin_memory=PIN_MEMORY,
    persistent_workers=PERSISTENT,
    aug_train=AUG_CFG,
    balance_train=USE_ONLINE_BAL,
    balance_bins=BAL_BINS,
    balance_smooth_eps=BAL_EPS,
)
def make_loader_fn(task, batch_size, encoder, T, gain, tfm, seed, **dl_kwargs):
    """Wrapper pass-through: el runner a√±ade dl_kwargs; aqu√≠ solo los propagamos."""
    return _raw_make_loader_fn(
        task=task, batch_size=batch_size, encoder=encoder, T=T, gain=gain, tfm=tfm, seed=seed,
        **{**_DL_KW, **dl_kwargs}
    )
print("make_loader_fn listo.")


<a id="sec-04"></a>
## 4) M√©tricas y objetivo para Optuna

**Objetivo**  
Definir la **funci√≥n objetivo** del HPO a partir de las m√©tricas almacenadas:

1. Leer `continual_results.json` del *run*.  
2. Identificar **primera** y **√∫ltima** tarea.  
3. Extraer:
   - `c1_mae`: MAE de la primera tarea en su propio test.  
   - `c1_after_c2_mae`: MAE de la primera **tras** aprender la √∫ltima (olvido).  
   - `c2_mae`: MAE de la **√∫ltima** tarea.  
4. Calcular **Olvido relativo (%)** = \((c1\_after\_c2 - c1\_mae)/c1\_mae \times 100\).

La **p√©rdida** a minimizar es:  
\[
\text{Objetivo}= \mathrm{MAE}_{\text{tarea final}} + \alpha \cdot \max(0, \text{olvido rel. } \%)
\]

> Ajusta `ALPHA_FORGET` para priorizar estabilidad (olvido bajo) vs. desempe√±o en la √∫ltima tarea.  


[‚Üë Volver al √≠ndice](#toc)

In [None]:
# === M√©tricas y objetivo: usa utilidades del repo ===
from src.hpo import objective_value, composite_objective, extract_metrics, safe_read_json
from pathlib import Path
ALPHA_FORGET = 0.5  # peso del olvido relativo en el objetivo compuesto

def _load_results(out_dir: Path) -> dict:
    """Lee outputs/<exp>/continual_results.json con manejo robusto."""
    return safe_read_json(Path(out_dir) / "continual_results.json")


<a id="sec-05"></a>
## 5) Espacios de b√∫squeda (por m√©todo)

**Objetivo**  
Declarar qu√© hiperpar√°metros explora Optuna:

- **ewc**  
  - `lam` \(\in [3\cdot 10^8, 2\cdot 10^9]\) (log-uniform).  
  - `fisher_batches` ‚àà {200, ‚Ä¶, 1200}.  
- **rehearsal**  
  - `buffer_size` ‚àà {1000, ‚Ä¶, 8000}.  
  - `replay_ratio` ‚àà [0.05, 0.4].  
- **rehearsal+ewc**: combina ambos.  
- **as-snn**  
  - `gamma_ratio` ‚àà [0.3, 0.8], `lambda_a` ‚àà [1.0, 4.0], `ema` ‚àà [0.70, 0.98].  
- **naive**: sin HPs.

> Puedes extender a otros m√©todos (p. ej., `sa-snn`, `sca-snn`, `colanet`) a√±adiendo su espacio de b√∫squeda y registrando el nombre.  


[‚Üë Volver al √≠ndice](#toc)

In [None]:
def suggest_params_for_method(trial: optuna.Trial, method: str, preset: str) -> dict:
    method = method.lower()
    preset = preset.lower()

    if method == "ewc":
        lam_grid = {
            "fast":     [7e8, 1e9],
            "std":      [1.5e8, 4e8, 1e9],
            "accurate": [5e8, 7e8, 1e9],
        }
        fisher_grid = {
            "fast":     [300, 1000],
            "std":      [500, 1000],
            "accurate": [800, 1200],
        }
        lam = trial.suggest_categorical("lam", lam_grid.get(preset, lam_grid["std"]))
        fb  = trial.suggest_categorical("fisher_batches", fisher_grid.get(preset, fisher_grid["std"]))
        return {"lam": float(lam), "fisher_batches": int(fb)}

    elif method == "rehearsal":
        buffer_size  = trial.suggest_int("buffer_size", 1000, 8000, step=1000)
        replay_ratio = trial.suggest_float("replay_ratio", 0.05, 0.4, step=0.05)
        return {"buffer_size": buffer_size, "replay_ratio": replay_ratio}

    elif method == "rehearsal+ewc":
        # Igual que arriba pero preset-aware para lam/fisher
        lam_grid = {
            "fast":     [7e8, 1e9],
            "std":      [1.5e8, 4e8, 1e9],
            "accurate": [5e8, 7e8, 1e9],
        }
        fisher_grid = {
            "fast":     [300, 1000],
            "std":      [500, 1000],
            "accurate": [800, 1200],
        }
        lam = trial.suggest_categorical("lam", lam_grid.get(preset, lam_grid["std"]))
        fisher_batches = trial.suggest_categorical("fisher_batches", fisher_grid.get(preset, fisher_grid["std"]))
        buffer_size  = trial.suggest_int("buffer_size", 1000, 8000, step=1000)
        replay_ratio = trial.suggest_float("replay_ratio", 0.05, 0.4, step=0.05)
        return {
            "buffer_size": buffer_size,
            "replay_ratio": replay_ratio,
            "lam": float(lam),
            "fisher_batches": int(fisher_batches),
        }

    elif method == "as-snn":
        gamma_ratio = trial.suggest_float("gamma_ratio", 0.3, 0.8, step=0.1)
        lambda_a    = trial.suggest_float("lambda_a", 1.0, 4.0)
        ema         = trial.suggest_float("ema", 0.70, 0.98)
        return {"gamma_ratio": gamma_ratio, "lambda_a": lambda_a, "ema": ema}

    else:
        return {}


<a id="sec-06"></a>
## 6) Estudio Optuna ‚Äî un m√©todo concreto

**Objetivo**  
Optimizar **un m√©todo** espec√≠fico (`METHOD_TO_OPTIMIZE`) durante `N_TRIALS`.

Flujo de cada *trial*:
1. Sugerir HPs (`suggest_*`).  
2. Construir `cfg` con esos HPs (y, si `HPO_EPOCHS` est√° definido, **reducir epochs** solo para HPO).  
3. Ejecutar `run_continual(...)`.  
4. Leer resultados, computar la **p√©rdida objetivo** y devolverla a Optuna.

**Persistencia**  
Se usa **SQLite** en `outputs/optuna/` para reanudar estudios y registrar todos los *trials* (`*_trials.csv`).

> Para m√©todos con *replay* se desactiva `persistent_workers` por estabilidad de DataLoader (puedes ajustar *AMP/pin_memory/workers* si lo necesitas).  

[‚Üë Volver al √≠ndice](#toc)

In [None]:
# === Celda 6: Configuraci√≥n del estudio + objetivo Optuna (con TAG autom√°tico) ===
import copy, json, inspect, hashlib, os
from pathlib import Path
import optuna, torch, gc, time
from src.telemetry import read_emissions_kg  # registra emisiones si hay CodeCarbon

# --- Par√°metros del estudio ---
METHOD_TO_OPTIMIZE = "ewc"   # "naive" | "ewc" | "rehearsal" | "rehearsal+ewc" | "as-snn"
N_TRIALS = 2                 # s√∫belo cuando est√©s satisfecho
HPO_EPOCHS = None            # None -> epochs del preset; o int para acelerar
REHEARSAL_NAMES = ("rehearsal", "rehearsal+ewc")

def build_cfg_with_method(base_cfg: dict, method_name: str, params: dict, hpo_epochs: int|None):
    cfg = copy.deepcopy(base_cfg)
    cfg["continual"]["method"] = method_name
    cfg["continual"]["params"] = params or {}

    # Perfil equilibrado-r√°pido para HPO
    cfg["optim"]["amp"] = True
    cfg["data"]["pin_memory"] = True
    cfg["data"]["persistent_workers"] = False  # evita fugas/hangs entre trials
    cfg["data"]["num_workers"] = min(max(2, int(cfg["data"].get("num_workers") or 2)), 4)

    cfg.setdefault("logging", {}).setdefault("telemetry", {})["codecarbon"] = False

    # Si HPO de replay, baja un poco el riesgo
    if method_name.lower() in REHEARSAL_NAMES:
        cfg["optim"]["amp"] = False
        cfg["data"]["pin_memory"] = False
        cfg["data"]["num_workers"] = min(int(cfg["data"].get("num_workers") or 0), 2)

    if hpo_epochs is not None:
        cfg["optim"]["epochs"] = int(hpo_epochs)
    return cfg

def run_one_cfg(cfg: dict) -> tuple[Path, dict, dict]:
    out_dir, res = run_continual(
        task_list=task_list,
        make_loader_fn=make_loader_fn,
        make_model_fn=make_model_fn,
        tfm=tfm,
        cfg=cfg,
        preset_name=PRESET,
        out_root=ROOT / "outputs",
        verbose=True,
    )
    results = _load_results(out_dir) or (res if isinstance(res, dict) else {})
    return out_dir, res, results

def optuna_objective(trial: optuna.Trial):
    params = suggest_params_for_method(trial, METHOD_TO_OPTIMIZE, PRESET)
    cfg_i  = build_cfg_with_method(CFG, METHOD_TO_OPTIMIZE, params, HPO_EPOCHS)

    # Etiqueta del run
    cfg_i.setdefault("naming", {})
    tag = f"{METHOD_TO_OPTIMIZE}_hpo_t{trial.number}"
    if METHOD_TO_OPTIMIZE in ("ewc", "rehearsal+ewc") and "lam" in params:
        tag += f"_lam_{params['lam']:.1e}"
    cfg_i["naming"]["tag"] = tag

    try:
        out_dir, _, results = run_one_cfg(cfg_i)
        metrics = extract_metrics(results)
        val = composite_objective(metrics, ALPHA_FORGET)

        emissions = read_emissions_kg(out_dir)
        if emissions is not None:
            trial.set_user_attr("emissions_kg", float(emissions))

        trial.set_user_attr("out_dir", str(out_dir))
        trial.set_user_attr("metrics", metrics)
        trial.set_user_attr("method", METHOD_TO_OPTIMIZE)
        trial.set_user_attr("params", params)
        return val
    except Exception as e:
        trial.set_user_attr("error", repr(e))
        return float("inf")
    finally:
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        time.sleep(0.5)

# -------------------------------
# TAG AUTOM√ÅTICO DEL ESPACIO HPO
# -------------------------------
def _space_fingerprint() -> str:
    bits = {
        "method": METHOD_TO_OPTIMIZE,
        "preset": PRESET,
        "model": MODEL_NAME,
        "encoder": ENCODER,
        "T": T,
        "gain": GAIN,
        "torch": torch.__version__,
        "optuna": optuna.__version__,
        "suggest_src": inspect.getsource(suggest_params_for_method),
    }
    g = suggest_params_for_method.__globals__
    for k in ("lam_grid", "fisher_grid", "replay_grid", "as_snn_grid"):
        if k in g:
            try:
                bits[k] = g[k]
            except Exception:
                bits[k] = str(g[k])

    raw = json.dumps(bits, sort_keys=True, default=str)
    return hashlib.sha1(raw.encode("utf-8")).hexdigest()[:10]

HPO_TAG = os.getenv("HPO_TAG_OVERRIDE") or f"space_{_space_fingerprint()}"

# --- Persistencia Optuna en SQLite ---
OPTUNA_DIR = ROOT / "outputs" / "optuna"
OPTUNA_DIR.mkdir(parents=True, exist_ok=True)

DB_PATH = OPTUNA_DIR / f"hpo_{METHOD_TO_OPTIMIZE}_{PRESET}_{MODEL_NAME}_{ENCODER}_T{T}_g{GAIN}_{HPO_TAG}.sqlite"
STORAGE = f"sqlite:///{DB_PATH}"
STUDY_NAME = f"HPO_{METHOD_TO_OPTIMIZE}_{PRESET}_{ENCODER}_T{T}_{HPO_TAG}"

study = optuna.create_study(
    direction="minimize",
    study_name=STUDY_NAME,
    storage=STORAGE,
    load_if_exists=True,
)
study.optimize(optuna_objective, n_trials=N_TRIALS, show_progress_bar=True)

print("HPO_TAG:", HPO_TAG)
print("SQLite:", DB_PATH)
print("Best value:", study.best_value)
print("Best params:", study.best_params)
print("Best attrs:", study.best_trial.user_attrs)

# Guardar trazabilidad de intentos a CSV
df_trials = study.trials_dataframe(attrs=("number","value","state","params","user_attrs"))
df_trials.to_csv(OPTUNA_DIR / f"{DB_PATH.stem}_trials.csv", index=False)

print("Best value:", study.best_value)
print("Best params:", study.best_params)
print("Best attrs:", study.best_trial.user_attrs)


<a id="sec-07"></a>
## 7) Estudio Optuna conjunto (elige m√©todo + HPs)

**Objetivo**  
Permitir que **cada *trial*** elija tambi√©n el **m√©todo** (`naive`, `ewc`, `rehearsal`, `rehearsal+ewc`, `as-snn`) adem√°s de sus HPs. √ötil cuando:

- No tienes claro qu√© m√©todo se adapta mejor a tu conjunto de datos.  
- Quieres una **comparativa autom√°tica** con el mismo presupuesto de c√≥mputo.

> Activa `RUN_JOINT=True` para lanzar este estudio y eleva `N_TRIALS_JOINT` cuando el flujo sea estable.  

[‚Üë Volver al √≠ndice](#toc)

In [None]:
RUN_JOINT = False   # Pon True si quieres lanzar el estudio conjunto
N_TRIALS_JOINT = 10

def optuna_objective_joint(trial: optuna.Trial):
    method = trial.suggest_categorical("method", ["naive","ewc","rehearsal","rehearsal+ewc","as_snn"])
    params = suggest_params_for_method(trial, method, PRESET)
    cfg_i  = build_cfg_with_method(CFG, method, params, HPO_EPOCHS)

    if method.lower() in ("rehearsal","rehearsal+ewc"):
        cfg_i["data"]["persistent_workers"] = False

    cfg_i.setdefault("naming", {})
    tag = f"{method}_hpo_t{trial.number}"
    if method in ("ewc", "rehearsal+ewc") and "lam" in params:
        tag += f"_lam_{params['lam']:.1e}"
    cfg_i["naming"]["tag"] = tag

    try:
        out_dir, _, results = run_one_cfg(cfg_i)
        metrics = extract_metrics(results)
        val = objective_value(metrics, ALPHA_FORGET)
        trial.set_user_attr("out_dir", str(out_dir))
        trial.set_user_attr("metrics", metrics)
        trial.set_user_attr("method", method)
        trial.set_user_attr("params", params)
        return val
    except Exception as e:
        trial.set_user_attr("error", repr(e))
        return float("inf")
    finally:
        import gc, time
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        time.sleep(0.5)

if RUN_JOINT:
    study_joint = optuna.create_study(direction="minimize", study_name="HPO_joint")
    study_joint.optimize(optuna_objective_joint, n_trials=N_TRIALS_JOINT, show_progress_bar=True)
    print("Best value:", study_joint.best_value)
    print("Best params:", study_joint.best_params)
    print("Best attrs:", study_joint.best_trial.user_attrs)
else:
    print("RUN_JOINT=False ‚Äî omitido.")


<a id="sec-08"></a>
## 8) Re-entrena con los mejores hiperpar√°metros

**Objetivo**  
Tomar `study.best_params` (y el m√©todo ganador) y **re-entrenar a pleno rendimiento**:

- Reconstruir `cfg_best` con los HPs √≥ptimos.  
- (Opcional) Restaurar `optim.epochs` del preset si usaste `HPO_EPOCHS` reducido durante la b√∫squeda.  
- Ejecutar `run_continual(...)` y mostrar las **m√©tricas finales** y la **carpeta de salida**.

> Separa la **b√∫squeda r√°pida** (menos epochs) del **entrenamiento definitivo** (epochs del preset).  

[‚Üë Volver al √≠ndice](#toc)

In [None]:
# Usa el mejor del estudio por m√©todo (arriba)
BEST_PARAMS = study.best_params
BEST_METHOD = METHOD_TO_OPTIMIZE
print("Mejor m√©todo:", BEST_METHOD)
print("Mejores HPs:", BEST_PARAMS)

cfg_best = copy.deepcopy(CFG)
cfg_best["continual"]["method"] = BEST_METHOD
cfg_best["continual"]["params"] = BEST_PARAMS

# Reponer ajustes del preset para el re-train
cfg_best["optim"]["amp"] = True
cfg_best["data"]["persistent_workers"] = CFG["data"]["persistent_workers"]
cfg_best["data"]["pin_memory"] = CFG["data"]["pin_memory"]
cfg_best["data"]["num_workers"] = CFG["data"]["num_workers"]

from src.telemetry import read_emissions_kg
out_dir, _, results = run_one_cfg(cfg_best)
metrics = extract_metrics(results)
print("Resultados finales (re-train):", metrics)
print("Emisiones totales (kg CO2e):", read_emissions_kg(out_dir))
print("Guardado en:", out_dir)


<a id="sec-09"></a>
## 9) Resumen r√°pido de runs (tabla)

**Objetivo**  
Inspeccionar resultados de outputs/continual_*:

- Construir una tabla con `preset`, `method`, `encoder`, `seed` (y `lambda` si aplica),
junto a `c1_mae`, `c1_after_c2_mae`, `forget_rel_%`, `c2_mae`. 
- Ordenar/filtrar para comparar manzanas con manzanas (mismo preset/encoder).

> Exporta a CSV/Parquet para gr√°ficas comparativas. Si la tabla sale vac√≠a, revisa que existan `continual_results.json`. 

[‚Üë Volver al √≠ndice](#toc)

In [None]:
# === Lanzador ligero sin Optuna (accurate) ===
from copy import deepcopy
from pathlib import Path
import json
from src.runner import run_continual
from src.models import build_model as _build_model

def build_model(tfm):
    return _build_model(MODEL_NAME, tfm, beta=0.9, threshold=0.5)

EXPS = [
    dict(method="naive", params={}, tag="grid01"),
    dict(method="rehearsal", params={"buffer_size": 3000, "replay_ratio": 0.10}, tag="grid02_rr10"),
    dict(method="rehearsal", params={"buffer_size": 3000, "replay_ratio": 0.15}, tag="grid03_rr15"),
    dict(method="rehearsal", params={"buffer_size": 3000, "replay_ratio": 0.20}, tag="grid04_rr20"),
    dict(method="rehearsal", params={"buffer_size": 3000, "replay_ratio": 0.25}, tag="grid05_rr25"),
    dict(method="sca-snn", params={"attach_to": "f6","flatten_spatial": False,"num_bins": 50,
                                   "bin_lo": -1.0,"bin_hi": 1.0,"anchor_batches": 12,
                                   "max_per_bin": 512,"beta": 0.65,"bias": 0.0,
                                   "soft_mask_temp": 0.75,"habit_decay": 0.995,"verbose": True,"log_every": 750},
         tag="grid06_sca_b065"),
    dict(method="sca-snn", params={"attach_to": "f6","flatten_spatial": False,"num_bins": 50,
                                   "bin_lo": -1.0,"bin_hi": 1.0,"anchor_batches": 12,
                                   "max_per_bin": 512,"beta": 0.60,"bias": 0.0,
                                   "soft_mask_temp": 0.75,"habit_decay": 0.995,"verbose": True,"log_every": 750},
         tag="grid07_sca_b060"),
    dict(method="sca-snn", params={"attach_to": "f6","flatten_spatial": False,"num_bins": 50,
                                   "bin_lo": -1.0,"bin_hi": 1.0,"anchor_batches": 12,
                                   "max_per_bin": 512,"beta": 0.70,"bias": 0.0,
                                   "soft_mask_temp": 0.75,"habit_decay": 0.995,"verbose": True,"log_every": 750},
         tag="grid08_sca_b070"),
    dict(method="as-snn", params={"gamma_ratio": 0.3, "lambda_a": 1.6, "ema": 0.9}, tag="grid09_as"),
    dict(method="sa-snn", params={"k": 8, "tau": 28, "thresh_lo": 1.2, "period": 200000}, tag="grid10_sa"),
    dict(method="ewc", params={"lam": 7e8, "fisher_batches": 1000}, tag="grid11_ewc"),
]

OUTS = []
for i, e in enumerate(EXPS, 1):
    cfg_i = deepcopy(CFG)
    cfg_i["continual"]["method"] = e["method"]
    cfg_i["continual"]["params"] = e["params"]
    cfg_i.setdefault("naming", {})["tag"] = e.get("tag", f"grid{i:02d}")
    cfg_i["optim"]["es_patience"] = None

    print(f"\n>>> {i}/{len(EXPS)} :: {e['method']} :: {e['params']}")
    out_dir, _ = run_continual(task_list, make_loader_fn, build_model, tfm, cfg_i, PRESET, out_root=ROOT/"outputs", verbose=True)
    OUTS.append(out_dir)
    print("Guardado en:", out_dir)

import pandas as pd, glob, os, json
from pathlib import Path

# 1) Combinar todas las filas resumen (run_row.csv)
rows = []
for f in glob.glob(str(ROOT / "outputs" / "continual_*" / "run_row.csv")):
    try:
        rows.append(pd.read_csv(f))
    except Exception:
        pass
df_all = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame()
if df_all.empty:
    print("No se encontraron run_row.csv")
else:
    df_all = df_all[df_all["preset"]=="accurate"].copy()
    out_all = ROOT / "outputs" / "all_runs_accurate.csv"
    df_all.to_csv(out_all, index=False)
    print("Guardado:", out_all)

# 2) Elegir el mejor por m√©todo (score simple)
def _pick_best(g):
    g = g.copy()
    g["r1"] = g["c2_final_mae"]
    g["r2"] = g["avg_forget_rel"]
    g["r3"] = g["emissions_kg"].fillna(0)
    return g.sort_values(["r1","r2","r3"], na_position="last").head(1).drop(columns=["r1","r2","r3"])

df_best = (df_all.groupby("method", group_keys=False)
           .apply(_pick_best)
           .reset_index(drop=True))
out_best = ROOT / "outputs" / "best_by_method_accurate.csv"
df_best.to_csv(out_best, index=False)
print("Guardado:", out_best)

# 3) Tabla Œî% vs naive (emisiones y tiempo)
if "naive" in df_best["method"].values:
    base = df_best[df_best["method"]=="naive"].iloc[0]
    base_e = float(base["emissions_kg"]) if pd.notna(base["emissions_kg"]) else None
    base_t = float(base["elapsed_sec"])
    def pct(x, base):
        if base in (None, 0) or pd.isna(x):
            return None
        return 100.0 * (float(x) - float(base)) / float(base)
    table_rows = []
    for _, r in df_best.iterrows():
        table_rows.append({
            "method": r["method"],
            "c2_final_mae": r["c2_final_mae"],
            "avg_forget_rel": r["avg_forget_rel"],
            "emissions_kg": r["emissions_kg"],
            "elapsed_sec": r["elapsed_sec"],
            "Œî%_emissions_vs_naive": None if base_e is None else pct(r["emissions_kg"], base_e),
            "Œî%_tiempo_vs_naive": pct(r["elapsed_sec"], base_t),
        })
    df_delta = pd.DataFrame(table_rows).sort_values(["c2_final_mae","avg_forget_rel"], na_position="last")
    out_delta = ROOT / "outputs" / "delta_vs_naive.csv"
    df_delta.to_csv(out_delta, index=False)
    print("Guardado:", out_delta)
else:
    print("No hay baseline naive en df_best; no puedo calcular Œî% vs naive.")

# 4) Recoger curvas de Rehearsal (C2) para figuras
reh_curves = []
for d in glob.glob(str(ROOT / "outputs" / "continual_*rehearsal*" )):
    for f in glob.glob(os.path.join(d, "task_*_*", "loss_curves.csv")):
        if any(k in f.lower() for k in ["c2","circuito2","task_2","track2"]):
            reh_curves.append(f)
print("Curvas C2 (Rehearsal) encontradas:", len(reh_curves))


## Ap√©ndice ‚Äî Optuna en 90 segundos

- **Study**: el proyecto de HPO (contiene todos los trials).
- **Trial**: una evaluaci√≥n con un conjunto de HPs (`suggest_int`, `suggest_float`, etc.).
- **Sampler**: estrategia para elegir el siguiente punto (por defecto, TPE).
- **Pruner**: **corta** trials que pintan mal (acelera b√∫squedas largas).
- **Storage**: base de datos (SQLite, PostgreSQL) para **reanudar** y/o **paralelizar**.

### Reanudar b√∫squedas
Puedes crear el estudio con almacenamiento:
```python
study = optuna.create_study(
    direction="minimize",
    study_name="HPO_ewc",
    storage=f"sqlite:///{ROOT/'outputs'/'optuna_ewc.sqlite'}",
    load_if_exists=True,
)


In [None]:
from src.hpo import objective_value, composite_objective
m = {"c2_mae": 0.16, "forget_rel_%": 12.5}
print(objective_value(m, key="forget_rel_%"))         # 12.5
print(objective_value(m, key="c1_forgetting_mae_rel_%"))  # 12.5 (alias OK)
print(composite_objective(m, alpha=0.5))              # 0.16 + 0.5*12.5 = 6.41
