In [1]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

import random, math, json, hashlib
from types import SimpleNamespace
from pathlib import Path
from datetime import datetime
from contextlib import nullcontext
from collections import deque

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import amp
from tqdm.auto import tqdm
import chinese_calendar as cc
import warnings
from sklearn.neighbors import BallTree

from uni2ts.model.moirai import MoiraiModule


# Device / Seed

if torch.backends.mps.is_available():
    DEVICE = "mps"
elif torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"
print(f"Using device: {DEVICE}")

SEED = 0
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)


# 경로 & 설정

ROOT = Path(".").resolve()
DATA_DIR = ROOT / "UrbanEV" / "data"
PATHS = {
    "occ": str(DATA_DIR / "occupancy.csv"),
    "dur": str(DATA_DIR / "duration.csv"),
    "vol": str(DATA_DIR / "volume.csv"),
    "e_price": str(DATA_DIR / "e_price.csv"),
    "s_price": str(DATA_DIR / "s_price.csv"),
    "weather": str(DATA_DIR / "weather_central.csv"),
    "inf": str(DATA_DIR / "inf.csv"),
    "adj": str(DATA_DIR / "adj.csv"),
    "dist": str(DATA_DIR / "distance.csv"),
    "poi": str(DATA_DIR / "poi.csv"),
}




Using device: cuda


In [2]:
# Helpers
def build_clock(train_range, test_range):
    start = pd.to_datetime(train_range[0])
    end   = pd.to_datetime(test_range[1]) + pd.Timedelta(hours=23)
    return pd.date_range(start=start, end=end, freq="h")

def read_ts_csv(path):
    df = pd.read_csv(path, index_col=0)
    df.index = pd.to_datetime(df.index)
    return df.sort_index().apply(pd.to_numeric, errors="coerce")

def reindex_local(df, clock):  return df.reindex(clock).fillna(0)
def reindex_price(df, clock):  return df.reindex(clock).ffill().bfill()

def reindex_weather(df, clock):
    out = df.reindex(clock)
    cont_cols = [c for c in out.columns if c.lower() not in ("nrain","rain","n_rain")]
    rain_cols = [c for c in out.columns if c.lower() in ("nrain","rain","n_rain")]
    if cont_cols: out[cont_cols] = out[cont_cols].interpolate("time").ffill().bfill()
    if rain_cols: out[rain_cols] = out[rain_cols].ffill().bfill().clip(lower=0).astype(int)
    return out

def load_inf(path):
    df = pd.read_csv(path, index_col=0)
    cols = ["longitude","latitude","charge_count","area","perimeter"]
    for c in cols: df[c] = pd.to_numeric(df.get(c, np.nan), errors="coerce")
    df.index = df.index.astype(str)
    return df[cols]

def load_matrix(path, as_float=False):
    df = pd.read_csv(path)
    df.index = df.columns.astype(str)
    df.columns = df.columns.astype(str)
    if df.index.duplicated().any():
        print(f"[warn] duplicated index in {path}, keeping first")
        df = df.loc[~df.index.duplicated(keep='first')]
    if df.columns.duplicated().any():
        print(f"[warn] duplicated columns in {path}, keeping first")
        df = df.loc[:, ~df.columns.duplicated(keep='first')]
    df = df.astype(float if as_float else int)
    np.fill_diagonal(df.values, 0)
    return df


def precompute_time(clock):
    hrs  = clock.hour; dows = clock.dayofweek
    hour_sin = np.sin(2*np.pi*hrs/24).astype(np.float32)
    hour_cos = np.cos(2*np.pi*hrs/24).astype(np.float32)
    dow_sin  = np.sin(2*np.pi*dows/7).astype(np.float32)
    dow_cos  = np.cos(2*np.pi*dows/7).astype(np.float32)

    # 주말 또는 중국 공휴일이면 1
    is_offday = np.array([
        float((d.weekday() in [5,6]) or cc.is_holiday(d.date()))
        for d in clock
    ], dtype=np.float32)

    return pd.DataFrame({
        "hour_sin":hour_sin,"hour_cos":hour_cos,
        "dow_sin":dow_sin,"dow_cos":dow_cos,
        "is_offday":is_offday
    }, index=clock)


def prepare_data(cfg):
    clock = build_clock(cfg.data.train_range, cfg.data.test_range)
    occ = reindex_local(read_ts_csv(cfg.data.paths["occ"]), clock)
    dur = reindex_local(read_ts_csv(cfg.data.paths["dur"]), clock)
    vol = reindex_local(read_ts_csv(cfg.data.paths["vol"]), clock)
    epr = reindex_price(read_ts_csv(cfg.data.paths["e_price"]), clock)
    spr = reindex_price(read_ts_csv(cfg.data.paths["s_price"]), clock)
    wth = reindex_weather(read_ts_csv(cfg.data.paths["weather"]), clock)
    inf = load_inf(cfg.data.paths["inf"])
    adj = load_matrix(cfg.data.paths["adj"])
    dist= load_matrix(cfg.data.paths["dist"], as_float=True)

    tables = {"occ":occ,"dur":dur,"vol":vol,"e_price":epr,"s_price":spr,
              "weather":wth,"inf":inf,"adj":adj,"distance":dist,"time":precompute_time(clock)}

    if cfg.data.use_poi and os.path.exists(cfg.data.paths["poi"]):
        tables["poi_raw"] = pd.read_csv(cfg.data.paths["poi"])

    zone_ids = sorted(list(set(occ.columns) & set(epr.columns) & set(spr.columns)))
    if len(zone_ids) != len(set(zone_ids)): raise ValueError("zone_ids duplicated")

    for k in ["occ","dur","vol","e_price","s_price"]:
        tables[k] = tables[k][zone_ids]
    tables["adj"]      = tables["adj"].reindex(index=zone_ids, columns=zone_ids, fill_value=0)
    tables["distance"] = tables["distance"].reindex(index=zone_ids, columns=zone_ids, fill_value=float('inf'))
    tables["inf"]      = tables["inf"].reindex(index=zone_ids)

    datahash  = hashlib.sha256(str(cfg.data.paths).encode()).hexdigest()[:16]
    clockhash = hashlib.sha256(str(clock).encode()).hexdigest()[:16]
    hashes = SimpleNamespace(datahash=datahash, clockhash=clockhash)
    return clock, tables, zone_ids, hashes

def compute_fit_stats(tables, train_times, zone_ids):
    """
    학습(train) 구간에서 표준화에 쓸 통계치 계산.
    - z-score 대상: occ, dur, e_price, s_price, weather(T,P0,P,U,Td)
    - log1p+z 대상(통계만): vol, inf(charge_count/area/perimeter), poi_counts[*]
    """
    stats = {}

    # 1) 시계열: z-score
    for key in ["occ", "dur", "e_price", "s_price"]:
        if key in tables:
            df = tables[key].loc[train_times, zone_ids]
            mu = float(np.nanmean(df.values))
            sd = float(np.nanstd(df.values)) or 1.0
            stats[key] = {"mean": mu, "std": sd}

    # 2) 시계열: log1p + z-score (vol)
    if "vol" in tables:
        df = tables["vol"].loc[train_times, zone_ids]
        vals = np.log1p(df.values.astype(float))
        mu = float(np.nanmean(vals))
        sd = float(np.nanstd(vals)) or 1.0
        stats["log_vol"] = {"mean": mu, "std": sd}

    # 3) 날씨 연속값: z-score  (rain 계열 제외)
    if "weather" in tables:
        wdf = tables["weather"].loc[train_times]
        for c in ["T", "P0", "P", "U", "Td"]:
            if c in wdf.columns:
                col = pd.to_numeric(wdf[c], errors="coerce").values
                mu = float(np.nanmean(col))
                sd = float(np.nanstd(col)) or 1.0
                stats[c] = {"mean": mu, "std": sd}
            else:
                stats[c] = {"mean": 0.0, "std": 1.0}

    # 4) 정적: log1p + z-score (원본은 유지할 것이므로 통계만)
    if "inf" in tables:
        inf = tables["inf"]
        for c in ["charge_count", "area", "perimeter"]:
            if c in inf.columns:
                vals = np.log1p(pd.to_numeric(inf[c], errors="coerce").values.astype(float))
                mu = float(np.nanmean(vals))
                sd = float(np.nanstd(vals)) or 1.0
                stats[f"log_{c}"] = {"mean": mu, "std": sd}

    # 5) POI 카운트: log1p + z-score (있을 때)
    if "poi_counts" in tables:
        poi = tables["poi_counts"]
        for c in poi.columns:
            vals = np.log1p(pd.to_numeric(poi[c], errors="coerce").values.astype(float))
            mu = float(np.nanmean(vals))
            sd = float(np.nanstd(vals)) or 1.0
            stats[f"log_{c}"] = {"mean": mu, "std": sd}

    # 버전 태깅
    stats_json = json.dumps(stats, sort_keys=True)
    stats["norm_version"] = hashlib.sha256(stats_json.encode()).hexdigest()[:16]
    return stats


