In [1]:
# =============================================================================
# Imports y setup
# =============================================================================
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

from pathlib import Path
import sys, json, torch

ROOT = Path.cwd().parents[0] if (Path.cwd().name == "notebooks") else Path.cwd()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

from src.utils import load_preset, set_seeds
from src.datasets import ImageTransform, AugmentConfig
from src.models import build_model, default_tfm_for_model
from src.training import TrainConfig
from src.eval import eval_loader
from src.bench import make_loader_fn_factory, universal_smoke_forward, enable_epoch_ips, disable_epoch_ips, print_bench_config

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_num_threads(4)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")
print("Device:", device)


Device: cuda


In [2]:
# =============================================================================
# Config: lee presets.yaml
# =============================================================================
PRESET = "fast"  # fast | std | accurate
CFG = load_preset(ROOT / "configs" / "presets.yaml", PRESET)

# Modelo / tfm
MODEL_NAME = CFG["model"]["name"]
tfm = ImageTransform(
    CFG["model"]["img_w"], CFG["model"]["img_h"],
    to_gray=bool(CFG["model"]["to_gray"]), crop_top=None
)

# Datos / codificación temporal
ENCODER = CFG["data"]["encoder"]
T       = int(CFG["data"]["T"])
GAIN    = float(CFG["data"]["gain"])
SEED    = int(CFG["data"]["seed"])

USE_OFFLINE_SPIKES   = bool(CFG["data"]["use_offline_spikes"])
USE_OFFLINE_BALANCED = bool(CFG["data"]["use_offline_balanced"])
RUNTIME_ENCODE       = bool(CFG["data"]["encode_runtime"])

# DataLoader / augment / balance
NUM_WORKERS = int(CFG["data"]["num_workers"])
PREFETCH    = CFG["data"]["prefetch_factor"]
PIN_MEMORY  = bool(CFG["data"]["pin_memory"])
PERSISTENT  = bool(CFG["data"]["persistent_workers"])

AUG_CFG = AugmentConfig(**CFG["data"]["aug_train"]) if CFG["data"]["aug_train"] else None

USE_ONLINE_BALANCING = bool(CFG["data"]["balance_online"])
BAL_BINS = int(CFG["data"]["balance_bins"])
BAL_EPS  = float(CFG["data"]["balance_smooth_eps"])

print(f"[PRESET={PRESET}] model={MODEL_NAME} {tfm.w}x{tfm.h} gray={tfm.to_gray}")
print(f"[DATA] encoder={ENCODER} T={T} gain={GAIN} seed={SEED}")
print(f"[LOADER] workers={NUM_WORKERS} prefetch={PREFETCH} pin={PIN_MEMORY} persistent={PERSISTENT}")
print(f"[BALANCE] offline={USE_OFFLINE_BALANCED} online={USE_ONLINE_BALANCING} bins={BAL_BINS}")
print(f"[RUNTIME_ENCODE] {RUNTIME_ENCODE} | [OFFLINE_SPIKES] {USE_OFFLINE_SPIKES}")

def make_model_fn(tfm):
    # kwargs solo necesarios para pilotnet_snn
    return build_model(MODEL_NAME, tfm, beta=0.9, threshold=0.5)


[PRESET=fast] model=pilotnet_snn 200x66 gray=True
[DATA] encoder=rate T=10 gain=0.5 seed=42
[LOADER] workers=8 prefetch=2 pin=True persistent=True
[BALANCE] offline=False online=False bins=21
[RUNTIME_ENCODE] False | [OFFLINE_SPIKES] True


In [3]:
# =============================================================================
# Verificación de datos (normal y, si existe, balanceado offline)
# =============================================================================
from pathlib import Path

RAW  = ROOT / "data" / "raw" / "udacity"
PROC = ROOT / "data" / "processed"

# Ajusta si usas otros recorridos
# RUNS = ["circuito1", "circuito2"]
RUNS = [d.name for d in PROC.iterdir() if d.is_dir()]

missing = []
for run in RUNS:
    base = PROC / run

    # Comprobación obligatoria: splits normales
    for part in ["train", "val", "test"]:
        p = base / f"{part}.csv"
        if not p.exists():
            missing.append(str(p))

    # Comprobación opcional: train_balanced.csv (para modo OFFLINE balanceado)
    p_bal = base / "train_balanced.csv"
    if p_bal.exists():
        print(f"✓ {p_bal} OK")
    else:
        print(f"  Falta {p_bal}. Si más abajo pones USE_OFFLINE_BALANCED=True, "
              f"ejecuta 01A_PREP_BALANCED.ipynb o tools/make_splits_balanced.py")

