In [1]:
import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
elif torch.cuda.is_available():
    device = torch.device("cuda")
    x = torch.ones(1, device=device)
    print (x)
else:
    device = torch.device("cpu")
    x = torch.ones(1, device=device)

tensor([1.], device='mps:0')


# Loading the Data

In [2]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import MinMaxScaler, RobustScaler, StandardScaler, MaxAbsScaler

# ─── configuration ─────────────────────────────────────────────────────────
df_path        = "df_all.csv"
forecast_dir   = "bmrs_csv_filled"
mask_dir       = "bmrs_csv_masks"
start_date     = "2021-07-01"
end_date       = "2025-06-30"
train_end_date = "2025-03-01"
val_end_date   = "2025-05-01"
horizon        = 48
use_time_feats = False   # False to skip trig-based time features

# ─── 1) load & crop historical data ────────────────────────────────────────
df = pd.read_csv(df_path, index_col="startTime", parse_dates=True).loc[start_date:end_date]
print(f"[DEBUG] Historical df loaded, length = {len(df)}")

# drop exactly the columns you specified
cols_to_drop = [
    "Forecast Wind",
    "Forecast Solar",
    "Actual Wind",
    "Actual Solar"
]
df = df.drop(columns=cols_to_drop, errors="ignore")

# ─── 2) embedded time features ─────────────────────────────────
# Month 0–11
df["month_idx"]   = df.index.month - 1
# Day-of-week 0–6
df["weekday_idx"] = df.index.dayofweek
# settlement period 0–47
if "Settlement Period" in df.columns:
    df["sp_idx"] = df["Settlement Period"].astype(int) - 1
    df = df.drop(columns=["Settlement Period"], errors="ignore")
else:
    print("[WARN] 'Settlement Period' not found; setting sp_idx to 0")
    df["sp_idx"] = 0
# Day-type: weekday=0, weekend=1
df["dtype_idx"]   = (df.index.dayofweek >= 5).astype(int)

print(df[["month_idx","weekday_idx","sp_idx","dtype_idx"]].head())

# ─── 3) load & stack future forecasts ──────────────────────────────────────
def load_future(name, prefix, horizon=horizon):
    fdf = (
        pd.read_csv(f"{forecast_dir}/{name}.csv",
                    index_col="startTime", parse_dates=True)
          .loc[start_date:end_date]
    )
    # reindex to match historical df exactly
    fdf = fdf.reindex(df.index)
    print(f"[DEBUG] {name}: loaded forecast df, length = {len(fdf)}")

    cols = [f"{prefix}_f{i}" for i in range(1, horizon+1)]
    existing = [c for c in cols if c in fdf.columns]
    print(f"[DEBUG] {name}: existing columns = {existing[:5]}... (total {len(existing)})")

    # fill any missing forecast steps with zero
    mat = fdf[existing].fillna(0).values
    if mat.shape[1] < horizon:
        pad = np.zeros((len(fdf), horizon - mat.shape[1]), dtype=mat.dtype)
        mat = np.hstack([mat, pad])
        print(f"[DEBUG] {name}: padded to horizon, new shape = {mat.shape}")
    else:
        print(f"[DEBUG] {name}: matrix shape = {mat.shape}")
    return mat

demand_mat = load_future("DEMAND_FORECASTS", "demand")
wind_mat   = load_future("WIND_FORECASTS",   "wind")
drm_mat    = load_future("DRM_FORECASTS",    "drm")

X_fut = np.stack([demand_mat, wind_mat, drm_mat], axis=2)
assert X_fut.shape[0] == len(df), "Future forecasts misaligned with historical df"
print(f"[DEBUG] X_fut stacked shape = {X_fut.shape} (should match len(df))")

# ─── 4) train/val/test split ───────────────────────────────────────────────
target_col = "Imbalance Price"
assert target_col in df.columns, f"Target column '{target_col}' not found in df"

# define column groups
cal_cols   = ["month_idx", "weekday_idx", "sp_idx", "dtype_idx"]
hist_cols  = [c for c in df.columns if c not in cal_cols + [target_col]]

# masks
train_mask = df.index < train_end_date
val_mask   = (df.index >= train_end_date) & (df.index < val_end_date)
test_mask  = df.index >= val_end_date

# slices
df_train, df_val, df_test = df[train_mask], df[val_mask], df[test_mask]
X_fut_train = X_fut[train_mask]
X_fut_val   = X_fut[val_mask]
X_fut_test  = X_fut[test_mask]

# targets
y_train = df_train[target_col].values
y_val   = df_val[target_col].values
y_test  = df_test[target_col].values

# historical features
X_train_hist = df_train[hist_cols].values
X_val_hist   = df_val[hist_cols].values
X_test_hist  = df_test[hist_cols].values

