In [None]:
from pathlib import Path
import sys
import numpy as np
import pandas as pd
import torch
from torch.amp import autocast
import cv2
from IPython.display import display

# =========================
# 0) CONFIGURACIÓN BÁSICA
# =========================
ROOT = Path("/home/cesar/proyectos/TFM_SNN").resolve()
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

# Circuito / vuelta a analizar (driving_log + IMG)
CIRCUIT_DIR = ROOT / "data" / "raw" / "udacity" / "circuito2" / "vuelta2"
LOG_PATH = CIRCUIT_DIR / "driving_log.csv"
IMG_DIR  = CIRCUIT_DIR / "IMG"

# Checkpoint a probar (task_2_circuito2 del run que quieras)
CKPT = ROOT / "outputs" / "continual_std_as-snn_gr_0.5_lam_0.2_ema_0.9_l1_scale_on_std_as_input_gr050_lam0p20_scaling_on_s07_rate_model-PilotNetSNN_66x200_gray_seed_42" / "task_2_circuito2" / "best.pth"

# Preset con el que se entrenó ese run ("std" / "accurate" / "fast")
PRESET = "std"

# Cuántas muestras de recta / curva queremos
N_STRAIGHT = 5
N_CURVE    = 5

# Umbral para considerar "recta" (en espacio normalizado [-1,1])
STRAIGHT_MAX_ABS = 0.02   # ≈ ±0.5 grados si el rango real es ±25°
DEG_RANGE        = 25.0   # si tus labels están en [-1,1] mapeados a ±25°

print("[CFG] ROOT=", ROOT)
print("[CFG] CIRCUIT_DIR=", CIRCUIT_DIR)
print("[CFG] CKPT=", CKPT)
print("[CFG] PRESET=", PRESET)

# =========================
# 1) Imports del proyecto
# =========================
from src.config import load_preset
from src.models import build_model
from src.datasets import ImageTransform, encode_rate as enc_rate, encode_latency as enc_latency

# =========================
# 2) Cargar preset + modelo
# =========================
cfg = load_preset(ROOT / "configs" / "presets.yaml", PRESET)
DATA, MODEL = cfg["data"], cfg["model"]

encoder = DATA["encoder"]
T       = int(DATA["T"])
gain    = float(DATA["gain"])
W       = int(MODEL["img_w"])
H       = int(MODEL["img_h"])
to_gray = bool(MODEL["to_gray"])

print(f"[CFG] encoder={encoder} T={T} gain={gain}")
print(f"[CFG] img={W}x{H} to_gray={to_gray}")

tfm = ImageTransform(W, H, to_gray=to_gray, crop_top=None)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = build_model("pilotnet_snn", tfm, beta=0.9, threshold=0.5)
state = torch.load(CKPT, map_location=device)
if isinstance(state, dict) and "state_dict" in state:
    sd = state["state_dict"]
elif isinstance(state, dict) and "model_state_dict" in state:
    sd = state["model_state_dict"]
else:
    sd = state
model.load_state_dict(sd)
model.to(device)
model.eval()
print(f"[OK] Modelo cargado en {device}")

# =========================
# 3) Encoder temporal runtime (mismo que sim_drive)
# =========================
def make_encode_runtimer(encoder: str, T: int, gain: float):
    enc = str(encoder).lower()
    if enc not in {"rate", "latency", "raw", "image"}:
        raise ValueError(f"Encoder no soportado: {encoder}")

    def _fn(x_img: torch.Tensor) -> torch.Tensor:
        # x_img: (C,H,W) float32 [0,1]
        if enc == "image":
            return x_img.unsqueeze(0)               # (1,C,H,W)
        elif enc == "rate":
            return enc_rate(x_img, T=T, gain=gain)  # (T,C,H,W) o (T,H,W)
        elif enc == "latency":
            return enc_latency(x_img, T=T)          # (T,C,H,W) o (T,H,W)
        elif enc == "raw":
            if x_img.dim() == 2:
                x_img = x_img.unsqueeze(0)
            return x_img.unsqueeze(0).expand(T, *x_img.shape).contiguous()
    return _fn

encode_runtime = make_encode_runtimer(encoder, T, gain)

# =========================
# 4) Cargar driving_log y seleccionar rectas/curvas automáticamente
# =========================
if not LOG_PATH.exists():
    raise FileNotFoundError(f"No existe driving_log: {LOG_PATH}")

if not IMG_DIR.exists():
    raise FileNotFoundError(f"No existe carpeta IMG: {IMG_DIR}")

log = pd.read_csv(LOG_PATH, header=None)
log.columns = ["center", "left", "right", "steering", "throttle", "brake", "speed"]

# --- IMPORTANTE: extraer basename desde rutas Windows tipo C:\...\IMG\center_xxx.jpg ---
def _basename_any(p) -> str:
    s = str(p).strip()
    # Normaliza separadores para quedarnos con lo que va después del último / o \
    s = s.replace("\\", "/")
    return s.split("/")[-1]

log["img_name"] = log["center"].apply(_basename_any)
log["img_path"] = log["img_name"].apply(lambda n: IMG_DIR / n)
log["exists"]   = log["img_path"].apply(lambda p: p.exists())

print("\n[DEBUG] Primeras 5 filas de driving_log con nombres normalizados:")
display(log.head())

print(f"[DEBUG] Ejemplo de img_path calculado: {log.iloc[0]['img_path']} (exists={log.iloc[0]['exists']})")

log = log[log["exists"]].copy()
log["abs_steer"] = log["steering"].astype(float).abs()

# Rectas “claras”
straight = log[log["abs_steer"] <= STRAIGHT_MAX_ABS].copy()
straight = straight.reset_index(drop=False)  # índice original -> col 'index'

