# IMGP - Supervised Learning - RMNIST

## Preamble

This notebook provides an example of how to perform Gaussian Process Regression on a 1D manifold. In this example we consider a supervised learning scenario, namely the number of labeled data points is equivalent to the number of the sampled points from the underlying manifold.

In [1]:
import torch
import gpytorch
import numpy as np
import math
import gc

%matplotlib widget
import matplotlib.pyplot as plt

from gpytorch.constraints import GreaterThan

from manifold_gp.kernels import RiemannMaternKernel
from manifold_gp.models import RiemannGP, VanillaGP
from manifold_gp.utils import rmnist_dataset, vanilla_train, manifold_informed_train, test_model, NearestNeighbors



## Dataset

In [2]:
num_train = 0.5
scaling, single_digit, regenerate = True, True, False, 
normalize_x, normalize_y = False, True
graphbandwidth_constraint, graphbandwidth_prior = True, False
load_manifold_model, load_vanilla_model = True, True
save_manifold_model, save_vanilla_model = True, True

In [3]:
train_x, train_y, test_x, test_y = rmnist_dataset(scaling=scaling, single_digit=single_digit, regenerate=regenerate)
torch.manual_seed(1337)
train_idx = torch.zeros(train_x.shape[0]).scatter_(0, torch.randperm(train_x.shape[0])[:int(num_train*train_x.shape[0])], 1).bool()
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
train_x, train_y = train_x[train_idx].contiguous().to(device).flatten(start_dim=1), train_y[train_idx].contiguous().to(device)
test_x, test_y = test_x.contiguous().to(device).flatten(start_dim=1), test_y.contiguous().to(device)

if normalize_x:
    mu_x, std_x = train_x.mean(dim=-2, keepdim=True), train_x.std(dim=-2, keepdim=True) + 1e-6
    train_x.sub_(mu_x).div_(std_x)
    test_x.sub_(mu_x).div_(std_x)
if normalize_y:
    mu_y, std_y = train_y.mean(), train_y.std()
    train_y.sub_(mu_y).div_(std_y)
    test_y.sub_(mu_y).div_(std_y)

Loading SRMNIST


In [4]:
if graphbandwidth_constraint or graphbandwidth_prior:
    knn = NearestNeighbors(train_x, nlist=1)
    edge_values = knn.search(train_x, 10)[0][:, 1:]
    
    graphbandwidth_min = edge_values[:,0].max().div(-4*math.log(1e-4)).sqrt()
    median = edge_values.sqrt().mean(dim=1).sort()[0][int(round(edge_values.shape[0]*0.50))]
    gamma_rate = 4*median/(median-graphbandwidth_min)**2
    gamma_concentration = gamma_rate * median + 1
    
    del knn, edge_values

## Model

In [5]:
%%capture
model_vanilla = VanillaGP(
    train_x, 
    train_y, 
    gpytorch.likelihoods.GaussianLikelihood(), 
    gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
).to(device)

hypers_vanilla = {
    'likelihood.noise_covar.noise': 1e-2,
    'covar_module.base_kernel.lengthscale': 1.0,
    'covar_module.outputscale': 1.0,
}
model_vanilla.initialize(**hypers_vanilla)

In [6]:
%%capture
likelihood = gpytorch.likelihoods.GaussianLikelihood(
    noise_constraint=gpytorch.constraints.GreaterThan(1e-8),
)

kernel = gpytorch.kernels.ScaleKernel(
    RiemannMaternKernel(
        nu=2,
        x=train_x,
        nearest_neighbors=10,
        laplacian_normalization="randomwalk",
        num_modes=1100,
        bump_scale=3.0,
        bump_decay=1.0,
        graphbandwidth_constraint=gpytorch.constraints.GreaterThan(graphbandwidth_min) if graphbandwidth_constraint else None,
        graphbandwidth_prior=gpytorch.priors.GammaPrior(gamma_concentration, gamma_rate) if graphbandwidth_prior else None
    )
)

model = RiemannGP(train_x, train_y, likelihood, kernel).to(device)

hypers = {
    'likelihood.noise_covar.noise': 1e-2,
    'covar_module.base_kernel.graphbandwidth': 1.0,
    'covar_module.base_kernel.lengthscale': 1.0,
    'covar_module.outputscale': 1.0,
}
model.initialize(**hypers)

## Train

In [7]:
if not load_manifold_model:
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0.0)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=50, threshold=1e-3, threshold_mode='rel',
                                                           cooldown=0, min_lr=0, eps=1e-8, verbose=True)
    loss = manifold_informed_train(model, optimizer, max_iter=200, tolerance=1e-2, update_norm=None, num_rand_vec=100,
                                   max_cholesky=1000, cg_tolerance=1e-2, cg_max_iter=2000, scheduler=scheduler, verbose=True)
    if save_manifold_model:
        torch.save(model.state_dict(), '../models/srmnist_manifold_supervised.pth' if single_digit else '../models/rmnist_manifold_supervised.pth')
else:
    model.load_state_dict(torch.load('../models/srmnist_manifold_supervised.pth' if single_digit else '../models/rmnist_manifold_supervised.pth'))

In [8]:
if not load_vanilla_model:
    optimizer_vanilla = torch.optim.Adam(model_vanilla.parameters(), lr=1e-2, weight_decay=0.0)
    scheduler_vanilla = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_vanilla, mode='min', factor=0.5, patience=200, threshold=1e-3, 
                                                                   threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-8, verbose=False)
    loss = vanilla_train(model_vanilla, optimizer_vanilla, max_iter=200, max_cholesky=1000, tolerance=1e-2, cg_tolerance=1e-2, cg_max_iter=1000, scheduler=None, 
                  verbose=True)
    if save_vanilla_model:
        torch.save(model_vanilla.state_dict(), '../models/srmnist_vanilla_supervised.pth' if single_digit else '../models/rmnist_vanilla_supervised.pth')
else:
    model_vanilla.load_state_dict(torch.load('../models/srmnist_vanilla_supervised.pth' if single_digit else '../models/rmnist_vanilla_supervised.pth'))

## Evaluation

In [9]:
rmse_vanilla, nll_vanilla = test_model(model_vanilla, test_x, test_y, noisy_test=True, base_model=None, max_cholesky=1000, cg_tolerance=1e-2, cg_iterations=1000)
print("RMSE Vanilla: ", rmse_vanilla)
print("NLL Vanilla: ", nll_vanilla)

RMSE Vanilla:  tensor(0.0093, device='cuda:0')
NLL Vanilla:  tensor(-1.7322, device='cuda:0')


In [10]:
rmse, nll = test_model(model, test_x, test_y, noisy_test=True, base_model=model_vanilla, max_cholesky=1000, cg_tolerance=1e-2, cg_iterations=1000)
print("RMSE Geometric: ", rmse)
print("NLL Geometric: ", nll)

RMSE Geometric:  tensor(0.0720, device='cuda:0')
NLL Geometric:  tensor(-0.7352, device='cuda:0')
