In [8]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import chinese_calendar as cc
from tqdm.auto import tqdm
import random
from uni2ts.model.moirai import MoiraiModule
from copy import deepcopy
from types import SimpleNamespace
import chinese_calendar as cc
from pathlib import Path

In [9]:
if torch.backends.mps.is_available():
    DEVICE = 'mps'  
elif torch.cuda.is_available():
    DEVICE = 'cuda'  
else:
    DEVICE = 'cpu'  

print(DEVICE)

cuda


In [3]:
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
BASE = {
    # ===== Paths =====
    "data": {
        "DATA_DIR": "UrbanEV/data",
        "FN": {
            "occ": "occupancy.csv", "dur": "duration.csv", "vol": "volume.csv",
            "e_price": "e_price.csv", "s_price": "s_price.csv",
            "weather": "weather_central.csv",
            "inf": "inf.csv", "adj": "adj.csv", "dist": "distance.csv", "poi": "poi.csv",
        },
    },

    # ===== Data/windows =====
    "data_mode": {
        "USE_DUMMY_DATA": True, "DUMMY_NUM_ZONES": 2, "DUMMY_NUM_HOURS": 48,
        "TIMEZONE": "Asia/Shanghai", "L": 24, "H": 3,  # inputs window / horizon
        "BASELINE_OCC_ONLY": True,  # spatial 3ch vs 9ch
    },

    # ===== Shared dims/invariants =====
    "dims": {
        "D": 128,  # final embedding dim used across modules
    },

    # ===== Backbones & Embedders =====
    "moirai": {
        "use": True,
        "frozen": True,                 # False면 파인튜닝
        "save_path": None,             # Placeholder, will be set later
        "adapter": {                   # Moirai 출력→D 어댑터(선택)
            "use": True,
            "out_dim": 128,
            "hidden": 128, "dropout": 0.10,
        },
    },

    "embed": {
        # 각 모달별 1D-CNN 임베더 (출력은 모두 D로 맞추기)
        "weather": {"in_ch": 9, "hidden": 64, "k": 5, "dropout": 0.10, "frozen": True},
        "price":   {"in_ch": 2, "hidden": 64, "k": 5, "dropout": 0.10, "frozen": True},
        "spatial": {"in_ch": 3, "hidden": 64, "k": 5, "dropout": 0.10, "frozen": True},  # 9ch일 땐 아래 derive에서 수정
        "time":    {"in_ch": 5, "hidden": 64, "k": 5, "dropout": 0.10, "frozen": True},
        "static":  {"in_dim": None, "hidden": 128, "dropout": 0.10, "frozen": True},     # in_dim은 로딩 후 설정
    },

    # ===== Align (Proj+LN+Drop+Gate) =====
    "align": {
        "use_gate": True, "p_drop": 0.10,
    },

    # ===== Fusion =====
    "fusion": {
        "type": "cross_attn",   # 나중에 "film", "concat_mlp" 등으로 교체 가능
        "NHEAD": 8,             # D % NHEAD == 0 권장
        "dropout": 0.10,
    },

    # ===== Head (Multi-Horizon Regression) =====
    "head": {
        "H": 3, "hidden": 512, "dropout": 0.10, "nonneg": True,
    },

    # ===== Train =====
    "train": {
        "seed": 42,
        "batch_size": 64,
        "lr": 1e-3,
        "epochs": 10,
        "optimizer": "adamw",
        "weight_decay": 1e-4,
    },
}

# Set save_path after BASE is defined
BASE["moirai"]["save_path"] = (
    "moirai_occ_embeddings.pt" if BASE["data_mode"]["BASELINE_OCC_ONLY"] else "moirai_other_embeddings.pt"
)

# === 실험용 오버라이드: 이번 실험에서만 바꿀 값들 ===
EXP = {
    # 예: spatial을 9ch로 바꾸고 weather/price를 학습 가능하게
    # "data_mode": {"BASELINE_OCC_ONLY": False},
    # "embed": {
    #     "weather": {"frozen": False, "hidden": 128, "k": 7},
    #     "price":   {"frozen": False},
    # },
    # "fusion": {"NHEAD": 4},
    # "head": {"hidden": 1024, "dropout": 0.2},
}

def deep_update(base: dict, over: dict):
    out = deepcopy(base)
    for k, v in over.items():
        if isinstance(v, dict) and k in out and isinstance(out[k], dict):
            out[k] = deep_update(out[k], v)
        else:
            out[k] = v
    return out

# === 머지 & 파생값 채우기 ===
_cfg = deep_update(BASE, EXP)

# 파생: spatial in_ch
spatial_ch = 3 if _cfg["data_mode"]["BASELINE_OCC_ONLY"] else 9
_cfg["embed"]["spatial"]["in_ch"] = spatial_ch

# 파생: static in_dim (데이터 로드 후 알 수 있음) → 일단 None 유지
# 파생: fusion NHEAD 유효성
D = _cfg["dims"]["D"]
if _cfg["fusion"]["type"] == "cross_attn":
    assert D % _cfg["fusion"]["NHEAD"] == 0, "D는 NHEAD로 나눠떨어지는 것이 안전합니다."

cfg = SimpleNamespace(**_cfg)  # dot-access
print("Config ready. D=", cfg.dims["D"], "| spatial in_ch=", cfg.embed["spatial"]["in_ch"])

Config ready. D= 128 | spatial in_ch= 3


In [6]:


# =========================
# UrbanEV 데이터 유틸 함수 (보강판)
# =========================
# Note: Assumes `cfg` (SimpleNamespace) is available from the configuration code

def normalize_zone_ids(ids):
    """
    Zone IDs를 문자열로 통일:
    - str 변환 → 공백 제거 → 끝의 '.0' 제거
    Args:
        ids: Iterable of zone IDs (str, int, or float)
    Returns:
        list[str]: Normalized zone IDs
    """
    return [str(x).strip().replace(".0", "") for x in ids]

def align_hourly(df, tz=None):
    """
    DatetimeIndex를 1시간 단위로 맞추고, 타임존을 고정한다.
    - 누락된 시간은 생성하고 ffill로 채움
    - 남은 NaN은 0으로 채움
    Args:
        df: pandas.DataFrame with DatetimeIndex
        tz: str, timezone (default: cfg.data_mode["TIMEZONE"])
    Returns:
        pandas.DataFrame: Aligned DataFrame
    """
    tz = tz or cfg.data_mode["TIMEZONE"]
    if df.index.tz is None:
        df.index = df.index.tz_localize(tz)
    else:
        df.index = df.index.tz_convert(tz)
    df = df.sort_index().asfreq("1h")  # 1시간 간격
    df = df.ffill().fillna(0.0)       # 결측 처리
    return df

def subset_dummy_timeseries(df, use_dummy=None, n_zones=None, n_hours=None, is_zone_table=True):
    """
    더미 모드일 때 시계열 데이터에서 일부만 선택한다.
    Args:
        df: pandas.DataFrame, input data
        use_dummy: bool, whether to use dummy mode (default: cfg.data_mode["USE_DUMMY_DATA"])
        n_zones: int, number of zones to select (default: cfg.data_mode["DUMMY_NUM_ZONES"])
        n_hours: int, number of hours to select (default: cfg.data_mode["DUMMY_NUM_HOURS"])
        is_zone_table: bool, whether df has zone columns (default: True)
    Returns:
        pandas.DataFrame: Subset DataFrame
    """
    use_dummy = use_dummy if use_dummy is not None else cfg.data_mode["USE_DUMMY_DATA"]
    n_zones = n_zones if n_zones is not None else cfg.data_mode["DUMMY_NUM_ZONES"]
    n_hours = n_hours if n_hours is not None else cfg.data_mode["DUMMY_NUM_HOURS"]

    if use_dummy:
        if is_zone_table:
            if df.shape[1] < n_zones:
                print(f"Warning: Data has {df.shape[1]} zones, but {n_zones} requested")
            df = df.iloc[:, :min(n_zones, df.shape[1])]
        if df.shape[0] < n_hours:
            print(f"Warning: Data has {df.shape[0]} rows, but {n_hours} requested")
        df = df.iloc[:min(n_hours, df.shape[0])]
    return df

def check_same_columns(dfs):
    """
    여러 DataFrame들이 동일한 zone column 순서를 갖는지 확인.
    Args:
        dfs: list[pandas.DataFrame], DataFrames to check
    Raises:
        AssertionError: If column orders differ
    """
    cols = None
    for df in dfs:
        if df is None:
            continue
        if cols is None:
            cols = list(df.columns)
        else:
            assert list(df.columns) == cols, "Zone column 순서가 일치하지 않습니다."

def read_timeseries_csv(path, tz=None):
    """
    첫 열이 timestamp, 나머지 열이 zone id인 시계열 CSV를 읽는다.
    Args:
        path: str or Path, path to CSV file
        tz: str, timezone (default: cfg.data_mode["TIMEZONE"])
    Returns:
        pandas.DataFrame: Aligned DataFrame with normalized zone IDs
    """
    path = Path(cfg.data["DATA_DIR"]) / path if not str(path).startswith(cfg.data["DATA_DIR"]) else Path(path)
    try:
        df = pd.read_csv(path, index_col=0, parse_dates=[0])
    except FileNotFoundError:
        print(f"Error: File {path} not found")
        return None
    df = align_hourly(df, tz=tz)
    df.columns = normalize_zone_ids(df.columns)
    return df