# calendar indices
month_train   = df_train["month_idx"].values
weekday_train = df_train["weekday_idx"].values
sp_train      = df_train["sp_idx"].values
dtype_train   = df_train["dtype_idx"].values

month_val   = df_val["month_idx"].values
weekday_val = df_val["weekday_idx"].values
sp_val      = df_val["sp_idx"].values
dtype_val   = df_val["dtype_idx"].values

month_test   = df_test["month_idx"].values
weekday_test = df_test["weekday_idx"].values
sp_test      = df_test["sp_idx"].values
dtype_test   = df_test["dtype_idx"].values

print(f"train → X_hist {X_train_hist.shape}, cal {(month_train.shape[0],4)}, y {y_train.shape}, X_fut {X_fut_train.shape}")
print(f"val   → X_hist {X_val_hist.shape}, cal {(month_val.shape[0],4)}, y {y_val.shape}, X_fut {X_fut_val.shape}")
print(f"test  → X_hist {X_test_hist.shape}, cal {(month_test.shape[0],4)}, y {y_test.shape}, X_fut {X_fut_test.shape}")



[DEBUG] Historical df loaded, length = 70126
                     month_idx  weekday_idx  sp_idx  dtype_idx
startTime                                                     
2021-07-01 00:00:00          6            3       2          0
2021-07-01 00:30:00          6            3       3          0
2021-07-01 01:00:00          6            3       4          0
2021-07-01 01:30:00          6            3       5          0
2021-07-01 02:00:00          6            3       6          0
[DEBUG] DEMAND_FORECASTS: loaded forecast df, length = 70126
[DEBUG] DEMAND_FORECASTS: existing columns = ['demand_f1', 'demand_f2', 'demand_f3', 'demand_f4', 'demand_f5']... (total 48)
[DEBUG] DEMAND_FORECASTS: matrix shape = (70126, 48)
[DEBUG] WIND_FORECASTS: loaded forecast df, length = 70126
[DEBUG] WIND_FORECASTS: existing columns = ['wind_f1', 'wind_f2', 'wind_f3', 'wind_f4', 'wind_f5']... (total 48)
[DEBUG] WIND_FORECASTS: matrix shape = (70126, 48)
[DEBUG] DRM_FORECASTS: loaded forecast df, length = 

# Hyperparameter Setup

In [3]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import TimeSeriesSplit
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import MinMaxScaler, RobustScaler, StandardScaler, MaxAbsScaler


# ─── 2) Class Definitions ───────────────────────────────────────────

# ──────────── a. Dataset Definition ─────────────────────────────────

import torch
from torch.utils.data import Dataset

import torch
from torch.utils.data import Dataset

class MultiFeedDataset(Dataset):
    """
    Dataset producing:
      - x_h: past history covariates (seq_len × n_hist_feats)
      - x_f: all known-future covariates (feed_len × n_fut_feats)
      - y_t: target values for the next fut_len steps (fut_len,)
      - mi, wi, si, di: calendar indices for seq_len+fut_len timestamps
    Decouples look‐ahead length (feed_len = l) from forecast horizon (fut_len = τ_max).
    """
    def __init__(self,
                 hist:        torch.Tensor,  # (N, n_hist_feats)
                 full_fut:    torch.Tensor,  # (N, feed_len, n_fut_feats)
                 y:           torch.Tensor,  # (N,)
                 month_idx:   torch.Tensor,  # (N,)
                 weekday_idx: torch.Tensor,  # (N,)
                 sp_idx:      torch.Tensor,  # (N,)
                 dtype_idx:   torch.Tensor,  # (N,)
                 seq_len:     int,
                 feed_len:    int,
                 fut_len:     int):
        assert fut_len <= feed_len, "forecast horizon (fut_len) must be ≤ look-ahead length (feed_len)"
        self.X_hist       = hist.float()           # (N, n_hist_feats)
        self.X_fut        = full_fut.float()       # (N, feed_len, n_fut_feats)
        self.y_full       = y.float()              # (N,)
        self.month_full   = month_idx.long()       # (N,)
        self.weekday_full = weekday_idx.long()     # (N,)
        self.sp_full      = sp_idx.long()          # (N,)
        self.dtype_full   = dtype_idx.long()       # (N,)
        self.seq_len      = seq_len
        self.feed_len     = feed_len
        self.fut_len      = fut_len

    def __len__(self):
        # number of valid windows = N – seq_len – fut_len + 1
        return self.X_hist.size(0) - self.seq_len - self.fut_len + 1

    def __getitem__(self, idx):
        # 1) history window [idx … idx+seq_len-1]
        x_h = self.X_hist[idx : idx + self.seq_len]                      # (seq_len, n_hist_feats)

        # 2) known-future covariates all the way out to feed_len
        anchor = idx + self.seq_len - 1
        x_f    = self.X_fut[anchor, : self.feed_len, :]                  # (feed_len, n_fut_feats)

        # 3) targets for the next fut_len steps [idx+seq_len … idx+seq_len+fut_len-1]
        start_y = idx + self.seq_len
        end_y   = start_y + self.fut_len
        y_t     = self.y_full[start_y : end_y]                          # (fut_len,)

        # 4) calendar indices for the full window [idx … idx+seq_len+fut_len-1]
        ci_start = idx
        ci_end   = idx + self.seq_len + self.fut_len
        mi = self.month_full[ci_start : ci_end]                         # (seq_len+fut_len,)
        wi = self.weekday_full[ci_start : ci_end]
        si = self.sp_full[ci_start : ci_end]
        di = self.dtype_full[ci_start : ci_end]

        return x_h, x_f, y_t, mi, wi, si, di


