# Benchmark Manifold GP Supervised Learning

## Preamble

This notebook provides an example of how to perform Gaussian Process Regression on a 1D manifold. In this example we consider a supervised learning scenario, namely the number of labeled data points is equivalent to the number of the sampled points from the underlying manifold.

In [1]:
import torch
import gpytorch
import numpy as np
import os
import scipy.spatial as ss
from time import time
from manifold_gp.kernels.riemann_matern_kernel import RiemannMaternKernel
from manifold_gp.models.riemann_gp import RiemannGP
from gpytorch.priors import NormalPrior, GammaPrior

## Dataset Preprocessing

### Load & Settings

In [2]:
torch.manual_seed(1337)
dataset = 'mnist'

data = np.loadtxt('datasets/'+dataset+'_train.csv')
sampled_x, sampled_y = data[:, 2:], data[:, 1]
rand_idx = torch.randperm(sampled_x.shape[0])
sampled_x, sampled_y = sampled_x[rand_idx], sampled_y[rand_idx]
del rand_idx

data = np.loadtxt('datasets/'+dataset+'_test.csv')
test_x, test_y = data[:, 2:], data[:, 1]
    
preprocess = False
normalize_features = False
normalize_labels = True

In [3]:
if preprocess:
    # remove coincident points
    sampled_x, id_unique = np.unique(sampled_x, axis=0, return_index=True)
    sampled_y = sampled_y[id_unique]

    # cut between 0.1 and 0.9 percentile of distances
    import faiss
    res = faiss.StandardGpuResources()
    knn = faiss.GpuIndexIVFFlat(res, sampled_x.shape[1], 1, faiss.METRIC_L2)
    knn.train(sampled_x)
    knn.add(sampled_x)
    v = np.sqrt(knn.search(sampled_x, 51)[0][:,1:])
    idx = np.argsort(v.mean(axis=1).ravel())
    percentile_start = int(np.round(idx.shape[0]*0.10))
    percentile_end = int(np.round(idx.shape[0]*0.90))
    sampled_x = sampled_x[idx[percentile_start:percentile_end], :]
    sampled_y = sampled_y[idx[percentile_start:percentile_end]]
    del knn
m = sampled_x.shape[0]

### Trainset & Testset

In [4]:
split = int(0.1 * m)
train_x, train_y = sampled_x[:split], sampled_y[:split]

train_x, train_y = torch.from_numpy(train_x).float(), torch.from_numpy(train_y).float()
test_x, test_y = torch.from_numpy(test_x).float(), torch.from_numpy(test_y).float()

if normalize_features:
    mu_x, std_x = train_x.mean(dim=-2, keepdim=True), train_x.std(dim=-2, keepdim=True) + 1e-6
    train_x.sub_(mu_x).div_(std_x)
    test_x.sub_(mu_x).div_(std_x)
    
if normalize_labels:
    mu_y, std_y = train_y.mean(), train_y.std()
    train_y.sub_(mu_y).div_(std_y)
    test_y.sub_(mu_y).div_(std_y)

### Move Data to Device

In [5]:
train_x, train_y = train_x.contiguous(), train_y.contiguous()
test_x, test_y = test_x.contiguous(), test_y.contiguous()

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

train_x, train_y = train_x.to(device), train_y.to(device)
test_x, test_y = test_x.to(device), test_y.to(device)

## Model

In [6]:
%%capture
likelihood = gpytorch.likelihoods.GaussianLikelihood(
    noise_constraint=gpytorch.constraints.GreaterThan(1e-8),
    noise_prior=None  # NormalPrior(torch.tensor([0.0]).to(device),  torch.tensor([1/9]).sqrt().to(device))
)

kernel = gpytorch.kernels.ScaleKernel(
    RiemannMaternKernel(
        nu=3,
        nodes=train_x,
        neighbors=50,
        operator="randomwalk",
        modes=100,
        ball_scale=3.0,
        prior_bandwidth=False,
    ),
    outputscale_prior=None  # NormalPrior(torch.tensor([1.0]).to(device),  torch.tensor([1/9]).sqrt().to(device))
)

model = RiemannGP(train_x, train_y, likelihood, kernel).to(device)

## Train

In [7]:
%%capture
hypers = {
    'likelihood.noise_covar.noise': 1e-2,
    'covar_module.base_kernel.epsilon': 0.5,
    'covar_module.base_kernel.lengthscale': 0.5,
    'covar_module.outputscale': 1.0,
}
model.initialize(**hypers)

