In [None]:
# sphinx ignore

import sys

sys.path.append("../..")

%config Completer.use_jedi = False
random_seed = 1_989

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from gpytorch.kernels import MaternKernel, ScaleKernel
from gpytorch.mlls import VariationalELBO

from vanguard.datasets.bike import BikeDataset
from vanguard.uncertainty import GaussianUncertaintyGPController
from vanguard.vanilla import GaussianGPController
from vanguard.variational import VariationalInference
from vanguard.warps import SetWarp, warpfunctions

In [None]:
DATASET = BikeDataset(rng=np.random.default_rng(random_seed))

In [None]:
plt.hist(DATASET.train_y)
plt.xlabel("$y$", fontsize=15)
plt.show()

In [None]:
N_DATA_POINTS = 500
N_INDUCING_POINTS = 20
DATASET = BikeDataset(num_samples=N_DATA_POINTS, rng=np.random.default_rng(random_seed))


@VariationalInference(n_inducing_points=N_INDUCING_POINTS, ignore_methods=("__init__",))
class GaussianVariationalGPController(GaussianGPController):
    """Does variational inference."""

    pass


class ScaledMaternKernel(ScaleKernel):
    """A scaled matern kernel."""

    def __init__(self):
        super().__init__(MaternKernel(nu=1.5, ard_num_dims=2))


# TODO: Include a batch_size argument in this example when functionality resolved
# https://github.com/gchq/Vanguard/issues/377
gp = GaussianVariationalGPController(
    train_x=DATASET.train_x[:, [4, 7]],
    train_y=DATASET.train_y,
    kernel_class=ScaledMaternKernel,
    y_std=0.001 * np.mean(np.abs(DATASET.train_y)),
    marginal_log_likelihood_class=VariationalELBO,
    likelihood_kwargs={"learn_additional_noise": True},
    optim_kwargs={"lr": 0.01},
    rng=np.random.default_rng(random_seed),
)

with gp.metrics_tracker.print_metrics(every=200):
    gp.fit(n_sgd_iters=2000)

In [None]:
inducing_points = gp._gp.variational_strategy.inducing_points.detach().cpu().numpy()
x = DATASET.train_x[:, [4, 7]]

plt.scatter(x[:, 0], x[:, 1])
plt.scatter(inducing_points[:, 0], inducing_points[:, 1], marker="x")
plt.show()

In [None]:
SLOW = False

In [None]:
N_INDUCING_POINTS = 750 if SLOW else 20


@VariationalInference(n_inducing_points=N_INDUCING_POINTS, ignore_methods=("__init__",))
class GaussianVariationalGPController(GaussianGPController):
    """Does variational inference."""

    pass

In [None]:
BATCH_SIZE = 256
NUM_ITERS = max(len(DATASET.train_x) // BATCH_SIZE, 15) * (100 if SLOW else 10)
print(NUM_ITERS)

In [None]:
class ScaledMaternKernel(ScaleKernel):
    """A scaled matern kernel."""

    def __init__(self):
        super().__init__(MaternKernel(nu=1.5, ard_num_dims=DATASET.train_x.shape[1]))

In [None]:
gp = GaussianVariationalGPController(
    train_x=DATASET.train_x,
    train_y=DATASET.train_y,
    kernel_class=ScaledMaternKernel,
    y_std=0.001 * np.mean(np.abs(DATASET.train_y)),
    marginal_log_likelihood_class=VariationalELBO,
    likelihood_kwargs={"learn_additional_noise": True},
    batch_size=BATCH_SIZE,
    optim_kwargs={"lr": 0.01},
    rng=np.random.default_rng(random_seed),
)

with gp.metrics_tracker.print_metrics(every=150):
    gp.fit(n_sgd_iters=NUM_ITERS)

In [None]:
posterior = gp.predictive_likelihood(DATASET.test_x)
DATASET.plot_prediction(*posterior.confidence_interval())
plt.show()

In [None]:
warp = warpfunctions.AffineWarpFunction() @ warpfunctions.BoxCoxWarpFunction(lambda_=0)


@SetWarp(warp_function=warp, ignore_methods=("fit", "__init__"))
@VariationalInference(n_inducing_points=N_INDUCING_POINTS, ignore_methods=("__init__",))
class WarpedGaussianVariationalGPController(GaussianGPController):
    """Does variational inference."""

    pass

In [None]:
gp = WarpedGaussianVariationalGPController(
    train_x=DATASET.train_x,
    train_y=DATASET.train_y,
    kernel_class=ScaledMaternKernel,
    y_std=0.001 * np.mean(np.abs(DATASET.train_y)),
    marginal_log_likelihood_class=VariationalELBO,
    likelihood_kwargs={"learn_additional_noise": True},
    batch_size=BATCH_SIZE,
    optim_kwargs={"lr": 0.01},
    rng=np.random.default_rng(random_seed),
)

with gp.metrics_tracker.print_metrics(every=150):
    gp.fit(n_sgd_iters=NUM_ITERS)

In [None]:
warp_posterior = gp.predictive_likelihood(DATASET.test_x)
DATASET.plot_prediction(*warp_posterior.confidence_interval())
plt.show()

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
DATASET.plot_prediction(*warp_posterior.confidence_interval(), y_upper_bound=0.5)
plt.title("Warping. " + plt.gca().title.get_text())
plt.subplot(1, 2, 2)
DATASET.plot_prediction(*posterior.confidence_interval(), y_upper_bound=0.5)
plt.title("No warping. " + plt.gca().title.get_text())
plt.show()

In [None]:
warp = warpfunctions.AffineWarpFunction() @ warpfunctions.BoxCoxWarpFunction(lambda_=0)


@SetWarp(warp_function=warp, ignore_all=True)
@VariationalInference(n_inducing_points=N_INDUCING_POINTS, ignore_all=True)
class WarpedGaussianUncertaintyVariationalGPController(GaussianUncertaintyGPController):
    """Does variational inference."""

    pass

In [None]:
gp = WarpedGaussianUncertaintyVariationalGPController(
    train_x=DATASET.train_x,
    train_x_std=0.1,
    train_y=DATASET.train_y,
    kernel_class=ScaledMaternKernel,
    y_std=0.001 * np.mean(np.abs(DATASET.train_y)),
    marginal_log_likelihood_class=VariationalELBO,
    likelihood_kwargs={"learn_additional_noise": True},
    batch_size=BATCH_SIZE,
    optim_kwargs={"lr": 0.01},
    rng=np.random.default_rng(random_seed),
)

with gp.metrics_tracker.print_metrics(every=150):
    gp.fit(n_sgd_iters=NUM_ITERS)

In [None]:
posterior = gp.predictive_likelihood(DATASET.test_x)
DATASET.plot_prediction(*warp_posterior.confidence_interval())
plt.show()