def read_square_noindex_csv(path, as_float=False, force_diag_zero=True):
    """
    행 인덱스 없이 열 헤더만 있는 정방행렬 CSV를 읽는다.
    Args:
        path: str or Path, path to CSV file
        as_float: bool, convert to float (default: False)
        force_diag_zero: bool, set diagonal to zero (default: True)
    Returns:
        pandas.DataFrame: Square matrix with normalized zone IDs
    """
    path = Path(cfg.data["DATA_DIR"]) / path if not str(path).startswith(cfg.data["DATA_DIR"]) else Path(path)
    try:
        df = pd.read_csv(path, header=0)
    except FileNotFoundError:
        print(f"Error: File {path} not found")
        return None

    dump_cols = [c for c in df.columns if str(c).startswith("Unnamed")]
    if dump_cols:
        df = df.drop(columns=dump_cols)

    df.columns = normalize_zone_ids(df.columns)
    df.index = df.columns.copy()

    n_rows, n_cols = df.shape
    if n_rows != n_cols:
        raise ValueError(f"정방 행렬이 아님: rows={n_rows}, cols={n_cols}")

    df = df.astype(float if as_float else int)
    if force_diag_zero:
        np.fill_diagonal(df.values, 0)
    return df

def read_inf_csv_rows_are_zones(path):
    """
    inf.csv: 행=zone ID, 열=['longitude','latitude','charge_count','area','perimeter']
    Args:
        path: str or Path, path to CSV file
    Returns:
        pandas.DataFrame: (Z,5) DataFrame with normalized zone IDs
    """
    path = Path(cfg.data["DATA_DIR"]) / path if not str(path).startswith(cfg.data["DATA_DIR"]) else Path(path)
    expected = ["longitude", "latitude", "charge_count", "area", "perimeter"]
    try:
        df = pd.read_csv(path, index_col=0)
    except FileNotFoundError:
        print(f"Error: File {path} not found")
        return None

    df.index = normalize_zone_ids(df.index)
    df.columns = [str(c).strip() for c in df.columns]

    rename_map = {
        "Longitude": "longitude",
        "LATITUDE": "latitude",
        "Charge_Count": "charge_count",
        "Area": "area",
        "Perimeter": "perimeter"
    }
    df = df.rename(columns=rename_map)

    for c in expected:
        if c not in df.columns:
            print(f"Warning: Column {c} missing in {path}, filling with 0.0")
            df[c] = 0.0
        df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0.0)

    return df[expected]

def read_poi_csv(path):
    """
    poi.csv: 행=primary_types, 열=['longitude','latitude']
    Args:
        path: str or Path, path to CSV file
    Returns:
        pandas.DataFrame: DataFrame with ['longitude','latitude']
    """
    path = Path(cfg.data["DATA_DIR"]) / path if not str(path).startswith(cfg.data["DATA_DIR"]) else Path(path)
    try:
        df = pd.read_csv(path, index_col=0)
    except FileNotFoundError:
        print(f"Error: File {path} not found")
        return None

    cols_l = {c.lower().strip(): c for c in df.columns}
    lon_src = cols_l.get("longitude") or cols_l.get("lon") or cols_l.get("lng")
    lat_src = cols_l.get("latitude") or cols_l.get("lat")
    if lon_src and lon_src != "longitude":
        df = df.rename(columns={lon_src: "longitude"})
    if lat_src and lat_src != "latitude":
        df = df.rename(columns={lat_src: "latitude"})
    for c in ["longitude", "latitude"]:
        if c not in df.columns:
            print(f"Warning: Column {c} missing in {path}, filling with NA")
            df[c] = pd.NA
    return df[["longitude", "latitude"]]

def expand_rain_onehot(df_weather):
    """
    weather.csv에서 nRAIN(0~3)을 원핫 벡터(4채널)로 변환.
    Args:
        df_weather: pandas.DataFrame with 'nRAIN' column
    Returns:
        pandas.DataFrame: [T,P0,P,U,Td,rain_0,rain_1,rain_2,rain_3]
    """
    assert "nRAIN" in df_weather.columns, "weather.csv에 nRAIN이 필요합니다."
    n_rain = df_weather["nRAIN"].fillna(0).astype(int)
    if n_rain.max() > 3 or n_rain.min() < 0:
        print(f"Warning: nRAIN values out of range [0,3]: {n_rain.min()} to {n_rain.max()}")
        n_rain = n_rain.clip(0, 3)
    rain_oh = pd.get_dummies(n_rain, prefix="rain")
    for k in range(4):
        col = f"rain_{k}"
        if col not in rain_oh.columns:
            rain_oh[col] = 0
    rain_oh = rain_oh[[f"rain_{k}" for k in range(4)]]

    base_cols = ["T", "P0", "P", "U", "Td"]
    base = df_weather.copy()
    for c in base_cols:
        if c not in base.columns:
            print(f"Warning: Column {c} missing in weather data, filling with 0.0")
            base[c] = 0.0
    base = base[base_cols]

    df = pd.concat([base, rain_oh], axis=1)
    assert df.shape[1] == 9, f"Weather 채널이 9개여야 합니다, got {df.shape[1]}"
    return df

def make_time_features(index, tz=None):
    """
    시간 관련 특성 5채널: [hour_sin, hour_cos, dow_sin, dow_cos, is_off_cn]
    Args:
        index: pandas.DatetimeIndex
        tz: str, timezone (default: cfg.data_mode["TIMEZONE"])
    Returns:
        np.ndarray: (T, 5)
    """
    tz = tz or cfg.data_mode["TIMEZONE"]
    idx = index.tz_convert(tz) if index.tz is not None else index.tz_localize(tz)

    hour = idx.hour.values
    dow = idx.dayofweek.values

    hour_sin = np.sin(2 * np.pi * hour / 24)
    hour_cos = np.cos(2 * np.pi * hour / 24)
    dow_sin = np.sin(2 * np.pi * dow / 7)
    dow_cos = np.cos(2 * np.pi * dow / 7)

    # Vectorized workday check
    dates = pd.to_datetime(idx)
    is_workday = np.array([cc.is_workday(date.to_pydatetime()) for date in dates], dtype=bool)
    is_off_cn = (~is_workday).astype(float)

    feats = np.stack([hour_sin, hour_cos, dow_sin, dow_cos, is_off_cn], axis=1)
    return feats  # (T,5)

def make_W(adj_df, dist_df, eps=1e-6, clip_max=None):
    """
    인접행렬과 거리행렬을 이용해 가중치 행렬 W 계산.
    Args:
        adj_df: pandas.DataFrame, adjacency matrix
        dist_df: pandas.DataFrame, distance matrix
        eps: float, small value to avoid division by zero
        clip_max: float, optional max value for weights
    Returns:
        np.ndarray: (Z,Z) normalized weight matrix
    """
    assert adj_df.shape == dist_df.shape, "Adjacency and distance matrices must have same shape"
    assert list(adj_df.columns) == list(dist_df.columns), "Column names must match"
    assert list(adj_df.index) == list(dist_df.index), "Index names must match"

    adj = adj_df.values.astype(float)
    dist = dist_df.values.astype(float)
    invd = 1.0 / (dist + eps)
    invd *= adj
    np.fill_diagonal(invd, 0.0)
    if clip_max is not None:
        invd = np.clip(invd, 0, clip_max)

    rowsum = invd.sum(axis=1, keepdims=True)
    iso = (rowsum.squeeze() == 0)
    if iso.any():
        for i in np.where(iso)[0]:
            invd[i, i] = 1.0
        rowsum = invd.sum(axis=1, keepdims=True)

    W = invd / rowsum
    return W

def make_spatial_features(occ_df, dur_df, vol_df, W, eps=1e-6):
    """
    Spatial features 생성: 3ch (occ only) 또는 9ch (occ+dur+vol).
    Args:
        occ_df: pandas.DataFrame, occupancy data
        dur_df: pandas.DataFrame, duration data (None if 3ch)
        vol_df: pandas.DataFrame, volume data (None if 3ch)
        W: np.ndarray, weight matrix (Z,Z)
        eps: float, small value to avoid division by zero
    Returns:
        np.ndarray: (Z,T,3) or (Z,T,9) depending on cfg.data_mode["BASELINE_OCC_ONLY"]
    """
    occ = occ_df.values.astype(float)
    nbr_occ = occ @ W
    gap_occ = occ - nbr_occ
    ratio_occ = occ / (nbr_occ + eps)
    
    if cfg.data_mode["BASELINE_OCC_ONLY"]:
        arr = np.stack([nbr_occ, gap_occ, ratio_occ], axis=-1)  # (T,Z,3)
        return np.transpose(arr, (1,0,2))  # (Z,T,3)

    assert dur_df is not None and vol_df is not None, "dur_df and vol_df required for 9ch"
    check_same_columns([occ_df, dur_df, vol_df])

    dur = dur_df.values.astype(float)
    vol = vol_df.values.astype(float)

    nbr_dur = dur @ W
    nbr_vol = vol @ W
    gap_dur = dur - nbr_dur
    gap_vol = vol - nbr_vol
    ratio_dur = dur / (nbr_dur + eps)
    ratio_vol = vol / (nbr_vol + eps)

    chs = [nbr_occ, nbr_dur, nbr_vol,
           gap_occ, gap_dur, gap_vol,
           ratio_occ, ratio_dur, ratio_vol]
    arr = np.stack(chs, axis=-1)  # (T,Z,9)
    return np.transpose(arr, (1,0,2))  # (Z,T,9)

def build_windows(index, L=None, H=None, step=1, use_future=False):
    """
    슬라이딩 윈도우 인덱스 생성.
    Args:
        index: pandas.DatetimeIndex
        L: int, input window length (default: cfg.data_mode["L"])
        H: int, horizon length (default: cfg.data_mode["H"])
        step: int, stride
        use_future: bool, include horizon in window
    Returns:
        list[tuple]: [(start, end), ...]
    """
    L = L if L is not None else cfg.data_mode["L"]
    H = H if H is not None else cfg.data_mode["H"]
    
    Tlen = len(index)
    out = []
    last_t = Tlen - 1 - (H if use_future else 0)
    if L > Tlen:
        print(f"Warning: Window length {L} exceeds index length {Tlen}")
        return []
    for t in range(L-1, last_t+1, step):
        s = t - (L - 1)
        out.append((s, t))
    return out

