# Disk-Backed vs. On-Batch Methods

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/markean/aimz/blob/main/docs/notebooks/disk_and_on_batch.ipynb)

This page explains and compares the two complementary execution styles provided by `ImpactModel`:

* **Disk-backed** (default) methods iterate over the input in chunks, materialize results incrementally, and persist structured artifacts (Zarr‑backed `xarray.DataTree` plus metadata) to a temporary or user-specified output directory.
* **On-batch** (``*_on_batch`` suffix) methods execute a single, fully in-memory pass and can optionally return a plain `dict` instead of a `xarray.DataTree`. The naming mirrors the Keras convention to signal an immediate, single-batch, memory-resident operation.

## Why Disk-Backed by Default

The non-`*_on_batch` methods default to a disk-backed (chunked) execution model for several reasons:

* Posterior predictive and prior predictive tensors can scale as `(#samples x #dims x #posterior_samples x ...)`. 
  Even moderate increases in any axis (time, spatial units, parameter samples) can exceed host or accelerator RAM.
* Using `batch_size` with chunked iteration limits peak memory and prevents out-of-memory errors.
* Persisted Zarr arrays with metadata (coords, dims, attrs) create an artifact you can reopen without rerunning inference.
* The `xarray.DataTree` + Zarr format integrates with scientific Python tools such as Dask and ArviZ.
* Summaries (means, HDIs, residual PPC stats) can be computed lazily over chunked storage without first materializing dense arrays.
* One API works for both small experiments and large-scale use cases.

## Comparison
Disk-backed variants target larger datasets, enable chunked processing, multi-device parallelism, and stable artifact generation. 
These methods build internal data loaders, iterate in chunks, and decouple sampling from file I/O, enabling concurrent execution.
Outputs consolidate into a single `xarray.DataTree` backed by Zarr files for post-hoc analysis.
On-batch variants, in contrast, favor minimal overhead, immediate return, and greater flexibility when posterior sample shapes are not shard-friendly.

### Feature Summary

| Feature                       | Disk-backed (default)                | On-batch (`*_on_batch`)                                      |
|-------------------------------|-------------------------------------|---------------------------------------------------------------|
| Typical dataset size           | Medium → large                      | Small → moderate                                              |
| Supported use cases            | Standard models                      | Broader model support                                         |
| Peak memory usage              | Chunk-bounded                        | Full batch resident                                           |
| Writes to disk                 | Yes                                  | No                                                            |
| Return type                    | `xarray.DataTree`                    | `xarray.DataTree` or `dict` (via `return_datatree=False`)    |
| Custom batch sizing            | Yes (`batch_size`)                   | No (single pass)                                              |
| Device parallelism (sharding)  | Yes                                  | No                                                            |
| Automatic fallback             | Yes (may auto-delegate to on-batch) | No (final mode)                                               |
| Latency (small data)           | Higher (I/O + orchestration)         | Minimal                                                       |
                       |

### Capability Matrix

| Capability                      | Disk-backed (default)                               | On-batch (`*_on_batch`)                                             |
|---------------------------------|----------------------------------------------------|----------------------------------------------------------------------|
| Full dataset training           | `fit()`                                           | `fit_on_batch()`                                                     |
| Single training step            | N/A                                                | `train_on_batch()`                                                   |
| Prior predictive sampling       | `sample_prior_predictive()`                        | `sample_prior_predictive_on_batch()`                                 |
| Posterior sampling              | `sample()`                                        | N/A                                                                  |
| Posterior predictive sampling   | `predict()` or `sample_posterior_predictive()`   | `predict_on_batch()` or `sample_posterior_predictive_on_batch()`     |
| Log-likelihood computation      | `log_likelihood()`                                 | N/A                                                                  |
| Effect estimation               | `estimate_effect()`                                | (consumes outputs above)                      

## Quick Recommendations
* Moderate or large data, or need persisted outputs: use disk-backed (e.g., `fit()`, `predict()`).
* Small data, rapid iteration, CI, or read-only / ephemeral filesystem: use on-batch (``*_on_batch``).
* If `predict()` issues a fallback warning, call `.predict_on_batch()` directly.
  This occurs when the model or posterior sample shapes are incompatible with shard-based chunked execution.
* Custom training loop: iterate with `train_on_batch()`.
* Need multi-device (sharding) execution: disk-backed.
* Need raw NumPy/dict outputs (no `xarray.DataTree`): on-batch with `return_datatree=False`.

> **Note:**  
> For MCMC inference, only `.fit_on_batch()` or `.sample()` is supported for training and posterior sampling,  
> as MCMC is incompatible with epoch-based or chunked batch processing. See the [MCMC Support](https://aimz.readthedocs.io/stable/user_guide/mcmc.html) for more details.

## Example: `.predict()` with Fallback Warning

In [1]:
import jax.numpy as jnp
import numpyro.distributions as dist
from jax import random
from jax.typing import ArrayLike
from numpyro import optim, plate, sample
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal

from aimz import ImpactModel

%load_ext watermark


def model(X: ArrayLike, y: ArrayLike | None = None) -> None:
    # Model includes a local latent variable
    sigma = sample("sigma", dist.Exponential().expand((X.shape[0],)))
    with plate("data", size=X.shape[0]):
        sample("y", dist.Normal(0.0, sigma), obs=y)


rng_key = random.key(42)
rng_key, rng_key_X, rng_key_y = random.split(rng_key, 3)
X = random.normal(rng_key_X, (100, 2))
y = random.normal(rng_key_y, (100,))


im = ImpactModel(
    model,
    rng_key=rng_key,
    inference=SVI(
        model,
        guide=AutoNormal(model),
        optim=optim.Adam(step_size=1e-3),
        loss=Trace_ELBO(),
    ),
    # This internally calls the `.run()` method of `SVI`
).fit_on_batch(X, y)

# Calling `.predict()` triggers a fallback warning
im.predict(X)

Backend: cpu, Devices: 1
Performing variational inference optimization...


100%|██████████| 10000/10000 [00:00<00:00, 14935.00it/s, init loss: 670.9567, avg. loss [9501-10000]: 166.4817]


Posterior sampling...


  im.predict(X)


## Performance Tips
* Tune `batch_size`` appropriately; it also determines the chunk size for Zarr-backed arrays.
* Monitor disk usage, as chunk sizes scale with `batch_size` and `num_samples`.
* Reduce `num_samples` first for faster iteration.
* Use on-batch methods in tests to minimize I/O overhead.

In [2]:
%watermark -iv

numpyro: 0.19.0
jax    : 0.7.2
aimz   : 0.8.0

