In [None]:
import numpy as np

np.random.seed(0)

import tqdm

from flax import linen as nn

from jax import jit, random

import numpyro

import numpyro.distributions as dist

from numpyro.contrib.module import random_flax_module, flax_module

from numpyro.infer import (
    Predictive,
    SVI,
    TraceMeanField_ELBO,
    autoguide,
    init_to_feasible,
)

import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
# def generate_data(n_samples, plot: bool=False):

#     x = np.random.normal(size=n_samples)


#     if not plot:
#         x = np.random.normal(size=n_samples)
#         y = np.cos(x * 3) + np.random.normal(size=n_samples) * np.abs(x) / 2

#     else:
#         x = 1.5 * np.random.normal(size=n_samples)
#         x = np.sort(x)
#         y = np.cos(x * 3) #= np.random.normal(size=n_samples) * np.abs(x) / 2

#     return x, y

In [None]:
def generate_data(n_samples: int = 100, latent: bool = False, sigma: float = 0.01):

    if not latent:
        x = 2.5 * np.random.normal(size=n_samples)
        y = np.sin(x) / x + sigma * np.random.normal(size=n_samples)
    else:
        x = 5.5 * np.random.normal(size=n_samples)
        x = np.sort(x)
        y = np.sin(x) / x

    return x, y

In [None]:
n_train_data = 1_000
n_test_data = 50
n_plot_data = 5_000

x_train, y_train = generate_data(n_train_data)
x_test, y_test = generate_data(n_test_data)
x_plot, y_plot = generate_data(n_plot_data, latent=True)

In [None]:
fig, ax = plt.subplots()

ax.scatter(x_train, y_train, color="tab:blue", label="Train")
ax.scatter(x_test, y_test, color="red", label="Test")
ax.plot(x_plot, y_plot, color="black", label="True")

plt.legend()
plt.show()

## Determinstic NN

In [None]:
class Net(nn.Module):

    n_units: int

    @nn.compact
    def __call__(self, x):

        x = nn.Dense(self.n_units)(x[..., None])

        x = nn.swish(x)

        x = nn.Dense(self.n_units)(x)

        x = nn.swish(x)

        mean = nn.Dense(1)(x)

        return mean.squeeze()

In [None]:
def nn_model(x, y=None, batch_size=None):

    module = Net(n_units=32)

    # net = random_flax_module("nn", module, dist.Normal(0, 0.1), input_shape=())
    net = flax_module("nn", module, input_shape=())

    with numpyro.plate("batch", x.shape[0], subsample_size=batch_size):

        batch_x = numpyro.subsample(x, event_dim=0)

        batch_y = numpyro.subsample(y, event_dim=0) if y is not None else None

        mean = net(batch_x)

        numpyro.sample("obs", dist.Normal(loc=mean, scale=0.01), obs=batch_y)

In [None]:
# guide = autoguide.AutoNormal(model, init_loc_fn=init_to_feasible)
guide = autoguide.AutoDelta(nn_model, init_loc_fn=init_to_feasible)

svi = SVI(nn_model, guide, numpyro.optim.Adam(5e-3), TraceMeanField_ELBO())

n_iterations = 10_000

svi_result = svi.run(random.PRNGKey(0), n_iterations, x_train, y_train, batch_size=32)

params, losses = svi_result.params, svi_result.losses


# predictive = Predictive(model, guide=guide, params=params, num_samples=1000)

# y_pred = predictive(random.PRNGKey(1), x_plt)["obs"].copy()

# assert losses[-1] < 3000

# assert np.sqrt(np.mean(np.square(y_test - y_pred))) < 1

In [None]:
predictive = Predictive(nn_model, guide=guide, params=params, num_samples=1000)

In [None]:
y_pred = predictive(random.PRNGKey(1), x_plot)["obs"].copy()

In [None]:
y_upper, y_mu, y_lower = np.quantile(y_pred, [0.05, 0.5, 0.95], axis=0)

In [None]:
fig, ax = plt.subplots()

ax.scatter(x_train, y_train, color="tab:blue", label="Train")
ax.scatter(x_test, y_test, color="red", label="Test")
plt.plot(x_plot, y_mu, color="black", label="Predictions")
plt.plot(x_plot, y_upper, color="tab:orange", label="upper bound")
plt.plot(x_plot, y_lower, color="tab:orange", label="lower bound")
# ax.plot(x_plot, y_plot, color='black', label='True')

plt.legend()
plt.show()

## Probabilistic NN

