# 2D mixture toy demonstration of VSD

In [None]:
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.distributions as td
from botorch import fit_gpytorch_mll
from botorch.acquisition.analytic import LogProbabilityOfImprovement
from botorch.models import SingleTaskGP
from gpytorch.kernels import MaternKernel, ScaleKernel
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.priors import GammaPrior
from scipy.stats import multinomial, multivariate_normal
from sklearn.metrics import balanced_accuracy_score, r2_score

from vsd.acquisition import (LogPIClassiferAcquisition,
                             VariationalSearchAcquisition)
from vsd.generation import generate_candidates_reinforce
from vsd.proposals import GaussianKDEProposal
from vsd.surrogates import ContinuousCPEModel, fit_cpe
from vsd.thresholds import AnnealedThreshold

SEED = 42
np.random.seed(seed=SEED)
torch.manual_seed(SEED)

# Sim properties

In [None]:
N_TRAINING = 50
GRADIENT_SAMPLES = 1024
N_CANDIDATES = 30
N_ROUNDS = 10
USE_CLASSIFIER = True  # GP or CPE
INITIAL_STD = 6
INITIAL_MEAN = [0, 0]
THRESH = 2.2
INCREASING_THRESH = False  # Swtich between FL (false) and BBO (true)
PERCENTILE = 0.5
TEMPERATURE = 0.7
CLASSIFIER_REG = 1e-4

## Simulate data

Simulate a "fitness landscape" from a mixture of Gaussians

In [None]:
w = np.array([1., 1, 1, 1, 1, 1, 1])
w /= sum(w)
mus = np.array([
    [-1, 3],
    [1, 2],
    [0, 4],
    [-3, -1.5],
    [-3, -3.5],
    [2, -2],
    [-1, 0]
])
covs = np.array([np.eye(2) for _ in range(len(mus))])
norms = np.array([multivariate_normal(mean=m, cov=S) for m, S in zip(mus, covs)])

def mixture_px(X):
    pX = np.array([n.pdf(X) for n in norms]).squeeze()
    pX = (w @ pX)
    return pX * 100

def mixture_sample(size=1):
    pw = multinomial(n=1, p=w)
    z = pw.rvs(size=size).astype(bool)
    X = []
    for i in range(size):
        x = norms[z[i]][0].rvs()
        X.append(x)
    return np.array(X)

In [None]:
prior = td.Independent(td.Normal(
    torch.tensor(INITIAL_MEAN, dtype=float),
    torch.tensor(INITIAL_STD, dtype=float)),
    1)

X = prior.sample([N_TRAINING])
X = X.float()

ngrid = 300
gbound = 10
mX, mY= np.meshgrid(np.linspace(-gbound, gbound, ngrid),
                    np.linspace(-gbound, gbound, ngrid))
gX = np.vstack([mX.flatten(), mY.flatten()]).T
gXT = torch.Tensor(gX)
pX = mixture_px(gX)
pmax = pX.max()

def plot_contours(f, Xsamples=None, ax=None, title=None):
    if ax is None:
        _, ax = plt.subplots(dpi=150, figsize=(10, 8))
    cs = ax.contour(mX, mY, f.reshape([ngrid, ngrid]), levels=7)
    if Xsamples is not None:
        ax.plot(Xsamples[:, 0], Xsamples[:, 1], 'r.')
    ax.clabel(cs, inline=True, fontsize=8)
    ax.grid()
    ax.set_title(title)

plot_contours(pX, X, title="True f")
plt.show()

In [None]:
nX, nY= np.meshgrid(np.linspace(-7, 7, ngrid),
                    np.linspace(-7, 7, ngrid))
ngX = np.vstack([nX.flatten(), nY.flatten()]).T

prX = np.exp(prior.log_prob(torch.tensor(ngX)).detach().numpy())
pnX = mixture_px(ngX)


nlevels = 20
gmax = np.argmax(pnX)
levels = np.linspace(0, pnX.max(), nlevels)
thresh = [THRESH, pnX.max()]

pXgY = np.array(prX)
pXgY[pnX < THRESH] = -1
xlevels = np.linspace(0, pXgY.max(), nlevels)

_, ax = plt.subplots(dpi=150, figsize=(4, 4))
ax.contourf(nX, nY, prX.reshape([ngrid, ngrid]), levels=nlevels,
            cmap="Blues")
ax.set_axis_off()
plt.show()