def _zone_radius_km(area_m2, perimeter_m, use_perimeter=True, radius_scale=1.0, max_radius_km=None):
    """
    Calculate zone radius (km) from area (m²) and perimeter (m).
    Args:
        area_m2: np.ndarray, zone areas in square meters
        perimeter_m: np.ndarray, zone perimeters in meters
        use_perimeter: bool, use perimeter-based radius
        radius_scale: float, scaling factor for radius
        max_radius_km: float, optional max radius
    Returns:
        np.ndarray: Radii in km
    """
    r_area = np.sqrt((np.maximum(area_m2, 0.0) / 1e6) / np.pi)
    if use_perimeter:
        r_per = (np.maximum(perimeter_m, 0.0) / 1000.0) / (2 * np.pi)
        r = np.minimum(r_area, r_per)
    else:
        r = r_area
    r = r * float(radius_scale)
    if max_radius_km is not None:
        r = np.minimum(r, float(max_radius_km))
    return r

def _haversine_matrix_km(Plon, Plat, Zlon, Zlat):
    """
    Haversine distance matrix (P×Z, km).
    Args:
        Plon: np.ndarray, POI longitudes
        Plat: np.ndarray, POI latitudes
        Zlon: np.ndarray, zone longitudes
        Zlat: np.ndarray, zone latitudes
    Returns:
        np.ndarray: (P,Z) distance matrix in km
    """
    R = 6371.0
    lon1 = np.radians(Plon)[:, None]  # (P,1)
    lat1 = np.radians(Plat)[:, None]  # (P,1)
    lon2 = np.radians(Zlon)[None, :]  # (1,Z)
    lat2 = np.radians(Zlat)[None, :]  # (1,Z)
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    a = np.sin(dlat/2.0)**2 + np.cos(lat1)*np.cos(lat2)*np.sin(dlon/2.0)**2
    return 2 * R * np.arcsin(np.sqrt(a))  # (P,Z)

def poi_to_zone_counts_within_area(df_poi, df_inf_z5, zones,
                                   use_perimeter=True,
                                   radius_scale=1.0,
                                   max_radius_km=None,
                                   ensure_types=None):
    """
    Count POIs within circular buffer of each zone.
    Args:
        df_poi: pandas.DataFrame, POI data with ['longitude','latitude']
        df_inf_z5: pandas.DataFrame, zone info with ['longitude','latitude','charge_count','area','perimeter']
        zones: list[str], zone IDs
        use_perimeter: bool, use perimeter-based radius
        radius_scale: float, scaling factor for radius
        max_radius_km: float, optional max radius
        ensure_types: list[str], ensure these POI types in output
    Returns:
        pandas.DataFrame: Zone×type count matrix
    """
    if df_poi is None or len(df_poi) == 0:
        return pd.DataFrame(index=zones)

    zones = normalize_zone_ids(zones)
    ztab = df_inf_z5.reindex(zones)
    for c in ["longitude", "latitude", "area", "perimeter"]:
        ztab[c] = pd.to_numeric(ztab[c], errors="coerce").fillna(0.0)

    Zlon = ztab["longitude"].to_numpy(float)
    Zlat = ztab["latitude"].to_numpy(float)
    Zarea = ztab["area"].to_numpy(float)
    Zperi = ztab["perimeter"].to_numpy(float)

    r_km = _zone_radius_km(Zarea, Zperi, use_perimeter=use_perimeter,
                           radius_scale=radius_scale, max_radius_km=max_radius_km)

    poi = df_poi.copy()
    poi = poi.dropna(subset=["longitude", "latitude"])
    if len(poi) == 0:
        return pd.DataFrame(index=zones)

    Plon = pd.to_numeric(poi["longitude"], errors="coerce").to_numpy()
    Plat = pd.to_numeric(poi["latitude"], errors="coerce").to_numpy()
    Ptyp = poi.index.astype(str).to_numpy()

    dist = _haversine_matrix_km(Plon, Plat, Zlon, Zlat)
    mask = dist <= r_km[None, :]

    p_idx, z_idx = np.where(mask)
    if len(p_idx) == 0:
        return pd.DataFrame(0.0, index=zones,
                            columns=(ensure_types if ensure_types is not None else []))

    zones_arr = np.array(zones)
    z_taken = zones_arr[z_idx]
    t_taken = Ptyp[p_idx]

    assign = pd.DataFrame({"zone": z_taken, "type": t_taken, "cnt": 1})
    counts = assign.pivot_table(index="zone", columns="type", values="cnt",
                                aggfunc="sum", fill_value=0)

    if ensure_types is not None:
        for t in ensure_types:
            if t not in counts.columns:
                counts[t] = 0
        counts = counts[ensure_types]

    counts = counts.reindex(index=zones).fillna(0).astype(float)
    return counts

In [7]:
# ===== Cell 3: Load data & prepare tensors (dummy/full, weather 9ch, time 5ch, spatial (3ch|9ch), static+POI) =====
# 이 셀은 (B, C, L) 텐서를 만들지 않습니다. 다음 셀에서 윈도우 슬라이싱에 사용할 "재료"들만 준비합니다.
# 필요 유틸: read_timeseries_csv, subset_dummy_timeseries, check_same_columns, normalize_zone_ids,
#            read_square_noindex_csv, read_inf_csv_rows_are_zones, read_poi_csv,
#            expand_rain_onehot, make_time_features, make_W, make_spatial_features,
#            poi_to_zone_counts_within_area, _zone_radius_km

# 0) 경로 준비
occ_path = os.path.join(cfg.data["DATA_DIR"], cfg.data["FN"]["occ"])
dur_path = os.path.join(cfg.data["DATA_DIR"], cfg.data["FN"]["dur"])
vol_path = os.path.join(cfg.data["DATA_DIR"], cfg.data["FN"]["vol"])
ep_path  = os.path.join(cfg.data["DATA_DIR"], cfg.data["FN"]["e_price"])
sp_path  = os.path.join(cfg.data["DATA_DIR"], cfg.data["FN"]["s_price"])
w_path   = os.path.join(cfg.data["DATA_DIR"], cfg.data["FN"]["weather"])
adj_path = os.path.join(cfg.data["DATA_DIR"], cfg.data["FN"]["adj"])
dst_path = os.path.join(cfg.data["DATA_DIR"], cfg.data["FN"]["dist"])
inf_path = os.path.join(cfg.data["DATA_DIR"], cfg.data["FN"]["inf"])
poi_path = os.path.join(cfg.data["DATA_DIR"], cfg.data["FN"]["poi"])

# 1) 시계열 로드 + 더미/전체 전환
df_occ = read_timeseries_csv(occ_path)
df_dur = read_timeseries_csv(dur_path)
df_vol = read_timeseries_csv(vol_path)
df_ep  = read_timeseries_csv(ep_path)
df_sp  = read_timeseries_csv(sp_path)
df_w   = read_timeseries_csv(w_path)

# Check for None returns from file loading
for df, name in [(df_occ, "occ"), (df_dur, "dur"), (df_vol, "vol"),
                 (df_ep, "e_price"), (df_sp, "s_price"), (df_w, "weather")]:
    if df is None:
        raise FileNotFoundError(f"Failed to load {name} data")

# Apply dummy mode if enabled
df_occ = subset_dummy_timeseries(df_occ, is_zone_table=True)
df_dur = subset_dummy_timeseries(df_dur, is_zone_table=True)
df_vol = subset_dummy_timeseries(df_vol, is_zone_table=True)
df_ep  = subset_dummy_timeseries(df_ep,  is_zone_table=True)
df_sp  = subset_dummy_timeseries(df_sp,  is_zone_table=True)
df_w   = subset_dummy_timeseries(df_w,   is_zone_table=False)

# 시계열 표들 zone 열 순서 일치 확인
check_same_columns([df_occ, df_dur, df_vol, df_ep, df_sp])

# zone 목록/개수/시간길이
zones = list(df_occ.columns)
Z = len(zones)
Tlen = len(df_occ)

# 2) 정방 행렬(adj, distance) 로드 후 zones 순서로 정렬
df_adj = read_square_noindex_csv(adj_path, as_float=False)
df_dist = read_square_noindex_csv(dst_path, as_float=True)
if df_adj is None or df_dist is None:
    raise FileNotFoundError("Failed to load adjacency or distance matrix")
df_adj = df_adj.loc[zones, zones]
df_dist = df_dist.loc[zones, zones]

# 3) 정적 표(inf) 로드 + reindex
df_inf = read_inf_csv_rows_are_zones(inf_path)
if df_inf is None:
    print("Warning: inf.csv failed to load, using default zeros")
    df_inf = pd.DataFrame(0.0, index=zones,
                          columns=["longitude", "latitude", "charge_count", "area", "perimeter"])
df_inf = df_inf.reindex(zones)

# 4) POI 표 로드
df_poi = read_poi_csv(poi_path)
if df_poi is None:
    print("Warning: poi.csv failed to load, using None")

# 5) 파생 특성: Weather 9채널, Time 5채널
df_w9 = expand_rain_onehot(df_w)  # (T, 9)
assert df_w9.shape[1] == cfg.embed["weather"]["in_ch"], f"Weather channels must be {cfg.embed['weather']['in_ch']}"
time5 = make_time_features(df_occ.index)  # (T, 5) np.ndarray
assert time5.shape[1] == cfg.embed["time"]["in_ch"], f"Time channels must be {cfg.embed['time']['in_ch']}"
time6 = time5  # Compatibility

# 6) 공간 가중치 W 및 Spatial (3ch | 9ch)
W = make_W(df_adj, df_dist, eps=1e-6, clip_max=None)  # (Z, Z)
spatial_Z_T = make_spatial_features(
    df_occ,
    None if cfg.data_mode["BASELINE_OCC_ONLY"] else df_dur,
    None if cfg.data_mode["BASELINE_OCC_ONLY"] else df_vol,
    W
)  # (Z, T, 3|9)
assert spatial_Z_T.shape[2] == cfg.embed["spatial"]["in_ch"], f"Spatial channels must be {cfg.embed['spatial']['in_ch']}"

