In [1]:
import math, os, sys, time, json, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from typing import Tuple

#___________________________________________________________________________________________________
##
## Reproducibility & Global Setup
##___________________________________________________________________________________________________

def set_seed(s: int = 42):
    """Set random seeds for reproducibility."""
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(s)

#___________________________________________________________________________________________________
##
## Data Utilities
##___________________________________________________________________________________________________

def _squeeze_last_if_singleton(a: np.ndarray) -> np.ndarray:
    """If target is (N, H, 1) or (N, 1), drop the last dim."""
    return a[..., 0] if (a.ndim >= 2 and a.shape[-1] == 1) else a

def load_npz(path: str):
    """Load NPZ with keys: X_train, y_train, X_val, y_val, X_test, y_test."""
    npz = np.load(path)
    X_train = npz["X_train"]
    y_train = _squeeze_last_if_singleton(npz["y_train"])
    X_val   = npz["X_val"]
    y_val   = _squeeze_last_if_singleton(npz["y_val"])
    X_test  = npz["X_test"]
    y_test  = _squeeze_last_if_singleton(npz["y_test"])
    return X_train, y_train, X_val, y_val, X_test, y_test

#___________________________________________________________________________________________________
##
## Normalization & Dataset (MASKING)
##___________________________________________________________________________________________________