In [8]:
t0 = time()
model.manifold_informed_train(lr=1e-2, iter=100, 
                              decay_step_size=1000, decay_magnitude=1.0, 
                              norm_step_size=10, norm_rand_vec=100, 
                              verbose=True, save=False)
t1 = time()
print("Time: %.2g sec" % (t1 - t0))

Iter: 0, LR: 0.100, Loss: 327.548, NoiseVar: 0.010, SignalVar: 49688.289, Lengthscale: 0.500, Epsilon: 0.500
Iter: 1, LR: 0.100, Loss: 337.466, NoiseVar: 0.009, SignalVar: 49688.188, Lengthscale: 0.462, Epsilon: 0.462
Iter: 2, LR: 0.100, Loss: 284.070, NoiseVar: 0.008, SignalVar: 49688.188, Lengthscale: 0.488, Epsilon: 0.427
Iter: 3, LR: 0.100, Loss: 273.730, NoiseVar: 0.008, SignalVar: 49688.148, Lengthscale: 0.507, Epsilon: 0.394
Iter: 4, LR: 0.100, Loss: 260.749, NoiseVar: 0.007, SignalVar: 49688.082, Lengthscale: 0.511, Epsilon: 0.363
Iter: 5, LR: 0.100, Loss: 210.878, NoiseVar: 0.006, SignalVar: 49688.004, Lengthscale: 0.503, Epsilon: 0.334
Iter: 6, LR: 0.100, Loss: 173.602, NoiseVar: 0.006, SignalVar: 49687.918, Lengthscale: 0.489, Epsilon: 0.307
Iter: 7, LR: 0.100, Loss: 143.143, NoiseVar: 0.005, SignalVar: 49687.828, Lengthscale: 0.473, Epsilon: 0.282
Iter: 8, LR: 0.100, Loss: 96.233, NoiseVar: 0.005, SignalVar: 49687.738, Lengthscale: 0.457, Epsilon: 0.258
Iter: 9, LR: 0.100, 



Iter: 13, LR: 0.100, Loss: -53.572, NoiseVar: 0.003, SignalVar: 49687.316, Lengthscale: 0.455, Epsilon: 0.165




Iter: 14, LR: 0.100, Loss: -68.785, NoiseVar: 0.003, SignalVar: 49687.234, Lengthscale: 0.462, Epsilon: 0.154




Iter: 15, LR: 0.100, Loss: -69.465, NoiseVar: 0.002, SignalVar: 49687.148, Lengthscale: 0.469, Epsilon: 0.150




Iter: 16, LR: 0.100, Loss: -88.464, NoiseVar: 0.002, SignalVar: 49687.066, Lengthscale: 0.475, Epsilon: 0.149




Iter: 17, LR: 0.100, Loss: -68.979, NoiseVar: 0.002, SignalVar: 49686.984, Lengthscale: 0.479, Epsilon: 0.151




Iter: 18, LR: 0.100, Loss: -62.318, NoiseVar: 0.002, SignalVar: 49686.902, Lengthscale: 0.481, Epsilon: 0.156




Iter: 19, LR: 0.100, Loss: -70.639, NoiseVar: 0.002, SignalVar: 49686.816, Lengthscale: 0.481, Epsilon: 0.163




Iter: 20, LR: 0.100, Loss: -63.866, NoiseVar: 0.001, SignalVar: 49686.727, Lengthscale: 0.477, Epsilon: 0.171




Iter: 21, LR: 0.100, Loss: -39.225, NoiseVar: 0.001, SignalVar: 49686.633, Lengthscale: 0.471, Epsilon: 0.179
Iter: 22, LR: 0.100, Loss: -40.765, NoiseVar: 0.001, SignalVar: 49686.535, Lengthscale: 0.463, Epsilon: 0.187
Iter: 23, LR: 0.100, Loss: -7.560, NoiseVar: 0.001, SignalVar: 49686.438, Lengthscale: 0.456, Epsilon: 0.193
Iter: 24, LR: 0.100, Loss: -19.060, NoiseVar: 0.001, SignalVar: 49686.340, Lengthscale: 0.450, Epsilon: 0.199
Iter: 25, LR: 0.100, Loss: -11.025, NoiseVar: 0.001, SignalVar: 49686.242, Lengthscale: 0.447, Epsilon: 0.203
Iter: 26, LR: 0.100, Loss: 7.043, NoiseVar: 0.001, SignalVar: 49686.148, Lengthscale: 0.448, Epsilon: 0.205
Iter: 27, LR: 0.100, Loss: -8.931, NoiseVar: 0.001, SignalVar: 49686.055, Lengthscale: 0.451, Epsilon: 0.205
Iter: 28, LR: 0.100, Loss: -2.412, NoiseVar: 0.001, SignalVar: 49685.961, Lengthscale: 0.456, Epsilon: 0.203
Iter: 29, LR: 0.100, Loss: -27.460, NoiseVar: 0.001, SignalVar: 49685.867, Lengthscale: 0.462, Epsilon: 0.201
Iter: 30, LR: 0



