In [1]:
# =============================================================================
# Imports y setup
# =============================================================================
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

from pathlib import Path
import sys, json, torch

ROOT = Path.cwd().parents[0] if (Path.cwd().name == "notebooks") else Path.cwd()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

from src.utils import set_seeds, load_preset
from src.datasets import ImageTransform, AugmentConfig
from src.models import SNNVisionRegressor
from src.training import TrainConfig
from src.eval import eval_loader

# NEW
from src.bench import make_loader_fn_factory, universal_smoke_forward, enable_epoch_ips, disable_epoch_ips, print_bench_config

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42

tfm = ImageTransform(160, 80, True, None)
def make_model_fn(tfm):
    return SNNVisionRegressor(in_channels=(1 if tfm.to_gray else 3), lif_beta=0.95)

torch.set_num_threads(4)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")


In [None]:
GPU_ENCODE = True
SAFE_MODE = False
NUM_WORKERS = 12
PREFETCH    = 2
PIN_MEMORY  = True
PERSISTENT  = True

AUG_CFG_LIGHT = AugmentConfig(prob_hflip=0.5, brightness=None, gamma=None, noise_std=0.0)
AUG_CFG_FULL  = AugmentConfig(prob_hflip=0.5, brightness=(0.9, 1.1), gamma=(0.95, 1.05), noise_std=0.005)
AUG_CFG = AUG_CFG_LIGHT

USE_OFFLINE_BALANCED = True
USE_ONLINE_BALANCING = False

if SAFE_MODE:
    NUM_WORKERS = 0; PREFETCH = None; PIN_MEMORY = False; PERSISTENT = False
    USE_OFFLINE_BALANCED = False; USE_ONLINE_BALANCING = False; AUG_CFG = None

print(f"[SAFE_MODE={SAFE_MODE}] workers={NUM_WORKERS} prefetch={PREFETCH} pin={PIN_MEMORY} persistent={PERSISTENT}")


[SAFE_MODE=False] workers=12 prefetch=2 pin=True persistent=True


In [3]:
# =============================================================================
# Verificación de datos (normal y, si existe, balanceado offline)
# =============================================================================
from pathlib import Path

RAW  = ROOT / "data" / "raw" / "udacity"
PROC = ROOT / "data" / "processed"

# Ajusta si usas otros recorridos
# RUNS = ["circuito1", "circuito2"]
RUNS = [d.name for d in PROC.iterdir() if d.is_dir()]

missing = []
for run in RUNS:
    base = PROC / run

    # Comprobación obligatoria: splits normales
    for part in ["train", "val", "test"]:
        p = base / f"{part}.csv"
        if not p.exists():
            missing.append(str(p))

    # Comprobación opcional: train_balanced.csv (para modo OFFLINE balanceado)
    p_bal = base / "train_balanced.csv"
    if p_bal.exists():
        print(f"✓ {p_bal} OK")
    else:
        print(f"  Falta {p_bal}. Si más abajo pones USE_OFFLINE_BALANCED=True, "
              f"ejecuta 01A_PREP_BALANCED.ipynb o tools/make_splits_balanced.py")

if missing:
    raise FileNotFoundError(
        "Faltan CSV obligatorios (ejecuta 01A_PREP_BALANCED.ipynb o tu pipeline de prep):\n"
        + "\n".join(" - " + m for m in missing)
    )

print("OK: splits 'train/val/test' encontrados.")


✓ /home/cesar/proyectos/TFM_SNN/data/processed/circuito1/train_balanced.csv OK
✓ /home/cesar/proyectos/TFM_SNN/data/processed/circuito2/train_balanced.csv OK
OK: splits 'train/val/test' encontrados.


In [4]:
PROC = ROOT / "data" / "processed"

with open(PROC / ("tasks_balanced.json" if USE_OFFLINE_BALANCED else "tasks.json"), "r", encoding="utf-8") as f:
    tasks_json = json.load(f)

task_list = [{"name": n, "paths": tasks_json["splits"][n]} for n in tasks_json["tasks_order"]]