# ──────────── b. Layer Definitions ──────────────────────────────────
    
class TimeFeatureEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb_month = nn.Embedding(12, 4)  # 12 months, 4-dim embedding
        self.emb_day   = nn.Embedding(7, 3)   # 7 days, 3-dim embedding
        self.emb_sp    = nn.Embedding(48, 6)  # 48 settlement periods, 6-dim embedding
        self.emb_dtype = nn.Embedding(2, 2)  # 2 types (weekday, weekend), 2-dim embedding

    def forward(self, month_idx, day_idx, sp_idx, dtype_idx):
        # Inputs are LongTensors of shape (B, T)
        e1 = self.emb_month(month_idx)  
        e2 = self.emb_day(day_idx)
        e3 = self.emb_sp(sp_idx)
        e4 = self.emb_dtype(dtype_idx)
        emb = torch.cat((e1, e2, e3, e4), dim=-1)  # Concatenate along the last dimension
        return emb

class VariableSelection(nn.Module):
    """
    Weights each feature of the concatenated input vector to focus on the most relevant features
    """
    def __init__(self, input_dim):
        """
        Input dim = measured_features + calendar_features
                 or forecast_features + calendar_features
        """
        super().__init__()
        # Projects the input to the same dimension
        self.proj = nn.Linear(input_dim, input_dim)

        # nn.init.eye_(self.proj.weight)
        # nn.init.zeros_(self.proj.bias)

    def forward(self, x):
        """
        x: Tensor of shape (B, T, input_dim)
        returns: Tensor of same shape, but with features reweighted
        """
        # Project to scores
        z = self.proj(x)
        # Normalise into weights
        w = torch.softmax(z, dim=-1)  # softmax along feature dim
        # Re-weight inputs
        x_weighted = w * x
        return x_weighted

class BiLSTMEncoder(nn.Module):
    """
    A bidirectional LSTM encoder that maps input sequences of shape
    (B, T, input_dim) → (B, T, 2*h).
    """
    def __init__(self, input_dim, hidden_size=64, num_layers=1, dropout=0.0):
        """
        input_dim   : dimension of each x_t (e.g. m_h + 15 or m_f + 15)
        hidden_size : h, the per-direction LSTM hidden size
        num_layers  : how many stacked LSTM layers
        dropout     : dropout probability between LSTM layers
        """
        super().__init__()
        self.lstm = nn.LSTM(
            input_size = input_dim,
            hidden_size= hidden_size,
            num_layers = num_layers,
            bidirectional=True,
            batch_first=True,
            dropout=dropout if num_layers>1 else 0.0,
        )

    def forward(self, x, hidden=None):
        """
        x: Tensor of shape (B, T, input_dim)
        hidden: optional tuple (h0, c0) each of shape
                (num_layers*2, B, hidden_size)
        returns:
          H : (B, T, 2*hidden_size)  # all time-step outputs
          (h_n, c_n) : final states if you ever need them
        """
        # If no initial state provided, PyTorch defaults to zeros
        H, (h_n, c_n) = self.lstm(x, hidden)
        # H is (B, T, 2*h)
        return H, (h_n, c_n)

