In [29]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

In [30]:
def make_windows_from_dataset(ds, p_max=20):
    """
    ds: dict with keys X [N,T], A [N,T,p_max_data], p_true [N]
    Returns windowed arrays:
      Z: [M, p_max]
      y: [M]
      a_true: [M, p_max]     (coeff targets aligned to y times)
      p_true_win: [M]        (order label per window)
      sample_id: [M]         (which trajectory each window came from)
      time_idx: [M]          (absolute t index in the original series)
    """
    X = ds["X"]
    A = ds["A"]
    p_true = ds["p_true"]

    N, T = X.shape
    p_max_data = A.shape[2]  # infer from data
    assert A.shape == (N, T, p_max_data)

    M_per = T - p_max
    M = N * M_per

    Z = np.zeros((M, p_max), dtype=np.float32)
    y = np.zeros((M,), dtype=np.float32)
    a_true = np.zeros((M, p_max), dtype=np.float32)
    p_true_win = np.zeros((M,), dtype=np.int64)
    sample_id = np.zeros((M,), dtype=np.int64)
    time_idx = np.zeros((M,), dtype=np.int64)

    idx = 0
    for i in range(N):
        x = X[i]
        # windows: t runs p_max..T-1
        for t in range(p_max, T):
            Z[idx] = x[t-p_max:t][::-1]      # [x_{t-1},...,x_{t-p_max}]
            y[idx] = x[t]
            # take only the first p_max coefficients (or pad if p_max > p_max_data)
            a_true[idx, :min(p_max, p_max_data)] = A[i, t, :min(p_max, p_max_data)]
            p_true_win[idx] = int(p_true[i]) # same label for all windows in traj
            sample_id[idx] = i
            time_idx[idx] = t
            idx += 1

    return Z, y, a_true, p_true_win, sample_id, time_idx

In [31]:
def split_trajectories(pilot, train_frac=0.6, val_frac=0.2, seed=0, method="global"):
    """
    Returns trajectory indices: tr_ids, va_ids, te_ids

    method:
      - "global": random split over all trajectories (Option 2)
      - "by_family": split within each family to preserve proportions (Option 1)
    """
    X = pilot["X"]
    N = X.shape[0]
    rng = np.random.default_rng(seed)

    assert 0 < train_frac < 1
    assert 0 <= val_frac < 1
    assert train_frac + val_frac < 1

    if method == "global":
        perm = rng.permutation(N)
        n_tr = int(round(train_frac * N))
        n_va = int(round(val_frac * N))
        tr_ids = perm[:n_tr]
        va_ids = perm[n_tr:n_tr + n_va]
        te_ids = perm[n_tr + n_va:]
        return tr_ids, va_ids, te_ids

    if method == "by_family":
        if "class_id" not in pilot or pilot["class_id"] is None:
            raise ValueError("pilot must contain 'class_id' for method='by_family'")

        class_id = pilot["class_id"]
        tr_ids, va_ids, te_ids = [], [], []

        for cid in np.unique(class_id):
            ids = np.where(class_id == cid)[0]
            perm = rng.permutation(ids)

            n = len(ids)
            n_tr = int(round(train_frac * n))
            n_va = int(round(val_frac * n))

            tr_ids.extend(perm[:n_tr].tolist())
            va_ids.extend(perm[n_tr:n_tr + n_va].tolist())
            te_ids.extend(perm[n_tr + n_va:].tolist())

        # shuffle within splits so batches mix families
        tr_ids = np.array(tr_ids, dtype=int)
        va_ids = np.array(va_ids, dtype=int)
        te_ids = np.array(te_ids, dtype=int)

        tr_ids = rng.permutation(tr_ids)
        va_ids = rng.permutation(va_ids)
        te_ids = rng.permutation(te_ids)
        return tr_ids, va_ids, te_ids

    raise ValueError(f"Unknown method: {method}")

In [32]:
def subset_pilot(pilot, ids):
    out = {
        "X": pilot["X"][ids],
        "A": pilot["A"][ids],
        "p_true": pilot["p_true"][ids],
    }
    # optional fields if present
    if "class_id" in pilot and pilot["class_id"] is not None:
        out["class_id"] = pilot["class_id"][ids]
    if "noise_std" in pilot and pilot["noise_std"] is not None:
        out["noise_std"] = pilot["noise_std"][ids]
    return out

In [33]:
class MLPTVAR_MultiHead(nn.Module):
    def __init__(self, p_max, hidden=128, depth=3, dropout=0.1):
        super().__init__()
        layers = []
        in_dim = p_max
        for _ in range(depth):
            layers += [
                nn.Linear(in_dim, hidden),
                nn.GELU(),
                nn.LayerNorm(hidden),
                nn.Dropout(dropout),
            ]
            in_dim = hidden
        self.backbone = nn.Sequential(*layers)

        self.coeff_head = nn.Linear(hidden, p_max)   # coeffs in standardized space
        self.bias_head  = nn.Linear(hidden, 1)       # bias in standardized space
        self.p_head     = nn.Linear(hidden, p_max)   # logits for p in {1..p_max}

        self.p_max = p_max

    def forward(self, Zs):
        h = self.backbone(Zs)
        coeffs_s = self.coeff_head(h)                # [B, p_max]
        bias_s = self.bias_head(h).squeeze(-1)       # [B]
        p_logits = self.p_head(h)                    # [B, p_max]

        y_pred_s = (coeffs_s * Zs).sum(dim=1) + bias_s
        return y_pred_s, coeffs_s, bias_s, p_logits

