# Didactic intro

In [None]:
import numpy as np
import scipy.stats as st
import matplotlib.pyplot as plt
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF

# Input parameters
xmin, xmax = -2, 2
nx = 100
weights, means, stds = [0.7, 0.3], [-1, 1.2], [0.25, 0.35]
# Tmp random choice, to find one that looks good
#x_initial = np.array([np.random.random() for _ in range(2)]) * (xmax - xmin) + xmin
#print(x_initial)
x_initial = [-1.3, 0.65, 1.80]
corr_length_initial = 0.5

# Plot parameters
truth_color = "0.25"
truth_ls = "--"
samples_marker = "o"
samples_color = "k"
gpr_color = "tab:blue"
gpr_alpha = 0.25
acq_color = "tab:orange"
prop_ls = ":"
prop_lw = 2

In [None]:
# Figure settings

params = {'axes.labelsize': 14,
          'axes.titlesize': 22,
          'font.size': 16,
          'legend.fontsize': 14,
          'font.family': 'serif',
          'font.sans-serif': ['Bitstream Vera Sans'],
          'font.serif': ['Bitstream Vera'],
          'xtick.labelsize': 16,
          'ytick.labelsize': 16,
          'text.usetex': True,
          'text.latex.preamble': r"""\usepackage{amsmath} \usepackage{amssymb} \usepackage{amsfonts}""",
         }
plt.rcParams.update(params)

In [None]:
# Create truth
xs = np.linspace(xmin, xmax, nx)
if len(means) == 1:
    logpdf = st.norm(means[0], stds[0]).logpdf
else:
    logpdf = lambda x: np.log(sum(weight * st.norm(mean, std).pdf(x) for weight, mean, std in zip(weights, means, stds)))
# Create GP model
kernel = RBF(length_scale=corr_length_initial, length_scale_bounds=(1e-1, 10))
gpr = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=9)
y_initial = logpdf(x_initial)
gpr.fit(np.atleast_2d(x_initial).T, y_initial)
# Create acquisition function
logacq = lambda x, mean, std: 2 * mean + np.log(std)

# Use krigging believer every i%n?
use_kg = False

legend_labels = {}

fig, axes = plt.subplots(nrows=2, ncols=4, sharex=True, sharey="row",
                         gridspec_kw={"height_ratios": [2, 1]}, figsize=(12, 4),
                         layout="constrained"
                        )
for ax in axes[0]:
    legsymbol_truth = ax.plot(xs, logpdf(xs), c=truth_color, ls=truth_ls)
legend_labels[legsymbol_truth[0]] = "True log-pdf"

# Initial GP fit and 0th iteration
i = 0
legsymbol_samples = axes[0][i].scatter(x_initial, y_initial, c=samples_color, marker=samples_marker)
legend_labels[legsymbol_samples] = "Truth evals."
mean_0, std_0 = gpr.predict(np.atleast_2d(xs).T, return_std=True)
legsymbol_mean = axes[0][i].plot(xs, mean_0, c=gpr_color)
legend_labels[legsymbol_mean[0]] = "GPR mean"
legsymbol_bar = axes[0][i].fill_between(xs, mean_0 - 1.96 * std_0, mean_0 + 1.96 * std_0,
                                        alpha=gpr_alpha, color=gpr_color)
legend_labels[legsymbol_bar] = r"GPR $95\%$ c.i."
# Initial acq and next proposal
logacq_0 = logacq(xs, mean_0, std_0)
legsymbol_acq = axes[1][i].plot(xs, logacq_0, c=acq_color)
legend_labels[legsymbol_acq[0]] = "Acquisition func."
x_prop_0 = xs[np.argmax(logacq_0)]
legsymbol_prop = axes[0][i].axvline(x_prop_0, c=acq_color, ls=prop_ls, lw=prop_lw)
legend_labels[legsymbol_prop] = "Proposed next eval."
axes[1][i].axvline(x_prop_0, c=acq_color, ls=prop_ls, lw=prop_lw)

# Add, fit, and next iteration
x_1 = list(x_initial) + [x_prop_0]
if use_kg is not False and not i % use_kg:
    y_1 = gpr.predict(np.atleast_2d(x_1).T)