class AdditiveAttention(nn.Module):
    def __init__(self, enc_dim, dec_dim, att_dim):
        """
        enc_dim : dimension of encoder states (2*h)
        dec_dim : dimension of decoder hidden state (d_dec)
        att_dim : size of the attention projection (d_att)
        """
        super().__init__()
        self.W = nn.Linear(enc_dim, att_dim, bias=False)
        self.U = nn.Linear(dec_dim, att_dim, bias=True)
        self.v = nn.Linear(att_dim, 1,       bias=False)

    def forward(self, H, s_prev, mask=None):
        """
        H       : (B, T, enc_dim)     encoder outputs
        s_prev  : (B, dec_dim)        decoder hidden state
        mask    : (B, T) optional (0 for pad, 1 for data)
        Returns:
          c     : (B, enc_dim)        context vector
          alpha : (B, T)              attention weights
        """
        # 1) Project encoder states
        #    → (B, T, att_dim)
        H_proj = self.W(H)

        # 2) Project decoder state, unsqueeze to match T
        #    → (B, 1, att_dim) → (B, T, att_dim)
        S_proj = self.U(s_prev).unsqueeze(1)
        S_proj = S_proj.expand_as(H_proj)

        # 3) Combine & nonlinearity
        E = torch.tanh(H_proj + S_proj)      # (B, T, att_dim)

        # 4) Score with v → (B, T, 1), then squeeze → (B, T)
        e = self.v(E).squeeze(-1)

        # 5) Masking (if provided)
        if mask is not None:
            e = e.masked_fill(mask == 0, float('-inf'))

        # 6) Normalize scores → weights
        alpha = F.softmax(e, dim=1)         # (B, T)

        # 7) Compute context as weighted sum
        #    Expand alpha → (B, T, 1) to match H
        alpha_exp = alpha.unsqueeze(-1)     # (B, T, 1)
        c = (alpha_exp * H).sum(dim=1)       # (B, enc_dim)

        return c, alpha


class DualLSTMDecoder(nn.Module):
    """
    Two parallel LSTMCells processing c_hist and c_fut,
    then combining their hidden states to predict one-step outputs.
    """
    def __init__(self,
                 enc_dim:   int,  # = 2 * lstm_hidden
                 dec_hidden:int   # your chosen decoder hidden size
                ):
        super().__init__()
        # LSTM for past‐attention context c_h
        self.lstm_h = nn.LSTMCell(input_size=enc_dim,
                                  hidden_size=dec_hidden)
        # LSTM for future‐attention context c_f
        self.lstm_f = nn.LSTMCell(input_size=enc_dim,
                                  hidden_size=dec_hidden)
        # Combine [s_h; s_f] → one scalar
        self.ffn = nn.Linear(2*dec_hidden, 1)

    def forward(self,
                hist_ctx: torch.Tensor,   # (B, L, enc_dim)
                fut_ctx:  torch.Tensor    # (B, L, enc_dim)
               ) -> torch.Tensor:          # returns (B, L)
        B, L, D = hist_ctx.shape
        # initialize both LSTMCell states from zeros (or from encoder as discussed)
        h_h = hist_ctx.new_zeros(B, self.lstm_h.hidden_size)
        c_h = h_h.clone()
        h_f = h_h.clone()
        c_f = h_h.clone()

        outputs = []
        for t in range(L):
            # step each decoder
            c_h, h_h = self.lstm_h(hist_ctx[:, t, :], (h_h, c_h))
            c_f, h_f = self.lstm_f(fut_ctx[:,  t, :], (h_f, c_f))

            # combine their hidden states
            comb = torch.cat([h_h, h_f], dim=-1)       # (B, 2*dec_hidden)
            y_t  = self.ffn(comb).squeeze(-1)          # (B,)
            outputs.append(y_t)

        return torch.stack(outputs, dim=1)            # (B, L)

# ──────────── c. Model Definitions ──────────────────────────────────

