## Mixture Kernel With non-stationary local weight

The Gaussian process in consideration is a mixture of independent processes. In specific, there is one global process $\mathcal{GP}_{g}$ and a collection of local processes $\{\mathcal{GP}_{l_i}\}$. A local non-stationary weight is added to each of the processes, to balance between exploring with global process and exploiting with local processes. The weight is modeled as Gaussian functions, to give the expression below:
\begin{gather*}
    f(\mathbf{x}) = e^{-\frac{\lVert \mathbf{x}-\pmb{\psi_g}\rVert_2^2}{2*\sigma_g^2}} f_{g}(\mathbf{x}) + \sum_i e^{-\frac{\lVert \mathbf{x}-\pmb{\psi_l}\rVert_2^2}{2*\sigma_{l_i}^2}} f_{l_i}(\mathbf{x}),\\
    f_{g} \sim \mathcal{GP}_{g},\quad f_{l_i} \sim \mathcal{GP}_{l_i},
\end{gather*}

$\pmb{\psi}$ denotes the position of the center of the influence region of a process.

Additivity of Gaussian processes results in the sum being Gaussian processes. If we further assume no correlation between the $\mathcal{GP}$ s, then we may describe $\mathcal{GP}_{tot}:\ f(\mathbf{x})\sim\mathcal{GP}_{tot}$ uniquely with mean function and covariance kernel as:

\begin{align*}
    k(\mathbf x_1, \mathbf x_2) = &\exp\left(\frac{\lVert \mathbf{x_1}-\pmb{\psi_g}\rVert_2^2 + \lVert \mathbf{x_2}-\pmb{\psi_g}\rVert_2^2}{2\sigma_g^2}\right)k_g(\mathbf x_1, \mathbf x_2)\\
    &+\sum_i \exp\left(\frac{\lVert \mathbf{x_1}-\pmb{\psi_g}\rVert_2^2 + \lVert \mathbf{x_2}-\pmb{\psi_g}\rVert_2^2}{2\sigma_{l_i}^2}\right)k_{l_i}(\mathbf x_1, \mathbf x_2),\\
        m(\mathbf x) = & m_g(\mathbf{x}) + \sum_i m_{l_i}(\mathbf{x}) 
\end{align*}


Our assumptions are: 1). Local Kernels are at a same place, and the areas of influence are isotropic. 2). input of $\mathbf x$ is vaguely standardized to $[-1,1]_d$, which can be used for setting the priors of sub-kernel hyperparameters and position/weights hyperparameters. 3). Global kernel has near uniform weight, which can be simulated by placing at $\pmb{\psi_g} = [0.5]_d$ and $\sigma_g$ being large, e.g. taken to be 10. 4). Global weight $\sigma_g$ isn't a hyperparameter, while local weights $\sigma_{l_i}$ s are hyperparameters. If necessary, we can use a unified $\sigma_{l}$. 5). To emphasize the local/global weights, $\sigma_l \ll \sigma_g$ might also be necessary. The practical hyperparameter might be $\sigma_l/\sigma_g$, constrained to $(0,1)$, or its logarithm constrained to $(-\infty,0)$, instead of $\sigma_l$.




In [2]:
from typing import Optional, Tuple
import gpytorch
from gpytorch.constraints import Interval
from gpytorch.priors import Prior
import torch
from torch._C import Size
from torch.nn import ModuleList
from ast import match_case
from math import sqrt
from statistics import linear_regression
from typing import Iterable
from gpytorch.kernels import Kernel
from linear_operator.operators import ZeroLinearOperator
from numpy import iterable
from gpytorch.distributions import MultivariateNormal

# Example from Spartan Kernel