In [34]:
def train_mlp_tvar(
    train_ds, val_ds,
    p_max=20,
    hidden=128, depth=3, dropout=0.05,
    batch_size=512,
    epochs=30,
    lr=2e-3,
    weight_decay=1e-4,
    coeff_loss_w=0.2,     # how hard to match true coefficients
    order_loss_w=0.2,     # how hard to predict p_true
    l1_out_w=1e-5,        # L1 on output coeffs (encourages sparsity)
    eval_every=1,
    patience=5,
    seed=0,
):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.manual_seed(seed)

    # Windowize
    Z_tr, y_tr, a_tr, p_tr, sid_tr, t_tr = make_windows_from_dataset(train_ds, p_max=p_max)
    Z_va, y_va, a_va, p_va, sid_va, t_va = make_windows_from_dataset(val_ds,   p_max=p_max)

    # Standardize using TRAIN windows only
    Z_mean = Z_tr.mean(axis=0, keepdims=True)
    Z_std  = Z_tr.std(axis=0, keepdims=True) + 1e-8
    y_mean = y_tr.mean()
    y_std  = y_tr.std() + 1e-8

    Z_tr_s = (Z_tr - Z_mean) / Z_std
    Z_va_s = (Z_va - Z_mean) / Z_std
    y_tr_s = (y_tr - y_mean) / y_std
    y_va_s = (y_va - y_mean) / y_std

    # Coeff targets in standardized space:
    # coeffs_s_k = a_k * Z_std_k / y_std
    a_tr_s = a_tr * (Z_std.reshape(-1) / y_std)
    a_va_s = a_va * (Z_std.reshape(-1) / y_std)

    # p labels: map 1..p_max -> 0..p_max-1
    p_tr_cls = (p_tr - 1).astype(np.int64)
    p_va_cls = (p_va - 1).astype(np.int64)

    # Torch datasets/loaders
    tr_loader = DataLoader(
        TensorDataset(
            torch.tensor(Z_tr_s, dtype=torch.float32),
            torch.tensor(y_tr_s, dtype=torch.float32),
            torch.tensor(a_tr_s, dtype=torch.float32),
            torch.tensor(p_tr_cls, dtype=torch.long),
        ),
        batch_size=batch_size,
        shuffle=True,
        drop_last=False,
    )

    va_loader = DataLoader(
        TensorDataset(
            torch.tensor(Z_va_s, dtype=torch.float32),
            torch.tensor(y_va_s, dtype=torch.float32),
            torch.tensor(a_va_s, dtype=torch.float32),
            torch.tensor(p_va_cls, dtype=torch.long),
        ),
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
    )

    model = MLPTVAR_MultiHead(p_max=p_max, hidden=hidden, depth=depth, dropout=dropout).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    mse = nn.MSELoss()
    ce  = nn.CrossEntropyLoss()

    best_val = float("inf")
    best_state = None
    bad = 0

    for epoch in range(1, epochs + 1):
        model.train()
        tr_loss = 0.0
        for Zb, yb, ab, pb in tr_loader:
            Zb = Zb.to(device); yb = yb.to(device); ab = ab.to(device); pb = pb.to(device)

            yhat_s, coeffs_s, bias_s, p_logits = model(Zb)

            loss_y = mse(yhat_s, yb)
            loss_a = mse(coeffs_s, ab)
            loss_p = ce(p_logits, pb)
            loss_l1 = coeffs_s.abs().mean()

            loss = loss_y + coeff_loss_w * loss_a + order_loss_w * loss_p + l1_out_w * loss_l1

            opt.zero_grad()
            loss.backward()
            opt.step()
            tr_loss += float(loss.item()) * Zb.size(0)

        tr_loss /= len(tr_loader.dataset)

        if epoch % eval_every == 0:
            model.eval()
            va_loss = 0.0
            with torch.no_grad():
                for Zb, yb, ab, pb in va_loader:
                    Zb = Zb.to(device); yb = yb.to(device); ab = ab.to(device); pb = pb.to(device)
                    yhat_s, coeffs_s, bias_s, p_logits = model(Zb)

                    loss_y = mse(yhat_s, yb)
                    loss_a = mse(coeffs_s, ab)
                    loss_p = ce(p_logits, pb)
                    loss_l1 = coeffs_s.abs().mean()

                    loss = loss_y + coeff_loss_w * loss_a + order_loss_w * loss_p + l1_out_w * loss_l1
                    va_loss += float(loss.item()) * Zb.size(0)

            va_loss /= len(va_loader.dataset)

            print(f"Epoch {epoch:03d} | train={tr_loss:.6f} | val={va_loss:.6f}")

            if va_loss < best_val - 1e-6:
                best_val = va_loss
                best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
                bad = 0
            else:
                bad += 1
                if bad >= patience:
                    print(f"Early stopping at epoch {epoch}. Best val={best_val:.6f}")
                    break

    if best_state is not None:
        model.load_state_dict(best_state)

    pack = {
        "model": model,
        "device": device,
        "Z_mean": Z_mean.astype(np.float32),
        "Z_std":  Z_std.astype(np.float32),
        "y_mean": float(y_mean),
        "y_std":  float(y_std),
        "p_max":  int(p_max),
    }
    return pack

