In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

import random, math, json, hashlib
from types import SimpleNamespace
from pathlib import Path
from datetime import datetime
from contextlib import nullcontext
from collections import deque

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torch import amp
from tqdm.auto import tqdm
import chinese_calendar as cc

from uni2ts.model.moirai import MoiraiModule


# Device / Seed

if torch.backends.mps.is_available():
    DEVICE = "mps"
elif torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"
print(f"Using device: {DEVICE}")

SEED = 0
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)


# 경로 & 설정

ROOT = Path(".").resolve()
DATA_DIR = ROOT / "UrbanEV" / "data"
PATHS = {
    "occ": str(DATA_DIR / "occupancy.csv"),
    "dur": str(DATA_DIR / "duration.csv"),
    "vol": str(DATA_DIR / "volume.csv"),
    "e_price": str(DATA_DIR / "e_price.csv"),
    "s_price": str(DATA_DIR / "s_price.csv"),
    "weather": str(DATA_DIR / "weather_central.csv"),
    "inf": str(DATA_DIR / "inf.csv"),
    "adj": str(DATA_DIR / "adj.csv"),
    "dist": str(DATA_DIR / "distance.csv"),
    "poi": str(DATA_DIR / "poi.csv"),
}

cfg = SimpleNamespace(
    exec = SimpleNamespace(
        new_model_train = True,
        poi_shared_dir = "poi_cache_global",
        moirai_shared_dir = "moirai_cache_global",
    ),
    data = SimpleNamespace(
        L = 24, H = 1, pred_offset=3,
        baseline_occ_only = True,
        train_range = ("2022-09-01", "2022-09-03"),
        val_range   = ("2022-09-05", "2022-09-06"),
        test_range  = ("2022-09-07", "2022-09-08"),
        paths = PATHS,
        use_poi = True,
        poi_radius_beta = 0.7, poi_rmin = 300.0, poi_rmax = 2000.0,
        sample_stride = 1, min_spatial_neighbors = 1,
    ),
    moirai = SimpleNamespace(
        model_id = "Salesforce/moirai-1.0-R-base",
        patch_sizes = [1],
        pool = "mean",
        emb_dim = 768,
        batch_size = 64,
        cache_file = "moirai_cache.pt",   # 실제 저장은 out.embed_dir 아래
    ),
    model = SimpleNamespace(
        D_moirai = 768,     # ← 실제 임베딩 차원으로 런타임에 갱신됨
        D_model  = 256,
        dropout  = 0.1,
        nhead    = 8,
        head_hidden = 512,
        nonneg_head = False,  # 표준화 타깃 → 음수 허용
        mode = "baseline",
        HIDDEN = 64, KERNEL_SIZE = 5,
        fusion_layers = 1,
        head_kind = "linear",
    ),
    train = SimpleNamespace(
        epochs = 4, batch_size = 64,
        lr = 1e-3, weight_decay = 1e-4,
        optimizer = "adamw",
        scheduler = "onecycle",
        clip_grad = 1.0,
        early_stop_patience = 5,
        metrics = ["MSE", "RMSE", "MAE", "sMAPE"],
    ),
    out = SimpleNamespace(
        run_id = "baseline_occOnly",
        run_dir = "runs/baseline_occOnly",
        history_dir = "runs/baseline_occOnly/history",
        ckpt_dir = "runs/baseline_occOnly/checkpoints",
        embed_dir = "runs/baseline_occOnly/embeddings",
        artifacts_dir = "runs/baseline_occOnly/artifacts",
        config_dump = "runs/baseline_occOnly/config.yaml",
        train_log_csv = "runs/baseline_occOnly/history/train_log.csv",
        summary_json = "runs/baseline_occOnly/history/summary.json",
        best_ckpt = "runs/baseline_occOnly/checkpoints/best.ckpt",
        last_ckpt = "runs/baseline_occOnly/checkpoints/last.ckpt",
        epoch_ckpt_tmpl = "runs/baseline_occOnly/checkpoints/epoch_{:04d}.ckpt"
    ),
)

for d in [cfg.out.run_dir, cfg.out.history_dir, cfg.out.ckpt_dir, cfg.out.embed_dir, cfg.out.artifacts_dir]:
    os.makedirs(d, exist_ok=True)


In [None]:
# Helpers
def build_clock(train_range, test_range):
    start = pd.to_datetime(train_range[0])
    end   = pd.to_datetime(test_range[1]) + pd.Timedelta(hours=23)
    return pd.date_range(start=start, end=end, freq="h")

def read_ts_csv(path):
    df = pd.read_csv(path, index_col=0)
    df.index = pd.to_datetime(df.index)
    return df.sort_index().apply(pd.to_numeric, errors="coerce")

def reindex_local(df, clock):  return df.reindex(clock).fillna(0)
def reindex_price(df, clock):  return df.reindex(clock).ffill().bfill()

def reindex_weather(df, clock):
    out = df.reindex(clock)
    cont_cols = [c for c in out.columns if c.lower() not in ("nrain","rain","n_rain")]
    rain_cols = [c for c in out.columns if c.lower() in ("nrain","rain","n_rain")]
    if cont_cols: out[cont_cols] = out[cont_cols].interpolate("time").ffill().bfill()
    if rain_cols: out[rain_cols] = out[rain_cols].ffill().bfill().clip(lower=0).astype(int)
    return out