Iter: 32, LR: 0.100, Loss: -15.181, NoiseVar: 0.001, SignalVar: 49685.570, Lengthscale: 0.471, Epsilon: 0.188




Iter: 33, LR: 0.100, Loss: -53.031, NoiseVar: 0.001, SignalVar: 49685.469, Lengthscale: 0.471, Epsilon: 0.184




Iter: 34, LR: 0.100, Loss: -40.269, NoiseVar: 0.001, SignalVar: 49685.367, Lengthscale: 0.469, Epsilon: 0.179




Iter: 35, LR: 0.100, Loss: -95.092, NoiseVar: 0.000, SignalVar: 49685.262, Lengthscale: 0.464, Epsilon: 0.175




Iter: 36, LR: 0.100, Loss: -74.543, NoiseVar: 0.000, SignalVar: 49685.160, Lengthscale: 0.460, Epsilon: 0.172




Iter: 37, LR: 0.100, Loss: -41.265, NoiseVar: 0.000, SignalVar: 49685.059, Lengthscale: 0.458, Epsilon: 0.170




Iter: 38, LR: 0.100, Loss: -66.361, NoiseVar: 0.000, SignalVar: 49684.957, Lengthscale: 0.455, Epsilon: 0.169




Iter: 39, LR: 0.100, Loss: -58.863, NoiseVar: 0.000, SignalVar: 49684.855, Lengthscale: 0.454, Epsilon: 0.169




Iter: 40, LR: 0.100, Loss: -52.951, NoiseVar: 0.000, SignalVar: 49684.758, Lengthscale: 0.454, Epsilon: 0.171




Iter: 41, LR: 0.100, Loss: -80.405, NoiseVar: 0.000, SignalVar: 49684.660, Lengthscale: 0.457, Epsilon: 0.174




Iter: 42, LR: 0.100, Loss: -39.190, NoiseVar: 0.000, SignalVar: 49684.566, Lengthscale: 0.461, Epsilon: 0.177




Iter: 43, LR: 0.100, Loss: -52.447, NoiseVar: 0.000, SignalVar: 49684.469, Lengthscale: 0.465, Epsilon: 0.180




Iter: 44, LR: 0.100, Loss: -21.066, NoiseVar: 0.000, SignalVar: 49684.371, Lengthscale: 0.467, Epsilon: 0.183




Iter: 45, LR: 0.100, Loss: -24.688, NoiseVar: 0.000, SignalVar: 49684.273, Lengthscale: 0.467, Epsilon: 0.186




Iter: 46, LR: 0.100, Loss: -29.299, NoiseVar: 0.000, SignalVar: 49684.172, Lengthscale: 0.465, Epsilon: 0.188
Iter: 47, LR: 0.100, Loss: -11.478, NoiseVar: 0.000, SignalVar: 49684.070, Lengthscale: 0.463, Epsilon: 0.190
Iter: 48, LR: 0.100, Loss: -25.434, NoiseVar: 0.000, SignalVar: 49683.969, Lengthscale: 0.460, Epsilon: 0.191
Iter: 49, LR: 0.100, Loss: -33.660, NoiseVar: 0.000, SignalVar: 49683.867, Lengthscale: 0.457, Epsilon: 0.191




Iter: 50, LR: 0.100, Loss: -33.833, NoiseVar: 0.000, SignalVar: 49683.766, Lengthscale: 0.456, Epsilon: 0.189




Iter: 51, LR: 0.100, Loss: -40.594, NoiseVar: 0.000, SignalVar: 49683.664, Lengthscale: 0.455, Epsilon: 0.187




Iter: 52, LR: 0.100, Loss: -36.504, NoiseVar: 0.000, SignalVar: 49683.562, Lengthscale: 0.456, Epsilon: 0.185




Iter: 53, LR: 0.100, Loss: -38.254, NoiseVar: 0.000, SignalVar: 49683.461, Lengthscale: 0.458, Epsilon: 0.182




Iter: 54, LR: 0.100, Loss: -47.629, NoiseVar: 0.000, SignalVar: 49683.359, Lengthscale: 0.460, Epsilon: 0.180




