In [None]:
%matplotlib inline

In [None]:
import math
import os
import sys
import warnings

import jax
import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pylab as plt
import numpy as np
import orbax.checkpoint as ocp
from flax import nnx

warnings.simplefilter("ignore", UserWarning)
sys.path.insert(0, os.path.abspath("../.."))

import transformers.models.mean_functions as mean_functions
from transformers.models import kernels, normalization
from transformers.models.b_mlp import GaussianMLPReparameterization
from transformers.models.bnn_likelihoods import LikGaussian
from transformers.models.bnn_priors import FixedGaussianPrior
from transformers.models.gp import GPR
from transformers.models.mlp import MLP
from transformers.models.rand_generators import GridGenerator
from transformers.training import utils
from transformers.training.jax_utils import vgrad
from transformers.training.wasserstein_mapper import MapperWasserstein

In [None]:
mpl.rcParams["figure.dpi"] = 100

In [None]:
OUT_DIR = os.path.expanduser(
    "~/busy-beeway/transformers/exp/1D_synthetic/tanh_gaussian_new"
)
FIG_DIR = os.path.join(OUT_DIR, "figures")
utils.ensure_dir(OUT_DIR)
utils.ensure_dir(FIG_DIR)

In [None]:
rng = np.random.default_rng(12345)

In [None]:
def make_random_gap(X, gap_ratio=0.2, rng=np.random.default_rng()):
    a, b = X.min(), X.max()
    gap_a = a + rng.random() * (b - a) * (1 - gap_ratio)
    gap_b = gap_a + (b - a) * gap_ratio
    idx = np.logical_and(gap_a < X, X < gap_b)
    if gap_a - a > b - gap_b:
        X[idx] = a + rng.random(idx.sum()) * (gap_a - a)
    else:
        X[idx] = gap_b + rng.random(idx.sum()) * (b - gap_b)


def gp_sample(X, ampl=1, leng=1, sn2=0.1, rng=np.random.default_rng()):
    n, x = X.shape[0], X / leng
    sum_xx = np.sum(x * x, 1).reshape(-1, 1).repeat(n, 1)
    D = sum_xx + sum_xx.transpose() - 2 * np.matmul(x, x.transpose())
    C = ampl**2 * np.exp(-0.5 * D) + np.eye(n) * sn2
    return rng.multivariate_normal(np.zeros(n), C).reshape(-1, 1)


def plot_samples(
    X,
    samples,
    var=None,
    n_keep=12,
    color="xkcd:bluish",
    smooth_q=False,
    ax=None,
    rng=np.random.default_rng(),
):
    if ax is None:
        ax = plt.gca()
    if samples.ndim > 2:
        samples = samples.squeeze()
    n_keep = int(samples.shape[1] / 10) if n_keep is None else n_keep
    keep_idx = rng.permutation(samples.shape[1])[:n_keep]
    mu = samples.mean(1)
    if var is None:
        q = 97.72  ## corresponds to 2 stdevs in Gaussian
        # q = 99.99  ## corresponds to 3 std
        Q = np.percentile(samples, [100 - q, q], axis=1)
        # ub, lb = Q[1,:], Q[0,:]
        ub, lb = mu + 2 * samples.std(1), mu - 2 * samples.std(1)
        if smooth_q:
            lb = moving_average(lb)
            ub = moving_average(ub)
    else:
        ub = mu + 3 * np.sqrt(var)
        lb = mu - 3 * np.sqrt(var)
    ####
    ax.fill_between(X.flatten(), ub, lb, color=color, alpha=0.25, lw=0)
    ax.plot(X, samples[:, keep_idx], color=color, alpha=0.8)
    ax.plot(X, mu, color="xkcd:red")

In [None]:
N = 64
M = 100
a, b = -10, 10

# Generate data
X = rng.random((N, 1)) * (b - a) + a
make_random_gap(X, gap_ratio=0.4, rng=rng)
y = gp_sample(X, ampl=1.6, leng=1.8, rng=rng)
Xtest = np.linspace(a - 5, b + 5, M).reshape(-1, 1)

# Normalize the dataset
X_, X_mean, X_std = normalization.zscore_normalization(X)
y_, y_mean, y_std = normalization.zscore_normalization(y)
Xtest_, _, _ = normalization.zscore_normalization(Xtest, X_mean, X_std)

In [None]:
fig = plt.figure()
plt.plot(X, y, "ko", ms=5)
plt.title("Dataset")
plt.show()

In [None]:
# GP hyper-parameters
rngs = nnx.Rngs(12345)
sn2 = 0.1  # noise variance
leng = 0.6  # lengthscale
ampl = 1.0  # amplitude

# Initialize GP Prior
kernel = kernels.RBF(
    input_dim=1,
    ARD=True,
    lengthscales=jnp.array([leng], dtype=jnp.double),
    variance=jnp.array([ampl], dtype=jnp.double),
)

gpmodel = GPR(
    X=X_,
    Y=y_.reshape((-1, 1)),
    kern=kernel,
    mean_function=mean_functions.Zero(),
    jitter_level=5e-5,
    rngs=rngs,
)
gpmodel.likelihood.variance.set(sn2)

In [None]:
# Initialize BNN Priors
width = 50  # Number of units in each hidden layer
depth = 3  # Number of hidden layers
transfer_fn = "tanh"  # Activation function

