# Custom Gaussian process

## Define Gaussian process

In [1]:
from torch import Tensor
from gpytorch.models import ExactGP
from gpytorch.means import ZeroMean, ConstantMean
from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel
from gpytorch.distributions import MultivariateNormal
from gpytorch.likelihoods import Likelihood


class GaussianProcess(ExactGP):

    def __init__(self,
                 x_train: Tensor, 
                 y_train: Tensor,
                 likelihood: Likelihood) -> None:

        # initialise ExactGP
        super(GaussianProcess, self).__init__(x_train, y_train, likelihood)

        # specify mean function and covariance kernel
        self.mean_module = ZeroMean()
        self.covar_module = ScaleKernel(
            base_kernel=RBFKernel(ard_num_dims=x_train.shape[-1])
        )

    def forward(self, x: Tensor) -> MultivariateNormal:

        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)

        return MultivariateNormal(mean_x, covar_x)


## Generate training data

In [2]:
from nubo.test_functions import Hartmann6D
from nubo.utils import gen_inputs


# test function
func = Hartmann6D(minimise=False)
dims = func.dims
bounds = func.bounds

# training data
x_train = gen_inputs(num_points=dims*5,
                     num_dims=dims,
                     bounds=bounds)
y_train = func(x_train)

## Fit Gaussian process

In [3]:
from nubo.models import fit_gp
from gpytorch.likelihoods import GaussianLikelihood

  
# initialise Gaussian process
likelihood = GaussianLikelihood()
gp = GaussianProcess(x_train, y_train, likelihood=likelihood)

# fit Gaussian process
fit_gp(x_train, y_train, gp=gp, likelihood=likelihood, lr=0.1, steps=200)

## Predict test point

In [4]:
import torch


# sample test point
x_test = torch.rand((5, dims))

# set Gaussian Process to eval mode
gp.eval()

# make predictions
pred = gp(x_test)

# predictive mean and variance
mean = pred.mean
variance = pred.variance.clamp_min(1e-10)

print(f"Mean: {mean.detach()}")
print(f"Variance: {variance.detach()}")

Mean: tensor([-0.1095, -0.1708,  0.1029,  0.2947, -0.0896], dtype=torch.float64)
Variance: tensor([0.0838, 0.0431, 0.0119, 0.0370, 0.0544], dtype=torch.float64)
