# Check have I implemented the nystrom approximation speed up correctly...

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow_probability import distributions as tfd
from tensorflow_probability import bijectors as tfb
import math
import copy


# ******************* return log probability ****************
# - distance based gp prior 
class GP():
    def __init__(self, n_vx, **kwargs):
        """
        Gaussian process        

        N(m,K+nugget)
        -> we construct m, the mean function
        -> we construct K, the covariance function
        -> with a nugget, for stability        

        Args:
            **kwargs: Optional parameters for controlling behavior, such as:
                - psd_control: Method for ensuring positive semidefiniteness.
                - gp_dtype: Data type for tensor conversion.
                - kernel: Choice of covariance function (default: 'RBF').
        """
        self.n_vx = tf.Variable(n_vx, dtype=tf.int32, name="n_vx")

        # Setup distance matrix and positive semidefinite control
        self.psd_control   = kwargs.get('psd_control', 'euclidean')  # 'euclidean' or 'none'
        self.eps           = kwargs.get('eps', 1e-6)
        self.embedding_dim = kwargs.get('embedding_dim', 10)
        self.gp_dtype   = kwargs.get('gp_dtype', tf.float64)

        self.stat_kernel_list = []
        self.lin_kernel_list = []
        self.warp_kernel_list = []
        self.mfunc_list = []
        self.mfunc_bijector = tfb.Identity()
        self.Xs = {}
        self.dXs = {}
        self.n_inducers = None  # if None, use all points
        self.inducer_idx = None  # if None, use all points
        self.nystrom = False
        self.kernel_type = {}
        self.pids = {}
        # Index of parameters to be passed...
        self.pids[0] = 'gpk_nugget' # Global nugget term
        self.pids[1] = 'mfunc_mean' # Global mean term 
        self.pids_inv = {}
        self._update_pids_inv()
        self.return_log_prob = self._return_log_prob_unfixed # by default, return log prob unfixed...
        self.gp_prior_dist = None

    def _update_pids_inv(self):
        self.pids_inv = {}
        self.pids_inv = {v:k for k,v in self.pids.items()}
    
    def update_n_vx(self, new_value):
        self.n_vx.assign(new_value)
                        
    # **************** MEAN FUNCTIONS ***************
    def add_xid_linear_mfunc(self, xid, **kwargs):
        ''' add linear mean function
        '''
        Xs = kwargs.get('Xs', None)
        Xs = tf.convert_to_tensor(Xs, dtype=self.gp_dtype)
        if len(Xs.shape) == 1:
            Xs = tf.expand_dims(Xs, axis=-1)
        self.Xs[xid] = Xs
        self.mfunc_list.append(xid)        
        Ds = self.Xs[xid].shape[1]
        for i in range(Ds):
            self.pids[len(self.pids)] = f'mfunc{xid}_slope{i}'

        # Update the inverse dictionary
        self._update_pids_inv()
    
    @tf.function
    def _return_mfunc(self, **kwargs):
        '''Return the mean function
        '''
        # Start of with zero then add global mean
        m_out = tf.zeros(self.n_vx, dtype=self.gp_dtype) + tf.cast(kwargs['mfunc_mean'], self.gp_dtype) # global mean...
        # then add any regressors...
        for m in self.mfunc_list:
            slopes = tf.stack([kwargs[f'mfunc{m}_slope{i}'] for i in range(self.Xs[m].shape[1])], axis=0)  # [D]
            m_out += tf.reduce_sum(tf.cast(slopes, dtype=self.gp_dtype) * tf.transpose(self.Xs[m]), axis=0) 
        return self.mfunc_bijector(tf.cast(m_out, dtype=tf.float32))
    
    def add_mfunc_bijector(self, bijector_type, **kwargs):
        ''' add transformations to parameters so that they are fit smoothly        
        
        identity        - do nothing
        softplus        - don't let anything be negative

        '''
        if bijector_type == 'identity':
            self.mfunc_bijector = tfb.Identity()        
        elif bijector_type == 'softplus':
            # Don't let anything be negative
            self.mfunc_bijector = tfb.Softplus()
        elif bijector_type == 'sigmoid':
            self.mfunc_bijector = tfb.Sigmoid(
                low=kwargs.get('low'), high=kwargs.get('high'),
            )
        else:
            self.mfunc_bijector = bijector_type

    # *************************************************
    # *************************************************
    # *************************************************
    
    # *** KERNELS ***
    # -> add Stationary kernels 
    def add_xid_stationary_kernel(self, xid, **kwargs):
        ''' add a kernel 
        '''
        Xs = kwargs.get('Xs', None)
        dXs = kwargs.get('dXs', None)
        psd_control = kwargs.get('psd_control', self.psd_control)
        embedding_dim = kwargs.get('embedding_dim', self.embedding_dim)
        self.kernel_type[xid] = kwargs.get('kernel_type', 'RBF')                
        self.stat_kernel_list.append(xid)

        if dXs is None:
            # Get distances from 
            dXs = compute_euclidean_distance_matrix(Xs[...,np.newaxis])        
        if psd_control == 'euclidean':
            print('Embedding in Euclidean space...')
            dXs = mds_embedding(dXs, embedding_dim)
            dXs = compute_euclidean_distance_matrix(dXs)
        self.dXs[xid] = tf.convert_to_tensor(dXs, dtype=self.gp_dtype)
        self.dXs[xid] = (self.dXs[xid] + tf.transpose(self.dXs[xid])) / 2.0        
        # Add a lengthscale & a variance
        self.pids[len(self.pids)] = f'gpk{xid}_l'
        self.pids[len(self.pids)] = f'gpk{xid}_v'

        # Update the inverse dictionary
        self._update_pids_inv()        
    
    # # -> add Linear kernels
    # def add_xid_linear_kernel(self, xid, **kwargs):
    #     ''' add a kernel
    #     '''
    #     Xs = kwargs.get('Xs', None)
    #     self.kernel_type[xid] = 'linear'
    #     self.lin_kernel_list.append(xid)
        
    #     self.Xs[xid] = tf.expand_dims(tf.convert_to_tensor(Xs, dtype=self.gp_dtype), axis=1)
    #     # Add a lengthscale & a variance
    #     self.pids[len(self.pids)] = f'gpk{xid}_slope'
    #     self.pids[len(self.pids)] = f'gpk{xid}_const'

    #     # Update the inverse dictionary
    #     self._update_pids_inv()    
    
    def add_xid_warp_kernel(self, xid, Xs, **kwargs):
        self.Xs[xid] = tf.convert_to_tensor(Xs, dtype=self.gp_dtype)
        self.pids[len(self.pids)] = f'gpk{xid}_v'
        # Distance for RBF type kernel comes from warped LBO
        for i in range(self.Xs[xid].shape[1]):
            self.pids[len(self.pids)] = f'gpk{xid}_w{i}' 
        self.warp_kernel_list.append(xid)
        self._update_pids_inv()
                        

    def add_nystrom_approximation(self, n_inducers, inducer_idx=None):
        ''' Use nystrom approximation to speed up the GP
        '''
        self.n_inducers = n_inducers
        self.inducer_idx = inducer_idx
        self.nystrom = True
        if self.inducer_idx is not None:
            self.inducer_idx = tf.convert_to_tensor(self.inducer_idx, dtype=tf.int32)
        else:
            self.inducer_idx = tf.random.shuffle(tf.range(self.n_vx))[:self.n_inducers]
        self.return_log_prob = self._return_log_prob_nystrom


    @tf.function
    def _return_sigma_full(self, **kwargs):
        ''' Putting all the kernels together - > return the full covariance matrix
        '''
        # Start covariance matrix from zero...
        s_out = tf.zeros((self.n_vx,self.n_vx), dtype=self.gp_dtype)
        
        # Add in any linear kernels
        for s in self.lin_kernel_list:
            s_out += self._return_sigma_xid_linear(
                gpk_slope=kwargs[f'gpk{s}_slope'],
                gpk_const=kwargs[f'gpk{s}_const'],
                Xs=self.Xs[s],
            )
        
        # # Add in any stationary kernels (e.g., RBF)
        # for s in self.stat_kernel_list:
        #     s_out += self._return_sigma_xid_stationary(
        #         gpk_l=kwargs[f'gpk{s}_l'],
        #         gpk_v=kwargs[f'gpk{s}_v'],
        #         dXs=self.dXs[s],
        #         kernel_type=self.kernel_type[s]
        #     )
        for s in self.warp_kernel_list:
            s_kernel_type = s.split('_')[-1]
            dXs = self._return_warp_dXs(
                s, **kwargs,
            )
            s_out += self._return_sigma_xid_stationary(
                gpk_l=1.0,
                gpk_v=kwargs[f'gpk{s}_v'],
                dXs=dXs,
                kernel_type=s_kernel_type
            )            
        # Add the nugget term
        s_out += tf.linalg.diag(tf.ones(self.n_vx, dtype=self.gp_dtype)) * tf.cast(self.eps + kwargs[f'gpk_nugget'], dtype=self.gp_dtype)
        return s_out        
    
    @tf.function
    def _return_warp_dXs(self, xid, **kwargs):
        # [1] Weighted sum of eigenvectors (LBOwarp)
        wX = tf.stack(
            [kwargs[f"gpk{xid}_w{i}"] for i in range(self.Xs[xid].shape[1])],
            axis=0
        )
        wX = tf.cast(wX, dtype=self.gp_dtype)
        # Warped distances
        warp_X = tf.matmul(self.Xs[xid], wX) # [N, 1]
        warp_X = tf.squeeze(warp_X, axis=-1) # [N,]            
        warp_dXs = compute_euclidean_distance_matrix(warp_X[...,tf.newaxis])                
        return warp_dXs 
    
    @tf.function
    def _return_sigma_xid_stationary(self, gpk_l, gpk_v, dXs, kernel_type):
        """
        Computes the covariance matrix using the chosen kernel.

        Args:
            gp_l (float): Lengthscale parameter.
            gp_v (float): Variance parameter.

        Returns:
            tf.Tensor: Covariance matrix.
        """
        gpk_v = tf.cast(gpk_v, dtype=self.gp_dtype)
        gpk_l = tf.cast(gpk_l, dtype=self.gp_dtype)

        if kernel_type == 'RBF':
            cov_matrix = tf.square(gpk_v) * tf.exp(
                -tf.square(dXs) / (2.0 * tf.square(gpk_l))
            )
        elif kernel_type == 'matern52':
            sqrt5 = tf.cast(tf.sqrt(5.0), dtype=self.gp_dtype)
            frac1 = (sqrt5 * dXs) / gpk_l
            frac2 = (5.0 * tf.square(dXs)) / (3.0 * tf.square(gpk_l))
            cov_matrix = tf.square(gpk_v) * (1 + frac1 + frac2) * tf.exp(-frac1)
        elif kernel_type == 'laplace':
            cov_matrix = tf.square(gpk_v) * tf.exp(-dXs / gpk_l)
        else:
            raise ValueError("Unsupported kernel: {}".format(kernel_type))
        # Add nugget term for numerical stability
        return cov_matrix 
    
    # @tf.function
    # def _return_sigma_xid_linear(self, gpk_slope, gpk_const, Xs):
    #     '''linear kernel
    #     '''
    #     gpk_slope = tf.cast(gpk_slope, dtype=self.gp_dtype)
    #     gpk_const = tf.cast(gpk_const, dtype=self.gp_dtype)        
    #     cov_matrix = gpk_slope**2 * (Xs-gpk_const) * (tf.transpose(Xs) - gpk_const)
    #     return cov_matrix

    def set_log_prob_fixed(self,**kwargs):
        # Create a one off covariance matrix -> then use it to get probability each time...
        # Get cov matrix
        self.cov_matrix = self._return_sigma_full(**kwargs)
        self.chol = tf.linalg.cholesky(tf.cast(self.cov_matrix, dtype=self.gp_dtype))
        # Get mean vector
        self.m_vect = self._return_mfunc(**kwargs)

        self.gp_prior_dist = tfd.MultivariateNormalTriL(
            loc=tf.squeeze(tf.cast(self.m_vect, dtype=tf.float32)), 
            scale_tril=tf.cast(self.chol, dtype=tf.float32),
            allow_nan_stats=False,
        )
        self.return_log_prob = self._return_log_prob_fixed
        
    @tf.function
    def _return_log_prob_nystrom(self, parameter, **kwargs):
        ''' Return the log probability using nystrom approximation
        '''
        gpk_nugget = tf.cast(kwargs['gpk_nugget'], dtype=self.gp_dtype)        
        m_vect = self._return_mfunc(**kwargs)        
        parameter_dm = tf.cast(parameter - m_vect, self.gp_dtype) # remove mean function from parameter

        K_full = self._return_sigma_full(**kwargs) # we do the nugget later, so have to remove it here...        
        # might change this later...
        K_full -= tf.linalg.diag(tf.ones(self.n_vx, dtype=self.gp_dtype)) * tf.cast(self.eps + kwargs[f'gpk_nugget'], dtype=self.gp_dtype)
        A=tf.gather(tf.gather(K_full, self.inducer_idx, axis=0), self.inducer_idx, axis=1)
        B=tf.gather(K_full, self.inducer_idx, axis=1)
        # Add small jitter to A for PD-ness
        A += tf.cast(self.eps, self.gp_dtype) * tf.eye(self.n_inducers, dtype=self.gp_dtype)

        # Build S = A + (1/nugget) B^T B  (m x m)
        BtB = tf.matmul(tf.transpose(B), B)   # (m,m)
        S = A + (1.0 / gpk_nugget) * BtB

        # Cholesky S
        Ls = tf.linalg.cholesky(S) # (m,m)
        # Solve S x = B^T y
        Bt_y = tf.matmul(tf.transpose(B), tf.expand_dims(parameter_dm, -1))  # (m,1)
        x = tf.linalg.cholesky_solve(Ls, Bt_y)                       # (m,1)

        # Compute quadratic term via Woodbury K^{-1} y = (1/sigma2) y - (1/sigma2^2) B x
        Bx = tf.matmul(B, x)                                        # (n,1)
        v = (1.0 / gpk_nugget) * tf.expand_dims(parameter_dm, -1) - (1.0 / (gpk_nugget * gpk_nugget)) * Bx
        quad = tf.squeeze(tf.matmul(tf.transpose(tf.expand_dims(parameter_dm, -1)), v))  # scalar
        # Log-determinant via determinant lemma:
        # log|K| = n log sigma2 - log|A| + log|S|
        La = tf.linalg.cholesky(A)
        logdetA = 2.0 * tf.reduce_sum(tf.math.log(tf.linalg.diag_part(La)))
        logdetS = 2.0 * tf.reduce_sum(tf.math.log(tf.linalg.diag_part(Ls)))
        n_float = tf.cast(self.n_vx, self.gp_dtype)
        logdetK = n_float * tf.math.log(gpk_nugget) - logdetA + logdetS        
        log2pi = tf.math.log(2.0 * tf.constant(math.pi, dtype=self.gp_dtype))
        logp = -0.5 * quad - 0.5 * logdetK - 0.5 * n_float * log2pi
        return tf.cast(tf.reshape(logp, []), tf.float32)

    @tf.function
    def _return_log_prob_unfixed(self, parameter, **kwargs):
        """
        Unfixed parameters using TensorFlow distribution.
        Recompute covariance and Cholesky decomposition on the fly.
        Optionally uses random selection of n_inducers for sparse GP approximation.
        """
        # Get cov matrix
        cov_matrix = self._return_sigma_full(**kwargs)
        chol = tf.linalg.cholesky(tf.cast(cov_matrix, dtype=self.gp_dtype))
        # Get mean vector
        m_vect = self._return_mfunc(**kwargs)
        gp_prior_dist = tfd.MultivariateNormalTriL(
            loc=tf.squeeze(tf.cast(m_vect, dtype=tf.float32)), #tf.fill([self.n_vx], tf.squeeze(m_vect)),
            scale_tril=tf.cast(chol, dtype=tf.float32),
            allow_nan_stats=False,
        )
        return gp_prior_dist.log_prob(parameter)

    @tf.function
    def _return_log_prob_fixed(self, parameter, **kwargs):
        """
        Unfixed parameters using TensorFlow distribution.
        Recompute covariance and Cholesky decomposition on the fly.
        Optionally uses random selection of n_inducers for sparse GP approximation.
        """
        return self.gp_prior_dist.log_prob(parameter)