# 7) 정적 특성(Static): inf + POI type counts(원형 버퍼 근사) 결합
base_cols = ["charge_count", "area", "perimeter"]  # Include perimeter for consistency
inf_z5 = df_inf  # (Z,5)

# POI 설정 (cfg에 추가 가능)
POI_USE_PERIMETER = True
POI_RADIUS_SCALE = 1.0
POI_MAX_RADIUS_KM = None
PRIMARY_TYPES = ['lifestyle services', 'business and residential', 'food and beverage services']

if df_poi is not None and len(df_poi) > 0:
    poi_counts = poi_to_zone_counts_within_area(
        df_poi=df_poi,
        df_inf_z5=inf_z5,
        zones=zones,
        use_perimeter=POI_USE_PERIMETER,
        radius_scale=POI_RADIUS_SCALE,
        max_radius_km=POI_MAX_RADIUS_KM,
        ensure_types=PRIMARY_TYPES
    )
else:
    poi_counts = pd.DataFrame(index=zones, columns=PRIMARY_TYPES).fillna(0.0)

static_df = pd.concat([inf_z5[base_cols], poi_counts], axis=1).fillna(0.0)
static_arr = static_df.values.astype(float)  # (Z, F_static)
F_static = static_arr.shape[1]
cfg.embed["static"]["in_dim"] = F_static  # Update cfg with static dimension

# 8) 요약 출력
print(f"Zones: {Z} | Time steps: {Tlen}")
print(f"Weather {cfg.embed['weather']['in_ch']}ch: {df_w9.shape}   (T, {cfg.embed['weather']['in_ch']})")
print(f"Time {cfg.embed['time']['in_ch']}ch:    {time5.shape}   (T, {cfg.embed['time']['in_ch']})")
print(f"W matrix:    {W.shape}       (Z, Z)")
print(f"Spatial {cfg.embed['spatial']['in_ch']}ch: {spatial_Z_T.shape}   (Z, T, {cfg.embed['spatial']['in_ch']})")
print(f"Static dim:  {F_static}      (inf[{base_cols}] + POI types = {poi_counts.shape[1]})")

# 이 셀의 산출물:
# - df_occ, df_dur, df_vol        (T, Z)
# - df_ep, df_sp                  (T, Z)
# - df_w9                         (T, 9)
# - time5                         (T, 5)  (np.ndarray)  # (호환: time6 = time5)
# - df_adj, df_dist               (Z, Z)
# - W                             (Z, Z)
# - spatial_Z_T                   (Z, T, 3|9)
# - static_arr                    (Z, F_static)
# - zones, Z, Tlen, F_static

Zones: 2 | Time steps: 48
Weather 9ch: (48, 9)   (T, 9)
Time 5ch:    (48, 5)   (T, 5)
W matrix:    (2, 2)       (Z, Z)
Spatial 3ch: (2, 48, 3)   (Z, T, 3)
Static dim:  6      (inf[['charge_count', 'area', 'perimeter']] + POI types = 3)


In [17]:


# ===== Cell 4: Window builder & lightweight Dataset (no target) =====
# 목적: 슬라이딩 윈도우로 (B, C, L) 배치 텐서 생성 (타깃 없음, 임베딩 모듈 shape 체크용)

# --------- helpers: slice functions (return channel-first C x L) ---------
def slice_occ_block(df_occ, s, t):
    """
    Slice occupancy data for a window.
    Args:
        df_occ: pandas.DataFrame, occupancy data (T, Z)
        s: int, start index
        t: int, end index (inclusive)
    Returns:
        np.ndarray: (Z, 1, L), occupancy data
    """
    if df_occ is None or df_occ.empty:
        raise ValueError("df_occ is None or empty")
    occ = df_occ.iloc[s:t+1].to_numpy(dtype=np.float32).T  # (Z, L)
    return occ[:, None, :]                                 # (Z, 1, L)

def slice_price_block(df_ep, df_sp, s, t):
    """
    Slice price data (electricity and service) for a window.
    Args:
        df_ep: pandas.DataFrame, electricity price data (T, Z)
        df_sp: pandas.DataFrame, service price data (T, Z)
        s: int, start index
        t: int, end index (inclusive)
    Returns:
        np.ndarray: (Z, 2, L), price data
    """
    if df_ep is None or df_sp is None or df_ep.empty or df_sp.empty:
        raise ValueError("df_ep or df_sp is None or empty")
    ep = df_ep.iloc[s:t+1].to_numpy(dtype=np.float32).T    # (Z, L)
    sp = df_sp.iloc[s:t+1].to_numpy(dtype=np.float32).T
    return np.stack([ep, sp], axis=1)                      # (Z, 2, L)

def slice_weather_block(df_w9, s, t, Z):
    """
    Slice weather data and broadcast to zones.
    Args:
        df_w9: pandas.DataFrame, weather data (T, 9)
        s: int, start index
        t: int, end index (inclusive)
        Z: int, number of zones
    Returns:
        np.ndarray: (Z, 9, L), broadcasted weather data
    """
    if df_w9 is None or df_w9.empty:
        raise ValueError("df_w9 is None or empty")
    w = df_w9.iloc[s:t+1].to_numpy(dtype=np.float32).T     # (9, L)
    return np.broadcast_to(w, (Z,)+w.shape)                # (Z, 9, L)

def slice_time_block(time_feat, s, t, Z):
    """
    Slice time features and broadcast to zones.
    Args:
        time_feat: np.ndarray, time features (T, C_time)
        s: int, start index
        t: int, end index (inclusive)
        Z: int, number of zones
    Returns:
        np.ndarray: (Z, C_time, L), broadcasted time features
    """
    if time_feat is None or time_feat.size == 0:
        raise ValueError("time_feat is None or empty")
    tm = time_feat[s:t+1].astype(np.float32).T             # (C_time, L)
    return np.broadcast_to(tm, (Z,)+tm.shape)              # (Z, C_time, L)

def slice_spatial_block(spatial_Z_T, s, t):
    """
    Slice spatial features for a window.
    Args:
        spatial_Z_T: np.ndarray, spatial features (Z, T, C_spatial)
        s: int, start index
        t: int, end index (inclusive)
    Returns:
        np.ndarray: (Z, C_spatial, L)
    """
    if spatial_Z_T is None or spatial_Z_T.size == 0:
        raise ValueError("spatial_Z_T is None or empty")
    return np.transpose(spatial_Z_T[:, s:t+1, :].astype(np.float32), (0, 2, 1))

# --------- Dataset (lightweight, no target) ---------
class EVEmbeddingDataset(Dataset):
    """
    Each sample = one zone-window. Returns channel-first tensors (C, L).
    Keys:
        - zone_id:      torch.LongTensor scalar
        - window:       torch.LongTensor scalar
        - x_local_occ:  (1, L), occupancy for Moirai (occ-only)
        - x_local:      (3, L), occ+dur+vol (if not BASELINE_OCC_ONLY)
        - x_price:      (2, L), electricity and service prices
        - x_weather:    (9, L), weather features
        - x_time:       (C_time, L), time features (usually 5)
        - x_spatial:    (C_spatial, L), spatial features (3 or 9)
        - x_static:     (F_static,), static features
    """
    def __init__(self, windows, df_occ, df_dur, df_vol, df_ep, df_sp,
                 df_w9, time_feat, spatial_Z_T, static_arr, zones):
        super().__init__()
        self.windows = windows
        self.df_occ = df_occ
        self.df_dur = df_dur if not cfg.data_mode["BASELINE_OCC_ONLY"] else None
        self.df_vol = df_vol if not cfg.data_mode["BASELINE_OCC_ONLY"] else None
        self.df_ep, self.df_sp = df_ep, df_sp
        self.df_w9 = df_w9
        self.time_feat = time_feat
        self.spatial = spatial_Z_T
        self.static_arr = static_arr
        self.zones = zones
        self.Z = len(zones)

        # Sanity checks
        if self.df_occ is None or self.df_occ.empty:
            raise ValueError("df_occ is None or empty")
        T = self.df_occ.shape[0]
        assert all(df is None or df.shape[0] == T for df in (self.df_dur, self.df_vol, self.df_ep, self.df_sp, self.df_w9)), \
            "All time-series must share the same T length"
        assert self.time_feat.shape[0] == T, "Time features T must match occ T"
        assert self.spatial.shape[0] == self.Z == self.static_arr.shape[0], "Z mismatch among spatial/static/zones"
        assert self.spatial.shape[2] == cfg.embed["spatial"]["in_ch"], f"Spatial channels must be {cfg.embed['spatial']['in_ch']}"
        assert self.df_w9.shape[1] == cfg.embed["weather"]["in_ch"], f"Weather channels must be {cfg.embed['weather']['in_ch']}"
        assert self.time_feat.shape[1] == cfg.embed["time"]["in_ch"], f"Time channels must be {cfg.embed['time']['in_ch']}"
        assert self.static_arr.shape[1] == cfg.embed["static"]["in_dim"], f"Static dim must be {cfg.embed['static']['in_dim']}"
        if not cfg.data_mode["BASELINE_OCC_ONLY"]:
            assert self.df_dur is not None and self.df_vol is not None, "df_dur and df_vol required for multi-channel"

    def __len__(self):
        return self.Z * len(self.windows)

    def __getitem__(self, idx):
        nW = len(self.windows)
        z = idx // nW
        w = idx % nW
        s, t = self.windows[w]

        # Slice data
        occ_only = slice_occ_block(self.df_occ, s, t)[z]                      # (1, L)
        if not cfg.data_mode["BASELINE_OCC_ONLY"]:
            local = slice_local_block(self.df_occ, self.df_dur, self.df_vol, s, t)[z]  # (3, L)
        else:
            local = occ_only  # Fallback to occ-only
        price = slice_price_block(self.df_ep, self.df_sp, s, t)[z]            # (2, L)
        weather = slice_weather_block(self.df_w9, s, t, self.Z)[z]            # (9, L)
        timeb = slice_time_block(self.time_feat, s, t, self.Z)[z]             # (C_time, L)
        spatial = slice_spatial_block(self.spatial, s, t)[z]                  # (C_spatial, L)
        static = self.static_arr[z].astype(np.float32)                        # (F_static,)

        sample = {
            "zone_id": torch.tensor(z, dtype=torch.long),
            "window": torch.tensor(w, dtype=torch.long),
            "x_local_occ": torch.tensor(occ_only, dtype=torch.float32),
            "x_local": torch.tensor(local, dtype=torch.float32),
            "x_price": torch.tensor(price, dtype=torch.float32),
            "x_weather": torch.tensor(weather, dtype=torch.float32),
            "x_time": torch.tensor(timeb, dtype=torch.float32),
            "x_spatial": torch.tensor(spatial, dtype=torch.float32),
            "x_static": torch.tensor(static, dtype=torch.float32),
        }
        return sample

