In [1]:
import matplotlib.pyplot as plt 
import time 
from tqdm import tqdm 
import random 
import torch 
from torch.utils.data import random_split, DataLoader, Dataset
import gpytorch
from gpytorch.kernels import ScaleKernel, RBFKernel
from sklearn.cluster import KMeans
from linear_operator.settings import max_cholesky_size

import sys 
sys.path.append("../")
from gp.util import dynamic_instantiation, flatten_dict, unflatten_dict, flatten_dataset, split_dataset, filter_param, heatmap

# System/Library imports
from typing import *

# Common data science imports
import numpy as np
import torch

# Gpytorch and linear_operator
import gpytorch 
import gpytorch.constraints
from gpytorch.kernels import ScaleKernel
import linear_operator
from linear_operator.operators.dense_linear_operator import DenseLinearOperator
from linear_operator.utils.cholesky import psd_safe_cholesky

# Our imports
from gp.soft_gp.mll import HutchinsonPseudoLoss
from linear_solver.cg import linear_cg



  from .autonotebook import tqdm as notebook_tqdm


### Soft GP testing profiling/tuning boilerplate
its recommend to collapse all functions with Ctrl/Cmnd +k +0 

### SoftGP baseline implementation 
---


In [2]:
class SoftGP_baseline(torch.nn.Module):
    def __init__(
        self,
        kernel: Callable,
        inducing_points: torch.Tensor,
        noise=1e-3,
        learn_noise=False,
        use_scale=False,
        device="cpu",
        dtype=torch.float32,
        solver="solve",
        max_cg_iter=50,
        cg_tolerance=0.5,
        mll_approx="hutchinson",
        fit_chunk_size=1024,
        use_qr=False,
    ) -> None:
        # Argument checking 
        methods = ["solve", "cholesky", "cg"]
        if not solver in methods:
            raise ValueError(f"Method {solver} should be in {methods} ...")
        
        # Check devices
        devices = ["cpu"]
        if torch.cuda.is_available():
            devices += ["cuda"]
            for i in range(torch.cuda.device_count()):
                devices += [f"cuda:{i}"]
        if not device in devices:
            raise ValueError(f"Device {device} should be in {devices} ...")

        # Create torch module
        super(SoftGP_baseline, self).__init__()

        # Misc
        self.device = device
        self.dtype = dtype
        
        # Mll approximation settings
        self.solve_method = solver
        self.mll_approx = mll_approx

        # Fit settings
        self.use_qr = use_qr
        self.fit_chunk_size = fit_chunk_size

        # Noise
        self.noise_constraint = gpytorch.constraints.Positive()
        noise = torch.tensor([noise], dtype=self.dtype, device=self.device)
        noise = self.noise_constraint.inverse_transform(noise)
        if learn_noise:
            self.register_parameter("raw_noise", torch.nn.Parameter(noise))
        else:
            self.raw_noise = noise

        # Kernel
        self.use_scale = use_scale
        if use_scale:
            self.kernel = ScaleKernel(kernel).to(self.device)
        else:
            self.kernel = kernel.to(self.device)

        # Inducing points
        self.register_parameter("inducing_points", torch.nn.Parameter(inducing_points))

        # Interpolation
        def softmax_interp(X: torch.Tensor, sigma_values: torch.Tensor) -> torch.Tensor:
            distances = torch.linalg.vector_norm(X - sigma_values, ord=2, dim=-1)
            softmax_distances = torch.softmax(-distances, dim=-1)
            return softmax_distances
        self.interp = softmax_interp
        
        # Fit artifacts
        self.alpha = None
        self.K_zz_alpha = None

        # CG solver params
        self.max_cg_iter = max_cg_iter
        self.cg_tol = cg_tolerance
        self.x0 = None
        
    # -----------------------------------------------------
    # Soft GP Helpers
    # -----------------------------------------------------
    
    @property
    def noise(self):
        return self.noise_constraint.transform(self.raw_noise)

    def get_lengthscale(self) -> float:
        if self.use_scale:
            return self.kernel.base_kernel.lengthscale.cpu()
        else:
            return self.kernel.lengthscale.cpu()
        
    def get_outputscale(self) -> float:
        if self.use_scale:
            return self.kernel.outputscale.cpu()
        else:
            return 1.

    def _mk_cov(self, z: torch.Tensor) -> torch.Tensor:
        return self.kernel(z, z).evaluate()
    
    def _interp(self, x: torch.Tensor) -> torch.Tensor:
        x_expanded = x.unsqueeze(1).expand(-1, self.inducing_points.shape[0], -1)
        W_xz = self.interp(x_expanded, self.inducing_points)
        return W_xz

    # -----------------------------------------------------
    # Linear solver
    # -----------------------------------------------------

    def _solve_system(
        self,
        kxx: linear_operator.operators.LinearOperator,
        full_rhs: torch.Tensor,
        x0: torch.Tensor = None,
        forwards_matmul: Callable = None,
        precond: torch.Tensor = None,
        return_pinv: bool = False,
    ) -> torch.Tensor:
        use_pinv = False
        with torch.no_grad():
            try:
                if self.solve_method == "solve":
                    solve = torch.linalg.solve(kxx, full_rhs)
                elif self.solve_method == "cholesky":
                    L = torch.linalg.cholesky(kxx)
                    solve = torch.cholesky_solve(full_rhs, L)
                elif self.solve_method == "cg":
                    # Source: https://github.com/AndPotap/halfpres_gps/blob/main/mlls/mixedpresmll.py
                    solve = linear_cg(
                        forwards_matmul,
                        full_rhs,
                        max_iter=self.max_cg_iter,
                        tolerance=self.cg_tol,
                        initial_guess=x0,
                        preconditioner=precond,
                    )
                else:
                    raise ValueError(f"Unknown method: {self.solve_method}")
            except RuntimeError as e:
                print("Fallback to pseudoinverse: ", str(e))
                solve = torch.linalg.pinv(kxx.evaluate()) @ full_rhs
                use_pinv = True

        # Apply torch.nan_to_num to handle NaNs from percision limits 
        solve = torch.nan_to_num(solve)
        return (solve, use_pinv) if return_pinv else solve

    # -----------------------------------------------------
    # Marginal Log Likelihood
    # -----------------------------------------------------

    def mll(self, X: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Compute the marginal log likelihood of a soft GP:
            
            log p(y) = log N(y | mu_x, Q_xx)

            where
                mu_X: mean of soft GP
                Q_XX = W_xz K_zz W_zx

        Args:
            X (torch.Tensor): B x D tensor of inputs where each row is a point.
            y (torch.Tensor): B tensor of targets.

        Returns:
            torch.Tensor:  log p(y)
        """        
        # Construct covariance matrix components
        K_zz = self._mk_cov(self.inducing_points)
        W_xz = self._interp(X)
        
        if self.mll_approx == "exact":
            # [Note]: Compute MLL with a multivariate normal. Unstable for float.
            # 1. mean: 0
            mean = torch.zeros(len(X), dtype=self.dtype, device=self.device)
            
            # 2. covariance: Q_xx = (W_xz L) (L^T W_xz) + noise I  where K_zz = L L^T
            L = psd_safe_cholesky(K_zz)
            LK = (W_xz @ L).to(device=self.device)
            cov_diag = self.noise * torch.ones(len(X), dtype=self.dtype, device=self.device)

            # 3. N(mu, Q_xx)
            normal_dist = torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal(mean, LK, cov_diag, validate_args=None)
            
            # 4. log N(y | mu, Q_xx)
            return normal_dist.log_prob(y)
        elif self.mll_approx == "hutchinson":
            # [Note]: Compute MLL with Hutchinson's trace estimator
            # 1. mean: 0
            mean = torch.zeros(len(X), dtype=self.dtype, device=self.device)
            
            # 2. covariance: Q_xx = W_xz K_zz K_zx + noise I
            cov_mat = W_xz @ K_zz @ W_xz.T 
            cov_mat += torch.eye(cov_mat.shape[1], dtype=self.dtype, device=self.device) * self.noise

            # 3. log N(y | mu, Q_xx) \appox 
            hutchinson_mll = HutchinsonPseudoLoss(self, num_trace_samples=10)
            return hutchinson_mll(mean, cov_mat, y)
        else:
            raise ValueError(f"Unknown MLL approximation method: {self.mll_approx}")
        
    # -----------------------------------------------------
    # Fit
    # -----------------------------------------------------

    def _direct_solve_fit(self, M, N, X, y, K_zz):
        # Construct A and b for linear solve
        #   A = (K_zz + hat{K}_zx @ noise^{-1} @ hat{K}_xz)
        #   b = (hat{K}_zx @ noise^{-1}) y
        if X.shape[0] * X.shape[1] <= 32768:
            # Case: "small" X
            # Form estimate \hat{K}_xz ~= W_xz K_zz
            W_xz = self._interp(X)
            hat_K_xz = W_xz @ K_zz
            hat_K_zx = hat_K_xz.T
            
            # Form A and b
            Lambda_inv_diag = (1 / self.noise) * torch.ones(N, dtype=self.dtype).to(self.device)
            A = K_zz + hat_K_zx @ (Lambda_inv_diag.unsqueeze(1) * hat_K_xz)
            b = hat_K_zx @ (Lambda_inv_diag * y)
        else:
            # Case: "large" X
            with torch.no_grad():
                # Initialize outputs
                A = torch.zeros(M, M, dtype=self.dtype, device=self.device)
                b = torch.zeros(M, dtype=self.dtype, device=self.device)
                
                # Initialize temporary values
                fit_chunk_size = self.fit_chunk_size
                batches = int(np.floor(N / fit_chunk_size))
                Lambda_inv = (1 / self.noise) * torch.eye(fit_chunk_size, dtype=self.dtype, device=self.device)
                tmp1 = torch.zeros(fit_chunk_size, M, dtype=self.dtype, device=self.device)
                tmp2 = torch.zeros(M, M, dtype=self.dtype, device=self.device)
                tmp3 = torch.zeros(fit_chunk_size, dtype=self.dtype, device=self.device)
                tmp4 = torch.zeros(M, dtype=self.dtype, device=self.device)
                tmp5 = torch.zeros(M, dtype=self.dtype, device=self.device)
                
                # Compute batches
                for i in range(batches):
                    # Update A: A += W_zx @ Lambda_inv @ W_xz
                    X_batch = X[i*fit_chunk_size:(i+1)*fit_chunk_size]
                    W_xz = self._interp(X_batch)
                    W_zx = W_xz.T
                    torch.matmul(Lambda_inv, W_xz, out=tmp1)
                    torch.matmul(W_zx, tmp1, out=tmp2)
                    A.add_(tmp2)
                    
                    # Update b: b += K_zz @ W_zx @ (Lambda_inv @ Y[i*batch_size:(i+1)*batch_size])
                    torch.matmul(Lambda_inv, y[i*fit_chunk_size:(i+1)*fit_chunk_size], out=tmp3)
                    torch.matmul(W_zx, tmp3, out=tmp4)
                    torch.matmul(K_zz, tmp4, out=tmp5)
                    b.add_(tmp5)
                
                # Compute last batch
                if N - (i+1)*fit_chunk_size > 0:
                    Lambda_inv = (1 / self.noise) * torch.eye(N - (i+1)*fit_chunk_size, dtype=self.dtype, device=self.device)
                    X_batch = X[(i+1)*fit_chunk_size:]
                    W_xz = self._interp(X_batch)
                    A += W_xz.T @ Lambda_inv @ W_xz
                    b += K_zz @ W_xz.T @ Lambda_inv @ y[(i+1)*fit_chunk_size:]

                # Aggregate result
                A = K_zz + K_zz @ A @ K_zz

        # Safe solve A \alpha = b
        A = DenseLinearOperator(A)
        self.alpha, use_pinv = self._solve_system(
            A,
            b.unsqueeze(1),
            x0=torch.zeros_like(b),
            forwards_matmul=A.matmul,
            precond=None,
            return_pinv=True
        )

        # Store for fast prediction
        self.K_zz_alpha = K_zz @ self.alpha
        return use_pinv

    def _qr_solve_fit(self, M, N, X, y, K_zz):
        if X.shape[0] * X.shape[1] <= 32768:
            # Compute: W_xz K_zz
            print("USING QR SMALL")
            W_xz = self._interp(X)
            hat_K_xz = W_xz @ K_zz
        else:
            # Compute: W_xz K_zz in a batched fashion
            print("USING QR BATCH")
            with torch.no_grad():
                # Compute batches
                fit_chunk_size = self.fit_chunk_size
                batches = int(np.floor(N / fit_chunk_size))
                Lambda_half_inv_diag = (1 / torch.sqrt(self.noise)) * torch.ones(fit_chunk_size, dtype=self.dtype, device=self.device)
                hat_K_xz = torch.zeros((N, M), dtype=self.dtype, device=self.device)
                for i in range(batches):
                    start = i*fit_chunk_size
                    end = (i+1)*fit_chunk_size
                    X_batch = X[start:end,:]
                    W_xz = self._interp(X_batch)
                    torch.matmul(W_xz, K_zz, out=hat_K_xz[start:end,:])
                
                start = (i+1)*fit_chunk_size
                if N - start > 0:
                    Lambda_half_inv_diag = (1 / torch.sqrt(self.noise)) * torch.eye(N - (i+1)*fit_chunk_size, dtype=self.dtype, device=self.device)
                    X_batch = X[start:]
                    W_xz = self._interp(X_batch)
                    torch.matmul(W_xz, K_zz, out=hat_K_xz[start:,:])
        
        # B^T = [(Lambda^{-1/2} \hat{K}_xz) U_zz ]
        U_zz = psd_safe_cholesky(K_zz, upper=True, max_tries=10)
        Lambda_half_inv_diag = (1 / torch.sqrt(self.noise)) * torch.ones(N, dtype=self.dtype).to(self.device)
        B = torch.cat([Lambda_half_inv_diag.unsqueeze(1) * hat_K_xz, U_zz], dim=0)

        # B = QR
        Q, R = torch.linalg.qr(B)

        # \alpha = R^{-1} @ Q^T @ Lambda^{-1/2}b
        b = Lambda_half_inv_diag * y
        self.alpha = torch.linalg.solve_triangular(R, (Q.T[:, 0:N] @ b).unsqueeze(1), upper=True).squeeze(1) # (should use triangular solve)
        # self.alpha = ((torch.linalg.inv(R) @ Q.T)[:, :N] @ b)
        
        # Store for fast inference
        self.K_zz_alpha = K_zz @ self.alpha

        return False

    def fit(self, X: torch.Tensor, y: torch.Tensor) -> bool:
        """Fits a SoftGP to dataset (X, y). That is, solve:

                (hat{K}_zx @ noise^{-1}) y = (K_zz + hat{K}_zx @ noise^{-1} @ hat{K}_xz) \alpha
        
            for \alpha where
            1. inducing points z are fixed,
            2. hat{K}_zx = K_zz W_zx, and
            3. hat{K}_xz = hat{K}_zx^T.

        Args:
            X (torch.Tensor): N x D tensor of inputs
            y (torch.Tensor): N tensor of outputs

        Returns:
            bool: Returns true if the pseudoinverse was used, false otherwise.
        """        
        # Prepare inputs
        N = len(X)
        M = len(self.inducing_points)
        X = X.to(self.device, dtype=self.dtype)
        y = y.to(self.device, dtype=self.dtype)

        # Form K_zz
        K_zz = self._mk_cov(self.inducing_points)

        if self.use_qr:
            return self._qr_solve_fit(M, N, X, y, K_zz)
        else:
            return self._direct_solve_fit(M, N, X, y, K_zz)

    # -----------------------------------------------------
    # Predict
    # -----------------------------------------------------

    def pred(self, x_star: torch.Tensor) -> torch.Tensor:
        """Give the posterior predictive:
        
            p(y_star | x_star, X, y) 
                = W_star_z (K_zz \alpha)
                = W_star_z K_zz (K_zz + hat{K}_zx @ noise^{-1} @ hat{K}_xz)^{-1} (hat{K}_zx @ noise^{-1}) y

        Args:
            x_star (torch.Tensor): B x D tensor of points to evaluate at.

        Returns:
            torch.Tensor: B tensor of p(y_star | x_star, X, y).
        """        
        W_star_z = self._interp(x_star)
        return torch.matmul(W_star_z, self.K_zz_alpha).squeeze(-1)

### SoftGP test implementation 

---

In [3]:
class SoftGP_test(torch.nn.Module):
    def __init__(
        self,
        kernel: Callable,
        inducing_points: torch.Tensor,
        noise=1e-3,
        learn_noise=False,
        use_scale=False,
        device="cpu",
        dtype=torch.float32,
        solver="solve",
        max_cg_iter=50,
        cg_tolerance=0.5,
        mll_approx="hutchinson",
        fit_chunk_size=1024,
        use_qr=False,
    ) -> None:
        # Argument checking 
        methods = ["solve", "cholesky", "cg"]
        if not solver in methods:
            raise ValueError(f"Method {solver} should be in {methods} ...")
        
        # Check devices
        devices = ["cpu"]
        if torch.cuda.is_available():
            devices += ["cuda"]
            for i in range(torch.cuda.device_count()):
                devices += [f"cuda:{i}"]
        if not device in devices:
            raise ValueError(f"Device {device} should be in {devices} ...")

        # Create torch module
        super(SoftGP_test, self).__init__()

        # Misc
        self.device = device
        self.dtype = dtype
        
        # Mll approximation settings
        self.solve_method = solver
        self.mll_approx = mll_approx

        # Fit settings
        self.use_qr = use_qr
        self.fit_chunk_size = fit_chunk_size

        # Noise
        self.noise_constraint = gpytorch.constraints.Positive()
        noise = torch.tensor([noise], dtype=self.dtype, device=self.device)
        noise = self.noise_constraint.inverse_transform(noise)
        if learn_noise:
            self.register_parameter("raw_noise", torch.nn.Parameter(noise))
        else:
            self.raw_noise = noise

        # Kernel
        self.use_scale = use_scale
        if use_scale:
            self.kernel = ScaleKernel(kernel).to(self.device)
        else:
            self.kernel = kernel.to(self.device)

        # Inducing points
        self.register_parameter("inducing_points", torch.nn.Parameter(inducing_points))

        # Interpolation
        #self.T = torch.nn.Parameter(torch.tensor(.005))

        def softmax_interp( X: torch.Tensor, sigma_values: torch.Tensor) -> torch.Tensor:
            X = X / .001 # Use the learnable T
            distances = torch.linalg.vector_norm(X - sigma_values, ord=2, dim=-1)
            softmax_distances = torch.softmax(-distances, dim=-1)
            softmax_distances = torch.where(softmax_distances < 1e-16, torch.tensor(0.0, dtype=softmax_distances.dtype), softmax_distances)
            
            return softmax_distances

        self.interp = softmax_interp
        
        # Fit artifacts
        self.alpha = None
        self.K_zz_alpha = None

        # CG solver params
        self.max_cg_iter = max_cg_iter
        self.cg_tol = cg_tolerance
        self.x0 = None
        
    # -----------------------------------------------------
    # Soft GP Helpers
    # -----------------------------------------------------
    
    @property
    def noise(self):
        return self.noise_constraint.transform(self.raw_noise)

    def get_lengthscale(self) -> float:
        if self.use_scale:
            return self.kernel.base_kernel.lengthscale.cpu()
        else:
            return self.kernel.lengthscale.cpu()
        
    def get_outputscale(self) -> float:
        if self.use_scale:
            return self.kernel.outputscale.cpu()
        else:
            return 1.

    def _mk_cov(self, z: torch.Tensor) -> torch.Tensor:
        return self.kernel(z, z).evaluate()
    
    # def _interp(self, x: torch.Tensor) -> torch.Tensor:
    #     # Expand input x and perform interpolation to get the dense matrix W_xz_dense
    #     x_expanded = x.unsqueeze(1).expand(-1, self.inducing_points.shape[0], -1)
    #     W_xz_dense = self.interp(x_expanded, self.inducing_points)
        
    #     non_zero_indices = W_xz_dense.nonzero(as_tuple=False).t()  
    #     non_zero_values = W_xz_dense[non_zero_indices[0], non_zero_indices[1]]  
        
    #     W_xz_sparse = torch.sparse_coo_tensor(
    #         non_zero_indices, non_zero_values, W_xz_dense.size(), device=W_xz_dense.device
    #     )
        
    #     W_xz_sparse = W_xz_sparse.coalesce()
        
    #     return W_xz_sparse

    def _interp(self, x: torch.Tensor) -> torch.Tensor:
        x_expanded = x.unsqueeze(1).expand(-1, self.inducing_points.shape[0], -1)
        W_xz = self.interp(x_expanded, self.inducing_points)
        return W_xz

    # -----------------------------------------------------
    # Linear solver
    # -----------------------------------------------------

    def _solve_system(
        self,
        kxx: linear_operator.operators.LinearOperator,
        full_rhs: torch.Tensor,
        x0: torch.Tensor = None,
        forwards_matmul: Callable = None,
        precond: torch.Tensor = None,
        return_pinv: bool = False,
    ) -> torch.Tensor:
        use_pinv = False
        with torch.no_grad():
            try:
                if self.solve_method == "solve":
                    solve = torch.linalg.solve(kxx, full_rhs)
                elif self.solve_method == "cholesky":
                    L = torch.linalg.cholesky(kxx)
                    solve = torch.cholesky_solve(full_rhs, L)
                elif self.solve_method == "cg":
                    # Source: https://github.com/AndPotap/halfpres_gps/blob/main/mlls/mixedpresmll.py
                    solve = linear_cg(
                        forwards_matmul,
                        full_rhs,
                        max_iter=self.max_cg_iter,
                        tolerance=self.cg_tol,
                        initial_guess=x0,
                        preconditioner=precond,
                    )
                else:
                    raise ValueError(f"Unknown method: {self.solve_method}")
            except RuntimeError as e:
                print("Fallback to pseudoinverse: ", str(e))
                solve = torch.linalg.pinv(kxx.evaluate()) @ full_rhs
                use_pinv = True

        # Apply torch.nan_to_num to handle NaNs from percision limits 
        solve = torch.nan_to_num(solve)
        return (solve, use_pinv) if return_pinv else solve

    # -----------------------------------------------------
    # Marginal Log Likelihood
    # -----------------------------------------------------

    def mll(self, X: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Compute the marginal log likelihood of a soft GP:
            
            log p(y) = log N(y | mu_x, Q_xx)

            where
                mu_X: mean of soft GP
                Q_XX = W_xz K_zz W_zx

        Args:
            X (torch.Tensor): B x D tensor of inputs where each row is a point.
            y (torch.Tensor): B tensor of targets.

        Returns:
            torch.Tensor:  log p(y)
        """        
        # Construct covariance matrix components
        K_zz = self._mk_cov(self.inducing_points)
        W_xz = self._interp(X)
        
        if self.mll_approx == "exact":
            # [Note]: Compute MLL with a multivariate normal. Unstable for float.
            # 1. mean: 0
            mean = torch.zeros(len(X), dtype=self.dtype, device=self.device)
            
            # 2. covariance: Q_xx = (W_xz L) (L^T W_xz) + noise I  where K_zz = L L^T
            L = psd_safe_cholesky(K_zz)
            LK = (W_xz @ L).to(device=self.device)
            cov_diag = self.noise * torch.ones(len(X), dtype=self.dtype, device=self.device)

            # 3. N(mu, Q_xx)
            normal_dist = torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal(mean, LK, cov_diag, validate_args=None)
            
            # 4. log N(y | mu, Q_xx)
            return normal_dist.log_prob(y)
        elif self.mll_approx == "hutchinson":
            # [Note]: Compute MLL with Hutchinson's trace estimator
            # 1. mean: 0
            mean = torch.zeros(len(X), dtype=self.dtype, device=self.device)
            
            # 2. covariance: Q_xx = W_xz K_zz K_zx + noise I
            cov_mat = W_xz @ K_zz @ W_xz.T 
            cov_mat += torch.eye(cov_mat.shape[1], dtype=self.dtype, device=self.device) * self.noise

            # 3. log N(y | mu, Q_xx) \appox 
            hutchinson_mll = HutchinsonPseudoLoss(self, num_trace_samples=10)
            return hutchinson_mll(mean, cov_mat, y)
        else:
            raise ValueError(f"Unknown MLL approximation method: {self.mll_approx}")
        
    # -----------------------------------------------------
    # Fit
    # -----------------------------------------------------
    def _direct_solve_fit(self, M, N, X, y, K_zz):
        # Construct A and b for linear solve
        #   A = (K_zz + hat{K}_zx @ noise^{-1} @ hat{K}_xz)
        #   b = (hat{K}_zx @ noise^{-1}) y
        if X.shape[0] * X.shape[1] <= 32768:
            # Case: "small" X
            # Form estimate \hat{K}_xz ~= W_xz K_zz
            W_xz = self._interp(X)
            hat_K_xz = W_xz @ K_zz
            hat_K_zx = hat_K_xz.T
            
            # Form A and b
            Lambda_inv_diag = (1 / self.noise) * torch.ones(N, dtype=self.dtype).to(self.device)
            A = K_zz + hat_K_zx @ (Lambda_inv_diag.unsqueeze(1) * hat_K_xz)
            b = hat_K_zx @ (Lambda_inv_diag * y)
        else:
            # Case: "large" X
            with torch.no_grad():
                # Initialize outputs
                A = torch.zeros(M, M, dtype=self.dtype, device=self.device)
                b = torch.zeros(M, dtype=self.dtype, device=self.device)
                
                # Initialize temporary values
                fit_chunk_size = self.fit_chunk_size
                batches = int(np.floor(N / fit_chunk_size))
                Lambda_inv = (1 / self.noise) * torch.eye(fit_chunk_size, dtype=self.dtype, device=self.device)
                tmp1 = torch.zeros(fit_chunk_size, M, dtype=self.dtype, device=self.device)
                tmp2 = torch.zeros(M, M, dtype=self.dtype, device=self.device)
                tmp3 = torch.zeros(fit_chunk_size, dtype=self.dtype, device=self.device)
                tmp4 = torch.zeros(M, dtype=self.dtype, device=self.device)
                tmp5 = torch.zeros(M, dtype=self.dtype, device=self.device)
                
                # Compute batches
                for i in range(batches):
                    # Update A: A += W_zx @ Lambda_inv @ W_xz
                    X_batch = X[i*fit_chunk_size:(i+1)*fit_chunk_size]
                    W_xz = self._interp(X_batch)
                    W_zx = W_xz.T
                    torch.matmul(Lambda_inv, W_xz, out=tmp1)
                    torch.matmul(W_zx, tmp1, out=tmp2)
                    A.add_(tmp2)
                    
                    # Update b: b += K_zz @ W_zx @ (Lambda_inv @ Y[i*batch_size:(i+1)*batch_size])
                    torch.matmul(Lambda_inv, y[i*fit_chunk_size:(i+1)*fit_chunk_size], out=tmp3)
                    torch.matmul(W_zx, tmp3, out=tmp4)
                    torch.matmul(K_zz, tmp4, out=tmp5)
                    b.add_(tmp5)
                
                # Compute last batch
                if N - (i+1)*fit_chunk_size > 0:
                    Lambda_inv = (1 / self.noise) * torch.eye(N - (i+1)*fit_chunk_size, dtype=self.dtype, device=self.device)
                    X_batch = X[(i+1)*fit_chunk_size:]
                    W_xz = self._interp(X_batch)
                    A += W_xz.T @ Lambda_inv @ W_xz
                    b += K_zz @ W_xz.T @ Lambda_inv @ y[(i+1)*fit_chunk_size:]

                # Aggregate result
                A = K_zz + K_zz @ A @ K_zz

        # Safe solve A \alpha = b
        A = DenseLinearOperator(A)
        self.alpha, use_pinv = self._solve_system(
            A,
            b.unsqueeze(1),
            x0=torch.zeros_like(b),
            forwards_matmul=A.matmul,
            precond=None,
            return_pinv=True
        )

        # Store for fast prediction
        self.K_zz_alpha = K_zz @ self.alpha
        return use_pinv
    # def _direct_solve_fit(self, M, N, X, y, K_zz):
    #     # Construct A and b for linear solve
    #     #   A = (K_zz + hat{K}_zx @ noise^{-1} @ hat{K}_xz)
    #     #   b = (hat{K}_zx @ noise^{-1}) y
    #     if X.shape[0] * X.shape[1] <= 32768:
    #         # Case: "small" X
    #         # Form estimate \hat{K}_xz ~= W_xz K_zz
    #         W_xz = self._interp(X)
    #         hat_K_xz = W_xz @ K_zz
    #         hat_K_zx = hat_K_xz.T
            
    #         # Form A and b
    #         Lambda_inv_diag = (1 / self.noise) * torch.ones(N, dtype=self.dtype).to(self.device)
    #         A = K_zz + hat_K_zx @ (Lambda_inv_diag.unsqueeze(1) * hat_K_xz)
    #         b = hat_K_zx @ (Lambda_inv_diag * y)
    #     else:
    #         # Case: "large" X
    #         with torch.no_grad():
    #             # Initialize outputs
    #             A = torch.zeros(M, M, dtype=self.dtype, device=self.device)
    #             b = torch.zeros(M, dtype=self.dtype, device=self.device)
                
    #             # Initialize temporary values
    #             fit_chunk_size = self.fit_chunk_size
    #             batches = int(np.floor(N / fit_chunk_size))
    #             Lambda_inv = (1 / self.noise) * torch.eye(fit_chunk_size, dtype=self.dtype, device=self.device)
    #             tmp1 = torch.zeros(fit_chunk_size, M, dtype=self.dtype, device=self.device)
    #             tmp2 = torch.zeros(M, M, dtype=self.dtype, device=self.device)
    #             tmp3 = torch.zeros(fit_chunk_size, dtype=self.dtype, device=self.device)
    #             tmp4 = torch.zeros(M, dtype=self.dtype, device=self.device)
    #             tmp5 = torch.zeros(M, dtype=self.dtype, device=self.device)
                
    #         for i in range(batches):
    #             # Update A: A += W_zx @ Lambda_inv @ W_xz
    #             X_batch = X[i*fit_chunk_size:(i+1)*fit_chunk_size]
                
    #             # Compute W_xz as sparse tensor
    #             W_xz = self._interp(X_batch)  # This returns a sparse tensor
    #             W_zx = W_xz.T# Transpose to get W_zx
    #             sparsity = 1.0 - W_xz._nnz() / float(W_xz.numel())
    #             print(sparsity)
    #             # Perform sparse-dense matrix multiplication
    #             tmp1 = torch.sparse.mm(Lambda_inv,W_xz)  # Sparse x Dense
    #             tmp2 = torch.sparse.mm(W_zx, tmp1)  # Sparse x Dense
    #             A.add_(tmp2)  # Update A
                
    #             # Update b: b += K_zz @ W_zx @ (Lambda_inv @ Y_batch)
    #             Y_batch = y[i*fit_chunk_size:(i+1)*fit_chunk_size]
    #             tmp3 = torch.matmul(Lambda_inv, Y_batch)  # Dense x Dense
    #             tmp4 = torch.sparse.mm(W_zx, tmp3.unsqueeze(-1) )  # Sparse x Dense
    #             tmp5 = torch.matmul(K_zz, tmp4)  # Dense x Dense
    #             b.add_(tmp5.squeeze())
                
    #             # Compute last batch
    #         if N - (i + 1) * fit_chunk_size > 0:
    #             last_batch_size = N - (i + 1) * fit_chunk_size
                
    #             # Create the identity matrix for Lambda_inv for the last batch
    #             Lambda_inv = (1 / self.noise) * torch.eye(last_batch_size, dtype=self.dtype, device=self.device)
                
    #             # Get the last batch of X and y
    #             X_batch = X[(i + 1) * fit_chunk_size:]
    #             Y_batch = y[(i + 1) * fit_chunk_size:]
                
    #             # Compute W_xz as a sparse tensor for the last batch
    #             W_xz = self._interp(X_batch)
                
    #             # Transpose the sparse tensor W_xz to get W_zx
    #             W_zx = W_xz.transpose(0, 1)
    #             # Update A: A += W_zx @ Lambda_inv @ W_xz
    #             tmp1 = torch.sparse.mm( Lambda_inv,W_xz)  # Sparse x Dense
    #             tmp2 = torch.sparse.mm(W_zx, tmp1)  # Sparse x Dense
    #             A.add_(tmp2)  # Update A (dense)
                
    #             # Update b: b += K_zz @ W_zx @ (Lambda_inv @ Y_batch)
    #             tmp3 = torch.matmul(Lambda_inv, Y_batch)  # Dense x Dense
    #             tmp4 = torch.sparse.mm(W_zx, tmp3.unsqueeze(-1))  # Sparse x Dense
    #             tmp5 = torch.matmul(K_zz, tmp4)  # Dense x Dense
    #             b.add_(tmp5.squeeze())  # Update b



    #         # Aggregate result
    #         A = K_zz + K_zz @ A @ K_zz

    #     # Safe solve A \alpha = b
    #     A = DenseLinearOperator(A)
    #     self.alpha, use_pinv = self._solve_system(
    #         A,
    #         b.unsqueeze(1),
    #         x0=torch.zeros_like(b),
    #         forwards_matmul=A.matmul,
    #         precond=None,
    #         return_pinv=True
    #     )

    #     # Store for fast prediction
    #     self.K_zz_alpha = K_zz @ self.alpha
    #     return use_pinv

    def _qr_solve_fit(self, M, N, X, y, K_zz):
        if X.shape[0] * X.shape[1] <= 32768:
            # Compute: W_xz K_zz
            print("USING QR SMALL")
            W_xz = self._interp(X)
            hat_K_xz = W_xz @ K_zz
        else:
            # Compute: W_xz K_zz in a batched fashion
            print("USING QR BATCH")
            with torch.no_grad():
                # Compute batches
                fit_chunk_size = self.fit_chunk_size
                batches = int(np.floor(N / fit_chunk_size))
                Lambda_half_inv_diag = (1 / torch.sqrt(self.noise)) * torch.ones(fit_chunk_size, dtype=self.dtype, device=self.device)
                hat_K_xz = torch.zeros((N, M), dtype=self.dtype, device=self.device)
                for i in range(batches):
                    start = i*fit_chunk_size
                    end = (i+1)*fit_chunk_size
                    X_batch = X[start:end,:]
                    W_xz = self._interp(X_batch)
                    torch.matmul(W_xz, K_zz, out=hat_K_xz[start:end,:])
                
                start = (i+1)*fit_chunk_size
                if N - start > 0:
                    Lambda_half_inv_diag = (1 / torch.sqrt(self.noise)) * torch.eye(N - (i+1)*fit_chunk_size, dtype=self.dtype, device=self.device)
                    X_batch = X[start:]
                    W_xz = self._interp(X_batch)
                    torch.matmul(W_xz, K_zz, out=hat_K_xz[start:,:])
        
        # B^T = [(Lambda^{-1/2} \hat{K}_xz) U_zz ]
        U_zz = psd_safe_cholesky(K_zz, upper=True, max_tries=10)
        Lambda_half_inv_diag = (1 / torch.sqrt(self.noise)) * torch.ones(N, dtype=self.dtype).to(self.device)
        B = torch.cat([Lambda_half_inv_diag.unsqueeze(1) * hat_K_xz, U_zz], dim=0)

        # B = QR
        Q, R = torch.linalg.qr(B)

        # \alpha = R^{-1} @ Q^T @ Lambda^{-1/2}b
        b = Lambda_half_inv_diag * y
        self.alpha = torch.linalg.solve_triangular(R, (Q.T[:, 0:N] @ b).unsqueeze(1), upper=True).squeeze(1) # (should use triangular solve)
        # self.alpha = ((torch.linalg.inv(R) @ Q.T)[:, :N] @ b)
        
        # Store for fast inference
        self.K_zz_alpha = K_zz @ self.alpha

        return False

    def fit(self, X: torch.Tensor, y: torch.Tensor) -> bool:
        """Fits a SoftGP to dataset (X, y). That is, solve:

                (hat{K}_zx @ noise^{-1}) y = (K_zz + hat{K}_zx @ noise^{-1} @ hat{K}_xz) \alpha
        
            for \alpha where
            1. inducing points z are fixed,
            2. hat{K}_zx = K_zz W_zx, and
            3. hat{K}_xz = hat{K}_zx^T.

        Args:
            X (torch.Tensor): N x D tensor of inputs
            y (torch.Tensor): N tensor of outputs

        Returns:
            bool: Returns true if the pseudoinverse was used, false otherwise.
        """        
        # Prepare inputs
        N = len(X)
        M = len(self.inducing_points)
        X = X.to(self.device, dtype=self.dtype)
        y = y.to(self.device, dtype=self.dtype)

        # Form K_zz
        K_zz = self._mk_cov(self.inducing_points)

        if self.use_qr:
            return self._qr_solve_fit(M, N, X, y, K_zz)
        else:
            return self._direct_solve_fit(M, N, X, y, K_zz)

    # -----------------------------------------------------
    # Predict
    # -----------------------------------------------------

    def pred(self, x_star: torch.Tensor) -> torch.Tensor:
        """Give the posterior predictive:
        
            p(y_star | x_star, X, y) 
                = W_star_z (K_zz \alpha)
                = W_star_z K_zz (K_zz + hat{K}_zx @ noise^{-1} @ hat{K}_xz)^{-1} (hat{K}_zx @ noise^{-1}) y

        Args:
            x_star (torch.Tensor): B x D tensor of points to evaluate at.

        Returns:
            torch.Tensor: B tensor of p(y_star | x_star, X, y).
        """        
        W_star_z = self._interp(x_star)
        return torch.matmul(W_star_z, self.K_zz_alpha).squeeze(-1)

### Test version profiling stats
---

In [4]:
from line_profiler import LineProfiler

def eval_gp(model, test_dataset: Dataset, device="cuda:0") -> float:
    preds = []
    neg_mlls = []
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,num_workers=1)
    for x_batch, y_batch in tqdm(test_loader):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        preds += [(model.pred(x_batch) - y_batch).detach().cpu()**2]
        neg_mlls += [-model.mll(x_batch, y_batch).detach().cpu()]
    rmse = torch.sqrt(torch.sum(torch.cat(preds)) / len(test_dataset)).item()
    neg_mll = torch.sum(torch.tensor(neg_mlls))
            
    print("RMSE:", rmse, "NEG_MLL", neg_mll.item(), "NOISE", model.noise.cpu().item(), "LENGTHSCALE", model.get_lengthscale(), "OUTPUTSCALE", model.get_outputscale())
    
    return {
        "rmse": rmse,
        "nll": neg_mll,
    }   
    