class BiAttnPointForecaster(nn.Module):
    """
    5-stage bi-attentional forecaster with:
      - hist look-back k = hist_len
      - fut look-ahead l = feed_len
      - forecast horizon τ_max = fut_len
    """
    def __init__(self,
                 num_hist_feats,
                 num_fut_feats,
                 num_time_feats,
                 lstm_hidden,
                 dec_hidden,
                 attn_dim,
                 hist_len,
                 feed_len,    # ← new
                 fut_len):    # ← renamed τ_max
        super().__init__()
        self.hist_len  = hist_len
        self.feed_len  = feed_len
        self.fut_len   = fut_len

        # 1) Time feature embedding
        self.time_embed = TimeFeatureEmbedding()

        # 2) Variable selection
        self.var_select_past   = VariableSelection(num_hist_feats  + num_time_feats)
        self.var_select_future = VariableSelection(num_fut_feats   + num_time_feats)

        # 3) Bidirectional LSTM encoders
        self.enc_hist = BiLSTMEncoder(input_dim=num_hist_feats  + num_time_feats, hidden_size=lstm_hidden)
        self.enc_fut  = BiLSTMEncoder(input_dim=num_fut_feats   + num_time_feats, hidden_size=lstm_hidden)

        # 4) Dual additive attention
        enc_dim = 2 * lstm_hidden
        self.attn_hist = AdditiveAttention(enc_dim=enc_dim, dec_dim=dec_hidden, att_dim=attn_dim)
        self.attn_fut  = AdditiveAttention(enc_dim=enc_dim, dec_dim=dec_hidden, att_dim=attn_dim)

        # 4a) Initial state projection
        self.init_h = nn.Linear(enc_dim, dec_hidden)
        self.init_c = nn.Linear(enc_dim, dec_hidden)

        # 5) Decoder
        self.decoder = DualLSTMDecoder(enc_dim=2*lstm_hidden, dec_hidden=dec_hidden)


    def forward(self,
                cont_hist,    # (B, hist_len, num_hist_feats)
                cont_fut,     # (B, feed_len, num_fut_feats)
                month_idx,    # (B, hist_len+feed_len)
                weekday_idx,  # (B, hist_len+feed_len)
                sp_idx,       # (B, hist_len+feed_len)
                dtype_idx,    # (B, hist_len+feed_len)
                mask_hist=None,
                mask_fut=None  # (B, feed_len)
               ) -> torch.Tensor:  # returns (B, fut_len)
        B = cont_hist.size(0)

        # 1) Embed all calendar features in one go (length hist_len+feed_len)
        emb = self.time_embed(month_idx, weekday_idx, sp_idx, dtype_idx)
        emb_hist = emb[:, :self.hist_len, :]                  # (B, hist_len,  time_dim)
        emb_fut  = emb[:, self.hist_len : self.hist_len+self.feed_len, :]  # (B, feed_len, time_dim)

        # 2) Variable‐selection
        x_hist = torch.cat([cont_hist, emb_hist], dim=-1)     # (B, hist_len,  m_h+time_dim)
        x_hist = self.var_select_past(x_hist)                 # same shape

        x_fut  = torch.cat([cont_fut, emb_fut], dim=-1)       # (B, feed_len, m_f+time_dim)
        x_fut  = self.var_select_future(x_fut)

        # 3) Bidir LSTM encode
        H_hist, (h_n, c_n) = self.enc_hist(x_hist)            # (B, hist_len, 2*h)
        H_fut,  _         = self.enc_fut(x_fut)               # (B, feed_len, 2*h)

        # 3a) Seed decoder from final hist-encoder state
        h_fwd, h_bwd = h_n[-2], h_n[-1]   # each (B, h)
        c_fwd, c_bwd = c_n[-2], c_n[-1]
        h_cat = torch.cat([h_fwd, h_bwd], dim=-1)  # (B, 2*h)
        c_cat = torch.cat([c_fwd, c_bwd], dim=-1)
        h_h = self.init_h(h_cat);  c_h = self.init_c(c_cat)
        h_f = h_h.clone();        c_f = c_h.clone()

        # 4&5) Decode τ_max steps, attending over full look-ahead l each time
        outputs = []
        for t in range(self.fut_len):
            ctx_h, _ = self.attn_hist(H_hist, h_h, mask=mask_hist)
            ctx_f, _ = self.attn_fut( H_fut,  h_f, mask=mask_fut)  # attends over all l

            h_h, c_h = self.decoder.lstm_h(ctx_h, (h_h, c_h))
            h_f, c_f = self.decoder.lstm_f(ctx_f, (h_f, c_f))

            comb = torch.cat([h_h, h_f], dim=-1)
            y_t  = self.decoder.ffn(comb).squeeze(-1)
            outputs.append(y_t)

        return torch.stack(outputs, dim=1)  # (B, fut_len)


# ─── 3) Device ───────────────────────────────────────────────────────

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")


# ─── 4) Time‐feature lists and scaler ────────────────────────────────


# ─── 5) CV Splitter & Hyperparameters ───────────────────────────────

TRANSFORMER_FACTORY = {
    "MinMax": MinMaxScaler,
    "Robust": RobustScaler,
    "Standard": StandardScaler,
    "MaxAbs": MaxAbsScaler
}
MODEL_FACTORY = {
    "BiAttnPointForecaster": BiAttnPointForecaster
}
LOSS_FACTORY = {
    "MAE": nn.L1Loss,
    "MSE": nn.MSELoss,
    "Huber": nn.SmoothL1Loss
}

# look-back, look-ahead, forecast horizon
seq_len   = 24
feed_len  = 24
fut_len   = 24 

# model widths
lstm_hidden = 24  
dec_hidden  = 12  
attn_dim    = 24
num_layers  = 1   
dropout     = 0.0

# training
batch_size = 48       
lr         = 1e-4     
patience   = 20       
max_epochs = 200      
scaler_used = "MaxAbs"
model_used  = "BiAttnPointForecaster"
loss_used   = "Huber"
beta        = 0.01 # For Huber loss only
notes       = None

# ─── 6) Model Tag & Directory ────────────────────────────────────────         

