<a href="https://colab.research.google.com/github/neeluu876/nsf-hdr-ml-challenge-2026/blob/main/Monkeys.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Neural Activity Forecasting with Cross-Day Generalization

This notebook implements a multi-step neural forecasting model for Î¼ECoG data.
The focus is on predicting future neural activity while handling day-to-day
recording drift through normalization and delta-based prediction.

In [None]:
!pip -q install timm torchinfo
import os, numpy as np, torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

In [None]:
DATA_DIR = "/content"

print("Files in /content:")
for f in sorted(os.listdir(DATA_DIR)):
    if f.endswith(".npz"):
        print(" -", f)


Files in /content:
 - train_data_affi.npz
 - train_data_affi_2024-03-20_private.npz
 - train_data_beignet.npz
 - train_data_beignet_2022-06-01_private.npz
 - train_data_beignet_2022-06-02_private.npz


In [None]:
def load_npz(path):
    z = np.load(path)
    arr = z[list(z.keys())[0]]
    return arr.astype(np.float32)

paths = {
    "affi_train": os.path.join(DATA_DIR, "train_data_affi.npz"),
    "affi_val":   os.path.join(DATA_DIR, "train_data_affi_2024-03-20_private.npz"),
    "be_train":   os.path.join(DATA_DIR, "train_data_beignet.npz"),
    "be_val1":    os.path.join(DATA_DIR, "train_data_beignet_2022-06-01_private.npz"),
    "be_val2":    os.path.join(DATA_DIR, "train_data_beignet_2022-06-02_private.npz"),
}

# Load
affi_train = load_npz(paths["affi_train"])
affi_val   = load_npz(paths["affi_val"])

be_train = load_npz(paths["be_train"])
be_val   = np.concatenate([load_npz(paths["be_val1"]), load_npz(paths["be_val2"])], axis=0)

print("affi_train:", affi_train.shape, "affi_val:", affi_val.shape)
print("be_train:",   be_train.shape,   "be_val:",   be_val.shape)


affi_train: (985, 20, 239, 9) affi_val: (162, 20, 239, 9)
be_train: (700, 20, 89, 9) be_val: (158, 20, 89, 9)


In [None]:
class NeuralForecastDataset(Dataset):
    def __init__(self, data_np, init_steps=10, use_all_features=True, eps=1e-6):
        self.data = data_np.astype(np.float32)
        self.init_steps = init_steps
        self.use_all_features = use_all_features
        self.eps = eps

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]  # (20,C,F)

        if not self.use_all_features:
            x = x[..., :1]  # only feature 0

        x_in  = x[:self.init_steps]      # (10,C,F')
        x_out = x[self.init_steps:]      # (10,C,F')

        mean = x_in.mean(axis=0, keepdims=True)           # (1,C,F')
        std  = x_in.std(axis=0, keepdims=True) + self.eps # (1,C,F')

        x_in_n  = (x_in  - mean) / std
        x_out_n = (x_out - mean) / std

        last = x_in_n[-1:]           # (1,C,F')
        y_delta = x_out_n - last     # (10,C,F')

        return (
            torch.from_numpy(x_in_n),    # (10,C,F')
            torch.from_numpy(y_delta),  # (10,C,F')
            torch.from_numpy(last),     # (1,C,F')
            torch.from_numpy(mean),     # (1,C,F')
            torch.from_numpy(std),      # (1,C,F')
        )


In [None]:
class TemporalBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, dilation=1, dropout=0.15):
        super().__init__()
        padding = (kernel_size - 1) * dilation
        self.conv1 = nn.Conv1d(in_ch, out_ch, kernel_size, padding=padding, dilation=dilation)
        self.conv2 = nn.Conv1d(out_ch, out_ch, kernel_size, padding=padding, dilation=dilation)
        self.dropout = nn.Dropout(dropout)
        self.downsample = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else None

    def forward(self, x):
        y = self.conv1(x)
        y = y[..., :x.size(-1)]  # crop to keep length
        y = F.gelu(y)
        y = self.dropout(y)

        y = self.conv2(y)
        y = y[..., :x.size(-1)]
        y = F.gelu(y)
        y = self.dropout(y)

        res = x if self.downsample is None else self.downsample(x)
        return y + res

