In [None]:
# Lag-Attention tvAR pipeline on Lorenz (publication-style, one chart per figure, no explicit colors)

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

# ------------- Global plot style -------------
def set_pub_style():
    plt.rcParams.update({
        "figure.dpi": 120,
        "savefig.dpi": 300,
        "font.size": 12,
        "axes.labelsize": 12,
        "axes.titlesize": 13,
        "legend.fontsize": 10,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "lines.linewidth": 1.8,
        "lines.markersize": 5,
        "axes.spines.top": False,
        "axes.spines.right": False,
        "axes.grid": True,
        "grid.alpha": 0.25,
    })

def prettify(ax, title=None, xlabel=None, ylabel=None, add_legend=False):
    if title: ax.set_title(title)
    if xlabel: ax.set_xlabel(xlabel)
    if ylabel: ax.set_ylabel(ylabel)
    if add_legend: ax.legend(frameon=False)
    ax.margins(x=0.02)

set_pub_style()


# ------------- Lorenz (RK4) -------------
def lorenz(T_steps=300, dt=0.005, sigma=10.0, rho=28.0, beta=8/3, x0=(1.0,1.0,1.0)):
    x, y, z = x0
    xs, ys, zs = [], [], []
    for _ in range(T_steps):
        def f(x,y,z):
            dx = sigma*(y-x)
            dy = x*(rho - z) - y
            dz = x*y - beta*z
            return dx, dy, dz
        k1 = f(x,y,z)
        k2 = f(x + 0.5*dt*k1[0], y + 0.5*dt*k1[1], z + 0.5*dt*k1[2])
        k3 = f(x + 0.5*dt*k2[0], y + 0.5*dt*k2[1], z + 0.5*dt*k2[2])
        k4 = f(x + dt*k3[0], y + dt*k3[1], z + dt*k3[2])
        x += (dt/6.0)*(k1[0] + 2*k2[0] + 2*k3[0] + k4[0])
        y += (dt/6.0)*(k1[1] + 2*k2[1] + 2*k3[1] + k4[1])
        z += (dt/6.0)*(k1[2] + 2*k2[2] + 2*k3[2])
        xs.append(x); ys.append(y); zs.append(z)
    return np.array(xs), np.array(ys), np.array(zs)

xs, ys, zs = lorenz()