def standardize_tables(tables, stats):
    """
    tables를 제자리(in-place)에서 표준화.
    - z-score: occ, dur, e_price, s_price, weather(T,P0,P,U,Td)
    - log1p+z: vol, poi_counts[*]
    - inf는 **원본 유지**(area/perimeter 원값 필요). 필요 시 fb_static 등에서 on-the-fly로 변환하세요.
    """
    eps = 1e-8

    # 1) 시계열: z-score
    for key in ["occ", "dur", "e_price", "s_price"]:
        if key in tables and key in stats:
            mu = stats[key]["mean"]; sd = max(stats[key]["std"], eps)
            tables[key] = ((tables[key] - mu) / sd).astype(np.float32)

    # 2) 시계열: log1p + z-score (vol)
    if "vol" in tables and "log_vol" in stats:
        mu = stats["log_vol"]["mean"]; sd = max(stats["log_vol"]["std"], eps)
        df = tables["vol"].astype(float)
        tables["vol"] = ((np.log1p(df) - mu) / sd).astype(np.float32)

    # 3) 날씨 연속값: z-score (rain 계열은 손대지 않음)
    if "weather" in tables:
        w = tables["weather"].copy()
        for c in ["T", "P0", "P", "U", "Td"]:
            if c in w.columns and c in stats:
                mu = stats[c]["mean"]; sd = max(stats[c]["std"], eps)
                w[c] = ((pd.to_numeric(w[c], errors="coerce") - mu) / sd).astype(np.float32)
        tables["weather"] = w

    # 4) POI 카운트: log1p + z-score (있을 때, 모델 입력만 쓸 거라면 덮어써도 OK)
    if "poi_counts" in tables:
        poi = tables["poi_counts"].copy()
        for c in poi.columns:
            key = f"log_{c}"
            if key in stats:
                mu = stats[key]["mean"]; sd = max(stats[key]["std"], eps)
                col = np.log1p(pd.to_numeric(poi[c], errors="coerce").astype(float))
                poi[c] = ((col - mu) / sd).astype(np.float32)
        tables["poi_counts"] = poi



def build_sample_indices(tables, clock, zone_ids, cfg):
    # ---- 파라미터 ----
    L = int(cfg.data.L)
    H = int(cfg.data.H)
    O = int(getattr(cfg.data, "pred_offset", 1))
    zone_ids = list(zone_ids)

    # ---- split 라벨링 (끝-of-day 보정 포함) ----
    split_sr = pd.Series(index=clock, data="none")

    def _eod(ts):
        ts = pd.to_datetime(ts)
        return ts + pd.Timedelta(hours=23) if (ts.hour==0 and ts.minute==0 and ts.second==0) else ts

    tr_s, tr_e = pd.to_datetime(cfg.data.train_range[0]), _eod(cfg.data.train_range[1])
    va_s, va_e = pd.to_datetime(cfg.data.val_range[0]),   _eod(cfg.data.val_range[1])
    te_s, te_e = pd.to_datetime(cfg.data.test_range[0]),  _eod(cfg.data.test_range[1])

    split_sr.loc[(split_sr.index>=tr_s)&(split_sr.index<=tr_e)] = "train"
    split_sr.loc[(split_sr.index>=va_s)&(split_sr.index<=va_e)] = "val"
    split_sr.loc[(split_sr.index>=te_s)&(split_sr.index<=te_e)] = "test"

    # ---- 통계/표준화 ----
    train_times = split_sr[split_sr=="train"].index
    stats = compute_fit_stats(tables, train_times, zone_ids)
    standardize_tables(tables, stats)

    # ====== 유효성 마스크 ======
    occ_fin = pd.DataFrame(np.isfinite(tables["occ"].values), index=clock, columns=zone_ids)
    epr_fin = pd.DataFrame(np.isfinite(tables["e_price"].values), index=clock, columns=zone_ids)
    spr_fin = pd.DataFrame(np.isfinite(tables["s_price"].values), index=clock, columns=zone_ids)

    masks = [occ_fin, epr_fin, spr_fin]

    # dur/vol을 cfg에 따라 추가
    if "dur" in getattr(cfg.data, "moirai_channels", []):
        dur_fin = pd.DataFrame(np.isfinite(tables["dur"].values), index=clock, columns=zone_ids)
        masks.append(dur_fin)

    if "vol" in getattr(cfg.data, "moirai_channels", []):
        vol_fin = pd.DataFrame(np.isfinite(tables["vol"].values), index=clock, columns=zone_ids)
        masks.append(vol_fin)

    # 입력창 L 연속 조건
    input_ok = True
    for m in masks:
        input_ok = input_ok & (m.rolling(L, min_periods=L).sum().eq(L).shift(-(L-1), fill_value=False))

    # 날씨는 zone 축이 없으니 컨티뉴어스 5개 모두 finite → L 연속 후 브로드캐스트
    w5 = tables["weather"][["T","P0","P","U","Td"]]
    w5_fin = pd.Series(np.isfinite(w5.values).all(axis=1), index=clock)
    w_ok_sr = w5_fin.rolling(L, min_periods=L).sum().eq(L).shift(-(L-1), fill_value=False)
    w_ok = pd.DataFrame(np.repeat(w_ok_sr.values[:,None], len(zone_ids), axis=1), index=clock, columns=zone_ids)

    # 타깃창 (t0+O 부터 H개 연속 finite)
    if H == 1:
        tgt_ok = occ_fin.shift(-O, fill_value=False)
    else:
        tgt_ok = (occ_fin.shift(-O, fill_value=False)
                          .rolling(H, min_periods=H).sum().eq(H)
                          .shift(-(H-1), fill_value=False))

    # 종합 유효성
    combined_ok = input_ok & w_ok & tgt_ok

    # ====== 경계(윈도우) 포함을 수식으로 강제 ======
    # 허용 t0 범위: [S, E - max(L-1, O+H-1)]
    Lm1 = pd.Timedelta(hours=L-1)
    Hm1 = pd.Timedelta(hours=H-1)
    Off = pd.Timedelta(hours=O)

    out = {}
    for split, (S, E) in {
        "train": (tr_s, tr_e),
        "val"  : (va_s, va_e),
        "test" : (te_s, te_e),
    }.items():
        upper = E - max(Lm1, Off + Hm1)  # t0 상한
        # 시간 인덱스 기반 부울 시리즈 → DataFrame으로 브로드캐스트
        t0_ok_sr = (clock >= S) & (clock <= upper)                  # ndarray(bool)
        t0_ok = pd.DataFrame(np.repeat(t0_ok_sr[:, None], len(zone_ids), axis=1),
                            index=clock, columns=zone_ids)

        split_ok_sr = (split_sr == split).to_numpy()  # shape (n,)
        split_ok = pd.DataFrame(np.repeat(split_ok_sr[:, None], len(zone_ids), axis=1),
                                index=clock, columns=zone_ids)

        valid = combined_ok & t0_ok & split_ok
        # 평탄화
        flat = valid.stack().reset_index()
        flat = flat[flat[0]].drop(columns=0)
        flat.columns = ["t_start", "zone_id"]

        # 메타
        if flat.empty:
            out[split] = flat.reset_index(drop=True)
            continue

        flat["zone_idx"] = flat["zone_id"].map({z:i for i,z in enumerate(zone_ids)})
        flat["L"] = L
        flat["H"] = H
        flat["occ_only"] = int(getattr(cfg.data, "baseline_occ_only", True))
        flat["split"] = split
        out[split] = flat.reset_index(drop=True)

    return out, stats





In [3]:
# POI utilities 
def compute_zone_radius(inf, beta=0.7, r_min=300, r_max=2000):
    A = inf["area"].to_numpy(); P = inf["perimeter"].to_numpy()
    r_area  = np.sqrt(np.clip(A,0,None)/np.pi)
    r_perim = P/(2*np.pi)
    r = beta*r_area + (1-beta)*r_perim
    return pd.Series(np.clip(r, r_min, r_max), index=inf.index)


def build_poi_counts_fast(inf, poi, radius_m):
    # radians
    z = np.deg2rad(np.c_[inf["latitude"].values, inf["longitude"].values])
    p = np.deg2rad(np.c_[poi["latitude"].values, poi["longitude"].values])

    # BallTree on POI
    tree = BallTree(p, metric="haversine")

    # radius vector in radians
    r_vec = (radius_m.values / 6371000.0).astype(np.float64)

    # single batched query
    idx_lists = tree.query_radius(z, r=r_vec)  # list of index arrays, len = n_zones

    types = poi["primary_types"].str.lower().values
    is_life = np.char.find(types.astype(str), "lifestyle services") >= 0
    is_bres = np.char.find(types.astype(str), "business and residential") >= 0
    is_food = np.char.find(types.astype(str), "food and beverage services") >= 0

    out = np.zeros((len(inf), 3), dtype=np.int32)
    for i, idxs in enumerate(idx_lists):
        if idxs.size == 0: 
            continue
        out[i, 0] = int(is_life[idxs].sum())
        out[i, 1] = int(is_bres[idxs].sum())
        out[i, 2] = int(is_food[idxs].sum())

    return pd.DataFrame(out, index=inf.index,
        columns=["poi_lifestyle","poi_business_residential","poi_food_beverage"])


