In [None]:
# =============================================================================
# Imports y setup
# =============================================================================
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

from pathlib import Path
import sys, json, torch

# Detecta la raíz del repo (si estás dentro de notebooks/, sube un nivel)
ROOT = Path.cwd().parents[0] if (Path.cwd().name == "notebooks") else Path.cwd()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

# Utilidades y módulos del proyecto
from src.datasets import AugmentConfig
from src.utils import set_seeds, load_preset, make_loaders_from_csvs
from src.datasets import ImageTransform
from src.models import SNNVisionRegressor
from src.training import TrainConfig, train_supervised, _permute_if_needed
from src.methods.ewc import EWC, EWCConfig

# Dispositivo (CUDA si disponible)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ROOT, device

SEED = 42

# Transformación de imagen
# IMPORTANTE: usa argumentos **posicionales** (w, h, to_gray, crop_top)
# Evita keywords tipo target_w/target_h porque la clase no los define.
tfm = ImageTransform(160, 80, True, None)

def make_snn_model(tfm):
    # C = 1 si gris, 3 si color
    return SNNVisionRegressor(in_channels=(1 if tfm.to_gray else 3), lif_beta=0.95)

torch.set_num_threads(4)               # evita sobre-contención de CPU al decodificar
torch.backends.cudnn.benchmark = True  # selecciona la mejor impl. de convs para tamaño fijo

# Permite TF32 (barato y suele acelerar matmul/convs en Ampere+ sin tocar precisión de FP16)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Precisión alta para kernels FP32 cuando no uses AMP
torch.set_float32_matmul_precision("high")


GPU_ENCODE = True  # activa codificación (rate/latency/raw) en GPU
RUN_BENCH = True   # pon True para ejecutar

# --- SAFE MODE: desactiva todo lo pesado para estabilizar ---
SAFE_MODE = False  # o False cuando quieras rendimiento

NUM_WORKERS   = 12
PREFETCH      = 2
PIN_MEMORY    = True
PERSISTENT    = True

# AUG_CFG = None
from src.datasets import AugmentConfig
# --- Augment: perfiles ---
AUG_CFG_LIGHT = AugmentConfig(prob_hflip=0.5, brightness=None, gamma=None, noise_std=0.0)
AUG_CFG_FULL  = AugmentConfig(prob_hflip=0.5, brightness=(0.9, 1.1), gamma=(0.95, 1.05), noise_std=0.005)

AUG_CFG = AUG_CFG_LIGHT   # ← usa LIGHT ahora; luego prueba FULL

USE_OFFLINE_BALANCED  = True
USE_ONLINE_BALANCING  = False

if SAFE_MODE:
    NUM_WORKERS  = 0
    PREFETCH     = None   # ← IMPORTANTÍSIMO con num_workers=0
    PIN_MEMORY   = False
    PERSISTENT   = False  # ← también obligatorio con num_workers=0
    USE_OFFLINE_BALANCED = False
    USE_ONLINE_BALANCING = False
    AUG_CFG = None

print(f"[SAFE_MODE={SAFE_MODE}] workers={NUM_WORKERS} prefetch={PREFETCH} "
      f"pin={PIN_MEMORY} persistent={PERSISTENT}")



In [None]:
# =============================================================================
# Verificación de datos (normal y, si existe, balanceado offline)
# =============================================================================
from pathlib import Path

RAW  = ROOT/"data"/"raw"/"udacity"
PROC = ROOT/"data"/"processed"

RUNS = ["circuito1","circuito2"]  # ajusta si hace falta

missing = []
for run in RUNS:
    base = PROC / run

    # Comprobación obligatoria: splits normales
    for part in ["train","val","test"]:
        p = base / f"{part}.csv"
        if not p.exists():
            missing.append(str(p))

    # Comprobación opcional: train_balanced.csv (para modo OFFLINE balanceado)
    p_bal = base / "train_balanced.csv"
    if p_bal.exists():
        print(f"✓ {p_bal} OK")
    else:
        print(f"⚠️  Falta {p_bal}. Si más abajo pones USE_OFFLINE_BALANCED=True, "
              f"ejecuta 01A_PREP_BALANCED.ipynb o el script tools/make_splits_balanced.py")

if missing:
    raise FileNotFoundError(
        "Faltan CSV obligatorios (ejecuta 01A_PREP_BALANCED.ipynb o tu pipeline de prep):\n"
        + "\n".join(" - " + m for m in missing)
    )

