# RNAM: Recurrent Neural Additive Model

A first-order [generalized additive model (GAM)](https://hastie.su.domains/Papers/gam.pdf) operating on tabular data $\mathbf{x} \in \mathbb{R}^p$ can be written as

$$
\Gamma(\mathbf{x}) = b + \sum_{i=1}^p f_i(x_i),
$$

and is interpretable under certain conditions.
Constraints such as those described in [Hastie's and Tibshirani's seminal paper](https://hastie.su.domains/Papers/gam.pdf) or [Hooker's generalized functional ANOVA](https://www.jstor.org/stable/27594267) are not applied within this notebook.
The core focus is an efficient estimation of unconstrained shape functions $f_i$ from longitudinal data using neural networks.
Data used within this example are generated from uniform latents, a [Friedman-like function](https://www.slac.stanford.edu/pubs/slacpubs/2250/slac-pub-2336.pdf), as well as a wide and shallow [rnam.minGRU](rnam/gru.py).

Additional related work includes:
- [Berhane's and Tibshirani's GAMs for longitudinal data](https://doi.org/10.2307/3315715).
- [Agarwal et al.'s Neural Additive Model](https://arxiv.org/pdf/2004.13912v2).

Last accessed: 2025-12-15

## Globals

Model:

In [1]:
HIDDEN_DIM = 8
BLOCKS = 4

Data:

In [2]:
TERMS = 8
WARMUP = 8
SEQ_LEN = 1_024
LATENT_DECAY = 0.95
GENERATOR_BATCH_SIZE = 100
SAMPLES = 10_000
SPLIT_POINT = 8_000

Training:

In [3]:
LR = 1e-3
BATCH_SIZE = 128
EPOCHS = 25

## Setup

In [None]:
import time

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

import rnam

In [5]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model

Additive layer.

In [7]:
class AdditiveLayer(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.bias = nn.Parameter(torch.zeros(dim))

    def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        return self.bias + input.sum(0), input.detach()

Simple pre-norm residual block.

In [8]:
class RNAMBlock(nn.Module):
    def __init__(self, terms: int, dim: int) -> None:
        super().__init__()
        self.gru = rnam.minGRU(terms, dim)
        self.ff = nn.Sequential(
            nn.LayerNorm(dim),
            rnam.Linear(terms, dim, dim),
            nn.ReLU(),
            rnam.Linear(terms, dim, dim),
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return input + self.ff(input + self.gru(input)[0])

First-order additive sequence regressor/classifier.

In [9]:
class FirstOrderRNAM(nn.Module):
    def __init__(
        self,
        terms: int,
        hidden_dim: int,
        out_dim: int,
        blocks: int,
    ) -> None:
        super().__init__()
        self.emb = rnam.Linear(terms=terms, in_dim=1, out_dim=hidden_dim)
        self.blocks = nn.ModuleList(
            RNAMBlock(terms, hidden_dim) for _ in range(blocks)
        )
        self.proj = nn.Sequential(
            rnam.Linear(terms=terms, in_dim=hidden_dim, out_dim=hidden_dim),
            nn.ReLU(),
            rnam.Linear(terms=terms, in_dim=hidden_dim, out_dim=out_dim),
        )
        self.gam = AdditiveLayer(dim=out_dim)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        input = self.emb(input)
        for block in self.blocks:
            input = block(input)
        input = self.proj(input[:, :, -1])
        return self.gam(input)

In [10]:
model = FirstOrderRNAM(
    terms=TERMS,
    hidden_dim=HIDDEN_DIM,
    out_dim=1,
    blocks=BLOCKS,
).to(device)

## Data

In [11]:
assert 4 <= TERMS
assert SAMPLES % GENERATOR_BATCH_SIZE == 0
assert SPLIT_POINT < SAMPLES

In [12]:
def collate_fn(
    batch: list[tuple[torch.Tensor, torch.Tensor]],
) -> tuple[torch.Tensor, torch.Tensor]:
    features, targets = [], []
    for x, y in batch:
        features.append(x)
        targets.append(y)

    return torch.stack(features, dim=1), torch.stack(targets)

In [13]:
class DataWrapper(Dataset):
    def __init__(self, features: torch.Tensor, targets: torch.Tensor) -> None:
        super().__init__()
        self.features = features
        self.targets = targets

    def __len__(self) -> int:
        return self.targets.size(0)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        return self.features[:, idx], self.targets[idx]

The EMA over time of latents $\mathbf{z} \sim \text{Uniform}(0, 1)$ is transformed into the target defined below.

In [14]:
def friedman(z: torch.Tensor) -> torch.Tensor:
    ema = torch.zeros_like(z[:, :, 0])
    for t in range(z.size(-2)):
        ema = LATENT_DECAY * ema + (1 - LATENT_DECAY) * z[:, :, t]
    return ema[0].sin() + ema[1].cos() + ema[2].abs().sqrt() + ema[3].square()

Inputs $\mathbf{X}$ are generated from $\text{Generator}(\mathbf{Z})$.

In [15]:
class Generator(nn.Module):
    def __init__(
        self,
    ) -> None:
        super().__init__()
        dim = HIDDEN_DIM * BLOCKS
        self.emb = rnam.Linear(terms=TERMS, in_dim=1, out_dim=dim)
        self.gru = rnam.minGRU(terms=TERMS, dim=dim)
        self.proj = rnam.Linear(terms=TERMS, in_dim=dim, out_dim=1)

    def forward(
        self, input: torch.Tensor, prev_hidden: torch.Tensor | None = None
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        input = self.emb(input)
        input, prev_hidden = self.gru(input, prev_hidden)
        return self.proj(input), prev_hidden

Uniformly distributed warmup noise is used to stabilize the recurrent generator.
Both inputs and targets are z-score normalized.

In [16]:
@torch.no_grad()
def generate_data() -> tuple[Dataset, Dataset]:
    generator = Generator()
    generator.eval()

    features, targets = [], []
    for _ in range(SAMPLES // GENERATOR_BATCH_SIZE):
        z = torch.rand(TERMS, GENERATOR_BATCH_SIZE, WARMUP + SEQ_LEN, 1)
        targets.append(friedman(z))

        batch, prev_hidden = [], None
        for t in range(WARMUP + SEQ_LEN):
            hidden, prev_hidden = generator(z[:, :, t : t + 1], prev_hidden)
            if WARMUP < t + 1:
                batch.append(hidden.detach())

        features.append(torch.cat(batch, dim=-2))

    features = torch.cat(features, dim=1)
    mu = torch.mean(features, (-3, -2, -1), keepdim=True)
    sigma = torch.std(features, (-3, -2, -1), correction=0, keepdim=True)
    features = (features - mu) / sigma

    targets = torch.cat(targets)
    targets = (targets - targets.mean()) / targets.std()

    train_data = DataWrapper(features[:, :SPLIT_POINT], targets[:SPLIT_POINT])
    val_data = DataWrapper(features[:, SPLIT_POINT:], targets[SPLIT_POINT:])

    return train_data, val_data

In [17]:
train_data, val_data = generate_data()

train_loader = DataLoader(
    train_data, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True
)
val_loader = DataLoader(
    val_data, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=False
)

## Training

In [18]:
optimizer = torch.optim.AdamW(params=model.parameters(), lr=LR)
criterion = nn.MSELoss()

Executed on the Google Colab free tier (T4).

In [19]:
if torch.cuda.is_available():
    torch.cuda.synchronize()
start = time.perf_counter()

for epoch in range(EPOCHS):
    model.train()

    train_loss = 0
    for x, y in train_loader:
        optimizer.zero_grad()
        x, y = x.to(device), y.to(device)

        y_hat, _ = model(x)

        loss = criterion(y_hat, y)
        train_loss += loss.item()
        loss.backward()

        optimizer.step()

    model.eval()

    val_loss = 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)

            y_hat, _ = model(x)

            loss = criterion(y_hat, y)
            val_loss += loss.item()

    train_loss /= len(train_loader)
    val_loss /= len(val_loader)

    print(f"{epoch + 1:03d}: {train_loss:.6f}/{val_loss:.6f}")

if torch.cuda.is_available():
    torch.cuda.synchronize()
print(f"total time: {time.perf_counter() - start}")

001: 1.274356/0.934273
002: 0.817840/0.823477
003: 0.728004/0.733901
004: 0.625222/0.639761
005: 0.510634/0.497099
006: 0.383550/0.314465
007: 0.241109/0.198362
008: 0.156576/0.124920
009: 0.112902/0.102335
010: 0.101714/0.100229
011: 0.100948/0.095167
012: 0.098225/0.095685
013: 0.094544/0.086277
014: 0.089156/0.107411
015: 0.087195/0.116670
016: 0.087054/0.082870
017: 0.087589/0.084297
018: 0.089546/0.090741
019: 0.089298/0.081064
020: 0.083044/0.080673
021: 0.083245/0.079981
022: 0.078800/0.079017
023: 0.080868/0.078642
024: 0.078069/0.079202
025: 0.080394/0.077141
total time: 495.89863707200004