# Initialize Gaussian prior.
# Fixed Prior
std_bnn = GaussianMLPReparameterization(
    input_dim=1,
    output_dim=1,
    activation_fn=transfer_fn,
    hidden_dims=[width] * depth,
    rngs=rngs,
)

# Prior to be optimized
opt_bnn = GaussianMLPReparameterization(
    input_dim=1,
    output_dim=1,
    activation_fn=transfer_fn,
    hidden_dims=[width] * depth,
    rngs=rngs,
)

In [None]:
data_generator = GridGenerator(-6, 6)

In [None]:
mapper_num_iters = 800

In [None]:
# mapper = MapperWasserstein(
#     gpmodel,
#     opt_bnn,
#     data_generator,
#     out_dir=OUT_DIR,
#     wasserstein_steps=(0, 1000),
#     wasserstein_lr=0.08,
#     n_data=200,
#     rngs=rngs,
# )

# # Start optimizing the prior
# w_hist = mapper.optimize(
#     num_iters=mapper_num_iters,
#     n_samples=512,
#     lr=0.01,
#     save_ckpt_every=50,
#     print_every=20,
# )
# path = os.path.join(OUT_DIR, "wsr_values.log")
# np.savetxt(path, w_hist, fmt="%.6e")

In [None]:
# Visualize progression of the prior optimization
wdist_file = os.path.join(OUT_DIR, "wsr_values.log")
wdist_vals = np.loadtxt(wdist_file)

fig = plt.figure(figsize=(6, 3.5))
indices = np.arange(mapper_num_iters)[::5]
plt.plot(indices, wdist_vals[indices], "-ko", ms=4)
plt.ylabel(r"$W_1(p_{gp}, p_{nn})$")
plt.xlabel("Iteration")
plt.show()

In [None]:
ckpt_path = os.path.join(OUT_DIR, "ckpts", "it-{}.ckpt".format(mapper_num_iters))
checkpointer = ocp.StandardCheckpointer()
empty_bnn = GaussianMLPReparameterization(
    input_dim=1,
    output_dim=1,
    activation_fn=transfer_fn,
    hidden_dims=[width] * depth,
    rngs=rngs,
)
utils.prng_to_raw(empty_bnn)
abstract_bnn = nnx.eval_shape(lambda: empty_bnn)
graphdef, abstract_state = nnx.split(abstract_bnn)
state_restored = checkpointer.restore(ckpt_path, abstract_state)
opt_bnn = nnx.merge(graphdef, state_restored)
utils.raw_to_prng(opt_bnn)
checkpointer.wait_until_finished()
checkpointer.close()

In [None]:
# Draw functions from the priors
n_plot = 4000

gp_samples = gpmodel.sample_functions(jnp.float32(Xtest_), n_plot).squeeze()

gp_samples = normalization.zscore_unnormalization(gp_samples, y_mean, y_std)

std_bnn_samples = std_bnn.sample_functions(jnp.float32(Xtest_), n_plot).squeeze()

std_bnn_samples = normalization.zscore_unnormalization(std_bnn_samples, y_mean, y_std)

opt_bnn_samples = opt_bnn.sample_functions(jnp.float32(Xtest_), n_plot).squeeze()

opt_bnn_samples = normalization.zscore_unnormalization(opt_bnn_samples, y_mean, y_std)

fig, axs = plt.subplots(1, 3, figsize=(14, 3))
plot_samples(Xtest, gp_samples, ax=axs[0], n_keep=5, rng=rng)
axs[0].set_title("GP Prior")
axs[0].set_ylim([-5, 5])

plot_samples(Xtest, std_bnn_samples, ax=axs[1], color="xkcd:grass", n_keep=5, rng=rng)
axs[1].set_title("BNN Prior (Fixed)")
axs[1].set_ylim([-5, 5])

plot_samples(
    Xtest, opt_bnn_samples, ax=axs[2], color="xkcd:yellowish orange", n_keep=5, rng=rng
)
axs[2].set_title("BNN Prior (GP-induced)")
axs[2].set_ylim([-5, 5])

plt.tight_layout()
plt.show()

In [None]:
gp_preds = gpmodel.predict_f_samples(Xtest_, 1000)
gp_preds = gp_preds.squeeze()
gp_preds = normalization.zscore_unnormalization(gp_preds, y_mean, y_std)

In [None]:
# SGHMC Hyper-parameters
sampling_configs = {
    "batch_size": 32,  # Mini-batch size
    "num_samples": 30,  # Total number of samples for each chain
    "n_discarded": 10,  # Number of the first samples to be discared for each chain
    "num_burn_in_steps": 2000,  # Number of burn-in steps
    "keep_every": 200,  # Thinning interval
    "lr": 0.01,  # Step size
    "num_chains": 4,  # Number of chains
    "mdecay": 0.01,  # Momentum coefficient
    "print_every_n_samples": 5,
}

In [None]:
rngs = nnx.Rngs(8675)
prior = FixedGaussianPrior(std=1.0)

# Setup likelihood
net = MLP(1, 1, [width] * depth, transfer_fn, rngs)
likelihood = LikGaussian(sn2)

# Initialize the sampler
saved_dir = os.path.join(OUT_DIR, "sampling_std")
utils.ensure_dir(saved_dir)
bayes_net_std = RegressionNet(net, likelihood, prior, saved_dir, n_gpu=0)

# # Start sampling
# bayes_net_std.sample_multi_chains(X, y, **sampling_configs)