In [1]:
import matplotlib.pyplot as plt 
import time 
from tqdm import tqdm 
import random 
import torch 
from torch.utils.data import random_split, DataLoader, Dataset
import gpytorch
from gpytorch.kernels import ScaleKernel, RBFKernel
from sklearn.cluster import KMeans
from linear_operator.settings import max_cholesky_size

import sys 
sys.path.append("../")
from gp.util import dynamic_instantiation, flatten_dict, unflatten_dict, flatten_dataset, split_dataset, filter_param, heatmap

# System/Library imports
from typing import *

# Common data science imports
import numpy as np
import torch

# Gpytorch and linear_operator
import gpytorch 
import gpytorch.constraints
from gpytorch.kernels import ScaleKernel
import linear_operator
from linear_operator.operators.dense_linear_operator import DenseLinearOperator
from linear_operator.utils.cholesky import psd_safe_cholesky

# Our imports
from gp.soft_gp.mll import HutchinsonPseudoLoss
from linear_solver.cg import linear_cg

import pandas as pd
import wandb
from tqdm.notebook import tqdm
import pickle
import torch
from os.path import exists
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import math
from matplotlib.ticker import MaxNLocator
from seaborn import heatmap
from gpytorch.kernels import ScaleKernel, RBFKernel
from gp.soft_gp.soft_gp import SoftGP

from data.get_uci import all_datasets
from analysis.util import fetch, init_uci_dict, get_uci_info

In [2]:

filters = {
    "group": "benchmark7"
}
raw2= fetch("soft-gp-2", filters)

KeyboardInterrupt: 

In [None]:
uci_info = get_uci_info()
uci_dict2 = {}
for exp in raw2:
    model = exp.config["model.name"]
    dataset = exp.config["dataset.name"]
    dtype = exp.config["model.dtype"]
    seed = exp.config["training.seed"]
    train_frac = float(exp.config["dataset.train_frac"])
    uci_dict2[(dataset, seed, model)] = exp.history

print(uci_dict2.keys())
def load(dataset):
    kernels = []
    for epoch_exact in range(1, 51): 
        kernel = ScaleKernel(RBFKernel())
        l = uci_dict2[(dataset, 6535, "exact")]["lengthscale"][epoch_exact]
        s = uci_dict2[(dataset, 6535, "exact")]["outputscale"][epoch_exact]
        kernel.base_kernel.lengthscale = l
        kernel.outputscale = s
        print("l", l, "s", s)
        kernels += [kernel]

    return kernels


kernels = load("elevators")

