In [None]:
# © Crown Copyright GCHQ
#
# Licensed under the GNU General Public License, version 3 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
# This notebook is not compiled into the documentation due to the time taken to run it to get
# a representative analysis. Please run this notebook locally if you wish to see the outputs.

In [None]:
# sphinx ignore

import sys

sys.path.append("../..")

%config Completer.use_jedi = False

In [None]:
random_seed = 1_989
num_iters = 100

In [None]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
import torch
from gpytorch import constraints, kernels, likelihoods, means
from tqdm import tqdm

from vanguard.datasets.air_passengers import AirPassengers
from vanguard.datasets.synthetic import SyntheticDataset, complicated_f
from vanguard.hierarchical import BayesianHyperparameters, LaplaceHierarchicalHyperparameters
from vanguard.learning import LearnYNoise
from vanguard.normalise import NormaliseY
from vanguard.vanilla import GaussianGPController

In [None]:
DATASET = SyntheticDataset(functions=(complicated_f,), rng=np.random.default_rng(random_seed))
train_test_split_index = len(DATASET.train_x)

In [None]:
class ScaledRBFKernel(kernels.ScaleKernel):
    def __init__(self, active_dims=None, batch_shape=torch.Size([])):
        super().__init__(
            kernels.RBFKernel(active_dims=active_dims, batch_shape=batch_shape),
            batch_shape=batch_shape,
        )


class ScaledMaternKernel(kernels.ScaleKernel):
    def __init__(self, active_dims=None, batch_shape=torch.Size([])):
        super().__init__(
            kernels.MaternKernel(nu=0.5, active_dims=active_dims, batch_shape=batch_shape),
            batch_shape=batch_shape,
        )


class ScaledPeriodicKernel(kernels.ScaleKernel):
    def __init__(self, active_dims=None, batch_shape=torch.Size([])):
        super().__init__(
            kernels.PeriodicKernel(active_dims=active_dims, batch_shape=batch_shape),
            batch_shape=batch_shape,
        )


class Kernel(kernels.ProductKernel):
    def __init__(self, batch_shape=torch.Size([])):
        super().__init__(
            ScaledRBFKernel(batch_shape=batch_shape),
            kernels.PeriodicKernel(batch_shape=batch_shape),
        )

In [None]:
@LearnYNoise(ignore_all=True)
class PointEstimateController(GaussianGPController):
    pass


gp = PointEstimateController(
    train_x=DATASET.train_x,
    train_y=DATASET.train_y,
    kernel_class=Kernel,
    y_std=DATASET.train_y_std,
    optim_kwargs={"lr": 0.5},
    rng=np.random.default_rng(random_seed),
)

with gp.metrics_tracker.print_metrics(every=20):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        gp.fit(n_sgd_iters=num_iters)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    posterior = gp.posterior_over_point(DATASET.test_x)
    likelihood = gp.predictive_likelihood(DATASET.test_x)

mu, lower, upper = posterior.confidence_interval()
l_mu, l_lower, l_upper = likelihood.confidence_interval()

plt_x = DATASET.test_x.ravel()
plt.figure(figsize=(15, 7))
plt.plot(plt_x, l_mu, label="likelihood")
plt.fill_between(plt_x, l_lower, l_upper, alpha=0.2, label="likelihood CI")
plt.plot(plt_x, DATASET.test_y, "x", label="data")
plt.grid(which="both")
plt.legend()
print(f"Log probability: {likelihood.log_probability(DATASET.test_y)}")

In [None]:
@BayesianHyperparameters()
class BayesianRBFKernel(kernels.RBFKernel):
    pass


@BayesianHyperparameters()
class BayesianPeriodicKernel(kernels.PeriodicKernel):
    pass


@BayesianHyperparameters()
class BayesianScaleKernel(kernels.ScaleKernel):
    pass


class BayesianScaledRBFKernel(BayesianScaleKernel):
    def __init__(self, active_dims=None, batch_shape=torch.Size([])):
        super().__init__(
            BayesianRBFKernel(active_dims=active_dims, batch_shape=batch_shape),
            batch_shape=batch_shape,
        )


class BayesianScaledPeriodicKernel(BayesianScaleKernel):
    def __init__(self, batch_shape=torch.Size([]), active_dims=None):
        super().__init__(
            BayesianPeriodicKernel(active_dims=active_dims, batch_shape=batch_shape),
            batch_shape=batch_shape,
        )


class BayesianKernel(kernels.ProductKernel):
    def __init__(self, batch_shape=torch.Size([])):
        super().__init__(
            BayesianScaledRBFKernel(batch_shape=batch_shape),
            BayesianPeriodicKernel(batch_shape=batch_shape),
        )


