# 03 — Entrenamiento y Evaluación (SUPERVISED y CONTINUAL con EWC/NAIVE)

Este notebook entrena un modelo **SNN** para **regresión del ángulo de dirección (steering)** en dos protocolos:

- **Supervised** sobre `circuito1`.
- **Continual** con dos tareas secuenciales `circuito1 → circuito2` usando:
  - **EWC** (consolidación elástica de pesos), o
  - **NAIVE** (baseline sin penalización; equivalente a λ=0).

> **Requisitos previos**: Ejecuta `01_DATA_QC_PREP.ipynb` para generar `train/val/test.csv` y `tasks.json`.


In [None]:
# =============================================================================
# Imports y setup
# =============================================================================
from pathlib import Path
import sys, json, torch

# Detecta la raíz del repo (si estás dentro de notebooks/, sube un nivel)
ROOT = Path.cwd().parents[0] if (Path.cwd().name == "notebooks") else Path.cwd()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

# Utilidades y módulos del proyecto
from src.utils import set_seeds, load_preset, make_loaders_from_csvs, ImageTransform
from src.models import SNNVisionRegressor
from src.training import TrainConfig, train_supervised, _permute_if_needed
from src.methods.ewc import EWC, EWCConfig

# Dispositivo (CUDA si disponible)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ROOT, device

SEED = 42

(PosixPath('/home/cesar/proyectos/TFM_SNN'), device(type='cuda'))

In [2]:
# =============================================================================
# Selecciona preset de ejecución
#   - 'fast': corridas rápidas (sanity check)
#   - 'std': equilibrio calidad/tiempo
#   - 'accurate': más épocas/T (finales)
# =============================================================================
preset = "fast"   # 'fast' | 'std' | 'accurate'
cfg = load_preset(ROOT/"configs"/"presets.yaml", preset)
print("Preset:", cfg)

# Transformación de imagen
# IMPORTANTE: usa argumentos **posicionales** (w, h, to_gray, crop_top)
# Evita keywords tipo target_w/target_h porque la clase no los define.
tfm = ImageTransform(160, 80, True, None)

Preset: {'epochs': 2, 'batch_size': 8, 'T': 10, 'gain': 0.5, 'encoder': 'rate', 'lr': 0.001, 'amp': True}


In [3]:
# =============================================================================
# Verificación de datos
# =============================================================================
RAW = ROOT/"data"/"raw"/"udacity"
PROC = ROOT/"data"/"processed"

# Comprueba que existen los CSV por split de cada circuito
for run in ["circuito1","circuito2"]:
    for part in ["train","val","test"]:
        path = PROC/run/f"{part}.csv"
        if not path.exists():
            raise FileNotFoundError(f"Falta {path}. Ejecuta 01_DATA_QC_PREP.ipynb primero.")
print("OK splits encontrados")

OK splits encontrados


In [None]:
# =============================================================================
# SUPERVISED: entrenamiento en circuito1
# =============================================================================
set_seeds(SEED)  # reproducibilidad

# DataLoaders con codificación temporal on-the-fly (cfg['encoder'], T, gain)
train_loader, val_loader, test_loader = make_loaders_from_csvs(
    base_dir=RAW/"circuito1",
    train_csv=PROC/"circuito1"/"train.csv",
    val_csv=PROC/"circuito1"/"val.csv",
    test_csv=PROC/"circuito1"/"test.csv",
    batch_size=cfg["batch_size"],
    encoder=cfg["encoder"],   # 'rate' | 'latency' | 'raw'
    T=cfg["T"],
    gain=cfg["gain"],
    tfm=tfm,
    seed=SEED,
)

# Modelo SNN (1 canal → gris), pérdida y configuración de entrenamiento
model = SNNVisionRegressor(in_channels=1, lif_beta=0.95)
loss_fn = torch.nn.MSELoss()
tcfg = TrainConfig(epochs=cfg["epochs"], batch_size=cfg["batch_size"], lr=cfg["lr"], amp=cfg["amp"])

# Carpeta de salida para supervised
out_dir = ROOT/"outputs"/f"supervised_{preset}_ewc0"
print("Entrenando SUPERVISED...")
_ = train_supervised(model, train_loader, val_loader, loss_fn, tcfg, out_dir, method=None)
print("OK:", out_dir)

Entrenando SUPERVISED...


Epoch 1/2: 100%|██████████| 437/437 [00:07<00:00, 58.63it/s, loss=0.00615] 
Epoch 2/2: 100%|██████████| 437/437 [00:09<00:00, 45.93it/s, loss=0.0121]  


OK: /home/cesar/proyectos/TFM_SNN/outputs/supervised_fast_ewc0