class SpartanKernel(Kernel):

        has_lengthscale = False

        def __init__(self, global_kernel: Kernel, local_kernels: Iterable[Kernel], 
                     ard_num_dims: int = 1,
                     local_position_prior: Optional[Prior] = None,
                     local_position_constraint: Optional[Interval] = None,
                     eps: float = 0.000001, **kwargs):
                
                
                super(SpartanKernel, self).__init__(ard_num_dims=ard_num_dims)

                self.global_kernel = global_kernel
                self.local_kernels = ModuleList(local_kernels)
                self.register_parameter(
                        name="raw_local_position", 
                        parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape,1,ard_num_dims))
                )
                if local_position_prior is not None:
                        if not isinstance(local_position_prior, Prior):
                                raise TypeError("Expected gpytorch.priors.Prior but got " + type(local_position_prior).__name__)
                        self.register_prior(
                                "local_position_prior",
                                local_position_prior,
                                lambda m: m.local_position,
                                lambda m, v: m._set_local_position(v),
                        )
                if local_position_constraint is None:
                        local_position_constraint = Interval(torch.zeros(1, ard_num_dims).squeeze(), 
                                                             torch.ones(1, ard_num_dims).squeeze())
                        # Constrained between 0 to 1, can be modified for the inputs
                self.register_constraint("raw_local_position", local_position_constraint)
                # Weight parameters other than center:
                weight_params = {'psi' : torch.ones([1,ard_num_dims])*0.5,
                                 'sigma_g' : sqrt(10.),
                                 'Normalized' : True,
                                 'sigma_l' : torch.tensor([sqrt(0.01)])}
                #                 'local_num_samples' : None}
                weight_params.update(kwargs)
                self.eps = eps
                self.register_buffer('psi', weight_params['psi'])
                self.register_buffer('sigma_g', torch.as_tensor(weight_params['sigma_g']))
                self.Normalized = weight_params['Normalized']
                #if weight_params['local_num_samples'] is not None:
                #        self.sigma_l = sqrt(weight_params['local_num_samples']/2)
                #else:
                self.register_buffer('sigma_l', weight_params['sigma_l'])
                # TO DO: What if we want seperate sigma_l for different kernels?

                


        @property
        def local_position(self):
                return self.raw_local_position_constraint.transform(self.raw_local_position)

        @local_position.setter
        def local_position(self, value):
                return self._set_local_position(value)
        
        def _set_local_position(self, value):
                if not torch.is_tensor(value):
                        value = torch.as_tensor(value).to(self.raw_local_position)
                
                self.initialize(raw_local_position = self.raw_local_position_constraint.inverse_transform(value))

        def omega_g(self, x):
                """Helper function for unnormalized weights"""

        def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: bool=False, **params):
                #print(x1.shape,x2.shape)
                # This is just to make mll work:
                res = ZeroLinearOperator() if not diag else 0
                #if diag and torch.equal(x1, x2):
                #       return torch.ones(*x1.shape[:-2], x1.shape[-2], dtype=x1.dtype, device=x1.device)
                _k_g = self.global_kernel(x1, x2, diag=diag).to_dense()
                _w_g_1 = MultivariateNormal(self.psi, torch.eye(self.ard_num_dims, device=self.device)*self.sigma_g**2).log_prob(x1)
                _w_g_2 = MultivariateNormal(self.psi, torch.eye(self.ard_num_dims, device=self.device)*self.sigma_g**2).log_prob(x2)
                #print("w",_w_g_1.shape,_w_g_2.shape)
                _w_sum_1 = torch.exp(_w_g_1) # keep track of total weights
                _w_sum_2 = torch.exp(_w_g_2)
                if not diag:
                        res = res + torch.mul(_k_g, torch.sqrt(torch.matmul(_w_sum_1.unsqueeze(-1),
                                                                     _w_sum_2.unsqueeze(-2)))) # TO DO: decide on whether logsum is needed
                else:
                        res = res + torch.mul(_k_g, torch.exp(_w_g_1 + _w_g_2))
                for _k in self.local_kernels:
                        # Only same sigma_l for different local kernels considered
                        _w_l_dist = MultivariateNormal(self.local_position, torch.eye(self.ard_num_dims, device=self.device)*self.sigma_l**2)
                        _w_l_1 = _w_l_dist.log_prob(x1)
                        _w_l_2 = _w_l_dist.log_prob(x2)
                        _w_sum_1 = _w_sum_1 + torch.exp(_w_l_1)
                        _w_sum_2 = _w_sum_2 + torch.exp(_w_l_2)
                        if not diag:
                                res = res + torch.mul(_k(x1, x2, diag=diag).to_dense(), 
                                                torch.matmul(torch.exp(_w_l_1/2).unsqueeze(-1), 
                                                                torch.exp(_w_l_2/2).unsqueeze(-2)))
                        else:
                                res = res + torch.mul(_k(x1, x2, diag=diag).to_dense(),
                                                                torch.exp(_w_l_1/2 + _w_l_2/2))
                        
                # Now apply normalization
                if not diag:
                        res = torch.div(res, torch.unsqueeze(torch.sqrt(_w_sum_1), -1))
                        if res.dim() > 2:
                                res = torch.div(res, torch.unsqueeze(torch.sqrt(_w_sum_2), -2))
                        else:
                                res = torch.div(res, torch.sqrt(_w_sum_2))
                else:
                        res = torch.div(res, torch.sqrt(torch.mul(_w_sum_1, _w_sum_2)))        
                return res

In [23]:
from typing import Optional, Tuple
import gpytorch
from gpytorch.constraints import Interval
from gpytorch.priors import Prior
import torch
from torch._C import Size
from torch.nn import ModuleList
from ast import match_case
from math import sqrt
from statistics import linear_regression
from typing import Iterable
from gpytorch.kernels import Kernel
from linear_operator.operators import ZeroLinearOperator
from numpy import iterable
from gpytorch.distributions import MultivariateNormal
from gpytorch.lazy import ZeroLazyTensor