2025-09-03 16:19:36.504628: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-09-03 16:19:36.509357: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-09-03 16:19:36.519479: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1756909176.536406 3586931 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1756909176.541528 3586931 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1756909176.556251 3586931 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linkin

In [7]:
n_loop = 10
nvx = 1000
m = 90

x = np.random.rand(nvx)
dXs = x[None,:] - x[...,None]

gpN = GP(n_vx=nvx)
gpN.add_xid_stationary_kernel(
    xid='s', dXs=dXs, psd_control=None, 
)
gpN.add_nystrom_approximation(n_inducers=m)

# Generate all random parameters in advance
random_params = [np.random.rand(nvx) for _ in range(n_loop)]

# --- Timing the Nystrom approximation loop ---
from time import time
print("Timing Nystrom approximation...")
start_nystrom = time()
No = []
for i in range(n_loop):
    eg_rand = random_params[i]
    No.append(gpN._return_log_prob_nystrom(
        parameter=tf.cast(eg_rand, tf.float32),
        mfunc_mean=0.0,
        gpks_l=1.0,
        gpks_v=1.0,
        gpk_nugget=0.1,
    ).numpy())
end_nystrom = time()
nystrom_time = end_nystrom - start_nystrom
print(f"Nystrom approximation loop took: {nystrom_time:.4f} seconds")

