In [None]:
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import random
from jax.typing import ArrayLike
from numpyro import sample
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal

from aimz.model import ImpactModel


def lm(X: ArrayLike, y: ArrayLike | None = None) -> None:
    """Linear regression model."""
    n_features = X.shape[1]

    # Priors for weights and bias
    w = sample("w", dist.Normal(jnp.zeros(n_features), jnp.ones(n_features)))
    b = sample("b", dist.Normal())

    # Likelihood
    mu = jnp.dot(X, w) + b
    sigma = sample("sigma", dist.Exponential())
    sample("y", dist.Normal(mu, sigma), obs=y)


rng_key = random.key(42)
key_w, key_b, key_x, key_e = random.split(rng_key, 4)

w = random.normal(key_w, (10,))
b = random.normal(key_b)

X = random.normal(key_x, (100000, 10))
e = random.normal(key_e, (100000,))
y = jnp.dot(X, w) + b + e


im = ImpactModel(
    lm,
    rng_key=random.key(42),
    inference=SVI(
        lm,
        guide=AutoNormal(lm),
        optim=numpyro.optim.Adam(step_size=1e-3),
        loss=Trace_ELBO(),
    ),
)
im.fit(X=X, y=y, epochs=1, batch_size=100, progress=True)
im.predict(X=X, batch_size=100)
im.cleanup()