def load_inf(path):
    df = pd.read_csv(path, index_col=0)
    cols = ["longitude","latitude","charge_count","area","perimeter"]
    for c in cols: df[c] = pd.to_numeric(df.get(c, np.nan), errors="coerce")
    df.index = df.index.astype(str)
    return df[cols]

def load_matrix(path, as_float=False):
    df = pd.read_csv(path)
    df.index = df.columns.astype(str)
    df.columns = df.columns.astype(str)
    if df.index.duplicated().any():
        print(f"[warn] duplicated index in {path}, keeping first")
        df = df.loc[~df.index.duplicated(keep='first')]
    if df.columns.duplicated().any():
        print(f"[warn] duplicated columns in {path}, keeping first")
        df = df.loc[:, ~df.columns.duplicated(keep='first')]
    df = df.astype(float if as_float else int)
    np.fill_diagonal(df.values, 0)
    return df

def precompute_time(clock):
    hrs  = clock.hour; dows = clock.dayofweek
    hour_sin = np.sin(2*np.pi*hrs/24).astype(np.float32)
    hour_cos = np.cos(2*np.pi*hrs/24).astype(np.float32)
    dow_sin  = np.sin(2*np.pi*dows/7).astype(np.float32)
    dow_cos  = np.cos(2*np.pi*dows/7).astype(np.float32)
    is_weekend = ((dows==5)|(dows==6)).astype(np.float32)
    is_holiday = np.array([float(cc.is_holiday(ts.date())) for ts in clock], dtype=np.float32)
    return pd.DataFrame({
        "hour_sin":hour_sin,"hour_cos":hour_cos,
        "dow_sin":dow_sin,"dow_cos":dow_cos,
        "is_weekend":is_weekend,"is_holiday_cn":is_holiday
    }, index=clock)

def prepare_data(cfg):
    clock = build_clock(cfg.data.train_range, cfg.data.test_range)
    occ = reindex_local(read_ts_csv(cfg.data.paths["occ"]), clock)
    dur = reindex_local(read_ts_csv(cfg.data.paths["dur"]), clock)
    vol = reindex_local(read_ts_csv(cfg.data.paths["vol"]), clock)
    epr = reindex_price(read_ts_csv(cfg.data.paths["e_price"]), clock)
    spr = reindex_price(read_ts_csv(cfg.data.paths["s_price"]), clock)
    wth = reindex_weather(read_ts_csv(cfg.data.paths["weather"]), clock)
    inf = load_inf(cfg.data.paths["inf"])
    adj = load_matrix(cfg.data.paths["adj"])
    dist= load_matrix(cfg.data.paths["dist"], as_float=True)

    tables = {"occ":occ,"dur":dur,"vol":vol,"e_price":epr,"s_price":spr,
              "weather":wth,"inf":inf,"adj":adj,"distance":dist,"time":precompute_time(clock)}

    if cfg.data.use_poi and os.path.exists(cfg.data.paths["poi"]):
        tables["poi_raw"] = pd.read_csv(cfg.data.paths["poi"])

    zone_ids = sorted(list(set(occ.columns) & set(epr.columns) & set(spr.columns)))
    if len(zone_ids) != len(set(zone_ids)): raise ValueError("zone_ids duplicated")

    for k in ["occ","dur","vol","e_price","s_price"]:
        tables[k] = tables[k][zone_ids]
    tables["adj"]      = tables["adj"].reindex(index=zone_ids, columns=zone_ids, fill_value=0)
    tables["distance"] = tables["distance"].reindex(index=zone_ids, columns=zone_ids, fill_value=float('inf'))
    tables["inf"]      = tables["inf"].reindex(index=zone_ids)

    datahash  = hashlib.sha256(str(cfg.data.paths).encode()).hexdigest()[:16]
    clockhash = hashlib.sha256(str(clock).encode()).hexdigest()[:16]
    hashes = SimpleNamespace(datahash=datahash, clockhash=clockhash)
    return clock, tables, zone_ids, hashes

def compute_fit_stats(tables, train_times, zone_ids):
    stats = {}
    for key in ["occ","e_price","s_price"]:
        df = tables[key].loc[train_times, zone_ids]
        mu, sd = float(np.nanmean(df.values)), float(np.nanstd(df.values)) or 1.0
        stats[key] = {"mean":mu,"std":sd}
    wdf = tables["weather"].loc[train_times]
    for c in ["T","P0","P","U","Td"]:
        if c in wdf.columns:
            mu, sd = float(np.nanmean(wdf[c])), float(np.nanstd(wdf[c])) or 1.0
        else:
            mu, sd = 0.0, 1.0
        stats[c] = {"mean":mu,"std":sd}
    inf = tables["inf"]
    for c in ["charge_count","area","perimeter"]:
        vals = pd.to_numeric(inf[c], errors="coerce").values
        mu, sd = float(np.nanmean(vals)), float(np.nanstd(vals)) or 1.0
        stats[c] = {"mean":mu,"std":sd}
    stats_json = json.dumps(stats, sort_keys=True)
    stats["norm_version"] = hashlib.sha256(stats_json.encode()).hexdigest()[:16]
    return stats

