In [None]:
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
import pyro
import torch
import pyro.contrib.gp as gp
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
import os
os.environ["OMP_NUM_THREADS"] = '1'
torch.set_num_threads(4)
pyro.set_rng_seed(9)

In [None]:
def f(x):
    return torch.sin(20*x) +2*torch.cos(14*x) -2*torch.sin(6*x)

In [None]:
# define the gaussian process
X = torch.tensor([-1,-.5,0,.5,1])
y = f(X)
kernel = gp.kernels.RBF(input_dim=1)
gpr = gp.models.GPRegression(X, y, kernel,noise=torch.tensor(10**-4))
gpr.kernel.variance = pyro.nn.PyroSample(dist.LogNormal(torch.tensor(0.), torch.tensor(2.0)))
gpr.kernel.lengthscale = pyro.nn.PyroSample(dist.LogNormal(torch.tensor(-1.0), torch.tensor(1.0)))

In [None]:
# run MCMC sampler
num_chains=4
hmc_kernel = NUTS(gpr.model)
mcmc = MCMC(hmc_kernel, num_samples=500, num_chains=num_chains,
            mp_context='spawn', warmup_steps=10)
mcmc.run()
samples = mcmc.get_samples(group_by_chain=True)

In [None]:
# inspect MCMC quality control metrics
print(az.summary(samples))
ax1= az.plot_trace(samples,transform=np.log,figsize=(12,8))
plt.savefig('MCMC_QC1.png')
az.plot_autocorr(samples,combined=True,max_lag=20,)
plt.savefig('MCMC_QC2.png')
az.plot_ess(samples,kind='local')
plt.savefig('MCMC_QC3.png')

In [None]:
#B.1.3 calculate mean and variance at each point by integrating over posterior samples
Xtest= torch.linspace(-1.0,1., 200)
subsample = mcmc.get_samples(500)
mean_list = []
var_list = []
for post_samp in range(0, 500):
    pyro.clear_param_store()
    kernel = gp.kernels.RBF(input_dim=1)
    kernel.variance = subsample['kernel.variance'][post_samp]
    kernel.lengthscale = subsample['kernel.lengthscale'][post_samp]
    gpr_post = gp.models.GPRegression(X, y, kernel,noise=torch.tensor(0.0001))
    post_pred = gpr_post(Xtest,full_cov=False, noiseless=False)
    mean_list.append(post_pred[0])
    var_list.append(post_pred[1])

In [None]:
mean = sum(mean_list)/len(mean_list)
var = (sum(var_list) + sum([i**2 for i in mean_list]))/len(var_list) - mean**2

In [None]:
# B.1.(1/3)
fig, (ax1,ax2) = plt.subplots(figsize=(10, 4), ncols=2)
fig.tight_layout(pad=4)
ax1.set_title('A', loc='left')
priorlen = dist.LogNormal(-1,1).sample(sample_shape=(500,))
priorvar = dist.LogNormal(0,2).sample(sample_shape=(500,))
ax1.scatter(priorvar, priorlen, s=7, label='prior')
ax1.scatter(subsample['kernel.variance'],
            subsample['kernel.lengthscale'], s=7, label='posterior')
ax1.set_xscale('log')
ax1.set_yscale('log')
ax1.set_xlabel('kernel variance')
ax1.set_ylabel('kernel lengthscale')
ax1.legend()

ax2.set_title('B', loc='left')
with torch.no_grad():
    ax2.plot(X.numpy(), y.numpy(), "kx", label='$\mathcal{D}$')
    sd = var.sqrt()  # standard deviation at each input point x
    ax2.plot(Xtest.numpy(), mean.numpy(), "r", lw=2, label='m(x)')  # plot the mean
    ax2.fill_between(
            Xtest.numpy(),  # plot the two-sigma uncertainty about the mean
            (mean - 2.0 * sd).numpy(),
            (mean + 2.0 * sd).numpy(),
            color="C0",
            alpha=0.3,
            label='2$\sigma$ CI'
        )
ax2.plot(Xtest, f(Xtest), "b", lw=2, label='f(x)')
ax2.set_xlabel('x')
ax2.set_ylabel('y')
ax2.legend(loc= 'lower right')

plt.savefig('GP.png')
plt.show()