# ------------- Helpers -------------
def acf_vals(sig, nlxags=60):
    s = sig - np.mean(sig)
    ac = np.correlate(s, s, mode="full")
    ac = ac[ac.size//2:]
    ac0 = ac[0] if ac[0] != 0 else 1.0
    return (ac / ac0)[:nlags+1]

def embed3(s, tau):
    # returns [N-2*tau, 3]: [s(t), s(t-τ), s(t-2τ)]
    return np.column_stack([s[2*tau:], s[tau:-tau], s[:-2*tau]])

# ------------- Train / test split on x(t) -------------
LAG_MAX   = 256     # number of candidate lags (L)
TRAIN_FR  = 0.80    # 80/20 split
REFRESH_EVERY = 20  # controllable rollout knob: 1 => 1-step ; large => free-run
TAU = 15           # state-space embedding delay (samples)

x = xs.copy()
N = len(x)
ntr = int(TRAIN_FR * (N - LAG_MAX))  # leave room for left padding by L
# scale by train stats
mu_tr  = x[:LAG_MAX + ntr].mean()
std_tr = x[:LAG_MAX + ntr].std() + 1e-12
x_n = (x - mu_tr) / std_tr

# ------------- Lag-attention tvAR (univariate) -------------
def build_lag_bank(x_tensor, L):
    # x_tensor: [B, T], returns [B, T, L] with Xlags[:, t, l] = x_{t-(l+1)}
    B, T = x_tensor.shape
    pads = F.pad(x_tensor, (L, 0))
    lags = [pads[:, L - (l+1) : L - (l+1) + T] for l in range(L)]
    return torch.stack(lags, dim=-1)

def topk_mask_logits(logits, k):
    if (k is None) or (k >= logits.size(-1)):
        return logits
    topk = torch.topk(logits, k, dim=-1)
    mask = torch.full_like(logits, float('-inf'))
    mask.scatter_(-1, topk.indices, topk.values)
    return mask

class LagAttentionTVAR(nn.Module):
    def __init__(self, L=256, d_model=128, n_layers=2, n_heads=4, topk=8, use_var=True):
        super().__init__()
        self.L = L
        self.topk = topk
        self.use_var = use_var

        self.lag_embed = nn.Embedding(L+1, d_model)  # lag index 1..L
        self.val_proj  = nn.Linear(1, d_model)

        # small causal conv context
        self.ctx_conv = nn.Conv1d(1, d_model, kernel_size=9, padding=8)
        self.ctx_proj = nn.Linear(d_model, d_model)

        enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, batch_first=True)
        self.lag_encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)

        self.score = nn.Sequential(
            nn.Linear(d_model*2, d_model),
            nn.GELU(),
            nn.Linear(d_model, 1)
        )
        self.bias_head   = nn.Linear(d_model, 1)
        if use_var:
            self.logvar_head = nn.Linear(d_model, 1)

    def forward(self, x):
        # x: [B, T]
        B, T = x.shape
        L = self.L

        Xlags = build_lag_bank(x, L)               # [B,T,L]
        lag_vals = Xlags.unsqueeze(-1)             # [B,T,L,1]
        lag_ids  = torch.arange(1, L+1, device=x.device).view(1,1,L).expand(B,T,L)

        Hval = self.val_proj(lag_vals).squeeze(-2) # [B,T,L,d]
        Hidx = self.lag_embed(lag_ids)             # [B,T,L,d]
        Hlag = Hval + Hidx                         # [B,T,L,d]

        # per-time encode lag tokens
        Hlag_flat = Hlag.view(B*T, L, -1)
        Henc = self.lag_encoder(Hlag_flat).view(B, T, L, -1)  # [B,T,L,d]

        # causal context from x
        ctx = self.ctx_conv(x.unsqueeze(1))        # [B,d,T+pad]
        ctx = ctx[..., :T].transpose(1,2)          # [B,T,d]
        ctx = self.ctx_proj(ctx)                   # [B,T,d]

        # score lags with context
        ctx_exp = ctx.unsqueeze(2).expand(-1, -1, L, -1)
        pair = torch.cat([Henc, ctx_exp], dim=-1)  # [B,T,L,2d]
        logits = self.score(pair).squeeze(-1)      # [B,T,L]

        logits_masked = topk_mask_logits(logits, k=self.topk)
        w = torch.softmax(logits_masked, dim=-1)   # [B,T,L], sums to 1

        mu_ar = (w * Xlags).sum(dim=-1)            # [B,T]
        c  = self.bias_head(ctx).squeeze(-1)       # [B,T]
        mu = mu_ar + c

        logvar = None
        if self.use_var:
            logvar = self.logvar_head(ctx).squeeze(-1).clamp(-8, 8)
        return mu, logvar, w

def gaussian_nll(y, mu, logvar):
    if logvar is None:
        logvar = torch.zeros_like(mu)
    return 0.5*(logvar + (y - mu)**2 / logvar.exp())

# ------------- Train tvAR on train split (teacher-forced 1-step) -------------
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(0)
model = LagAttentionTVAR(L=LAG_MAX, d_model=128, n_layers=2, n_heads=4, topk=8, use_var=False).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=1e-4)

# use contiguous block up to train end; keep last part for test
x_tr = torch.tensor(x_n[:LAG_MAX + ntr], dtype=torch.float32, device=device).unsqueeze(0)  # [1, Ttr]
EPOCHS = 10  # keep small for demo; increase for better fit
for ep in range(EPOCHS):
    model.train()
    opt.zero_grad()
    mu, logvar, w = model(x_tr)
    # predict current index; you can shift by 1 if preferred—here Xlags already uses past values
    nll = gaussian_nll(x_tr, mu, logvar).mean()
    # mild sparsity on weights
    l1 = w.abs().mean()
    loss = nll + 1e-3 * l1
    loss.backward()
    opt.step()