def ensure_poi_counts(tables, cfg):
    """
    - 캐시된 poi_counts 있으면 로드
    - 없으면 compute_zone_radius + build_poi_counts 돌려 계산 후 저장
    """
    # 캐시 폴더
    cache_dir = Path(getattr(cfg.exec, "poi_shared_dir", "poi_cache_global"))
    cache_dir.mkdir(parents=True, exist_ok=True)

    # 캐시 파일 이름 만들기 (데이터/파라미터 기반 해시)
    meta = {
        "beta": cfg.data.poi_radius_beta,
        "rmin": cfg.data.poi_rmin,
        "rmax": cfg.data.poi_rmax,
        # inf/poi 내용도 포함 (여기서는 간단히 길이와 해시로)
        "n_zones": len(tables["inf"]),
        "n_poi": len(tables["poi_raw"])
    }
    h = hashlib.sha256(json.dumps(meta, sort_keys=True).encode()).hexdigest()[:16]
    cache_file = cache_dir / f"poi_counts_{h}.parquet"

    # 이미 있으면 로드
    if cache_file.exists():
        poi_counts = pd.read_parquet(cache_file)
        tables["poi_counts"] = poi_counts
        return poi_counts

    # 없으면 새로 계산
    radius = compute_zone_radius(
        tables["inf"],
        beta=cfg.data.poi_radius_beta,
        r_min=cfg.data.poi_rmin,
        r_max=cfg.data.poi_rmax
    )
    poi_counts = build_poi_counts(tables["inf"], tables["poi_raw"], radius)

    # 저장
    poi_counts.to_parquet(cache_file)
    tables["poi_counts"] = poi_counts
    return poi_counts



In [4]:
# Moirai cache (keying fixed)
KEY_COLS = ["zone_id","t_start_iso","L","occ_only","c_sig","model_id","psig","pool","clock_hash","norm_version","datahash"]


def make_moirai_keys(si, hashes, stats, cfg):
    df = si.copy()
    df["t_start_iso"] = df["t_start"].dt.strftime("%Y-%m-%dT%H:%M:%S")
    df["L"] = df["L"].astype(int)
    df["occ_only"] = df["occ_only"].astype(int)
    df["model_id"] = cfg.moirai.model_id
    df["psig"] = "-".join(map(str, cfg.moirai.patch_sizes))
    df["pool"] = cfg.moirai.pool
    df["clock_hash"] = hashes.clockhash
    df["norm_version"] = stats["norm_version"]
    df["datahash"] = hashes.datahash
    chan = getattr(cfg.data, "moirai_channels", ["occ"])
    df["c_sig"] = "+".join(chan)  # 예: "occ" or "occ+e_price+s_price"
    return df[KEY_COLS]

@torch.no_grad()
def encode_batch(x_np, backbone, patch_sizes, pool):
    x = torch.from_numpy(x_np).to(DEVICE).float()  # (B, C, L) ; 여기선 C=1
    B, C, L = x.shape
    target = x.transpose(1,2)  # (B, L, C)
    observed_mask = torch.ones_like(target, dtype=torch.bool)
    prediction_mask = torch.zeros(B, L, dtype=torch.bool, device=DEVICE)
    sample_id = torch.arange(B, device=DEVICE).unsqueeze(1).expand(B, L)
    time_id   = torch.arange(L, device=DEVICE).unsqueeze(0).expand(B, L)
    variate_id= torch.zeros(B, L, dtype=torch.long, device=DEVICE)

    hs_list = []
    for ps in patch_sizes:
        patch_size = torch.full((B, L), ps, dtype=torch.long, device=DEVICE)
        try:
            hs = backbone.encode(target=target, observed_mask=observed_mask, patch_size=patch_size)
        except AttributeError:
            output = backbone(target=target, observed_mask=observed_mask,
                              prediction_mask=prediction_mask, sample_id=sample_id,
                              time_id=time_id, variate_id=variate_id, patch_size=patch_size)
            hs = output if isinstance(output, torch.Tensor) else getattr(output, "mean", output.sample())
        if hs.dim()==3: hs = hs.mean(dim=1) if pool=="mean" else hs.max(dim=1)[0]
        elif hs.dim()!=2: raise ValueError(f"Unexpected hs shape: {hs.shape}")
        hs_list.append(hs)
    emb = torch.stack(hs_list).mean(dim=0)  # (B, D_emb)
    return emb.cpu().numpy()

def _sanitize(s:str)->str:
    return "".join(ch if (str(ch).isalnum() or ch in "-._") else "_" for ch in str(s))

def _make_cache_path_with_dim(cfg, model_id, patch_sig, pool, emb_dim, hashes, stats):
    c_sig = "+".join(getattr(cfg.data, "moirai_channels", ["occ"]))
    sig = {"model":model_id,"psig":patch_sig,"pool":pool,"edim":str(emb_dim),
           "c":c_sig, 
           "clock":hashes.clockhash[:8],"norm":stats["norm_version"][:8],"data":hashes.datahash[:8]}
    fname = "moirai_" + "_".join(f"{k}-{_sanitize(v)}" for k,v in sig.items()) + ".pt"
    base_dir = str(Path(getattr(cfg.exec, "moirai_shared_dir", "")))
    return os.path.join(base_dir, fname)


def ensure_moirai_cache(keys_df, tables, cfg, hashes, stats):
    # 빈 키 가드
    if len(keys_df) == 0:
        empty_df = pd.DataFrame(columns=keys_df.columns)
        edim_default = int(getattr(cfg.moirai, "emb_dim", 0))
        empty_tensor = torch.zeros(0, edim_default, dtype=torch.float32)
        return empty_df, empty_tensor, edim_default

    # 사용할 채널 결정
    chan = getattr(cfg.data, "moirai_channels", None)
    if chan is None:
        chan = ["occ"] if bool(getattr(cfg.data, "baseline_occ_only", True)) else ["occ","e_price","s_price"]

    # 간단한 헬퍼: (C,L) 스택
    def stack_channels(tables, z, t0, L, chan_names):
        win = slice(t0, t0 + pd.Timedelta(hours=L-1))
        arrs = []
        for name in chan_names:
            # 각 채널은 시간×zone 테이블이어야 함(예: occ/e_price/s_price)
            series = tables[name].loc[win, z].values.astype(np.float32)
            arrs.append(series)  # (L,)
        return np.stack(arrs, axis=0)  # (C, L)

    backbone = MoiraiModule.from_pretrained(cfg.moirai.model_id).to(DEVICE).eval()
    patch_sig = "-".join(map(str, cfg.moirai.patch_sizes))
    pool = cfg.moirai.pool

    # ---- 임베딩 차원 프로빙 (1개 샘플로 실행) ----
    first_row = keys_df.iloc[0]
    z, t0 = first_row["zone_id"], pd.to_datetime(first_row["t_start_iso"])
    x_probe = stack_channels(tables, z, t0, cfg.data.L, chan)     # (C,L)
    x_probe = x_probe[None, ...]                                  # (1,C,L)
    probe_np = encode_batch(x_probe, backbone, cfg.moirai.patch_sizes, pool)  # (1,E)
    emb_dim = int(probe_np.shape[1])

    # ---- 캐시 경로(파일명에 c_sig 포함되도록 앞서 구현 완료라고 가정) ----
    cache_file = _make_cache_path_with_dim(
        cfg, cfg.moirai.model_id, patch_sig, pool, emb_dim, hashes, stats
    )
    os.makedirs(os.path.dirname(cache_file), exist_ok=True)

    # ---- 로드 or 초기화 ----
    if os.path.exists(cache_file):
        moirai_df, moirai_emb = torch.load(cache_file, map_location="cpu")
        if moirai_emb.ndim != 2 or moirai_emb.shape[1] != emb_dim:
            print(f"[MoiraiCache] dim mismatch in file: {moirai_emb.shape} vs emb_dim={emb_dim} → reset")
            moirai_df = pd.DataFrame(columns=keys_df.columns)
            moirai_emb = torch.zeros(0, emb_dim, dtype=torch.float32)
    else:
        moirai_df = pd.DataFrame(columns=keys_df.columns)
        moirai_emb = torch.zeros(0, emb_dim, dtype=torch.float32)

    # ---- 누락 키만 인코딩 ----
    key2row = {tuple(row[KEY_COLS]): i for i, (_, row) in enumerate(moirai_df.iterrows())}
    todo_rows = [r for _, r in keys_df.iterrows() if tuple(r[KEY_COLS]) not in key2row]

    if len(todo_rows) > 0:
        B = cfg.moirai.batch_size
        for i in tqdm(range(0, len(todo_rows), B), desc="[MoiraiCache] encode"):
            batch_rows = todo_rows[i:i+B]
            xs = []
            for r in batch_rows:
                z = r["zone_id"]
                t0 = pd.to_datetime(r["t_start_iso"])
                x_cl = stack_channels(tables, z, t0, cfg.data.L, chan)  # (C,L)
                xs.append(x_cl)
            x_np = np.stack(xs).astype(np.float32)  # (b, C, L)

            y_np = encode_batch(x_np, backbone, cfg.moirai.patch_sizes, pool)  # (b, E)
            new_emb = torch.from_numpy(y_np).float()   # CPU

            if new_emb.shape[1] != moirai_emb.shape[1]:
                print(f"[MoiraiCache] runtime dim mismatch: cache={moirai_emb.shape[1]} vs new={new_emb.shape[1]} → reset")
                moirai_df = pd.DataFrame(columns=keys_df.columns)
                moirai_emb = torch.zeros(0, new_emb.shape[1], dtype=torch.float32)

            batch_df = pd.DataFrame(batch_rows, columns=keys_df.columns)
            moirai_df = pd.concat([moirai_df, batch_df], ignore_index=True)
            moirai_emb = torch.cat([moirai_emb, new_emb], dim=0)

            torch.save((moirai_df, moirai_emb), cache_file + ".tmp")
        os.replace(cache_file + ".tmp", cache_file)

    # 참고: 경로/행수/차원 출력(원하면 유지)
    # print(f"[MoiraiCache] path={cache_file} | rows={len(moirai_df)} | dim={emb_dim}")

    return moirai_df, moirai_emb, emb_dim



