In [None]:
%matplotlib inline

In [None]:
import math
import os
import sys
import warnings

import matplotlib as mpl
import matplotlib.pylab as plt
import numpy as np
import torch

import h5py

warnings.simplefilter("ignore", UserWarning)

In [None]:
os.chdir("..")

In [None]:
from optbnn.bnn.likelihoods import LikCE, LikGaussian
from optbnn.bnn.nets.mlp import MLP
from optbnn.bnn.priors import FixedGaussianPrior, OptimGaussianPrior
from optbnn.bnn.reparam_nets import GaussianMLPReparameterization
from optbnn.gp.models.model import LCFModel
from optbnn.gp.reward_functions import pen_task_reward_prior
from optbnn.metrics.sampling import compute_rhat_regression
from optbnn.prior_mappers.wasserstein_mapper import (
    MapperWasserstein,
    WassersteinDistance,
)
from optbnn.sgmcmc_bayes_net.pref_net import PrefNet
from optbnn.utils import util
from optbnn.utils.rand_generators import DataSetSampler

In [None]:
mpl.rcParams["figure.dpi"] = 100

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
OUT_DIR = "./exp/reward_learning/pen"
FIG_DIR = os.path.join(OUT_DIR, "figures")
util.ensure_dir(OUT_DIR)
util.ensure_dir(FIG_DIR)

In [None]:
def plot_samples(
    X, samples, var=None, n_keep=12, color="xkcd:bluish", smooth_q=False, ax=None
):
    if ax is None:
        ax = plt.gca()
    if samples.ndim > 2:
        samples = samples.squeeze()
    n_keep = int(samples.shape[1] / 10) if n_keep is None else n_keep
    keep_idx = np.random.permutation(samples.shape[1])[:n_keep]
    mu = samples.mean(1)
    if var is None:
        q = 97.72  ## corresponds to 2 stdevs in Gaussian
        # q = 99.99  ## corresponds to 3 std
        Q = np.percentile(samples, [100 - q, q], axis=1)
        # ub, lb = Q[1,:], Q[0,:]
        ub, lb = mu + 2 * samples.std(1), mu - 2 * samples.std(1)
        if smooth_q:
            lb = moving_average(lb)
            ub = moving_average(ub)
    else:
        ub = mu + 3 * np.sqrt(var)
        lb = mu - 3 * np.sqrt(var)
    ####
    ax.fill_between(X.flatten(), ub, lb, color=color, alpha=0.25, lw=0)
    ax.plot(X, samples[:, keep_idx], color=color, alpha=0.8)
    ax.plot(X, mu, color="xkcd:red")

In [None]:
util.set_seed(1)

# p_mean = np.array([0.0, -1.0, 1.0, 10.0, 50.0, -5.0])
p_covariance = 3*np.identity(6)
pen_prior = LCFModel(p_covariance, pen_task_reward_prior)
pen_prior = pen_prior.to(device)

In [None]:
util.set_seed(1)
# Initialize BNN Priors
width = 64  # Number of units in each hidden layer
depth = 3  # Number of hidden layers
transfer_fn = "relu"  # Activation function

# Initialize Gaussian prior.
# Fixed Prior
std_bnn = GaussianMLPReparameterization(
    input_dim=69, output_dim=1, activation_fn=transfer_fn, hidden_dims=[width] * depth
)

# Prior to be optimized
opt_bnn = GaussianMLPReparameterization(
    input_dim=69, output_dim=1, activation_fn=transfer_fn, hidden_dims=[width] * depth
)

std_bnn = std_bnn.to(device)
opt_bnn = opt_bnn.to(device)

In [None]:
util.set_seed(1)
with h5py.File("data/adroit_pen/adroit_pen_tuning_set.hdf5") as f:
    data_generator = DataSetSampler(f["obs"][:], f["aux_obs"][:])

In [None]:
mapper_num_iters = 1000

In [None]:
# Initiialize the Wasserstein optimizer
util.set_seed(1)
mapper = MapperWasserstein(
    pen_prior,
    opt_bnn,
    data_generator,
    out_dir=OUT_DIR,
    input_dim=69,
    wasserstein_steps=(0, 1200),
    wasserstein_lr=0.001,
    n_data=512,
    n_gpu=1,
    gpu_gp=True,
)

# Start optimizing the prior
w_hist = mapper.optimize(
    num_iters=mapper_num_iters,
    n_samples=1024,
    lr=0.08,
    save_ckpt_every=50,
    print_every=20,
    debug=True,
)
path = os.path.join(OUT_DIR, "wsr_values.log")
np.savetxt(path, w_hist, fmt="%.6e")

In [None]:
# Visualize progression of the prior optimization
wdist_file = os.path.join(OUT_DIR, "wsr_values.log")
wdist_vals = np.loadtxt(wdist_file)

fig = plt.figure(figsize=(6, 3.5))
indices = np.arange(mapper_num_iters)[::5]
plt.plot(indices, wdist_vals[indices], "-ko", ms=4)
plt.ylabel(r"$W_1(p_{gp}, p_{nn})$")
plt.xlabel("Iteration")
plt.show()

In [None]:
# Load the optimize prior
util.set_seed(1)
ckpt_path = os.path.join(OUT_DIR, "ckpts", "it-{}.ckpt".format(mapper_num_iters))
opt_bnn.load_state_dict(torch.load(ckpt_path))