def zscore_fit(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Compute per-feature μ/σ from TRAIN windows only. X: [N, L, F]. NaN-aware."""
    mu = np.nanmean(X.reshape(-1, X.shape[-1]), axis=0)
    sd = np.nanstd(X.reshape(-1, X.shape[-1]), axis=0)
    sd = np.where(sd < 1e-8, 1.0, sd)
    return mu.astype(np.float32), sd.astype(np.float32)

def zscore_apply(X: np.ndarray, mu: np.ndarray, sd: np.ndarray) -> np.ndarray:
    return (X - mu) / sd

class TSWindowDataset(Dataset):
    """
    Dataset that builds masks for NaNs and zero-imputes AFTER capturing masks.
    x_mask: [N, L] where 1.0 = valid timestep (all feats present), 0.0 = masked
    y_mask: [N, H] where 1.0 = valid target, 0.0 = masked
    """
    def __init__(self, X, y):
        x_nan_mask = np.isnan(X).any(axis=2)  # [N, L]
        y_nan_mask = np.isnan(y)              # [N, H]
        self.x_mask = torch.as_tensor(~x_nan_mask, dtype=torch.float32)
        self.y_mask = torch.as_tensor(~y_nan_mask, dtype=torch.float32)

        X = X.copy()
        y = y.copy()
        X[np.isnan(X)] = 0.0
        y[np.isnan(y)] = 0.0

        self.X = torch.as_tensor(X, dtype=torch.float32)
        self.y = torch.as_tensor(y, dtype=torch.float32)

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, i):
        return self.X[i], self.y[i], self.x_mask[i], self.y_mask[i]

#___________________________________________________________________________________________________
##
## Device Selection & Autotuning Heuristics
##___________________________________________________________________________________________________

def pick_device():
    """Prefer MPS (Apple Silicon), else CPU."""
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

def autotune_hparams(CFG, device):
    tuned = dict(CFG)
    if device.type == "mps":
        tuned["num_workers"] = max(tuned.get("num_workers", 0), 4)
        return tuned
    tuned["batch_size"]  = min(CFG.get("batch_size", 128), 64)
    tuned["channels"]    = min(CFG.get("channels", 64), 48)
    tuned["num_workers"] = max(CFG.get("num_workers", 0), 2)
    tuned["lr"]          = min(CFG.get("lr", 3e-3), 2e-3)
    return tuned

#___________________________________________________________________________________________________
##
## Model: Dilated Causal CNN
##___________________________________________________________________________________________________

class CausalConv1d(nn.Conv1d):
    def __init__(self, in_ch, out_ch, k, dilation=1):
        super().__init__(in_ch, out_ch, kernel_size=k, dilation=dilation, padding=0, bias=True)
        self.left_pad = (k - 1) * dilation
    def forward(self, x):
        if self.left_pad > 0:
            x = F.pad(x, (self.left_pad, 0))
        return super().forward(x)

class LayerNormChannel(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.ln = nn.LayerNorm(C)
    def forward(self, x):
        x = x.transpose(1, 2).contiguous()
        x = self.ln(x)
        return x.transpose(1, 2).contiguous()

class ResBlock(nn.Module):
    def __init__(self, C, k, dilation, dropout=0.1):
        super().__init__()
        self.conv1 = CausalConv1d(C, C, k, dilation)
        self.act1  = nn.GELU()
        self.norm1 = LayerNormChannel(C)
        self.drop1 = nn.Dropout(dropout)
        self.conv2 = CausalConv1d(C, C, k, dilation)
        self.act2  = nn.GELU()
        self.norm2 = LayerNormChannel(C)
        self.drop2 = nn.Dropout(dropout)
    def forward(self, x):
        residual = x
        x = self.drop1(self.norm1(self.act1(self.conv1(x))))
        x = self.drop2(self.norm2(self.act2(self.conv2(x))))
        return x + residual

class DilatedCausalCNN(nn.Module):
    def __init__(self, in_feats, C=64, k=5, dilations=(1,2,4,8,16,32), horizon=24, dropout=0.1):
        super().__init__()
        self.in_proj = nn.Conv1d(in_feats, C, kernel_size=1)
        self.blocks = nn.ModuleList([ResBlock(C, k=k, dilation=d, dropout=dropout) for d in dilations])
        self.head_norm = LayerNormChannel(C)
        self.head = nn.Linear(C, horizon)
    def forward(self, x):
        # x: [B, L, F] -> [B, F, L]
        x = x.transpose(1, 2)
        x = self.in_proj(x)
        for b in self.blocks:
            x = b(x)
        last_step = self.head_norm(x)[:, :, -1]  # [B, C]
        yhat = self.head(last_step)              # [B, H]
        return yhat

#___________________________________________________________________________________________________
##
## Metrics, Baselines & MASKED LOSS / METRICS
##___________________________________________________________________________________________________

def masked_huber_loss(yhat, y, mask, delta=1.0):
    """
    Computes Huber loss on valid (unmasked) elements.
    mask: 1 for valid, 0 for invalid. Shapes [B, H].
    """
    loss_per_element = F.huber_loss(yhat, y, reduction='none', delta=delta)
    masked_loss = loss_per_element * mask
    mean_loss = masked_loss.sum() / (mask.sum() + 1e-9)
    return mean_loss

def masked_rmse(yhat, y, mask, eps=1e-9):
    err = (yhat - y) ** 2 * mask
    return torch.sqrt(err.sum() / (mask.sum() + eps))

def per_horizon_masked_rmse(yhat, y, mask, eps=1e-9):
    # yhat,y,mask: [N, H]
    err = (yhat - y) ** 2 * mask
    per_h = torch.sqrt(err.sum(dim=0) / (mask.sum(dim=0) + eps))
    overall = torch.sqrt(err.sum() / (mask.sum() + eps))
    return per_h, overall

def receptive_field(k: int, dilations) -> int:
    """
    Two causal convs per ResBlock; each adds (k-1)*d to RF.
    RF = 1 + 2*(k-1)*sum(dilations)
    """
    return 1 + 2 * (k - 1) * sum(dilations)

# Persistence baseline helpers
def last_valid_target(xb, x_mask_b, target_idx):
    """
    xb: [B, L, F] (zero-imputed), x_mask_b: [B, L] (1 valid, 0 invalid if ANY feat NaN)
    target_idx: index of target feature inside X
    Returns: [B] last valid target value; falls back to last step if mask is all zeros.
    """
    B, L, _ = xb.shape
    rev = torch.flip(x_mask_b, dims=[1])           # [B, L]
    last_valid_from_end = torch.argmax(rev, dim=1) # [B]
    last_valid = (L - 1) - last_valid_from_end     # [B]
    rows = torch.arange(B, device=xb.device)
    return xb[rows, last_valid, target_idx]        # [B]

#___________________________________________________________________________________________________
##
## Training Loop (MASKED VAL METRIC)
##___________________________________________________________________________________________________

def make_optimizer(model, lr=3e-4, weight_decay=1e-2):
    # Exclude biases and norms from weight decay
    decay, no_decay = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n.endswith("bias") or ("ln" in n) or ("norm" in n):
            no_decay.append(p)
        else:
            decay.append(p)
    return torch.optim.AdamW(
        [
            {"params": decay, "weight_decay": weight_decay},
            {"params": no_decay, "weight_decay": 0.0},
        ],
        lr=lr, betas=(0.9, 0.999)
    )

def train_loop(model, loaders, device, epochs=80, max_lr=3e-3, clip=1.0, patience=8):
    train_loader, val_loader = loaders
    opt = make_optimizer(model, lr=max_lr, weight_decay=1e-2)
    sched = torch.optim.lr_scheduler.OneCycleLR(
        opt, max_lr=max_lr, epochs=epochs, steps_per_epoch=len(train_loader),
        pct_start=0.1, div_factor=25.0, final_div_factor=1e4
    )
    best_val = float('inf')
    best_state = None
    patience_counter = 0

    for ep in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        t0 = time.time()
        for xb, yb, x_mask_b, y_mask_b in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            y_mask_b = y_mask_b.to(device)

            opt.zero_grad(set_to_none=True)
            yhat = model(xb)
            loss = masked_huber_loss(yhat, yb, y_mask_b)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip)
            opt.step()
            sched.step()
            running_loss += loss.item() * xb.size(0)
        train_loss = running_loss / len(train_loader.dataset)

        # ---- Validation: MASKED RMSE for honest tracking ----
        model.eval()
        val_sum = 0.0
        with torch.no_grad():
            for xb, yb, _, y_mask_b in val_loader:
                xb, yb, y_mask_b = xb.to(device), yb.to(device), y_mask_b.to(device)
                yhat = model(xb)
                val_sum += masked_rmse(yhat, yb, y_mask_b).item() * xb.size(0)
        val_rmse = val_sum / len(val_loader.dataset)

        dt = time.time() - t0
        print(f"Epoch {ep:03d}/{epochs} | train_loss={train_loss:.5f} | val_RMSE(masked)={val_rmse:.5f} | dt={dt:.1f}s")

        if val_rmse < best_val:
            best_val = val_rmse
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {ep} (no improvement for {patience} epochs)")
                break

    if best_state is not None:
        model.load_state_dict(best_state)
    return best_val

#___________________________________________________________________________________________________
##
## Main Execution Block
##___________________________________________________________________________________________________

def main():
    # === Configuration ===
    CFG = {
        "npz": "data/processed/sav_0927_v3.npz",
        "horizon": 24,
        "lookback": 168,
        "batch_size": 128,
        "epochs": 80,
        "lr": 3e-4,  # OneCycle max_lr
        "channels": 64,
        "kernel": 5,
        "dropout": 0.1,
        "dilations": "1,2,4,8,16,32",
        "seed": 42,
        "save": "best_model.pt",
        "num_workers": 0,        # 0 for troubleshooting; autotune may bump this
        "target_feat_idx": -1,   # index of target feature inside X (if applicable)
        "use_torch_compile": True
    }
    print("Configuration:")
    print(json.dumps(CFG, indent=2))

    set_seed(CFG["seed"])
    device = pick_device()
    CFG = autotune_hparams(CFG, device)
    print(f"\nUsing device: {device}")
    print("Autotuned Configuration:")
    print(json.dumps(CFG, indent=2))

    # === Load & Prepare Data ===
    Xtr, ytr, Xva, yva, Xte, yte = load_npz(CFG["npz"])

    # Trim to desired lookback/horizon if necessary
    L, H = CFG["lookback"], CFG["horizon"]
    if Xtr.shape[1] > L: Xtr, Xva, Xte = Xtr[:, -L:], Xva[:, -L:], Xte[:, -L:]
    if ytr.shape[1] > H: ytr, yva, yte = ytr[:, :H], yva[:, :H], yte[:, :H]

    # Z-score normalization (NaN-aware) using TRAIN stats
    mu, sd = zscore_fit(Xtr)
    Xtr = zscore_apply(Xtr, mu, sd)
    Xva = zscore_apply(Xva, mu, sd)
    Xte = zscore_apply(Xte, mu, sd)

    # Create Datasets (masking inside)
    train_ds = TSWindowDataset(Xtr, ytr)
    val_ds   = TSWindowDataset(Xva, yva)
    test_ds  = TSWindowDataset(Xte, yte)

    # Create DataLoaders
    train_loader = DataLoader(train_ds, batch_size=CFG["batch_size"], shuffle=True,  num_workers=CFG["num_workers"])
    val_loader   = DataLoader(val_ds,   batch_size=CFG["batch_size"], shuffle=False, num_workers=CFG["num_workers"])
    test_loader  = DataLoader(test_ds,  batch_size=CFG["batch_size"], shuffle=False, num_workers=CFG["num_workers"])

    F = Xtr.shape[-1]
    print(f"Dataset sizes — train: {len(train_ds)}, val: {len(val_ds)}, test: {len(test_ds)}")
    print(f"Features (F): {F}, Lookback (L): {L}, Horizon (H): {H}\n")

    # === Initialize Model ===
    dilations = tuple(int(x) for x in CFG["dilations"].split(","))
    RF = receptive_field(CFG["kernel"], dilations)
    if L < RF:
        print(f"⚠️  Lookback ({L}) < receptive field ({RF}). Consider increasing lookback or reducing k/dilations.")

    model = DilatedCausalCNN(
        in_feats=F, C=CFG["channels"], k=CFG["kernel"],
        dilations=dilations, horizon=H, dropout=CFG["dropout"]
    ).to(device)

    # Optional torch.compile for PyTorch 2.x
    if CFG.get("use_torch_compile", True):
        try:
            model = torch.compile(model)  # may no-op on older versions/backends
            print("torch.compile enabled.")
        except Exception as e:
            print(f"torch.compile not available/failed: {e}")

    n_params = sum(p.numel() for p in model.parameters())
    print(f"Model params: {n_params:,}\n")

    # === Train Model ===
    print("Starting training...")
    best_val_rmse = train_loop(
        model, (train_loader, val_loader), device,
        epochs=CFG["epochs"], max_lr=CFG["lr"], clip=1.0, patience=8
    )
    print(f"\nTraining finished. Best validation RMSE (masked): {best_val_rmse:.6f}")

    # === Save Model & Normalization Stats ===
    save_path = CFG["save"]
    save_dir = os.path.dirname(save_path)
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
    torch.save({
        "state_dict": model.state_dict(),
        "mu": mu, "sd": sd,
        "config": CFG
    }, save_path)
    print(f"Saved best model and stats to: {save_path}")

    # === Evaluate on Test Set (MASKED metrics + robust persistence) ===
    TARGET_FEAT_IDX = CFG["target_feat_idx"]
    model.eval()
    yhats, ys, lasts, y_masks = [], [], [], []
    with torch.no_grad():
        for xb, yb, x_mask_b, y_mask_b in test_loader:
            xb, yb = xb.to(device), yb.to(device)
            x_mask_b, y_mask_b = x_mask_b.to(device), y_mask_b.to(device)

            yhat = model(xb)
            yhats.append(yhat.cpu())
            ys.append(yb.cpu())
            y_masks.append(y_mask_b.cpu())
            lasts.append(last_valid_target(xb, x_mask_b, TARGET_FEAT_IDX).detach().cpu())

    yhat  = torch.cat(yhats,  dim=0)  # [N, H]
    y     = torch.cat(ys,     dim=0)  # [N, H]
    mask  = torch.cat(y_masks,dim=0)  # [N, H]
    last  = torch.cat(lasts,  dim=0)  # [N]

    base = last.unsqueeze(1).repeat(1, y.shape[1])  # [N, H]
    ph_m, overall_m = per_horizon_masked_rmse(yhat, y, mask)
    ph_b, overall_b = per_horizon_masked_rmse(base,  y, mask)
    ph_skill = 1.0 - (ph_m / (ph_b + 1e-12))
    overall_skill = 1.0 - (overall_m / (overall_b + 1e-12))

    print("\n=== Test Metrics (MASKED) ===")
    print(f"Overall RMSE (model):       {overall_m.item():.4f}")
    print(f"Overall RMSE (persistence): {overall_b.item():.4f}")
    print(f"Overall Skill vs persist:   {overall_skill.item():.4f}")
    print("\nPer-horizon (t+1..t+H):")
    for h in range(H):
        print(f"h+{h+1:02d}: RMSE_model={ph_m[h].item():.4f}  |  RMSE_persist={ph_b[h].item():.4f}  |  Skill={ph_skill[h].item():.4f}")

if __name__ == '__main__':
    main()


Configuration:
{
  "npz": "data/processed/sav_0927_v3.npz",
  "horizon": 24,
  "lookback": 168,
  "batch_size": 128,
  "epochs": 80,
  "lr": 0.0003,
  "channels": 64,
  "kernel": 5,
  "dropout": 0.1,
  "dilations": "1,2,4,8,16,32",
  "seed": 42,
  "save": "best_model.pt",
  "num_workers": 0,
  "target_feat_idx": -1,
  "use_torch_compile": true
}

Using device: mps
Autotuned Configuration:
{
  "npz": "data/processed/sav_0927_v3.npz",
  "horizon": 24,
  "lookback": 168,
  "batch_size": 128,
  "epochs": 80,
  "lr": 0.0003,
  "channels": 64,
  "kernel": 5,
  "dropout": 0.1,
  "dilations": "1,2,4,8,16,32",
  "seed": 42,
  "save": "best_model.pt",
  "num_workers": 4,
  "target_feat_idx": -1,
  "use_torch_compile": true
}
Dataset sizes — train: 60718, val: 13011, test: 13012
Features (F): 5, Lookback (L): 168, Horizon (H): 24

⚠️  Lookback (168) < receptive field (505). Consider increasing lookback or reducing k/dilations.
torch.compile enabled.
Model params: 250,136

Starting training...


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/miniconda3/envs/ai/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/opt/miniconda3/envs/ai/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'TSWindowDataset' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/miniconda3/envs/ai/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/opt/miniconda3/envs/ai/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'TSWindowDataset' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/miniconda3/envs/ai/lib/python3.

RuntimeError: DataLoader worker (pid(s) 9017) exited unexpectedly