In [None]:
# functions for conducting bayesian optimization and plotting it
def bayesian_opt(X,y,loss_method, kappa=2):
    global seed
    # estimate  theta|X,y
    print(i)
    for attempt in range(0,10):
        pyro.clear_param_store()
        kernel = gp.kernels.RBF(input_dim=1)
        kernel.variance = pyro.nn.PyroSample(dist.LogNormal(torch.tensor(0.), torch.tensor(2.0)))
        kernel.lengthscale = pyro.nn.PyroSample(dist.LogNormal(torch.tensor(-1.0), torch.tensor(1.0)))
        gpr_opt = gp.models.GPRegression(X, y, kernel,noise=torch.tensor(10**-4))
        hmc_kernel = NUTS(gpr_opt.model)#, jit_compile=True)
        mcmc = MCMC(hmc_kernel, num_samples=100, num_chains=1,
                mp_context='spawn', warmup_steps=10, disable_progbar=True)
        mcmc.run()
        diagn = mcmc.diagnostics()
        if (diagn['kernel.variance']['r_hat']<=1.05) & (diagn['kernel.variance']['r_hat']<=1.05):
            break
        seed+=1
        pyro.set_rng_seed(seed)
    else:
        raise Exception('No convergence after 10 attempts.')

    posterior = mcmc.get_samples(1)
    # calculate p(f*|X*,theta)
    pyro.clear_param_store()
    kernel = gp.kernels.RBF(input_dim=1)
    kernel.variance = posterior['kernel.variance']
    kernel.lengthscale = posterior['kernel.lengthscale']
    gpr_opt = gp.models.GPRegression(X, y, kernel, noise=torch.tensor(10**-4))
    mean, cov = gpr_opt(Xtest, full_cov=True, noiseless=False)
    sd = cov.diag().sqrt()
    cov += 1e-3 * torch.eye(len(Xtest)) # provides numerical stability

    # sample f*
    fstar = dist.MultivariateNormal(
        mean, covariance_matrix=cov, 
    ).sample()
    # find argmin f*
    if loss_method=='fmin':
        Xstar = Xtest[torch.argmin(fstar)]
    elif loss_method=='LCM':
        LCB = mean - kappa*sd
        Xstar = Xtest[torch.argmin(LCB)]
    else:
        raise Exception('Loss method must be one of "fmin" or "LCM".')
    X = torch.cat((X, Xstar.reshape(1)))
    y = f(X)
    seed+=1
    pyro.set_rng_seed(seed)
    return(X,y,mean,cov,Xstar,fstar)

def bay_opt_plot(X,y,mean,cov,Xstar,fstar, timestep, ax, legend=None):
    ax.plot(Xtest.numpy(), fstar.numpy().T, lw=2, alpha=0.4, label='f*(x)')
    ax.plot(Xtest.numpy(), f(Xtest), lw=2, alpha=0.4, label='f(x)')
    ax.plot(X.numpy(), y.numpy(), "kx", label='$\mathcal{D}$')
    sd = cov.diag().sqrt()  # standard deviation at each input point x
    ax.plot(Xtest.numpy(), mean.numpy(), "r", lw=2, label='m(x)')  # plot the mean
    ax.fill_between(
            Xtest.numpy(),  # plot the two-sigma uncertainty about the mean
            (mean - 2.0 * sd).numpy(),
            (mean + 2.0 * sd).numpy(),
            color="C0",
            alpha=0.3,
            label='2$\sigma$ CI'
    )
    ax.plot(Xstar.numpy(), f(Xstar).numpy(), "bo", markersize=10, label='$(x^*_p, f(x^*_p)$')
    if legend != None:
        ax.legend(loc=legend,prop={'size': 6})
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_title(f"Iteration: {timestep}", loc='left')

In [None]:
# make progression plot for B.2.1
seed=1
Xtest= torch.linspace(-1., 1, 200)
loss_method = 'fmin'
fig, ax = plt.subplots(figsize=(8, 3),ncols=3)
X = torch.tensor([-1,-.5,0,.5,1])
y = f(X)
for i in range(0,11):
    X,y,mean,cov,Xstar,fstar = bayesian_opt(X,y,loss_method, kappa=2)    
    pyro.set_rng_seed(seed)
    # plot
    if i%5 == 0:
        n = i//5
        if n==0:
            legend= 'lower left'
        else:
            legend= None
        with torch.no_grad():
            bay_opt_plot(X,y,mean,cov,Xstar,fstar, timestep=i, ax=ax[n], legend=legend)
plt.savefig(f'bayopt.png')

In [None]:
seed=1
X = torch.tensor([-1,-.5,0,.5,1])
y = f(X)
Xtest= torch.linspace(-1., 1, 200)
loss_method = 'fmin'
for j in range(0,10):
    loss_method = 'fmin'
    fig, ax = plt.subplots(figsize=(6,4),nrows=1)
    X = torch.tensor([-1,-.5,0,.5,1])
    y = f(X)
    for i in range(0,11):
        X,y,mean,cov,Xstar,fstar = bayesian_opt(X,y,loss_method, kappa=2)    
        pyro.set_rng_seed(seed)
        # plot
        if i== 10:
            with torch.no_grad():
                bay_opt_plot(X,y,mean,cov,Xstar,fstar, timestep=i, ax=ax,legend= 'lower left')
    plt.savefig(f'PML_finals/bayopt{j}.png')
    plt.show()

In [None]:
seed=1
X = torch.tensor([-1,-.5,0,.5,1])
y = f(X)
Xtest= torch.linspace(-1., 1, 200)
loss_method = 'LCM'
for j in range(0,10):
    fig, ax = plt.subplots(figsize=(6, 4),nrows=1)
    X = torch.tensor([-1,-.5,0,.5,1])
    y = f(X)
    for i in range(0,11):
        X,y,mean,cov,Xstar,fstar = bayesian_opt(X,y,loss_method, kappa=2)    
        pyro.set_rng_seed(seed)
        # plot
        if i == 10:
            n = i//5
            with torch.no_grad():
                bay_opt_plot(X,y,mean,cov,Xstar,fstar, timestep=i, ax=ax,legend= 'lower left')
    plt.savefig(f'PML_finals/bayopt_LCM{j}.png')
    plt.show()