_, ax = plt.subplots(dpi=150, figsize=(4, 4))
ax.contourf(nX, nY, pnX.reshape([ngrid, ngrid]), levels=levels, cmap="Greys")
ax.plot(*ngX[gmax], 'x', color="white", markersize=7, label="max")
ax.set_axis_off()
plt.show()

plt.rcParams['hatch.color'] = "white"
_, ax = plt.subplots(dpi=150, figsize=(4, 4))
ax.contourf(nX, nY, pnX.reshape([ngrid, ngrid]), levels=levels, cmap="Greys")
ax.contour(nX, nY, pnX.reshape([ngrid, ngrid]), levels=thresh,
              colors="white")
ax.contourf(nX, nY, pnX.reshape([ngrid, ngrid]), levels=thresh,
               colors="white", hatches=["///", None], alpha=0.5)
ax.set_axis_off()
plt.show()

_, ax = plt.subplots(dpi=150, figsize=(4, 4))
ax.contour(nX, nY, pnX.reshape([ngrid, ngrid]), levels=thresh,
              colors="white")
ax.contourf(nX, nY, pnX.reshape([ngrid, ngrid]), levels=levels,
               cmap="Greys")
ax.contourf(nX, nY, pXgY.reshape([ngrid, ngrid]), levels=xlevels,
               cmap="Blues")
ax.set_axis_off()
plt.show()

_, ax = plt.subplots(dpi=150, figsize=(4, 4))
ax.contour(nX, nY, pnX.reshape([ngrid, ngrid]), levels=thresh,
              colors="white")
ax.contourf(nX, nY, pnX.reshape([ngrid, ngrid]), levels=levels,
               cmap="Greys")
ax.contourf(nX, nY, pnX.reshape([ngrid, ngrid]), levels=levels[levels >= THRESH],
               cmap="Greens", vmin=levels[0], vmax=levels[-1])
ax.set_axis_off()
plt.show()

## Bayesian Optimisation loop

### Surrogate model initial training

In [None]:
# Training
Xt = X
Yt = torch.tensor(mixture_px(X)).float()

# Testing
Xs_np = mixture_sample(100)
Xs = torch.tensor(Xs_np).float()
Ys = torch.tensor(mixture_px(Xs_np)).float()


# Adaptive threshold?
if INCREASING_THRESH:
    thresh = AnnealedThreshold(PERCENTILE, TEMPERATURE)
    THRESH = thresh(Yt)
print(f"Initial threshold: {THRESH:.3f}")


# Model initialisation
if USE_CLASSIFIER:
    model = ContinuousCPEModel(x_dim=2, latent_dim=100, dropoutp=0)
    fit_cpe(
        model,
        Xt.float(), Yt.float(),
        best_f=THRESH,
        batch_size=len(Yt),
        optimizer_options=dict(weight_decay=CLASSIFIER_REG),
        stop_options=dict(n_window=2000)
    )

    for n, x, y in (("train", Xt, Yt), ("test", Xs, Ys)):
        z = (y > 0).detach().type(torch.int).flatten()
        ez = (model(x) > np.log(.5)).type(torch.int).detach().flatten()
        score = balanced_accuracy_score(ez, z)
        print(f"Clf {n} balanced acc. = {score:.3f}")
else:
    kernel = ScaleKernel(MaternKernel(nu=1.5, lengthscale_prior=GammaPrior(3, 1)))
    model = SingleTaskGP(train_X=Xt, train_Y=Yt.unsqueeze(-1),
                         covar_module=kernel)
    mll = ExactMarginalLogLikelihood(model.likelihood, model)
    mll = fit_gpytorch_mll(mll, max_attempts=10)

    for n, x, y in (("train", Xt, Yt), ("test", Xs, Ys)):
        ey = np.asarray(model(x).loc.detach())
        score = r2_score(ey, y.flatten().numpy())
        print(f"GP {n} R^2 = {score:.3f}")

In [None]:
if USE_CLASSIFIER:
    logPY = model(gXT).detach().numpy()
    PY = np.exp(logPY)
    _, axs = plt.subplots(1, 2, dpi=150, figsize=(10, 5))
    plot_contours(PY, X, title="Initial classifier $p(y > \\tau | x)$", ax=axs[0])
    plot_contours(PY*(1-PY), X, title="Initial classifier variance", ax=axs[1])
    plt.show()