def standardize_tables(tables, stats):
    for key in ["occ","e_price","s_price"]:
        mu, sd = stats[key]["mean"], stats[key]["std"]
        tables[key] = ((tables[key] - mu) / sd).astype(np.float32)
    w = tables["weather"]
    for c in ["T","P0","P","U","Td"]:
        if c in w.columns:
            mu, sd = stats[c]["mean"], stats[c]["std"]
            w[c] = ((w[c] - mu) / sd).astype(np.float32)

def build_sample_indices(tables, clock, zone_ids, cfg):
    L, H = cfg.data.L, cfg.data.H
    O = int(getattr(cfg.data, "pred_offset", 1))  

    # ----- split 라벨링 -----
    split_sr = pd.Series(index=clock, data="none")
    tr_s, tr_e = pd.to_datetime(cfg.data.train_range[0]), pd.to_datetime(cfg.data.train_range[1])
    va_s, va_e = pd.to_datetime(cfg.data.val_range[0]),   pd.to_datetime(cfg.data.val_range[1])
    te_s, te_e = pd.to_datetime(cfg.data.test_range[0]),  pd.to_datetime(cfg.data.test_range[1])
    split_sr.loc[(split_sr.index>=tr_s)&(split_sr.index<=tr_e)] = "train"
    split_sr.loc[(split_sr.index>=va_s)&(split_sr.index<=va_e)] = "val"
    split_sr.loc[(split_sr.index>=te_s)&(split_sr.index<=te_e)] = "test"

    # ----- 통계/표준화 -----
    train_times = split_sr[split_sr=="train"].index
    stats = compute_fit_stats(tables, train_times, zone_ids)
    standardize_tables(tables, stats)

    # ----- 입력창 유효 마스크 (L 연속) -----
    local_valid  = tables["occ"].notna().rolling(L).sum().eq(L).shift(-(L-1), fill_value=False)
    price_valid  = (
        tables["e_price"].notna().rolling(L).sum().eq(L) &
        tables["s_price"].notna().rolling(L).sum().eq(L)
    ).shift(-(L-1), fill_value=False)

    w_cont  = tables["weather"][["T","P0","P","U","Td"]].notna().all(axis=1).rolling(L).sum().eq(L)
    w_valid = w_cont.shift(-(L-1), fill_value=False)
    w_valid = pd.DataFrame(np.repeat(w_valid.values[:,None], len(zone_ids), axis=1),
                           index=clock, columns=zone_ids)

    # ----- 타깃 유효 마스크 (t0+O부터 H개) -----
    if H == 1:
        # 한 시점만: t0+O 가 존재/유효해야 함
        target_valid = tables["occ"].notna().shift(-O, fill_value=False)
    else:
        # 연속 H개: 오프셋만큼 당긴 뒤 rolling(H)로 H개 연속 체크
        target_valid = (
            tables["occ"].notna()
            .shift(-O, fill_value=False)         # t0 기준 O시간 뒤로 정렬
            .rolling(H).sum().eq(H)              # H개 연속 True
            .shift(-(H-1), fill_value=False)     # 시작점 기준으로 되돌림
        )

    combined = local_valid & price_valid & w_valid & target_valid

    # ----- split별 시작점 필터링 -----
    out = {}
    for split in ["train","val","test"]:
        s_mask = (split_sr == split)

        # 입력창 L개가 split 내부에 모두 있어야 함
        in_start = s_mask.rolling(L).sum().eq(L).shift(-(L-1), fill_value=False)

        # 타깃의 마지막 시점(t0 + O + (H-1))도 같은 split 내부여야 안전
        last_tgt_shift = -(O + (H - 1))   # H=1이면 -(O)
        tgt_ok = s_mask.shift(last_tgt_shift, fill_value=False)

        time_ok = in_start & tgt_ok
        valid = combined.copy()
        valid[~time_ok] = False

        flat = valid.stack().reset_index()
        flat = flat[flat[0]].drop(columns=0)
        flat.columns = ["t_start","zone_id"]
        flat["zone_idx"] = flat["zone_id"].map({z:i for i,z in enumerate(zone_ids)})
        flat["L"] = L
        flat["H"] = H
        flat["occ_only"] = int(cfg.data.baseline_occ_only)
        flat["split"] = split
        out[split] = flat.reset_index(drop=True)

    return out, stats



In [None]:
# POI utilities 
def compute_zone_radius(inf, beta=0.7, r_min=300, r_max=2000):
    A = inf["area"].to_numpy(); P = inf["perimeter"].to_numpy()
    r_area  = np.sqrt(np.clip(A,0,None)/np.pi)
    r_perim = P/(2*np.pi)
    r = beta*r_area + (1-beta)*r_perim
    return pd.Series(np.clip(r, r_min, r_max), index=inf.index)