def profileGP(GP,train_dataset, test_dataset):
    num_inducing = 512
    dtype = torch.float32
    batch_size = 1024
    epochs = 2
    lr = 0.01 
    learn_noise = False
    device = "cpu"
    
    # Initialize inducing points with kmeans
    train_features, train_labels = flatten_dataset(train_dataset)
    kmeans = KMeans(n_clusters=num_inducing)
    kmeans.fit(train_features)
    centers = kmeans.cluster_centers_
    inducing_points = torch.tensor(centers).to(dtype=dtype, device=device)
    
    # Setup model
    kernel = RBFKernel().to(device=device, dtype=dtype)

    model = GP(kernel,
        inducing_points,
        noise=1e-3,
        learn_noise=learn_noise,
        use_scale=False,
        dtype=dtype,
        solver="solve",
        max_cg_iter=50,
        cg_tolerance=0.5,
        mll_approx="hutchinson",
        fit_chunk_size=1024,
        use_qr=True,
    )

    if learn_noise:
        params = model.parameters()
    else:
        params = filter_param(model.named_parameters(), "likelihood.noise_covar.raw_noise")
    optimizer = torch.optim.Adam(params, lr=lr)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    pbar = tqdm(range(epochs), desc="Optimizing MLL")
    def train_model():
        for epoch in pbar:
            t1 = time.perf_counter()
            
            neg_mlls = []
            for x_batch, y_batch in train_loader:
                x_batch = x_batch.clone().detach().to(dtype=dtype, device=device)
                y_batch = y_batch.clone().detach().to(dtype=dtype, device=device)

                optimizer.zero_grad()
                with gpytorch.settings.max_root_decomposition_size(100), max_cholesky_size(int(1.e7)):
                    neg_mll = -model.mll(x_batch, y_batch)
                neg_mlls += [-neg_mll.item()]
                neg_mll.backward()
                optimizer.step()

                pbar.set_description(f"Epoch {epoch+1}/{epochs}")
                pbar.set_postfix(MLL=f"{-neg_mll.item()}")
            t2 = time.perf_counter()

            use_pinv = model.fit(train_features, train_labels)
            t3 = time.perf_counter()
            #results = eval_gp(model, test_dataset, device=device)
            #print(results)
    profiler = LineProfiler()
    profiler.add_function(train_model)
    profiler.add_function(model.fit)  
    profiler.add_function(model.interp)  
    profiler.add_function(model._qr_solve_fit) 
    profiler.add_function(model.mll)  


    profiler.enable_by_count()
    train_model() 
    profiler.disable_by_count()
    profiler.print_stats()
    return model


