# NN with Numpyro

* `stax`
* `flax`
* `haiku`
* `equinox`

In [None]:
import numpy as np

np.random.seed(0)

import tqdm

from flax import linen as nn

from jax import jit, random

import numpyro

import numpyro.distributions as dist

from numpyro.contrib.module import random_flax_module, flax_module

from numpyro.infer import (
    Predictive,
    SVI,
    TraceMeanField_ELBO,
    autoguide,
    init_to_feasible,
)

import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
def generate_data(n_samples: int = 100, latent: bool = False, sigma: float = 0.01):

    if not latent:
        x = 2.5 * np.random.normal(size=n_samples)
        y = np.sin(x) / x + sigma * np.random.normal(size=n_samples)
    else:
        x = 5.5 * np.random.normal(size=n_samples)
        x = np.sort(x)
        y = np.sin(x) / x

    return x[:, None], y[:, None]

In [None]:
n_train_data = 1_000
n_test_data = 50
n_plot_data = 5_000

x_train, y_train = generate_data(n_train_data)
x_test, y_test = generate_data(n_test_data)
x_plot, y_plot = generate_data(n_plot_data, latent=True)

In [None]:
fig, ax = plt.subplots()

ax.scatter(x_train, y_train, color="tab:blue", label="Train")
ax.scatter(x_test, y_test, color="red", label="Test")
ax.plot(x_plot, y_plot, color="black", label="True")

plt.legend()
plt.show()

## Model

`[32, 32, 1]`

In `keras`-like syntax:

```python
x = Dense(1, 32)(x)
x = Tanh(x)
x = Dense(32, 32)(x)
x = Tanh(x)
x = Dense(32, 1)(x)
```

## Stax

In [None]:
from jax.example_libraries import stax

In [None]:
def stax_net(hidden_dim, out_dim):
    return stax.serial(
        stax.Dense(hidden_dim),
        stax.Tanh,
        stax.Dense(hidden_dim),
        stax.Tanh,
        stax.Dense(out_dim),
    )

In [None]:
def stax_model(x, y=None, batch_size=None):

    if batch_size is None:
        batch_size, *n_dims = x.shape
    else:
        n_dims = x.shape[1:]

    # net = random_flax_module("nn", module, dist.Normal(0, 0.1), input_shape=(batch_size, *n_dims))
    net = numpyro.module("nn", stax_net(64, 1), input_shape=(batch_size, *n_dims))
    # net = flax_module("nn", module, input_shape=())

    with numpyro.plate("batch", x.shape[0], subsample_size=batch_size):

        batch_x = numpyro.subsample(x, event_dim=1)

        batch_y = numpyro.subsample(y, event_dim=1) if y is not None else None

        mean = net(batch_x)

        return numpyro.sample(
            "obs", dist.Normal(loc=mean, scale=0.01).to_event(1), obs=batch_y
        )

In [None]:
with numpyro.handlers.seed(rng_seed=0):

    y_pred = stax_model(x_train, batch_size=None)

assert y_pred.shape == x_train.shape

# BATCH SIZE
batch_size = 32
with numpyro.handlers.seed(rng_seed=0):

    y_pred = stax_model(x_train, batch_size=batch_size)

assert y_pred.shape == x_train[:batch_size].shape

# Conditional
with numpyro.handlers.seed(rng_seed=0):

    y_pred = stax_model(x_train, y_train)

np.testing.assert_array_almost_equal(y_pred, y_train)


# Conditional + BATCH SIZE
with numpyro.handlers.seed(rng_seed=0):

    y_pred = stax_model(x_train, y_train, batch_size=batch_size)

np.testing.assert_equal(y_pred.shape, y_train[:batch_size].shape)

In [None]:
guide = autoguide.AutoDelta(stax_model, init_loc_fn=init_to_feasible)

svi = SVI(stax_model, guide, numpyro.optim.Adam(5e-3), TraceMeanField_ELBO())

n_iterations = 10_000

svi_result = svi.run(
    random.PRNGKey(0), n_iterations, x=x_train, y=y_train, batch_size=256
)

params, losses = svi_result.params, svi_result.losses

In [None]:
predictive = Predictive(stax_model, guide=guide, params=params, num_samples=1000)

In [None]:
y_pred = predictive(random.PRNGKey(1), x_plot)["obs"].copy()

y_upper, y_mu, y_lower = np.quantile(y_pred, [0.05, 0.5, 0.95], axis=0)

In [None]:
fig, ax = plt.subplots()