In [5]:
# Feature builders
def fb_local(tables, z, t0, L, occ_only=True):
    win = slice(t0, t0 + pd.Timedelta(hours=L-1))
    occ = tables["occ"].loc[win, z].values.astype(np.float32)
    return occ[None, :]  # (1, L)

def fb_price(tables, z, t0, L):
    win = slice(t0, t0 + pd.Timedelta(hours=L-1))
    epr = tables["e_price"].loc[win, z].values.astype(np.float32)
    spr = tables["s_price"].loc[win, z].values.astype(np.float32)
    return np.stack([epr, spr]).astype(np.float32)

def fb_weather(tables, t0, L):
    win = slice(t0, t0 + pd.Timedelta(hours=L-1))
    w = tables["weather"].loc[win]
    cont = w[["T","P0","P","U","Td"]].values.T.astype(np.float32)
    rain = w.get("nRAIN", pd.Series(0, index=w.index)).clip(0,3).values
    rain_oh = np.zeros((4, L), dtype=np.float32); rain_oh[rain, np.arange(L)] = 1.0
    return np.concatenate([cont, rain_oh], axis=0).astype(np.float32)

def fb_spatial(tables, z, t0, L, zone_ids, spatial_usekeys=("occ", "dur", "vol"), use_adj=True):
    """
    공간 피처 (mean/gap/ratio) × len(use_keys)
    use_keys: 고려할 시계열 키 목록 ("occ","dur","vol")
    use_adj : True면 adj 행렬 곱해서 인접 zone만 고려
    """
    win = slice(t0, t0 + pd.Timedelta(hours=L-1))
    zi = zone_ids.index(z)
    D = tables["distance"].values.astype(np.float32)
    W = 1.0 / (D + 1e-6)
    np.fill_diagonal(W, 0.0)

    if use_adj and "adj" in tables:
        A = tables["adj"].values.astype(np.float32)
        np.fill_diagonal(A, 0.0)
        W *= A   # adj=0이면 제외

    # 행 정규화
    W = W / (W.sum(axis=1, keepdims=True) + 1e-6)
    w_i = W[zi]

    feats = []
    for key in spatial_usekeys:
        if key not in tables:
            continue
        mat = tables[key].loc[win].values.astype(np.float32)  # (L, Z)
        mean_val = mat @ w_i
        self_val = mat[:, zi]
        gap   = self_val - mean_val
        ratio = np.clip(self_val / (mean_val + 1e-6), 0, 5)
        feats.append(np.stack([mean_val, gap, ratio]))

    return np.concatenate(feats, axis=0).astype(np.float32)  # (3*len(use_keys), L)



def fb_static(tables, z, stats=None):
    row = tables["inf"].loc[z]

    def _safe_standardize_log1p(val, stat_key):
        v = pd.to_numeric(val, errors="coerce")
        if not np.isfinite(v):
            return np.float32(0.0)  # mean-impute in standardized space

        # 🔧 음수 방어: log1p 전에 0으로 클립
        v = float(v)
        if v < 0:
            v = 0.0

        l = math.log1p(v)
        s = stats.get(stat_key, {"mean": 0.0, "std": 1.0})
        mu, sd = float(s.get("mean", 0.0)), max(float(s.get("std", 1.0)), 1e-8)
        return np.float32((l - mu) / sd)


    # meta: 3개 모두 log1p + z-score (NaN→0.0)
    cc  = _safe_standardize_log1p(row.get("charge_count", np.nan), "log_charge_count")
    ar  = _safe_standardize_log1p(row.get("area", np.nan),          "log_area")
    pe  = _safe_standardize_log1p(row.get("perimeter", np.nan),     "log_perimeter")
    meta = np.array([cc, ar, pe], dtype=np.float32)

    # poi: 있으면 동일하게 처리, 없으면 0
    if "poi_counts" in tables and isinstance(tables["poi_counts"], pd.DataFrame):
        poi_row = tables["poi_counts"].loc[z]
        vals = []
        for c in poi_row.index:
            stat_key = f"log_{c}"
            if stats is not None and stat_key in stats:
                vals.append(_safe_standardize_log1p(poi_row[c], stat_key))
            else:
                # 통계가 없으면 0.0으로
                vals.append(np.float32(0.0))
        poi = np.array(vals, dtype=np.float32)
    else:
        poi = np.zeros(3, dtype=np.float32)

    return np.concatenate([meta, poi]).astype(np.float32)



def fb_time(tables, t0, L):
    win = slice(t0, t0 + pd.Timedelta(hours=L-1))
    return tables["time"].loc[win].values.T.astype(np.float32)

def fb_target(tables, z, t0, H, offset=None):
    if offset is None:
        offset = getattr(cfg.data, "pred_offset", 1)  # 기본 1
    start = t0 + pd.Timedelta(hours=offset)
    end   = start + pd.Timedelta(hours=H-1)
    if H == 1:
        y = tables["occ"].loc[start, z].astype(np.float32)
        y = np.array([y], dtype=np.float32)
        mask = np.isfinite(y)
        return y, mask
    else:
        y_idx = slice(start, end)
        y = tables["occ"].loc[y_idx, z].values.astype(np.float32)
        mask = np.isfinite(y)
        return y, mask




In [6]:
#### 위에꺼 디버그
# 새로/수정: EVDataset
# ------------------------------
# EVDataset (수정 버전)
# ------------------------------
class EVDataset(torch.utils.data.Dataset):
    def __init__(self, si_df: pd.DataFrame, tables, zone_ids, cfg, stats=None):
        self.si = si_df.reset_index(drop=True)
        self.tables = tables
        self.zone_ids = list(zone_ids)
        self.cfg = cfg
        self.stats = stats  # ★ fb_static에서 사용

        # spatial 옵션 (cfg.data.spatial 네임스페이스)
        self.spatial_keys = tuple(getattr(cfg.data, "spatial_usekeys", ("occ",)))
        self.spatial_use_adj = bool(getattr(cfg.data, "spatial_use_adj", True))
        self.spatial_dist_thresh = getattr(getattr(cfg.data, "spatial", SimpleNamespace()), "dist_thresh", None)
        self.spatial_fallback_k = getattr(getattr(cfg.data, "spatial", SimpleNamespace()), "fallback_k", 5)

    def __len__(self):
        return len(self.si)

    def __getitem__(self, idx):
        r = self.si.iloc[idx]
        z   = r["zone_id"]
        t0  = pd.to_datetime(r["t_start"])
        L   = int(r["L"])
        H   = int(self.cfg.data.H)

        # ------ build inputs ------
        x_local   = fb_local  (self.tables, z, t0, L)                         
        x_weather = fb_weather(self.tables, t0, L)                            
        x_price   = fb_price  (self.tables, z, t0, L)                        

        # ★ fb_spatial 확장 버전 호출
        x_spatial = fb_spatial(
            self.tables, z, t0, L, self.zone_ids,
            spatial_usekeys=self.spatial_keys,
            use_adj=self.spatial_use_adj
        )

        x_time    = fb_time   (self.tables, t0, L)                            
        # ★ fb_static에 stats 전달
        x_static  = fb_static (self.tables, z, stats=self.stats)              

        # ------ target ------
        y, mask = fb_target(self.tables, z, t0, H, offset=getattr(self.cfg.data, "pred_offset", 1))

        # 디버그 유효성 체크 (원래 코드 그대로 두어도 OK)
        def _chk(name, arr, expect_ndim=None):
            if not np.isfinite(arr).all():
                bad = np.sum(~np.isfinite(arr))
                raise ValueError(f"[EVDataset] non-finite in {name} (bad={bad}) | idx={idx} z={z} t0={t0}")
            if expect_ndim is not None and arr.ndim != expect_ndim:
                raise ValueError(f"[EVDataset] ndim mismatch for {name}: got {arr.ndim}, expect {expect_ndim}")
        _chk("x_local",   x_local,   expect_ndim=2)
        _chk("x_weather", x_weather, expect_ndim=2)
        _chk("x_price",   x_price,   expect_ndim=2)
        _chk("x_spatial", x_spatial, expect_ndim=2)
        _chk("x_time",    x_time,    expect_ndim=2)
        _chk("x_static",  x_static,  expect_ndim=1)
        _chk("y", y, expect_ndim=1)

        has_target = bool(np.isfinite(y[mask]).any()) if mask.any() else False

        moirai_key = {
            "zone_id": z,
            "t_start_iso": t0.strftime("%Y-%m-%dT%H:%M:%S"),
            "L": L,
            "occ_only": int(getattr(self.cfg.data, "baseline_occ_only", True)),
            "c_sig": "+".join(getattr(self.cfg.data, "moirai_channels", ["occ"])),
            "model_id": self.cfg.moirai.model_id,
            "psig": "-".join(map(str, self.cfg.moirai.patch_sizes)),
            "pool": self.cfg.moirai.pool,
        }

        return {
            "x_local":  x_local,
            "x_weather":x_weather,
            "x_price":  x_price,
            "x_spatial":x_spatial,
            "x_time":   x_time,
            "x_static": x_static,
            "y":        y.astype(np.float32),
            "mask":     mask.astype(bool),
            "has_target": has_target,
            "moirai_key": moirai_key,
        }