In [None]:
class Net(nn.Module):

    n_units: int

    @nn.compact
    def __call__(self, x):

        x = nn.Dense(self.n_units)(x[..., None])

        x = nn.swish(x)

        x = nn.Dense(self.n_units)(x)

        x = nn.swish(x)

        mean = nn.Dense(1)(x)

        rho = nn.Dense(1)(x)

        return mean.squeeze(), rho.squeeze()

In [None]:
def probnn_model(x, y=None, batch_size=None):

    module = Net(n_units=32)

    # net = random_flax_module("nn", module, dist.Normal(0, 0.1), input_shape=())
    net = flax_module("nn", module, input_shape=())

    with numpyro.plate("batch", x.shape[0], subsample_size=batch_size):

        batch_x = numpyro.subsample(x, event_dim=0)

        batch_y = numpyro.subsample(y, event_dim=0) if y is not None else None

        mean, rho = net(batch_x)

        sigma = nn.softplus(rho + 1e-10)

        numpyro.sample("obs", dist.Normal(mean, sigma), obs=batch_y)

In [None]:
guide = autoguide.AutoNormal(probnn_model, init_loc_fn=init_to_feasible)

svi = SVI(probnn_model, guide, numpyro.optim.Adam(5e-3), TraceMeanField_ELBO())

n_iterations = 10_000

svi_result = svi.run(random.PRNGKey(0), n_iterations, x_train, y_train, batch_size=32)

params, losses = svi_result.params, svi_result.losses


# predictive = Predictive(model, guide=guide, params=params, num_samples=1000)

# y_pred = predictive(random.PRNGKey(1), x_plt)["obs"].copy()

assert losses[-1] < 10_000

# assert np.sqrt(np.mean(np.square(y_test - y_pred))) < 1

In [None]:
predictive = Predictive(probnn_model, guide=guide, params=params, num_samples=1000)

In [None]:
y_pred = predictive(random.PRNGKey(1), x_plot)["obs"].copy()

In [None]:
y_upper, y_mu, y_lower = np.quantile(y_pred, [0.05, 0.5, 0.95], axis=0)

In [None]:
fig, ax = plt.subplots()

ax.scatter(x_train, y_train, color="tab:blue", label="Train")
ax.scatter(x_test, y_test, color="red", label="Test")
plt.plot(x_plot, y_mu, color="black", label="Predictions")
plt.plot(x_plot, y_upper, color="tab:orange", label="upper bound")
plt.plot(x_plot, y_lower, color="tab:orange", label="lower bound")
# ax.plot(x_plot, y_plot, color='black', label='True')

plt.legend()
plt.show()

## Bayesian NN

In [None]:
def bnn_model(x, y=None, batch_size=None):

    module = Net(n_units=32)

    net = random_flax_module("nn", module, dist.Normal(0, 0.1), input_shape=())
    # net = flax_module("nn", module, input_shape=())

    with numpyro.plate("batch", x.shape[0], subsample_size=batch_size):

        batch_x = numpyro.subsample(x, event_dim=0)

        batch_y = numpyro.subsample(y, event_dim=0) if y is not None else None

        mean, rho = net(batch_x)

        sigma = nn.softplus(rho + 1e-5)

        numpyro.sample("obs", dist.Normal(mean, sigma), obs=batch_y)

In [None]:
guide = autoguide.AutoNormal(bnn_model, init_loc_fn=init_to_feasible)

svi = SVI(bnn_model, guide, numpyro.optim.Adam(5e-3), TraceMeanField_ELBO())

n_iterations = 3000

svi_result = svi.run(random.PRNGKey(0), n_iterations, x_train, y_train, batch_size=32)

params, losses = svi_result.params, svi_result.losses

In [None]:
predictive = Predictive(bnn_model, guide=guide, params=params, num_samples=1000)

y_pred = predictive(random.PRNGKey(1), x_plot)["obs"].copy()

In [None]:
y_upper, y_mu, y_lower = np.quantile(y_pred, [0.05, 0.5, 0.95], axis=0)

In [None]:
fig, ax = plt.subplots()

ax.scatter(x_train, y_train, color="tab:blue", label="Train")
ax.scatter(x_test, y_test, color="red", label="Test")
plt.plot(x_plot, y_mu, color="black", label="Predictions")
plt.plot(x_plot, y_upper, color="tab:orange", label="upper bound")
plt.plot(x_plot, y_lower, color="tab:orange", label="lower bound")
# ax.plot(x_plot, y_plot, color='black', label='True')

plt.legend()
plt.show()