# Large-Scale Stochastic Variational GP Regression (CUDA)

## Overview

In this notebook, we'll give an overview of how to use SVGP stochastic variational regression ((https://arxiv.org/pdf/1411.2005.pdf)) to rapidly train using minibatches on the `3droad` UCI dataset with hundreds of thousands of training examples. 

In [1]:
import math
import torch
import tqdm
import pyro
import gpytorch
from matplotlib import pyplot as plt

# Make plots inline
%matplotlib inline

## Loading Data

For this example notebook, we'll be using the `song` UCI dataset used in the paper. Running the next cell downloads a copy of the dataset that has already been scaled and normalized appropriately. For this notebook, we'll simply be splitting the data using the first 80% of the data as training and the last 20% as testing.

**Note**: Running the next cell will attempt to download a **~136 MB** file to the current directory.

In [17]:
import urllib.request
import os.path
from scipy.io import loadmat
from math import floor
    
data = torch.Tensor(loadmat('/home/jake.gardner/data/pol.mat')['data'])
X = data[:, :-1]
X = X - X.min(0)[0]
X = 2 * (X / X.max(0)[0]) - 1
y = data[:, -1]

y = (y - y.mean()) / y.std()

# Use the first 80% of the data for training, and the last 20% for testing.
train_n = int(floor(0.8*len(X)))

train_x = X[:train_n, :].contiguous().cuda()
train_y = y[:train_n].contiguous().cuda()

test_x = X[train_n:, :].contiguous().cuda()
test_y = y[train_n:].contiguous().cuda()

## Creating a DataLoader

The next step is to create a torch `DataLoader` that will handle getting us random minibatches of data. This involves using the standard `TensorDataset` and `DataLoader` modules provided by PyTorch.

In this notebook we'll be using a fairly large batch size of 1024 just to make optimization run faster, but you could of course change this as you so choose.

In [18]:
from torch.utils.data import TensorDataset, DataLoader
train_dataset = TensorDataset(train_x, train_y)
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)

test_dataset = TensorDataset(test_x, test_y)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

## Defining the SVGP Model

We now define the GP regression module that, intuitvely, will act as the final "layer" of our neural network. In this case, because we are doing variational inference and *not* exact inference, we will be using an `AbstractVariationalGP`. In this example, because we will be learning the inducing point locations, we'll be using a base `VariationalStrategy` with `learn_inducing_locations=True`.

Because the feature extractor we defined above extracts two features, we'll need to define our grid bounds over two dimensions.

In [25]:
from gpytorch.models import GaussianPredictiveGP, SteinVariationalGP


class GPModel(GaussianPredictiveGP):
    def __init__(self, inducing_points, likelihood, num_data):
        super().__init__(inducing_points, likelihood=likelihood, num_data=num_data)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

    
inducing_points = torch.randn(512, train_x.size(-1))
likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda()
model = GPModel(inducing_points=inducing_points, likelihood=likelihood, num_data=train_n).cuda()

## Training the Model

The cell below trains the model above, learning both the hyperparameters of the Gaussian process **and** the parameters of the neural network in an end-to-end fashion using Type-II MLE.

Unlike when using the exact GP marginal log likelihood, performing variational inference allows us to make use of stochastic optimization techniques. For this example, we'll do one epoch of training. Given the small size of the neural network relative to the size of the dataset, this should be sufficient to achieve comparable accuracy to what was observed in the DKL paper.

The optimization loop differs from the one seen in our more simple tutorials in that it involves looping over both a number of training iterations (epochs) *and* minibatches of the data. However, the basic process is the same: for each minibatch, we forward through the model, compute the loss (the `VariationalMarginalLogLikelihood` or ELBO), call backwards, and do a step of optimization.

In [26]:
from pyro import optim
from pyro import infer
import pyro

pyro.clear_param_store()

optimizer = optim.Adam({"lr": 0.01})
num_epochs = 100

kernel = infer.RBFSteinKernel()
svi = infer.SVGD(model.model, kernel, optimizer, num_particles=16, max_plate_nesting=1)

for i in range(num_epochs):
    # Within each iteration, we will go over each minibatch of data
    loader = tqdm.tqdm_notebook(train_loader, desc=f"Train (Epoch {i + 1})")
    for x_batch, y_batch in loader:
        loss = svi.step(x_batch, y_batch)
        loader.set_postfix(
            loss=(loss['.inducing_values'] / train_n),
            ls=model.covar_module.base_kernel.lengthscale.item(),
            os=model.covar_module.outputscale.item()
        )


