In [None]:
# =============================================================================
# Imports y setup
# =============================================================================
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

from pathlib import Path
import sys, json, torch

ROOT = Path.cwd().parents[0] if (Path.cwd().name == "notebooks") else Path.cwd()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

from src.utils import set_seeds, load_preset
from src.datasets import ImageTransform, AugmentConfig
from src.models import SNNVisionRegressor
from src.training import TrainConfig
from src.eval import eval_loader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42

tfm = ImageTransform(160, 80, True, None)
def make_model_fn(tfm):
    return SNNVisionRegressor(in_channels=(1 if tfm.to_gray else 3), lif_beta=0.95)

torch.set_num_threads(4)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")


In [None]:
GPU_ENCODE = True

SAFE_MODE = False
NUM_WORKERS    = 12
PREFETCH       = 2
PIN_MEMORY     = True
PERSISTENT     = True

AUG_CFG_LIGHT = AugmentConfig(prob_hflip=0.5, brightness=None, gamma=None, noise_std=0.0)
AUG_CFG_FULL  = AugmentConfig(prob_hflip=0.5, brightness=(0.9, 1.1), gamma=(0.95, 1.05), noise_std=0.005)
AUG_CFG = AUG_CFG_LIGHT

USE_OFFLINE_BALANCED = True
USE_ONLINE_BALANCING = False

if SAFE_MODE:
    NUM_WORKERS = 0
    PREFETCH = None
    PIN_MEMORY = False
    PERSISTENT = False
    USE_OFFLINE_BALANCED = False
    USE_ONLINE_BALANCING = False
    AUG_CFG = None

print(f"[SAFE_MODE={SAFE_MODE}] workers={NUM_WORKERS} prefetch={PREFETCH} pin={PIN_MEMORY} persistent={PERSISTENT}")


In [None]:
# =============================================================================
# Verificación de datos (normal y, si existe, balanceado offline)
# =============================================================================
from pathlib import Path

RAW  = ROOT/"data"/"raw"/"udacity"
PROC = ROOT/"data"/"processed"

RUNS = ["circuito1","circuito2"]  # ajusta si hace falta

missing = []
for run in RUNS:
    base = PROC / run

    # Comprobación obligatoria: splits normales
    for part in ["train","val","test"]:
        p = base / f"{part}.csv"
        if not p.exists():
            missing.append(str(p))

    # Comprobación opcional: train_balanced.csv (para modo OFFLINE balanceado)
    p_bal = base / "train_balanced.csv"
    if p_bal.exists():
        print(f"✓ {p_bal} OK")
    else:
        print(f"⚠️  Falta {p_bal}. Si más abajo pones USE_OFFLINE_BALANCED=True, "
              f"ejecuta 01A_PREP_BALANCED.ipynb o el script tools/make_splits_balanced.py")

if missing:
    raise FileNotFoundError(
        "Faltan CSV obligatorios (ejecuta 01A_PREP_BALANCED.ipynb o tu pipeline de prep):\n"
        + "\n".join(" - " + m for m in missing)
    )

print("OK: splits 'train/val/test' encontrados.")


In [None]:
# ===================== Balanceo: helper =====================
print(
    "Modo balanceo:",
    "OFFLINE (tasks_balanced.json)" if USE_OFFLINE_BALANCED else "ORIGINAL (tasks.json)",
    "| Balanceo ONLINE:", USE_ONLINE_BALANCING
)

# Seguridad anti doble balanceo:
if USE_OFFLINE_BALANCED and USE_ONLINE_BALANCING:
    raise RuntimeError("Doble balanceo detectado: OFFLINE y ONLINE a la vez. "
                       "Pon USE_ONLINE_BALANCING=False cuando uses train_balanced.csv.")

from pathlib import Path  # (omite esta línea si ya importaste Path arriba)

def _balance_flag(train_csv_path: str | Path) -> bool:
    """
    Activa balanceo ONLINE solo si:
    - USE_ONLINE_BALANCING == True
    - Y el CSV de train NO es 'train_balanced.csv'
    """
    is_balanced_csv = Path(train_csv_path).name == "train_balanced.csv"
    return bool(USE_ONLINE_BALANCING and not is_balanced_csv)


In [None]:
import src.training as training


In [None]:
# =============================================================================
# Elegir split: normal (tasks.json) o balanceado offline (tasks_balanced.json)
# =============================================================================
with open(PROC / ("tasks_balanced.json" if USE_OFFLINE_BALANCED else "tasks.json"), "r", encoding="utf-8") as f:
    tasks_json = json.load(f)

task_list = [{"name": n, "paths": tasks_json["splits"][n]} for n in tasks_json["tasks_order"]]

# Vista rápida: muestra el CSV de train que se usará por cada tarea
print("Tareas y su TRAIN CSV:")
for t in task_list:
    print(f" - {t['name']}: {Path(t['paths']['train']).name}")

task_list[:2]  # vista rápida