md = {
    "model":        model_used,
    "seq_len":      seq_len,
    "feed_len":     feed_len,
    "horizon":      fut_len,
    "lstm_hidden":  lstm_hidden,
    "dec_hidden":   dec_hidden,
    "num_layers":   num_layers,
    "batch_size":   batch_size,
    "learning_rate":lr,
    "max_epochs":   max_epochs,
    "patience":     patience,
    "scaler":       scaler_used,
    "loss":         loss_used,
    **({"notes": notes} if notes is not None else {}),
}

initials = lambda s: "".join(w[0] for w in s.split("_"))

parts = []
for k, v in md.items():
    if k in ("model", "scaler", "loss", "notes"):
        s = str(v)
    else:
        s = str(v)
        if isinstance(v, float) and s.startswith("0."):
            s = s.replace("0.", ".")
        s = f"{initials(k)}{s}"
    parts.append(s)

tag = "_".join(parts)

models_root = "models"
candidate   = os.path.join(models_root, tag)
version     = 0
while os.path.exists(candidate):
    version   += 1
    candidate = f"{os.path.join(models_root, tag)}_v{version}"


Using device: mps


# Training

In [4]:
# ─── TEST-EVALUATION + EARLY STOPPING ────────────────────────────────
import os, json, copy, joblib
import numpy as np
import torch, random, sys
from datetime import datetime, timezone
from sklearn.metrics import mean_absolute_error, mean_squared_error

# ─── Create directory ──────────────────────────────────────────────
os.makedirs(candidate, exist_ok=True)
base_dir = candidate
print(f"→ Saving run in: {base_dir}")

# ─── 1) SCALE TRAIN/VAL/TEST ──────────────────────────────────────────────
# scale historical
scaler_X = TRANSFORMER_FACTORY[scaler_used]()
X_train_hist_scaled = scaler_X.fit_transform(X_train_hist)
X_val_hist_scaled   = scaler_X.transform(X_val_hist)
X_test_hist_scaled  = scaler_X.transform(X_test_hist)

# scale future
n_fut_feats = X_fut_train.shape[2]
scaler_F    = TRANSFORMER_FACTORY[scaler_used]()
flat_F_train = X_fut_train.reshape(-1, n_fut_feats)
flat_F_train = scaler_F.fit_transform(flat_F_train)
X_fut_train_scaled = flat_F_train.reshape(X_fut_train.shape)

flat_F_val   = X_fut_val.reshape(-1, n_fut_feats)
flat_F_val   = scaler_F.transform(flat_F_val)
X_fut_val_scaled = flat_F_val.reshape(X_fut_val.shape)

flat_F_test  = X_fut_test.reshape(-1, n_fut_feats)
flat_F_test  = scaler_F.transform(flat_F_test)
X_fut_test_scaled = flat_F_test.reshape(X_fut_test.shape)

# scale targets
scaler_y     = TRANSFORMER_FACTORY[scaler_used]()
y_train_scaled = scaler_y.fit_transform(y_train.reshape(-1,1)).flatten()
y_val_scaled   = scaler_y.transform(y_val.reshape(-1,1)).flatten()
y_test_scaled  = scaler_y.transform(y_test.reshape(-1,1)).flatten()

# ─── 2) BUILD DATALOADERS ──────────────────────────────────────────────
def to_tensor(x, dtype):
    return torch.tensor(x, dtype=dtype)

train_ds = MultiFeedDataset(
    hist       = to_tensor(X_train_hist_scaled, torch.float32),
    full_fut   = to_tensor(X_fut_train_scaled,  torch.float32),
    y          = to_tensor(y_train_scaled,      torch.float32),
    month_idx  = to_tensor(month_train,         torch.long),
    weekday_idx= to_tensor(weekday_train,       torch.long),
    sp_idx     = to_tensor(sp_train,            torch.long),
    dtype_idx  = to_tensor(dtype_train,         torch.long),
    seq_len    = seq_len,
    feed_len   = feed_len,
    fut_len    = fut_len
)
val_ds = MultiFeedDataset(
    hist       = to_tensor(X_val_hist_scaled,   torch.float32),
    full_fut   = to_tensor(X_fut_val_scaled,    torch.float32),
    y          = to_tensor(y_val_scaled,        torch.float32),
    month_idx  = to_tensor(month_val,           torch.long),
    weekday_idx= to_tensor(weekday_val,         torch.long),
    sp_idx     = to_tensor(sp_val,              torch.long),
    dtype_idx  = to_tensor(dtype_val,           torch.long),
    seq_len    = seq_len,
    feed_len   = feed_len,
    fut_len    = fut_len
)
test_ds = MultiFeedDataset(
    hist       = to_tensor(X_test_hist_scaled,  torch.float32),
    full_fut   = to_tensor(X_fut_test_scaled,   torch.float32),
    y          = to_tensor(y_test_scaled,       torch.float32),
    month_idx  = to_tensor(month_test,          torch.long),
    weekday_idx= to_tensor(weekday_test,        torch.long),
    sp_idx     = to_tensor(sp_test,             torch.long),
    dtype_idx  = to_tensor(dtype_test,          torch.long),
    seq_len    = seq_len,
    feed_len   = feed_len,
    fut_len    = fut_len
)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, pin_memory=True)

