In [22]:
import numpy as np

from sklearn.gaussian_process.kernels import Matern
from abc import ABC, abstractmethod
import gpytorch
import torch

from gpytorch.lazy import KroneckerProductLazyVariable, MatmulLazyVariable
from gpytorch.lazy import NonLazyVariable
from gpytorch.lazy import LazyVariable
from gpytorch.lazy import DiagLazyVariable


In [23]:
class SpatialGrid(ABC):
    
    def __init__(self, coord_x, coord_y, step):
        """
        :param coord_x: numpy array of unique coordinates in dimension 1
        :param coord_y: numpy array of unique coordinates in dimension 2
        :param step:    distance between two neighbouring points along a single dimension
        """
        self.coord_x = coord_x
        self.coord_y = coord_y

        self.dimX = coord_x.shape[0]
        self.dimY = coord_y.shape[0]

        self.step = step

    @abstractmethod
    def get_cov_matrix(self, theta, kroneckerised=False, **kwargs):
        pass

class SpatialGridMatern(SpatialGrid):

    def __init__(self, coord_x, coord_y, step):
        super(SpatialGridMatern, self).__init__(coord_x, coord_y, step)

    def get_cov_matrix(self, theta, kroneckerised=False, **kwargs):
        """
        Compute the covariance matrix using the locations of the observations. The locations are observed in 2-D here.
        :param lengthscaleX: lengtscale of the Matern covariance function in the 'longitude' direction
        :param lengthscaleY: lengtscale of the Matern covariance function in the 'latitude' direction
        :param kroneckerised: a boolean flag indicating whether to return a list of covariance matrices for each
                              dimension separately, or take the kronecker product of them.
        :param smoothness: Smmothness parameter used for both latitude and longitude covariance functions.
        :return: If kroeckerised is False return the list of covariance matrices (one for each direction). Otherwise
                 take the kronecker product of the matrices (in the order of 'longitude, latitude, temporal'.
        """

        variance = theta[0]
        lengthscaleX = theta[1]
        lengthscaleY = theta[2]
        smoothness = kwargs['smoothness'] if 'smoothness' in kwargs else 2.5

        k1 = Matern(length_scale=lengthscaleX, nu=smoothness)
        k2 = Matern(length_scale=lengthscaleY, nu=smoothness)

        x_coordinates = np.reshape(self.coord_x, (self.dimX, 1))
        y_coordinates = np.reshape(self.coord_y, (self.dimY, 1))

        # we need to split the signal variance into two parts
        K1 = np.sqrt(variance) * k1(x_coordinates)
        K2 = np.sqrt(variance) * k2(y_coordinates)
        return np.kron(K1, K2) if kroneckerised else [K1, K2]


In [24]:
def compute_log_det(np_Ks, np_W):
    """
    Compute logDet(I + KW)
    :param np_Ks: list of 2-D numpy arrays
    :param np_W: a diagonal of a matrix passed as a 1-D array
    :return: logDet(I + KW)
    """
    with gpytorch.settings.max_cg_iterations(50), gpytorch.settings.max_lanczos_quadrature_iterations(
            30), gpytorch.settings.num_trace_samples(30):
        Ks = [NonLazyVariable(torch.from_numpy(Ki)) for Ki in np_Ks]

        K = KroneckerProductLazyVariable(*Ks)
        W = DiagLazyVariable(torch.from_numpy(np_W))

        # return LgcpLogDetVar(K, W).log_det()
        #
        temp_var1 = MatmulLazyVariable(K, W).add_diag(torch.Tensor([1]).double())
        return temp_var1.log_det().item()

In [46]:
M = 30
N = 30

x_coordinates = np.linspace(1, M, M)
y_coordinates = np.linspace(1, N, N)

sigma_sq = 4
lengthscale = 1.2

# Make a grid
grid = SpatialGridMatern(x_coordinates, y_coordinates, 1)

# Generate a list of covariance matrices (one per dimension) with specified parameters
Ks = grid.get_cov_matrix((sigma_sq, lengthscale, lengthscale))

# Randomly generate W -> roughly corresponds to our real use case
W = np.exp(np.random.normal(0, 4, M*N))

exact_sign, exact_value = np.linalg.slogdet(np.identity(M*N) + np.dot(np.kron(Ks[0], Ks[1]), np.diagflat(W)))

print("Lanczos: {}".format(compute_log_det(Ks, W)))
print("Exact: {}".format(exact_sign * exact_value))


Lanczos: 111.23140137548211
Exact: 1929.1047405503373