ax.scatter(x_train, y_train, color="tab:blue", label="Train")
ax.scatter(x_test, y_test, color="red", label="Test")
plt.plot(x_plot, y_mu, color="black", label="Predictions")
plt.plot(x_plot, y_upper, color="tab:orange", label="upper bound")
plt.plot(x_plot, y_lower, color="tab:orange", label="lower bound")
# ax.plot(x_plot, y_plot, color='black', label='True')

plt.legend()
plt.show()

## Flax

In [None]:
class Net(nn.Module):

    n_units: int

    @nn.compact
    def __call__(self, x):

        x = nn.Dense(self.n_units)(x)

        x = nn.tanh(x)

        x = nn.Dense(self.n_units)(x)

        x = nn.tanh(x)

        mean = nn.Dense(1)(x)

        return mean

In [None]:
def flax_model(x, y=None, batch_size=None):

    module = Net(n_units=32)
    if batch_size is None:
        batch_size, *n_dims = x.shape
    else:
        n_dims = x.shape[1:]

    # net = random_flax_module("nn", module, dist.Normal(0, 0.1), input_shape=(batch_size, *n_dims))
    net = flax_module("nn", module, input_shape=(batch_size, *n_dims))
    # net = flax_module("nn", module, input_shape=())

    with numpyro.plate("batch", x.shape[0], subsample_size=batch_size):

        batch_x = numpyro.subsample(x, event_dim=1)

        batch_y = numpyro.subsample(y, event_dim=1) if y is not None else None

        mean = net(batch_x)

        return numpyro.sample(
            "obs", dist.Normal(loc=mean, scale=0.01).to_event(1), obs=batch_y
        )

In [None]:
with numpyro.handlers.seed(rng_seed=0):

    y_pred = flax_model(x_train, batch_size=x_train.shape[0])

assert y_pred.shape == x_train.shape

with numpyro.handlers.seed(rng_seed=0):

    y_pred = flax_model(x_train, y_train)

np.testing.assert_array_almost_equal(y_pred, y_train)

In [None]:
guide = autoguide.AutoDelta(flax_model, init_loc_fn=init_to_feasible)

svi = SVI(flax_model, guide, numpyro.optim.Adam(5e-3), TraceMeanField_ELBO())

n_iterations = 10_000

svi_result = svi.run(random.PRNGKey(0), n_iterations, x_train, y_train, batch_size=256)

params, losses = svi_result.params, svi_result.losses


# predictive = Predictive(model, guide=guide, params=params, num_samples=1000)

# y_pred = predictive(random.PRNGKey(1), x_plt)["obs"].copy()

# assert losses[-1] < 3000

# assert np.sqrt(np.mean(np.square(y_test - y_pred))) < 1

In [None]:
predictive = Predictive(flax_model, guide=guide, params=params, num_samples=1000)

In [None]:
y_pred = predictive(random.PRNGKey(1), x_plot)["obs"].copy()

In [None]:
y_upper, y_mu, y_lower = np.quantile(y_pred, [0.05, 0.5, 0.95], axis=0)

In [None]:
fig, ax = plt.subplots()

ax.scatter(x_train, y_train, color="tab:blue", label="Train")
ax.scatter(x_test, y_test, color="red", label="Test")
plt.plot(x_plot, y_mu, color="black", label="Predictions")
plt.plot(x_plot, y_upper, color="tab:orange", label="upper bound")
plt.plot(x_plot, y_lower, color="tab:orange", label="lower bound")
# ax.plot(x_plot, y_plot, color='black', label='True')

plt.legend()
plt.show()

## Equinox

In [None]:
import equinox as eqx
import jax
from typing import Optional

In [None]:
class Tanh(eqx.Module):
    def __call__(self, x, *, key: Optional["jax.random.PRNGKey"] = None):

        return jax.nn.tanh(x)

In [None]:
mlp = eqx.nn.MLP(
    in_size=1,
    out_size=1,
    width_size=32,
    depth=2,
    activation=Tanh(),
    final_activation=eqx.nn.Identity(),
    key=jax.random.PRNGKey(42),
)

In [None]:
out = jax.vmap(mlp)(x_train)
out.shape

In [None]:
def make_eq_net(in_size=1, out_size=1, hidden_dim=32, depth=2, activation=Tanh()):

    mlp = eqx.nn.MLP(
        in_size=in_size,
        out_size=out_size,
        width_size=hidden_dim,
        depth=depth,
        activation=activation,
        final_activation=eqx.nn.Identity(),
        key=jax.random.PRNGKey(42),
    )
    params, static = eqx.partition(mlp, eqx.is_inexact_array)

    print("here!")

    # init function for compatibility
    def init_fn(rng=None, input_shape=None):
        print("init!")
        return None, params

    def apply_fn(_params, x):
        print("apply")
        print(x.shape)
        model = eqx.combine(_params, static)
        # return jax.vmap(model)(x)
        return model(x)

    return init_fn, apply_fn

