In [123]:
import torch

import matplotlib.pyplot as plt

import math

In [130]:
class MaternKernel:

    """ Matern kernel functions """

    @staticmethod
    def matern_12( 
        tau : torch.Tensor,
        sigma : float,
        legnthscale : float,
        ) -> torch.Tensor:

        """
        Matern 1/2 kernel

        Arguments:
            tau (torch.Tensor): distance between points (i.e. |x - x'|)
            sigma (float)           : amplitude hyperparameter
            lengthscale (float)     : lengthscale hyperparameter

        Returns:
            torch.Tensor: Kernel weighting for distance tau
        """
        # compute the kernel
        kernel = (sigma ** 2) * torch.exp(-tau / legnthscale)

        return kernel

    @staticmethod
    def matern_32(
        tau : torch.Tensor,
        sigma : float,
        legnthscale : float,
        ) -> torch.Tensor:
        """
        Matern 3/2 kernel

        Arguments:
            tau (torch.Tensor)      : distance between points (i.e. |x - x'|)
            sigma (float)           : amplitude hyperparameter
            lengthscale (float)     : lengthscale hyperparameter 

        Returns:
            torch.Tensor: Kernel weighting for distance tau
        """
        # compute the kernel
        kernel = (sigma ** 2) * (1 + ((math.sqrt(3) * tau) / legnthscale)) * torch.exp(- (math.sqrt(3)) / legnthscale)

        return kernel

    @staticmethod
    def matern_52(
        tau : torch.Tensor,
        sigma : float,
        legnthscale : float,
        ) -> torch.Tensor:
        """
        Matern 5/2 kernel

        Arguments:
            tau (torch.Tensor)      : distance between points (i.e. |x - x'|)
            sigma (float)           : amplitude hyperparameter
            lengthscale (float)     : lengthscale hyperparameter 

        Returns:
            torch.Tensor: Kernel weighting for distance tau
        """
        # compute the kernel
        kernel = (sigma ** 2) * (1 + ((math.sqrt(5) * tau) / legnthscale) + ( (5 * (tau ** 2)) / (3 * (legnthscale ** 2)))) * torch.exp(- (math.sqrt(5) * tau) / legnthscale)

        return kernel

In [127]:
class SpectralDensity:
    
    """ Matern spectral density functions """

    @staticmethod
    def matern_12( 
        omega : torch.Tensor, 
        sigma : float, 
        lengthscale : float 
        ) -> torch.Tensor:
        """
        Computes the spectra corersponding to the Matérn 1/2 covariances

        Arguments:
            omega (torch.Tensor)    : frequency
            sigma (float)           : amplitude hyperparameter
            lengthscale (float)     : lengthscale hyperparameter

        Returns:
            (torch.Tensor)          : spectral density
        """
        # get lamnda
        lmbda = 1 / lengthscale

        # compute spectral density
        numerator = 2 * (sigma ** 2) * lmbda
        denominator = (lmbda ** 2) + (omega ** 2)
        spectral_density = numerator / denominator

        return spectral_density

    @staticmethod
    def matern_32(
        omega : torch.Tensor, 
        sigma : float, 
        lengthscale : float
        ) -> torch.Tensor:
        """
        Computes the spectra corersponding to the Matérn 3/2 covariances

        Arguments:
            omega (torch.Tensor)    : frequency
            sigma (float)           : amplitude hyperparameter
            lengthscale (float)     : lengthscale hyperparameter (lmbda = sqrt(3) / original lengthscale)

        Returns:
            (torch.Tensor)          : spectral density
        """
        # get lmbda
        lmbda = math.sqrt(3) / lengthscale

        # compute spectral density
        numerator = 4 * (sigma ** 2) * (lmbda ** 3)
        denominator = (lmbda ** 2) + (omega ** 2)
        spectral_density = numerator / (denominator ** 2)

        return spectral_density

    @staticmethod
    def matern_52(
        omega : torch.Tensor, 
        sigma : float, 
        lengthscale : float
        ) -> torch.Tensor:
        """
        Computes the spectra corersponding to the Matérn 5/2 covariances

        Arguments:
            omega (torch.Tensor)    : frequency
            sigma (float)           : amplitude hyperparameter
            lengthscale (float)     : lengthscale hyperparameter (lmbda = sqrt(5) / original lengthscale)

        Returns:
            (torch.Tensor)          : spectral density
        """
        # get lmbda
        lmbda = math.sqrt(5) / lengthscale

        # compute spectral density
        numerator = (16 / 3) * (sigma ** 2) * (lmbda ** 5)
        denominator = (lmbda ** 2) + (omega ** 2)
        spectral_density = numerator / (denominator ** 3)

        return spectral_density