class TCNForecaster(nn.Module):
    def __init__(self, in_dim, hidden=512, levels=5, kernel_size=3, dropout=0.15):
        super().__init__()
        layers = []
        ch_in = in_dim
        for i in range(levels):
            dilation = 2**i
            layers.append(TemporalBlock(ch_in, hidden, kernel_size=kernel_size, dilation=dilation, dropout=dropout))
            ch_in = hidden
        self.net = nn.Sequential(*layers)
        self.head = nn.Conv1d(hidden, in_dim, kernel_size=1)

    def forward(self, x):
        # x: (B,10,D)
        x = x.transpose(1,2)   # (B,D,10)
        y = self.net(x)        # (B,H,10)
        y = self.head(y)       # (B,D,10)
        return y.transpose(1,2)  # (B,10,D)


In [None]:
def train_one_epoch(model, loader, opt):
    model.train()
    total = 0.0
    for x_in, y_delta, last, mean, std in loader:
        x_in = x_in.to(device)
        y_delta = y_delta.to(device)

        B,T,C,Fdim = x_in.shape
        x_flat = x_in.reshape(B, T, C*Fdim)
        y_flat = y_delta.reshape(B, T, C*Fdim)

        pred = model(x_flat)
        loss = F.mse_loss(pred, y_flat)

        opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        total += loss.item()

    return total / max(1, len(loader))

@torch.no_grad()
def eval_one_epoch(model, loader):
    model.eval()
    total = 0.0
    for x_in, y_delta, last, mean, std in loader:
        x_in = x_in.to(device)
        y_delta = y_delta.to(device)

        B,T,C,Fdim = x_in.shape
        x_flat = x_in.reshape(B, T, C*Fdim)
        y_flat = y_delta.reshape(B, T, C*Fdim)

        pred = model(x_flat)
        loss = F.mse_loss(pred, y_flat)
        total += loss.item()

    return total / max(1, len(loader))

@torch.no_grad()
def future_feature0_mse(model, loader):
    model.eval()
    total = 0.0
    count = 0

    for x_in, y_delta, last, mean, std in loader:
        x_in = x_in.to(device)
        y_delta = y_delta.to(device)
        last = last.to(device)
        mean = mean.to(device)
        std = std.to(device)

        B,T,C,Fdim = x_in.shape
        x_flat = x_in.reshape(B, T, C*Fdim)

        pred_delta = model(x_flat).reshape(B, T, C, Fdim)
        pred_future_n = last + pred_delta
        true_future_n = last + y_delta

        pred_future = pred_future_n * std + mean
        true_future = true_future_n * std + mean

        pred0 = pred_future[..., 0]  # (B,10,C)
        true0 = true_future[..., 0]

        total += F.mse_loss(pred0, true0, reduction="sum").item()
        count += pred0.numel()

    return total / count


In [None]:
def train_model(monkey_name, train_np, val_np, use_all_features=True,
                epochs=60, batch_size=32, hidden=512, levels=5, lr=2e-4, wd=1e-4, dropout=0.15):

    init_steps = 10
    train_ds = NeuralForecastDataset(train_np, init_steps=init_steps, use_all_features=use_all_features)
    val_ds   = NeuralForecastDataset(val_np,   init_steps=init_steps, use_all_features=use_all_features)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=64, shuffle=False, num_workers=2, pin_memory=True)

    # infer dims
    x0, *_ = train_ds[0]
    C = x0.shape[1]
    Fdim = x0.shape[2]
    D = C * Fdim

    model = TCNForecaster(in_dim=D, hidden=hidden, levels=levels, dropout=dropout).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)

    save_path = f"tcn_{monkey_name}.pth"
    best = float("inf")

    for ep in range(epochs):
        tr = train_one_epoch(model, train_loader, opt)
        va = eval_one_epoch(model, val_loader)
        va_f0 = future_feature0_mse(model, val_loader)
        sched.step()

        if va_f0 < best:
            best = va_f0
            torch.save({
                "state_dict": model.state_dict(),
                "monkey_name": monkey_name,
                "use_all_features": use_all_features,
                "init_steps": init_steps,
                "trained_C": C,
                "trained_Fdim": Fdim,
                "trained_D": D,
                "hidden": hidden,
                "levels": levels,
                "dropout": dropout,
            }, save_path)

        if ep % 5 == 0 or ep == epochs-1:
            print(f"[{monkey_name}] ep {ep:02d} train_mse(delta)={tr:.6f} "
                  f"val_mse(delta)={va:.6f} val_future_f0_mse={va_f0:.6f} best={best:.6f}")

    print("Saved best:", save_path)
    return save_path