### Notes
- Total time: 76.4511 s QR no changes 

- Total time: 46.756 s sparse 

- baseline + boltz mask
- Total time: 17.1253 s, 15.5868 s, 14.7804 s


baseline no mask 
Total time  22.1084 s  

In [5]:
from data.get_uci import ElevatorsDataset
dataset = ElevatorsDataset("../data/uci_datasets/uci_datasets/pol/data.csv")
train_dataset, val_dataset, test_dataset = split_dataset(
    dataset,
    train_frac=9/10,
    val_frac=0/10
)

softgp = profileGP(SoftGP_test,train_dataset,test_dataset)

SIZE (15000, 27)


100%|██████████| 2/2 [00:01<00:00,  1.89it/s]
Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md

Epoch 1/2:   0%|          | 0/2 [00:16<?, ?it/s, MLL=49.041358947753906]

USING QR BATCH


Epoch 2/2:  50%|█████     | 1/2 [00:39<00:21, 21.99s/it, MLL=74.28733825683594] 

USING QR BATCH


Epoch 2/2: 100%|██████████| 2/2 [00:44<00:00, 22.34s/it, MLL=74.28733825683594]

Timer unit: 1e-07 s

Total time: 44.6828 s
File: C:\Users\chris\AppData\Local\Temp\ipykernel_29896\3236570565.py
Function: train_model at line 63

