In [1]:
import numpy as np

from sklearn.gaussian_process.kernels import Matern
from abc import ABC, abstractmethod
import gpytorch
import torch

from gpytorch.lazy import KroneckerProductLazyVariable, MatmulLazyVariable
from gpytorch.lazy import NonLazyVariable
from gpytorch.lazy import LazyVariable
from gpytorch.lazy import DiagLazyVariable


In [2]:
class SpatialGrid(ABC):
    
    def __init__(self, coord_x, coord_y, step):
        """
        :param coord_x: numpy array of unique coordinates in dimension 1
        :param coord_y: numpy array of unique coordinates in dimension 2
        :param step:    distance between two neighbouring points along a single dimension
        """
        self.coord_x = coord_x
        self.coord_y = coord_y

        self.dimX = coord_x.shape[0]
        self.dimY = coord_y.shape[0]

        self.step = step

    @abstractmethod
    def get_cov_matrix(self, theta, kroneckerised=False, **kwargs):
        pass

class SpatialGridMatern(SpatialGrid):

    def __init__(self, coord_x, coord_y, step):
        super(SpatialGridMatern, self).__init__(coord_x, coord_y, step)

    def get_cov_matrix(self, theta, kroneckerised=False, **kwargs):
        """
        Compute the covariance matrix using the locations of the observations. The locations are observed in 2-D here.
        :param lengthscaleX: lengtscale of the Matern covariance function in the 'longitude' direction
        :param lengthscaleY: lengtscale of the Matern covariance function in the 'latitude' direction
        :param kroneckerised: a boolean flag indicating whether to return a list of covariance matrices for each
                              dimension separately, or take the kronecker product of them.
        :param smoothness: Smmothness parameter used for both latitude and longitude covariance functions.
        :return: If kroeckerised is False return the list of covariance matrices (one for each direction). Otherwise
                 take the kronecker product of the matrices (in the order of 'longitude, latitude, temporal'.
        """

        variance = theta[0]
        lengthscaleX = theta[1]
        lengthscaleY = theta[2]
        smoothness = kwargs['smoothness'] if 'smoothness' in kwargs else 2.5

        k1 = Matern(length_scale=lengthscaleX, nu=smoothness)
        k2 = Matern(length_scale=lengthscaleY, nu=smoothness)

        x_coordinates = np.reshape(self.coord_x, (self.dimX, 1))
        y_coordinates = np.reshape(self.coord_y, (self.dimY, 1))

        # we need to split the signal variance into two parts
        K1 = np.sqrt(variance) * k1(x_coordinates)
        K2 = np.sqrt(variance) * k2(y_coordinates)
        return np.kron(K1, K2) if kroneckerised else [K1, K2]


In [3]:
def compute_log_det(np_Ks, np_W):
    """
    Compute logDet(I + sqrt(W) * K * sqrt(W))
    :param np_Ks: list of 2-D numpy arrays
    :param np_W: a diagonal of a matrix passed as a 1-D array
    :return: logDet(I + sqrt(W) * K * sqrt(W))
    """

    with gpytorch.settings.max_cg_iterations(200), gpytorch.settings.max_lanczos_quadrature_iterations(
            180), gpytorch.settings.num_trace_samples(180):
        Ks = [NonLazyVariable(torch.from_numpy(Ki)) for Ki in np_Ks]

        K = KroneckerProductLazyVariable(*Ks)
        sqrtW = DiagLazyVariable(torch.from_numpy(np.sqrt(np_W)))

        return LgcpLogDetVar(K, sqrtW).log_det().item()


class LgcpLogDetVar(LazyVariable):
    def __init__(self, K, sqrtW):
        super(LgcpLogDetVar, self).__init__(K, sqrtW)
        self.K = K
        self.sqrtW = sqrtW

    def _size(self):
        return self.K.size()

    def _matmul(self, rhs):
        """
        Returns: (I + sqrt(W) * K * sqrt(W)) * rhs
        """
        # compute sqrtW * rhs (element-wise multiplication of W and rhs)
        temp = self.sqrtW._matmul(rhs)

        # kronecker-style multiplication of the step above and K, i.e. (K * sqrtW * rhs)
        temp2 = self.K._matmul(temp)
        
        # (sqrtW * K sqrtW * rhs)
        temp3 = self.sqrtW._matmul(temp2)

        # rhs + (sqrtW * K sqrtW * rhs)
        return rhs + temp3


In [4]:
def compute_log_det2(np_Ks, np_W):
    with gpytorch.settings.max_cg_iterations(200), gpytorch.settings.max_lanczos_quadrature_iterations(
            180), gpytorch.settings.num_trace_samples(180):
        Ks = [NonLazyVariable(torch.from_numpy(Ki)) for Ki in np_Ks]
        
        K = KroneckerProductLazyVariable(*Ks)
        sqrtW = DiagLazyVariable(torch.from_numpy(np.sqrt(np_W)))
        
        temp1 = MatmulLazyVariable(sqrtW, MatmulLazyVariable(K, sqrtW)).add_diag(torch.Tensor([1]).double())
        return temp1.log_det().item()

In [5]:
M = 20
N = 20

x_coordinates = np.linspace(1, M, M)
y_coordinates = np.linspace(1, N, N)

sigma_sq = 4
lengthscale = 1.2

# Make a grid
grid = SpatialGridMatern(x_coordinates, y_coordinates, 1)

# Generate a list of covariance matrices (one per dimension) with specified parameters
Ks = grid.get_cov_matrix((sigma_sq, lengthscale, lengthscale))

# Randomly generate W -> roughly corresponds to our real use case
W = np.exp(np.random.normal(0, 4, M*N))

W_sqrt = np.sqrt(W)

# exact_2_sign, exact_2_value = np.linalg.slogdet(np.identity(M*N) + np.dot(np.diagflat(W_sqrt), np.dot(np.kron(Ks[0], Ks[1]), np.diagflat(W_sqrt))))

# matrix = np.identity(M*N) + np.dot(np.diagflat(W_sqrt), np.dot(np.kron(Ks[0], Ks[1]), np.diagflat(W_sqrt)))
# print("Is symmetric: {}".format(np.allclose(matrix, matrix.T)))


# exact_sign, exact_value = np.linalg.slogdet(np.identity(M*N) + np.dot(np.kron(Ks[0], Ks[1]), np.diagflat(W)))

# print("Lanczos: {}".format(compute_log_det(Ks, W)))
# print("Exact: {}".format(exact_sign * exact_value))
# print("Exact other: {}".format(exact_2_sign * exact_2_value))


In [6]:
from time import process_time as timer

In [10]:
t_start = timer()
custom_method_result = compute_log_det(Ks, W)
t_end = timer()
print(t_end - t_start)

4.292205000000003


In [11]:
t_start2 = timer()
default_method_result = compute_log_det2(Ks, W)
t_end2 = timer()
print(t_end2 - t_start2)

23.037360000000007


In [12]:
print(custom_method_result)
print(default_method_result)

791.8336858465867
800.5521823259061