# Guardarraíl extra: si has activado el OFFLINE balanceado,
# exige que el 'train' sea train_balanced.csv y que exista.
if USE_OFFLINE_BALANCED:
    for t in task_list:
        train_path = Path(t["paths"]["train"])
        if train_path.name != "train_balanced.csv":
            raise RuntimeError(
                f"[{t['name']}] Esperaba 'train_balanced.csv' pero encontré '{train_path.name}'. "
                "Repite 01A_PREP_BALANCED.ipynb o ajusta USE_OFFLINE_BALANCED=False."
            )
        if not train_path.exists():
            raise FileNotFoundError(
                f"[{t['name']}] No existe {train_path}. Genera los balanceados con 01A_PREP_BALANCED.ipynb."
            )
    print("✔ Verificación OFFLINE balanceado superada (train_balanced.csv por tarea).")


In [None]:
from pathlib import Path
from src.utils import make_loaders_from_csvs

def make_loader_fn(task, batch_size, encoder, T, gain, tfm, seed, **dl_kwargs):
    RAW = ROOT / "data" / "raw" / "udacity" / task["name"]
    paths = task["paths"]

    # Si vamos a codificar en GPU, pedimos 4D (image) al loader;
    # si no, dejamos el encoder temporal en el propio dataset.
    encoder_for_loader = "image" if (GPU_ENCODE and encoder in {"rate","latency","raw"}) else encoder

    return make_loaders_from_csvs(
        base_dir=RAW,
        train_csv=Path(paths["train"]),
        val_csv=Path(paths["val"]),
        test_csv=Path(paths["test"]),
        batch_size=batch_size,
        encoder=encoder_for_loader,
        T=T, gain=gain, tfm=tfm, seed=SEED,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        persistent_workers=PERSISTENT,
        prefetch_factor=PREFETCH,
        # online balancing opcional:
        aug_train=AUG_CFG,
        balance_train=(USE_ONLINE_BALANCING and Path(paths["train"]).name != "train_balanced.csv"),
        balance_bins=21,
        balance_smooth_eps=1e-3,
    )


In [None]:
# === PRUEBA UNIVERSAL: loader -> (T,B,C,H,W) -> forward con fallback AMP ===
import torch, src.training as training

# --- 1) Loader pequeño con tu helper ---
tr, va, te = make_loader_fn(
    task=task_list[0],
    batch_size=8,
    encoder="rate",   # si tu pipeline ya devuelve 4D, lo detectamos abajo
    T=10,
    gain=0.5,
    tfm=tfm,
    seed=SEED,
)

xb, yb = next(iter(tr))
print("batch del loader:", xb.shape, yb.shape)

# --- 2) A (T,B,C,H,W) según formato de entrada ---
#    - Si el dataset ya codifica (5D): solo permutar.
#    - Si es 4D (imagen): activamos encode en GPU y usamos el helper runtime.
if xb.ndim == 5:  # (B,T,C,H,W)
    x5d = xb.permute(1,0,2,3,4).contiguous()
    used_runtime_encode = False
    print("dataset ya codificado; solo permuto a (T,B,C,H,W)")