@BayesianHyperparameters()
class BayesianConstantMean(means.ConstantMean):
    pass


@BayesianHyperparameters()
class BayesianFixedNoiseGaussianLikelihood(likelihoods.FixedNoiseGaussianLikelihood):
    pass

In [None]:
@LaplaceHierarchicalHyperparameters(num_mc_samples=100, ignore_all=True)
class FullBayesianController(PointEstimateController):
    pass


gp = FullBayesianController(
    train_x=DATASET.train_x,
    train_y=DATASET.train_y,
    kernel_class=BayesianKernel,
    y_std=DATASET.train_y_std,
    mean_class=BayesianConstantMean,
    likelihood_class=BayesianFixedNoiseGaussianLikelihood,
    optim_kwargs={"lr": 0.5},
    rng=np.random.default_rng(random_seed),
)

with gp.metrics_tracker.print_metrics(every=20):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        gp.fit(n_sgd_iters=num_iters)

In [None]:
plt.imshow(gp.hyperparameter_posterior.covariance_matrix.detach().cpu().numpy())
plt.colorbar()
print(gp.hyperparameter_posterior.mean.detach().cpu().numpy())

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    posterior = gp.posterior_over_point(DATASET.test_x)
    likelihood = gp.predictive_likelihood(DATASET.test_x)

mu, lower, upper = posterior.confidence_interval()
l_mu, l_lower, l_upper = likelihood.confidence_interval()
plt_x = DATASET.test_x.ravel()

# Convert from tensors to numpy arrays for plotting
plt_x = plt_x.detach().cpu().numpy()
l_mu = l_mu.detach().cpu().numpy()
l_lower = l_lower.detach().cpu().numpy()
l_upper = l_upper.detach().cpu().numpy()
plt_y = DATASET.test_y.detach().cpu().numpy()

plt.figure(figsize=(10, 4))
plt.plot(plt_x, l_mu, label="likelihood")
plt.fill_between(plt_x, l_lower, l_upper, alpha=0.2, label="likelihood CI")
plt.plot(plt_x, plt_y, "x", label="data")
plt.grid(which="both")
plt.legend()
print(f"Log probability: {likelihood.log_probability(DATASET.test_y)}")

In [None]:
plt.figure(figsize=(10, 4))
plt.plot(posterior.sample(500).T)
plt.show()

In [None]:
temps = np.logspace(-5, 0, 20)
log_probs = []
for _ in tqdm(range(20)):
    lp = []
    for temperature in temps:
        gp.temperature = temperature
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            likelihood = gp.predictive_likelihood(DATASET.test_x)
        lp.append(likelihood.log_probability(DATASET.test_y))
    log_probs.append(lp)

log_probs = np.array(log_probs)
plt.plot(temps, log_probs.T)
plt.xscale("log")
plt.grid()
plt.show()

In [None]:
mean_log_probs = np.mean(log_probs, axis=0)
plt.plot(temps, mean_log_probs, label="empirical mean")
plt.vlines(
    [gp.auto_temperature()],
    [min(mean_log_probs)],
    [max(mean_log_probs)],
    linestyles="--",
    color="r",
    label="auto temperature",
)
plt.xscale("log")
plt.ylabel("log probability")
plt.xlabel("temperature")
plt.legend()
plt.grid()
plt.show()

In [None]:
data = AirPassengers()
df = data._load_data()

train_test_split_index = 100
x = df.index.values.astype(float)
y = df.y.values.astype(float)
train_x, train_y = x[:train_test_split_index], y[:train_test_split_index]
test_x, test_y = x[train_test_split_index:], y[train_test_split_index:]

In [None]:
linear_co_constraint = constraints.Interval(0.0, 1.0)


class AirlineKernel(kernels.AdditiveKernel):
    def __init__(self, batch_shape=torch.Size([])):
        local_period = ScaledRBFKernel(batch_shape=batch_shape)
        local_period *= kernels.PeriodicKernel(batch_shape=batch_shape)
        linear = kernels.LinearKernel(
            batch_shape=batch_shape,
            variance_constraint=linear_co_constraint,
        )
        rbf = ScaledRBFKernel(batch_shape=batch_shape)
        super().__init__(local_period, linear, rbf)

In [None]:
@NormaliseY()
@LearnYNoise(ignore_all=True)
class PointEstimateController(GaussianGPController):
    pass