# ─── 3) INSTANTIATE MODEL, OPTIMIZER, CRITERION ──────────────────────
model     = MODEL_FACTORY[model_used](
                num_hist_feats = X_train_hist_scaled.shape[1],
                num_fut_feats  = n_fut_feats,
                num_time_feats = sum([4,3,6,2]),
                lstm_hidden    = lstm_hidden,
                dec_hidden     = dec_hidden,
                attn_dim       = attn_dim,
                hist_len       = seq_len,
                feed_len       = feed_len,
                fut_len        = fut_len
            ).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = (LOSS_FACTORY[loss_used](beta=beta)
             if loss_used.startswith("Huber")
             else LOSS_FACTORY[loss_used]())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-7
            )
last_lrs = scheduler.get_last_lr()
# ─── 4) TRAIN w/ EARLY STOPPING ON VAL ───────────────────────────────
best_val, epochs_no_improve, best_ckpt = float('inf'), 0, None

for epoch in range(1, max_epochs+1):
    model.train()
    train_loss = 0.0
    for x_h, x_f, y_t, mi, wi, si, di in train_loader:
        x_h, x_f, y_t = x_h.to(device), x_f.to(device), y_t.to(device)
        mi, wi, si, di = mi.to(device), wi.to(device), si.to(device), di.to(device)

        optimizer.zero_grad()
        out  = model(x_h, x_f, mi, wi, si, di)
        loss = criterion(out, y_t)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x_h, x_f, y_t, mi, wi, si, di in val_loader:
            x_h, x_f, y_t = x_h.to(device), x_f.to(device), y_t.to(device)
            mi, wi, si, di = mi.to(device), wi.to(device), si.to(device), di.to(device)
            out = model(x_h, x_f, mi, wi, si, di)
            val_loss += criterion(out, y_t).item()
    val_loss /= len(val_loader)

    scheduler.step(val_loss)
    if scheduler.get_last_lr()[0] != last_lrs[0]:
        print(f"→ LR reduced from {last_lrs[0]:.2e} to {scheduler.get_last_lr()[0]:.2e}")
    last_lrs = scheduler.get_last_lr()

    print(f"[Epoch {epoch:03d}] train={train_loss:.5f}  val={val_loss:.5f}")
    if val_loss < best_val:
        best_val, epochs_no_improve = val_loss, 0
        best_ckpt = {
            'model':     copy.deepcopy(model.state_dict()),
            'optimizer': copy.deepcopy(optimizer.state_dict())
        }
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"→ early stopping after {epoch} epochs")
            break

# load best checkpoint
model.load_state_dict(best_ckpt['model'])
optimizer.load_state_dict(best_ckpt['optimizer'])

# ─── 5) INFERENCE ON TEST ──────────────────────────────────────────
model.eval()
preds_all, trues_all = [], []
with torch.no_grad():
    for x_h, x_f, y_t, mi, wi, si, di in test_loader:
        x_h, x_f = x_h.to(device), x_f.to(device)
        mi, wi, si, di = mi.to(device), wi.to(device), si.to(device), di.to(device)
        out = model(x_h, x_f, mi, wi, si, di)
        preds_all.append(out.cpu().numpy())
        trues_all.append(y_t.numpy())

preds_all = np.concatenate(preds_all, axis=0)
trues_all = np.concatenate(trues_all, axis=0)

preds = scaler_y.inverse_transform(preds_all.reshape(-1,1)).flatten()
trues = scaler_y.inverse_transform(trues_all.reshape(-1,1)).flatten()
err   = trues - preds

mae   = mean_absolute_error(trues, preds)
rmse  = np.sqrt(mean_squared_error(trues, preds))
smape = np.mean(2.0 * np.abs(err) / (np.abs(trues) + np.abs(preds) + 1e-8)) * 100
huber_vals = np.where(np.abs(err) <= beta,
                      0.5 * err**2 / beta,
                      np.abs(err) - 0.5 * beta)
huber = huber_vals.mean()

print(f"\nTEST → MAE={mae:.4f}, RMSE={rmse:.4f}, SMAPE={smape:.4f}%, Huber={huber:.4f}")
# ─── 6) BUILD METADATA & SAVE ────────────────────────────────────────
class NpTorchJSONEncoder(json.JSONEncoder):
    def default(self, o):
        if isinstance(o, np.generic):   return o.item()
        if isinstance(o, np.ndarray):   return o.tolist()
        if isinstance(o, torch.Tensor): return o.detach().cpu().tolist()
        if isinstance(o, torch.device): return str(o)
        if isinstance(o, datetime):     return o.isoformat()
        return super().default(o)