Line #      Hits         Time  Per Hit   % Time  Line Contents
    63                                               def train_model():
    64         3      39386.0  13128.7      0.0          for epoch in pbar:
    65         2         79.0     39.5      0.0              t1 = time.perf_counter()
    66                                                       
    67         2         22.0     11.0      0.0              neg_mlls = []
    68        30   85255323.0    3e+06     19.1              for x_batch, y_batch in train_loader:
    69        28      36040.0   1287.1      0.0                  x_batch = x_batch.clone().detach().to(dtype=dtype, device=device)
    70        28       2243.0     80.1      0.0                  y_batch = y_batch.clone().detach().to(dtype=dtype, device=device)
    71                                           
    72  




### Comparison: Test vs Baseline 
---


In [6]:
def plot_results(GP_classes, all_mean_rmse, all_mean_runtimes, all_std_rmse, all_std_runtimes, epochs):
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    for i, GP_class in enumerate(GP_classes):
        epochs_range = range(1, epochs + 1)
        axes[0].plot(epochs_range, all_mean_rmse[i], label=f'{GP_class.__name__} RMSE')
        axes[0].fill_between(epochs_range,
                             [m - s for m, s in zip(all_mean_rmse[i], all_std_rmse[i])],
                             [m + s for m, s in zip(all_mean_rmse[i], all_std_rmse[i])],
                             alpha=0.3)
    axes[0].set_title('RMSE per Epoch')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('RMSE')
    axes[0].legend()

    for i, GP_class in enumerate(GP_classes):
        epochs_range = range(1, epochs + 1)
        axes[1].plot(epochs_range, all_mean_runtimes[i], label=f'{GP_class.__name__} Runtime')
        axes[1].fill_between(epochs_range,
                             [m - s for m, s in zip(all_mean_runtimes[i], all_std_runtimes[i])],
                             [m + s for m, s in zip(all_mean_runtimes[i], all_std_runtimes[i])],
                             alpha=0.3)
    axes[1].set_title('Training Time per Epoch')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Time (s)')
    axes[1].legend()

    plt.tight_layout()
    plt.show()