In [None]:
# Draw functions from the priors
n_plot = 4000
util.set_seed(8)
X, aux_X = data_generator.get(100)

gp_samples = (
    pen_prior.sample_functions(X, n_plot, aux_X).detach().cpu().numpy().squeeze()
)

std_bnn_samples = (
    std_bnn.sample_functions(X.float(), n_plot).detach().cpu().numpy().squeeze()
)

opt_bnn_samples = (
    opt_bnn.sample_functions(X.float(), n_plot).detach().cpu().numpy().squeeze()
)

seq = np.arange(100)

fig, axs = plt.subplots(1, 3, figsize=(14, 3))
plot_samples(seq, gp_samples, ax=axs[0], n_keep=5)
axs[0].set_title("GP Prior")
axs[0].set_ylim([-4, 4])

plot_samples(seq, std_bnn_samples, ax=axs[1], color="xkcd:grass", n_keep=5)
axs[1].set_title("BNN Prior (Fixed)")
axs[1].set_ylim([-4, 4])

plot_samples(seq, opt_bnn_samples, ax=axs[2], color="xkcd:yellowish orange", n_keep=5)
axs[2].set_title("BNN Prior (GP-induced)")
axs[2].set_ylim([-4, 4])

plt.tight_layout()
plt.show()

In [None]:
# SGHMC Hyper-parameters
sampling_configs = {
    "batch_size": 256,  # Mini-batch size
    "num_samples": 40,  # Total number of samples for each chain
    "n_discarded": 10,  # Number of the first samples to be discared for each chain
    "num_burn_in_steps": 2000,  # Number of burn-in steps
    "keep_every": 2000,  # Thinning interval
    "lr": 0.01,  # Step size
    "num_chains": 4,  # Number of chains
    "mdecay": 0.01,  # Momentum coefficient
    "print_every_n_samples": 5,
}

In [None]:
X_train, y_train, X_test, _ = util.load_pref_data(
    "data/adroit_pen/AdroitHandPen-v1_pref_b.hdf5", 0.8
)
X_test = X_test[:, :, :, :69].reshape(-1, 69)

In [None]:
# Initialize the prior
util.set_seed(1)
prior = FixedGaussianPrior(std=1.0)

# Setup likelihood
net = MLP(69, 1, [width] * depth, transfer_fn)
likelihood = LikCE()

# Initialize the sampler
saved_dir = os.path.join(OUT_DIR, "sampling_std")
util.ensure_dir(saved_dir)
bayes_net_std = PrefNet(net, likelihood, prior, saved_dir, n_gpu=0)

# Start sampling
bayes_net_std.sample_multi_chains(X_train, y_train, **sampling_configs)

In [None]:
# Make predictions
util.set_seed(1)
_, _, bnn_std_preds = bayes_net_std.predict(X_test, True)
# Convergence diagnostics using the R-hat statistic
r_hat = compute_rhat_regression(bnn_std_preds, sampling_configs["num_chains"])
print(r"R-hat: mean {:.4f} std {:.4f}".format(float(r_hat.mean()), float(r_hat.std())))
bnn_std_preds = bnn_std_preds.squeeze().T

# Save the predictions
posterior_std_path = os.path.join(OUT_DIR, "posterior_std.npz")
np.savez(posterior_std_path, bnn_samples=bnn_std_preds)

In [None]:
# Load the optimized prior
ckpt_path = os.path.join(OUT_DIR, "ckpts", "it-{}.ckpt".format(mapper_num_iters))
prior = OptimGaussianPrior(ckpt_path)

# Setup likelihood
net = MLP(69, 1, [width] * depth, transfer_fn)
likelihood = LikCE()

# Initialize the sampler
saved_dir = os.path.join(OUT_DIR, "sampling_optim")
util.ensure_dir(saved_dir)
bayes_net_optim = PrefNet(net, likelihood, prior, saved_dir, n_gpu=0)

# Start sampling
bayes_net_optim.sample_multi_chains(X_train, y_train, **sampling_configs)

In [None]:
# Make predictions
util.set_seed(1)
_, _, bnn_optim_preds = bayes_net_optim.predict(X_test, True)

# Convergence diagnostics using the R-hat statistic
r_hat = compute_rhat_regression(bnn_optim_preds, sampling_configs["num_chains"])
print(r"R-hat: mean {:.4f} std {:.4f}".format(float(r_hat.mean()), float(r_hat.std())))
bnn_optim_preds = bnn_optim_preds.squeeze().T

# Save the predictions
posterior_optim_path = os.path.join(OUT_DIR, "posterior_optim.npz")
np.savez(posterior_optim_path, bnn_samples=bnn_optim_preds)

In [None]:
util.set_seed(8)
fig, axs = plt.subplots(1, 2, figsize=(14, 3))
plot_samples(seq, bnn_std_preds[100:200], ax=axs[0], color="xkcd:grass", n_keep=16)
axs[0].set_title("BNN Posterior (Fixed)")
axs[0].set_ylim([-4, 4])

plot_samples(
    seq, bnn_optim_preds[100:200], ax=axs[1], color="xkcd:yellowish orange", n_keep=16
)
axs[1].set_title("BNN Posterior (GP-induced)")
axs[1].set_ylim([-4, 4])

plt.tight_layout()
plt.show()