# --- Timing the unfixed loop ---
print("\nTiming unfixed calculation...")
start_unfixed = time()
Fo = []
for i in range(n_loop):
    # Use the same random parameter as the Nystrom loop
    eg_rand = random_params[i]
    Fo.append(gpN._return_log_prob_unfixed(
        parameter=tf.cast(eg_rand, tf.float32),
        mfunc_mean=0.0,
        gpks_l=1.0,
        gpks_v=1.0,
        gpk_nugget=0.1,
    ).numpy())
end_unfixed = time()
unfixed_time = end_unfixed - start_unfixed
print(f"Unfixed calculation loop took: {unfixed_time:.4f} seconds")

# --- Comparing the results ---
print("\n--- Comparison ---")
if nystrom_time < unfixed_time:
    print(f"The Nystrom approximation was faster by a factor of: {unfixed_time / nystrom_time:.2f}")
elif unfixed_time < nystrom_time:
    print(f"The unfixed calculation was faster by a factor of: {nystrom_time / unfixed_time:.2f}")
else:
    print("The two calculations took approximately the same amount of time.")
No, Fo = np.array(No), np.array(Fo)
# You can now calculate the correlation if you wish
correlation = np.corrcoef(No, Fo)[0, 1]
mse=(np.diff(No-Fo)**2).mean()
print(f"\nCorrelation between Nystrom and Unfixed results: corr={correlation:.4f}, mse={mse}")

Timing Nystrom approximation...
Nystrom approximation loop took: 0.2023 seconds

Timing unfixed calculation...
Unfixed calculation loop took: 0.5916 seconds

--- Comparison ---
The Nystrom approximation was faster by a factor of: 2.92

Correlation between Nystrom and Unfixed results: corr=1.0000, mse=4.437234792931122e-07
