# 03 · Entrenamiento y evaluación (EWC baseline, preset FAST)

In [None]:
from pathlib import Path; import json, sys, torch
ROOT=Path.cwd().parent; sys.path.append(str(ROOT))
from src.utils import set_seeds, load_preset, make_loaders_from_csvs, ImageTransform
from src.models import SNNVisionRegressor
from src.training import TrainConfig, train_supervised
from src.methods.ewc import EWC, EWCConfig

set_seeds(42); preset='fast'; cfg=load_preset(ROOT/'configs'/'presets.yaml', preset); print('Preset:', cfg)
PROC=ROOT/'data'/'processed'; RAW=ROOT/'data'/'raw'/'udacity'; tfm=ImageTransform(160,80,True,None)

# SUPERVISED
run='circuito1'
train_csv=PROC/run/'train.csv'; val_csv=PROC/run/'val.csv'; test_csv=PROC/run/'test.csv'
assert train_csv.exists(), 'Ejecuta 01_DATA_QC_PREP.ipynb'
train_loader,val_loader,test_loader=make_loaders_from_csvs(RAW/run, train_csv,val_csv,test_csv,
    batch_size=cfg['batch_size'], encoder=cfg['encoder'], T=cfg['T'], gain=cfg['gain'], tfm=tfm)

model=SNNVisionRegressor(in_channels=1,lif_beta=0.95); loss_fn=torch.nn.MSELoss()
tcfg=TrainConfig(epochs=cfg['epochs'], batch_size=cfg['batch_size'], lr=cfg['lr'], amp=cfg['amp'])
out_dir=ROOT/'outputs'/f'supervised_{preset}_ewc0'; print('Entrenando SUPERVISED...')
_ = train_supervised(model, train_loader, val_loader, loss_fn, tcfg, out_dir, method=None); print('OK:', out_dir)

# CONTINUAL (c1->c2) con EWC
with open(PROC/'tasks.json','r',encoding='utf-8') as f: tasks=json.load(f)
def make_loader_fn(task,batch):
    name=task['name']; base=RAW/name; paths=task['paths']
    return make_loaders_from_csvs(base, Path(paths['train']), Path(paths['val']), Path(paths['test']),
        batch_size=batch, encoder=cfg['encoder'], T=cfg['T'], gain=cfg['gain'], tfm=tfm)
task_list=[{'name':n,'paths':tasks['splits'][n]} for n in tasks['tasks_order']]
model2=SNNVisionRegressor(in_channels=1,lif_beta=0.95); ewc=EWC(model2,EWCConfig(lambd=1e10,fisher_batches=25))
from src.training import _device; device=_device()
tcfg2=TrainConfig(epochs=cfg['epochs'], batch_size=cfg['batch_size'], lr=cfg['lr'], amp=cfg['amp'])
outc=ROOT/'outputs'/f'continual_{preset}_ewc'; outc.mkdir(parents=True, exist_ok=True)
results={}; seen=[]
for i,t in enumerate(task_list):
    name=t['name']; tr,va,te=make_loader_fn(t, tcfg2.batch_size); print(f'Tarea {i+1}:', name)
    _ = train_supervised(model2, tr, va, torch.nn.MSELoss(), tcfg2, outc/f'task_{i+1}_{name}', method=ewc)
    print('Estimando Fisher...'); ewc.estimate_fisher(va, torch.nn.MSELoss(), device=device)
    def eval_loader(loader):
        mae_sum=mse_sum=n=0.0
        for x,y in loader:
            x=x.to(device); y=y.to(device)
            with torch.no_grad(): y_hat=model2(x)
            mae_sum += torch.mean(torch.abs(y_hat-y)).item()*len(y)
            mse_sum += torch.mean((y_hat-y)**2).item()*len(y); n += len(y)
        return mae_sum/n, mse_sum/n
    te_mae,te_mse=eval_loader(te); results[name]={'test_mae':te_mae,'test_mse':te_mse}; seen.append((name,te))
    for pname,p_loader in seen[:-1]:
        p_mae,p_mse=eval_loader(p_loader)
        results[pname][f'after_{name}_mae']=p_mae; results[pname][f'after_{name}_mse']=p_mse
with open(outc/'continual_results.json','w',encoding='utf-8') as f: json.dump(results,f,indent=2); print('OK:', outc/'continual_results.json')