print("Tareas y su TRAIN CSV:")
for t in task_list:
    from pathlib import Path as _P
    print(f" - {t['name']}: {_P(t['paths']['train']).name}")

# Guardarraíl si activas OFFLINE balanceado:
if USE_OFFLINE_BALANCED:
    from pathlib import Path as _P
    for t in task_list:
        train_path = _P(t["paths"]["train"])
        if train_path.name != "train_balanced.csv":
            raise RuntimeError(f"[{t['name']}] Esperaba 'train_balanced.csv' pero encontré '{train_path.name}'.")
        if not train_path.exists():
            raise FileNotFoundError(f"[{t['name']}] No existe {train_path}.")
    print(" Verificación OFFLINE balanceado superada.")


Tareas y su TRAIN CSV:
 - circuito1: train_balanced.csv
 - circuito2: train_balanced.csv
 Verificación OFFLINE balanceado superada.


In [5]:
make_loader_fn = make_loader_fn_factory(
    ROOT,
    GPU_ENCODE=GPU_ENCODE,
    SEED=SEED,
    num_workers=NUM_WORKERS,
    prefetch_factor=PREFETCH,
    pin_memory=PIN_MEMORY,
    persistent_workers=PERSISTENT,
    aug_train=AUG_CFG,
    use_online_balancing=USE_ONLINE_BALANCING,
)


In [6]:
_ = universal_smoke_forward(
    make_loader_fn,
    task=task_list[0],
    encoder="rate", T=10, gain=0.5,
    tfm=tfm, seed=SEED, device=device,
    use_runtime_encode=GPU_ENCODE,
)


dataset 4D; uso encode en GPU y permuto a (T,B,C,H,W)
x5d.device: cuda:0 | shape: (10, 8, 1, 80, 160)
[forward] ejecutado con AMP


In [None]:
print_bench_config(
    NUM_WORKERS=NUM_WORKERS, PREFETCH=PREFETCH,
    PIN_MEMORY=PIN_MEMORY, PERSISTENT=PERSISTENT,
    USE_OFFLINE_BALANCED=USE_OFFLINE_BALANCED, USE_ONLINE_BALANCING=USE_ONLINE_BALANCING
)


[Bench RUN_BENCH=False] workers=12 prefetch=2 pin=True persistent=True | offline_bal=True online_bal=False


In [8]:
# === PRUEBA UNIVERSAL: loader -> (T,B,C,H,W) -> forward con fallback AMP ===
import torch, src.training as training

# --- 1) Loader pequeño con tu helper ---
tr, va, te = make_loader_fn(
    task=task_list[0],
    batch_size=8,
    encoder="rate",   # si tu pipeline ya devuelve 4D, lo detectamos abajo
    T=10,
    gain=0.5,
    tfm=tfm,
    seed=SEED,
)

xb, yb = next(iter(tr))
print("batch del loader:", xb.shape, yb.shape)

# --- 2) A (T,B,C,H,W) según formato de entrada ---
#    - Si el dataset ya codifica (5D): solo permutar.
#    - Si es 4D (imagen): activamos encode en GPU y usamos el helper runtime.
if xb.ndim == 5:  # (B,T,C,H,W)
    x5d = xb.permute(1,0,2,3,4).contiguous()
    used_runtime_encode = False
    print("dataset ya codificado; solo permuto a (T,B,C,H,W)")