# Update collate_fn to handle x_local
def collate_fn(batch):
    """
    Collate samples into a batch with channel-first tensors (B, C, L).
    Args:
        batch: list[dict], list of samples from EVEmbeddingDataset
    Returns:
        dict: Batched tensors
    """
    out = {}
    out["zone_id"] = torch.stack([b["zone_id"] for b in batch], dim=0)  # (B,)
    out["window"] = torch.stack([b["window"] for b in batch], dim=0)    # (B,)
    for k, expected_ch in [
        ("x_local_occ", 1),
        ("x_local", 3 if not cfg.data_mode["BASELINE_OCC_ONLY"] else 1),
        ("x_weather", cfg.embed["weather"]["in_ch"]),
        ("x_price", cfg.embed["price"]["in_ch"]),
        ("x_spatial", cfg.embed["spatial"]["in_ch"]),
        ("x_time", cfg.embed["time"]["in_ch"])
    ]:
        stacked = torch.stack([b[k] for b in batch], dim=0)  # (B, C, L) or (B, L, C)
        if stacked.shape[1] != expected_ch:
            stacked = stacked.transpose(1, 2)  # (B, L, C) -> (B, C, L)
        assert stacked.shape[1] == expected_ch, f"{k} must have {expected_ch} channels, got {stacked.shape[1]}"
        out[k] = stacked
    out["x_static"] = torch.stack([b["x_static"] for b in batch], dim=0)  # (B, F_static)
    assert out["x_static"].shape[1] == cfg.embed["static"]["in_dim"], f"x_static shape mismatch: expected {cfg.embed['static']['in_dim']}, got {out['x_static'].shape[1]}"
    return out

# Rebuild DataLoader with updated dataset
windows = build_windows(df_occ.index)
if len(windows) == 0:
    raise ValueError(f"No windows built. Check L({cfg.data_mode['L']}) <= T({len(df_occ)}).")

dataset = EVEmbeddingDataset(
    windows=windows,
    df_occ=df_occ,
    df_dur=df_dur,
    df_vol=df_vol,
    df_ep=df_ep,
    df_sp=df_sp,
    df_w9=df_w9,
    time_feat=time5,
    spatial_Z_T=spatial_Z_T,
    static_arr=static_arr,
    zones=zones
)

loader = DataLoader(
    dataset,
    batch_size=cfg.train["batch_size"],
    shuffle=True,
    collate_fn=collate_fn
)

# --------- Smoke test: take one batch and print shapes ---------
batch = next(iter(loader))
print("zone_id     :", tuple(batch["zone_id"].shape), "(B,)")
print("window      :", tuple(batch["window"].shape), "(B,)")
print("x_local_occ :", tuple(batch["x_local_occ"].shape), f"(B, 1, {cfg.data_mode['L']})")
print("x_price     :", tuple(batch["x_price"].shape), f"(B, {cfg.embed['price']['in_ch']}, {cfg.data_mode['L']})")
print("x_weather   :", tuple(batch["x_weather"].shape), f"(B, {cfg.embed['weather']['in_ch']}, {cfg.data_mode['L']})")
print("x_time      :", tuple(batch["x_time"].shape), f"(B, {cfg.embed['time']['in_ch']}, {cfg.data_mode['L']})")
print("x_spatial   :", tuple(batch["x_spatial"].shape), f"(B, {cfg.embed['spatial']['in_ch']}, {cfg.data_mode['L']})")
print("x_static    :", tuple(batch["x_static"].shape), f"(B, {cfg.embed['static']['in_dim']})")

zone_id     : (50,) (B,)
window      : (50,) (B,)
x_local_occ : (50, 1, 24) (B, 1, 24)
x_price     : (50, 2, 24) (B, 2, 24)
x_weather   : (50, 9, 24) (B, 9, 24)
x_time      : (50, 5, 24) (B, 5, 24)
x_spatial   : (50, 3, 24) (B, 3, 24)
x_static    : (50, 6) (B, 6)


In [19]:
import torch
import torch.nn as nn
from tqdm import tqdm
from contextlib import contextmanager
from torch.utils.data import Dataset, DataLoader
import numpy as np


# ===== Cell 5: Moirai(frozen) Feature Extractor with Adapter =====
# Supports: 
# - Single-channel (occ-only) or multi-channel (occ, dur, vol)
# - Multi-scale patch sizes
# - Adapter to map Moirai embeddings to cfg.dims["D"]

# 1) Device setup
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2) Adapter module
class MoiraiAdapter(nn.Module):
    """
    Adapter to map Moirai embeddings to cfg.dims["D"].
    Args:
        in_dim: int, input dimension (Moirai embedding dim)
        hidden: int, hidden dimension (from cfg.moirai["adapter"]["hidden"])
        out_dim: int, output dimension (from cfg.dims["D"])
        dropout: float, dropout rate (from cfg.moirai["adapter"]["dropout"])
    """
    def __init__(self, in_dim, hidden, out_dim, dropout):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, out_dim)
        )
    
    def forward(self, x):
        return self.layers(x)

# 3) Load frozen Moirai backbone
backbone = MoiraiModule.from_pretrained("Salesforce/moirai-1.0-R-base").to(DEVICE).eval()
for p in backbone.parameters():
    p.requires_grad = False

# 4) Pick top-level encoder
def pick_top_encoder(m):
    """
    Find the top-level encoder module in the Moirai backbone.
    Args:
        m: nn.Module, Moirai backbone
    Returns:
        tuple: (encoder_name, encoder_module)
    """
    if hasattr(m, "encoder"):
        return "encoder", m.encoder
    cand = [(n, mod) for n, mod in m.named_modules() if n == "encoder" or n.endswith(".encoder")]
    if cand:
        cand.sort(key=lambda x: len(x[0]))
        return cand[0]
    cand = [(n, mod) for n, mod in m.named_modules() if "encoder" in n.lower()]
    if cand:
        cand.sort(key=lambda x: len(x[0]))
        return cand[0]
    raise RuntimeError(f"Could not find an encoder module in Moirai backbone: {list(m.named_modules())}")

@contextmanager
def encoder_hook(module):
    """
    Context manager for registering and removing a forward hook.
    Args:
        module: nn.Module, encoder module
    Yields:
        dict: Cache for hidden states
    """
    cache = {"h": None}
    def hook_fn(module, inp, out):
        if torch.is_tensor(out) and out.ndim == 3:
            cache["h"] = out.detach()
    handle = module.register_forward_hook(hook_fn)
    try:
        yield cache
    finally:
        handle.remove()

# 5) Encode a batch with adapter
@torch.no_grad()
def encode_batch(batch_dict: dict, patch_sizes=[1], use_multi_channel=False, adapter=None) -> torch.Tensor:
    """
    Encode a batch using Moirai backbone and apply adapter.
    Args:
        batch_dict: dict, batch from DataLoader
        patch_sizes: list[int], patch sizes for multi-scale encoding
        use_multi_channel: bool, use occ+dur+vol (True) or occ-only (False)
        adapter: nn.Module, adapter to map embeddings to cfg.dims["D"] (optional)
    Returns:
        torch.Tensor: (B, D), adapted embeddings
    """
    x_local = batch_dict["x_local"].to(DEVICE) if use_multi_channel else batch_dict["x_local_occ"].to(DEVICE)  # (B, C, L)
    B, C, L = x_local.shape
    target = x_local.transpose(1, 2).contiguous()  # (B, L, C)

    # Initialize output embeddings
    all_hs = []
    
    # Get encoder module
    enc_name, enc_mod = pick_top_encoder(backbone)
    
    # Multi-scale encoding
    for ps in patch_sizes:
        with encoder_hook(enc_mod) as cache:
            observed_mask = torch.ones((B, L, C), dtype=torch.bool, device=DEVICE)  # (B, L, C)
            prediction_mask = torch.zeros((B, L), dtype=torch.bool, device=DEVICE)  # (B, L)
            patch_size = torch.full((B, L), ps, dtype=torch.long, device=DEVICE)   # (B, L)
            sample_id = torch.arange(B, device=DEVICE).unsqueeze(1).expand(B, L)   # (B, L)
            time_id = torch.arange(L, device=DEVICE).unsqueeze(0).expand(B, L)     # (B, L)
            variate_id = torch.zeros((B, L), dtype=torch.long, device=DEVICE)      # (B, L)

            _ = backbone(
                target=target,
                observed_mask=observed_mask,
                sample_id=sample_id,
                time_id=time_id,
                variate_id=variate_id,
                prediction_mask=prediction_mask,
                patch_size=patch_size,
            )
            hs = cache["h"]
            if hs is None:
                raise RuntimeError(f"Encoder hidden states not captured for patch_size={ps}")
            
            # Handle (B, N, D_moirai) or (N, B, D_moirai)
            if hs.shape[0] == B:
                hs = hs.mean(dim=1)  # (B, D_moirai)
            elif hs.shape[1] == B:
                hs = hs.mean(dim=0)  # (B, D_moirai)
            else:
                raise RuntimeError(f"Unexpected hidden shape: {tuple(hs.shape)} (B={B})")
            all_hs.append(hs)
    
    # Aggregate multi-scale embeddings
    emb = torch.stack(all_hs, dim=0).mean(dim=0)  # (B, D_moirai)
    
    # Apply adapter if provided
    if adapter is not None:
        emb = adapter(emb.to(DEVICE)).cpu()  # (B, D)
        assert emb.shape[1] == cfg.dims["D"], f"Adapted embedding dim must be {cfg.dims['D']}, got {emb.shape[1]}"
    
    return emb