def build_poi_counts(inf, poi, radius):
    from sklearn.neighbors import BallTree
    poi_lat = np.deg2rad(poi["latitude"].values)
    poi_lon = np.deg2rad(poi["longitude"].values)
    tree = BallTree(np.c_[poi_lat, poi_lon], metric="haversine")
    types = poi["primary_types"].str.lower()

    life_mask = types.str.contains("lifestyle services")
    bres_mask = types.str.contains("business and residential")
    food_mask = types.str.contains("food and beverage services")

    zlat = np.deg2rad(inf["latitude"].values); zlon = np.deg2rad(inf["longitude"].values)
    out = np.zeros((len(inf),3), dtype=int)
    for i in tqdm(range(len(inf)), desc="POI"):
        r_rad = radius.iloc[i] / 6371000.0
        idxs = tree.query_radius([[zlat[i], zlon[i]]], r=r_rad)[0]
        out[i,0] = life_mask.iloc[idxs].sum()
        out[i,1] = bres_mask.iloc[idxs].sum()
        out[i,2] = food_mask.iloc[idxs].sum()
    return pd.DataFrame(out, index=inf.index,
                        columns=["poi_lifestyle","poi_business_residential","poi_food_beverage"])


In [None]:
# Moirai cache (keying fixed)
KEY_COLS = ["zone_id","t_start_iso","L","occ_only","model_id","psig","pool","clock_hash","norm_version","datahash"]

def make_moirai_keys(si, hashes, stats, cfg):
    df = si.copy()
    df["t_start_iso"] = df["t_start"].dt.strftime("%Y-%m-%dT%H:%M:%S")
    df["L"] = df["L"].astype(int)
    df["occ_only"] = df["occ_only"].astype(int)
    df["model_id"] = cfg.moirai.model_id
    df["psig"] = "-".join(map(str, cfg.moirai.patch_sizes))
    df["pool"] = cfg.moirai.pool
    df["clock_hash"] = hashes.clockhash
    df["norm_version"] = stats["norm_version"]
    df["datahash"] = hashes.datahash
    return df[KEY_COLS]

@torch.no_grad()
def encode_batch(x_np, backbone, patch_sizes, pool):
    x = torch.from_numpy(x_np).to(DEVICE).float()  # (B, C, L) ; 여기선 C=1
    B, C, L = x.shape
    target = x.transpose(1,2)  # (B, L, C)
    observed_mask = torch.ones_like(target, dtype=torch.bool)
    prediction_mask = torch.zeros(B, L, dtype=torch.bool, device=DEVICE)
    sample_id = torch.arange(B, device=DEVICE).unsqueeze(1).expand(B, L)
    time_id   = torch.arange(L, device=DEVICE).unsqueeze(0).expand(B, L)
    variate_id= torch.zeros(B, L, dtype=torch.long, device=DEVICE)

    hs_list = []
    for ps in patch_sizes:
        patch_size = torch.full((B, L), ps, dtype=torch.long, device=DEVICE)
        try:
            hs = backbone.encode(target=target, observed_mask=observed_mask, patch_size=patch_size)
        except AttributeError:
            output = backbone(target=target, observed_mask=observed_mask,
                              prediction_mask=prediction_mask, sample_id=sample_id,
                              time_id=time_id, variate_id=variate_id, patch_size=patch_size)
            hs = output if isinstance(output, torch.Tensor) else getattr(output, "mean", output.sample())
        if hs.dim()==3: hs = hs.mean(dim=1) if pool=="mean" else hs.max(dim=1)[0]
        elif hs.dim()!=2: raise ValueError(f"Unexpected hs shape: {hs.shape}")
        hs_list.append(hs)
    emb = torch.stack(hs_list).mean(dim=0)  # (B, D_emb)
    return emb.cpu().numpy()

def _sanitize(s:str)->str:
    return "".join(ch if (str(ch).isalnum() or ch in "-._") else "_" for ch in str(s))

def _make_cache_path_with_dim(cfg, model_id, patch_sig, pool, emb_dim, hashes, stats):
    sig = {"model":model_id,"psig":patch_sig,"pool":pool,"edim":str(emb_dim),
           "clock":hashes.clockhash[:8],"norm":stats["norm_version"][:8],"data":hashes.datahash[:8]}
    fname = "moirai_" + "_".join(f"{k}-{_sanitize(v)}" for k,v in sig.items()) + ".pt"
    return os.path.join(cfg.out.embed_dir, fname)