In [4]:
# =============================================================================
# Helper de evaluación (asegura forma correcta (T,B,C,H,W) antes del modelo)
# =============================================================================
def eval_loader(loader, model, device):
    """Calcula MAE/MSE promediados sobre todo el loader.

    - El DataLoader produce (B, T, C, H, W)

    - El modelo espera (T, B, C, H, W)

    """
    mae_sum = mse_sum = 0.0
    n = 0
    for x, y in loader:
        # (B,T,C,H,W) -> (T,B,C,H,W) si aplica
        x = _permute_if_needed(x.to(device))
        y = y.to(device)
        with torch.no_grad():
            y_hat = model(x)
        mae_sum += torch.mean(torch.abs(y_hat - y)).item() * len(y)
        mse_sum += torch.mean((y_hat - y) ** 2).item() * len(y)
        n += len(y)
    return mae_sum / max(n, 1), mse_sum / max(n, 1)

In [5]:
# =============================================================================
# Cargar orden de tareas (continual) desde tasks.json
# =============================================================================
with open(PROC/"tasks.json","r",encoding="utf-8") as f:
    tasks_json = json.load(f)

# task_list = [{'name': 'circuito1', 'paths': {...}}, {'name': 'circuito2', 'paths': {...}}]
task_list = [{"name": n, "paths": tasks_json["splits"][n]} for n in tasks_json["tasks_order"]]
task_list

[{'name': 'circuito1',
  'paths': {'train': '/home/cesar/proyectos/TFM_SNN/data/processed/circuito1/train.csv',
   'val': '/home/cesar/proyectos/TFM_SNN/data/processed/circuito1/val.csv',
   'test': '/home/cesar/proyectos/TFM_SNN/data/processed/circuito1/test.csv'}},
 {'name': 'circuito2',
  'paths': {'train': '/home/cesar/proyectos/TFM_SNN/data/processed/circuito2/train.csv',
   'val': '/home/cesar/proyectos/TFM_SNN/data/processed/circuito2/val.csv',
   'test': '/home/cesar/proyectos/TFM_SNN/data/processed/circuito2/test.csv'}}]

In [10]:
# =============================================================================
# Función para crear loaders de una tarea dada (respeta cfg del preset)
# =============================================================================
def make_loader_fn(task, batch_size, seed):
    name = task["name"]
    base = RAW/name
    paths = task["paths"]
    return make_loaders_from_csvs(
        base_dir=base,
        train_csv=Path(paths["train"]),
        val_csv=Path(paths["val"]),
        test_csv=Path(paths["test"]),
        batch_size=batch_size,
        encoder=cfg["encoder"],
        T=cfg["T"],
        gain=cfg["gain"],
        tfm=tfm,
        seed=seed,
    )

In [None]:
# =============================================================================
# Toggle del método a ejecutar: 'ewc' o 'naive'
#   - EWC: usa lambda > 0 (p.ej., 1e10)
#   - NAIVE: lambda = 0 (sin penalización)
# =============================================================================
METHOD = "ewc"   # "ewc" | "naive"
EWC_LAMBDA = 1e11 if METHOD == "ewc" else 0.0

# Siembra global para que NAIVE y EWC arranquen igual
set_seeds(SEED)

# Etiqueta de salida única para no pisar resultados previos
RUN_TAG = f"{preset}_{METHOD}" + (f"_lam_{EWC_LAMBDA:.0e}" if METHOD=="ewc" else "")
RUN_TAG += f"_{cfg['encoder']}"
RUN_TAG += f"_seed_{SEED}"

print("Método:", METHOD, "| λ:", EWC_LAMBDA, "| tag:", RUN_TAG)

Método: naive | λ: 0.0 | tag: fast_naive_rate


In [8]:
# =============================================================================
# Instanciación del modelo, método y carpeta de salida
# =============================================================================
model2 = SNNVisionRegressor(in_channels=1, lif_beta=0.95)
ewc = EWC(model2, EWCConfig(lambd=float(EWC_LAMBDA), fisher_batches=25))
loss_fn = torch.nn.MSELoss()
tcfg2 = TrainConfig(epochs=cfg["epochs"], batch_size=cfg["batch_size"], lr=cfg["lr"], amp=cfg["amp"])

outc = ROOT / "outputs" / f"continual_{RUN_TAG}"
outc.mkdir(parents=True, exist_ok=True)
outc

PosixPath('/home/cesar/proyectos/TFM_SNN/outputs/continual_fast_naive_rate')

In [None]:
# =============================================================================
# Bucle continual: circuito1 -> circuito2 (entrenar, consolidar si EWC, evaluar)
# =============================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
results = {}
seen = []