dict_keys([('pol', 6535, 'soft-gp'), ('pol', 8830, 'soft-gp'), ('pol', 92357, 'soft-gp'), ('elevators', 6535, 'soft-gp'), ('elevators', 8830, 'soft-gp'), ('elevators', 92357, 'soft-gp'), ('bike', 6535, 'soft-gp'), ('bike', 8830, 'soft-gp'), ('bike', 92357, 'soft-gp'), ('kin40k', 6535, 'soft-gp'), ('kin40k', 8830, 'soft-gp'), ('kin40k', 92357, 'soft-gp'), ('protein', 6535, 'soft-gp'), ('protein', 8830, 'soft-gp'), ('protein', 92357, 'soft-gp'), ('keggdirected', 6535, 'soft-gp'), ('keggdirected', 8830, 'soft-gp'), ('keggdirected', 92357, 'soft-gp'), ('slice', 6535, 'soft-gp'), ('slice', 8830, 'soft-gp'), ('slice', 92357, 'soft-gp'), ('keggundirected', 6535, 'soft-gp'), ('keggundirected', 8830, 'soft-gp'), ('keggundirected', 92357, 'soft-gp'), ('3droad', 6535, 'soft-gp'), ('3droad', 8830, 'soft-gp'), ('3droad', 92357, 'soft-gp'), ('song', 6535, 'soft-gp'), ('song', 8830, 'soft-gp'), ('song', 92357, 'soft-gp'), ('buzz', 6535, 'soft-gp'), ('buzz', 8830, 'soft-gp'), ('buzz', 92357, 'soft-gp'

In [15]:
print(kernels)

[ScaleKernel(
  (base_kernel): RBFKernel(
    (raw_lengthscale_constraint): Positive()
  )
  (raw_outputscale_constraint): Positive()
), ScaleKernel(
  (base_kernel): RBFKernel(
    (raw_lengthscale_constraint): Positive()
  )
  (raw_outputscale_constraint): Positive()
), ScaleKernel(
  (base_kernel): RBFKernel(
    (raw_lengthscale_constraint): Positive()
  )
  (raw_outputscale_constraint): Positive()
), ScaleKernel(
  (base_kernel): RBFKernel(
    (raw_lengthscale_constraint): Positive()
  )
  (raw_outputscale_constraint): Positive()
), ScaleKernel(
  (base_kernel): RBFKernel(
    (raw_lengthscale_constraint): Positive()
  )
  (raw_outputscale_constraint): Positive()
), ScaleKernel(
  (base_kernel): RBFKernel(
    (raw_lengthscale_constraint): Positive()
  )
  (raw_outputscale_constraint): Positive()
), ScaleKernel(
  (base_kernel): RBFKernel(
    (raw_lengthscale_constraint): Positive()
  )
  (raw_outputscale_constraint): Positive()
), ScaleKernel(
  (base_kernel): RBFKernel(
    (r

### Soft GP testing profiling/tuning boilerplate
its recommend to collapse all functions with Ctrl/Cmnd +k +0 

### SoftGP baseline implementation 
---


In [2]:
class SoftGP_baseline(torch.nn.Module):
    def __init__(
        self,
        kernel: Callable,
        inducing_points: torch.Tensor,
        noise=1e-3,
        learn_noise=False,
        use_scale=False,
        device="cpu",
        dtype=torch.float32,
        solver="solve",
        max_cg_iter=50,
        cg_tolerance=0.5,
        mll_approx="hutchinson",
        fit_chunk_size=1024,
        use_qr=False,
    ) -> None:
        # Argument checking 
        methods = ["solve", "cholesky", "cg"]
        if not solver in methods:
            raise ValueError(f"Method {solver} should be in {methods} ...")
        
        # Check devices
        devices = ["cpu"]
        if torch.cuda.is_available():
            devices += ["cuda"]
            for i in range(torch.cuda.device_count()):
                devices += [f"cuda:{i}"]
        if not device in devices:
            raise ValueError(f"Device {device} should be in {devices} ...")

        # Create torch module
        super(SoftGP_baseline, self).__init__()

        # Misc
        self.device = device
        self.dtype = dtype
        
        # Mll approximation settings
        self.solve_method = solver
        self.mll_approx = mll_approx

        # Fit settings
        self.use_qr = use_qr
        self.fit_chunk_size = fit_chunk_size

        # Noise
        self.noise_constraint = gpytorch.constraints.Positive()
        noise = torch.tensor([noise], dtype=self.dtype, device=self.device)
        noise = self.noise_constraint.inverse_transform(noise)
        if learn_noise:
            self.register_parameter("raw_noise", torch.nn.Parameter(noise))
        else:
            self.raw_noise = noise

        # Kernel
        self.use_scale = use_scale
        if use_scale:
            self.kernel = ScaleKernel(kernel).to(self.device)
        else:
            self.kernel = kernel.to(self.device)

        # Inducing points
        self.register_parameter("inducing_points", torch.nn.Parameter(inducing_points))

        # Interpolation
        def softmax_interp(X: torch.Tensor, sigma_values: torch.Tensor) -> torch.Tensor:
            distances = torch.linalg.vector_norm(X - sigma_values, ord=2, dim=-1)
            softmax_distances = torch.softmax(-distances, dim=-1)
            return softmax_distances
        self.interp = softmax_interp
        
        # Fit artifacts
        self.alpha = None
        self.K_zz_alpha = None

        # CG solver params
        self.max_cg_iter = max_cg_iter
        self.cg_tol = cg_tolerance
        self.x0 = None
        
    # -----------------------------------------------------
    # Soft GP Helpers
    # -----------------------------------------------------
    
    @property
    def noise(self):
        return self.noise_constraint.transform(self.raw_noise)

    def get_lengthscale(self) -> float:
        if self.use_scale:
            return self.kernel.base_kernel.lengthscale.cpu()
        else:
            return self.kernel.lengthscale.cpu()
        
    def get_outputscale(self) -> float:
        if self.use_scale:
            return self.kernel.outputscale.cpu()
        else:
            return 1.

    def _mk_cov(self, z: torch.Tensor) -> torch.Tensor:
        return self.kernel(z, z).evaluate()
    
    def _interp(self, x: torch.Tensor) -> torch.Tensor:
        x_expanded = x.unsqueeze(1).expand(-1, self.inducing_points.shape[0], -1)
        W_xz = self.interp(x_expanded, self.inducing_points)
        return W_xz

    # -----------------------------------------------------
    # Linear solver
    # -----------------------------------------------------

    def _solve_system(
        self,
        kxx: linear_operator.operators.LinearOperator,
        full_rhs: torch.Tensor,
        x0: torch.Tensor = None,
        forwards_matmul: Callable = None,
        precond: torch.Tensor = None,
        return_pinv: bool = False,
    ) -> torch.Tensor:
        use_pinv = False
        with torch.no_grad():
            try:
                if self.solve_method == "solve":
                    solve = torch.linalg.solve(kxx, full_rhs)
                elif self.solve_method == "cholesky":
                    L = torch.linalg.cholesky(kxx)
                    solve = torch.cholesky_solve(full_rhs, L)
                elif self.solve_method == "cg":
                    # Source: https://github.com/AndPotap/halfpres_gps/blob/main/mlls/mixedpresmll.py
                    solve = linear_cg(
                        forwards_matmul,
                        full_rhs,
                        max_iter=self.max_cg_iter,
                        tolerance=self.cg_tol,
                        initial_guess=x0,
                        preconditioner=precond,
                    )
                else:
                    raise ValueError(f"Unknown method: {self.solve_method}")
            except RuntimeError as e:
                print("Fallback to pseudoinverse: ", str(e))
                solve = torch.linalg.pinv(kxx.evaluate()) @ full_rhs
                use_pinv = True

        # Apply torch.nan_to_num to handle NaNs from percision limits 
        solve = torch.nan_to_num(solve)
        return (solve, use_pinv) if return_pinv else solve

    # -----------------------------------------------------
    # Marginal Log Likelihood
    # -----------------------------------------------------

    def mll(self, X: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Compute the marginal log likelihood of a soft GP:
            
            log p(y) = log N(y | mu_x, Q_xx)

            where
                mu_X: mean of soft GP
                Q_XX = W_xz K_zz W_zx

        Args:
            X (torch.Tensor): B x D tensor of inputs where each row is a point.
            y (torch.Tensor): B tensor of targets.

        Returns:
            torch.Tensor:  log p(y)
        """        
        # Construct covariance matrix components
        K_zz = self._mk_cov(self.inducing_points)
        W_xz = self._interp(X)
        
        if self.mll_approx == "exact":
            # [Note]: Compute MLL with a multivariate normal. Unstable for float.
            # 1. mean: 0
            mean = torch.zeros(len(X), dtype=self.dtype, device=self.device)
            
            # 2. covariance: Q_xx = (W_xz L) (L^T W_xz) + noise I  where K_zz = L L^T
            L = psd_safe_cholesky(K_zz)
            LK = (W_xz @ L).to(device=self.device)
            cov_diag = self.noise * torch.ones(len(X), dtype=self.dtype, device=self.device)

            # 3. N(mu, Q_xx)
            normal_dist = torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal(mean, LK, cov_diag, validate_args=None)
            
            # 4. log N(y | mu, Q_xx)
            return normal_dist.log_prob(y)
        elif self.mll_approx == "hutchinson":
            # [Note]: Compute MLL with Hutchinson's trace estimator
            # 1. mean: 0
            mean = torch.zeros(len(X), dtype=self.dtype, device=self.device)
            
            # 2. covariance: Q_xx = W_xz K_zz K_zx + noise I
            cov_mat = W_xz @ K_zz @ W_xz.T 
            cov_mat += torch.eye(cov_mat.shape[1], dtype=self.dtype, device=self.device) * self.noise

            # 3. log N(y | mu, Q_xx) \appox 
            hutchinson_mll = HutchinsonPseudoLoss(self, num_trace_samples=10)
            return hutchinson_mll(mean, cov_mat, y)
        else:
            raise ValueError(f"Unknown MLL approximation method: {self.mll_approx}")
        
    # -----------------------------------------------------
    # Fit
    # -----------------------------------------------------

    def _direct_solve_fit(self, M, N, X, y, K_zz):
        # Construct A and b for linear solve
        #   A = (K_zz + hat{K}_zx @ noise^{-1} @ hat{K}_xz)
        #   b = (hat{K}_zx @ noise^{-1}) y
        if X.shape[0] * X.shape[1] <= 32768:
            # Case: "small" X
            # Form estimate \hat{K}_xz ~= W_xz K_zz
            W_xz = self._interp(X)
            hat_K_xz = W_xz @ K_zz
            hat_K_zx = hat_K_xz.T
            
            # Form A and b
            Lambda_inv_diag = (1 / self.noise) * torch.ones(N, dtype=self.dtype).to(self.device)
            A = K_zz + hat_K_zx @ (Lambda_inv_diag.unsqueeze(1) * hat_K_xz)
            b = hat_K_zx @ (Lambda_inv_diag * y)
        else:
            # Case: "large" X
            with torch.no_grad():
                # Initialize outputs
                A = torch.zeros(M, M, dtype=self.dtype, device=self.device)
                b = torch.zeros(M, dtype=self.dtype, device=self.device)
                
                # Initialize temporary values
                fit_chunk_size = self.fit_chunk_size
                batches = int(np.floor(N / fit_chunk_size))
                Lambda_inv = (1 / self.noise) * torch.eye(fit_chunk_size, dtype=self.dtype, device=self.device)
                tmp1 = torch.zeros(fit_chunk_size, M, dtype=self.dtype, device=self.device)
                tmp2 = torch.zeros(M, M, dtype=self.dtype, device=self.device)
                tmp3 = torch.zeros(fit_chunk_size, dtype=self.dtype, device=self.device)
                tmp4 = torch.zeros(M, dtype=self.dtype, device=self.device)
                tmp5 = torch.zeros(M, dtype=self.dtype, device=self.device)
                
                # Compute batches
                for i in range(batches):
                    # Update A: A += W_zx @ Lambda_inv @ W_xz
                    X_batch = X[i*fit_chunk_size:(i+1)*fit_chunk_size]
                    W_xz = self._interp(X_batch)
                    W_zx = W_xz.T
                    torch.matmul(Lambda_inv, W_xz, out=tmp1)
                    torch.matmul(W_zx, tmp1, out=tmp2)
                    A.add_(tmp2)
                    
                    # Update b: b += K_zz @ W_zx @ (Lambda_inv @ Y[i*batch_size:(i+1)*batch_size])
                    torch.matmul(Lambda_inv, y[i*fit_chunk_size:(i+1)*fit_chunk_size], out=tmp3)
                    torch.matmul(W_zx, tmp3, out=tmp4)
                    torch.matmul(K_zz, tmp4, out=tmp5)
                    b.add_(tmp5)
                
                # Compute last batch
                if N - (i+1)*fit_chunk_size > 0:
                    Lambda_inv = (1 / self.noise) * torch.eye(N - (i+1)*fit_chunk_size, dtype=self.dtype, device=self.device)
                    X_batch = X[(i+1)*fit_chunk_size:]
                    W_xz = self._interp(X_batch)
                    A += W_xz.T @ Lambda_inv @ W_xz
                    b += K_zz @ W_xz.T @ Lambda_inv @ y[(i+1)*fit_chunk_size:]

                # Aggregate result
                A = K_zz + K_zz @ A @ K_zz

        # Safe solve A \alpha = b
        A = DenseLinearOperator(A)
        self.alpha, use_pinv = self._solve_system(
            A,
            b.unsqueeze(1),
            x0=torch.zeros_like(b),
            forwards_matmul=A.matmul,
            precond=None,
            return_pinv=True
        )

        # Store for fast prediction
        self.K_zz_alpha = K_zz @ self.alpha
        return use_pinv

    def _qr_solve_fit(self, M, N, X, y, K_zz):
        if X.shape[0] * X.shape[1] <= 32768:
            # Compute: W_xz K_zz
            print("USING QR SMALL")
            W_xz = self._interp(X)
            hat_K_xz = W_xz @ K_zz
        else:
            # Compute: W_xz K_zz in a batched fashion
            print("USING QR BATCH")
            with torch.no_grad():
                # Compute batches
                fit_chunk_size = self.fit_chunk_size
                batches = int(np.floor(N / fit_chunk_size))
                Lambda_half_inv_diag = (1 / torch.sqrt(self.noise)) * torch.ones(fit_chunk_size, dtype=self.dtype, device=self.device)
                hat_K_xz = torch.zeros((N, M), dtype=self.dtype, device=self.device)
                for i in range(batches):
                    start = i*fit_chunk_size
                    end = (i+1)*fit_chunk_size
                    X_batch = X[start:end,:]
                    W_xz = self._interp(X_batch)
                    torch.matmul(W_xz, K_zz, out=hat_K_xz[start:end,:])
                
                start = (i+1)*fit_chunk_size
                if N - start > 0:
                    Lambda_half_inv_diag = (1 / torch.sqrt(self.noise)) * torch.eye(N - (i+1)*fit_chunk_size, dtype=self.dtype, device=self.device)
                    X_batch = X[start:]
                    W_xz = self._interp(X_batch)
                    torch.matmul(W_xz, K_zz, out=hat_K_xz[start:,:])
        
        # B^T = [(Lambda^{-1/2} \hat{K}_xz) U_zz ]
        U_zz = psd_safe_cholesky(K_zz, upper=True, max_tries=10)
        Lambda_half_inv_diag = (1 / torch.sqrt(self.noise)) * torch.ones(N, dtype=self.dtype).to(self.device)
        B = torch.cat([Lambda_half_inv_diag.unsqueeze(1) * hat_K_xz, U_zz], dim=0)

        # B = QR
        Q, R = torch.linalg.qr(B)

        # \alpha = R^{-1} @ Q^T @ Lambda^{-1/2}b
        b = Lambda_half_inv_diag * y
        self.alpha = torch.linalg.solve_triangular(R, (Q.T[:, 0:N] @ b).unsqueeze(1), upper=True).squeeze(1) # (should use triangular solve)
        # self.alpha = ((torch.linalg.inv(R) @ Q.T)[:, :N] @ b)
        
        # Store for fast inference
        self.K_zz_alpha = K_zz @ self.alpha

        return False

    def fit(self, X: torch.Tensor, y: torch.Tensor) -> bool:
        """Fits a SoftGP to dataset (X, y). That is, solve:

                (hat{K}_zx @ noise^{-1}) y = (K_zz + hat{K}_zx @ noise^{-1} @ hat{K}_xz) \alpha
        
            for \alpha where
            1. inducing points z are fixed,
            2. hat{K}_zx = K_zz W_zx, and
            3. hat{K}_xz = hat{K}_zx^T.

        Args:
            X (torch.Tensor): N x D tensor of inputs
            y (torch.Tensor): N tensor of outputs

        Returns:
            bool: Returns true if the pseudoinverse was used, false otherwise.
        """        
        # Prepare inputs
        N = len(X)
        M = len(self.inducing_points)
        X = X.to(self.device, dtype=self.dtype)
        y = y.to(self.device, dtype=self.dtype)

        # Form K_zz
        K_zz = self._mk_cov(self.inducing_points)

        if self.use_qr:
            return self._qr_solve_fit(M, N, X, y, K_zz)
        else:
            return self._direct_solve_fit(M, N, X, y, K_zz)

    # -----------------------------------------------------
    # Predict
    # -----------------------------------------------------

    def pred(self, x_star: torch.Tensor) -> torch.Tensor:
        """Give the posterior predictive:
        
            p(y_star | x_star, X, y) 
                = W_star_z (K_zz \alpha)
                = W_star_z K_zz (K_zz + hat{K}_zx @ noise^{-1} @ hat{K}_xz)^{-1} (hat{K}_zx @ noise^{-1}) y

        Args:
            x_star (torch.Tensor): B x D tensor of points to evaluate at.

        Returns:
            torch.Tensor: B tensor of p(y_star | x_star, X, y).
        """        
        W_star_z = self._interp(x_star)
        return torch.matmul(W_star_z, self.K_zz_alpha).squeeze(-1)

### SoftGP test implementation 

---

In [86]:
from sklearn.neighbors import NearestNeighbors

class SoftGP_test(torch.nn.Module):
    def __init__(
        self,
        kernel: Callable,
        inducing_points: torch.Tensor,
        noise=1e-3,
        learn_noise=False,
        use_scale=False,
        device="cpu",
        dtype=torch.float32,
        solver="solve",
        max_cg_iter=50,
        cg_tolerance=0.5,
        mll_approx="hutchinson",
        fit_chunk_size=1024,
        use_qr=False,
    ) -> None:
        # Argument checking 
        methods = ["solve", "cholesky", "cg"]
        if not solver in methods:
            raise ValueError(f"Method {solver} should be in {methods} ...")
        
        # Check devices
        devices = ["cpu"]
        if torch.cuda.is_available():
            devices += ["cuda"]
            for i in range(torch.cuda.device_count()):
                devices += [f"cuda:{i}"]
        if not device in devices:
            raise ValueError(f"Device {device} should be in {devices} ...")

        # Create torch module
        super(SoftGP_test, self).__init__()

        # Misc
        self.device = device
        self.dtype = dtype
        
        # Mll approximation settings
        self.solve_method = solver
        self.mll_approx = mll_approx

        # Fit settings
        self.use_qr = use_qr
        self.fit_chunk_size = fit_chunk_size

        # Noise
        self.noise_constraint = gpytorch.constraints.Positive()
        noise = torch.tensor([noise], dtype=self.dtype, device=self.device)
        noise = self.noise_constraint.inverse_transform(noise)
        if learn_noise:
            self.register_parameter("raw_noise", torch.nn.Parameter(noise))
        else:
            self.raw_noise = noise

        # Kernel
        self.use_scale = use_scale
        if use_scale:
            self.kernel = ScaleKernel(kernel).to(self.device)
        else:
            self.kernel = kernel.to(self.device)

        # Inducing points
        self.register_parameter("inducing_points", torch.nn.Parameter(inducing_points))

        # Interpolation
        #self.threshold = torch.nn.Parameter(torch.tensor(1e-1))
        self.k = torch.nn.Parameter(torch.tensor(30.0))
        #self.T = torch.nn.Parameter(torch.tensor(1))

       
            
            

        # def softmax_interp(X: torch.Tensor, Z: torch.Tensor, T=1, n_neighbors=40,threshold=1e-7) -> torch.Tensor:
        #     X=X/T
        #     #--------- Distances ---------
        #     X_norm = (X**2).sum(dim=1, keepdim=True)  
        #     Z_norm = (Z**2).sum(dim=1, keepdim=True) 
        #     #||X_i - Z_j||^2 = ||X_i||^2 + ||Z_j||^2 - 2*X_i*Z_j
        #     distances = X_norm + Z_norm.T - 2 * torch.mm(X, Z.T)
            
        #     #--------- Thresholding ---------
        #     distances = torch.sqrt(torch.clamp(distances, min=threshold)) 
            
        #     #--------- K neighbors ---------
        #     _, indices = torch.topk(-distances, k=n_neighbors, dim=1) 
        #     selected_distances = torch.gather(distances, 1, indices)
            
        #     #--------- Softmax ---------
        #     exp_dists = torch.exp(-selected_distances )
        #     W_XZ_local = exp_dists / exp_dists.sum(dim=1, keepdim=True)
            
        #     # Populate sparse matrix with k values per row
        #     W_XZ = torch.zeros((X.size(0), Z.size(0)), device=device)
        #     W_XZ.scatter_(1, indices, W_XZ_local)
        #     return W_XZ
        self.threshold = 1e-32
        # self.desired_sparsity=torch.nn.Parameter(torch.tensor(50.0))

     
        def sparsity_measure(tensor: torch.Tensor) -> float:
            """Calculate the sparsity of a tensor (ratio of zero elements)."""
            num_zeros = (tensor == 0).sum().item()
            total_elements = tensor.numel()
            return num_zeros / total_elements

        def softmax_interp(X: torch.Tensor, sigma_values: torch.Tensor, threshold_factor: float = 1.0) -> torch.Tensor:
            # Compute distances between X and sigma_values
            distances = torch.linalg.vector_norm(X - sigma_values, ord=2, dim=-1)
            
            # Apply softmax to the negative distances
            softmax_distances = torch.softmax(-distances, dim=-1)
            
            # Measure sparsity before applying the threshold
            #print("Sparsity before masking:", sparsity_measure(softmax_distances))
            
            # Set threshold as a factor of the mean of the softmax distances
            threshold_value = torch.mean(softmax_distances) * .5
            
            #print(f"Threshold value: {threshold_value}")
            
            # Apply the threshold: round off (zero out) values below the threshold
            masked_distances = torch.where(softmax_distances < threshold_value, torch.tensor(0.0, device=softmax_distances.device), softmax_distances)
            
            # Measure sparsity after applying the threshold
            #print("Sparsity after masking:", sparsity_measure(masked_distances))
            
            return masked_distances




        
        self.interp = softmax_interp

        
        def sparse_boltz(X: torch.Tensor, sigma_values: torch.Tensor) -> torch.Tensor:
            X = X / .005 
            
            # Calculate distances
            distances = torch.linalg.vector_norm(X - sigma_values, ord=2, dim=-1)
            
            # Apply softmax to distances
            softmax_distances = torch.softmax(-distances, dim=-1)
            # print("Min value:", softmax_distances.min().item())
            # print("Max value:", softmax_distances.max().item())
            # print("Mean value:", softmax_distances.mean().item())
            # # Apply threshold to create sparsity
            threshold = 1e-2
            mask = softmax_distances >= threshold
            
            # Get indices of non-zero (or significant) elements
            nonzero_indices = mask.nonzero(as_tuple=False).t()  # Stack the indices as a 2D tensor

            # Get values of the non-zero elements
            nonzero_values = softmax_distances[mask]

            sparse_matrix = torch.sparse_coo_tensor(nonzero_indices, nonzero_values, size=softmax_distances.shape)
            
            return sparse_matrix
        
        # def softmax_interp(X: torch.Tensor, sigma_values: torch.Tensor, threshold=1e-3) -> torch.Tensor:
        #     X= X/.005
        #     distances = torch.linalg.vector_norm(X - sigma_values, ord=2, dim=-1)
        #     softmax_distances = torch.softmax(-distances, dim=-1)
        #     masked_distances = torch.where(softmax_distances < threshold, torch.tensor(0.0, device=softmax_distances.device), softmax_distances)

        #     return masked_distances
        

        
        
   
        
        # Fit artifacts
        self.alpha = None
        self.K_zz_alpha = None

        # CG solver params
        self.max_cg_iter = max_cg_iter
        self.cg_tol = cg_tolerance
        self.x0 = None
        
    # -----------------------------------------------------
    # Soft GP Helpers
    # -----------------------------------------------------
    
    @property
    def noise(self):
        return self.noise_constraint.transform(self.raw_noise)
    
  

    def seed_threshold(self, X: torch.Tensor, Z: torch.Tensor, desired_sparsity: float = 0.25) -> float:
        def sparsity_measure(tensor: torch.Tensor) -> float:
        # Count number of zeros in the tensor
            num_zeros = (tensor == 0).sum().item()
            # Total number of elements
            total_elements = tensor.numel()
            # Sparsity ratio
            return num_zeros / total_elements
        
        # Randomly select indices from X
        random_indices = np.random.randint(0, X.shape[0], 50)
        X_samples = torch.tensor(X[random_indices])

        # Compute the softmax distances for each selected sample
        softmax_distances_list = [self.interp(x.unsqueeze(0), Z) for x in X_samples]
        
        # Stack the softmax distances into a tensor
        softmax_distances_stack = torch.stack(softmax_distances_list)
        
        # Measure sparsity before applying the threshold
        print("Sparsity before thresholding:", sparsity_measure(softmax_distances_stack))
        
        # Compute the geometric mean of the softmax distances
        geometric_mean_softmax = torch.exp(torch.mean(torch.log(softmax_distances_stack), dim=0))
        
        # Flatten the geometric mean tensor for sorting
        geometric_mean_softmax_flattened = geometric_mean_softmax.flatten()
        
        # Sort the flattened tensor in descending order
        sorted_geometric_mean_softmax = torch.sort(geometric_mean_softmax_flattened, descending=True).values
        
        # Determine the index corresponding to the desired sparsity
        threshold_index = int(.5 * len(sorted_geometric_mean_softmax))
        
        # Get the threshold value
        threshold_value = sorted_geometric_mean_softmax[threshold_index].item()
        
        # Store the threshold value
        self.threshold = threshold_value
        
        # Apply the threshold to the softmax distances to induce sparsity
        softmax_distances_stack[softmax_distances_stack < threshold_value] = 0
        
        # Measure sparsity after applying the threshold
        print("Sparsity after thresholding:", sparsity_measure(softmax_distances_stack))

            
    def get_lengthscale(self) -> float:
        if self.use_scale:
            return self.kernel.base_kernel.lengthscale.cpu()
        else:
            return self.kernel.lengthscale.cpu()
        
    def get_outputscale(self) -> float:
        if self.use_scale:
            return self.kernel.outputscale.cpu()
        else:
            return 1.

    def _mk_cov(self, z: torch.Tensor) -> torch.Tensor:
        return self.kernel(z, z).evaluate()
    
    # def _interp(self, x: torch.Tensor) -> torch.Tensor:
    #     # Expand input x and perform interpolation to get the dense matrix W_xz_dense
    #     x_expanded = x.unsqueeze(1).expand(-1, self.inducing_points.shape[0], -1)
    #     W_xz_dense = self.interp(x_expanded, self.inducing_points)
        
    #     non_zero_indices = W_xz_dense.nonzero(as_tuple=False).t()  
    #     non_zero_values = W_xz_dense[non_zero_indices[0], non_zero_indices[1]]  
        
    #     W_xz_sparse = torch.sparse_coo_tensor(
    #         non_zero_indices, non_zero_values, W_xz_dense.size(), device=W_xz_dense.device
    #     )
        
    #     W_xz_sparse = W_xz_sparse.coalesce()
        
    #     return W_xz_sparse

    def _interp(self, x: torch.Tensor) -> torch.Tensor:
        x_expanded = x.unsqueeze(1).expand(-1, self.inducing_points.shape[0], -1)
        W_xz = self.interp(x_expanded, self.inducing_points)
        return W_xz
    # def _interp(self, x: torch.Tensor) -> torch.Tensor:
        
    #     W_xz = self.interp(x, self.inducing_points)
    #     return W_xz

    # -----------------------------------------------------
    # Linear solver
    # -----------------------------------------------------

    def _solve_system(
        self,
        kxx: linear_operator.operators.LinearOperator,
        full_rhs: torch.Tensor,
        x0: torch.Tensor = None,
        forwards_matmul: Callable = None,
        precond: torch.Tensor = None,
        return_pinv: bool = False,
    ) -> torch.Tensor:
        use_pinv = False
        with torch.no_grad():
            try:
                if self.solve_method == "solve":
                    solve = torch.linalg.solve(kxx, full_rhs)
                elif self.solve_method == "cholesky":
                    L = torch.linalg.cholesky(kxx)
                    solve = torch.cholesky_solve(full_rhs, L)
                elif self.solve_method == "cg":
                    # Source: https://github.com/AndPotap/halfpres_gps/blob/main/mlls/mixedpresmll.py
                    solve = linear_cg(
                        forwards_matmul,
                        full_rhs,
                        max_iter=self.max_cg_iter,
                        tolerance=self.cg_tol,
                        initial_guess=x0,
                        preconditioner=precond,
                    )
                else:
                    raise ValueError(f"Unknown method: {self.solve_method}")
            except RuntimeError as e:
                print("Fallback to pseudoinverse: ", str(e))
                solve = torch.linalg.pinv(kxx.evaluate()) @ full_rhs
                use_pinv = True

        # Apply torch.nan_to_num to handle NaNs from percision limits 
        solve = torch.nan_to_num(solve)
        return (solve, use_pinv) if return_pinv else solve

    # -----------------------------------------------------
    # Marginal Log Likelihood
    # -----------------------------------------------------

    def mll(self, X: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Compute the marginal log likelihood of a soft GP:
            
            log p(y) = log N(y | mu_x, Q_xx)

            where
                mu_X: mean of soft GP
                Q_XX = W_xz K_zz W_zx

        Args:
            X (torch.Tensor): B x D tensor of inputs where each row is a point.
            y (torch.Tensor): B tensor of targets.

        Returns:
            torch.Tensor:  log p(y)
        """        
        # Construct covariance matrix components
        K_zz = self._mk_cov(self.inducing_points)
        W_xz = self._interp(X)
        
        if self.mll_approx == "exact":
            # [Note]: Compute MLL with a multivariate normal. Unstable for float.
            # 1. mean: 0
            mean = torch.zeros(len(X), dtype=self.dtype, device=self.device)
            
            # 2. covariance: Q_xx = (W_xz L) (L^T W_xz) + noise I  where K_zz = L L^T
            L = psd_safe_cholesky(K_zz)
            LK = (W_xz @ L).to(device=self.device)
            cov_diag = self.noise * torch.ones(len(X), dtype=self.dtype, device=self.device)

            # 3. N(mu, Q_xx)
            normal_dist = torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal(mean, LK, cov_diag, validate_args=None)
            
            # 4. log N(y | mu, Q_xx)
            return normal_dist.log_prob(y)
        
        elif self.mll_approx == "hutchinson":
            # [Note]: Compute MLL with Hutchinson's trace estimator
            # 1. mean: 0
            mean = torch.zeros(len(X), dtype=self.dtype, device=self.device)
            
            # 2. covariance: Q_xx = W_xz K_zz K_zx + noise I
            # cov_mat = W_xz @ K_zz @ W_xz.T 
            cov_mat = W_xz @ K_zz @ W_xz.T 
            # cov_mat = torch.sparse.mm(W_xz, K_zz)  # First sparse-dense multiplication (W_xz @ K_zz)
            # cov_mat = torch.sparse.mm(cov_mat, W_xz.t())  # Second multiplication with W_xz^T (transpose)
            cov_mat += torch.eye(cov_mat.shape[1], dtype=self.dtype, device=self.device) * self.noise

            # 3. log N(y | mu, Q_xx) \appox 
            hutchinson_mll = HutchinsonPseudoLoss(self, num_trace_samples=10,vector_format="sphere")
            return hutchinson_mll(mean, cov_mat, y)
        else:
            raise ValueError(f"Unknown MLL approximation method: {self.mll_approx}")
        
    # -----------------------------------------------------
    # Fit
    # -----------------------------------------------------
    def _direct_solve_fit(self, M, N, X, y, K_zz):
        # Construct A and b for linear solve
        #   A = (K_zz + hat{K}_zx @ noise^{-1} @ hat{K}_xz)
        #   b = (hat{K}_zx @ noise^{-1}) y
        if X.shape[0] * X.shape[1] <= 32768:
            # Case: "small" X
            # Form estimate \hat{K}_xz ~= W_xz K_zz
            W_xz = self._interp(X)
            hat_K_xz = W_xz @ K_zz
            
            hat_K_zx = hat_K_xz.T

            # Form A and b
            Lambda_inv_diag = (1 / self.noise) * torch.ones(N, dtype=self.dtype).to(self.device)
            A = K_zz + hat_K_zx @ (Lambda_inv_diag.unsqueeze(1) * hat_K_xz)
            b = hat_K_zx @ (Lambda_inv_diag * y)
        else:
            # Case: "large" X
            with torch.no_grad():
                # Initialize outputs
                A = torch.zeros(M, M, dtype=self.dtype, device=self.device)
                b = torch.zeros(M, dtype=self.dtype, device=self.device)
                
                # Initialize temporary values
                fit_chunk_size = self.fit_chunk_size
                batches = int(np.floor(N / fit_chunk_size))
                Lambda_inv = (1 / self.noise) * torch.eye(fit_chunk_size, dtype=self.dtype, device=self.device)
                tmp1 = torch.zeros(fit_chunk_size, M, dtype=self.dtype, device=self.device)
                tmp2 = torch.zeros(M, M, dtype=self.dtype, device=self.device)
                tmp3 = torch.zeros(fit_chunk_size, dtype=self.dtype, device=self.device)
                tmp4 = torch.zeros(M, dtype=self.dtype, device=self.device)
                tmp5 = torch.zeros(M, dtype=self.dtype, device=self.device)
                
                # Compute batches
                for i in range(batches):
                    # Update A: A += W_zx @ Lambda_inv @ W_xz
                    X_batch = X[i*fit_chunk_size:(i+1)*fit_chunk_size]
                    W_xz = self._interp(X_batch)
                    W_zx = W_xz.T
                    torch.matmul(Lambda_inv, W_xz, out=tmp1)
                    torch.matmul(W_zx, tmp1, out=tmp2)
                    A.add_(tmp2)
                    
                    # Update b: b += K_zz @ W_zx @ (Lambda_inv @ Y[i*batch_size:(i+1)*batch_size])
                    torch.matmul(Lambda_inv, y[i*fit_chunk_size:(i+1)*fit_chunk_size], out=tmp3)
                    torch.matmul(W_zx, tmp3, out=tmp4)
                    torch.matmul(K_zz, tmp4, out=tmp5)
                    b.add_(tmp5)
                
                # Compute last batch
                if N - (i+1)*fit_chunk_size > 0:
                    Lambda_inv = (1 / self.noise) * torch.eye(N - (i+1)*fit_chunk_size, dtype=self.dtype, device=self.device)
                    X_batch = X[(i+1)*fit_chunk_size:]
                    W_xz = self._interp(X_batch)
                    A += W_xz.T @ Lambda_inv @ W_xz
                    b += K_zz @ W_xz.T @ Lambda_inv @ y[(i+1)*fit_chunk_size:]

                # Aggregate result
                A = K_zz + K_zz @ A @ K_zz

        # Safe solve A \alpha = b
        A = DenseLinearOperator(A)
        self.alpha, use_pinv = self._solve_system(
            A,
            b.unsqueeze(1),
            x0=torch.zeros_like(b),
            forwards_matmul=A.matmul,
            precond=None,
            return_pinv=True
        )

        # Store for fast prediction
        self.K_zz_alpha = K_zz @ self.alpha
        return use_pinv
 
    def _qr_solve_fit(self, M, N, X, y, K_zz):
        # W_xz = self._interp(X)
        # hat_K_xz = W_xz @ K_zz
        if X.shape[0] * X.shape[1] <= 32768:
            # Compute: W_xz K_zz
            print("USING QR SMALL")
            W_xz = self._interp(X)
            hat_K_xz = W_xz @ K_zz
        else:
            # Compute: W_xz K_zz in a batched fashion
            print("USING QR BATCH")
            with torch.no_grad():
                # Compute batches
                fit_chunk_size = self.fit_chunk_size
                batches = int(np.floor(N / fit_chunk_size))
                Lambda_half_inv_diag = (1 / torch.sqrt(self.noise)) * torch.ones(fit_chunk_size, dtype=self.dtype, device=self.device)
                hat_K_xz = torch.zeros((N, M), dtype=self.dtype, device=self.device)
                for i in range(batches):
                    start = i*fit_chunk_size
                    end = (i+1)*fit_chunk_size
                    X_batch = X[start:end,:]
                    W_xz = self._interp(X_batch)
                    
                    # sparsity = 1.0 - W_xz._nnz() / float(W_xz.numel())
                    # print(f"Sparsity: {sparsity}")
               
                    #result = torch.sparse.mm(W_xz, K_zz)
                    #hat_K_xz[start:end, :] = result
                    torch.matmul(W_xz, K_zz, out=hat_K_xz[start:end,:])
                                
                start = (i+1)*fit_chunk_size
                if N - start > 0:
                    Lambda_half_inv_diag = (1 / torch.sqrt(self.noise)) * torch.eye(N - (i+1)*fit_chunk_size, dtype=self.dtype, device=self.device)
                    X_batch = X[start:]
                    W_xz = self._interp(X_batch)
                    #result = torch.sparse.mm(W_xz, K_zz)
                    #hat_K_xz[start:, :] = result
                    torch.matmul(W_xz, K_zz, out=hat_K_xz[start:,:])
                    

        
        # B^T = [(Lambda^{-1/2} \hat{K}_xz) U_zz ]
        U_zz = psd_safe_cholesky(K_zz, upper=True, max_tries=10)
        Lambda_half_inv_diag = (1 / torch.sqrt(self.noise)) * torch.ones(N, dtype=self.dtype).to(self.device)
        B = torch.cat([Lambda_half_inv_diag.unsqueeze(1) * hat_K_xz, U_zz], dim=0)

        # B = QR
        Q, R = torch.linalg.qr(B)

        # \alpha = R^{-1} @ Q^T @ Lambda^{-1/2}b
        b = Lambda_half_inv_diag * y
        self.alpha = torch.linalg.solve_triangular(R, (Q.T[:, 0:N] @ b).unsqueeze(1), upper=True).squeeze(1) # (should use triangular solve)
        # Store for fast inference
        self.K_zz_alpha = (K_zz) @ self.alpha 

        #self.K_zz_alpha = (K_zz+ self.noise * torch.eye(K_zz.shape[0],dtype=self.dtype).to(self.device)) @ self.alpha 

        return False

    def fit(self, X: torch.Tensor, y: torch.Tensor) -> bool:
        """Fits a SoftGP to dataset (X, y). That is, solve:

                (hat{K}_zx @ noise^{-1}) y = (K_zz + hat{K}_zx @ noise^{-1} @ hat{K}_xz) \alpha
        
            for \alpha where
            1. inducing points z are fixed,
            2. hat{K}_zx = K_zz W_zx, and
            3. hat{K}_xz = hat{K}_zx^T.

        Args:
            X (torch.Tensor): N x D tensor of inputs
            y (torch.Tensor): N tensor of outputs

        Returns:
            bool: Returns true if the pseudoinverse was used, false otherwise.
        """        
        # Prepare inputs
        N = len(X)
        M = len(self.inducing_points)
        X = X.to(self.device, dtype=self.dtype)
        y = y.to(self.device, dtype=self.dtype)

        # Form K_zz
        K_zz = self._mk_cov(self.inducing_points)

        if self.use_qr:
            return self._qr_solve_fit(M, N, X, y, K_zz)
        else:
            return self._direct_solve_fit(M, N, X, y, K_zz)

    # -----------------------------------------------------
    # Predict
    # -----------------------------------------------------

    def pred(self, x_star: torch.Tensor) -> torch.Tensor:
        """Give the posterior predictive:
        
            p(y_star | x_star, X, y) 
                = W_star_z (K_zz \alpha)
                = W_star_z K_zz (K_zz + hat{K}_zx @ noise^{-1} @ hat{K}_xz)^{-1} (hat{K}_zx @ noise^{-1}) y

        Args:
            x_star (torch.Tensor): B x D tensor of points to evaluate at.

        Returns:
            torch.Tensor: B tensor of p(y_star | x_star, X, y).
        """        
        W_star_z = self._interp(x_star)
        return torch.matmul(W_star_z, self.K_zz_alpha).squeeze(-1)

### Test version profiling stats
---

In [4]:
from line_profiler import LineProfiler

def eval_gp(model, test_dataset: Dataset, device="cuda:0") -> float:
    preds = []
    neg_mlls = []
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,num_workers=1)
    for x_batch, y_batch in tqdm(test_loader):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        preds += [(model.pred(x_batch) - y_batch).detach().cpu()**2]
        neg_mlls += [-model.mll(x_batch, y_batch).detach().cpu()]
    rmse = torch.sqrt(torch.sum(torch.cat(preds)) / len(test_dataset)).item()
    neg_mll = torch.sum(torch.tensor(neg_mlls))
            
    print("RMSE:", rmse, "NEG_MLL", neg_mll.item(), "NOISE", model.noise.cpu().item(), "LENGTHSCALE", model.get_lengthscale(), "OUTPUTSCALE", model.get_outputscale(),"k",model.k)# "T",model.T)
    
    return {
        "rmse": rmse,
        "nll": neg_mll,
    }   
    
def profileGP(GP,train_dataset, test_dataset):
    num_inducing = 512
    dtype = torch.float32
    batch_size = 1024
    epochs = 3
    lr = 0.01 
    learn_noise = False
    device = "cpu"
    
    # Initialize inducing points with kmeans
    train_features, train_labels = flatten_dataset(train_dataset)
    kmeans = KMeans(n_clusters=num_inducing)
    kmeans.fit(train_features)
    centers = kmeans.cluster_centers_
    inducing_points = torch.tensor(centers).to(dtype=dtype, device=device)
    
    # Setup model
    kernel = RBFKernel().to(device=device, dtype=dtype)

    model = GP(kernel,
        inducing_points,
        noise=1e-3,
        learn_noise=learn_noise,
        use_scale=True,
        dtype=dtype,
        solver="solve",
        max_cg_iter=50,
        cg_tolerance=0.5,
        mll_approx="hutchinson",
        fit_chunk_size=1024,
        use_qr=True,
    )

    if learn_noise:
        params = model.parameters()
    else:
        params = filter_param(model.named_parameters(), "likelihood.noise_covar.raw_noise")
    optimizer = torch.optim.Adam(params, lr=lr)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    pbar = tqdm(range(epochs), desc="Optimizing MLL")
    def train_model():
        print("test")
        for epoch in pbar:
            t1 = time.perf_counter()
            
            neg_mlls = []
            for x_batch, y_batch in train_loader:
                x_batch = x_batch.clone().detach().to(dtype=dtype, device=device)
                y_batch = y_batch.clone().detach().to(dtype=dtype, device=device)

                optimizer.zero_grad()
                with gpytorch.settings.max_root_decomposition_size(100), max_cholesky_size(int(1.e7)):
                    neg_mll = -model.mll(x_batch, y_batch)
                neg_mlls += [-neg_mll.item()]
                neg_mll.backward()
                optimizer.step()

                pbar.set_description(f"Epoch {epoch+1}/{epochs}")
                pbar.set_postfix(MLL=f"{-neg_mll.item()}")
                
            t2 = time.perf_counter()
            #print("fitting",train_features.shape)
            use_pinv = model.fit(train_features, train_labels)
            t3 = time.perf_counter()
            # Calculate time spent on each part
            epoch_time = t2 - t1
            fitting_time = t3 - t2
            
            # Print out the times for the epoch and model fitting
            print(f"Epoch {epoch+1}/{epochs} completed in {epoch_time:.4f} seconds.")
            print(f"Time spent on model fitting: {fitting_time:.4f} seconds.")
        
            #results = eval_gp(model, test_dataset, device=device)
            #print(results)
    profiler = LineProfiler()
    profiler.add_function(train_model)
    profiler.add_function(model.fit)  
    profiler.add_function(model.interp)  
    profiler.add_function(model._qr_solve_fit) 
    profiler.add_function(model.mll)  


    profiler.enable_by_count()
    train_model() 
    profiler.disable_by_count()
    profiler.print_stats()
    return model


### Notes

Total time  22.1084 s  

In [67]:
import faiss

### Profiler 

In [57]:
from data.get_uci import ElevatorsDataset,PoleteleDataset,CTSlicesDataset
# dataset = PoleteleDataset("../data/uci_datasets/uci_datasets/pol/data.csv")
# train_dataset, val_dataset, test_dataset = split_dataset(
#     dataset,
#     train_frac=9/10,
#     val_frac=0/10
# )


dataset = CTSlicesDataset("../data/uci_datasets/uci_datasets/slice/data.csv")
train_dataset, val_dataset, test_dataset = split_dataset(
    dataset,
    train_frac=9/10,
    val_frac=0/10
)

softgp = profileGP(SoftGP_test,train_dataset,test_dataset)

SIZE (53500, 386)


  0%|          | 0/6 [00:00<?, ?it/s]

KeyboardInterrupt: 

### Comparison: Test vs Baseline 
---


In [84]:
#==================Dataset============================
from data.get_uci import ElevatorsDataset,PoleteleDataset
# # dataset = ElevatorsDataset("../data/uci_datasets/uci_datasets/elevators/data.csv")
dataset = PoleteleDataset("../data/uci_datasets/uci_datasets/pol/data.csv")

# train_dataset, val_dataset, test_dataset = split_dataset(
#     dataset,
#     train_frac=9/10,
#     val_frac=0/10  
# )


# dataset = CTSlicesDataset("../data/uci_datasets/uci_datasets/slice/data.csv")
train_dataset, val_dataset, test_dataset = split_dataset(
    dataset,
    train_frac=9/10,
    val_frac=0/10
)

def plot_results(GP_classes, all_mean_rmse, all_mean_runtimes, all_std_rmse, all_std_runtimes, epochs):
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    for i, GP_class in enumerate(GP_classes):
        epochs_range = range(1, epochs + 1)
        axes[0].plot(epochs_range, all_mean_rmse[i], label=f'{GP_class.__name__} RMSE')
        axes[0].fill_between(epochs_range,
                             [m - s for m, s in zip(all_mean_rmse[i], all_std_rmse[i])],
                             [m + s for m, s in zip(all_mean_rmse[i], all_std_rmse[i])],
                             alpha=0.3)
    axes[0].set_title('RMSE per Epoch')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('RMSE')
    axes[0].legend()

    for i, GP_class in enumerate(GP_classes):
        epochs_range = range(1, epochs + 1)
        axes[1].plot(epochs_range, all_mean_runtimes[i], label=f'{GP_class.__name__} Runtime')
        axes[1].fill_between(epochs_range,
                             [m - s for m, s in zip(all_mean_runtimes[i], all_std_runtimes[i])],
                             [m + s for m, s in zip(all_mean_runtimes[i], all_std_runtimes[i])],
                             alpha=0.3)
    axes[1].set_title('Training Time per Epoch')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Time (s)')
    axes[1].legend()

    plt.tight_layout()
    plt.show()

def train_gp(GP_class, inducing_points, test_dataset, train_features, train_labels, epochs, device, dtype):
    kernel = RBFKernel().to(device=device, dtype=dtype)
    learn_noise = False
    lr = .01
    batch_size = 1024

    model = GP_class(kernel,
                     inducing_points,
                     noise=1e-3,
                     learn_noise=learn_noise,
                     
                     use_scale=True,
                     dtype=dtype,
                     solver="solve",
                     max_cg_iter=50,
                     cg_tolerance=0.5,
                     mll_approx="hutchinson",
                     fit_chunk_size=1024,
                     use_qr=True)

    epoch_runtimes = []
    epoch_rmse = []

    # pbar = tqdm(range(epochs), desc="Optimizing MLL")
    if learn_noise:
        params = model.parameters()
    else:
        params = filter_param(model.named_parameters(), "likelihood.noise_covar.raw_noise")
    optimizer = torch.optim.Adam(params, lr=lr)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    # if GP_class  ==SoftGP_test:
    #     model.seed_threshold(train_features,inducing_points,desired_sparsity=.5)
        
    def train_model():
        #==================Train============================
        for _ in tqdm(range(epochs)):
            print("training current epoch")
            epoch_start_time = time.time()

            for x_batch, y_batch in train_loader:
                x_batch = x_batch.clone().detach().to(dtype=dtype, device=device)
                y_batch = y_batch.clone().detach().to(dtype=dtype, device=device)
                optimizer.zero_grad()
                with gpytorch.settings.max_root_decomposition_size(100), max_cholesky_size(int(1.e7)):
                    neg_mll = -model.mll(x_batch, y_batch)
                neg_mll.backward()
                optimizer.step()
                # pbar.set_description(f"Epoch {epoch + 1}/{epochs}")
                # pbar.set_postfix(MLL=f"{-neg_mll.item()}")
            model.fit(train_features, train_labels)
            epoch_end_time = time.time()
            epoch_runtimes.append(epoch_end_time - epoch_start_time)

            #==================Evaluate============================
            print("Running eval")
            eval_results = eval_gp(model, test_dataset, device=device)
            epoch_rmse.append(eval_results['rmse'])
            print("eval finished")    
    train_model()
    return epoch_rmse, epoch_runtimes

SIZE (15000, 27)


### Benchmark 1 

In [87]:
def benchmark(GP_classes, train_dataset, test_dataset, epochs=2, seed=42, N=3):
    torch.manual_seed(6535)
    np.random.seed(6535)
    random.seed(6535)

    num_inducing = 512
    dtype = torch.float32
    device = "cpu"
    
    all_mean_rmse = []
    all_mean_runtimes = []
    all_std_rmse = []
    all_std_runtimes = []

    #==================Inducing Points============================
    train_features, train_labels = flatten_dataset(train_dataset)
    kmeans = KMeans(n_clusters=num_inducing)
    kmeans.fit(train_features)
    centers = kmeans.cluster_centers_
    inducing_points = torch.tensor(centers).to(dtype=dtype, device=device)

    for GP_class in GP_classes:
        print(f"Training {GP_class.__name__}...")
        
        # Run the experiment N times and store results for each run
        all_runs_rmse = []
        all_runs_runtimes = []
        
        for run in range(N):
            epoch_rmse, epoch_runtimes = train_gp(
                GP_class,
                inducing_points.clone(),
                test_dataset,
                train_features,
                train_labels,
                epochs,
                device,
                dtype
            )
            all_runs_rmse.append(epoch_rmse)
            all_runs_runtimes.append(epoch_runtimes)

        # Calculate mean and std deviation across the N runs
        mean_rmse = np.mean(all_runs_rmse, axis=0)
        std_rmse = np.std(all_runs_rmse, axis=0)
        mean_runtimes = np.mean(all_runs_runtimes, axis=0)
        std_runtimes = np.std(all_runs_runtimes, axis=0)

        all_mean_rmse.append(mean_rmse)
        all_mean_runtimes.append(mean_runtimes)
        all_std_rmse.append(std_rmse)
        all_std_runtimes.append(std_runtimes)

    return all_mean_rmse, all_mean_runtimes, all_std_rmse, all_std_runtimes

#==================Benchmark============================


GP_classes = [SoftGP_test] 
epochs = 50
N = 1 # Number of runs
all_mean_rmse, all_mean_runtimes, all_std_rmse, all_std_runtimes = benchmark(GP_classes, train_dataset, test_dataset, epochs=epochs, seed=42, N=N)
plot_results(GP_classes, all_mean_rmse, all_mean_runtimes, all_std_rmse, all_std_runtimes, epochs)


  0%|          | 0/2 [00:00<?, ?it/s]

Training SoftGP_test...


  0%|          | 0/50 [00:00<?, ?it/s]

training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.2595203220844269 NEG_MLL -1394.49462890625 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6315]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(0.7653, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.24740484356880188 NEG_MLL -1371.0137939453125 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6065]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(0.8373, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.24221740663051605 NEG_MLL -1303.1715087890625 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6111]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(0.9106, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.23815147578716278 NEG_MLL -1223.810302734375 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6287]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(0.9854, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.2349194884300232 NEG_MLL -1149.45068359375 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6491]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(1.0600, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.2322375625371933 NEG_MLL -1100.5086669921875 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6622]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(1.1348, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.2309694141149521 NEG_MLL -1051.3189697265625 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6691]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(1.2104, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.22920764982700348 NEG_MLL -1010.8848266601562 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6721]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(1.2854, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.22774824500083923 NEG_MLL -971.176513671875 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6698]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(1.3630, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.2259245216846466 NEG_MLL -933.2472534179688 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6671]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(1.4396, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.22519129514694214 NEG_MLL -904.427734375 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6631]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(1.5162, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.22364568710327148 NEG_MLL -876.8345336914062 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6587]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(1.5919, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.22229516506195068 NEG_MLL -846.5908813476562 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6494]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(1.6687, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.2211017906665802 NEG_MLL -826.352783203125 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6392]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(1.7437, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.2193845510482788 NEG_MLL -805.601318359375 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6288]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(1.8181, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.21775978803634644 NEG_MLL -787.3208618164062 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6189]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(1.8929, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.21718920767307281 NEG_MLL -767.9110107421875 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6073]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(1.9684, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.21493051946163177 NEG_MLL -753.1427001953125 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.5941]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(2.0432, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.21463388204574585 NEG_MLL -734.5127563476562 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.5813]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(2.1185, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.21445152163505554 NEG_MLL -720.14501953125 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.5716]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(2.1927, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.2135833501815796 NEG_MLL -705.0137939453125 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.5652]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(2.2660, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.2119896113872528 NEG_MLL -689.9107055664062 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.5580]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(2.3378, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.21106858551502228 NEG_MLL -678.4403686523438 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.5478]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(2.4088, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.21021433174610138 NEG_MLL -668.0677490234375 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.5385]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(2.4798, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.20983679592609406 NEG_MLL -655.2451171875 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.5295]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(2.5515, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.2091773897409439 NEG_MLL -641.8870849609375 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.5176]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(2.6225, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.20894791185855865 NEG_MLL -635.5533447265625 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.5069]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(2.6916, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.20779964327812195 NEG_MLL -627.4630737304688 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.4986]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(2.7600, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.2075592428445816 NEG_MLL -613.1063842773438 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.4921]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(2.8292, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.20726996660232544 NEG_MLL -605.0144653320312 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.4848]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(2.8979, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.2073953002691269 NEG_MLL -595.076416015625 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.4763]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(2.9651, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.20629826188087463 NEG_MLL -584.47900390625 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.4685]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(3.0307, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.20574474334716797 NEG_MLL -577.5623168945312 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.4604]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(3.0953, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.20458050072193146 NEG_MLL -569.2518310546875 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.4516]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(3.1613, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.20519064366817474 NEG_MLL -563.0997924804688 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.4446]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(3.2254, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.20553742349147797 NEG_MLL -555.9343872070312 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.4391]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(3.2874, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.20518353581428528 NEG_MLL -545.2118530273438 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.4328]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(3.3508, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.20519106090068817 NEG_MLL -537.832763671875 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.4251]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(3.4144, grad_fn=<SoftplusBackward0>) k Parameter containing:
tensor(30., requires_grad=True)
eval finished
training current epoch
USING QR BATCH
Running eval


  0%|          | 0/47 [00:00<?, ?it/s]

KeyboardInterrupt: 

[Softmax with temperature] All thresholds at 1e-3
- Fixed T=1 (softmax) |   `RMSE : 0.18995238840579987`
- Fixed T=.005        |   `RMSE : 0.15676772594451904`
- Fixed T=1e-8        |   `RMSE : early blow up ~.9`
- Learn T from 1 |`RMSE: 0.17583267390727997 NEG_MLL -95.2702407836914`
- Learn T from 1e-4|`RMSE: 0.15030381083488464 NEG_MLL -177.95814514160156 `
- Learn T from 1e-8|`RMSE: 0.16479668021202087 NEG_MLL -75.8147964477539`


[new data]
- Nearest neighbor trick k=1 T=1 |`RMSE:Early blow up`
- Nearest neighbor trick k=5 T=1 |`RMSE: 0.14919526875019073`
- Nearest neighbor trick k=10 T=1 |`RMSE: 0.1487170159816742`
- Nearest neighbor trick k=26 T=1 |`RMSE: 0.14025409519672394`
- Nearest neighbor trick k=30 T=1 |`RMSE: 0.14733223617076874 `
- Nearest neighbor trick k=40 T=1 |`RMSE: 0.1309099793434143`
- Nearest neighbor trick k=60 T=1 |`RMSE: 0.14369042217731476 `


- Nearest neighbor trick k=40  T=.5 |`RMSE: 0.16032849252223969 `















# 0.1432650238275528 thresh 1e-7

# RMSE: 0.14299191534519196 thresh 1e-3


#RMSE: 0.16022315621376038 NEG_MLL -58.170013427734375 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.5337]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(2.8697, grad_fn=<SoftplusBackward0>) T=.5

RMSE: 0.1340067833662033 NEG_MLL -81.39476013183594 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.3085]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(3.2416, grad_fn=<SoftplusBackward0>) neighbors =40



In [10]:
from gp.exact_gp.exact_gp import ExactGPModel,CGDMLL
from gp.exact_gp.exact_gp import eval_gp as evgp
from gpytorch.constraints import GreaterThan

def train_exact(train_dataset, test_dataset,epochs):

    dtype=torch.float32
    device="cpu"
    lr=.01
    learn_noise = False
    torch.set_default_dtype(dtype)
    torch.manual_seed(6535)
    Kxxs = []
    # Dataset preparation
    train_x, train_y = flatten_dataset(train_dataset)
    train_x = train_x.to(dtype=dtype, device=device)
    train_y = train_y.to(dtype=dtype, device=device)

    # Model
    # inducing_points = train_x[:num_inducing, :].clone() # torch.rand(num_inducing, D).cuda()
    likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=GreaterThan(1e-1)).to(device=device)
    likelihood.noise = torch.tensor([.5]).to(device=device)
    model = ExactGPModel(train_x, train_y, likelihood, kernel=RBFKernel(), use_scale=True).to(device=device)
    # mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
    mll = CGDMLL(likelihood, model,max_cg_iters=50,cg_tolerance=1e-3)

    # Training parameters
    model.train()
    likelihood.train()

    # Set optimizer
    if learn_noise:
        params = model.parameters()
        hypers = likelihood.parameters()
    else:
        params = model.parameters()
        hypers = []
        
    optimizer = torch.optim.Adam([
        {'params': params},
        {'params': hypers}
    ], lr=lr)
    lr_sched = lambda epoch: 1.0
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_sched)
    
    # Training loop
    pbar = tqdm(range(epochs), desc="Optimizing MLL")
    for epoch in pbar:
        t1 = time.perf_counter()

        # Load batch
        optimizer.zero_grad()
        output = likelihood(model(train_x))
        loss = -mll(output, train_y)
        loss.backward()

        # step optimizers and learning rate schedulers
        optimizer.step()
        scheduler.step()
        t2 = time.perf_counter()

        # Log
        pbar.set_description(f"Epoch {epoch+1}/{epochs}")
        pbar.set_postfix(MLL=f"{-loss.item()}")

        # Evaluate
        #test_rmse, test_nll = evgp(model, likelihood, test_dataset, device=device)
        #print(test_rmse)
        k = model.covar_module(train_x)
        print(torch.linalg.matrix_norm(k.to_dense(),ord='fro'))
        Kxxs += k.to_dense()
        model.train()
        likelihood.train()

    return model, likelihood,Kxxs

model,_,Kxxs = train_exact(train_dataset,test_dataset,epochs=2)

  0%|          | 0/2 [00:00<?, ?it/s]

Epoch 1/2:  50%|█████     | 1/2 [00:26<00:26, 26.02s/it, MLL=-5742.74853515625]

tensor(369.2504, grad_fn=<LinalgVectorNormBackward0>)


: 

: 

In [185]:
def plot_results(GP_classes, all_mean_rmse, all_mean_runtimes, all_mean_relative_errors, all_std_rmse, all_std_runtimes, epochs):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))  # Add a third subplot for relative errors

    for i, GP_class in enumerate(GP_classes):
        epochs_range = range(1, epochs + 1)

        # Plot RMSE
        axes[0].plot(epochs_range, all_mean_rmse[i], label=f'{GP_class.__name__} RMSE')
        axes[0].fill_between(epochs_range,
                             [m - s for m, s in zip(all_mean_rmse[i], all_std_rmse[i])],
                             [m + s for m, s in zip(all_mean_rmse[i], all_std_rmse[i])],
                             alpha=0.3)
        axes[0].set_title('RMSE per Epoch')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('RMSE')
        axes[0].legend()

        # Plot runtimes
        axes[1].plot(epochs_range, all_mean_runtimes[i], label=f'{GP_class.__name__} Runtime')
        axes[1].fill_between(epochs_range,
                             [m - s for m, s in zip(all_mean_runtimes[i], all_std_runtimes[i])],
                             [m + s for m, s in zip(all_mean_runtimes[i], all_std_runtimes[i])],
                             alpha=0.3)
        axes[1].set_title('Training Time per Epoch')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Time (s)')
        axes[1].legend()

        # Plot relative errors
        axes[2].plot(epochs_range, all_mean_relative_errors[i], label=f'{GP_class.__name__} Relative Error')
        axes[2].set_title('Relative Error per Epoch')
        axes[2].set_xlabel('Epoch')
        axes[2].set_ylabel('Relative Error')
        axes[2].legend()

    plt.tight_layout()
    plt.show()
    
def train_gp(GP_class, inducing_points, test_dataset, train_features, train_labels, epochs, device, dtype):
    kernel = RBFKernel().to(device=device, dtype=dtype)
    learn_noise = False
    lr = .01
    batch_size = 1024

    model = GP_class(kernel,
                     inducing_points,
                     noise=1e-3,
                     learn_noise=learn_noise,
                     use_scale=True,
                     dtype=dtype,
                     solver="solve",
                     max_cg_iter=50,
                     cg_tolerance=0.5,
                     mll_approx="hutchinson",
                     fit_chunk_size=1024,
                     use_qr=True)

    epoch_runtimes = []
    epoch_rmse = []
    relative_errors = []  # Store relative errors per epoch

    # Compute the exact K_xx kernel once at the start
    # with torch.no_grad():
    #     # sk = ScaleKernel(kernel).to(device)

    if learn_noise:
        params = model.parameters()
    else:
        params = filter_param(model.named_parameters(), "likelihood.noise_covar.raw_noise")
    optimizer = torch.optim.Adam(params, lr=lr)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    def train_model():
        #==================Train============================
        for epoch in tqdm(range(epochs)):
            epoch_start_time = time.time()

            for x_batch, y_batch in train_loader:
                x_batch = x_batch.clone().detach().to(dtype=dtype, device=device)
                y_batch = y_batch.clone().detach().to(dtype=dtype, device=device)
                optimizer.zero_grad()
                with gpytorch.settings.max_root_decomposition_size(100), max_cholesky_size(int(1.e7)):
                    neg_mll = -model.mll(x_batch, y_batch)
                neg_mll.backward()
                optimizer.step()

            model.fit(train_features, train_labels)
            epoch_end_time = time.time()
            epoch_runtimes.append(epoch_end_time - epoch_start_time)

            #==================Evaluate============================
            eval_results = eval_gp(model, test_dataset, device=device)
            epoch_rmse.append(eval_results['rmse'])

            #==================Relative Error========================
            # Compute the approximation and relative error with K_xx
            with torch.no_grad():
                K_xx = kernels[epoch](train_features, train_features).evaluate()
                W_xz = model._interp(train_features)  # Get interpolation weights
                K_zz = model._mk_cov(model.inducing_points)
                Q_xx = W_xz @ K_zz @ W_xz.T
                print(torch.linalg.matrix_norm(K_xx - Q_xx, ord='fro') / torch.linalg.matrix_norm(K_xx, ord='fro'))
                relative_error = torch.linalg.matrix_norm(K_xx - Q_xx, ord='fro') / torch.linalg.matrix_norm(K_xx, ord='fro')
                relative_errors.append(relative_error.item())

    train_model()
    return epoch_rmse, epoch_runtimes, relative_errors  

def benchmark(GP_classes, train_dataset, test_dataset, epochs=2, seed=42, N=3):
    torch.manual_seed(6535)
    np.random.seed(6535)
    random.seed(6535)

    num_inducing = 512
    dtype = torch.float32
    device = "cpu"
    
    all_mean_rmse = []
    all_mean_runtimes = []
    all_mean_relative_errors = []  # Store mean relative errors across runs
    all_std_rmse = []
    all_std_runtimes = []

    #==================Inducing Points============================
    train_features, train_labels = flatten_dataset(train_dataset)
    kmeans = KMeans(n_clusters=num_inducing)
    kmeans.fit(train_features)
    centers = kmeans.cluster_centers_
    inducing_points = torch.tensor(centers).to(dtype=dtype, device=device)

    for GP_class in GP_classes:
        print(f"Training {GP_class.__name__}...")
        
        all_runs_rmse = []
        all_runs_runtimes = []
        all_runs_relative_errors = []  # Store relative errors for each run

        for run in range(N):
            epoch_rmse, epoch_runtimes, relative_errors = train_gp(
                GP_class,
                inducing_points.clone(),
                test_dataset,
                train_features,
                train_labels,
                epochs,
                device,
                dtype
            )
            all_runs_rmse.append(epoch_rmse)
            all_runs_runtimes.append(epoch_runtimes)
            all_runs_relative_errors.append(relative_errors)

        # Calculate mean and std deviation across the N runs
        mean_rmse = np.mean(all_runs_rmse, axis=0)
        std_rmse = np.std(all_runs_rmse, axis=0)
        mean_runtimes = np.mean(all_runs_runtimes, axis=0)
        std_runtimes = np.std(all_runs_runtimes, axis=0)
        mean_relative_errors = np.mean(all_runs_relative_errors, axis=0)

        all_mean_rmse.append(mean_rmse)
        all_mean_runtimes.append(mean_runtimes)
        all_mean_relative_errors.append(mean_relative_errors)  # Store mean relative errors
        all_std_rmse.append(std_rmse)
        all_std_runtimes.append(std_runtimes)

    return all_mean_rmse, all_mean_runtimes, all_mean_relative_errors, all_std_rmse, all_std_runtimes

GP_classes = [SoftGP_baseline,SoftGP_test]
epochs = 5 # Run for 150 epochs
N = 1  # Number of runs
all_mean_rmse, all_mean_runtimes, all_mean_relative_errors, all_std_rmse, all_std_runtimes = benchmark(GP_classes, train_dataset, test_dataset, epochs=epochs, seed=42, N=N)
plot_results(GP_classes, all_mean_rmse, all_mean_runtimes, all_mean_relative_errors, all_std_rmse, all_std_runtimes, epochs)

  0%|          | 0/2 [00:00<?, ?it/s]

Training SoftGP_baseline...


  0%|          | 0/5 [00:00<?, ?it/s]

USING QR BATCH


  0%|          | 0/47 [00:00<?, ?it/s]

RMSE: 0.26991087198257446 NEG_MLL -1346.427001953125 NOISE 0.0010000001639127731 LENGTHSCALE tensor([[0.6336]], grad_fn=<SoftplusBackward0>) OUTPUTSCALE tensor(0.7653, grad_fn=<SoftplusBackward0>)


NameError: name 'kernels' is not defined