# 6) Run over the loader and save embeddings
patch_sizes = [1, 4, 8] if cfg.moirai.get("multi_scale", False) else [1]
use_multi_channel = not cfg.data_mode["BASELINE_OCC_ONLY"]

# Infer Moirai embedding dimension from first batch
first_batch = next(iter(loader))
with torch.no_grad():
    first_emb = encode_batch(first_batch, patch_sizes=patch_sizes, use_multi_channel=use_multi_channel)  # (B, D_moirai)
D_moirai = first_emb.shape[1]

# Initialize adapter if enabled
adapter = None
if cfg.moirai["adapter"]["use"]:
    adapter = MoiraiAdapter(
        in_dim=D_moirai,
        hidden=cfg.moirai["adapter"]["hidden"],
        out_dim=cfg.dims["D"],
        dropout=cfg.moirai["adapter"]["dropout"]
    ).to(DEVICE).eval()



Moirai Embeddings: 100%|██████████| 1/1 [00:00<00:00,  9.61it/s]

h_local total: (50, 128) | zones idx shape: (50,)
Saved to: moirai_occ_embeddings.pt





In [20]:
all_embs, all_zone_ids = [], []
loader_iterator = tqdm(loader, desc="Moirai Embeddings")
for batch in loader_iterator:
    emb = encode_batch(
        batch,
        patch_sizes=patch_sizes,
        use_multi_channel=use_multi_channel,
        adapter=adapter
    )  # (B, D) or (B, D_moirai) if no adapter
    all_embs.append(emb)
    all_zone_ids.append(batch["zone_id"].cpu())

emb = torch.cat(all_embs, dim=0)  # (N_samples, D) or (N_samples, D_moirai)
zones_taken = torch.cat(all_zone_ids, dim=0)  # (N_samples,)
print(f"h_local total: {tuple(emb.shape)} | zones idx shape: {tuple(zones_taken.shape)}")
print(f"Saved to: {cfg.moirai['save_path']}")

torch.save({"h_local": emb, "zone_id": zones_taken}, cfg.moirai["save_path"])

Moirai Embeddings: 100%|██████████| 1/1 [00:00<00:00,  7.27it/s]

h_local total: (50, 128) | zones idx shape: (50,)
Saved to: moirai_occ_embeddings.pt





In [21]:
class CNN1dEmbedder(nn.Module):
    """
    Generic 1D CNN Embedder.
    Args:
        in_ch: int, input channels (from cfg.embed[...]["in_ch"])
        out_dim: int, output dimension (from cfg.dims["D"])
        hidden: int, hidden channels (from cfg.cnn["hidden"])
        kernel_size: int, convolutional kernel size (from cfg.cnn["kernel_size"])
        dropout: float, dropout rate (from cfg.cnn["dropout"])
    Input: (B, C_in, L)
    Output: (B, out_dim)
    """
    def __init__(self, in_ch: int, out_dim: int, hidden: int, kernel_size: int, dropout: float):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_ch, hidden, kernel_size=kernel_size, padding=kernel_size//2),
            nn.BatchNorm1d(hidden),
            nn.ReLU(),
            nn.Conv1d(hidden, hidden, kernel_size=kernel_size, padding=kernel_size//2),
            nn.BatchNorm1d(hidden),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.pool = nn.AdaptiveAvgPool1d(1)  # Global average over L
        self.proj = nn.Linear(hidden, out_dim)

    def forward(self, x):
        """
        Args:
            x: torch.Tensor, (B, C_in, L)
        Returns:
            torch.Tensor, (B, out_dim)
        """
        if x.ndim != 3:
            raise ValueError(f"Expected 3D input (B, C_in, L), got shape {x.shape}")
        h = self.conv(x)              # (B, hidden, L)
        h = self.pool(h).squeeze(-1)  # (B, hidden)
        h = self.proj(h)              # (B, out_dim)
        return h

In [25]:
import torch
import torch.nn as nn
from types import SimpleNamespace

# ===== Cell 6: CNN1d Embedders for Weather, Price, Spatial, Time =====
# Aligns with cfg.dims["D"] and previous cells

# 1) CNN1d Embedder class
class CNN1dEmbedder(nn.Module):
    """
    Generic 1D CNN Embedder.
    Args:
        in_ch: int, input channels (from cfg.embed[...]["in_ch"])
        out_dim: int, output dimension (from cfg.dims["D"])
        hidden: int, hidden channels (from cfg.cnn.hidden)
        kernel_size: int, convolutional kernel size (from cfg.cnn.kernel_size)
        dropout: float, dropout rate (from cfg.cnn.dropout)
    Input: (B, C_in, L)
    Output: (B, out_dim)
    """
    def __init__(self, in_ch: int, out_dim: int, hidden: int, kernel_size: int, dropout: float):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_ch, hidden, kernel_size=kernel_size, padding=kernel_size//2),
            nn.BatchNorm1d(hidden),
            nn.ReLU(),
            nn.Conv1d(hidden, hidden, kernel_size=kernel_size, padding=kernel_size//2),
            nn.BatchNorm1d(hidden),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.pool = nn.AdaptiveAvgPool1d(1)  # Global average over L
        self.proj = nn.Linear(hidden, out_dim)

    def forward(self, x):
        """
        Args:
            x: torch.Tensor, (B, C_in, L)
        Returns:
            torch.Tensor, (B, out_dim)
        """
        if x.ndim != 3:
            raise ValueError(f"Expected 3D input (B, C_in, L), got shape {x.shape}")
        h = self.conv(x)              # (B, hidden, L)
        h = self.pool(h).squeeze(-1)  # (B, hidden)
        h = self.proj(h)              # (B, out_dim)
        return h

# 2) Initialize embedders
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# CNN configuration (add to cfg if not present)
if not hasattr(cfg, "cnn"):
    cfg.cnn = SimpleNamespace(
        hidden=64,
        kernel_size=5,
        dropout=0.1
    )

# Validate configuration
required_embed_keys = ["weather", "price", "spatial", "time"]
for key in required_embed_keys:
    if key not in cfg.embed or "in_ch" not in cfg.embed[key]:
        raise ValueError(f"cfg.embed['{key}']['in_ch'] is missing")
if "D" not in cfg.dims:
    raise ValueError("cfg.dims['D'] is missing")
if "L" not in cfg.data_mode:
    raise ValueError("cfg.data_mode['L'] is missing")
if "save_path" not in cfg.moirai:
    raise ValueError("cfg.moirai['save_path'] is missing")

embed_weather = CNN1dEmbedder(
    in_ch=cfg.embed["weather"]["in_ch"],
    out_dim=cfg.dims["D"],
    hidden=cfg.cnn.hidden,
    kernel_size=cfg.cnn.kernel_size,
    dropout=cfg.cnn.dropout
).to(DEVICE)

embed_price = CNN1dEmbedder(
    in_ch=cfg.embed["price"]["in_ch"],
    out_dim=cfg.dims["D"],
    hidden=cfg.cnn.hidden,
    kernel_size=cfg.cnn.kernel_size,
    dropout=cfg.cnn.dropout
).to(DEVICE)

embed_spatial = CNN1dEmbedder(
    in_ch=cfg.embed["spatial"]["in_ch"],
    out_dim=cfg.dims["D"],
    hidden=cfg.cnn.hidden,
    kernel_size=cfg.cnn.kernel_size,
    dropout=cfg.cnn.dropout
).to(DEVICE)

embed_time = CNN1dEmbedder(
    in_ch=cfg.embed["time"]["in_ch"],
    out_dim=cfg.dims["D"],
    hidden=cfg.cnn.hidden,
    kernel_size=cfg.cnn.kernel_size,
    dropout=cfg.cnn.dropout
).to(DEVICE)

# Set eval mode
embed_weather.eval()
embed_price.eval()
embed_spatial.eval()
embed_time.eval()

# 3) Process a batch
batch = next(iter(loader))  # From Cell 4 DataLoader
x_weather = batch["x_weather"].to(DEVICE)  # (B, 9, L)
x_price = batch["x_price"].to(DEVICE)      # (B, 2, L)
x_spatial = batch["x_spatial"].to(DEVICE)  # (B, 3|9, L)
x_time = batch["x_time"].to(DEVICE)        # (B, 5, L)

# Validate input shapes
assert x_weather.shape[1] == cfg.embed["weather"]["in_ch"], f"x_weather channels must be {cfg.embed['weather']['in_ch']}"
assert x_price.shape[1] == cfg.embed["price"]["in_ch"], f"x_price channels must be {cfg.embed['price']['in_ch']}"
assert x_spatial.shape[1] == cfg.embed["spatial"]["in_ch"], f"x_spatial channels must be {cfg.embed['spatial']['in_ch']}"
assert x_time.shape[1] == cfg.embed["time"]["in_ch"], f"x_time channels must be {cfg.embed['time']['in_ch']}"
assert x_weather.shape[2] == cfg.data_mode["L"], f"Input length must be {cfg.data_mode['L']}"

with torch.no_grad():
    h_weather = embed_weather(x_weather)  # (B, D)
    h_price = embed_price(x_price)       # (B, D)
    h_spatial = embed_spatial(x_spatial) # (B, D)
    h_time = embed_time(x_time)          # (B, D)

# 4) Print shapes and validate
print(f"h_weather: {h_weather.shape} (expected: (B, {cfg.dims['D']}))")
print(f"h_price  : {h_price.shape} (expected: (B, {cfg.dims['D']}))")
print(f"h_spatial: {h_spatial.shape} (expected: (B, {cfg.dims['D']}))")
print(f"h_time   : {h_time.shape} (expected: (B, {cfg.dims['D']}))")

# 5) Save embeddings (optional, aligned with Cell 5)
all_embs = {
    "h_weather": [],
    "h_price": [],
    "h_spatial": [],
    "h_time": [],
    "zone_id": []
}
for batch in loader:
    x_weather = batch["x_weather"].to(DEVICE)
    x_price = batch["x_price"].to(DEVICE)
    x_spatial = batch["x_spatial"].to(DEVICE)
    x_time = batch["x_time"].to(DEVICE)
    
    with torch.no_grad():
        all_embs["h_weather"].append(embed_weather(x_weather).cpu())
        all_embs["h_price"].append(embed_price(x_price).cpu())
        all_embs["h_spatial"].append(embed_spatial(x_spatial).cpu())
        all_embs["h_time"].append(embed_time(x_time).cpu())
        all_embs["zone_id"].append(batch["zone_id"].cpu())

# Concatenate and save
for k in all_embs:
    all_embs[k] = torch.cat(all_embs[k], dim=0)
save_path = cfg.moirai["save_path"].replace(".pt", "_cnn.pt")
torch.save(all_embs, save_path)
print(f"Saved CNN embeddings to: {save_path}")

h_weather: torch.Size([50, 128]) (expected: (B, 128))
h_price  : torch.Size([50, 128]) (expected: (B, 128))
h_spatial: torch.Size([50, 128]) (expected: (B, 128))
h_time   : torch.Size([50, 128]) (expected: (B, 128))
Saved CNN embeddings to: moirai_occ_embeddings_cnn.pt


In [26]:
import torch
import torch.nn as nn
from types import SimpleNamespace

# ===== Cell 7: StaticMLP Embedder for Static Features =====
# Aligns with cfg.dims["D"] and previous cells

# 1) StaticMLP class
class StaticMLP(nn.Module):
    """
    MLP for embedding static features.
    Args:
        in_dim: int, input dimension (from cfg.embed["static"]["in_dim"])
        out_dim: int, output dimension (from cfg.dims["D"])
        hidden: int, hidden dimension (from cfg.mlp.hidden or cfg.cnn.hidden)
        dropout: float, dropout rate (from cfg.mlp.dropout or cfg.cnn.dropout)
    Input: (B, in_dim)
    Output: (B, out_dim)
    """
    def __init__(self, in_dim: int, out_dim: int, hidden: int, dropout: float):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, out_dim),
        )

    def forward(self, x):
        """
        Args:
            x: torch.Tensor, (B, in_dim)
        Returns:
            torch.Tensor, (B, out_dim)
        """
        if x.ndim != 2:
            raise ValueError(f"Expected 2D input (B, in_dim), got shape {x.shape}")
        return self.net(x)  # (B, out_dim)