Iter: 55, LR: 0.100, Loss: -33.437, NoiseVar: 0.000, SignalVar: 49683.258, Lengthscale: 0.462, Epsilon: 0.179




Iter: 56, LR: 0.100, Loss: -55.665, NoiseVar: 0.000, SignalVar: 49683.156, Lengthscale: 0.463, Epsilon: 0.178




Iter: 57, LR: 0.100, Loss: -60.837, NoiseVar: 0.000, SignalVar: 49683.055, Lengthscale: 0.465, Epsilon: 0.178




Iter: 58, LR: 0.100, Loss: -30.918, NoiseVar: 0.000, SignalVar: 49682.953, Lengthscale: 0.464, Epsilon: 0.177




Iter: 59, LR: 0.100, Loss: -52.992, NoiseVar: 0.000, SignalVar: 49682.852, Lengthscale: 0.463, Epsilon: 0.176




Iter: 60, LR: 0.100, Loss: -41.760, NoiseVar: 0.000, SignalVar: 49682.750, Lengthscale: 0.460, Epsilon: 0.175




Iter: 61, LR: 0.100, Loss: -93.453, NoiseVar: 0.000, SignalVar: 49682.648, Lengthscale: 0.457, Epsilon: 0.175




Iter: 62, LR: 0.100, Loss: -46.976, NoiseVar: 0.000, SignalVar: 49682.547, Lengthscale: 0.456, Epsilon: 0.175




Iter: 63, LR: 0.100, Loss: -47.172, NoiseVar: 0.000, SignalVar: 49682.449, Lengthscale: 0.457, Epsilon: 0.175




Iter: 64, LR: 0.100, Loss: -65.197, NoiseVar: 0.000, SignalVar: 49682.352, Lengthscale: 0.461, Epsilon: 0.176




Iter: 65, LR: 0.100, Loss: -54.234, NoiseVar: 0.000, SignalVar: 49682.254, Lengthscale: 0.464, Epsilon: 0.178




Iter: 66, LR: 0.100, Loss: -74.659, NoiseVar: 0.000, SignalVar: 49682.156, Lengthscale: 0.466, Epsilon: 0.179




Iter: 67, LR: 0.100, Loss: -44.563, NoiseVar: 0.000, SignalVar: 49682.059, Lengthscale: 0.467, Epsilon: 0.180




Iter: 68, LR: 0.100, Loss: -46.575, NoiseVar: 0.000, SignalVar: 49681.957, Lengthscale: 0.466, Epsilon: 0.181




Iter: 69, LR: 0.100, Loss: -19.730, NoiseVar: 0.000, SignalVar: 49681.855, Lengthscale: 0.463, Epsilon: 0.182




Iter: 70, LR: 0.100, Loss: -58.147, NoiseVar: 0.000, SignalVar: 49681.754, Lengthscale: 0.459, Epsilon: 0.183
Iter: 71, LR: 0.100, Loss: -62.628, NoiseVar: 0.000, SignalVar: 49681.652, Lengthscale: 0.456, Epsilon: 0.184




Iter: 72, LR: 0.100, Loss: -8.635, NoiseVar: 0.000, SignalVar: 49681.551, Lengthscale: 0.454, Epsilon: 0.184




Iter: 73, LR: 0.100, Loss: -42.361, NoiseVar: 0.000, SignalVar: 49681.449, Lengthscale: 0.455, Epsilon: 0.185
Iter: 74, LR: 0.100, Loss: -48.158, NoiseVar: 0.000, SignalVar: 49681.348, Lengthscale: 0.456, Epsilon: 0.185




Iter: 75, LR: 0.100, Loss: -43.540, NoiseVar: 0.000, SignalVar: 49681.250, Lengthscale: 0.458, Epsilon: 0.184




Iter: 76, LR: 0.100, Loss: -45.387, NoiseVar: 0.000, SignalVar: 49681.152, Lengthscale: 0.461, Epsilon: 0.183




Iter: 77, LR: 0.100, Loss: -62.971, NoiseVar: 0.000, SignalVar: 49681.055, Lengthscale: 0.465, Epsilon: 0.182




Iter: 78, LR: 0.100, Loss: -25.337, NoiseVar: 0.000, SignalVar: 49680.957, Lengthscale: 0.467, Epsilon: 0.180




Iter: 79, LR: 0.100, Loss: -46.026, NoiseVar: 0.000, SignalVar: 49680.855, Lengthscale: 0.466, Epsilon: 0.179