elif xb.ndim == 4:  # (B,C,H,W)
    training.set_runtime_encode(mode="rate", T=10, gain=0.5,
                                device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    x5d = training._permute_if_needed(xb)  # aplica encode+permuta -> (T,B,C,H,W)
    used_runtime_encode = True
    print("dataset 4D; uso encode en GPU y permuto a (T,B,C,H,W)")
else:
    raise RuntimeError(f"Forma inesperada del batch: {xb.shape}")

print("x5d.device:", x5d.device, "| shape:", tuple(x5d.shape))

# --- 3) Modelo y forward con fallback automático AMP ---
model = make_model_fn(tfm).to(device).eval()

def forward_with_auto_amp(model, x5d, device):
    # Intento 1: AMP (solo si hay CUDA)
    if torch.cuda.is_available():
        try:
            x_amp = x5d.to(device, dtype=torch.float16, non_blocking=True)
            with torch.inference_mode(), torch.amp.autocast('cuda', enabled=True):
                y = model(x_amp)
            print("[forward] ejecutado con AMP (fp16)")
            return y
        except Exception as e:
            print("[forward] AMP falló, reintento en FP32. Motivo:", str(e))

    # Intento 2: FP32 (CPU o fallback)
    x_fp32 = x5d.to(device, dtype=torch.float32, non_blocking=True)
    with torch.inference_mode():
        y = model(x_fp32)
    print("[forward] ejecutado en FP32")
    return y

yhat = forward_with_auto_amp(model, x5d, device)
print("yhat:", tuple(yhat.shape))

# --- 4) Limpieza del runtime encode (si se usó) ---
if used_runtime_encode:
    training.set_runtime_encode(None)


In [None]:
# ===================== BENCH: toggle y eco de configuración =====================
# Usa el RUN_BENCH que ya defines en la Celda 2
print(
    f"[Bench workers={NUM_WORKERS} prefetch={PREFETCH} "
    f"pin={PIN_MEMORY} persistent={PERSISTENT} "
    f"| offline_bal={USE_OFFLINE_BALANCED} online_bal={USE_ONLINE_BALANCING}"
)


In [None]:
from src.runner import run_continual


In [None]:
# === Activar métrica de it/s por época (parche temporal) ===
import time, json
from pathlib import Path
import torch
from torch import nn, optim
from torch.amp import autocast, GradScaler
import src.training as training
from src.utils import set_seeds  # ya lo tienes importado en el notebook

# Guarda la referencia al original para poder restaurar luego
orig_train_supervised = training.train_supervised

def train_supervised_ips(model: nn.Module, train_loader, val_loader, loss_fn: nn.Module,
                         cfg, out_dir: Path, method=None):
    out_dir = Path(out_dir); out_dir.mkdir(parents=True, exist_ok=True)
    if cfg.seed is not None:
        set_seeds(cfg.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    opt = optim.Adam(model.parameters(), lr=cfg.lr)

    use_amp = bool(cfg.amp and torch.cuda.is_available())
    scaler = GradScaler(enabled=use_amp)

    history = {"train_loss": [], "val_loss": []}
    t0_total = time.perf_counter()

    for epoch in range(1, cfg.epochs + 1):
        model.train()
        running = 0.0
        nb = 0
        t_epoch0 = time.perf_counter()

        for x, y in train_loader:
            # encode/permutación runtime y subida a device (usa tu helper actual)
            x = training._permute_if_needed(x).to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            with autocast("cuda", enabled=use_amp):
                y_hat = model(x)
                loss = loss_fn(y_hat, y)
                if method is not None:
                    loss = loss + method.penalty()

            if use_amp:
                scaler.scale(loss).backward()
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(opt); scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                opt.step()

            running += loss.item()
            nb += 1

        epoch_time = time.perf_counter() - t_epoch0
        ips = nb / epoch_time if epoch_time > 0 else float("nan")
        print(f"[TRAIN it/s] epoch {epoch}/{cfg.epochs}: {ips:.1f} it/s  "
              f"({nb} iters en {epoch_time:.2f}s)")

        train_loss = running / max(1, nb)

        # --- validación ---
        model.eval()
        v_running = 0.0; nvb = 0
        with torch.no_grad():
            for x, y in val_loader:
                x = training._permute_if_needed(x).to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)
                with autocast("cuda", enabled=use_amp):
                    y_hat = model(x)
                    v_loss = loss_fn(y_hat, y)
                v_running += v_loss.item(); nvb += 1
        val_loss = v_running / max(1, nvb)
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)

    elapsed = time.perf_counter() - t0_total
    manifest = {
        "epochs": cfg.epochs, "batch_size": cfg.batch_size, "lr": cfg.lr,
        "amp": cfg.amp, "seed": cfg.seed, "elapsed_sec": elapsed,
        "device": str(device), "history": history,
    }
    (out_dir / "manifest.json").write_text(json.dumps(manifest, indent=2), encoding="utf-8")
    return history

# Activa el parche
training.train_supervised = train_supervised_ips
print("✅ it/s por época ACTIVADO. Para desactivarlo: training.train_supervised = orig_train_supervised")


In [None]:

preset_demo = "std"
seed_demo = 42
enc_demo = load_preset(ROOT / "configs" / "presets.yaml", preset_demo)["encoder"]

print("\n>>> NAIVE (smoke)")
out_path, res = run_continual(
    task_list=task_list,
    make_loader_fn=make_loader_fn,
    make_model_fn=make_model_fn,
    tfm=tfm,
    preset=preset_demo,
    method="naive",
    lam=None,
    seed=seed_demo,
    encoder=enc_demo,
    fisher_batches_by_preset={"std": 600},
    epochs_override=2,
    runtime_encode=GPU_ENCODE,
    out_root=ROOT/"outputs",
    verbose=True,
)
print("OK:", out_path)

print("\n>>> EWC (smoke)")
out_path, res = run_continual(
    task_list=task_list,
    make_loader_fn=make_loader_fn,
    make_model_fn=make_model_fn,
    tfm=tfm,
    preset=preset_demo,
    method="ewc",
    lam=1e9,
    seed=seed_demo,
    encoder=enc_demo,
    fisher_batches_by_preset={"std": 600},
    epochs_override=2,
    runtime_encode=GPU_ENCODE,
    out_root=ROOT/"outputs",
    verbose=True,
)
print("OK:", out_path)


In [None]:
# Mini-sweep de λ en preset std (rápido)
lams = [3e8, 5e8, 7e8]
out_runs = []

for lam in lams:
    print(f"\n>>> EWC SMOKE λ={lam:.0e}")
    out_dir, _ = run_continual(
        task_list=task_list,
        make_loader_fn=make_loader_fn,
        make_model_fn=make_model_fn,
        tfm=tfm,
        preset="std",
        method="ewc",
        lam=lam,
        seed=42,
        encoder="rate",
        fisher_batches_by_preset={"std": 600},
        epochs_override=2,           # smoke rápido
        runtime_encode=GPU_ENCODE,   # << importante
        out_root=ROOT/"outputs",
        verbose=True,
    )
    out_runs.append(out_dir)

print("\nHecho:", out_runs)