# ------------- Controlled rollout on test (refresh-every-k) -------------
def hybrid_rollout(model, series_norm, start_idx, n_steps, refresh_every, L):
    """
    series_norm: normalized full series (numpy)
    start_idx: absolute start index in series_norm
    returns denormalized predictions over test window
    """
    model.eval()
    cur = series_norm.copy()
    cur_t = torch.tensor(cur, dtype=torch.float32, device=device).unsqueeze(0)
    preds = []
    with torch.no_grad():
        # we will iterate step-by-step to respect refresh schedule
        for t in range(n_steps):
            # take a slice up to current absolute time (inclusive) for causality
            t_abs = start_idx + t
            x_slice = cur_t[:, :t_abs]  # [1, t_abs]
            mu, _, _ = model(x_slice)
            xhat_t = mu[0, -1].item()
            preds.append(xhat_t)
            if ((t+1) % max(int(refresh_every),1)) != 0:
                # open-loop: append prediction into working series
                cur[t_abs] = xhat_t
                cur_t[0, t_abs] = xhat_t
            # else: refresh with truth implicitly by leaving series as-is
    preds = np.array(preds)
    return preds * std_tr + mu_tr  # de-normalize




KeyboardInterrupt: 

In [2]:
# ======== FAST Lag-Attention tvAR (no Transformer) ========
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

def build_lag_bank_fast(x_tensor, L):
    """
    x_tensor: [B, T]
    Returns Xlags[B, T, L] with Xlags[:, t, l] = x_{t-(l+1)} (left-padded with zeros).
    Vectorized (no Python loops over L).
    """
    B, T = x_tensor.shape
    pads = F.pad(x_tensor, (L, 0))                 # [B, T+L]
    # Build indices so that for each t we gather T rows that start at offset (t .. t+L-1)
    idx_base = torch.arange(T, device=x_tensor.device).view(1, T, 1)  # [1,T,1]
    lag_offsets = torch.arange(L, device=x_tensor.device).view(1, 1, L)  # [1,1,L]
    gather_idx = idx_base + lag_offsets            # [1,T,L]
    # shift by offset (L-1) so the last column corresponds to x_{t-1}
    gather_idx = L - 1 + gather_idx                # [1,T,L]
    Xlags = pads[:, gather_idx.squeeze(0)]         # [B,T,L]
    return Xlags

class LagAttentionTVARFast(nn.Module):
    def __init__(self, L=256, d_model=128, topk=8, use_var=False):
        super().__init__()
        self.L = L
        self.topk = topk
        self.use_var = use_var

        # embeddings for lag indices (1..L) and projected lag values
        self.lag_embed = nn.Embedding(L+1, d_model)     # id 1..L
        self.val_proj  = nn.Linear(1, d_model)

        # strictly causal context (left pad only)
        self.ctx_pad  = nn.ConstantPad1d((8, 0), 0)     # pad left for k=9
        self.ctx_conv = nn.Conv1d(1, d_model, kernel_size=9, padding=0)
        self.ctx_proj = nn.Linear(d_model, d_model)

        # bilinear scorer: score_{t,ℓ} = < Wq * ctx_t , Wk * Hlag_{t,ℓ} >
        self.Wq = nn.Linear(d_model, d_model, bias=False)
        self.Wk = nn.Linear(d_model, d_model, bias=False)

        # bias term from context
        self.bias_head = nn.Linear(d_model, 1)
        if use_var:
            self.logvar_head = nn.Linear(d_model, 1)

        # init a bit smaller for stability
        for m in [self.Wq, self.Wk, self.val_proj, self.bias_head]:
            if hasattr(m, 'weight'):
                nn.init.xavier_uniform_(m.weight)

    def forward(self, x):
        """
        x: [B, T]
        returns: mu[B,T], logvar[B,T] or None, w[B,T,L]
        """
        B, T = x.shape
        L = self.L

        # build lag bank (past-only)
        Xlags = build_lag_bank_fast(x, L)                # [B,T,L]
        lag_vals = Xlags.unsqueeze(-1)                   # [B,T,L,1]
        lag_ids  = torch.arange(1, L+1, device=x.device).view(1,1,L).expand(B,T,L)

        Hval = self.val_proj(lag_vals).squeeze(-2)       # [B,T,L,d]
        Hidx = self.lag_embed(lag_ids)                   # [B,T,L,d]
        Hlag = Hval + Hidx                               # [B,T,L,d]

        # causal context from x
        ctx = self.ctx_conv(self.ctx_pad(x.unsqueeze(1)))   # [B,d,T]
        ctx = ctx.transpose(1, 2)                            # [B,T,d]
        ctx = self.ctx_proj(ctx)                             # [B,T,d]

        # bilinear scores
        q = self.Wq(ctx)                                     # [B,T,d]
        k = self.Wk(Hlag)                                    # [B,T,L,d]
        # score_{t,ℓ} = sum_d q_{t,d} * k_{t,ℓ,d}
        logits = torch.einsum('btd,btld->btl', q, k)         # [B,T,L]

        # optional top-k masking for sparsity/speed in softmax
        if (self.topk is not None) and (self.topk < L):
            topk = torch.topk(logits, self.topk, dim=-1)
            mask = torch.full_like(logits, float('-inf'))
            logits = mask.scatter(-1, topk.indices, topk.values)

        w = torch.softmax(logits, dim=-1)                    # [B,T,L], sums to 1
        mu_ar = (w * Xlags).sum(dim=-1)                      # [B,T]
        c = self.bias_head(ctx).squeeze(-1)                  # [B,T]
        mu = mu_ar + c

        logvar = None
        if self.use_var:
            logvar = self.logvar_head(ctx).squeeze(-1).clamp(-8, 8)
        return mu, logvar, w