HBox(children=(IntProgress(value=0, description='Train (Epoch 1)', max=12, style=ProgressStyle(description_wid…




HBox(children=(IntProgress(value=0, description='Train (Epoch 2)', max=12, style=ProgressStyle(description_wid…




HBox(children=(IntProgress(value=0, description='Train (Epoch 3)', max=12, style=ProgressStyle(description_wid…




HBox(children=(IntProgress(value=0, description='Train (Epoch 4)', max=12, style=ProgressStyle(description_wid…




HBox(children=(IntProgress(value=0, description='Train (Epoch 5)', max=12, style=ProgressStyle(description_wid…




HBox(children=(IntProgress(value=0, description='Train (Epoch 6)', max=12, style=ProgressStyle(description_wid…




HBox(children=(IntProgress(value=0, description='Train (Epoch 7)', max=12, style=ProgressStyle(description_wid…




HBox(children=(IntProgress(value=0, description='Train (Epoch 8)', max=12, style=ProgressStyle(description_wid…




HBox(children=(IntProgress(value=0, description='Train (Epoch 9)', max=12, style=ProgressStyle(description_wid…




HBox(children=(IntProgress(value=0, description='Train (Epoch 10)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 11)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 12)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 13)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 14)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 15)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 16)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 17)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 18)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 19)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 20)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 21)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 22)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 23)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 24)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 25)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 26)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 27)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 28)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 29)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 30)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 31)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 32)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 33)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 34)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 35)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 36)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 37)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 38)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 39)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 40)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 41)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 42)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 43)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 44)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 45)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 46)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 47)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 48)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 49)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 50)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 51)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 52)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 53)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 54)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 55)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 56)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 57)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 58)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 59)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 60)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 61)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 62)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 63)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 64)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 65)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 66)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 67)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 68)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 69)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 70)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 71)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 72)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 73)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 74)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 75)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 76)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 77)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 78)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 79)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 80)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 81)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 82)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 83)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 84)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 85)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 86)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 87)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 88)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 89)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 90)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 91)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 92)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 93)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 94)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 95)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 96)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 97)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 98)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 99)', max=12, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Train (Epoch 100)', max=12, style=ProgressStyle(description_w…




In [27]:
# model.variational_distribution.chol_variational_covar

## Making Predictions

The next cell gets the predictive covariance for the test set (and also technically gets the predictive mean, stored in `preds.mean()`). Because the test set is substantially smaller than the training set, we don't need to make predictions in mini batches here, although this can be done by passing in minibatches of `test_x` rather than the full tensor.

In [29]:
model.eval()
likelihood.eval()
all_means = []
all_vars = []
all_log_probs = []
with torch.no_grad(), pyro.condition(data=svi.get_named_particles()):
    for x_batch, y_batch in test_loader:
        preds = model.likelihood(model(x_batch))
        means = preds.mean.cpu()
        vars = preds.variance.cpu() #  - means.pow(2) + preds.mean.pow(2).mean(dim=0).cpu()
        means = means.squeeze()
        vars = vars.squeeze()
        log_probs = torch.distributions.Normal(
            means, vars.clamp_min(1e-5).sqrt()
        ).log_prob(y_batch.cpu())

        all_means.append(means)
        all_vars.append(vars)
        all_log_probs.append(log_probs)

means = torch.cat(all_means)
vars = torch.cat(all_vars)
log_probs = torch.cat(all_log_probs)

In [30]:
print('Test MAE: {}'.format(torch.mean(torch.abs(means - test_y.cpu()))))

Test MAE: 0.24857251346111298


In [31]:
print('Test NLL: {}'.format(-torch.mean(log_probs)))

Test NLL: -0.20033620297908783


In [14]:
print(vars.min(), vars.mean(), vars.max())

tensor(0.0063) tensor(0.3196) tensor(1.3654)


In [7]:
preds = model.likelihood(model(x_batch))

In [13]:
all_means[0].shape

torch.Size([1, 1024])

In [13]:
model.variational_distribution.variational_distribution.mean

Parameter containing:
tensor([ -4.6352,  -5.4083,  -5.6489,  -5.3201,  -5.3477,  -5.0642,  -5.4948,
         -3.7757,  -4.2214,  -4.7031,  -3.0722,  -5.5411,  -3.8173,   4.2127,
         -5.8280,  -5.5388,  -5.7286,  -5.9319,  -4.5326,  -4.8363,  -5.8946,
         -3.1866,  -6.4119,  -6.1779,  -3.8048,  -5.2828,   6.9590,  -5.0885,
         -4.7191,  -6.3420,  -5.7104,  -6.3354,  -4.4375,  -6.0836,  -5.8743,
          7.8812,  -4.2562,   5.9606,  -5.8739,  -6.1699,   5.6891,  -5.5693,
         -1.7132,  -6.1152,  -0.3793,  -6.0570,  -6.5817,  -5.8898,  -7.1212,
         -5.7330,   5.9369,  -6.6238,  -7.3095,   3.9049,   7.0287,  -6.5150,
          5.8446,  -6.0420,  -5.0433,  -6.2936,   7.1456,  -6.5200,   0.0787,
         -5.5423,  -8.6788,  -7.1170,   4.6330,   3.7760,  -5.1240,  -5.1637,
         -5.3866,  -5.6956,  -5.9857,  -7.3920,  -5.9660,  -7.0695,   7.3482,
         -5.8691,  -5.8548,  -7.2414,  -4.7209,  -5.6269,  -6.5979,  -9.3246,
         -5.1241,  -5.8267,  -7.1856,  -7.