if missing:
    raise FileNotFoundError(
        "Faltan CSV obligatorios (ejecuta 01A_PREP_BALANCED.ipynb o tu pipeline de prep):\n"
        + "\n".join(" - " + m for m in missing)
    )

print("OK: splits 'train/val/test' encontrados.")


✓ /home/cesar/proyectos/TFM_SNN/data/processed/circuito1/train_balanced.csv OK
✓ /home/cesar/proyectos/TFM_SNN/data/processed/circuito2/train_balanced.csv OK
OK: splits 'train/val/test' encontrados.


In [4]:
PROC = ROOT / "data" / "processed"

with open(PROC / ("tasks_balanced.json" if USE_OFFLINE_BALANCED else "tasks.json"), "r", encoding="utf-8") as f:
    tasks_json = json.load(f)

task_list = [{"name": n, "paths": tasks_json["splits"][n]} for n in tasks_json["tasks_order"]]

print("Tareas y su TRAIN CSV:")
for t in task_list:
    from pathlib import Path as _P
    print(f" - {t['name']}: {_P(t['paths']['train']).name}")

# Guardarraíl si activas OFFLINE balanceado:
if USE_OFFLINE_BALANCED:
    from pathlib import Path as _P
    for t in task_list:
        train_path = _P(t["paths"]["train"])
        if train_path.name != "train_balanced.csv":
            raise RuntimeError(f"[{t['name']}] Esperaba 'train_balanced.csv' pero encontré '{train_path.name}'.")
        if not train_path.exists():
            raise FileNotFoundError(f"[{t['name']}] No existe {train_path}.")
    print(" Verificación OFFLINE balanceado superada.")


Tareas y su TRAIN CSV:
 - circuito1: train.csv
 - circuito2: train.csv


In [5]:
make_loader_fn = make_loader_fn_factory(
    ROOT,
    RUNTIME_ENCODE=RUNTIME_ENCODE,
    num_workers=NUM_WORKERS,
    prefetch_factor=PREFETCH,
    pin_memory=PIN_MEMORY,
    persistent_workers=PERSISTENT,
    aug_train=AUG_CFG,
    # balanceo online:
    balance_train=USE_ONLINE_BALANCING,
    balance_bins=BAL_BINS,
    balance_smooth_eps=BAL_EPS,
)


In [6]:
_ = universal_smoke_forward(
    make_loader_fn,
    task=task_list[0],
    encoder=ENCODER, T=T, gain=GAIN,
    tfm=tfm, seed=SEED, device=device,
    use_runtime_encode=RUNTIME_ENCODE,
)


dataset ya codificado; solo permuto a (T,B,C,H,W)
x5d.device: cpu | shape: (10, 8, 1, 66, 200)
[forward] ejecutado con AMP


In [7]:
print_bench_config(
    NUM_WORKERS=NUM_WORKERS, PREFETCH=PREFETCH,
    PIN_MEMORY=PIN_MEMORY, PERSISTENT=PERSISTENT,
    USE_OFFLINE_BALANCED=USE_OFFLINE_BALANCED, USE_ONLINE_BALANCING=USE_ONLINE_BALANCING
)