def gaussian_nll(y, mu, logvar):
    if logvar is None:
        logvar = torch.zeros_like(mu)
    return 0.5*(logvar + (y - mu)**2 / (logvar.exp() + 1e-8))

# ======== TRAINING (with prints) ========
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(0)

model = LagAttentionTVARFast(L=LAG_MAX, d_model=128, topk=8, use_var=False).to(device)
opt   = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))

x_tr = torch.tensor(x_n[:LAG_MAX + ntr], dtype=torch.float32, device=device).unsqueeze(0)  # [1,Ttr]
x_val = torch.tensor(x_n[LAG_MAX + ntr - 2048 : LAG_MAX + ntr], dtype=torch.float32, device=device).unsqueeze(0)  # last 2048 as val context

EPOCHS = 20
PRINT_EVERY = 1

for ep in range(1, EPOCHS+1):
    t0 = time.time()
    model.train()
    opt.zero_grad(set_to_none=True)

    with torch.cuda.amp.autocast(enabled=(device=="cuda")):
        mu, logvar, w = model(x_tr)  # teacher forced; predicts each t from its past
        # drop warm-up LAG_MAX steps from loss (left padding period)
        nll = gaussian_nll(x_tr[:, LAG_MAX:], mu[:, LAG_MAX:], logvar).mean()
        l1  = w.abs().mean()
        loss = nll + 1e-3 * l1

    scaler.scale(loss).backward()
    scaler.step(opt)
    scaler.update()

    # quick val (no grad)
    model.eval()
    with torch.no_grad():
        mu_v, lv_v, w_v = model(x_val)
        val_loss = gaussian_nll(x_val[:, LAG_MAX:], mu_v[:, LAG_MAX:], lv_v).mean().item()

    dt = time.time() - t0
    if ep % PRINT_EVERY == 0:
        print(f"[Epoch {ep:02d}] train_loss={loss.item():.5f}  val_loss={val_loss:.5f}  |T|={x_tr.shape[1]}  L={model.L}  time={dt:.2f}s")


  scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))
  with torch.cuda.amp.autocast(enabled=(device=="cuda")):


IndexError: index 2707 is out of bounds for dimension 0 with size 2707

In [None]:
# indices for test region (use last 20% after L left pad)
t0 = LAG_MAX + ntr
N_te = N - t0
y_true = x[t0 : t0 + N_te]
yhat_ctrl = hybrid_rollout(model, x_n, start_idx=t0, n_steps=N_te, refresh_every=REFRESH_EVERY, L=LAG_MAX)

# ------------- (3) Test-window prediction -------------
fig, ax = plt.subplots(figsize=(6,3))
ax.plot(y_true, lw=1.6, label="Truth")
ax.plot(yhat_ctrl, lw=1.4, label=f"tvAR (refresh={REFRESH_EVERY})")
prettify(ax, title="Lag-Attention tvAR — test window", xlabel="Time (test index)", ylabel="x(t)", add_legend=True)
plt.tight_layout(); plt.show()