In [None]:
be_ckpt = train_model(
    monkey_name="beignet",
    train_np=be_train,
    val_np=be_val,
    use_all_features=True,
    epochs=60,
    batch_size=32,
    hidden=512,
    levels=5,
    lr=2e-4,
    wd=1e-4,
    dropout=0.15
)


[beignet] ep 00 train_mse(delta)=11.579854 val_mse(delta)=18.320062 val_future_f0_mse=69491.913640 best=69491.913640
[beignet] ep 05 train_mse(delta)=10.300733 val_mse(delta)=18.355276 val_future_f0_mse=64090.525729 best=64090.525729
[beignet] ep 10 train_mse(delta)=8.917003 val_mse(delta)=18.583972 val_future_f0_mse=65564.767686 best=64090.525729
[beignet] ep 15 train_mse(delta)=8.142272 val_mse(delta)=18.601347 val_future_f0_mse=68425.131674 best=64090.525729
[beignet] ep 20 train_mse(delta)=7.555680 val_mse(delta)=18.775513 val_future_f0_mse=69865.545840 best=64090.525729
[beignet] ep 25 train_mse(delta)=7.085939 val_mse(delta)=18.809107 val_future_f0_mse=70054.267103 best=64090.525729
[beignet] ep 30 train_mse(delta)=6.738588 val_mse(delta)=18.871667 val_future_f0_mse=69755.492419 best=64090.525729
[beignet] ep 35 train_mse(delta)=6.462330 val_mse(delta)=18.833643 val_future_f0_mse=69730.397611 best=64090.525729
[beignet] ep 40 train_mse(delta)=6.263483 val_mse(delta)=18.877181 val

In [None]:
affi_ckpt = train_model(
    monkey_name="affi",
    train_np=affi_train,
    val_np=affi_val,
    use_all_features=True,
    epochs=60,
    batch_size=16,   # smaller to avoid OOM
    hidden=512,
    levels=5,
    lr=2e-4,
    wd=1e-4,
    dropout=0.15
)


[affi] ep 00 train_mse(delta)=13.696932 val_mse(delta)=13.723002 val_future_f0_mse=62698.405496 best=62698.405496
[affi] ep 05 train_mse(delta)=11.614823 val_mse(delta)=14.932104 val_future_f0_mse=61763.957229 best=59531.340751
[affi] ep 10 train_mse(delta)=10.223691 val_mse(delta)=14.454315 val_future_f0_mse=64737.936960 best=57095.671140
[affi] ep 15 train_mse(delta)=9.219204 val_mse(delta)=14.312815 val_future_f0_mse=56419.205124 best=56419.205124
[affi] ep 20 train_mse(delta)=8.288452 val_mse(delta)=14.269698 val_future_f0_mse=61208.918973 best=56419.205124
[affi] ep 25 train_mse(delta)=7.653831 val_mse(delta)=14.303302 val_future_f0_mse=60432.249434 best=56419.205124
[affi] ep 30 train_mse(delta)=7.119524 val_mse(delta)=14.341990 val_future_f0_mse=59165.721866 best=56419.205124
[affi] ep 35 train_mse(delta)=6.717621 val_mse(delta)=14.331668 val_future_f0_mse=59254.636211 best=56419.205124
[affi] ep 40 train_mse(delta)=6.416847 val_mse(delta)=14.345446 val_future_f0_mse=58841.98679