def ensure_moirai_cache(keys_df, tables, cfg, hashes, stats):
    backbone = MoiraiModule.from_pretrained(cfg.moirai.model_id).to(DEVICE).eval()
    patch_sig = "-".join(map(str, cfg.moirai.patch_sizes)); pool = cfg.moirai.pool

    # 실제 emb_dim 프로빙
    first_row = keys_df.iloc[0]
    z, t0 = first_row["zone_id"], pd.to_datetime(first_row["t_start_iso"])
    x_probe = np.expand_dims(tables["occ"].loc[slice(t0, t0+pd.Timedelta(hours=cfg.data.L-1)), z].values.astype(np.float32), 0)  # (1,L)
    x_probe = x_probe[None, ...]  # (1,1,L)
    probe_np = encode_batch(x_probe, backbone, cfg.moirai.patch_sizes, pool)  # (1,E)
    emb_dim = int(probe_np.shape[1])

    cache_file = _make_cache_path_with_dim(cfg, cfg.moirai.model_id, patch_sig, pool, emb_dim, hashes, stats)
    os.makedirs(os.path.dirname(cache_file), exist_ok=True)

    if os.path.exists(cache_file):
        moirai_df, moirai_emb = torch.load(cache_file, map_location="cpu")
        if moirai_emb.ndim!=2 or moirai_emb.shape[1]!=emb_dim:
            print(f"[MoiraiCache] dim mismatch in file: {moirai_emb.shape} vs emb_dim={emb_dim} → reset")
            moirai_df = pd.DataFrame(columns=keys_df.columns)
            moirai_emb = torch.zeros(0, emb_dim, dtype=torch.float32)
    else:
        moirai_df = pd.DataFrame(columns=keys_df.columns)
        moirai_emb = torch.zeros(0, emb_dim, dtype=torch.float32)

    key2row = {tuple(row[KEY_COLS]): i for i, (_, row) in enumerate(moirai_df.iterrows())}
    todo_rows = [r for _, r in keys_df.iterrows() if tuple(r[KEY_COLS]) not in key2row]

    if len(todo_rows) > 0:
        B = cfg.moirai.batch_size
        for i in tqdm(range(0, len(todo_rows), B), desc="[MoiraiCache] encode"):
            batch_rows = todo_rows[i:i+B]
            xs = []
            for r in batch_rows:
                z = r["zone_id"]; t0 = pd.to_datetime(r["t_start_iso"])
                occ = tables["occ"].loc[slice(t0, t0+pd.Timedelta(hours=cfg.data.L-1)), z].values.astype(np.float32)
                xs.append(occ[None, :])  # (1,L)
            x_np = np.stack(xs).astype(np.float32)  # (b,1,L)

            y_np = encode_batch(x_np, backbone, cfg.moirai.patch_sizes, pool)  # (b,E)
            new_emb = torch.from_numpy(y_np).float()   # CPU

            if new_emb.shape[1] != moirai_emb.shape[1]:
                print(f"[MoiraiCache] runtime dim mismatch: cache={moirai_emb.shape[1]} vs new={new_emb.shape[1]} → reset")
                moirai_df = pd.DataFrame(columns=keys_df.columns)
                moirai_emb = torch.zeros(0, new_emb.shape[1], dtype=torch.float32)

            batch_df = pd.DataFrame(batch_rows, columns=keys_df.columns)
            moirai_df = pd.concat([moirai_df, batch_df], ignore_index=True)
            moirai_emb = torch.cat([moirai_emb, new_emb], dim=0)

            torch.save((moirai_df, moirai_emb), cache_file + ".tmp")
        os.replace(cache_file + ".tmp", cache_file)

    return moirai_df, moirai_emb, emb_dim


In [None]:
# Feature builders
def fb_local(tables, z, t0, L, occ_only=True):
    win = slice(t0, t0 + pd.Timedelta(hours=L-1))
    occ = tables["occ"].loc[win, z].values.astype(np.float32)
    return occ[None, :]  # (1, L)

def fb_price(tables, z, t0, L):
    win = slice(t0, t0 + pd.Timedelta(hours=L-1))
    epr = tables["e_price"].loc[win, z].values.astype(np.float32)
    spr = tables["s_price"].loc[win, z].values.astype(np.float32)
    return np.stack([epr, spr]).astype(np.float32)

def fb_weather(tables, t0, L):
    win = slice(t0, t0 + pd.Timedelta(hours=L-1))
    w = tables["weather"].loc[win]
    cont = w[["T","P0","P","U","Td"]].values.T.astype(np.float32)
    rain = w.get("nRAIN", pd.Series(0, index=w.index)).clip(0,3).values
    rain_oh = np.zeros((4, L), dtype=np.float32); rain_oh[rain, np.arange(L)] = 1.0
    return np.concatenate([cont, rain_oh], axis=0).astype(np.float32)

def fb_spatial(tables, z, t0, L, zone_ids):
    win = slice(t0, t0 + pd.Timedelta(hours=L-1))
    occ_win = tables["occ"].loc[win].values.astype(np.float32)
    zi = zone_ids.index(z)
    D = tables["distance"].values.astype(np.float32)
    W = 1.0 / (D + 1e-6); W = W / (W.sum(axis=1, keepdims=True) + 1e-6)
    mean_occ = occ_win @ W[zi]
    self_occ = occ_win[:, zi]
    gap = self_occ - mean_occ
    ratio = np.clip(self_occ / (mean_occ + 1e-6), 0, 5)
    return np.stack([mean_occ, gap, ratio]).astype(np.float32)

def fb_static(tables, z):
    row = tables["inf"].loc[z]
    meta = row[["charge_count","area","perimeter"]].values.astype(np.float32)
    poi  = tables["poi_counts"].loc[z].values.astype(np.float32) if "poi_counts" in tables else np.zeros(3, dtype=np.float32)
    return np.concatenate([meta, poi]).astype(np.float32)

def fb_time(tables, t0, L):
    win = slice(t0, t0 + pd.Timedelta(hours=L-1))
    return tables["time"].loc[win].values.T.astype(np.float32)