# 2) Initialize embedder
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MLP configuration (use cfg.cnn if cfg.mlp is not present)
if not hasattr(cfg, "mlp"):
    if hasattr(cfg, "cnn"):
        cfg.mlp = cfg.cnn  # Reuse CNN config
    else:
        cfg.mlp = SimpleNamespace(
            hidden=128,
            dropout=0.1
        )

# Validate configuration
if "static" not in cfg.embed or "in_dim" not in cfg.embed["static"]:
    raise ValueError("cfg.embed['static']['in_dim'] is missing")
if "D" not in cfg.dims:
    raise ValueError("cfg.dims['D'] is missing")
if "save_path" not in cfg.moirai:
    raise ValueError("cfg.moirai['save_path'] is missing")

# Get F_static from cfg
F_static = cfg.embed["static"]["in_dim"]

embed_static = StaticMLP(
    in_dim=F_static,
    out_dim=cfg.dims["D"],
    hidden=cfg.mlp.hidden,
    dropout=cfg.mlp.dropout
).to(DEVICE)
embed_static.eval()

# 3) Process a batch
batch = next(iter(loader))  # From Cell 4 DataLoader
x_static = batch["x_static"].to(DEVICE)  # (B, F_static)

# Validate input shape
assert x_static.shape[1] == cfg.embed["static"]["in_dim"], f"x_static dim must be {cfg.embed['static']['in_dim']}, got {x_static.shape[1]}"

with torch.no_grad():
    h_static = embed_static(x_static)  # (B, D)

# 4) Print shape and validate
print(f"h_static: {h_static.shape} (expected: (B, {cfg.dims['D']}))")

# 5) Save embeddings (aligned with Cell 5 and Cell 6)
all_embs = {
    "h_static": [],
    "zone_id": []
}
for batch in loader:
    x_static = batch["x_static"].to(DEVICE)
    
    with torch.no_grad():
        all_embs["h_static"].append(embed_static(x_static).cpu())
        all_embs["zone_id"].append(batch["zone_id"].cpu())

# Concatenate and save
for k in all_embs:
    all_embs[k] = torch.cat(all_embs[k], dim=0)
save_path = cfg.moirai["save_path"].replace(".pt", "_static.pt")
torch.save(all_embs, save_path)
print(f"Saved static embeddings to: {save_path}")

h_static: torch.Size([50, 128]) (expected: (B, 128))
Saved static embeddings to: moirai_occ_embeddings_static.pt


In [27]:
import torch
import torch.nn as nn
from types import SimpleNamespace
from tqdm import tqdm

# ===== Cell 8: EV Forecasting Model =====
# Integrates embeddings from Cells 5–7 and predicts multi-horizon output

