# Model Persistence

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/markean/aimz/blob/main/docs/notebooks/model_persistence.ipynb)

Model persistence allows you to save a trained model to disk and reload it later for inference or continued training. 
This documentation shows how to serialize and deserialize an `ImpactModel` instance using [`cloudpickle`](https://pypi.org/project/cloudpickle/), which extends the standard `pickle` module to handle a wide range of Python objects, including closures and local functions. 
An alternative is [`dill`](https://pypi.org/project/dill/), which offers similar functionality.

In [1]:
from pathlib import Path

import cloudpickle
import jax.numpy as jnp
import numpyro.distributions as dist
from jax import random
from jax.typing import ArrayLike
from numpyro import optim, sample
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal

from aimz import ImpactModel

%load_ext watermark

## Model Training

In [2]:
def model(X: ArrayLike, y: ArrayLike | None = None) -> None:
    """Linear regression model."""
    w = sample("w", dist.Normal().expand((X.shape[1],)))
    b = sample("b", dist.Normal())
    mu = jnp.dot(X, w) + b
    sigma = sample("sigma", dist.Exponential())
    sample("y", dist.Normal(mu, sigma), obs=y)


rng_key = random.key(42)
rng_key, rng_key_w, rng_key_b, rng_key_x, rng_key_e = random.split(rng_key, 5)
w = random.normal(rng_key_w, (10,))
b = random.normal(rng_key_b)
X = random.normal(rng_key_x, (1000, 10))
e = random.normal(rng_key_e, (1000,))
y = jnp.dot(X, w) + b + e

rng_key, rng_subkey = random.split(rng_key)
im = ImpactModel(
    model,
    rng_key=rng_subkey,
    inference=SVI(
        model,
        guide=AutoNormal(model),
        optim=optim.Adam(step_size=1e-3),
        loss=Trace_ELBO(),
    ),
)
im.fit_on_batch(X, y, progress=False);

Backend: cpu, Devices: 1
Performing variational inference optimization...
Posterior sampling...


## Serialization

Save a trained `ImpactModel` (and optionally its input data) to disk for later use:

In [3]:
with Path.open("model.pkl", "wb") as f:
    cloudpickle.dump((im, X, y), f)

## Deserialization

Load a previously saved `ImpactModel` (and optionally its input data) from disk in a fresh new session or different runtime environment. To use the loaded model correctly, the same dependencies, imports, and any constants or variables that the `model` relied on when it was saved must be available. Any JAX array—whether part of the `ImpactModel` or the input data—will be placed on the default device.

In [4]:
from pathlib import Path

import cloudpickle
import jax.numpy as jnp
import numpyro.distributions as dist
from numpyro import sample

with Path.open("model.pkl", "rb") as f:
    im, X, y = cloudpickle.load(f)

Backend: cpu, Devices: 1


## Model Usage

In [5]:
# Resume training from the previous SVI state
im.fit_on_batch(X, y)

# Predict using the loaded model
im.predict_on_batch(X)

Performing variational inference optimization...


100%|██████████| 10000/10000 [00:01<00:00, 7239.14it/s, init loss: 1494.5947, avg. loss [9501-10000]: 1491.9969]


Posterior sampling...


## Resources

- [`dill` documentation](https://dill.readthedocs.io/en/latest/)
- [`jax` `Array` serialization](https://docs.jax.dev/en/latest/jax.numpy.html#copying-and-serialization)

In [6]:
%watermark -iv

aimz       : 0.8.0
jax        : 0.7.2
numpyro    : 0.19.0
cloudpickle: 3.1.1