else:
    pred = model(gXT)
    epX = pred.mean.detach()
    spX = torch.sqrt(pred.variance).detach()

    _, axs = plt.subplots(1, 2, dpi=150, figsize=(10, 5))
    plot_contours(epX, X, title="Initial GP predictive mean", ax=axs[0])
    plot_contours(spX, X, title="Initial GP predictive confidence", ax=axs[1])
    plt.show()

In [None]:
prop = GaussianKDEProposal(2, k_components=10, scale=1., mu_scale_init=5)

# Expected log likelihood
if USE_CLASSIFIER:
    acq = LogPIClassiferAcquisition(model=model)
else:
    acq = LogProbabilityOfImprovement(model=model, best_f=THRESH)

# ELBO
elbo = VariationalSearchAcquisition(acq, prior, kl_weight=1)

iter, acquis, normgrad = [], [], []

def callback(i: int, loss: torch.Tensor, grad: Tuple[torch.Tensor]):
    """For logging."""
    global iter
    global acquis
    global normgrad
    iter.append(i)
    acquis.append(-loss)
    normgrad.append(np.mean([g.mean() for g in grad]))

def plot_optimisation(iter, acquis, meangrad, ax=None):
    if ax is None:
        _, ax = plt.subplots(dpi=150)
    ax.plot(iter, meangrad)
    ax.set_xlabel("Iterations")
    ax.set_ylabel("Mean Gradient")
    ax.grid()
    ax2 = ax.twinx()
    ax2.set_ylabel("Acquisition", color="tab:red")
    ax2.plot(iter, acquis, color="tab:red")
    ax2.tick_params(axis="y", labelcolor="tab:red")


In [None]:
for r in range(1, N_ROUNDS+1):
    print(f"----- Round {r} -----")

    # Optimise search distribution
    iter, acquis, normgrad = [], [], []
    Xc, cand_acquis = generate_candidates_reinforce(
        proposal_distribution=prop,
        acquisition_function=elbo,
        callback=callback,
        stop_options=dict(maxiter=30000, n_window=5000),
        optimizer_options=dict(lr=1e-3),
        cv_smoothing=0.7,
        gradient_samples=GRADIENT_SAMPLES,
        candidate_samples=N_CANDIDATES
    )

    fig, axs = plt.subplots(3, 2, dpi=150, figsize=(10, 10))
    plot_optimisation(iter, acquis, normgrad, axs[0, 0])

    # Plot candidates on real density
    plot_contours(pX, Xc, title=f"R{r} q samples", ax=axs[0, 1])

    # Plot candidates on posterior density
    logqX = prop.log_prob(gXT).detach()
    qX = torch.exp(logqX)
    plot_contours(qX, Xc, title=f"R{r} q density", ax=axs[2, 0])

    # Update model and plot
    Yc = torch.Tensor(mixture_px(Xc))
    Xt = torch.concat((Xt, Xc))
    Yt = torch.concat((Yt, Yc))
    if INCREASING_THRESH:
        THRESH = thresh(Yt)
        print(f"New threshold: {THRESH:.3f}")
    if USE_CLASSIFIER:
        logPY = acq(gXT).detach().numpy()
        PY = np.exp(logPY)
        plot_contours(PY, Xc, title=f"R{r} classifier $p(y > \\tau | x)$",
                      ax=axs[1, 0])
        plot_contours(PY*(1-PY), Xc, title=f"R{r} classifier variance",
                      ax=axs[1, 1])
        fit_cpe(
            model,
            Xt.float(), Yt.float(),
            best_f=THRESH,
            batch_size=len(Yt),
            optimizer_options=dict(weight_decay=CLASSIFIER_REG)
        )

    else:
        logPY = acq(gXT.unsqueeze(1)).detach().numpy()
        pred = model(gXT)
        spX = torch.sqrt(pred.variance).detach()
        plot_contours(np.exp(logPY), Xc, title=f"R{r} GP $p(y > \\tau | x)$",
                      ax=axs[1, 0])
        plot_contours(spX, Xc, title=f"R{r} GP predictive confidence", ax=axs[1, 1])
        model = model.condition_on_observations(Xc, Yc.unsqueeze(-1))
        elbo.acq.model = model
        elbo.acq.best_f = torch.tensor(THRESH)


    # Plot candidates on acquisition function
    logPYX = logPY + prior.log_prob(gXT).detach().numpy()
    plot_contours(logPYX, Xc, title=f"R{r} $\\log p(y > \\tau, x)$",
                  ax=axs[2, 1])

    fig.tight_layout()
    plt.show()