
# Dilated Causal CNN Forecaster — Notebook Version

This notebook refactors your `train_cnn_forecaster.py` script into modular, commented cells so you can iterate, test, and refine interactively.  
**Workflow**: Configure → Load/Prep Data → Build Model → Train → Evaluate → Save.

> Notes:
> - Replace argparse with a `CFG` dict for parameters.
> - Uses MPS on Apple Silicon if available, otherwise CPU.
> - Huber loss + AdamW + OneCycle. Early stopping on val RMSE.
> - Metrics include per-horizon RMSE and **skill vs. persistence**.


## Imports & Global Setup

In [1]:

import math, os, sys, time, json, random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from typing import Tuple

# Reproducibility
def set_seed(s: int = 42):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)


## Data Utilities

In [2]:

def _squeeze_last_if_singleton(a: np.ndarray) -> np.ndarray:
    """If target is (N, H, 1) or (N, 1), drop the last dim."""
    return a[..., 0] if (a.ndim >= 1 and a.shape[-1] == 1) else a

def load_npz(path: str):
    """Load NPZ with keys: X_train, y_train, X_val, y_val, X_test, y_test."""
    npz = np.load(path)
    X_train = npz["X_train"]
    y_train = _squeeze_last_if_singleton(npz["y_train"])
    X_val   = npz["X_val"]
    y_val   = _squeeze_last_if_singleton(npz["y_val"])
    X_test  = npz["X_test"]
    y_test  = _squeeze_last_if_singleton(npz["y_test"])
    return X_train, y_train, X_val, y_val, X_test, y_test


## Normalization & Dataset

In [3]:

def zscore_fit(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Compute per-feature μ/σ from TRAIN windows only. X: [N, L, F]."""
    mu = X.reshape(-1, X.shape[-1]).mean(axis=0)
    sd = X.reshape(-1, X.shape[-1]).std(axis=0)
    sd = np.where(sd < 1e-8, 1.0, sd)
    return mu.astype(np.float32), sd.astype(np.float32)

def zscore_apply(X: np.ndarray, mu: np.ndarray, sd: np.ndarray) -> np.ndarray:
    return (X - mu) / sd

class TSWindowDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.as_tensor(X, dtype=torch.float32)  # [N,L,F]
        self.y = torch.as_tensor(y, dtype=torch.float32)  # [N,H]
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i): return self.X[i], self.y[i]


## Device Selection & Autotuning Heuristics

In [4]:

def pick_device():
    """Prefer MPS (Apple Silicon), else CPU (no CUDA here)."""
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

def autotune_hparams(CFG, device):
    """
    Heuristics:
      - MPS (Apple Silicon): keep defaults, bump workers a bit.
      - CPU (assume ~16 GB RAM): clamp to a lighter profile.
    """
    tuned = dict(CFG)
    if device.type == "mps":
        tuned["num_workers"] = max(tuned.get("num_workers", 2), 4)
        return tuned

    tuned["batch_size"] = min(CFG.get("batch_size", 128), 64)
    tuned["channels"]   = min(CFG.get("channels", 64), 48)
    tuned["epochs"]     = min(CFG.get("epochs", 80), 80)
    tuned["num_workers"] = max(CFG.get("num_workers", 2), 2)
    tuned["lr"]         = min(CFG.get("lr", 3e-3), 2e-3)
    return tuned


## Model: Dilated Causal CNN with Residual Blocks

In [5]:

class CausalConv1d(nn.Conv1d):
    """1D convolution with left-padding only (causal; no future leakage)."""
    def __init__(self, in_ch, out_ch, k, dilation=1):
        super().__init__(in_ch, out_ch, kernel_size=k, dilation=dilation, padding=0, bias=True)
        self.left_pad = (k - 1) * dilation
    def forward(self, x):  # x: [B,C,T]
        if self.left_pad > 0:
            x = torch.nn.functional.pad(x, (self.left_pad, 0))
        return super().forward(x)

class LayerNormChannel(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.ln = nn.LayerNorm(C)

    def forward(self, x):      # x: [B,C,T]
        x = x.transpose(1, 2).contiguous()  # [B,T,C]
        x = self.ln(x)
        x = x.transpose(1, 2).contiguous()  # [B,C,T]
        return x

class ResBlock(nn.Module):
    def __init__(self, C, k, dilation, dropout=0.1):
        super().__init__()
        self.conv1 = CausalConv1d(C, C, k, dilation)
        self.act1  = nn.GELU()
        self.norm1 = LayerNormChannel(C)
        self.drop1 = nn.Dropout(dropout)

        self.conv2 = CausalConv1d(C, C, k, dilation)
        self.act2  = nn.GELU()
        self.norm2 = LayerNormChannel(C)
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x):      # x: [B,C,T]
        residual = x.contiguous()
        x = self.conv1(x).contiguous()
        x = self.act1(x)
        x = self.norm1(x).contiguous()
        x = self.drop1(x)

        x = self.conv2(x).contiguous()
        x = self.act2(x)
        x = self.norm2(x).contiguous()
        x = self.drop2(x)

        x = (x + residual).contiguous()
        return x

class DilatedCausalCNN(nn.Module):
    def __init__(self, in_feats, C=64, k=5, dilations=(1,2,4,8,16,32), horizon=24, dropout=0.1):
        super().__init__()
        self.in_proj = nn.Conv1d(in_feats, C, kernel_size=1)
        self.blocks = nn.ModuleList([ResBlock(C, k=k, dilation=d, dropout=dropout) for d in dilations])
        self.head_norm = LayerNormChannel(C)
        self.head = nn.Linear(C, horizon)

    def forward(self, x):      # x: [B,L,F]
        x = x.transpose(1, 2).contiguous()  # [B,F,L]
        x = self.in_proj(x).contiguous()    # [B,C,L]
        for b in self.blocks:
            x = b(x).contiguous()           # [B,C,L]
        x = self.head_norm(x).contiguous()  # [B,C,L]
        last = x[:, :, -1].contiguous()     # [B,C]
        yhat = self.head(last)              # [B,H]
        return yhat


## Metrics & Baselines

In [6]:

def rmse(a, b, dim=None, eps=1e-9):
    return torch.sqrt(torch.mean((a - b)**2, dim=dim) + eps)

def persistence_baseline(x_last, H):
    """Repeat last observed target H times."""
    return x_last.unsqueeze(1).repeat(1, H)

def per_horizon_rmse(yhat, y):
    err = (yhat - y) ** 2
    ph = torch.sqrt(err.mean(dim=0))  # [H]
    overall = torch.sqrt(err.mean())  # scalar
    return ph, overall

def skill_vs_persistence(yhat, y, y_last):
    base = persistence_baseline(y_last, y.shape[1])
    ph_m, overall_m = per_horizon_rmse(yhat, y)
    ph_b, overall_b = per_horizon_rmse(base, y)
    ph_skill = 1.0 - (ph_m / (ph_b + 1e-12))
    overall_skill = 1.0 - (overall_m / (overall_b + 1e-12))
    return ph_skill, overall_skill, ph_m, overall_m, ph_b, overall_b


## Configuration (edit me)

In [7]:

# === EDIT THESE ===
CFG = {
    "npz": "data/processed/sav_0927_v3.npz",  # <-- set your NPZ path
    "horizon": 24,     # H
    "lookback": 168,   # L
    "batch_size": 128,
    "epochs": 80,
    "lr": 3e-3,
    "channels": 64,    # C
    "kernel": 5,       # k
    "dropout": 0.1,
    "dilations": "1,2,4,8,16,32",
    "seed": 42,
    "save": "best_model.pt",
    "auto_tune": False,
    "num_workers": 2,
}

#worker troubleshooting
CFG["num_workers"] = 0

print(json.dumps(CFG, indent=2))


{
  "npz": "data/processed/sav_0927_v3.npz",
  "horizon": 24,
  "lookback": 168,
  "batch_size": 128,
  "epochs": 80,
  "lr": 0.003,
  "channels": 64,
  "kernel": 5,
  "dropout": 0.1,
  "dilations": "1,2,4,8,16,32",
  "seed": 42,
  "save": "best_model.pt",
  "auto_tune": false,
  "num_workers": 0
}


## Load Data, Trim to L/H, Normalize (train μ/σ)

In [8]:

set_seed(CFG["seed"])
device = pick_device()
print(f"Device: {device}")

# Load data
Xtr, ytr, Xva, yva, Xte, yte = load_npz(CFG["npz"])

# Basic shape checks / trims to desired L,H if needed
L, H = CFG["lookback"], CFG["horizon"]
if Xtr.shape[1] != L:
    if Xtr.shape[1] > L:
        Xtr = Xtr[:, -L:, :]; Xva = Xva[:, -L:, :]; Xte = Xte[:, -L:, :]
        print(f"Trimmed lookback to last {L} steps.")
    else:
        raise ValueError(f"X lookback {Xtr.shape[1]} < desired L={L}.")
if ytr.shape[1] != H:
    if ytr.shape[1] > H:
        ytr = ytr[:, :H]; yva = yva[:, :H]; yte = yte[:, :H]
        print(f"Trimmed horizon to first {H} steps.")
    else:
        raise ValueError(f"y horizon {ytr.shape[1]} < desired H={H}.")

# Z-score per feature using TRAIN only
mu, sd = zscore_fit(Xtr)
Xtr = zscore_apply(Xtr, mu, sd)
Xva = zscore_apply(Xva, mu, sd)
Xte = zscore_apply(Xte, mu, sd)

# Optional auto-tune
if CFG["auto_tune"]:
    CFG = autotune_hparams(CFG, device)

print(f"[Config] bs={CFG['batch_size']}  epochs={CFG['epochs']}  C={CFG['channels']}  "
      f"workers={CFG['num_workers']}  lr={CFG['lr']:.2e}")

# Dataloaders
train_ds = TSWindowDataset(Xtr, ytr)
val_ds   = TSWindowDataset(Xva, yva)
test_ds  = TSWindowDataset(Xte, yte)

persistent_workers = False  # has no effect with 0 workers, but keep it explicit

train_loader = DataLoader(train_ds, batch_size=CFG["batch_size"], shuffle=True, drop_last=False,
                          num_workers=CFG["num_workers"], pin_memory=False)
val_loader   = DataLoader(val_ds, batch_size=CFG["batch_size"], shuffle=False, drop_last=False,
                          num_workers=CFG["num_workers"], pin_memory=False)
test_loader  = DataLoader(test_ds, batch_size=CFG["batch_size"], shuffle=False, drop_last=False,
                          num_workers=CFG["num_workers"], pin_memory=False)
F = Xtr.shape[-1]
print(f"Dataset sizes — train: {len(train_ds)}, val: {len(val_ds)}, test: {len(test_ds)}; F={F}, L={L}, H={H}")


Device: mps
[Config] bs=128  epochs=80  C=64  workers=0  lr=3.00e-03
Dataset sizes — train: 60718, val: 13011, test: 13012; F=5, L=168, H=24


In [9]:

def train_loop(model, loaders, device, epochs=80, max_lr=3e-3, clip=1.0, patience=8):
    train_loader, val_loader = loaders
    opt = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=1e-2, betas=(0.9, 0.999))
    steps_per_epoch = len(train_loader)
    sched = torch.optim.lr_scheduler.OneCycleLR(
        opt, max_lr=max_lr, epochs=epochs, steps_per_epoch=steps_per_epoch,
        pct_start=0.1, div_factor=25.0, final_div_factor=1e4, three_phase=False
    )
    criterion = torch.nn.HuberLoss(delta=1.0, reduction='mean')

    best_val = float('inf')
    best_state = None
    bad = 0

    for ep in range(1, epochs+1):
        model.train()
        run_loss = 0.0
        t0 = time.time()
        for xb, yb in train_loader:
            xb = xb.to(device); yb = yb.to(device)
            opt.zero_grad(set_to_none=True)
            yhat = model(xb)
            loss = criterion(yhat, yb)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip)
            opt.step()
            sched.step()
            run_loss += loss.item() * xb.size(0)
        train_loss = run_loss / (len(train_loader.dataset))

        # Validate
        model.eval()
        vsum = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device).contiguous(); yb = yb.to(device).contiguous()
                yhat = model(xb)
                vsum += torch.sqrt(((yhat - yb) ** 2).mean()).item() * xb.size(0)
        val_rmse = vsum / len(val_loader.dataset)

        dt = time.time() - t0
        print(f"Epoch {ep:03d}/{epochs} | train_loss={train_loss:.5f} | val_RMSE={val_rmse:.5f} | dt={dt:.1f}s")

        if val_rmse < best_val - 1e-6:
            best_val = val_rmse
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                print(f"Early stopping at epoch {ep} (no improvement for {patience} epochs)")
                break

    if best_state is not None:
        model.load_state_dict(best_state)
    return best_val


## Initialize Model

In [10]:

dilations = tuple(int(x) for x in CFG["dilations"].split(",") if x.strip())
model = DilatedCausalCNN(in_feats=F, C=CFG["channels"], k=CFG["kernel"], dilations=dilations,
                         horizon=H, dropout=CFG["dropout"]).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model params: {n_params:,} | F={F}, L={L}, H={H}, dilations={dilations}")


Model params: 250,136 | F=5, L=168, H=24, dilations=(1, 2, 4, 8, 16, 32)


## Train

In [11]:

best_val = train_loop(
    model, (train_loader, val_loader), device,
    epochs=CFG["epochs"], max_lr=CFG["lr"], clip=1.0, patience=8
)
print(f"Best val RMSE: {best_val:.6f}")


Epoch 001/80 | train_loss=nan | val_RMSE=nan | dt=35.3s
Epoch 002/80 | train_loss=nan | val_RMSE=nan | dt=32.8s
Epoch 003/80 | train_loss=nan | val_RMSE=nan | dt=32.8s
Epoch 004/80 | train_loss=nan | val_RMSE=nan | dt=32.8s


KeyboardInterrupt: 

## Save Best Model & Normalization

In [None]:

torch.save({"state_dict": model.state_dict(),
            "mu": mu, "sd": sd,
            "config": CFG}, CFG["save"])
print(f"Saved: {CFG['save']}")


## Evaluate on Test: Per-Horizon RMSE & Skill vs Persistence

In [None]:

model.eval()
yhats, ys, lasts = [], [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device); yb = yb.to(device)
        yhat = model(xb)
        yhats.append(yhat.cpu()); ys.append(yb.cpu())

        # Baseline: using last observed 'temp' proxy
        # NOTE: adjust this index if your 'temp' feature isn't the last feature!
        temp_proxy = xb[:, -1, -1].detach().cpu()  # [B]
        lasts.append(temp_proxy)

yhat = torch.cat(yhats, dim=0)  # [N,H]
y    = torch.cat(ys, dim=0)     # [N,H]
last = torch.cat(lasts, dim=0)  # [N]

ph_skill, overall_skill, ph_rmse_m, overall_rmse_m, ph_rmse_b, overall_rmse_b = skill_vs_persistence(yhat, y, last)

print("\n=== Test Metrics ===")
print(f"Overall RMSE (model):       {overall_rmse_m.item():.4f}")
print(f"Overall RMSE (persistence): {overall_rmse_b.item():.4f}")
print(f"Overall Skill vs persist:   {overall_skill.item():.4f}")
print("\nPer-horizon (t+1..t+H):")
for h in range(H):
    print(f"h+{h+1:02d}: RMSE_model={ph_rmse_m[h].item():.4f}  "
          f"RMSE_persist={ph_rmse_b[h].item():.4f}  Skill={ph_skill[h].item():.4f}")



## Tips for Iteration
- **Swap loss**: try `nn.MSELoss()` or quantile loss heads.
- **Alternate schedules**: CosineAnnealingLR or ReduceLROnPlateau.
- **Change baseline**: persistence on other feature(s) or climatology.
- **Diagnostics**: plot learning rates, gradient norms, or per-horizon charts.
- **Stability**: try smaller `lr`, change `channels`, or increase patience.


In [26]:
import torch
print(torch.backends.mps.is_available())  # Should return True
print(torch.backends.mps.is_built())      # Should also return True

True
True