else:
    y_1 = logpdf(x_1)
gpr.fit(np.atleast_2d(x_1).T, y_1)

# 1st iteration
i = 1
axes[0][i].scatter(x_1, y_1, c=samples_color, marker=samples_marker)
mean_1, std_1 = gpr.predict(np.atleast_2d(xs).T, return_std=True)
axes[0][i].plot(xs, mean_1, c=gpr_color)
axes[0][i].fill_between(xs, mean_1 - 1.96 * std_1, mean_1 + 1.96 * std_1,
                        alpha=gpr_alpha, color=gpr_color)
# Initial acq and next proposal
logacq_1 = logacq(xs, mean_1, std_1)
axes[1][i].plot(xs, logacq_1, c=acq_color)
x_prop_1 = xs[np.argmax(logacq_1)]
axes[0][i].axvline(x_prop_1, c=acq_color, ls=prop_ls, lw=prop_lw)
axes[1][i].axvline(x_prop_1, c=acq_color, ls=prop_ls, lw=prop_lw)

# Add, fit, and next iteration
x_2 = list(x_1) + [x_prop_1]
if use_kg is not False and not i % use_kg:
    y_2 = gpr.predict(np.atleast_2d(x_2).T)
else:
    y_2 = logpdf(x_2)
gpr.fit(np.atleast_2d(x_2).T, y_2)

# 2nd iteration
i = 2
axes[0][i].scatter(x_2, y_2, c=samples_color, marker=samples_marker)
mean_2, std_2 = gpr.predict(np.atleast_2d(xs).T, return_std=True)
axes[0][i].plot(xs, mean_2, c=gpr_color)
axes[0][i].fill_between(xs, mean_2 - 1.96 * std_2, mean_2 + 1.96 * std_2,
                        alpha=gpr_alpha, color=gpr_color)
# Initial acq and next proposal
logacq_2 = logacq(xs, mean_2, std_2)
axes[1][i].plot(xs, logacq_2, c=acq_color)
x_prop_2 = xs[np.argmax(logacq_2)]
axes[0][i].axvline(x_prop_2, c=acq_color, ls=prop_ls, lw=prop_lw)
axes[1][i].axvline(x_prop_2, c=acq_color, ls=prop_ls, lw=prop_lw)

# Add, fit, and next iteration
x_3 = list(x_2) + [x_prop_2]
if use_kg is not False and not i % use_kg:
    y_3 = gpr.predict(np.atleast_2d(x_3).T)
else:
    y_3 = logpdf(x_3)
gpr.fit(np.atleast_2d(x_3).T, y_3)

# 3rd iteration
i = 3
axes[0][i].scatter(x_3, y_3, c=samples_color, marker=samples_marker)
mean_3, std_3 = gpr.predict(np.atleast_2d(xs).T, return_std=True)
axes[0][i].plot(xs, mean_3, c=gpr_color)
axes[0][i].fill_between(xs, mean_3 - 1.96 * std_3, mean_3 + 1.96 * std_3,
                        alpha=gpr_alpha, color=gpr_color)
# Initial acq and next proposal
logacq_3 = logacq(xs, mean_3, std_3)
axes[1][i].plot(xs, logacq_3, c=acq_color)
x_prop_3 = xs[np.argmax(logacq_3)]
axes[0][i].axvline(x_prop_3, c=acq_color, ls=prop_ls, lw=prop_lw)
axes[1][i].axvline(x_prop_3, c=acq_color, ls=prop_ls, lw=prop_lw)

# Adjustments
axes[0][0].set_ylim(-10, 3)
axes[1][0].set_ylim(-20, None)
for ax in axes[1]:
    ax.set_xlabel("x")
axes[0][0].set_ylabel(r"$\log(p)$")
axes[1][0].set_ylabel(r"$\log(a)$")
fig.legend(list(legend_labels), list(legend_labels.values()), ncols=len(legend_labels), loc="upper center", bbox_to_anchor=(0, 1, 1, 0.125))
plt.savefig("images/active_learning.png", dpi=200, bbox_inches="tight")