def fb_target(tables, z, t0, H, offset=None):
    if offset is None:
        offset = getattr(cfg.data, "pred_offset", 1)  # 기본 1
    start = t0 + pd.Timedelta(hours=offset)
    end   = start + pd.Timedelta(hours=H-1)
    if H == 1:
        y = tables["occ"].loc[start, z].astype(np.float32)
        y = np.array([y], dtype=np.float32)
        mask = np.isfinite(y)
        return y, mask
    else:
        y_idx = slice(start, end)
        y = tables["occ"].loc[y_idx, z].values.astype(np.float32)
        mask = np.isfinite(y)
        return y, mask




In [None]:
# Dataset / Collate
class EVDataset(Dataset):
    def __init__(self, df, tables, zone_ids):
        self.df = df.reset_index(drop=True); self.tables = tables; self.zone_ids = list(zone_ids)
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        r = self.df.iloc[idx]
        z, t0, L, H = r["zone_id"], r["t_start"], int(r["L"]), int(r["H"])
        x_weather = fb_weather(self.tables, t0, L)
        x_price   = fb_price(self.tables, z, t0, L)
        x_spatial = fb_spatial(self.tables, z, t0, L, self.zone_ids)
        x_static  = fb_static(self.tables, z)
        x_time    = fb_time(self.tables, t0, L)
        y, mask   = fb_target(self.tables, z, t0, H)
        return {
            "zone_id": z,
            "x_weather": x_weather, "x_price": x_price, "x_spatial": x_spatial,
            "x_static": x_static, "x_time": x_time, "y": y, "mask": mask,
            "moirai_key": {
                "zone_id": z, "t_start_iso": pd.to_datetime(t0).strftime("%Y-%m-%dT%H:%M:%S"),
                "L": L, "occ_only": 1, "model_id": cfg.moirai.model_id,
                "psig": "-".join(map(str, cfg.moirai.patch_sizes)), "pool": cfg.moirai.pool,
                "clock_hash": "", "norm_version":"", "datahash":""
            }
        }

def collate_fn_builder(moirai_emb, key2row, hashes, stats):
    def collate_fn(batch):
        to_f32 = lambda xs: torch.from_numpy(np.stack(xs)).float().to(DEVICE)
        out = {
            "x_weather": to_f32([b["x_weather"] for b in batch]),
            "x_price":   to_f32([b["x_price"]   for b in batch]),
            "x_spatial": to_f32([b["x_spatial"] for b in batch]),
            "x_static":  to_f32([b["x_static"]  for b in batch]),
            "x_time":    to_f32([b["x_time"]    for b in batch]),
            "y":         to_f32([b["y"]         for b in batch]),
            "mask": torch.from_numpy(np.stack([b["mask"] for b in batch])).to(DEVICE)
        }
        rows = []
        for b in batch:
            mk = dict(b["moirai_key"])
            mk["clock_hash"] = hashes.clockhash; mk["norm_version"] = stats["norm_version"]; mk["datahash"] = hashes.datahash
            rows.append([mk[k] for k in KEY_COLS])
        keys_df = pd.DataFrame(rows, columns=KEY_COLS)
        idxs = [key2row[tuple(r)] for _, r in keys_df.iterrows()]
        out["h_local_moirai"] = moirai_emb[idxs].to(DEVICE).float()
        return out
    return collate_fn

In [None]:
# Model
class EVForecastModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        D   = cfg.model.D_model
        D_m = cfg.model.D_moirai
        p   = cfg.model.dropout

        self.adapter = nn.Sequential(nn.Linear(D_m, D), nn.GELU(), nn.Dropout(p), nn.Linear(D, D))
        ks = cfg.model.KERNEL_SIZE; pad = ks//2

        def conv_block(cin):
            return nn.Sequential(
                nn.Conv1d(cin, 64, ks, padding=pad), nn.GELU(),
                nn.Conv1d(64, 64, ks, padding=pad), nn.GELU()
            )
        self.enc_weather = conv_block(9)
        self.enc_price   = conv_block(2)
        self.enc_spatial = conv_block(3)
        self.enc_time    = conv_block(6)

        self.enc_static  = nn.Sequential(nn.Linear(6, 128), nn.GELU(), nn.Dropout(p), nn.Linear(128, D))
        self.proj_weather= nn.Linear(64, D)
        self.proj_price  = nn.Linear(64, D)
        self.proj_spatial= nn.Linear(64, D)
        self.proj_time   = nn.Linear(64, D)

        def align(): return nn.Sequential(nn.LayerNorm(D), nn.Dropout(p), nn.Linear(D, D))
        self.align_local = align(); self.align_weather = align(); self.align_price = align()
        self.align_spatial = align(); self.align_static = align(); self.align_time = align()

        self.mha  = nn.MultiheadAttention(D, cfg.model.nhead, dropout=p, batch_first=True)
        self.head = nn.Linear(D, cfg.data.H)
        self.use_softplus = bool(getattr(cfg.model, "nonneg_head", False))
        self.softplus = nn.Softplus()

    def forward(self, batch):
        h_local   = self.adapter(batch["h_local_moirai"])
        h_weather = self.proj_weather(self.enc_weather(batch["x_weather"]).mean(-1))
        h_price   = self.proj_price  (self.enc_price  (batch["x_price"  ]).mean(-1))
        h_spatial = self.proj_spatial(self.enc_spatial(batch["x_spatial"]).mean(-1))
        h_time    = self.proj_time   (self.enc_time   (batch["x_time"   ]).mean(-1))
        h_static  = self.enc_static(batch["x_static"])

        def gate(align, h): return torch.sigmoid(align(h)) * h
        h_local_a   = gate(self.align_local,   h_local)
        h_weather_a = gate(self.align_weather, h_weather)
        h_price_a   = gate(self.align_price,   h_price)
        h_spatial_a = gate(self.align_spatial, h_spatial)
        h_static_a  = gate(self.align_static,  h_static)
        h_time_a    = gate(self.align_time,    h_time)

        tokens = torch.stack([h_weather_a, h_price_a, h_spatial_a, h_static_a, h_time_a], dim=1)
        q = h_local_a.unsqueeze(1)
        h, _ = self.mha(q, tokens, tokens)
        y_hat = self.head(h.squeeze(1))
        return self.softplus(y_hat) if self.use_softplus else y_hat