# ------------- (4) Residual ACF (test) -------------
resid = y_true - yhat_ctrl
acf_v = acf_vals(resid, nlags=60)
bound = 1.96 / np.sqrt(max(len(resid),1))
fig, ax = plt.subplots(figsize=(4,3))
markerline, stemlines, baseline = ax.stem(range(len(acf_v)), acf_v, basefmt=" ")
plt.setp(stemlines, linewidth=1.2)
plt.setp(markerline, marker='o', markersize=5)
ax.axhline(bound, linestyle='--', linewidth=1)
ax.axhline(-bound, linestyle='--', linewidth=1)
prettify(ax, title="Residual ACF — tvAR on test", xlabel="Lag", ylabel="Autocorrelation")
plt.tight_layout(); plt.show()

# ------------- (5) y_pred vs y_true (test) -------------
fig, ax = plt.subplots(figsize=(3,3))
ax.scatter(y_true, yhat_ctrl, alpha=0.35, s=10, edgecolors="none")
vmin = min(np.min(y_true), np.min(yhat_ctrl))
vmax = max(np.max(y_true), np.max(yhat_ctrl))
ax.plot([vmin, vmax], [vmin, vmax], lw=1)
yt = y_true - np.mean(y_true)
yp = yhat_ctrl - np.mean(yhat_ctrl)
corr = float((yt @ yp) / np.sqrt((yt @ yt) * (yp @ yp) + 1e-12))
ss_res = np.sum((y_true - yhat_ctrl)**2)
ss_tot = np.sum((y_true - np.mean(y_true))**2)
r2 = 1 - ss_res / (ss_tot + 1e-12)
ax.text(0.05, 0.95, f"$R^2$ = {r2:.3f}\n$\\rho$ = {corr:.3f}",
        transform=ax.transAxes, ha='left', va='top',
        bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, boxstyle="round"))
ax.set_aspect('equal', adjustable='box')
prettify(ax, title="Test scatter: y_pred vs y_true", xlabel="y_true", ylabel="y_pred")
plt.tight_layout(); plt.show()

# ------------- (6) State space — TRUE vs tvAR (variance-matched) -------------
# variance-match for fair geometry comparison
hy = (yhat_ctrl - yhat_ctrl.mean())
hy = hy * (y_true.std() / (hy.std() + 1e-12)) + y_true.mean()
X3_true = embed3(y_true, TAU)
X3_hyb  = embed3(hy,      TAU)

from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
fig = plt.figure(figsize=(7.5, 5.5))
ax = fig.add_subplot(111, projection='3d')
ax.plot(X3_true[:,0], X3_true[:,1], X3_true[:,2], label="TRUE", alpha=0.85)
ax.plot(X3_hyb[:,0],  X3_hyb[:,1],  X3_hyb[:,2],  label=f"tvAR (refresh={REFRESH_EVERY})", alpha=0.85)
ax.set_title(f"State space — TRUE vs tvAR (τ={TAU}, refresh={REFRESH_EVERY})")
ax.set_xlabel("x(t)"); ax.set_ylabel(f"x(t-{TAU})"); ax.set_zlabel(f"x(t-{2*TAU})")
ax.legend(frameon=False)
plt.tight_layout(); plt.show()

# ------------- (7) Lag-weight heatmap over test window -------------
# recompute weights over test with teacher-forcing (to visualize selection cleanly)
model.eval()
with torch.no_grad():
    x_test_tf = torch.tensor(x_n[:t0+N_te], dtype=torch.float32, device=device).unsqueeze(0)
    _, _, w_full = model(x_test_tf)             # [1, T, L]
    w_test = w_full[0, t0:, :].cpu().numpy()    # [N_te, L]

fig, ax = plt.subplots(figsize=(7,3))
# show newest lags on the left if you prefer: w_test[:, :]; here index 0 = lag-1
im = ax.imshow(w_test.T, aspect='auto', origin='lower', interpolation='nearest')
ax.set_ylabel("Lag index (ℓ = 1..L)")
ax.set_xlabel("Time (test index)")
ax.set_title("Lag-attention weights w_ℓ(t) on test")
plt.tight_layout(); plt.show()