elif xb.ndim == 4:  # (B,C,H,W)
    training.set_runtime_encode(mode="rate", T=10, gain=0.5,
                                device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    x5d = training._permute_if_needed(xb)  # aplica encode+permuta -> (T,B,C,H,W)
    used_runtime_encode = True
    print("dataset 4D; uso encode en GPU y permuto a (T,B,C,H,W)")
else:
    raise RuntimeError(f"Forma inesperada del batch: {xb.shape}")

print("x5d.device:", x5d.device, "| shape:", tuple(x5d.shape))

# --- 3) Modelo y forward con fallback automático AMP ---
model = make_model_fn(tfm).to(device).eval()

def forward_with_auto_amp(model, x5d, device):
    # Intento 1: AMP (solo si hay CUDA)
    if torch.cuda.is_available():
        try:
            x_amp = x5d.to(device, dtype=torch.float16, non_blocking=True)
            with torch.inference_mode(), torch.amp.autocast('cuda', enabled=True):
                y = model(x_amp)
            print("[forward] ejecutado con AMP (fp16)")
            return y
        except Exception as e:
            print("[forward] AMP falló, reintento en FP32. Motivo:", str(e))

    # Intento 2: FP32 (CPU o fallback)
    x_fp32 = x5d.to(device, dtype=torch.float32, non_blocking=True)
    with torch.inference_mode():
        y = model(x_fp32)
    print("[forward] ejecutado en FP32")
    return y

yhat = forward_with_auto_amp(model, x5d, device)
print("yhat:", tuple(yhat.shape))

# --- 4) Limpieza del runtime encode (si se usó) ---
if used_runtime_encode:
    training.set_runtime_encode(None)


batch del loader: torch.Size([8, 1, 80, 160]) torch.Size([8, 1])
dataset 4D; uso encode en GPU y permuto a (T,B,C,H,W)
x5d.device: cuda:0 | shape: (10, 8, 1, 80, 160)
[forward] ejecutado con AMP (fp16)
yhat: (8, 1)


In [9]:
# ===================== BENCH: toggle y eco de configuración =====================
# Usa el RUN_BENCH que ya defines en la Celda 2
print(
    f"[Bench workers={NUM_WORKERS} prefetch={PREFETCH} "
    f"pin={PIN_MEMORY} persistent={PERSISTENT} "
    f"| offline_bal={USE_OFFLINE_BALANCED} online_bal={USE_ONLINE_BALANCING}"
)


[Bench workers=12 prefetch=2 pin=True persistent=True | offline_bal=True online_bal=False


In [10]:
enable_epoch_ips()
print("it/s por época ACTIVADO. Llama a disable_epoch_ips() para restaurar.")


it/s por época ACTIVADO. Llama a disable_epoch_ips() para restaurar.


In [11]:
from src.runner import run_continual

preset_demo = "fast"
seed_demo   = 42
enc_demo    = load_preset(ROOT / "configs" / "presets.yaml", preset_demo)["encoder"]

print("\n>>> NAIVE (smoke)")
out_path, res = run_continual(
    task_list=task_list,
    make_loader_fn=make_loader_fn,
    make_model_fn=make_model_fn,
    tfm=tfm,
    preset=preset_demo,
    method="naive",
    lam=None,
    seed=seed_demo,
    encoder=enc_demo,
    fisher_batches_by_preset={"std": 600},
    epochs_override=2,
    runtime_encode=GPU_ENCODE,
    out_root=ROOT/"outputs",
    verbose=True,
)
print("OK:", out_path)

print("\n>>> EWC (smoke)")
out_path, res = run_continual(
    task_list=task_list,
    make_loader_fn=make_loader_fn,
    make_model_fn=make_model_fn,
    tfm=tfm,
    preset=preset_demo,
    method="ewc",
    lam=1e9,
    seed=seed_demo,
    encoder=enc_demo,
    fisher_batches_by_preset={"std": 600},
    epochs_override=2,
    runtime_encode=GPU_ENCODE,
    out_root=ROOT/"outputs",
    verbose=True,
)
print("OK:", out_path)



>>> NAIVE (smoke)

--- Tarea 1/2: circuito1 | preset=fast | method=naive | λ=- | B=64 T=10 AMP=True | enc=rate ---
  loader batch shape: (64, 1, 80, 160) | y: (64, 1)
  runtime encode: ON (GPU)
[TRAIN it/s] epoch 1/2: 58.1 it/s (909 iters en 15.65s)
[TRAIN it/s] epoch 2/2: 45.5 it/s (909 iters en 19.96s)
  runtime encode: OFF

--- Tarea 2/2: circuito2 | preset=fast | method=naive | λ=- | B=64 T=10 AMP=True | enc=rate ---
  loader batch shape: (64, 1, 80, 160) | y: (64, 1)
  runtime encode: ON (GPU)
[TRAIN it/s] epoch 1/2: 44.9 it/s (202 iters en 4.50s)
[TRAIN it/s] epoch 2/2: 44.2 it/s (202 iters en 4.57s)
  runtime encode: OFF
OK: /home/cesar/proyectos/TFM_SNN/outputs/continual_fast_naive_rate_seed_42

>>> EWC (smoke)

--- Tarea 1/2: circuito1 | preset=fast | method=ewc | λ=1000000000.0 | B=64 T=10 AMP=True | enc=rate ---
  loader batch shape: (64, 1, 80, 160) | y: (64, 1)
  runtime encode: ON (GPU)
[TRAIN it/s] epoch 1/2: 37.7 it/s (909 iters en 24.10s)
[TRAIN it/s] epoch 2/2: 39.9 

In [12]:
# Mini-sweep de λ en preset std (rápido)
lams = [3e8, 5e8, 7e8]
out_runs = []

for lam in lams:
    print(f"\n>>> EWC SMOKE λ={lam:.0e}")
    out_dir, _ = run_continual(
        task_list=task_list,
        make_loader_fn=make_loader_fn,
        make_model_fn=make_model_fn,
        tfm=tfm,
        preset="std",
        method="ewc",
        lam=lam,
        seed=42,
        encoder="rate",
        fisher_batches_by_preset={"std": 600},
        epochs_override=2,           # smoke rápido
        runtime_encode=GPU_ENCODE,   # << importante
        out_root=ROOT/"outputs",
        verbose=True,
    )
    out_runs.append(out_dir)

print("\nHecho:", out_runs)



>>> EWC SMOKE λ=3e+08

--- Tarea 1/2: circuito1 | preset=std | method=ewc | λ=300000000.0 | B=56 T=16 AMP=True | enc=rate ---
  loader batch shape: (56, 1, 80, 160) | y: (56, 1)
  runtime encode: ON (GPU)
[TRAIN it/s] epoch 1/2: 28.4 it/s (1038 iters en 36.55s)
[TRAIN it/s] epoch 2/2: 29.5 it/s (1038 iters en 35.16s)
  runtime encode: OFF

--- Tarea 2/2: circuito2 | preset=std | method=ewc | λ=300000000.0 | B=56 T=16 AMP=True | enc=rate ---
  loader batch shape: (56, 1, 80, 160) | y: (56, 1)
  runtime encode: ON (GPU)
[TRAIN it/s] epoch 1/2: 25.3 it/s (230 iters en 9.09s)
[TRAIN it/s] epoch 2/2: 30.2 it/s (230 iters en 7.61s)
  runtime encode: OFF

>>> EWC SMOKE λ=5e+08

--- Tarea 1/2: circuito1 | preset=std | method=ewc | λ=500000000.0 | B=56 T=16 AMP=True | enc=rate ---
  loader batch shape: (56, 1, 80, 160) | y: (56, 1)
  runtime encode: ON (GPU)
[TRAIN it/s] epoch 1/2: 31.4 it/s (1038 iters en 33.05s)
[TRAIN it/s] epoch 2/2: 32.5 it/s (1038 iters en 31.94s)
  runtime encode: OFF



In [None]:
lams = [3e8, 5e8, 7e8]
out_runs = []
for lam in lams:
    print(f"\n>>> EWC SMOKE λ={lam:.0e}")
    out_dir, _ = run_continual(
        task_list=task_list,
        make_loader_fn=make_loader_fn,
        make_model_fn=make_model_fn,
        tfm=tfm,
        preset="std",
        method="ewc",
        lam=lam,
        seed=42,
        encoder="rate",
        fisher_batches_by_preset={"std": 600},
        epochs_override=2,         # smoke rápido
        runtime_encode=GPU_ENCODE, # importante
        out_root=ROOT/"outputs",
        verbose=True,
    )
    out_runs.append(out_dir)
print("\nHecho:", out_runs)