env_meta = {
    "seed_torch":  torch.initial_seed(),
    "seed_numpy":  np.random.get_state()[1][0],
    "seed_python": random.getstate()[1][0],
    "run_time":    datetime.now(timezone.utc).isoformat()
}

data_meta = {
    "start":     df_train.index.min().strftime('%Y-%m-%d'),
    "train_end": train_end_date,
    "val_end":   val_end_date,
    "end":       df_test.index.max().strftime('%Y-%m-%d'),
    "seq_len":   seq_len,
    "feed_len":  feed_len,
    "horizon":   fut_len,
    "n_train":   len(train_loader.dataset),
    "n_val":     len(val_loader.dataset),
    "n_test":    len(test_loader.dataset)
}

feat_meta = {
    "hist_feats": hist_cols,
    "time_feats": cal_cols,
    "n_fut_feats": n_fut_feats
}

loader_meta = {
    "batch_size":  batch_size,
    "pin_memory":  True,
    "shuffle":     {"train": True, "val": False, "test": False}
}

hyperparams = {
    "model":        model_used,
    "seq_len":      seq_len,
    "feed_len":     feed_len,
    "horizon":      fut_len,
    "lstm_hidden":  lstm_hidden,
    "dec_hidden":   dec_hidden,
    "attn_dim":     attn_dim,
    "num_layers":   1,
    "batch_size":   batch_size,
    "learning_rate":lr,
    "max_epochs":   max_epochs,
    "patience":     patience,
    "scaler":       scaler_used,
    "loss":         loss_used,
    **({"beta": beta} if loss_used.startswith("Huber") else {})
}

optim_meta = {
    "type": optimizer.__class__.__name__,
    "lr":   optimizer.defaults.get("lr")
}

sched_meta = {
    "type":     scheduler.__class__.__name__,
    "factor":   getattr(scheduler, "factor", None),
    "patience": getattr(scheduler, "patience", None)
}

earlystop_meta = {
    "max_epochs": max_epochs,
    "patience":   patience
}

metrics_meta = {
    "mae":   mae,
    "rmse":  rmse,
    "smape": smape,
    "huber": huber
}

# Save model & scalers
torch.save({
    "model_state_dict":     model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "scheduler_state_dict": scheduler.state_dict()
}, os.path.join(base_dir, "torch_model.pt"))

joblib.dump({
    "scaler_X": scaler_X,
    "scaler_F": scaler_F,
    "scaler_y": scaler_y
}, os.path.join(base_dir, "scalers.joblib"))

# Save summary
with open(os.path.join(base_dir, "test_summary.json"), "w") as f:
    json.dump({
        "environment": env_meta,
        "data":        data_meta,
        "features":    feat_meta,
        "dataloader":  loader_meta,
        "hyperparams": hyperparams,
        "optimizer":   optim_meta,
        "scheduler":   sched_meta,
        "early_stop":  earlystop_meta,
        "metrics":     metrics_meta
    }, f, indent=2, cls=NpTorchJSONEncoder)

print(f"✅ Saved all outputs to {base_dir}")


→ Saving run in: models/BiAttnPointForecaster_sl24_fl24_h24_lh24_dh12_nl1_bs48_lr.0001_me200_p20_MaxAbs_Huber_v1
[Epoch 001] train=0.01302  val=0.00389
[Epoch 002] train=0.01035  val=0.00423
[Epoch 003] train=0.00905  val=0.00420
[Epoch 004] train=0.00881  val=0.00400
[Epoch 005] train=0.00868  val=0.00395
[Epoch 006] train=0.00858  val=0.00375
[Epoch 007] train=0.00848  val=0.00373
[Epoch 008] train=0.00841  val=0.00387
[Epoch 009] train=0.00836  val=0.00383
[Epoch 010] train=0.00831  val=0.00401
[Epoch 011] train=0.00827  val=0.00354
[Epoch 012] train=0.00824  val=0.00356
[Epoch 013] train=0.00821  val=0.00391
[Epoch 014] train=0.00819  val=0.00423
[Epoch 015] train=0.00815  val=0.00372
[Epoch 016] train=0.00812  val=0.00353
[Epoch 017] train=0.00810  val=0.00355
[Epoch 018] train=0.00808  val=0.00351
[Epoch 019] train=0.00805  val=0.00367
[Epoch 020] train=0.00803  val=0.00354
[Epoch 021] train=0.00802  val=0.00356
[Epoch 022] train=0.00823  val=0.00349
[Epoch 023] train=0.00798  va