# Curvas fuertes: top-|steering|
curves = log.sort_values("abs_steer", ascending=False).copy()
curves = curves.reset_index(drop=False)

print(f"[INFO] muestras totales con imagen={len(log)}")
print(f"[INFO] rectas (|steer| <= {STRAIGHT_MAX_ABS})={len(straight)}")
print(f"[INFO] curvas (top-|steer|)={len(curves)}")

def _pick_even(df: pd.DataFrame, n: int) -> pd.DataFrame:
    """Coge n filas espaciadas a lo largo del df (para no caer siempre en seguidas)."""
    if df.empty or n <= 0:
        return df.head(0)
    if len(df) <= n:
        return df
    idx = np.linspace(0, len(df)-1, num=n, dtype=int)
    return df.iloc[idx]

straight_sel = _pick_even(straight, N_STRAIGHT)
curves_sel   = _pick_even(curves,   N_CURVE)

print(f"[SEL] Rectas seleccionadas={len(straight_sel)} | Curvas seleccionadas={len(curves_sel)}")

# =========================
# 5) Función de inferencia para una fila del log
# =========================
DEBUG_FIRST = {"done": False}

def run_one_row(row, kind: str):
    img_path = row["img_path"]
    bgr = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
    if bgr is None:
        raise RuntimeError(f"No se pudo leer la imagen: {img_path}")

    # 1) Transform (igual que en entrenamiento)
    x_img = tfm(bgr).float()          # (C,H,W) en [0,1]

    # 2) Encode temporal (igual que en sim_drive)
    xT = encode_runtime(x_img)        # (T,C,H,W) o (T,H,W)
    if xT.dim() == 3:                 # (T,H,W) -> (T,1,H,W)
        xT = xT.unsqueeze(1)
    if xT.dim() == 4:                 # (T,C,H,W) -> (T,1,C,H,W)
        x5d = xT.unsqueeze(1)
    else:
        x5d = xT

    if not DEBUG_FIRST["done"]:
        print("\n[DEBUG] Primera muestra")
        print("  img_name:", img_path.name)
        print("  bgr.shape:", bgr.shape)
        print("  x_img.shape:", tuple(x_img.shape),
              "| min/max:", float(x_img.min().item()), float(x_img.max().item()))
        print("  xT.shape:", tuple(xT.shape))
        print("  x5d.shape:", tuple(x5d.shape))
        DEBUG_FIRST["done"] = True

    # 3) Inferencia
    with torch.no_grad():
        x5d_dev = x5d.to(device, non_blocking=True)
        use_amp = torch.cuda.is_available()
        with autocast("cuda", enabled=use_amp):
            y = model(x5d_dev)        # asumimos (B,1) o (T,B,1) → squeeze
        steer_pred = float(y.squeeze().detach().cpu().item())
        steer_pred = float(np.clip(steer_pred, -1.0, 1.0))

    steer_gt   = float(row["steering"])
    speed_gt   = float(row["speed"])

    steer_gt_deg   = steer_gt * DEG_RANGE
    steer_pred_deg = steer_pred * DEG_RANGE
    err_deg        = steer_pred_deg - steer_gt_deg

    return {
        "kind": kind,
        "log_index": int(row["index"]),
        "img_name": img_path.name,
        "steer_gt_norm": steer_gt,
        "steer_pred_norm": steer_pred,
        "steer_gt_deg": steer_gt_deg,
        "steer_pred_deg": steer_pred_deg,
        "err_deg": err_deg,
        "speed": speed_gt,
    }

# =========================
# 6) Ejecutar inferencia en rectas y curvas
# =========================
results = []
for _, r in straight_sel.iterrows():
    results.append(run_one_row(r, "recta"))
for _, r in curves_sel.iterrows():
    results.append(run_one_row(r, "curva"))

df_res = pd.DataFrame(results)
if df_res.empty:
    print("[WARN] No se generaron resultados; revisa filtros / rutas.")
else:
    df_res = df_res.sort_values(["kind", "log_index"]).reset_index(drop=True)

    print("\n=== Comparativa frame a frame (rectas / curvas) ===")
    pd.set_option("display.float_format", lambda v: f"{v:7.3f}")
    display(df_res)

    def agg_stats(kind):
        sub = df_res[df_res["kind"] == kind]
        if sub.empty:
            return None
        mae = sub["err_deg"].abs().mean()
        maxe = sub["err_deg"].abs().max()
        return mae, maxe, len(sub)

    for k in ["recta", "curva"]:
        stats = agg_stats(k)
        if stats is None:
            print(f"[STATS] {k}: sin muestras")
        else:
            mae, maxe, n = stats
            print(f"[STATS] {k}: n={n} | MAE_deg={mae:.2f} | max_err_deg={maxe:.2f}")


[CFG] ROOT= /home/cesar/proyectos/TFM_SNN
[CFG] CIRCUIT_DIR= /home/cesar/proyectos/TFM_SNN/data/raw/udacity/circuito2/vuelta2
[CFG] CKPT= /home/cesar/proyectos/TFM_SNN/outputs/continual_std_as-snn_gr_0.5_lam_0.2_ema_0.9_l1_scale_on_std_as_input_gr050_lam0p20_scaling_on_s07_rate_model-PilotNetSNN_66x200_gray_seed_42/task_2_circuito2/best.pth
[CFG] PRESET= std
[CFG] encoder=rate T=18 gain=0.7
[CFG] img=200x66 to_gray=True
[OK] Modelo cargado en cuda
[INFO] muestras totales con imagen=0
[INFO] rectas (|steer| <= 0.02)=0
[INFO] curvas (top-|steer|)=0
[SEL] Rectas seleccionadas=0 | Curvas seleccionadas=0
[WARN] No se generaron resultados; revisa filtros / rutas.