def train_gp(GP_class, inducing_points, test_dataset, train_features, train_labels, epochs, device, dtype):
    kernel = RBFKernel().to(device=device, dtype=dtype)
    learn_noise = False
    lr = .01
    batch_size = 1024

    model = GP_class(kernel,
                     inducing_points,
                     noise=1e-3,
                     learn_noise=learn_noise,
                     use_scale=False,
                     dtype=dtype,
                     solver="solve",
                     max_cg_iter=50,
                     cg_tolerance=0.5,
                     mll_approx="hutchinson",
                     fit_chunk_size=1024,
                     use_qr=True)

    epoch_runtimes = []
    epoch_rmse = []

    # pbar = tqdm(range(epochs), desc="Optimizing MLL")
    if learn_noise:
        params = model.parameters()
    else:
        params = filter_param(model.named_parameters(), "likelihood.noise_covar.raw_noise")
    optimizer = torch.optim.Adam(params, lr=lr)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    def train_model():
        #==================Train============================
        for _ in range(epochs):
            epoch_start_time = time.time()
            for x_batch, y_batch in train_loader:
                x_batch = x_batch.clone().detach().to(dtype=dtype, device=device)
                y_batch = y_batch.clone().detach().to(dtype=dtype, device=device)
                optimizer.zero_grad()
                with gpytorch.settings.max_root_decomposition_size(100), max_cholesky_size(int(1.e7)):
                    neg_mll = -model.mll(x_batch, y_batch)
                neg_mll.backward()
                optimizer.step()
                # pbar.set_description(f"Epoch {epoch + 1}/{epochs}")
                # pbar.set_postfix(MLL=f"{-neg_mll.item()}")

            model.fit(train_features, train_labels)
            epoch_end_time = time.time()
            epoch_runtimes.append(epoch_end_time - epoch_start_time)

            #==================Evaluate============================
            eval_results = eval_gp(model, test_dataset, device=device)
            epoch_rmse.append(eval_results['rmse'])
    
    train_model()
    return epoch_rmse, epoch_runtimes