for i, t in enumerate(task_list):
    name = t["name"]
    print(f"Tarea {i+1}: {name}")
    tr, va, te = make_loader_fn(t, tcfg2.batch_size, seed=SEED)

    # Entrenamiento de la tarea actual (aplica penalty solo si METHOD == 'ewc')
    _ = train_supervised(
        model2, tr, va, loss_fn, tcfg2,
        outc/f"task_{i+1}_{name}",
        method=ewc if METHOD == "ewc" else None
    )

    # Consolidación por Fisher (solo EWC)
    if METHOD == "ewc":
        print("Estimando Fisher...")
        ewc.estimate_fisher(va, loss_fn, device=device)

    # Evaluación post-tarea en su test
    te_mae, te_mse = eval_loader(te, model2, device)
    results[name] = {"test_mae": te_mae, "test_mse": te_mse}
    seen.append((name, te))

    # Evaluación de tareas previas para medir olvido (BWT)
    for pname, p_loader in seen[:-1]:
        p_mae, p_mse = eval_loader(p_loader, model2, device)
        results[pname][f"after_{name}_mae"] = p_mae
        results[pname][f"after_{name}_mse"] = p_mse

# Guarda resultados del continual
with open(outc/"continual_results.json","w",encoding="utf-8") as f:
    json.dump(results, f, indent=2)
print("OK:", outc/"continual_results.json")
results

Tarea 1: circuito1


Epoch 1/2: 100%|██████████| 437/437 [00:07<00:00, 54.84it/s, loss=0.0729]  
Epoch 2/2: 100%|██████████| 437/437 [00:07<00:00, 57.18it/s, loss=4.11e-6] 


Tarea 2: circuito2


Epoch 1/2: 100%|██████████| 129/129 [00:02<00:00, 56.95it/s, loss=0.0485]
Epoch 2/2: 100%|██████████| 129/129 [00:02<00:00, 47.99it/s, loss=0.0381] 


OK: /home/cesar/proyectos/TFM_SNN/outputs/continual_fast_naive_rate/continual_results.json


{'circuito1': {'test_mae': 0.06532558528406004,
  'test_mse': 0.019158605222068588,
  'after_circuito2_mae': 0.1160996605552012,
  'after_circuito2_mse': 0.020342665440302783},
 'circuito2': {'test_mae': 0.1860753110552256,
  'test_mse': 0.06933023987453332}}

In [20]:
# =============================================================================
# (Opcional) Resumen comparativo de todos los continual_* en outputs/
# =============================================================================
from pathlib import Path
import json, re
import pandas as pd

def parse_exp_name(name: str):
    """Extrae preset, método, lambda y encoder desde el nombre de carpeta de salida."""
    m = re.match(r"continual_(?P<preset>\w+)_(?P<method>ewc|naive)(?:_lam_(?P<lambda>[^_]+))?_(?P<enc>\w+)", name)
    if not m:
        return {"preset": None, "method": None, "lambda": None, "encoder": None}
    d = m.groupdict()
    return {"preset": d["preset"], "method": d["method"], "lambda": d.get("lambda"), "encoder": d["enc"]}

rows = []
for d in sorted((ROOT/"outputs").glob("continual_*")):
    f = d/"continual_results.json"
    if not f.exists():
        continue
    meta = parse_exp_name(d.name)
    res = json.loads(f.read_text())

    c1_mae = res["circuito1"]["test_mae"]
    c1_after = res["circuito1"].get("after_circuito2_mae")
    c1_forget = None if c1_after is None else (c1_after - c1_mae)
    c1_forget_rel = None if (c1_after is None or c1_mae == 0) else (c1_forget / c1_mae * 100.0)

    c2_mae = res["circuito2"]["test_mae"]

    rows.append({
        "exp": d.name, **meta,
        "c1_mae": c1_mae,
        "c1_after_c2_mae": c1_after,
        "c1_forgetting_mae_abs": c1_forget,
        "c1_forgetting_mae_rel_%": c1_forget_rel,
        "c2_mae": c2_mae
    })

df = pd.DataFrame(rows).sort_values(["preset","encoder","method","lambda"], na_position="last")
df

Unnamed: 0,exp,preset,method,lambda,encoder,c1_mae,c1_after_c2_mae,c1_forgetting_mae_abs,c1_forgetting_mae_rel_%,c2_mae
0,continual_fast_ewc_lam_1e+09_rate,fast,ewc,1000000000.0,rate,0.080127,0.080117,-9.480523e-06,-0.011832,0.179358
1,continual_fast_ewc_lam_1e+10_rate,fast,ewc,10000000000.0,rate,0.08252,0.082522,1.299926e-06,0.001575,0.178743
2,continual_fast_ewc_lam_1e+11_rate,fast,ewc,100000000000.0,rate,0.082494,0.082493,-3.586093e-07,-0.000435,0.178792
3,continual_fast_naive_rate,fast,naive,,rate,0.105127,0.090398,-0.01472912,-14.010812,0.183966