In [35]:
def evaluate_on_test(pack, test_ds):
    model = pack["model"]
    device = pack["device"]
    Z_mean = pack["Z_mean"]
    Z_std  = pack["Z_std"]
    y_mean = pack["y_mean"]
    y_std  = pack["y_std"]
    p_max  = pack["p_max"]

    Z_te, y_te, a_te, p_te, sid_te, t_te = make_windows_from_dataset(test_ds, p_max=p_max)

    # standardize features
    Z_te_s = (Z_te - Z_mean) / Z_std

    model.eval()
    with torch.no_grad():
        Zt = torch.tensor(Z_te_s, dtype=torch.float32).to(device)
        yhat_s, coeffs_s, bias_s, p_logits = model(Zt)

        yhat_s = yhat_s.cpu().numpy()
        coeffs_s = coeffs_s.cpu().numpy()
        p_logits = p_logits.cpu().numpy()

    # unstandardize predictions
    yhat = yhat_s * y_std + y_mean

    # unstandardize coefficients to original AR coefficients:
    # a_hat = coeffs_s * y_std / Z_std
    a_hat = coeffs_s * (y_std / Z_std.reshape(-1))

    # window-level prediction mse
    pred_mse = float(np.mean((yhat - y_te) ** 2))

    # trajectory-level order prediction: aggregate logits per trajectory
    N = test_ds["X"].shape[0]
    logits_sum = np.zeros((N, p_max), dtype=np.float64)
    counts = np.zeros((N,), dtype=np.int64)

    for i in range(len(sid_te)):
        sid = sid_te[i]
        logits_sum[sid] += p_logits[i]
        counts[sid] += 1

    logits_mean = logits_sum / np.maximum(counts[:, None], 1)
    p_hat = np.argmax(logits_mean, axis=1) + 1  # back to 1..p_max

    p_true_traj = test_ds["p_true"].astype(int)
    order_acc = float(np.mean(p_hat == p_true_traj))

    # coefficient MSE masked to active lags per trajectory
    # We compute per-window mse with mask determined by that trajectory's p_true.
    masks = np.zeros((N, p_max), dtype=np.float32)
    for i in range(N):
        masks[i, :p_true_traj[i]] = 1.0

    # apply window-wise mask
    mask_w = masks[sid_te]  # [M, p_max]
    coeff_mse_masked = float(np.sum(((a_hat - a_te) ** 2) * mask_w) / (np.sum(mask_w) + 1e-8))

    out = {
        "pred_mse": pred_mse,
        "order_acc": order_acc,
        "coeff_mse_masked": coeff_mse_masked,
        "p_hat": p_hat,
        "p_true": p_true_traj,
    }
    return out

In [36]:
pilot = np.load("tvar_pilot_T10000.npz")

# 1) Split by trajectory
tr_ids, va_ids, te_ids = split_trajectories(pilot, method="global", seed=0)

train_ds = subset_pilot(pilot, tr_ids)
val_ds   = subset_pilot(pilot, va_ids)
test_ds  = subset_pilot(pilot, te_ids)

In [37]:
# 2) Train the MLP baseline
pack = train_mlp_tvar(
    train_ds, val_ds,
    p_max=6,
    hidden=128, depth=3, dropout=0.05,
    batch_size=512,
    epochs=30,
    lr=2e-3,
    coeff_loss_w=0.2,
    order_loss_w=0.2,
    l1_out_w=1e-5,
    patience=5,
    seed=0
)

# 3) Evaluate
metrics = evaluate_on_test(pack, test_ds)
print(metrics)

Epoch 001 | train=1.313946 | val=1.285309
Epoch 002 | train=1.261396 | val=1.290521
Epoch 003 | train=1.259574 | val=1.288601
Epoch 004 | train=1.258712 | val=1.290433
Epoch 005 | train=1.258545 | val=1.290193
Epoch 006 | train=1.258485 | val=1.286990
Early stopping at epoch 6. Best val=1.285309
{'pred_mse': 0.11363424360752106, 'order_acc': 0.4375, 'coeff_mse_masked': 0.023323909327213517, 'p_hat': array([4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]), 'p_true': array([4, 4, 4, 2, 4, 6, 4, 4, 1, 4, 6, 2, 1, 1, 2, 2])}