def benchmark(GP_classes, train_dataset, test_dataset, epochs=2, seed=42, N=3):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    num_inducing = 512
    dtype = torch.float32
    device = "cpu"
    
    all_mean_rmse = []
    all_mean_runtimes = []
    all_std_rmse = []
    all_std_runtimes = []

    #==================Inducing Points============================
    train_features, train_labels = flatten_dataset(train_dataset)
    kmeans = KMeans(n_clusters=num_inducing)
    kmeans.fit(train_features)
    centers = kmeans.cluster_centers_
    inducing_points = torch.tensor(centers).to(dtype=dtype, device=device)

    for GP_class in GP_classes:
        print(f"Training {GP_class.__name__}...")
        
        # Run the experiment N times and store results for each run
        all_runs_rmse = []
        all_runs_runtimes = []
        
        for run in range(N):
            epoch_rmse, epoch_runtimes = train_gp(
                GP_class,
                inducing_points.clone(),
                test_dataset,
                train_features,
                train_labels,
                epochs,
                device,
                dtype
            )
            all_runs_rmse.append(epoch_rmse)
            all_runs_runtimes.append(epoch_runtimes)

        # Calculate mean and std deviation across the N runs
        mean_rmse = np.mean(all_runs_rmse, axis=0)
        std_rmse = np.std(all_runs_rmse, axis=0)
        mean_runtimes = np.mean(all_runs_runtimes, axis=0)
        std_runtimes = np.std(all_runs_runtimes, axis=0)

        all_mean_rmse.append(mean_rmse)
        all_mean_runtimes.append(mean_runtimes)
        all_std_rmse.append(std_rmse)
        all_std_runtimes.append(std_runtimes)

    return all_mean_rmse, all_mean_runtimes, all_std_rmse, all_std_runtimes