In [None]:
%%writefile model.py
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class TemporalBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, dilation=1, dropout=0.0):
        super().__init__()
        padding = (kernel_size - 1) * dilation
        self.conv1 = nn.Conv1d(in_ch, out_ch, kernel_size, padding=padding, dilation=dilation)
        self.conv2 = nn.Conv1d(out_ch, out_ch, kernel_size, padding=padding, dilation=dilation)
        self.dropout = nn.Dropout(dropout)
        self.downsample = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else None

    def forward(self, x):
        y = self.conv1(x)
        y = y[..., :x.size(-1)]
        y = F.gelu(y)
        y = self.dropout(y)

        y = self.conv2(y)
        y = y[..., :x.size(-1)]
        y = F.gelu(y)
        y = self.dropout(y)

        res = x if self.downsample is None else self.downsample(x)
        return y + res

class TCNForecaster(nn.Module):
    def __init__(self, in_dim, hidden=512, levels=5, kernel_size=3, dropout=0.0):
        super().__init__()
        layers = []
        ch_in = in_dim
        for i in range(levels):
            dilation = 2**i
            layers.append(TemporalBlock(ch_in, hidden, kernel_size=kernel_size, dilation=dilation, dropout=dropout))
            ch_in = hidden
        self.net = nn.Sequential(*layers)
        self.head = nn.Conv1d(hidden, in_dim, kernel_size=1)

    def forward(self, x):
        x = x.transpose(1,2)   # (B,D,10)
        y = self.net(x)
        y = self.head(y)
        return y.transpose(1,2)  # (B,10,D)

def _normalize_per_sample(x_in, eps=1e-6):
    # x_in: (N,10,C,F)
    mean = x_in.mean(axis=1, keepdims=True)  # (N,1,C,F)
    std  = x_in.std(axis=1, keepdims=True) + eps
    return (x_in - mean) / std, mean, std

class Model:
    def __init__(self, monkey_name=""):
        self.monkey_name = monkey_name
        self.init_steps = 10
        self.use_all_features = True
        self.model = None
        self.trained_C = None
        self.trained_Fdim = None

    def load(self):
        if self.monkey_name == "affi":
            w = "tcn_affi.pth"
        elif self.monkey_name == "beignet":
            w = "tcn_beignet.pth"
        else:
            raise ValueError(f"No such monkey: {self.monkey_name}")

        path = os.path.join(os.path.dirname(__file__), w)
        ckpt = torch.load(path, map_location="cpu")

        self.trained_C = int(ckpt["trained_C"])
        self.trained_Fdim = int(ckpt["trained_Fdim"])
        D = int(ckpt["trained_D"])
        hidden = int(ckpt.get("hidden", 512))
        levels = int(ckpt.get("levels", 5))

        self.model = TCNForecaster(in_dim=D, hidden=hidden, levels=levels, dropout=0.0).to(DEVICE)
        self.model.load_state_dict(ckpt["state_dict"], strict=True)
        self.model.eval()

    @torch.no_grad()
    def predict(self, X):
        X = np.asarray(X, dtype=np.float32)
        N,T,C_in,F_in = X.shape
        if T != 20:
            raise ValueError("Expected 20 time steps")

        out = np.zeros((N, 20, C_in), dtype=np.float32)
        out[:, :10, :] = X[:, :10, :, 0]      # copy observed
        out[:, 10:, :] = X[:, 9:10, :, 0]     # fallback persistence

        if self.use_all_features:
            Xf = X
        else:
            Xf = X[..., :1]

        C_common = min(C_in, self.trained_C)

        x_in = Xf[:, :10, :C_common, :self.trained_Fdim]  # (N,10,C,F')
        x_in_n, mean, std = _normalize_per_sample(x_in)

        last = x_in_n[:, -1:, :, :]  # (N,1,C,F')
        Nn,_,Cc,Fc = x_in_n.shape

        x_flat = torch.from_numpy(x_in_n).to(DEVICE).reshape(Nn, 10, Cc*Fc)
        pred_delta = self.model(x_flat).reshape(Nn, 10, Cc, Fc).cpu().numpy()

        pred_future_n = last + pred_delta
        pred_future = pred_future_n * std + mean

        out[:, 10:, :C_common] = pred_future[..., 0]
        return out


Writing model.py