def _to_f32_on_device(xs):
    return torch.from_numpy(np.stack(xs)).float().to(DEVICE)

def _stack_bool_on_device(xs):
    return torch.from_numpy(np.stack(xs)).to(DEVICE)

def collate_fn_builder_baseline(moirai_emb, key2row, hashes, stats, *, debug_print=False):
    def collate_fn(batch):
        # --- 1) 샘플 레벨 필터: 유효 타깃 있는 샘플만
        valid_batch = [b for b in batch if b.get("has_target", True) and b["mask"].any()]
        if len(valid_batch) == 0:
            # 학습 루프에서 skip 처리할 수 있게 플래그 반환
            if debug_print:
                print("[collate/baseline] skip: no valid samples in this batch")
            return {"skip": True}

        batch = valid_batch

        # --- 2) Moirai 키 구성
        rows = []
        for b in batch:
            mk = dict(b["moirai_key"])
            mk["clock_hash"] = hashes.clockhash
            mk["norm_version"] = stats["norm_version"]
            mk["datahash"] = hashes.datahash
            rows.append([mk[k] for k in KEY_COLS])
        keys_df = pd.DataFrame(rows, columns=KEY_COLS)
        idxs = [key2row[tuple(r)] for _, r in keys_df.iterrows()]

        # --- 3) 스택
        out = {
            "x_local":   _to_f32_on_device([b["x_local"]   for b in batch]),  # (B,1,L)
            "x_weather": _to_f32_on_device([b["x_weather"] for b in batch]),
            "x_price":   _to_f32_on_device([b["x_price"]   for b in batch]),
            "x_spatial": _to_f32_on_device([b["x_spatial"] for b in batch]),
            "x_static":  _to_f32_on_device([b["x_static"]  for b in batch]),
            "x_time":    _to_f32_on_device([b["x_time"]    for b in batch]),
            "y":         _to_f32_on_device([b["y"]         for b in batch]),  # (B,H)
            "mask":      _stack_bool_on_device([b["mask"]  for b in batch]),  # (B,H)
            "h_local_moirai": moirai_emb[idxs].to(DEVICE).float(),            # (B, E)
            "skip": False,
        }

        # --- 4) 배치 유효성 디버그
        sel = out["mask"].bool()
        if debug_print:
            B, H = out["y"].shape
            print(f"[collate/baseline] B={B} H={H} | valid_targets={int(sel.sum().item())}")
        if sel.sum().item() == 0:
            if debug_print:
                print("[collate/baseline] skip: mask has no True")
            out["skip"] = True
            return out

        # NaN/Inf 체크 (y는 mask 기준으로만 검사)
        if not torch.isfinite(out["y"][sel]).all():
            raise ValueError("[collate/baseline] non-finite in y (masked selection)")
        # 입력들 간단 검사(원하면 더 추가)
        for k in ["x_local","x_weather","x_price","x_spatial","x_time","x_static","h_local_moirai"]:
            if not torch.isfinite(out[k]).all():
                raise ValueError(f"[collate/baseline] non-finite in {k}")

        return out
    return collate_fn


def collate_fn_builder_finetune(*, debug_print=False):
    def collate_fn(batch):
        # --- 1) 샘플 레벨 필터
        valid_batch = [b for b in batch if b.get("has_target", True) and b["mask"].any()]
        if len(valid_batch) == 0:
            if debug_print:
                print("[collate/finetune] skip: no valid samples in this batch")
            return {"skip": True}

        batch = valid_batch

        # --- 2) 스택
        out = {
            "x_local":   _to_f32_on_device([b["x_local"]   for b in batch]),
            "x_weather": _to_f32_on_device([b["x_weather"] for b in batch]),
            "x_price":   _to_f32_on_device([b["x_price"]   for b in batch]),
            "x_spatial": _to_f32_on_device([b["x_spatial"] for b in batch]),
            "x_static":  _to_f32_on_device([b["x_static"]  for b in batch]),
            "x_time":    _to_f32_on_device([b["x_time"]    for b in batch]),
            "y":         _to_f32_on_device([b["y"]         for b in batch]),
            "mask":      _stack_bool_on_device([b["mask"]  for b in batch]),
            "skip": False,
        }

        # --- 3) 배치 유효성 디버그
        sel = out["mask"].bool()
        if debug_print:
            B, H = out["y"].shape
            print(f"[collate/finetune] B={B} H={H} | valid_targets={int(sel.sum().item())}")
        if sel.sum().item() == 0:
            if debug_print:
                print("[collate/finetune] skip: mask has no True")
            out["skip"] = True
            return out

        if not torch.isfinite(out["y"][sel]).all():
            raise ValueError("[collate/finetune] non-finite in y (masked selection)")
        for k in ["x_local","x_weather","x_price","x_spatial","x_time","x_static"]:
            if not torch.isfinite(out[k]).all():
                raise ValueError(f"[collate/finetune] non-finite in {k}")

        return out
    return collate_fn


In [7]:
class EVForecastModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        D   = cfg.model.D_model
        p   = cfg.model.dropout

        # ----- local branch (mode별) -----
        if cfg.model.mode == "baseline":
            # emb_dim은 ensure_moirai_cache() 이후 cfg.model.D_moirai가 세팅되어 있음
            D_m = cfg.model.D_moirai
            self.adapter = nn.Sequential(
                nn.Linear(D_m, D), nn.GELU(), nn.Dropout(p), nn.Linear(D, D)
            )
            self.moirai_backbone = None
            self.local_proj = None  # baseline에선 사용하지 않음
        else:  # finetune
            # 학습 가능한 Moirai
            self.moirai_backbone = MoiraiModule.from_pretrained(cfg.moirai.model_id)
            for prm in self.moirai_backbone.parameters():
                prm.requires_grad = True
            # D_m은 모듈 내부 설정에 따라 달라질 수 있어 → 첫 forward에서 lazy init
            self.adapter = None
            self.local_proj = None  # 첫 forward에서 Moirai 출력 차원 보고 생성

        # ----- 동적 채널 계산 -----
        # weather: 연속 5(T,P0,P,U,Td) + rain 원핫 4 = 9
        cin_weather = 9
        # price: e_price, s_price = 2
        cin_price   = 2
        # time: hour_sin, hour_cos, dow_sin, dow_cos, is_offday = 5
        cin_time    = 5

        # spatial: (mean, gap, ratio) × len(use_keys)
        # 우선순위: cfg.data.spatial.use_keys -> cfg.data.spatial_usekeys -> 기본 ("occ",)
        spatial_use_keys = None
    
        if hasattr(cfg.data, "spatial_usekeys"):
            spatial_use_keys = tuple(cfg.data.spatial_usekeys)
        else:
            spatial_use_keys = ("occ",)  # fallback

        cin_spatial = 3 * max(1, len(spatial_use_keys))

        # ----- other encoders -----
        ks = cfg.model.KERNEL_SIZE; pad = ks//2
        def conv_block(cin):
            return nn.Sequential(
                nn.Conv1d(cin, 64, ks, padding=pad), nn.GELU(),
                nn.Conv1d(64, 64, ks, padding=pad), nn.GELU()
            )

        self.enc_weather = conv_block(cin_weather)
        self.enc_price   = conv_block(cin_price)
        self.enc_spatial = conv_block(cin_spatial)
        self.enc_time    = conv_block(cin_time)

        # static: charge_count/area/perimeter + poi 3개 = 6
        self.enc_static  = nn.Sequential(nn.Linear(6, 128), nn.GELU(), nn.Dropout(p), nn.Linear(128, D))

        self.proj_weather= nn.Linear(64, D)
        self.proj_price  = nn.Linear(64, D)
        self.proj_spatial= nn.Linear(64, D)
        self.proj_time   = nn.Linear(64, D)

        def align(): return nn.Sequential(nn.LayerNorm(D), nn.Dropout(p), nn.Linear(D, D))
        self.align_local = align(); self.align_weather = align(); self.align_price = align()
        self.align_spatial = align(); self.align_static = align(); self.align_time = align()

        self.mha  = nn.MultiheadAttention(D, cfg.model.nhead, dropout=p, batch_first=True)
        self.head = nn.Linear(D, self.cfg.data.H)
        self.use_softplus = bool(getattr(cfg.model, "nonneg_head", False))
        self.softplus = nn.Softplus()

    def forward(self, batch):
        # ===== 1) Local (mode별 경로) =====
        if self.cfg.model.mode == "baseline":
            # 캐시에서 온 h_local_moirai → adapter
            h_local = self.adapter(batch["h_local_moirai"])
        else:
            # x_local: (B,1,L) → Moirai -> (B,T,Dm or B,Dm)
            x_local = batch["x_local"]  # (B,1,L)
            B, C, L = x_local.shape
            target = x_local.transpose(1, 2)  # (B,L,1)
            observed_mask = torch.ones_like(target, dtype=torch.bool)
            prediction_mask = torch.zeros(B, L, dtype=torch.bool, device=target.device)
            sample_id = torch.arange(B, device=target.device).unsqueeze(1).expand(B, L)
            time_id   = torch.arange(L, device=target.device).unsqueeze(0).expand(B, L)
            variate_id= torch.zeros(B, L, dtype=torch.long, device=target.device)
            patch_size = torch.full((B, L), self.cfg.moirai.patch_sizes[0], dtype=torch.long, device=target.device)

            # ── lazy init: local_proj가 아직 없으면 모양만 한 번 확인
            if self.local_proj is None:
                with torch.no_grad():
                    out_probe = self.moirai_backbone(
                        target=target, observed_mask=observed_mask,
                        prediction_mask=prediction_mask, sample_id=sample_id,
                        time_id=time_id, variate_id=variate_id, patch_size=patch_size
                    )
                    hs_probe = out_probe if isinstance(out_probe, torch.Tensor) else getattr(out_probe, "mean", out_probe.sample())
                    if hs_probe.dim() == 3:  # (B,T,Dm)
                        hs_probe = hs_probe.mean(dim=1)  # cfg.moirai.pool 반영하려면 여기서 분기
                    D_m = int(hs_probe.shape[1])
                # 이제 투사기 생성(D_m→D)
                D = self.cfg.model.D_model; p = self.cfg.model.dropout
                self.local_proj = nn.Sequential(nn.Linear(D_m, D), nn.GELU(), nn.Dropout(p), nn.Linear(D, D)).to(target.device)

            # 실제 forward는 그래프를 유지해야 하므로 재계산(프로브는 no_grad로만 모양 확인용)
            out = self.moirai_backbone(
                target=target, observed_mask=observed_mask,
                prediction_mask=prediction_mask, sample_id=sample_id,
                time_id=time_id, variate_id=variate_id, patch_size=patch_size
            )
            hs = out if isinstance(out, torch.Tensor) else getattr(out, "mean", out.sample())
            if hs.dim() == 3:
                hs = hs.mean(dim=1)  # pool="mean" 기본
            h_local = self.local_proj(hs)

        # ===== 2) Other modalities =====
        h_weather = self.proj_weather(self.enc_weather(batch["x_weather"]).mean(-1))
        h_price   = self.proj_price  (self.enc_price  (batch["x_price"  ]).mean(-1))
        h_spatial = self.proj_spatial(self.enc_spatial(batch["x_spatial"]).mean(-1))
        h_time    = self.proj_time   (self.enc_time   (batch["x_time"   ]).mean(-1))
        h_static  = self.enc_static(batch["x_static"])

        # ===== 3) Align + fusion + head =====
        def gate(align, h): return torch.sigmoid(align(h)) * h
        h_local_a   = gate(self.align_local,   h_local)
        h_weather_a = gate(self.align_weather, h_weather)
        h_price_a   = gate(self.align_price,   h_price)
        h_spatial_a = gate(self.align_spatial, h_spatial)
        h_static_a  = gate(self.align_static,  h_static)
        h_time_a    = gate(self.align_time,    h_time)

        tokens = torch.stack([h_weather_a, h_price_a, h_spatial_a, h_static_a, h_time_a], dim=1)
        q = h_local_a.unsqueeze(1)
        h, _ = self.mha(q, tokens, tokens)
        y_hat = self.head(h.squeeze(1))
        return self.softplus(y_hat) if self.use_softplus else y_hat