#==================Dataset============================
from data.get_uci import ElevatorsDataset
dataset = ElevatorsDataset("../data/uci_datasets/uci_datasets/pol/data.csv")
train_dataset, val_dataset, test_dataset = split_dataset(
    dataset,
    train_frac=9/10,
    val_frac=0/10  
)

#==================Benchmark============================
GP_classes = [SoftGP_baseline, SoftGP_test] 
epochs = 10
N = 1  # Number of runs
all_mean_rmse, all_mean_runtimes, all_std_rmse, all_std_runtimes = benchmark(GP_classes, train_dataset, test_dataset, epochs=epochs, seed=42, N=N)
plot_results(GP_classes, all_mean_rmse, all_mean_runtimes, all_std_rmse, all_std_runtimes, epochs)


SIZE (15000, 27)


100%|██████████| 2/2 [00:00<00:00,  2.06it/s]


Training SoftGP_baseline...
USING QR BATCH


100%|██████████| 47/47 [00:05<00:00,  8.09it/s]


RMSE: 0.4898330271244049 NEG_MLL -1650.2716064453125 NOISE 0.0010000000474974513 LENGTHSCALE tensor([[0.6957]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE 1.0
USING QR BATCH


100%|██████████| 47/47 [00:06<00:00,  7.02it/s]


RMSE: 0.45824193954467773 NEG_MLL -1488.86328125 NOISE 0.0010000000474974513 LENGTHSCALE tensor([[0.6755]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE 1.0
USING QR BATCH


100%|██████████| 47/47 [00:05<00:00,  9.15it/s]


RMSE: 0.43424949049949646 NEG_MLL -1380.808349609375 NOISE 0.0010000000474974513 LENGTHSCALE tensor([[0.6842]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE 1.0
USING QR BATCH


  0%|          | 0/47 [00:00<?, ?it/s]

In [None]:
SIZE (15000, 27)
100%|██████████| 2/2 [00:01<00:00,  1.95it/s]
Training SoftGP_baseline...
USING QR BATCH
100%|██████████| 47/47 [00:06<00:00,  7.74it/s]
RMSE: 0.5389598608016968 NEG_MLL 45598.875 NOISE 0.0010000000474974513 LENGTHSCALE tensor([[0.6435]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE 1.0
USING QR BATCH
100%|██████████| 47/47 [00:06<00:00,  7.42it/s]
RMSE: 0.5087882876396179 NEG_MLL 41339.1015625 NOISE 0.0010000000474974513 LENGTHSCALE tensor([[0.6068]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE 1.0
USING QR BATCH
100%|██████████| 47/47 [00:06<00:00,  7.20it/s]
RMSE: 0.48444491624832153 NEG_MLL 38776.59765625 NOISE 0.0010000000474974513 LENGTHSCALE tensor([[0.5828]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE 1.0
Training SoftGP_test...
USING QR BATCH
100%|██████████| 47/47 [00:04<00:00,  9.42it/s]
RMSE: 0.49580860137939453 NEG_MLL 29874.55859375 NOISE 0.0010000000474974513 LENGTHSCALE tensor([[0.7512]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE 1.0
USING QR BATCH
100%|██████████| 47/47 [00:04<00:00, 10.42it/s]
RMSE: 0.48252424597740173 NEG_MLL 30173.1953125 NOISE 0.0010000000474974513 LENGTHSCALE tensor([[0.8222]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE 1.0
USING QR BATCH
100%|██████████| 47/47 [00:04<00:00, 10.80it/s]
RMSE: 0.4675379991531372 NEG_MLL 29973.189453125 NOISE 0.0010000000474974513 LENGTHSCALE tensor([[0.9045]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE 1.0


SyntaxError: invalid character '█' (U+2588) (2939239835.py, line 2)