Iter: 80, LR: 0.100, Loss: -34.072, NoiseVar: 0.000, SignalVar: 49680.754, Lengthscale: 0.464, Epsilon: 0.179




Iter: 81, LR: 0.100, Loss: -42.953, NoiseVar: 0.000, SignalVar: 49680.652, Lengthscale: 0.460, Epsilon: 0.179




Iter: 82, LR: 0.100, Loss: -56.494, NoiseVar: 0.000, SignalVar: 49680.547, Lengthscale: 0.455, Epsilon: 0.179




Iter: 83, LR: 0.100, Loss: -47.856, NoiseVar: 0.000, SignalVar: 49680.441, Lengthscale: 0.451, Epsilon: 0.180




Iter: 84, LR: 0.100, Loss: -51.222, NoiseVar: 0.000, SignalVar: 49680.340, Lengthscale: 0.450, Epsilon: 0.180




Iter: 85, LR: 0.100, Loss: -59.276, NoiseVar: 0.000, SignalVar: 49680.242, Lengthscale: 0.454, Epsilon: 0.181




Iter: 86, LR: 0.100, Loss: -65.048, NoiseVar: 0.000, SignalVar: 49680.145, Lengthscale: 0.459, Epsilon: 0.182




Iter: 87, LR: 0.100, Loss: -39.335, NoiseVar: 0.000, SignalVar: 49680.047, Lengthscale: 0.464, Epsilon: 0.183




Iter: 88, LR: 0.100, Loss: -37.573, NoiseVar: 0.000, SignalVar: 49679.945, Lengthscale: 0.467, Epsilon: 0.183




Iter: 89, LR: 0.100, Loss: -30.615, NoiseVar: 0.000, SignalVar: 49679.844, Lengthscale: 0.467, Epsilon: 0.184




Iter: 90, LR: 0.100, Loss: -55.797, NoiseVar: 0.000, SignalVar: 49679.738, Lengthscale: 0.465, Epsilon: 0.185




Iter: 91, LR: 0.100, Loss: -10.842, NoiseVar: 0.000, SignalVar: 49679.633, Lengthscale: 0.463, Epsilon: 0.185
Iter: 92, LR: 0.100, Loss: -16.907, NoiseVar: 0.000, SignalVar: 49679.527, Lengthscale: 0.459, Epsilon: 0.186




Iter: 93, LR: 0.100, Loss: -52.252, NoiseVar: 0.000, SignalVar: 49679.422, Lengthscale: 0.455, Epsilon: 0.186




Iter: 94, LR: 0.100, Loss: -40.461, NoiseVar: 0.000, SignalVar: 49679.316, Lengthscale: 0.453, Epsilon: 0.185




Iter: 95, LR: 0.100, Loss: -41.877, NoiseVar: 0.000, SignalVar: 49679.215, Lengthscale: 0.455, Epsilon: 0.185




Iter: 96, LR: 0.100, Loss: -32.062, NoiseVar: 0.000, SignalVar: 49679.113, Lengthscale: 0.458, Epsilon: 0.184




Iter: 97, LR: 0.100, Loss: -50.724, NoiseVar: 0.000, SignalVar: 49679.012, Lengthscale: 0.462, Epsilon: 0.183




Iter: 98, LR: 0.100, Loss: -56.323, NoiseVar: 0.000, SignalVar: 49678.910, Lengthscale: 0.464, Epsilon: 0.182
Iter: 99, LR: 0.100, Loss: -22.734, NoiseVar: 0.000, SignalVar: 49678.809, Lengthscale: 0.466, Epsilon: 0.182
Time: 45 sec




## Evaluation

In [9]:
%%capture
likelihood.eval()
model.eval()

## Metrics

In [10]:
with torch.no_grad(), gpytorch.settings.fast_pred_var(), gpytorch.settings.cg_tolerance(10000):
    preds_test = likelihood(model(test_x))
        
    error = test_y - preds_test.mean
    covar = preds_test.lazy_covariance_matrix.evaluate_kernel()
    inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=error.unsqueeze(-1), logdet=True)
    
    rmse = (error.square().sum()/test_y.shape[0]).sqrt()
    nll = 0.5 * sum([inv_quad, logdet, error.size(-1)* np.log(2 * np.pi)])/test_y.shape[0]
    model._clear_cache()
    
print("RMSE: ", rmse)
print("NLL: ", nll)

RMSE:  tensor(0.5869, device='cuda:0')
NLL:  tensor(1617.9510, device='cuda:0')