In [128]:
# TODO: I might need to pass the omegas as paramters to the class so that gradient updates can be performed, alternatively a setter method might also work

class FourierBasis:
     
     """ Constructs a Fourier Basis for Variational Fourier Features """

     def __init__( self, M : int, a : float, b : float ):
         self.a = a
         self.b = b
         self.M = M
         self.omegas = torch.tensor([(2 * torch.pi * m) / (b - a) for m in range(M+1)])

     def __call__( self, x : float ) -> torch.Tensor:
          
          """ Evaluates the Fourier Basis at a point x """

          cosines = torch.cos(self.omegas * (x - self.a)) # includes omega = 0 frequency (cos(0) = 1))
          sines = torch.sin(self.omegas[1:] * (x - self.a)) # exclues omega = 0 frequency

          return torch.cat((cosines, sines))    

In [129]:
class Matern12Kuf:

    def __call__( self, fourier_basis : torch.Tensor, x : float ) -> torch.Tensor:
          
          """ Returns the cross-covariance between the domains """

          return fourier_basis(x)

In [None]:
class Matern12Kuu:

    """ Returns the covariances between the inducing variables in the transformed domain using the Matérn 1/2 kernel """

    def _alpha( 
            omegas : torch.Tensor, 
            sigma : float, 
            lengthscale : float, 
            a : float, 
            b : float 
            ) -> torch.Tensor:
        """
        Computes alpha half of the Kuu representation for the Matérn 1/2 covarainces

        Arguments:
            omegas (torch.Tensor)   : frequency (! omegas[0] = 0 !)
            sigma (float)           : amplitude hyperparameter
            lengthscale (float)     : lengthscale hyperparameter
            a (float)               : lower bound of the input space
            b (float)               : upper bound of the input space

        Returns:
            (torch.Tensor)          : alpha
        """
        # check that omegas[0] = 0
        assert omegas[0] == 0, "The first element of omegas must be 0"

        # compute the inverse spectral density
        S_inv = 1 / SpectralDensity.matern_12(omegas, sigma, lengthscale)

        # compute the alpha half
        S_inner = torch.cat((S_inv[1:], S_inv[1:]))
        alpha = ((b - a) / 2) * torch.cat((2 * S_inv[0].unsqueeze(-1), S_inner))

        return alpha


    def _beta( 
            omegas : torch.Tensor, 
            sigma : float 
            ) -> torch.Tensor:
        """
        Computes the beta half of the Kuu representation for the Matérn 1/2 covarainces

        Arguments:
            omega (torch.Tensor)    : frequency
            sigma (float)           : amplitude hyperparameter

        Returns:
            (torch.Tensor)          : beta
        """
        
        # compute the sigma half
        sigma_half = torch.tensor(sigma).repeat(len(omegas))

        # compute the zero half
        zero_half = torch.tensor(0.0).repeat(len(omegas) - 1)

        # compute beta
        beta = torch.cat((sigma_half, zero_half))

        return beta


    def __call__( self, omegas : torch.Tensor, sigma : float, lengthscale : float, a : float, b : float ) -> torch.Tensor:

        """
        Computes the Kuu using the representation given by (62) in the VFF paper for the Matérn 1/2 covarainces

        Arguments:
            omegas (torch.Tensor)   : frequency (! omegas[0] = 0 !)
            sigma (float)           : amplitude hyperparameter
            lengthscale (float)     : lengthscale hyperparameter
            a (float)               : lower bound of the input space
            b (float)               : upper bound of the input space

        Returns:
            (torch.Tensor)          : Kuu
        """

        # compute the alphas
        alphas = self._alpha(omegas, sigma, lengthscale, a, b)

        # compute the betas
        betas = self._beta(omegas, sigma)

        # compute Kuu
        Kuu = torch.diag(alphas) + (betas @ betas)

        return Kuu