In [8]:
import torch
import torch.nn.functional as F
from tqdm import tqdm

def loss_epoch(model, dataloader, criterion=F.mse_loss, optimizer=None, scheduler=None, sched_per_batch=False, cfg=None, desc=None):
    """
    한 epoch 학습/검증 (optimizer 유무로 train/eval).
    - loss는 mask=True 위치만 평균(reduction='mean').
    - epoch 평균은 유효타깃수 가중 평균.
    - clip_grad는 cfg.train.clip_grad 사용(없으면 미적용).
    - scheduler는 옵션: per-batch(step_per_batch=True) 또는 per-epoch(train() 바깥).
    """
    is_train = optimizer is not None
    model.train() if is_train else model.eval()

    clip_grad = None
    if cfg is not None and hasattr(cfg, "train") and hasattr(cfg.train, "clip_grad"):
        clip_grad = cfg.train.clip_grad
        # 0 또는 False면 비적용

    total_loss_w = 0.0
    total_valid  = 0
    batch_losses = []

    loop = tqdm(dataloader, desc=desc or ("train" if is_train else "valid"))
    for batch in loop:
        if batch.get("skip", False):
            continue

        if is_train:
            optimizer.zero_grad(set_to_none=True)

        # forward & masked loss
        y_hat = model(batch)             # (B,H)
        sel   = batch["mask"].bool()     # (B,H)
        n     = int(sel.sum().item())
        if n == 0:
            continue

        loss = criterion(y_hat[sel], batch["y"][sel])  # 평균 손실

        if is_train:
            loss.backward()
            if clip_grad and clip_grad > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
            optimizer.step()
            if scheduler is not None and sched_per_batch:
                scheduler.step()

        total_loss_w += float(loss.item()) * n
        total_valid  += n
        batch_losses.append(float(loss.item()))
        loop.set_postfix(loss=f"{loss.item():.4f}", n_valid=n)

    epoch_loss = total_loss_w / max(1, total_valid)
    return epoch_loss, batch_losses

import os, re, time, torch

def _find_latest_epoch_ckpt(ckpt_dir):
    """
    ckpt_dir에서 epoch_####.ckpt / .pt 중 가장 번호가 큰 파일 경로 반환. 없으면 None.
    """
    if not os.path.isdir(ckpt_dir):
        return None
    files = os.listdir(ckpt_dir)
    cand = []
    for fn in files:
        m = re.match(r"epoch_(\d+)\.(ckpt|pt)$", fn)
        if m:
            cand.append((int(m.group(1)), os.path.join(ckpt_dir, fn)))
    if not cand:
        return None
    cand.sort(key=lambda x: x[0])
    return cand[-1][1]  # latest path

def train(model, train_loader, val_loader, criterion,
          optimizer, save_best_path, config=None, resume_path=None, scheduler=None, sched_per_batch=False):
    """
    - best 기준: validation loss (criterion)
    - metrics: cfg.train.metrics 사용 (없으면 기본 4종)
    - scheduler: 주어지면 사용(None이면 무시)
    - resume: cfg.exec.new_model_train==False면 run_dir/checkpoints에서 최신 epoch_* 로드(모델+옵티마이저(+스케줄러)).
              resume_path가 따로 주어지면 그걸 우선 사용.
    """
    cfg = config or {}
    EPOCHS   = getattr(cfg.train, "epochs", 5)
    METRICS  = tuple(getattr(cfg.train, "metrics", ["RMSE","MAPE","RAE","MAE"]))
    RUN_DIR  = cfg.out.run_dir
    CKPT_DIR = cfg.out.ckpt_dir
    NEW_TRAIN = bool(getattr(cfg.exec, "new_model_train", True))

    model = model.to(DEVICE)

    # ---- FULL RESUME (A)
    start_epoch = 0
    if not NEW_TRAIN:
        # 우선순위: 명시 resume_path > 최신 epoch_* 자동탐색
        latest = resume_path if resume_path and os.path.exists(resume_path) else _find_latest_epoch_ckpt(CKPT_DIR)
        if latest:
            print(f"▶️ Full resume from: {latest}")
            ckpt = torch.load(latest, map_location=DEVICE)
            model.load_state_dict(ckpt["model_state_dict"])
            if "optimizer_state_dict" in ckpt:
                optimizer.load_state_dict(ckpt["optimizer_state_dict"])
            if scheduler and ckpt.get("scheduler_state_dict") is not None:
                try:
                    scheduler.load_state_dict(ckpt["scheduler_state_dict"])
                except Exception:
                    print("[warn] scheduler state incompatible; ignoring")
            start_epoch = int(ckpt.get("epoch", 0))
        else:
            print("ℹ️ new_model_train=False 이지만 체크포인트가 없어 새로 시작합니다.")

    history = {"train_epoch": [], "train_iter": [], "val_epoch": [], "val_metrics": []}
    best_val_loss = float("inf")
    best_epoch = 0

    t0 = time.time()
    for ep in range(start_epoch, EPOCHS):
        eidx = ep + 1
        lr = optimizer.param_groups[0]["lr"]
        print(f"\n[Epoch {eidx}/{EPOCHS}] LR={lr:.3e}")

        # train
        train_loss, train_batches = loss_epoch(
            model, train_loader, criterion=criterion,
            optimizer=optimizer, scheduler=scheduler, sched_per_batch=sched_per_batch, cfg=cfg, desc="train"
        )
        history["train_epoch"].append(train_loss)
        history["train_iter"].append(train_batches)
        print(f" Train Loss: {train_loss:.4f} | iters: {len(train_batches)}")

        # valid
        val_loss, _ = loss_epoch(model, val_loader, criterion=criterion, optimizer=None, cfg=cfg, desc="valid")
        history["val_epoch"].append(val_loss)
        metrics = compute_metrics(model, val_loader, metrics=METRICS)
        history["val_metrics"].append(metrics)
        print(" Val  Loss: {:.4f} | {}".format(val_loss, " | ".join([f"{k}:{v:.4f}" for k,v in metrics.items()])))

        # best by val loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_epoch = eidx
            os.makedirs(os.path.dirname(save_best_path) or ".", exist_ok=True)
            torch.save({
                "epoch": best_epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict() if scheduler else None,
                "train_loss": train_loss,
                "val_loss": best_val_loss,
                "val_metrics": metrics
            }, save_best_path)
            print(f" ✅ Best model updated @ epoch {best_epoch} (ValLoss {best_val_loss:.4f})")

        # per-epoch ckpt
        ep_ckpt = cfg.out.epoch_ckpt_tmpl.format(eidx)  # 예: ".../epoch_0001.ckpt"
        torch.save({
            "epoch": eidx,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict() if scheduler else None,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "val_metrics": metrics
        }, ep_ckpt)

        # scheduler per-epoch step (per-batch면 loss_epoch에서 이미 처리)
        if scheduler is not None and not sched_per_batch:
            scheduler.step()

    print(f"\nDone. Best epoch={best_epoch} | best val loss={best_val_loss:.4f} | time={time.time()-t0:.1f}s")
    return history

