Afiq Abdillah Effiezal Aswadi= · Oliver Britton= · Ross Baker= · Matthew Farrugia-Roberts
Abstract: Modern deep learning science often assumes that neural networks learn from a fixed data distribution. However, many practically important learning problems involve data distributions that change throughout training. How does such non-stationarity impact the inductive biases of deep learning towards models with different structural, generalisation, and safety properties? A fruitful testbed for studying inductive bias is in-context linear regression sequence modelling, where small transformers display strikingly different generalisation patterns depending on the diversity of the (fixed) training task distribution. In this paper, we explore the effect of diversifying the task distribution across training time, finding that such temporal diversity leads to an increased bias towards generalisation over memorisation.
Recommended citation:
@inproceedings{EffiezalAswadi+2026,
title={Temporal Task Diversity: Inductive Biases Under Non-Stationarity in
Synthetic Sequence Modelling},
author={Effiezal Aswadi, Afiq Abdillah
and Britton, Oliver
and Baker, Ross
and Farrugia-Roberts, Matthew
},
booktitle={Technical AI Safety Conference (TAIS)},
year={2026},
}
This repository contains code used to replicate the main experiments in the
paper. Namely, training transformers on synthetic in-context linear regression
(ICLR) sequences, following the setting of Raventós et al.
(2023). Each transformer is trained on
sequences of
Each
In this repository, we are interested in determining the following properties of transformers trained in this setting:
-
How close the transformer's predictions are to two natural Bayesian
baselines: The discrete minimum mean squared error (dMMSE) predictor is
the posterior mean under a uniform prior over
$\mathcal T_M$ . The ridge predictor is the posterior mean under$\mathcal N(\mathbf 0, \mathbf I_D)$ as the prior over the tasks. We measure how close the transformer is to each via$\Delta_\text{PT, dMMSE}$ and$\Delta_\text{PT, Ridge}$ , the mean squared distance between the transformer's predictions and these optimal predictors on held-out sequences. -
What is the transformer's implicit prior over task vectors: The
transformer has an implicit prior over the latent
$\mathbf t$ it is inferring. We approximate this via predictive Monte Carlo, where by rolling out unconditionally from the model and fitting OLS to each generated sequence we may find an approximate sample of$\mathbf t$ from the transformer's prior. The energy distance between this empirical distribution and the dMMSE / ridge priors tells us how close the transformer is to each of the Bayesian baselines.
The two main entry points are:
-
experiment_iclr_task_diversity.py: trains a single transformer, and records the per-training-step metrics (loss,$\Delta_\text{PT, dMMSE}$ ,$\Delta_\text{PT, Ridge}$ ) to a JSON file inruns/. With--compute-energy-distance, it additionally records the energy distance between the transformer's implicit prior and the dMMSE / ridge priors at each evaluation step. Non-stationarity is controlled by either--mala-step-size(for the MALA random walk) or--num-resamples(for the number of resampling events as determined by a Dirichlet distribution). -
experiment_iclr_prior_tracking.py: trains a single transformer, and samples the transformer's implicit prior at regular intervals and writes these to asnapshots.npzfile, alongside per-snapshot$\Delta_\text{PT}$ metrics. This is used to inspect the prior distribution over training time. Non-stationarity is controlled by either--mala-step-sizeor--num-resamples(in this script using equispaced resampling).
Note: Here "iclr" stands for "in-context linear regression", not the deep learning conference.
Requires Python 3.12 or later. We recommend using uv.
git clone https://github.com/matomatical/temporal-task-diversity.git
cd temporal-task-diversity
uv sync
uv sync --group tpu # On TPU hosts, additionally install the JAX TPU backendUse experiment_iclr_task_diversity.py
to train a transformer. By default this script uses the hyperparameters
detailed in Appendix A of the paper (see Hyperparameters),
and with this configuration each training step takes about 5ms on a dual-core
TPU v4, so the default 524288-step run takes about 45 minutes.
Each run writes metrics to runs/iclr_task_diversity/<run_name>.json and saves
an orbax checkpoint to runs/iclr_task_diversity/checkpoints/<run_name>/. By
default, re-running with the same --run-name resumes from the latest
checkpoint, but this may be overriden by --force-restart.
# Train a transformer on a fixed set of 32 tasks, and record metrics to runs/iclr_task_diversity/first-run.json
uv run experiment_iclr_task_diversity.py --task-diversity 32 --run-name first-run
# Train a transformer and additionally record the energy distance between the
# implicit prior and the Bayesian reference priors every 800 steps. Predictive
# Monte Carlo requires a head that emits a likelihood (override the default
# `point` with `gaussian` or `mog`); each rollout runs for --num-examples steps.
uv run experiment_iclr_task_diversity.py \
--task-diversity 32 \
--num-examples 64 \
--head-type gaussian \
--compute-energy-distance \
--run-name first-stationary-energy-distance-run
# Train a transformer on a non-stationary set of tasks, updating according to a
# MALA random walk with step size 1e-3
uv run experiment_iclr_task_diversity.py \
--task-diversity 32 \
--mala-step-size 1e-3 \
--run-name mala-td64-1e-3
# Train a transformer on a non-stationary set of tasks where each task is
# resampled at 10 randomly chosen points throughout training
uv run experiment_iclr_task_diversity.py \
--task-diversity 32 \
--num-resamples 10 \
--run-name resampling-td64-R10The main finding of the paper is that non-stationarity lowers the task diversity threshold for which the transformer switches from memorisation-like behaviour (approximating dMMSE) to generalisation (approximating ridge). For a minimal replication, try the following four configurations:
# M=8, stationary: converges to dMMSE.
uv run experiment_iclr_task_diversity.py --task-diversity 8 --run-name td8-stationary
# M=64, stationary: converges to ridge.
uv run experiment_iclr_task_diversity.py --task-diversity 64 --run-name td64-stationary
# M=8, small amount of non-stationarity: still converges to dMMSE.
uv run experiment_iclr_task_diversity.py --task-diversity 8 --mala-step-size 1e-3 --run-name td8-mala-1e-3
# M=8, large amount of non-stationarity: now converges to ridge instead, despite task diversity being 8.
uv run experiment_iclr_task_diversity.py --task-diversity 8 --mala-step-size 1e-2 --run-name td8-mala-1e-2To verify the result, check the stats in the output JSON. train_delta_dmmse
should be near zero in the td8-stationary and td8-mala-1e3 runs, and
train_delta_ridge near zero in the td64-stationary and td8-mala-1e-2
runs.
Very large step sizes (
See Full options for more information.
Use experiment_iclr_prior_tracking.py
to track a transformer's implicit prior at many points during training. By
default it uses the 1D setting config.json and snapshots.npz to runs/iclr_prior_tracking/<run_name>/,
note that there is no checkpointing for this script.
# Track the prior of a transformer when trained on a stationary set of tasks
uv run experiment_iclr_prior_tracking.py --run-name stationary-tracking
# Track the prior of a transformer when the tasks are changing according to a MALA random walk
uv run experiment_iclr_prior_tracking.py --mala-step-size 1e-2 --run-name mala-1e-2
# Track the prior of a transformer when there are 64 task resamples (equispaced throughout training)
uv run experiment_iclr_prior_tracking.py --num-resamples 64 --run-name resampling-R64
# Track the prior of a transformer, taking denser snapshots in a particular window
uv run experiment_iclr_prior_tracking.py \
--num-steps 100000 \
--dense-step-start 40000 \
--dense-step-end 50000 \
--dense-snapshot-period 64 \
--run-name zoom-40k-50kSee Full options for more information.
The architecture, optimiser, and training schedule are fixed across all runs in
the paper and match the defaults of both training scripts. The per-experiment
variables (iclr/setting.py.
| Category | Hyperparameter | CLI flag | Default |
|---|---|---|---|
| Data | Task vector dimension ( |
--task-dim |
8 |
| Data | In-context examples ( |
--num-examples |
16 or 64 (higher for prior tracking) |
| Data | Observation noise variance ( |
--noise-var |
0.25 |
| Model | Number of layers | --num-blocks |
8 |
| Model | Attention heads per layer | --num-heads |
2 |
| Model | Embedding dimension | --embed-size |
128 |
| Model | MLP hidden dimension | --mlp-size |
128 |
| Model | Layer normalisation | — | Pre |
| Model | Positional embeddings | — | Learned |
| Model | Attention mask | — | Causal |
| Model | Prediction head | --head-type |
point / gaussian / mog
|
| Model | Mixture components ( |
--num-components |
4 (only used when --head-type=mog) |
| Optimiser | Type | — | Adam |
| Optimiser | Peak learning rate ( |
--learning-rate |
3e-3 |
| Optimiser | Warm-up |
--lr-warmup / --no-lr-warmup
|
10% linear, then constant |
| Training | Batch size ( |
--batch-size |
256 |
| Training | Total steps ( |
--num-steps |
524288 ( |
| Per-experiment | Task diversity ( |
--task-diversity |
varies |
| Per-experiment | MALA step size ( |
--mala-step-size |
varies |
| Per-experiment | Resampling events ( |
--num-resamples |
varies |
| Per-experiment | PRNG seed | --seed |
varies |
For experiment_iclr_task_diversity.py,
each run writes
runs/iclr_task_diversity/<run_name>.json # per-step metrics
runs/iclr_task_diversity/checkpoints/<run_name>/ # orbax checkpoint
The JSON metrics have the shape:
For experiment_iclr_prior_tracking.py,
each run writes
runs/iclr_prior_tracking/<run_name>/config.json
runs/iclr_prior_tracking/<run_name>/snapshots.npz
where snapshots.npz contains in compressed numpy format the arrays
steps int32 [N]
losses float32 [N]
tasks float32 [N, M, D]
pr float32 [N, num_pr_samples, D]
delta_dmmse float32 [N]
delta_ridge float32 [N]
resample_steps int32 [#events] # only when --num-resamples > 0
config str scalar (JSON)
$ uv run experiment_iclr_task_diversity.py --help
usage: experiment_iclr_task_diversity.py [-h] [OPTIONS]
Train an in-context regression transformer at one configuration. Stationary by
default. Set --mala-step-size for MALA random-walk non-stationarity, or
--num-resamples for Dirichlet resampling non-stationarity. With
--compute-energy-distance, also records the energy distance between the
transformer's implicit prior (sampled via predictive Monte Carlo) and the dMMSE
/ ridge priors.
Writes per-step metrics to <output_root>/<run_name>.json and an orbax checkpoint
under <output_root>/checkpoints/<run_name>/. Re-invoking with the same
--run-name resumes training from the latest saved checkpoint, pass
--force-restart to discard this checkpoint and start again with the same
run_name.
╭─ options ────────────────────────────────────────────────────────────────────╮
│ -h, --help show this help message and exit │
│ --task-diversity INT number of latent task vectors (M); task set is │
│ sampled from N(0, I_D) at start. (default: 1) │
│ --mala-step-size FLOAT MALA random-walk step size (gamma) for │
│ non-stationary training; mutually exclusive with │
│ --num-resamples. (default: 0.0) │
│ --num-resamples INT number of Dirichlet resampling events (R) per task; │
│ mutually exclusive with --mala-step-size. (default: │
│ 0) │
│ --seed INT PRNG seed. (default: 42) │
│ --output-root PATH parent directory under which the run JSON and │
│ checkpoints subdirectory are created. (default: │
│ runs/iclr_task_diversity) │
│ --run-name {None}|STR name for <output_root>/<run_name>.json and │
│ <output_root>/checkpoints/<run_name>/; │
│ auto-generated from a timestamp if unset. │
│ Re-invoking with the same name resumes training from │
│ the latest saved checkpoint. (default: None) │
│ --task-dim INT task vector dimension (D in the paper). (default: 8) │
│ --num-examples INT in-context examples per sequence (K). (default: 16) │
│ --noise-var FLOAT observation noise variance (sigma^2). (default: │
│ 0.25) │
│ --head-type {point,gaussian,mog} │
│ prediction head. "point" predicts a scalar; │
│ "gaussian" predicts a mean and variance; "mog" │
│ predicts a mixture of Gaussians. (default: point) │
│ --num-components INT number of mixture components (G); used only when │
│ --head-type=mog. (default: 4) │
│ --num-blocks INT number of transformer layers. (default: 8) │
│ --num-heads INT attention heads per layer. (default: 2) │
│ --embed-size INT transformer embedding dimension. (default: 128) │
│ --mlp-size INT MLP hidden dimension per block. (default: 128) │
│ --num-steps INT total training steps. (default: 524288) │
│ --learning-rate FLOAT Adam peak learning rate. (default: 0.003) │
│ --batch-size INT gradient batch size. (default: 256) │
│ --lr-warmup, --no-lr-warmup │
│ linearly warm up the LR from 0 over the first 10% of │
│ training before holding it constant; set │
│ --no-lr-warmup for a constant LR throughout. │
│ (default: True) │
│ --eval-period INT interval (in training steps) between evaluation │
│ passes. (default: 64) │
│ --eval-batch-size INT batch size for held-out evaluation sequences (drawn │
│ from N(0, I_D)). (default: 1024) │
│ --delta-prompt-length INT │
│ prompt prefix length used for the Delta_PT metric │
│ computations. (default: 16) │
│ --delta-tail-start {None}|INT │
│ if set, additionally compute Delta metrics over │
│ positions delta_tail_start..K-1. (default: None) │
│ --checkpoint-period INT │
│ interval (in training steps) between orbax │
│ checkpoint saves. (default: 8192) │
│ --compute-energy-distance, --no-compute-energy-distance │
│ also compute energy distance to the dMMSE and ridge │
│ priors at each evaluation step. (default: False) │
│ --energy-distance-period INT │
│ interval (in training steps) between energy-distance │
│ evaluations. (default: 800) │
│ --energy-n-samples INT predictive Monte Carlo samples per evaluation. │
│ (default: 5000) │
│ --final-checkpoint, --no-final-checkpoint │
│ keep the orbax checkpoint directory after training │
│ finishes; set --no-final-checkpoint to delete it and │
│ retain only the JSON. (default: True) │
│ --force-restart, --no-force-restart │
│ remove any existing checkpoint at --run-name before │
│ starting and train from scratch; default is to │
│ resume. (default: False) │
╰──────────────────────────────────────────────────────────────────────────────╯
$ uv run experiment_iclr_prior_tracking.py --help
usage: experiment_iclr_prior_tracking.py [-h] [OPTIONS]
Train a transformer and snapshot its implicit prior over task vectors via
predictive Monte Carlo. Defaults configure the 1D scenario (task_dim=1,
task_diversity=1) where the prior can be visualised as a histogram on the real
line. Stationary by default. Set --mala-step-size for MALA random-walk
non-stationarity, or --num-resamples for equispaced resampling non-stationarity
(this is the equispaced scheme; the Dirichlet partition lives in
experiment_iclr_task_diversity.py). MALA takes precedence when both are set.
Snapshots are taken every --snapshot-period training steps; an optional
dense-snapshot window (--dense-step-start, --dense-step-end,
--dense-snapshot-period) takes higher-resolution snapshots within a
sub-interval.
Writes a snapshots.npz archive of prior samples and per-snapshot Delta_PT,dMMSE
/ Delta_PT,Ridge metrics under <output_root>/<run_name>/, where <run_name>
defaults to a timestamp if not specified. This script does not resume from
previous runs.
╭─ options ────────────────────────────────────────────────────────────────────╮
│ -h, --help show this help message and exit │
│ --task-diversity INT number of latent task vectors (M); fixed set sampled │
│ from N(0, I_D) (default: 1) │
│ --mala-step-size FLOAT MALA random-walk step size (gamma) for │
│ non-stationary training; takes precedence over │
│ --num-resamples if both are set (default: 0.0) │
│ --num-resamples INT number of independently-sampled task sets shown over │
│ R equal-length segments (R-1 internal resampling │
│ events); 0 disables resampling (default: 0) │
│ --seed INT PRNG seed (default: 0) │
│ --output-root PATH parent directory under which the run subdirectory is │
│ created (default: runs/iclr_prior_tracking) │
│ --run-name {None}|STR name for the run subdirectory under output_root; │
│ defaults to a timestamp if unset (default: None) │
│ --task-dim INT task vector dimension (D) (default: 1) │
│ --num-examples INT in-context examples per sequence (K) (default: 64) │
│ --noise-var FLOAT observation noise variance (sigma^2) (default: 0.25) │
│ --head-type {gaussian,mog} │
│ prediction head: single Gaussian (mean+variance) or │
│ mixture of Gaussians (default: gaussian) │
│ --num-components INT number of mixture components (G); used only when │
│ --head-type=mog (default: 4) │
│ --num-blocks INT number of transformer layers (default: 8) │
│ --num-heads INT attention heads per layer (default: 2) │
│ --embed-size INT transformer embedding dimension (default: 128) │
│ --mlp-size INT MLP hidden dimension per block (default: 128) │
│ --num-steps INT total training steps (default: 524288) │
│ --learning-rate FLOAT Adam peak learning rate (default: 0.003) │
│ --batch-size INT gradient batch size (default: 256) │
│ --snapshot-period INT interval (in training steps) between uniform │
│ snapshots (default: 512) │
│ --num-pr-samples INT predictive Monte Carlo samples per snapshot │
│ (default: 200) │
│ --pr-generation-steps INT │
│ autoregressive rollout length per PMC sample │
│ (default: 16) │
│ --dense-snapshot-period {None}|INT │
│ interval between dense snapshots within the dense │
│ window; if unset, only uniform snapshots are taken │
│ (default: None) │
│ --dense-step-start {None}|INT │
│ start of an explicit dense-snapshot window; │
│ overrides the resampling-aligned dense window │
│ (default: None) │
│ --dense-step-end {None}|INT │
│ end of an explicit dense-snapshot window (default: │
│ None) │
│ --dense-snapshot-window INT │
│ width of the dense window placed around each │
│ resampling event (used when dense_step_start/end are │
│ unset) (default: 2048) │
│ --force-restart, --no-force-restart │
│ remove any existing snapshots.npz at --run-name │
│ before starting; default is to refuse to overwrite │
│ (default: False) │
╰──────────────────────────────────────────────────────────────────────────────╯


{ "times": { "start": ..., "end": ... }, "config": {...}, "stats": [ { "step": 0, "lr": ..., "timestamp": ..., "train_error_model": ..., // train_* keys: evaluated on sequences derived from training step task set "train_error_ridge": ..., "train_error_dmmse": ..., "test_error_model": ..., // test_* keys: evaluated on sequences derived from tasks sampled from N(0, I_D) "test_error_ridge": ..., "test_error_dmmse": ..., "train_delta_dmmse": ..., "train_delta_ridge": ..., "test_delta_dmmse": ..., "test_delta_ridge": ..., "train_nll_model": ..., // only with --head-type gaussian or mog "test_nll_model": ..., // only with --head-type gaussian or mog "train_delta_tail_dmmse": ..., // only with --delta-tail-start "train_delta_tail_ridge": ..., // only with --delta-tail-start "test_delta_tail_dmmse": ..., // only with --delta-tail-start "test_delta_tail_ridge": ..., // only with --delta-tail-start "energy_dist/prior/dmmse": ..., // only with --compute-energy-distance "energy_dist/prior/ridge": ... // only with --compute-energy-distance }, ... ] }