print("OK: splits 'train/val/test' encontrados.")


In [None]:
from src.utils import make_loaders_from_csvs
# =============================================================================
# Función para crear loaders de una tarea dada (respeta cfg del preset)
# =============================================================================
def make_loader_fn(task, batch_size, encoder, T, gain, tfm, seed,
                   num_workers=NUM_WORKERS, prefetch_factor=PREFETCH,
                   pin_memory=PIN_MEMORY, persistent_workers=PERSISTENT):
    from pathlib import Path
    name  = task["name"]
    paths = task["paths"]

    pw = persistent_workers and (num_workers > 0)
    pf = prefetch_factor if (num_workers > 0) else None

    # CLAVE: si vamos a codificar en GPU, el loader debe dar (B,C,H,W):
    encoder_for_loader = "image" if GPU_ENCODE else encoder

    return make_loaders_from_csvs(
        base_dir=RAW/name,
        train_csv=Path(paths["train"]),
        val_csv=Path(paths["val"]),
        test_csv=Path(paths["test"]),
        batch_size=batch_size,
        encoder=encoder_for_loader,
        T=T,
        gain=gain,
        tfm=tfm,
        seed=seed,
        # ---- Parche estabilidad WSL ----
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=pw,
        prefetch_factor=pf,
        aug_train=AUG_CFG,
        balance_train=_balance_flag(paths["train"]),
        balance_bins=21,
    )

print(f"[Loaders] workers={NUM_WORKERS} prefetch={PREFETCH} "
      f"pin={PIN_MEMORY} persistent={PERSISTENT} "
      f"offline_bal={USE_OFFLINE_BALANCED} online_bal={USE_ONLINE_BALANCING}")


In [None]:
# === PRUEBA UNIVERSAL: loader -> (T,B,C,H,W) -> forward con fallback AMP ===
import torch, src.training as training

# --- 1) Loader pequeño con tu helper ---
tr, va, te = make_loader_fn(
    task=task_list[0],
    batch_size=8,
    encoder="rate",   # si tu pipeline ya devuelve 4D, lo detectamos abajo
    T=10,
    gain=0.5,
    tfm=tfm,
    seed=SEED,
)

xb, yb = next(iter(tr))
print("batch del loader:", xb.shape, yb.shape)

# --- 2) A (T,B,C,H,W) según formato de entrada ---
#    - Si el dataset ya codifica (5D): solo permutar.
#    - Si es 4D (imagen): activamos encode en GPU y usamos el helper runtime.
if xb.ndim == 5:  # (B,T,C,H,W)
    x5d = xb.permute(1,0,2,3,4).contiguous()
    used_runtime_encode = False
    print("dataset ya codificado; solo permuto a (T,B,C,H,W)")
elif xb.ndim == 4:  # (B,C,H,W)
    training.set_runtime_encode(mode="rate", T=10, gain=0.5,
                                device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    x5d = training._permute_if_needed(xb)  # aplica encode+permuta -> (T,B,C,H,W)
    used_runtime_encode = True
    print("dataset 4D; uso encode en GPU y permuto a (T,B,C,H,W)")
else:
    raise RuntimeError(f"Forma inesperada del batch: {xb.shape}")

print("x5d.device:", x5d.device, "| shape:", tuple(x5d.shape))

# --- 3) Modelo y forward con fallback automático AMP ---
model = make_snn_model(tfm).to(device).eval()

def forward_with_auto_amp(model, x5d, device):
    # Intento 1: AMP (solo si hay CUDA)
    if torch.cuda.is_available():
        try:
            x_amp = x5d.to(device, dtype=torch.float16, non_blocking=True)
            with torch.inference_mode(), torch.amp.autocast('cuda', enabled=True):
                y = model(x_amp)
            print("[forward] ejecutado con AMP (fp16)")
            return y
        except Exception as e:
            print("[forward] AMP falló, reintento en FP32. Motivo:", str(e))

    # Intento 2: FP32 (CPU o fallback)
    x_fp32 = x5d.to(device, dtype=torch.float32, non_blocking=True)
    with torch.inference_mode():
        y = model(x_fp32)
    print("[forward] ejecutado en FP32")
    return y

yhat = forward_with_auto_amp(model, x5d, device)
print("yhat:", tuple(yhat.shape))

# --- 4) Limpieza del runtime encode (si se usó) ---
if used_runtime_encode:
    training.set_runtime_encode(None)
