In [None]:
from baynes import *
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

import cmdstanpy
import logging
cmdstanpy.utils.get_logger().setLevel(logging.ERROR)

# Gaussian Processes

In [None]:
def cov_exp_quad(x1, x2, alpha, rho):
    x1 = np.asarray(x1)
    x2 = np.asarray(x1)
    k = []
    for x in x1:
        k.append((x-x2)**2)
    return alpha**2 * np.exp(-np.array(k)/(2*rho**2))

def cov_periodic(x1, x2, alpha, rho, p):
    x1 = np.asarray(x1)
    x2 = np.asarray(x1)
    k = []
    for x in x1:
        k.append(np.abs(x-x2))
    return alpha**2 * np.exp(-2*np.sin(np.pi*np.array(k)/p)**2/(rho**2))

## Generate data

In [None]:
N2=31
x = np.linspace(0,10, N2)
N1=11
x1 = np.sort(np.random.choice(x, N1))
y1 = 0.3*x1 + np.random.normal(0,0.1,N1)+ 0.5* np.sin(x1)
plt.scatter(x1, y1)

## Exponentiated quadratic kernel

In [None]:
data={'N1':N1, 'N2':N2, 'x1':x1, 'x2':x, 'y1':y1}
fit = get_model('GP_quadratic.stan').sample(data,
                   chains=8,
                   iter_warmup=500,
                   iter_sampling=500,
                   save_warmup=True,
                   adapt_delta=0.9,
                   inits=1)
print(fit.diagnose())


In [None]:
fplot = FitPlotter(fit)
df = fit.draws_pd(['y2']).transpose()
df['x']=x
df = df.melt(id_vars='x')
ax = sns.lineplot(df, x='x', y='value', errorbar=hdi, label='GP median')
ax.set_ylabel('y')
sns.scatterplot(x=x1, y=y1, ax=ax, label='data')
fplot.new_figure('GP', ax.figure)
fplot.update_legend(edgecolor='white', bbox_to_anchor=(0.6,0.85))


In [None]:
means = fit.draws_pd(['alpha', 'rho']).mean(axis=0)
z=cov_exp_quad(x, x, means['alpha'], means['rho'])
f = fplot.new_figure('kernel')
ax = f.subplots()
c = plt.pcolor(x,x,z)
c.set_edgecolor('face')
f.colorbar(c)
ax.set_xlabel('x')
ax.set_ylabel(r"x'")

## Periodic kernel

In [None]:
model = get_model('GP_periodic.stan')
fit = model.sample(data,
                   chains=4,
                   iter_warmup=500,
                   iter_sampling=500,
                   save_warmup=True,
                   show_progress=True)
print(fit.diagnose())
fplot.add_fit(fit)
df = fit.draws_pd(['y2']).transpose()
df['x']=x
df = df.melt(id_vars='x')
ax = sns.lineplot(df, x='x', y='value', errorbar=hdi, label='GP median')
ax.set_ylabel('y')
sns.scatterplot(x=x1, y=y1, ax=ax, label='data')
fplot.new_figure('GP', ax.figure)
fplot.update_legend(edgecolor='white', bbox_to_anchor=(1.2,0.65))

In [None]:
f = fplot.new_figure('kernel')
ax = f.subplots()
alpha, rho, p = fit.summary()['Mean'].loc[['alpha', 'rho', 'p']]
z = cov_periodic(x, x, alpha, rho, p) 
c = ax.pcolor(x,x,z)
c.set_edgecolor('face')
f.colorbar(c)
ax.set_xlabel('x')
ax.set_ylabel(r"x'")