class Gaussian_Weight_Spartan_Kernel(Kernel):
        has_lengthscale = False
        def __init__(self, global_kernel: Kernel, local_kernels: Iterable[Kernel], 
                     ard_num_dims: int = 1,
                     local_position_prior: Optional[Prior] = None,
                     local_position_constraint: Optional[Interval] = None,
                     local_weight_var_prior: Optional[Prior] = None,
                     local_weight_var_constraint: Optional[Interval] = None,
                     eps: float = 0.000001, **kwargs):
                
                
                super().__init__(ard_num_dims=ard_num_dims)

                self.global_kernel = global_kernel
                self.local_kernels = ModuleList(local_kernels)
                # numbers of local kernels for calculation
                self.local_kernels_num = len(self.local_kernels)
                # hyperparameters for weights
                # Note: The logistic functions used to set interval constraints might cause a problem. The optimizer might favour either ends of the interval where the convergence seems to be reached because of almost 0 differentials.
                self.register_parameter(
                        name="raw_local_position", 
                        parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape,1,ard_num_dims))
                )
                if local_position_prior is not None:
                        if not isinstance(local_position_prior, Prior):
                                raise TypeError("Expected gpytorch.priors.Prior but got " + type(local_position_prior).__name__)
                        self.register_prior(
                                "local_position_prior",
                                local_position_prior,
                                lambda m: m.local_position,
                                lambda m, v: m._set_local_position(v),
                        )
                if local_position_constraint is None:
                        local_position_constraint = Interval(torch.zeros(1, ard_num_dims).squeeze(), 
                                                             torch.ones(1, ard_num_dims).squeeze())
                self.register_constraint("raw_local_position", local_position_constraint)

                self.register_parameter(
                        name="raw_local_weight_var",
                        parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, self.local_kernels_num))
                )
                if local_weight_var_prior is not None:
                        if not isinstance(local_weight_var_prior, Prior):
                                raise TypeError("Expected gpytorch.priors.Prior but got " + type(local_weight_var_prior).__name__)
                        self.register_prior(
                                "local_weight_var_prior",
                                local_weight_var_prior,
                                lambda m: m.local_weight_var,
                                lambda m, v: m._set_local_weight_var(v),
                        )
                
                if local_weight_var_constraint is None:
                        local_weight_var_constraint = Interval(torch.zeros(self.local_kernels_num).squeeze(),
                                                                torch.ones(self.local_kernels_num).squeeze())       
                self.register_constraint("raw_local_weight_var", local_weight_var_constraint)

                # Weight parameters other than center:
                weight_params = {'psi' : torch.ones([1,ard_num_dims])*0.5,
                                 'sigma_g' : sqrt(10.),
                                 'Normalized' : True,
                                 #'sigma_l' : torch.tensor([sqrt(0.01)])
                                 }
                #                 'local_num_samples' : None}
                weight_params.update(kwargs)
                self.eps = eps
                self.register_buffer('psi', weight_params['psi'])
                self.register_buffer('sigma_g', torch.as_tensor(weight_params['sigma_g']))
                self.Normalized = weight_params['Normalized']
                #
                #if weight_params['local_num_samples'] is not None:
                #        self.sigma_l = sqrt(weight_params['local_num_samples']/2)
                #else:
                self.register_buffer('sigma_l', weight_params['sigma_l'])
                # TO DO: Incorporate information on samples for local weight priors?

                


        @property
        def local_position(self):
                return self.raw_local_position_constraint.transform(self.raw_local_position)

        @local_position.setter
        def local_position(self, value):
                return self._set_local_position(value)
        
        def _set_local_position(self, value):
                if not torch.is_tensor(value):
                        value = torch.as_tensor(value).to(self.raw_local_position)
                
                self.initialize(raw_local_position = self.raw_local_position_constraint.inverse_transform(value))

        @property
        def local_weight_var(self):
                return self.raw_local_weight_var_constraint.transform(self.raw_local_weight_var)
        
        @local_weight_var.setter
        def local_weight_var(self, value):
                return self._set_local_weight_var(value)
        
        def _set_local_weight_var(self, value):
                if not torch.is_tensor(value):
                        value = torch.as_tensor(value).to(self.raw_local_weight_var)

                self.initialize(raw_local_weight_var = self.raw_local_weight_var_constraint.inverse_transform(value))

        def omega_g(self, x):
                """Helper function for unnormalized weights"""

        def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: bool=False, **params):
                res = ZeroLazyTensor() if not diag else 0
                _k_g = self.global_kernel(x1, x2, diag=diag)
                _w_g = (torch.unsqueeze((x1 - self.psi).norm(dim=-1), -1) + torch.unsqueeze((x2 - self.psi).pow(2).norm(dim=-1), -2))/2*self.sigma_g**2
                res = res + _k_g.mul(_w_g)
                for _kernel, local_var in zip(self.local_kernels, self.local_weight_var):
                        _k_l = _kernel(x1, x2, diag = diag)
                        _w_l = (torch.unsqueeze((x1 - self.local_position).norm(dim=-1), -1) + torch.unsqueeze((x2 - self.local_position).pow(2).norm(dim=-1), -2))/2*local_var**2
                        res = res + _k_l.mul(_w_l)
                
                return res