gp = PointEstimateController(
    train_x=train_x,
    train_y=train_y,
    kernel_class=AirlineKernel,
    y_std=0,
    optim_kwargs={"lr": 0.1},
    rng=np.random.default_rng(random_seed),
)

with gp.metrics_tracker.print_metrics(every=20):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        gp.fit(n_sgd_iters=num_iters)

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    posterior = gp.posterior_over_point(x)
    likelihood = gp.predictive_likelihood(x)

mu, lower, upper = posterior.confidence_interval()
l_mu, l_lower, l_upper = likelihood.confidence_interval()

plt_x = x.ravel()
plt.figure(figsize=(15, 7))
plt.plot(plt_x, l_mu, label="likelihood")
plt.fill_between(plt_x, l_lower, l_upper, alpha=0.2, label="likelihood CI")
plt.plot(train_x, train_y, "x", label="train data")
plt.plot(test_x, test_y, "o", label="test data")
plt.grid(which="both")
plt.legend()
print(f"Log probability: {likelihood.log_probability(torch.tensor(y))}")

In [None]:
@BayesianHyperparameters()
class BayesianLinearKernel(kernels.LinearKernel):
    pass


class BayesianAirlineKernel(kernels.AdditiveKernel):
    def __init__(self, batch_shape=torch.Size([])):
        periodic = BayesianPeriodicKernel(batch_shape=batch_shape)
        local_period = BayesianScaledRBFKernel(batch_shape=batch_shape) * periodic
        linear = BayesianLinearKernel(
            batch_shape=batch_shape,
            variance_constraint=linear_co_constraint,
        )
        rbf = BayesianScaledRBFKernel(batch_shape=batch_shape)
        super().__init__(local_period, linear, rbf)

In [None]:
@LaplaceHierarchicalHyperparameters(num_mc_samples=100, ignore_all=True)
class FullBayesianController(PointEstimateController):
    pass


laplace_gp = FullBayesianController(
    train_x=train_x,
    train_y=train_y,
    kernel_class=BayesianAirlineKernel,
    y_std=0,
    mean_class=BayesianConstantMean,
    likelihood_class=BayesianFixedNoiseGaussianLikelihood,
    optim_kwargs={"lr": 0.1},
    rng=np.random.default_rng(random_seed),
)

with laplace_gp.metrics_tracker.print_metrics(every=20):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        laplace_gp.fit(n_sgd_iters=num_iters)

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    posterior = laplace_gp.posterior_over_point(x)
    laplace_likelihood = laplace_gp.predictive_likelihood(x)

laplace_mu, laplace_lower, laplace_upper = laplace_likelihood.confidence_interval()

plt_x = x.ravel()
plt.figure(figsize=(15, 7))
plt.plot(plt_x, laplace_mu.detach().cpu().numpy(), label="likelihood")
plt.fill_between(
    plt_x, laplace_lower.detach().cpu().numpy(), laplace_upper.detach().cpu().numpy(), alpha=0.2, label="likelihood CI"
)
plt.plot(train_x, train_y, "x", label="train data")
plt.plot(test_x, test_y, "o", label="test data")
plt.grid(which="both")
plt.legend()
print(f"Log probability: {laplace_likelihood.log_probability(torch.tensor(y))}")

In [None]:
plt.imshow(laplace_gp.hyperparameter_posterior.covariance_matrix.detach().cpu().numpy())
plt.colorbar()
print(laplace_gp.hyperparameter_posterior.mean.detach().cpu().numpy())

In [None]:
temps = np.logspace(-5, 0, 20)
log_probs = []
for run_index in tqdm(range(20)):
    lp = []
    for temperature in temps:
        try:
            laplace_gp.temperature = temperature
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                likelihood = laplace_gp.predictive_likelihood(test_x)
            lp.append(likelihood.log_probability(torch.tensor(test_y)))
        except Exception:
            print(f"Skipping temperature {temperature} run {run_index+1} due to numerical issues")
            lp.append(np.nan)

    log_probs.append(lp)

log_probs = np.array(log_probs)
plt.plot(temps, log_probs.T)
plt.xscale("log")
plt.grid()
plt.show()

In [None]:
mean_log_probs = np.mean(log_probs, axis=0)
plt.plot(temps, mean_log_probs, label="empirical mean")
plt.vlines(
    [laplace_gp.auto_temperature()],
    [min(mean_log_probs)],
    [max(mean_log_probs)],
    linestyles="--",
    color="r",
    label="auto temperature",
)
plt.xscale("log")
plt.ylabel("log probability")
plt.xlabel("temperature")
plt.legend()
plt.grid()
plt.show()