# 1) Model components
class AlignProjNorm(nn.Module):
    """
    Aligns and normalizes input embeddings.
    Args:
        D: int, input/output dimension (from cfg.dims["D"])
        p_drop: float, dropout rate (from cfg.model.dropout)
        use_gate: bool, whether to use gating (from cfg.model.use_gate)
    Input: (B, D)
    Output: (B, D)
    """
    def __init__(self, D: int, p_drop: float, use_gate: bool):
        super().__init__()
        self.proj = nn.Linear(D, D)
        self.norm = nn.LayerNorm(D)
        self.drop = nn.Dropout(p_drop)
        self.use_gate = use_gate
        if use_gate:
            self.gate = nn.Sequential(
                nn.Linear(D, D//4),
                nn.ReLU(),
                nn.Linear(D//4, 1),
                nn.Sigmoid()
            )

    def forward(self, h):  # h: (B, D)
        a = self.proj(h)
        a = self.norm(a)
        a = self.drop(a)
        if self.use_gate:
            g = self.gate(a)  # (B, 1)
            a = a * g
        return a  # (B, D)

class CrossAttnFusion(nn.Module):
    """
    Fuses local embedding with context embeddings using cross-attention.
    Args:
        D: int, embedding dimension (from cfg.dims["D"])
        nhead: int, number of attention heads (from cfg.model.nhead)
        dropout: float, dropout rate (from cfg.model.dropout)
    Input:
        h_local: (B, D)
        context_list: list of (B, D) tensors
    Output:
        f: (B, D), fused embedding
        w: (B, M), attention weights (M = len(context_list))
    """
    def __init__(self, D: int, nhead: int, dropout: float):
        super().__init__()
        self.q_proj = nn.Linear(D, D)
        self.k_proj = nn.Linear(D, D)
        self.v_proj = nn.Linear(D, D)
        self.attn = nn.MultiheadAttention(D, nhead, dropout=dropout, batch_first=True)
        self.out = nn.Linear(D, D)
        self.norm = nn.LayerNorm(D)

    def forward(self, h_local, context_list):
        Q = self.q_proj(h_local).unsqueeze(1)  # (B, 1, D)
        K = self.k_proj(torch.stack(context_list, dim=1))  # (B, M, D)
        V = self.v_proj(torch.stack(context_list, dim=1))  # (B, M, D)
        out, w = self.attn(Q, K, V)  # out: (B, 1, D), w: (B, 1, M)
        f = self.out(out.squeeze(1))  # (B, D)
        f = self.norm(f + h_local)  # Residual connection
        return f, w.squeeze(1)  # (B, D), (B, M)

class MultiHorizonHead(nn.Module):
    """
    Predicts multi-horizon output from fused embedding.
    Args:
        D: int, input dimension (from cfg.dims["D"])
        H: int, output horizon (from cfg.model.H)
        hidden: int, hidden dimension (from cfg.model.head_hidden)
        p_drop: float, dropout rate (from cfg.model.dropout)
        nonneg: bool, enforce non-negative output (from cfg.model.nonneg)
    Input: (B, D)
    Output: (B, H)
    """
    def __init__(self, D: int, H: int, hidden: int, p_drop: float, nonneg: bool):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(D),
            nn.Linear(D, hidden),
            nn.GELU(),
            nn.Dropout(p_drop),
            nn.Linear(hidden, H),
        )
        self.nonneg = nonneg

    def forward(self, x):
        y = self.net(x)
        if self.nonneg:
            y = torch.nn.functional.softplus(y)  # Enforce non-negative
        return y  # (B, H)

class EVForecastModel(nn.Module):
    """
    Combines modality embeddings and predicts multi-horizon output.
    Args:
        D: int, embedding dimension (from cfg.dims["D"])
        H: int, output horizon (from cfg.model.H)
        nhead: int, number of attention heads (from cfg.model.nhead)
        dropout: float, dropout rate (from cfg.model.dropout)
        use_gate: bool, whether to use gating (from cfg.model.use_gate)
        head_hidden: int, hidden dimension for head (from cfg.model.head_hidden)
        nonneg: bool, enforce non-negative output (from cfg.model.nonneg)
    Input: dict of embeddings (h_local, h_weather, h_price, h_spatial, h_time, h_static)
    Output:
        y_hat: (B, H), predictions
        attn_w: (B, M), attention weights (M=5: weather, price, spatial, time, static)
    """
    def __init__(self, D: int, H: int, nhead: int, dropout: float, use_gate: bool, head_hidden: int, nonneg: bool):
        super().__init__()
        self.align_local = AlignProjNorm(D, dropout, use_gate)
        self.align_weather = AlignProjNorm(D, dropout, use_gate)
        self.align_price = AlignProjNorm(D, dropout, use_gate)
        self.align_spatial = AlignProjNorm(D, dropout, use_gate)
        self.align_time = AlignProjNorm(D, dropout, use_gate)
        self.align_static = AlignProjNorm(D, dropout, use_gate)
        self.fusion = CrossAttnFusion(D, nhead, dropout)
        self.head = MultiHorizonHead(D, H, head_hidden, dropout, nonneg)

    def forward(self, feats):
        a_local = self.align_local(feats["h_local"])
        a_weather = self.align_weather(feats["h_weather"])
        a_price = self.align_price(feats["h_price"])
        a_spatial = self.align_spatial(feats["h_spatial"])
        a_time = self.align_time(feats["h_time"])
        a_static = self.align_static(feats["h_static"])
        context_list = [a_weather, a_price, a_spatial, a_time, a_static]
        f, attn_w = self.fusion(a_local, context_list)
        y_hat = self.head(f)  # (B, H)
        return y_hat, attn_w  # (B, H), (B, M)

# 2) Initialize device and configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Add model configuration if missing
if not hasattr(cfg, "model"):
    cfg.model = SimpleNamespace(
        H=6,
        nhead=8,
        dropout=0.1,
        use_gate=True,
        head_hidden=512,
        nonneg=True
    )

# Validate configuration
required_embed_keys = ["weather", "price", "spatial", "time", "static"]
for key in required_embed_keys:
    if key not in cfg.embed or ("in_ch" not in cfg.embed[key] and key != "static") or ("in_dim" not in cfg.embed[key] and key == "static"):
        raise ValueError(f"cfg.embed['{key}'] missing required field")
if "D" not in cfg.dims:
    raise ValueError("cfg.dims['D'] is missing")
if "L" not in cfg.data_mode:
    raise ValueError("cfg.data_mode['L'] is missing")
if "save_path" not in cfg.moirai:
    raise ValueError("cfg.moirai['save_path'] is missing")

# 3) Load Moirai embeddings
checkpoint = torch.load(cfg.moirai["save_path"], weights_only=True)
h_local = checkpoint["h_local"].to(DEVICE)  # (N_samples, cfg.dims["D"])
zone_ids = checkpoint["zone_id"].to(DEVICE)  # (N_samples,)
assert h_local.shape[1] == cfg.dims["D"], f"h_local dim must be {cfg.dims['D']}, got {h_local.shape[1]}"

# 4) Define embedders
embed_weather = CNN1dEmbedder(
    in_ch=cfg.embed["weather"]["in_ch"],
    out_dim=cfg.dims["D"],
    hidden=cfg.cnn.hidden,
    kernel_size=cfg.cnn.kernel_size,
    dropout=cfg.cnn.dropout
).to(DEVICE)

embed_price = CNN1dEmbedder(
    in_ch=cfg.embed["price"]["in_ch"],
    out_dim=cfg.dims["D"],
    hidden=cfg.cnn.hidden,
    kernel_size=cfg.cnn.kernel_size,
    dropout=cfg.cnn.dropout
).to(DEVICE)

embed_spatial = CNN1dEmbedder(
    in_ch=cfg.embed["spatial"]["in_ch"],
    out_dim=cfg.dims["D"],
    hidden=cfg.cnn.hidden,
    kernel_size=cfg.cnn.kernel_size,
    dropout=cfg.cnn.dropout
).to(DEVICE)

embed_time = CNN1dEmbedder(
    in_ch=cfg.embed["time"]["in_ch"],
    out_dim=cfg.dims["D"],
    hidden=cfg.cnn.hidden,
    kernel_size=cfg.cnn.kernel_size,
    dropout=cfg.cnn.dropout
).to(DEVICE)

embed_static = StaticMLP(
    in_dim=cfg.embed["static"]["in_dim"],
    out_dim=cfg.dims["D"],
    hidden=cfg.mlp.hidden,
    dropout=cfg.mlp.dropout
).to(DEVICE)

# Set eval mode
embed_weather.eval()
embed_price.eval()
embed_spatial.eval()
embed_time.eval()
embed_static.eval()

# 5) Process all batches and save predictions
model = EVForecastModel(
    D=cfg.dims["D"],
    H=cfg.model.H,
    nhead=cfg.model.nhead,
    dropout=cfg.model.dropout,
    use_gate=cfg.model.use_gate,
    head_hidden=cfg.model.head_hidden,
    nonneg=cfg.model.nonneg
).to(DEVICE)
model.eval()

all_preds = {
    "y_pred": [],
    "attn_w": [],
    "zone_id": [],
    "window": []
}
modalities = ["weather", "price", "spatial", "time", "static"]

for batch in tqdm(loader, desc="Forecasting"):
    # Extract batch data
    x_weather = batch["x_weather"].to(DEVICE)  # (B, 9, L)
    x_price = batch["x_price"].to(DEVICE)      # (B, 2, L)
    x_spatial = batch["x_spatial"].to(DEVICE)  # (B, 3|9, L)
    x_time = batch["x_time"].to(DEVICE)        # (B, 5, L)
    x_static = batch["x_static"].to(DEVICE)    # (B, F_static)
    batch_zone_ids = batch["zone_id"].to(DEVICE)  # (B,)
    batch_windows = batch["window"].to(DEVICE)    # (B,)
    B = batch_zone_ids.size(0)

    # Validate input shapes
    assert x_weather.shape[1] == cfg.embed["weather"]["in_ch"], f"x_weather channels must be {cfg.embed['weather']['in_ch']}"
    assert x_price.shape[1] == cfg.embed["price"]["in_ch"], f"x_price channels must be {cfg.embed['price']['in_ch']}"
    assert x_spatial.shape[1] == cfg.embed["spatial"]["in_ch"], f"x_spatial channels must be {cfg.embed['spatial']['in_ch']}"
    assert x_time.shape[1] == cfg.embed["time"]["in_ch"], f"x_time channels must be {cfg.embed['time']['in_ch']}"
    assert x_static.shape[1] == cfg.embed["static"]["in_dim"], f"x_static dim must be {cfg.embed['static']['in_dim']}"
    assert x_weather.shape[2] == cfg.data_mode["L"], f"Input length must be {cfg.data_mode['L']}"

    # Generate embeddings
    with torch.no_grad():
        h_weather = embed_weather(x_weather)  # (B, D)
        h_price = embed_price(x_price)       # (B, D)
        h_spatial = embed_spatial(x_spatial) # (B, D)
        h_time = embed_time(x_time)          # (B, D)
        h_static = embed_static(x_static)    # (B, D)

        # Select h_local for the batch
        selected_indices = torch.zeros(B, dtype=torch.long, device=DEVICE)
        used_indices = torch.zeros(zone_ids.size(0), dtype=torch.bool, device=DEVICE)
        for i, (bzid, w) in enumerate(zip(batch_zone_ids, batch_windows)):
            matching_indices = torch.where(zone_ids == bzid)[0]
            if matching_indices.numel() == 0:
                raise RuntimeError(f"No matching index found for batch_zone_id {bzid.item()}")
            for idx in matching_indices:
                if not used_indices[idx]:
                    selected_indices[i] = idx
                    used_indices[idx] = True
                    break
            else:
                raise RuntimeError(f"No unused index found for batch_zone_id {bzid.item()}, window {w.item()}")
        
        h_local_batch = h_local[selected_indices]  # (B, D)

        # Validate embedding shapes
        assert h_local_batch.shape[1] == cfg.dims["D"], f"h_local_batch dim must be {cfg.dims['D']}"
        for h, name in [(h_weather, "h_weather"), (h_price, "h_price"), (h_spatial, "h_spatial"), 
                        (h_time, "h_time"), (h_static, "h_static")]:
            assert h.shape[1] == cfg.dims["D"], f"{name} dim must be {cfg.dims['D']}"

        # Create feats dictionary
        feats = {
            "h_local": h_local_batch,
            "h_weather": h_weather,
            "h_price": h_price,
            "h_spatial": h_spatial,
            "h_time": h_time,
            "h_static": h_static
        }

        # Run model
        y_pred, attn_w = model(feats)  # y_pred: (B, H), attn_w: (B, M)
        all_preds["y_pred"].append(y_pred.cpu())
        all_preds["attn_w"].append(attn_w.cpu())
        all_preds["zone_id"].append(batch_zone_ids.cpu())
        all_preds["window"].append(batch_windows.cpu())

# 6) Concatenate and save predictions
for k in all_preds:
    all_preds[k] = torch.cat(all_preds[k], dim=0)

# Analyze attention weights
attn_w_mean = all_preds["attn_w"].mean(dim=0)  # (M,)
for i, w in enumerate(attn_w_mean):
    print(f"{modalities[i]} attention weight: {w.item():.4f}")

# Save predictions
save_path = cfg.moirai["save_path"].replace(".pt", "_predictions.pt")
torch.save(all_preds, save_path)
print(f"Saved predictions to: {save_path}")

# 7) Print shapes for verification
print(f"y_pred shape: {tuple(all_preds['y_pred'].shape)} (expected: (N_samples, {cfg.model.H}))")
print(f"attn_w shape: {tuple(all_preds['attn_w'].shape)} (expected: (N_samples, {len(modalities)}))")
print(f"zone_id shape: {tuple(all_preds['zone_id'].shape)}")
print(f"window shape: {tuple(all_preds['window'].shape)}")

Forecasting: 100%|██████████| 1/1 [00:00<00:00,  5.85it/s]

weather attention weight: 0.2026
price attention weight: 0.1994
spatial attention weight: 0.2017
time attention weight: 0.1954
static attention weight: 0.2010
Saved predictions to: moirai_occ_embeddings_predictions.pt
y_pred shape: (50, 6) (expected: (N_samples, 6))
attn_w shape: (50, 5) (expected: (N_samples, 5))
zone_id shape: (50,)
window shape: (50,)