[Bench workers=8 prefetch=2 pin=True persistent=True | offline_bal=False online_bal=False


In [8]:
# === PRUEBA UNIVERSAL: loader -> (T,B,C,H,W) -> forward con fallback AMP ===
import torch, src.training as training

# --- 1) Loader pequeño con tu helper ---
tr, va, te = make_loader_fn(
    task=task_list[0],
    batch_size=8,
    encoder="rate",   # si tu pipeline ya devuelve 4D, lo detectamos abajo
    T=10,
    gain=0.5,
    tfm=tfm,
    seed=SEED,
)

xb, yb = next(iter(tr))
print("batch del loader:", xb.shape, yb.shape)

# --- 2) A (T,B,C,H,W) según formato de entrada ---
#    - Si el dataset ya codifica (5D): solo permutar.
#    - Si es 4D (imagen): activamos encode en GPU y usamos el helper runtime.
if xb.ndim == 5:  # (B,T,C,H,W)
    x5d = xb.permute(1,0,2,3,4).contiguous()
    used_runtime_encode = False
    print("dataset ya codificado; solo permuto a (T,B,C,H,W)")
elif xb.ndim == 4:  # (B,C,H,W)
    training.set_runtime_encode(mode="rate", T=10, gain=0.5,
                                device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    x5d = training._permute_if_needed(xb)  # aplica encode+permuta -> (T,B,C,H,W)
    used_runtime_encode = True
    print("dataset 4D; uso encode en GPU y permuto a (T,B,C,H,W)")
else:
    raise RuntimeError(f"Forma inesperada del batch: {xb.shape}")

print("x5d.device:", x5d.device, "| shape:", tuple(x5d.shape))

# --- 3) Modelo y forward con fallback automático AMP ---
model = make_model_fn(tfm).to(device).eval()

def forward_with_auto_amp(model, x5d, device):
    # Intento 1: AMP (solo si hay CUDA)
    if torch.cuda.is_available():
        try:
            x_amp = x5d.to(device, dtype=torch.float16, non_blocking=True)
            with torch.inference_mode(), torch.amp.autocast('cuda', enabled=True):
                y = model(x_amp)
            print("[forward] ejecutado con AMP (fp16)")
            return y
        except Exception as e:
            print("[forward] AMP falló, reintento en FP32. Motivo:", str(e))

    # Intento 2: FP32 (CPU o fallback)
    x_fp32 = x5d.to(device, dtype=torch.float32, non_blocking=True)
    with torch.inference_mode():
        y = model(x_fp32)
    print("[forward] ejecutado en FP32")
    return y

yhat = forward_with_auto_amp(model, x5d, device)
print("yhat:", tuple(yhat.shape))

# --- 4) Limpieza del runtime encode (si se usó) ---
if used_runtime_encode:
    training.set_runtime_encode(None)


batch del loader: torch.Size([8, 10, 1, 66, 200]) torch.Size([8, 1])
dataset ya codificado; solo permuto a (T,B,C,H,W)
x5d.device: cpu | shape: (10, 8, 1, 66, 200)
[forward] ejecutado con AMP (fp16)
yhat: (8, 1)


In [9]:
# ===================== BENCH: toggle y eco de configuración =====================
# Usa el RUN_BENCH que ya defines en la Celda 2
print(
    f"[Bench workers={NUM_WORKERS} prefetch={PREFETCH} "
    f"pin={PIN_MEMORY} persistent={PERSISTENT} "
    f"| offline_bal={USE_OFFLINE_BALANCED} online_bal={USE_ONLINE_BALANCING}"
)


[Bench workers=8 prefetch=2 pin=True persistent=True | offline_bal=False online_bal=False


In [10]:
enable_epoch_ips()
print("it/s por época ACTIVADO. Llama a disable_epoch_ips() para restaurar.")


it/s por época ACTIVADO. Llama a disable_epoch_ips() para restaurar.


In [11]:
from copy import deepcopy
from src.utils import load_preset
from src.runner import run_continual

# 1) Carga el preset COMPLETO (anidado)
preset_demo = "fast"
CFG = load_preset(ROOT / "configs" / "presets.yaml", preset_demo)

# 2) Overrides opcionales (si quieres fijar la semilla aquí)
CFG = deepcopy(CFG)
CFG["data"]["seed"] = 42  # <- opcional

# (Opcional) si quieres asegurarte de que los flags del factory
# vienen del preset en vez de variables sueltas del notebook:
USE_OFFLINE_SPIKES = bool(CFG["data"].get("use_offline_spikes", False))
RUNTIME_ENCODE     = bool(CFG["data"].get("encode_runtime", not USE_OFFLINE_SPIKES))

# 3) Ejecuta NAIVE con la NUEVA firma
print("\n>>> NAIVE (smoke)")
CFG_NAIVE = deepcopy(CFG)
CFG_NAIVE["continual"]["method"] = "naive"
CFG_NAIVE["continual"]["params"] = {}

out_path, res = run_continual(
    task_list=task_list,
    make_loader_fn=make_loader_fn,   # tu wrapper de la celda 4 que pasa **dl_kwargs
    make_model_fn=make_model_fn,
    tfm=tfm,
    cfg=CFG_NAIVE,                   # <- configuración completa
    preset_name=preset_demo,         # <- solo para el nombre de salida
    out_root=ROOT / "outputs",
    verbose=True,
)
print("OK:", out_path)

# 4) EWC (smoke)
print("\n>>> EWC (smoke)")
CFG_EWC = deepcopy(CFG)
CFG_EWC["continual"]["method"] = "ewc"
CFG_EWC["continual"]["params"] = {"lam": 1e9, "fisher_batches": 200}

out_path, res = run_continual(
    task_list=task_list,
    make_loader_fn=make_loader_fn,
    make_model_fn=make_model_fn,
    tfm=tfm,
    cfg=CFG_EWC,
    preset_name=preset_demo,
    out_root=ROOT / "outputs",
    verbose=True,
)
print("OK:", out_path)

# 5) REHEARSAL (smoke)
print("\n>>> REHEARSAL (smoke)")
CFG_R = deepcopy(CFG)
CFG_R["continual"]["method"] = "rehearsal"
CFG_R["continual"]["params"] = {"buffer_size": 5000, "replay_ratio": 0.2}

out_path, res = run_continual(
    task_list=task_list,
    make_loader_fn=make_loader_fn,
    make_model_fn=make_model_fn,
    tfm=tfm,
    cfg=CFG_R,
    preset_name=preset_demo,
    out_root=ROOT / "outputs",
    verbose=True,
)
print("OK:", out_path)

# 6) REHEARSAL+EWC (smoke)
print("\n>>> REHEARSAL+EWC (smoke)")
CFG_RE = deepcopy(CFG)
CFG_RE["continual"]["method"] = "rehearsal+ewc"
CFG_RE["continual"]["params"] = {
    "buffer_size": 5000, "replay_ratio": 0.2, "lam": 7e8, "fisher_batches": 200
}

out_path, res = run_continual(
    task_list=task_list,
    make_loader_fn=make_loader_fn,
    make_model_fn=make_model_fn,
    tfm=tfm,
    cfg=CFG_RE,
    preset_name=preset_demo,
    out_root=ROOT / "outputs",
    verbose=True,
)
print("OK:", out_path)



>>> NAIVE (smoke)

--- Tarea 1/2: circuito1 | preset=fast | method=naive | B=64 T=10 AMP=True | enc=rate ---
[TRAIN it/s] epoch 1/2: 10.0 it/s (164 iters en 16.38s)
[TRAIN it/s] epoch 2/2: 11.1 it/s (164 iters en 14.72s)

--- Tarea 2/2: circuito2 | preset=fast | method=naive | B=64 T=10 AMP=True | enc=rate ---
[TRAIN it/s] epoch 1/2: 8.7 it/s (49 iters en 5.66s)
[TRAIN it/s] epoch 2/2: 11.1 it/s (49 iters en 4.42s)
OK: /home/cesar/proyectos/TFM_SNN/outputs/continual_fast_naive_rate_model-PilotNetSNN_66x200_gray_seed_42

>>> EWC (smoke)

--- Tarea 1/2: circuito1 | preset=fast | method=ewc_lam_1e+09 | B=64 T=10 AMP=True | enc=rate ---
[TRAIN it/s] epoch 1/2: 11.1 it/s (164 iters en 14.74s)
[TRAIN it/s] epoch 2/2: 10.5 it/s (164 iters en 15.55s)

--- Tarea 2/2: circuito2 | preset=fast | method=ewc_lam_1e+09 | B=64 T=10 AMP=True | enc=rate ---
[TRAIN it/s] epoch 1/2: 9.4 it/s (49 iters en 5.22s)
[TRAIN it/s] epoch 2/2: 9.7 it/s (49 iters en 5.04s)
OK: /home/cesar/proyectos/TFM_SNN/outputs

In [12]:
from copy import deepcopy
from src.utils import load_preset
from src.runner import run_continual

preset_name = "fast"
base_cfg = load_preset(ROOT / "configs" / "presets.yaml", preset_name)
base_cfg = deepcopy(base_cfg)
base_cfg["data"]["seed"] = 42  # opcional

lams = [7e8, 1e9]
out_runs = []
for lam in lams:
    print(f"\n>>> REHEARSAL+EWC SMOKE λ={lam:.0e}")
    cfg = deepcopy(base_cfg)
    cfg["continual"]["method"] = "rehearsal+ewc"
    cfg["continual"]["params"] = {
        "buffer_size": 5000,
        "replay_ratio": 0.2,
        "lam": lam,
        "fisher_batches": 200,
    }

    out_dir, _ = run_continual(
        task_list=task_list,
        make_loader_fn=make_loader_fn,
        make_model_fn=make_model_fn,
        tfm=tfm,
        cfg=cfg,                     # <- CONFIG COMPLETA
        preset_name=preset_name,     # <- solo naming
        out_root=ROOT / "outputs",
        verbose=True,
    )
    out_runs.append(out_dir)

print("\nHecho:", out_runs)



>>> REHEARSAL+EWC SMOKE λ=7e+08

--- Tarea 1/2: circuito1 | preset=fast | method=rehearsal_buf_5000_rr_20+ewc_lam_7e+08 | B=64 T=10 AMP=True | enc=rate ---
[TRAIN it/s] epoch 1/2: 9.1 it/s (164 iters en 18.07s)
[TRAIN it/s] epoch 2/2: 9.8 it/s (164 iters en 16.66s)

--- Tarea 2/2: circuito2 | preset=fast | method=rehearsal_buf_5000_rr_20+ewc_lam_7e+08 | B=64 T=10 AMP=True | enc=rate ---
[TRAIN it/s] epoch 1/2: 7.7 it/s (49 iters en 6.36s)
[TRAIN it/s] epoch 2/2: 8.0 it/s (49 iters en 6.11s)

>>> REHEARSAL+EWC SMOKE λ=1e+09

--- Tarea 1/2: circuito1 | preset=fast | method=rehearsal_buf_5000_rr_20+ewc_lam_1e+09 | B=64 T=10 AMP=True | enc=rate ---
[TRAIN it/s] epoch 1/2: 9.2 it/s (164 iters en 17.85s)
[TRAIN it/s] epoch 2/2: 9.7 it/s (164 iters en 16.88s)

--- Tarea 2/2: circuito2 | preset=fast | method=rehearsal_buf_5000_rr_20+ewc_lam_1e+09 | B=64 T=10 AMP=True | enc=rate ---
[TRAIN it/s] epoch 1/2: 7.8 it/s (49 iters en 6.29s)
[TRAIN it/s] epoch 2/2: 8.1 it/s (49 iters en 6.07s)

Hecho

In [13]:
from copy import deepcopy
from src.utils import load_preset
from src.runner import run_continual

preset_name = "fast"
base_cfg = load_preset(ROOT / "configs" / "presets.yaml", preset_name)
base_cfg = deepcopy(base_cfg)
base_cfg["data"]["seed"] = 42  # opcional

lams = [7e8, 1e9]
out_runs = []
for lam in lams:
    print(f"\n>>> REHEARSAL+EWC SMOKE λ={lam:.0e}")
    cfg = deepcopy(base_cfg)
    cfg["continual"]["method"] = "rehearsal+ewc"
    cfg["continual"]["params"] = {
        "buffer_size": 5000,
        "replay_ratio": 0.2,
        "lam": lam,
        "fisher_batches": 200,
    }

    out_dir, _ = run_continual(
        task_list=task_list,
        make_loader_fn=make_loader_fn,
        make_model_fn=make_model_fn,
        tfm=tfm,
        cfg=cfg,                     # <- CONFIG COMPLETA
        preset_name=preset_name,     # <- solo naming
        out_root=ROOT / "outputs",
        verbose=True,
    )
    out_runs.append(out_dir)

print("\nHecho:", out_runs)



>>> REHEARSAL+EWC SMOKE λ=7e+08

--- Tarea 1/2: circuito1 | preset=fast | method=rehearsal_buf_5000_rr_20+ewc_lam_7e+08 | B=64 T=10 AMP=True | enc=rate ---
[TRAIN it/s] epoch 1/2: 9.0 it/s (164 iters en 18.26s)
[TRAIN it/s] epoch 2/2: 9.4 it/s (164 iters en 17.53s)

--- Tarea 2/2: circuito2 | preset=fast | method=rehearsal_buf_5000_rr_20+ewc_lam_7e+08 | B=64 T=10 AMP=True | enc=rate ---
[TRAIN it/s] epoch 1/2: 7.6 it/s (49 iters en 6.43s)
[TRAIN it/s] epoch 2/2: 8.3 it/s (49 iters en 5.94s)

>>> REHEARSAL+EWC SMOKE λ=1e+09

--- Tarea 1/2: circuito1 | preset=fast | method=rehearsal_buf_5000_rr_20+ewc_lam_1e+09 | B=64 T=10 AMP=True | enc=rate ---
[TRAIN it/s] epoch 1/2: 8.9 it/s (164 iters en 18.42s)
[TRAIN it/s] epoch 2/2: 8.9 it/s (164 iters en 18.52s)

--- Tarea 2/2: circuito2 | preset=fast | method=rehearsal_buf_5000_rr_20+ewc_lam_1e+09 | B=64 T=10 AMP=True | enc=rate ---
[TRAIN it/s] epoch 1/2: 6.9 it/s (49 iters en 7.09s)
[TRAIN it/s] epoch 2/2: 8.1 it/s (49 iters en 6.08s)

Hecho