In [None]:
init_fn, apply_fn = stax_net(1, 1)

In [None]:
out = init_fn(jax.random.PRNGKey(45), x_train.shape)
out

In [None]:
init_fn, apply_fn = make_eq_net()

In [None]:
_, params = init_fn()

out = apply_fn(params, x_train[:32])
out.shape

In [None]:
import functools


def my_module(name, nn, input_shape=None):
    """
    Declare a :mod:`~jax.example_libraries.stax` style neural network inside a
    model so that its parameters are registered for optimization via
    :func:`~numpyro.primitives.param` statements.

    :param str name: name of the module to be registered.
    :param tuple nn: a tuple of `(init_fn, apply_fn)` obtained by a :mod:`~jax.example_libraries.stax`
        constructor function.
    :param tuple input_shape: shape of the input taken by the
        neural network.
    :return: a `apply_fn` with bound parameters that takes an array
        as an input and returns the neural network transformed output
        array.
    """
    print("here!")
    module_key = name + "$params"
    nn_init, nn_apply = nn
    nn_params = numpyro.param(module_key)
    if nn_params is None:
        if input_shape is None:
            raise ValueError("Valid value for `input_shape` needed to initialize.")
        rng_key = numpyro.prng_key()
        _, nn_params = nn_init(rng_key, input_shape)
        print(nn_params)
        numpyro.param(module_key, nn_params)
    return functools.partial(nn_apply, nn_params)

In [None]:
def eqx_model(x, y=None, batch_size=None):

    if batch_size is None:
        batch_size, *n_dims = x.shape
    else:
        n_dims = x.shape[1:]

    eqx_net = make_eq_net(1, 1, 32, 2, Tanh())
    # net = random_flax_module("nn", module, dist.Normal(0, 0.1), input_shape=(batch_size, *n_dims))
    input_shape = (batch_size, *n_dims)
    print(input_shape)
    net = my_module("nn", eqx_net, input_shape=input_shape)
    # net = flax_module("nn", module, input_shape=())

    with numpyro.plate("batch", x.shape[0], subsample_size=batch_size):

        batch_x = numpyro.subsample(x, event_dim=1)

        batch_y = numpyro.subsample(y, event_dim=1) if y is not None else None

        print("heyyy:")

        mean = net(batch_x)

        return numpyro.sample(
            "obs", dist.Normal(loc=mean, scale=0.01).to_event(1), obs=batch_y
        )

In [None]:
with numpyro.handlers.seed(rng_seed=0):

    y_pred = eqx_model(x_train, batch_size=x_train.shape[0])

assert y_pred.shape == x_train.shape

with numpyro.handlers.seed(rng_seed=0):

    y_pred = eqx_model(x_train, y_train)

np.testing.assert_array_almost_equal(y_pred, y_train)

In [None]:
guide = autoguide.AutoDelta(flax_model, init_loc_fn=init_to_feasible)

svi = SVI(flax_model, guide, numpyro.optim.Adam(5e-3), TraceMeanField_ELBO())

n_iterations = 10_000

svi_result = svi.run(random.PRNGKey(0), n_iterations, x_train, y_train, batch_size=256)

params, losses = svi_result.params, svi_result.losses


# predictive = Predictive(model, guide=guide, params=params, num_samples=1000)

# y_pred = predictive(random.PRNGKey(1), x_plt)["obs"].copy()

# assert losses[-1] < 3000

# assert np.sqrt(np.mean(np.square(y_test - y_pred))) < 1

In [None]:
predictive = Predictive(flax_model, guide=guide, params=params, num_samples=1000)

In [None]:
y_pred = predictive(random.PRNGKey(1), x_plot)["obs"].copy()

In [None]:
y_upper, y_mu, y_lower = np.quantile(y_pred, [0.05, 0.5, 0.95], axis=0)

In [None]:
fig, ax = plt.subplots()

ax.scatter(x_train, y_train, color="tab:blue", label="Train")
ax.scatter(x_test, y_test, color="red", label="Test")
plt.plot(x_plot, y_mu, color="black", label="Predictions")
plt.plot(x_plot, y_upper, color="tab:orange", label="upper bound")
plt.plot(x_plot, y_lower, color="tab:orange", label="lower bound")
# ax.plot(x_plot, y_plot, color='black', label='True')

plt.legend()
plt.show()