In [None]:
# Train / Val
def safe_save(obj, path):
    tmp = path + ".tmp"; torch.save(obj, tmp); os.replace(tmp, path)

def train_epoch(model, loader, opt, sched, scaler, use_amp):
    model.train(); total = 0.0
    amp_ctx = amp.autocast(device_type="cuda", enabled=use_amp) if use_amp else nullcontext()
    for batch in tqdm(loader, desc="train"):
        opt.zero_grad(set_to_none=True)
        with amp_ctx:
            y_hat = model(batch)
            loss = F.mse_loss(y_hat[batch["mask"]], batch["y"][batch["mask"]])
        if use_amp:
            scaler.scale(loss).backward()
            if cfg.train.clip_grad and cfg.train.clip_grad>0:
                scaler.unscale_(opt); torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.train.clip_grad)
            scaler.step(opt); scaler.update()
        else:
            loss.backward()
            if cfg.train.clip_grad and cfg.train.clip_grad>0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.train.clip_grad)
            opt.step()
        if sched is not None: sched.step()
        total += loss.item()
    return total / max(1, len(loader))

@torch.no_grad()
def val_epoch(model, loader, use_amp=False):
    model.eval(); total = 0.0
    amp_ctx = amp.autocast(device_type="cuda", enabled=use_amp) if use_amp else nullcontext()
    for batch in tqdm(loader, desc="valid"):
        with amp_ctx:
            y_hat = model(batch)
            loss = F.mse_loss(y_hat[batch["mask"]], batch["y"][batch["mask"]])
        total += loss.item()
    return total / max(1, len(loader))

def run_train(cfg):
    clock, tables, zone_ids, hashes = prepare_data(cfg)

    # POI (optional)
    try:
        if cfg.data.use_poi:
            from sklearn.neighbors import BallTree  # import check
            radius = compute_zone_radius(tables["inf"], beta=cfg.data.poi_radius_beta,
                                         r_min=cfg.data.poi_rmin, r_max=cfg.data.poi_rmax)
            tables["poi_counts"] = build_poi_counts(tables["inf"], tables["poi_raw"], radius)
    except ModuleNotFoundError:
        print("[warn] scikit-learn not installed; skip POI"); cfg.data.use_poi = False

    si, stats = build_sample_indices(tables, clock, zone_ids, cfg)
    print(f"[norm_version] {stats['norm_version']}")

    keys_train_val = make_moirai_keys(pd.concat([si["train"], si["val"]], ignore_index=True), hashes, stats, cfg)
    moirai_df, moirai_emb, emb_dim = ensure_moirai_cache(keys_train_val, tables, cfg, hashes, stats)

    # Moirai 임베딩 차원 반영
    cfg.model.D_moirai = int(emb_dim)
    print(f"[Moirai] emb_dim={emb_dim} -> adapter: {cfg.model.D_moirai} → {cfg.model.D_model}")

    key2row = {tuple(r[KEY_COLS]): i for i, (_, r) in enumerate(moirai_df.iterrows())}

    train_ds = EVDataset(si["train"], tables, zone_ids)
    val_ds   = EVDataset(si["val"],   tables, zone_ids)
    collate_fn = collate_fn_builder(moirai_emb, key2row, hashes, stats)
    train_loader = DataLoader(train_ds, batch_size=cfg.train.batch_size, shuffle=True,  collate_fn=collate_fn)
    val_loader   = DataLoader(val_ds,   batch_size=cfg.train.batch_size, shuffle=False, collate_fn=collate_fn)

    model = EVForecastModel(cfg).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
    if str(cfg.train.scheduler).lower() == "onecycle":
        sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=cfg.train.lr,
                                                    epochs=cfg.train.epochs,
                                                    steps_per_epoch=max(1,len(train_loader)))
    else:
        sched = None

    use_amp = (DEVICE=="cuda" and torch.cuda.is_available())
    scaler  = amp.GradScaler(enabled=use_amp)

    # ----- Train with per-epoch + best/last ckpt -----
    best_val = float("inf"); saved_epochs = []; keep_last_k = 0  # 0: 모두 보존
    for epoch in range(cfg.train.epochs):
        train_loss = train_epoch(model, train_loader, opt, sched, scaler, use_amp=use_amp)
        val_loss   = val_epoch(model, val_loader)

        # per-epoch checkpoint payload
        payload = {
            "state_dict": model.state_dict(),
            "epoch": epoch+1,
            "val_loss": float(val_loss),
            "train_loss": float(train_loss),
            "D_moirai": int(cfg.model.D_moirai),
            "cfg": cfg.__dict__,
        }
        ep_path = cfg.out.epoch_ckpt_tmpl.format(epoch+1)
        safe_save(payload, ep_path)        # epoch별 저장
        safe_save(payload, cfg.out.last_ckpt)  # last 갱신

        saved_epochs.append(ep_path)
        if keep_last_k > 0 and len(saved_epochs) > keep_last_k:
            old = saved_epochs.pop(0)
            try: os.remove(old)
            except FileNotFoundError: pass

        if val_loss < best_val:
            best_val = val_loss
            safe_save(payload, cfg.out.best_ckpt)

        print(f"[Epoch {epoch+1}/{cfg.train.epochs}] "
              f"train={train_loss:.4f} | val={val_loss:.4f} | best={best_val:.4f}")

    return {"clock": clock, "tables": tables, "zone_ids": zone_ids, "hashes": hashes, "stats": stats, "si": si}