@torch.no_grad()
def compute_metrics(model, dataloader, metrics=("RMSE","MAPE","RAE","MAE"), eps=1e-8):
    model.eval()
    agg = {m: 0.0 for m in metrics}
    total_valid = 0

    for batch in dataloader:
        if batch.get("skip", False):
            continue
        y_hat = model(batch)
        y     = batch["y"]
        sel   = batch["mask"].bool()
        n     = int(sel.sum().item())
        if n == 0:
            continue

        diff   = (y_hat - y)[sel]
        absdif = diff.abs()
        sqdif  = diff.pow(2)
        y_sel  = y[sel]

        rmse = torch.sqrt(sqdif.mean()).item()
        mae  = absdif.mean().item()
        mape = (absdif / (y_sel.abs() + eps)).mean().item() * 100.0
        rae  = (absdif.sum().item()) / ((y_sel - y_sel.mean()).abs().sum().item() + eps)

        if "RMSE" in metrics: agg["RMSE"] += rmse * n
        if "MAE"  in metrics: agg["MAE"]  += mae  * n
        if "MAPE" in metrics: agg["MAPE"] += mape * n
        if "RAE"  in metrics: agg["RAE"]  += rae  * n
        total_valid += n

    return {m: (agg[m] / max(1, total_valid)) for m in metrics}

@torch.no_grad()
def test(model, test_loader, criterion=F.mse_loss, metrics=("RMSE","MAPE","RAE","MAE")):
    model.eval()
    total_loss_w = 0.0
    total_valid  = 0
    for batch in tqdm(test_loader, desc="test"):
        if batch.get("skip", False):
            continue
        out = model(batch)
        sel = batch["mask"].bool()
        n   = int(sel.sum().item())
        if n == 0:
            continue
        loss = criterion(out[sel], batch["y"][sel])
        total_loss_w += float(loss.item()) * n
        total_valid  += n
    test_loss = total_loss_w / max(1, total_valid)

    m = compute_metrics(model, test_loader, metrics=metrics)
    print("Test:", " | ".join([f"{k}:{v:.4f}" for k,v in m.items()]), f"| Loss:{test_loss:.4f}")
    return {"loss": test_loss, **m}


In [9]:
def run_train(cfg):
    # 1) Load & align data
    clock, tables, zone_ids, hashes = prepare_data(cfg)

    # 2) (Optional) POI counts via cache (fast, idempotent)
    if getattr(cfg.data, "use_poi", False) and os.path.exists(cfg.data.paths.get("poi", "")):
        try:
            ensure_poi_counts(tables, cfg)  # you implemented (uses build_poi_counts_fast inside)
        except ModuleNotFoundError:
            print("[warn] scikit-learn not installed; skip POI")
            cfg.data.use_poi = False

    # 3) Sample indices + standardization (tables standardized in-place)
    si, stats = build_sample_indices(tables, clock, zone_ids, cfg)
    print(f"[norm_version] {stats['norm_version']}")

    # 4) Collate function & (optional) Moirai cache for baseline mode
    mode = str(getattr(cfg.model, "mode", "baseline")).lower()
    if mode == "baseline":
        # create Moirai cache for train+val to avoid misses during training
        si_train = si.get("train", pd.DataFrame())
        si_val   = si.get("val",   pd.DataFrame())
        si_tv    = pd.concat([si_train, si_val], ignore_index=True) if not si_train.empty or not si_val.empty else pd.DataFrame(columns=["t_start","zone_id","L","H","occ_only","split"])
        keys_tv  = make_moirai_keys(si_tv, hashes, stats, cfg) if not si_tv.empty else pd.DataFrame(columns=KEY_COLS)

        moirai_df, moirai_emb, emb_dim = ensure_moirai_cache(keys_tv, tables, cfg, hashes, stats)
        cfg.model.D_moirai = int(emb_dim)

        # build lookup index for cached embeddings
        key2row = {}
        for i, (_, r) in enumerate(moirai_df.iterrows()):
            key2row[tuple(r[KEY_COLS])] = i

        collate_fn = collate_fn_builder_baseline(moirai_emb, key2row, hashes, stats)
    else:
        collate_fn = collate_fn_builder_finetune()

    # 5) Datasets & Loaders (pass stats to keep fb_static normalization consistent)
    train_df = si.get("train", pd.DataFrame())
    val_df   = si.get("val",   pd.DataFrame())
    train_ds = EVDataset(train_df, tables, zone_ids, cfg, stats=stats)
    val_ds   = EVDataset(val_df,   tables, zone_ids, cfg, stats=stats)

    train_loader = DataLoader(
        train_ds,
        batch_size=cfg.train.batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=cfg.train.batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )

    # 6) Model & optimizer
    model = EVForecastModel(cfg).to(DEVICE)
    opt_name = str(getattr(cfg.train, "optimizer", "adamw")).lower()
    if opt_name == "adamw":
        optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)

    # 7) Scheduler (inline)
    scheduler = None
    sched_per_batch = False
    sched_name = str(getattr(cfg.train, "scheduler", "none") or "none").lower()
    if sched_name == "onecycle":
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=cfg.train.lr,
            epochs=cfg.train.epochs,
            steps_per_epoch=max(1, len(train_loader))
        )
        sched_per_batch = True
    elif sched_name in ("steplr", "step"):
        step  = getattr(cfg.train, "lr_step", None)
        gamma = getattr(cfg.train, "lr_gamma", None)
        if step and gamma:
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step, gamma=gamma)
        else:
            print("[warn] StepLR needs cfg.train.lr_step & lr_gamma — skipping scheduler.")

    # 8) Full resume when exec.new_model_train == False
    resume_path = None
    if not bool(getattr(cfg.exec, "new_model_train", True)):
        latest_num, latest_path = -1, None
        if os.path.isdir(cfg.out.ckpt_dir):
            for fn in os.listdir(cfg.out.ckpt_dir):
                if fn.startswith("epoch_") and (fn.endswith(".ckpt") or fn.endswith(".pt")):
                    try:
                        n = int(fn.split("_")[1].split(".")[0])
                        if n > latest_num:
                            latest_num, latest_path = n, os.path.join(cfg.out.ckpt_dir, fn)
                    except Exception:
                        pass
        if latest_path:
            resume_path = latest_path
            print(f"▶️ Full resume from: {resume_path}")
            ckpt = torch.load(resume_path, map_location=DEVICE)
            model.load_state_dict(ckpt["model_state_dict"])
            if "optimizer_state_dict" in ckpt:
                optimizer.load_state_dict(ckpt["optimizer_state_dict"])
            if scheduler is not None and ckpt.get("scheduler_state_dict") is not None:
                try:
                    scheduler.load_state_dict(ckpt["scheduler_state_dict"])
                except Exception:
                    print("[warn] scheduler state incompatible; ignoring")
        else:
            print("ℹ️ new_model_train=False 이지만 체크포인트가 없어 새로 시작합니다.")

    # 9) Train (best by val loss)
    history = train(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=torch.nn.functional.mse_loss,
        optimizer=optimizer,
        save_best_path=cfg.out.best_ckpt,
        config=cfg,
        resume_path=resume_path,
        scheduler=scheduler,
        sched_per_batch=sched_per_batch
    )

    # 10) Return prepared bundle for reuse
    return {
        "history": history,
        "clock": clock,
        "tables": tables,
        "zone_ids": zone_ids,
        "hashes": hashes,
        "stats": stats,
        "si": si
    }


In [10]:
import os
import torch
from torch.utils.data import DataLoader

