In [1]:
from src.spaces import Grassmannian, OrientedGrassmannian, HyperbolicSpace, SO, \
    SymmetricPositiveDefiniteMatrices, Sphere, Stiefel, SU
from src.spectral_kernel import RandomSpectralKernel, EigenbasisSumKernel, RandomFourierFeatureKernel, RandomPhaseKernel
from src.prior_approximation import RandomPhaseApproximation, RandomFourierApproximation
from src.spectral_measure import MaternSpectralMeasure, SqExpSpectralMeasure
from examples.gpr_model import ExactGPModel, train
from torch.nn import MSELoss
from torch.autograd.functional import _vmap as vmap
import gpytorch
import torch
import sys, os

INFO: Using numpy backend


In [2]:
sys.setrecursionlimit(2000)
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
torch.autograd.set_detect_anomaly(True)
dtype = torch.float64
device = 'cuda' if torch.cuda.is_available() else 'cpu'
%matplotlib widget


In [3]:
# choose some space
#n, m = 3, 1
#space = Grassmannian(n, m, order=10, average_order=10)
n, m = 3, 3
space = SO(n, order=10)

In [4]:
def f(x):
    dist = space.pairwise_dist(x, space.id.view(-1, *space.id.shape)).squeeze()
    return torch.sin(dist)

In [5]:
#configure kernel

lengthscale, nu, variance = 1.0, 5.0 + space.dim, 1.0
measure = SqExpSpectralMeasure(space.dim, lengthscale, variance=variance)
#self.measure = MaternSpectralMeasure(self.space.dim, self.lengthscale, self.nu)

kernel = EigenbasisSumKernel(measure, space)
#kernel = RandomPhaseKernel(measure, space, phase_order=10)

kernel_ = EigenbasisSumKernel(measure, space)
sampler = RandomPhaseApproximation(kernel_)

In [6]:
n_train, n_test = 50, 100
train_x, test_x = space.rand(n_train), space.rand(n_test)
train_y, test_y = f(train_x), f(test_x)

train_x, test_x = train_x.reshape(n_train, -1), test_x.reshape(n_test, -1)

In [7]:
# train model
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(train_x, train_y, likelihood, kernel, space, point_shape=(n, m)).to(device=device)
train(model, train_x, train_y, 900, 300)

torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.
X = torch.triangular_solve(B, A).solution
should be replaced with
X = torch.linalg.solve_triangular(A, B). (Triggered internally at  C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\BatchLinearAlgebra.cpp:1672.)
  res = torch.triangular_solve(right_tensor, self.evaluate(), upper=self.upper).solution


Iter 300/900 - Loss: -0.209   lengthscale: 1.230 variance: 0.658   noise: 0.001
Iter 600/900 - Loss: -0.209   lengthscale: 1.230 variance: 0.658   noise: 0.001
Iter 900/900 - Loss: -0.209   lengthscale: 1.230 variance: 0.658   noise: 0.001


In [8]:
model.eval()
with torch.no_grad(), gpytorch.settings.skip_posterior_variances(state=True):
    pred_f = model(test_x)
pred_y = pred_f.mean
error = MSELoss()(pred_y, test_y)
print("prediction mse error:", error.detach().cpu())
print("data variance:", torch.var(test_y))

prediction mse error: tensor(0.0035, dtype=torch.float64)
data variance: tensor(0.3996, device='cuda:0', dtype=torch.float64)


In [9]:
# Regression with euclidean kernel

likelihood = gpytorch.likelihoods.GaussianLikelihood()
euclidean_kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
euclidean_model = ExactGPModel(train_x, train_y, likelihood, euclidean_kernel, space).to(device=device)
train(euclidean_model, train_x, train_y, 900, 300)

Iter 300/900 - Loss: -0.155   lengthscale: 2.984 variance: 1.988   noise: 0.000
Iter 600/900 - Loss: -0.156   lengthscale: 2.969 variance: 1.960   noise: 0.000
Iter 900/900 - Loss: -0.156   lengthscale: 2.967 variance: 1.956   noise: 0.000


In [10]:
euclidean_model.eval()
with torch.no_grad(), gpytorch.settings.skip_posterior_variances(state=True):
    euclidean_f = euclidean_model(test_x)
euclidean_pred_y = euclidean_f.mean
euclidean_error = MSELoss()(euclidean_pred_y, test_y)
print("euclidean error:", euclidean_error.detach().cpu())
print("data variance:", torch.var(test_y))

euclidean error: tensor(0.0042, dtype=torch.float64)
data variance: tensor(0.3996, device='cuda:0', dtype=torch.float64)