def load_model_from_ckpt(cfg, ckpt_path):
    model = EVForecastModel(cfg).to(DEVICE)
    payload = torch.load(ckpt_path, map_location=DEVICE)
    if "D_moirai" in payload and int(payload["D_moirai"]) != int(cfg.model.D_moirai):
        cfg.model.D_moirai = int(payload["D_moirai"])
        model = EVForecastModel(cfg).to(DEVICE)
    model.load_state_dict(payload["state_dict"], strict=True)
    model.eval()
    return model, payload

@torch.no_grad()
def run_eval(cfg, si, tables, zone_ids, hashes, stats, which="best", path_override=None):
    if which=="best":
        ckpt_path = cfg.out.best_ckpt
    elif which=="last":
        ckpt_path = cfg.out.last_ckpt
    elif which=="path" and path_override:
        ckpt_path = path_override
    else:
        raise ValueError("which ∈ {'best','last','path(with path_override)'}")

    model, meta = load_model_from_ckpt(cfg, ckpt_path)
    print(f"[Eval] ckpt: {ckpt_path} | epoch={meta.get('epoch')} | val_loss={meta.get('val_loss')}")

    # 테스트용 Moirai 키/캐시 준비 (독립 실행도 고려)
    keys_test = make_moirai_keys(si["test"], hashes, stats, cfg)
    moirai_df_t, moirai_emb_t, emb_dim_t = ensure_moirai_cache(keys_test, tables, cfg, hashes, stats)

    # 모델의 D_moirai가 다르면 재생성
    if int(emb_dim_t) != int(cfg.model.D_moirai):
        cfg.model.D_moirai = int(emb_dim_t)
        model, meta = load_model_from_ckpt(cfg, ckpt_path)

    key2row_t = {tuple(r[KEY_COLS]): i for i, (_, r) in enumerate(moirai_df_t.iterrows())}
    collate_fn_t = collate_fn_builder(moirai_emb_t, key2row_t, hashes, stats)

    test_ds = EVDataset(si["test"], tables, zone_ids)
    test_loader = DataLoader(test_ds, batch_size=cfg.train.batch_size, shuffle=False, collate_fn=collate_fn_t)

    total_loss = 0.0
    for batch in tqdm(test_loader, desc="test"):
        y_hat = model(batch)
        loss = F.mse_loss(y_hat[batch["mask"]], batch["y"][batch["mask"]])
        total_loss += loss.item()
    test_mse = total_loss / max(1, len(test_loader))
    print(f"[Test:{which}] MSE={test_mse:.4f}")
    return {"test_mse": test_mse, "ckpt": ckpt_path, "epoch": meta.get("epoch"), "val_loss": meta.get("val_loss")}


In [None]:
prepared = run_train(cfg)


In [None]:
# 1) Best 모델로 테스트
res_best = run_eval(cfg,
                    prepared["si"], prepared["tables"], prepared["zone_ids"],
                    prepared["hashes"], prepared["stats"],
                    which="best")




In [None]:
# 2) Last 모델로 테스트
res_last = run_eval(cfg,
                    prepared["si"], prepared["tables"], prepared["zone_ids"],
                    prepared["hashes"], prepared["stats"],
                    which="last")

In [None]:
# 3) 임의 에포크 경로로 테스트 (예: epoch_0002)
some_ckpt = cfg.out.epoch_ckpt_tmpl.format(2)
if os.path.exists(some_ckpt):
    res_path = run_eval(cfg,
                        prepared["si"], prepared["tables"], prepared["zone_ids"],
                        prepared["hashes"], prepared["stats"],
                        which="path", path_override=some_ckpt)
else:
    print(f"[Info] {some_ckpt} not found; skip path-based eval.")
