In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from IPython.display import clear_output

In [None]:
import math
import torch
import numpy as np
import gpytorch
from matplotlib import pyplot as plt

# Classification using the Dirichlet Classification Likelihood, botorch and EasyBO

We adapt the tutorial [here](https://docs.gpytorch.ai/en/stable/examples/01_Exact_GPs/GP_Regression_on_Classification_Labels.html).

First, generate the data. We have two inputs ($x, y$) and three outputs (three-class classification).

In [None]:
def gen_data(num_data, seed = 2019):
    torch.random.manual_seed(seed)

    x = torch.randn(num_data,1)
    y = torch.randn(num_data,1)

    u = torch.rand(1)
    data_fn = lambda x, y: 1 * torch.sin(0.15 * u * 3.1415 * (x + y)) + 1
    latent_fn = data_fn(x, y)
    z = torch.round(latent_fn).long().squeeze()
    return torch.cat((x,y),dim=1), z, data_fn

In [None]:
train_x, train_y, genfn = gen_data(500)

In [None]:
plt.scatter(train_x[:,0].numpy(), train_x[:,1].numpy(), c = train_y)

In [None]:
test_d1 = np.linspace(-3, 3, 20)
test_d2 = np.linspace(-3, 3, 20)

test_x_mat, test_y_mat = np.meshgrid(test_d1, test_d2)
test_x_mat, test_y_mat = torch.Tensor(test_x_mat), torch.Tensor(test_y_mat)

test_x = torch.cat((test_x_mat.view(-1,1), test_y_mat.view(-1,1)),dim=1)
test_labels = torch.round(genfn(test_x_mat, test_y_mat))
test_y = test_labels.view(-1)

In [None]:
plt.contourf(test_x_mat.numpy(), test_y_mat.numpy(), test_labels.numpy())

## Model initialization and training

In [None]:
from easyBO import gp, bo

In [None]:
model = gp.get_gp(train_x=train_x, train_y=train_y, gp_type="classification")

In [None]:
losses = gp.train_gp_(model=model)

In [None]:
test_dist = gp.infer(model=model, grid=test_x, parsed=False)

In [None]:
pred_means = test_dist.mean

In [None]:
fig, ax = plt.subplots(1, 3, figsize = (15, 5))

for i in range(3):
    im = ax[i].contourf(
        test_x_mat.numpy(), test_y_mat.numpy(), pred_means[i].numpy().reshape((20,20))
    )
    fig.colorbar(im, ax=ax[i])
    ax[i].set_title("Logits: Class " + str(i), fontsize = 20)

In [None]:
pred_samples = test_dist.sample(torch.Size((256,))).exp()
probabilities = (pred_samples / pred_samples.sum(-2, keepdim=True)).mean(0)

In [None]:
fig, ax = plt.subplots(1, 3, figsize = (15, 5))

levels = np.linspace(0, 1.05, 20)
for i in range(3):
    im = ax[i].contourf(
        test_x_mat.numpy(), test_y_mat.numpy(), probabilities[i].numpy().reshape((20,20)), levels=levels
    )
    fig.colorbar(im, ax=ax[i])
    ax[i].set_title("Probabilities: Class " + str(i), fontsize = 20)

In [None]:
from easyBO import bo
from botorch.acquisition.objective import ScalarizedPosteriorTransform

In [None]:
posterior_transform = ScalarizedPosteriorTransform(weights=torch.tensor([1.0, 1.0, 1.0]))

In [None]:
torch.manual_seed(0) # to keep the restart conditions the same
bounds = [(-3, 3), (-3, 3)]
pt = bo.ask(
    model=model,
    bounds=bounds,
    acquisition_function="qMaxVar",
    acquisition_function_kwargs={"posterior_transform": posterior_transform},
    optimize_acqf_kwargs={"q": 3, "num_restarts": 20, "raw_samples": 512}
)