def run_test(cfg, prepared=None, mode="best", ckpt_path=None, batch_size=None, drop_last=False, num_workers=0):
    """
    mode:
      - "best":  cfg.out.best_ckpt 로드
      - "last":  cfg.out.last_ckpt 로드(없으면 ckpt_dir의 최대 epoch_* 자동탐색)
      - "path":  인자로 주는 ckpt_path 로드
    """
    # ---------- 0) sanity ----------
    mode = str(mode).lower()
    assert mode in ("best","last","path"), "mode must be one of {'best','last','path'}"
    if mode == "path":
        assert ckpt_path and os.path.exists(ckpt_path), f"ckpt_path not found: {ckpt_path}"

    # ---------- 1) 데이터/전처리 준비 ----------
    if prepared is not None:
        clock  = prepared["clock"]
        tables = prepared["tables"]
        zone_ids = prepared["zone_ids"]
        hashes = prepared["hashes"]
        stats  = prepared["stats"]
        si     = prepared["si"]
    else:
        clock, tables, zone_ids, hashes = prepare_data(cfg)

        # (옵션) POI 캐시/계산
        if getattr(cfg.data, "use_poi", False) and os.path.exists(cfg.data.paths.get("poi","")):
            try:
                ensure_poi_counts(tables, cfg)
            except ModuleNotFoundError:
                print("[warn] scikit-learn not installed; skip POI"); cfg.data.use_poi = False

        # 샘플 인덱스 + 표준화
        si, stats = build_sample_indices(tables, clock, zone_ids, cfg)

    # 테스트 분할 체크
    if "test" not in si or si["test"].empty:
        raise RuntimeError("[run_test] no test samples found in si['test'].")

    # ---------- 2) Collate & (baseline일 때) Moirai 캐시 ----------
    model_mode = str(getattr(cfg.model, "mode", "baseline")).lower()
    if model_mode == "baseline":
        keys_te = make_moirai_keys(si["test"], hashes, stats, cfg)
        moirai_df, moirai_emb, emb_dim = ensure_moirai_cache(keys_te, tables, cfg, hashes, stats)
        cfg.model.D_moirai = int(emb_dim)

        key2row = {}
        for i, (_, r) in enumerate(moirai_df.iterrows()):
            key2row[tuple(r[KEY_COLS])] = i

        collate_fn = collate_fn_builder_baseline(moirai_emb, key2row, hashes, stats)
    else:
        collate_fn = collate_fn_builder_finetune()

    # ---------- 3) Test Loader ----------
    bs = int(batch_size or cfg.train.batch_size)
    test_ds = EVDataset(si["test"], tables, zone_ids, cfg, stats=stats)
    test_loader = DataLoader(
        test_ds,
        batch_size=bs,
        shuffle=False,
        drop_last=drop_last,
        num_workers=num_workers,
        collate_fn=collate_fn
    )

    # ---------- 4) 체크포인트 선택 ----------
    load_path = None
    if mode == "best":
        load_path = cfg.out.best_ckpt
        if not os.path.exists(load_path):
            raise FileNotFoundError(f"[run_test] best_ckpt not found: {load_path}")
    elif mode == "last":
        lp = cfg.out.last_ckpt
        if os.path.exists(lp):
            load_path = lp
        else:
            # ckpt_dir에서 가장 큰 epoch_* 탐색
            ckpt_dir = cfg.out.ckpt_dir
            best_n, best_path = -1, None
            if os.path.isdir(ckpt_dir):
                for fn in os.listdir(ckpt_dir):
                    if fn.startswith("epoch_") and (fn.endswith(".ckpt") or fn.endswith(".pt")):
                        try:
                            n = int(fn.split("_")[1].split(".")[0])
                            if n > best_n:
                                best_n, best_path = n, os.path.join(ckpt_dir, fn)
                        except Exception:
                            pass
            if best_path is None:
                raise FileNotFoundError("[run_test] no last checkpoint found (last_ckpt missing and no epoch_*.ckpt).")
            load_path = best_path
    else:  # mode == "path"
        load_path = ckpt_path

    # ---------- 5) 모델 로드 ----------
    model = EVForecastModel(cfg).to(DEVICE)
    ckpt = torch.load(load_path, map_location=DEVICE)
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()

    # ---------- 6) 평가 ----------
    # test()는 masked MSE로 loss 계산 + compute_metrics로 4종 지표 리턴
    result = test(
        model=model,
        test_loader=test_loader,
        criterion=torch.nn.functional.mse_loss,
        metrics=tuple(getattr(cfg.train, "metrics", ["RMSE","MAPE","RAE","MAE"]))
    )

    # 메타정보 추가
    out = {
        "mode": mode,
        "ckpt_path": load_path,
        "n_test_samples": len(test_ds),
        "batch_size": bs,
        **result,
    }

    return out


In [11]:
cfg = SimpleNamespace(
    exec = SimpleNamespace(
        new_model_train = True,
        poi_shared_dir = "poi_cache_global",
        moirai_shared_dir = "moirai_cache_global",
    ),
    data = SimpleNamespace(
        L = 24, H = 1, pred_offset=3,
        baseline_occ_only = True,
        moirai_channel = ["occ"], # occ, dur, etc...
        train_range = ("2022-09-01", "2022-12-31"),
        val_range   = ("2023-01-01", "2023-01-31"),
        test_range  = ("2023-02-01", "2023-02-28"),
        paths = PATHS,
        use_poi = True,
        poi_radius_beta = 0.7, poi_rmin = 300.0, poi_rmax = 2000.0,
        sample_stride = 1, min_spatial_neighbors = 1,
        spatial_usekeys = ('occ','dur','vol')
    ),
    moirai = SimpleNamespace(
        model_id = "Salesforce/moirai-1.0-R-base",
        patch_sizes = [1],
        pool = "mean",
        emb_dim = 768,
        batch_size = 64,
        cache_file = "moirai_cache.pt",   # 실제 저장은 out.embed_dir 아래
    ),
    model = SimpleNamespace(
        D_moirai = 768,     # ← 실제 임베딩 차원으로 런타임에 갱신됨
        D_model  = 256,
        dropout  = 0.1,
        nhead    = 8,
        head_hidden = 512,
        nonneg_head = False,  # 표준화 타깃 → 음수 허용
        mode = "baseline",
        HIDDEN = 64, KERNEL_SIZE = 5,
        fusion_layers = 1,
        head_kind = "linear",
    ),
    train = SimpleNamespace(
        epochs = 2, batch_size = 64,
        lr = 1e-3, weight_decay = 1e-4,
        optimizer = "adamw",
        scheduler = None ,#"onecycle",
        clip_grad = 1.0,
        early_stop_patience = 5,
        metrics = ["RMSE","MAPE","RAE", "MAE", ],
    ),
    out = SimpleNamespace(
        run_id = "baseline_occOnly",
        run_dir = "runs/baseline_occOnly",
        history_dir = "runs/baseline_occOnly/history",
        ckpt_dir = "runs/baseline_occOnly/checkpoints",
        embed_dir = "runs/baseline_occOnly/embeddings",
        artifacts_dir = "runs/baseline_occOnly/artifacts",
        config_dump = "runs/baseline_occOnly/config.yaml",
        train_log_csv = "runs/baseline_occOnly/history/train_log.csv",
        summary_json = "runs/baseline_occOnly/history/summary.json",
        best_ckpt = "runs/baseline_occOnly/checkpoints/best.ckpt",
        last_ckpt = "runs/baseline_occOnly/checkpoints/last.ckpt",
        epoch_ckpt_tmpl = "runs/baseline_occOnly/checkpoints/epoch_{:04d}.ckpt"
    ),
)

for d in [cfg.out.run_dir, cfg.out.history_dir, cfg.out.ckpt_dir, cfg.out.embed_dir, cfg.out.artifacts_dir]:
    os.makedirs(d, exist_ok=True)

In [13]:
prepared = run_train(cfg)


[norm_version] 2635e1472fea4ea7


[MoiraiCache] encode:  56%|█████▌    | 8662/15581 [2:23:45<3:38:34,  1.90s/it]

In [100]:
# 1) 베스트 체크포인트로 테스트
res_best = run_test(cfg, prepared=prepared, mode="best")




  moirai_df, moirai_emb = torch.load(cache_file, map_location="cpu")
[MoiraiCache] encode: 100%|██████████| 108/108 [00:16<00:00,  6.54it/s]
  ckpt = torch.load(load_path, map_location=DEVICE)
test: 100%|██████████| 108/108 [00:17<00:00,  6.27it/s]


Test: RMSE:0.3231 | MAPE:111.6814 | RAE:0.3321 | MAE:0.2017 | Loss:0.1117


In [None]:
# 2) 라스트 체크포인트로 테스트
res_last = run_test(cfg, prepared=prepared_bundle, mode="last")

# 3) 특정 에포크 경로로 테스트
res_ep2 = run_test(cfg, prepared=prepared_bundle, mode="path",
                   ckpt_path="runs/baseline_occOnly/checkpoints/epoch_0002.ckpt")

# 1) Best 모델로 테스트
res_best = run_eval(cfg,
                    prepared["si"], prepared["tables"], prepared["zone_ids"],
                    prepared["hashes"], prepared["stats"],
                    which="best")




In [None]:
# 2) Last 모델로 테스트
res_last = run_eval(cfg,
                    prepared["si"], prepared["tables"], prepared["zone_ids"],
                    prepared["hashes"], prepared["stats"],
                    which="last")

In [None]:
# 3) 임의 에포크 경로로 테스트 (예: epoch_0002)
some_ckpt = cfg.out.epoch_ckpt_tmpl.format(2)
if os.path.exists(some_ckpt):
    res_path = run_eval(cfg,
                        prepared["si"], prepared["tables"], prepared["zone_ids"],
                        prepared["hashes"], prepared["stats"],
                        which="path", path_override=some_ckpt)
else:
    print(f"[Info] {some_ckpt} not found; skip path-based eval.")
