# Deep Gaussian Processes with Doubly Stochastic VI

In this notebook, we provide a GPyTorch implementation of deep Gaussian processes, where training and inference is performed using the method of Salimbeni et al., 2017 (https://arxiv.org/abs/1705.08933) adapted to CG-based inference.

We'll be training a simple two layer deep GP on the `elevators` UCI dataset.

In [1]:
import os
import fire
import math
import time
import pyro
import torch
from torch.utils.data import TensorDataset, DataLoader
from collections import namedtuple
import gpytorch
from gpytorch.models import AbstractVariationalGP, PyroVariationalGP
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import VariationalStrategy, WhitenedVariationalStrategy
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel
from gpytorch.distributions import MultivariateNormal
import numpy as np
import pandas as pd
from bayesian_benchmarks import data as _datasets

try:
    from gpytorch.variational import CholeskyVariationalStrategy, WhitenedCholeskyVariationalStrategy
except:
    pass

## Loading Data

For this example notebook, we'll be using the `elevators` UCI dataset used in the paper. Running the next cell downloads a copy of the dataset that has already been scaled and normalized appropriately. For this notebook, we'll simply be splitting the data using the first 80% of the data as training and the last 20% as testing.

**Note**: Running the next cell will attempt to download a ~400 KB dataset file to the current directory.

In [2]:
import urllib.request
import os.path
from scipy.io import loadmat
from math import floor
import numpy as np

if not os.path.isfile('elevators.mat'):
    print('Downloading \'elevators\' UCI dataset...')
    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', 'elevators.mat')
    
data = torch.Tensor(loadmat('elevators.mat')['data'])
X = data[:, :-1]
y = data[:, -1]

N = data.shape[0]
np.random.seed(0)
data = data[np.random.permutation(np.arange(N)),:]

train_n = int(floor(0.8*len(X)))

train_x = X[:train_n, :].contiguous().cuda()
train_y = y[:train_n].contiguous().cuda()

test_x = X[train_n:, :].contiguous().cuda()
test_y = y[train_n:].contiguous().cuda()

mean = train_x.mean(dim=-2, keepdim=True)
std = train_x.std(dim=-2, keepdim=True) + 1e-6
train_x = (train_x - mean) / std
test_x = (test_x - mean) / std

mean,std = train_y.mean(),train_y.std()
train_y = (train_y - mean) / std
test_y = (test_y - mean) / std

In [3]:
train_dataset = TensorDataset(train_x, train_y)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=False)
test_dataset = TensorDataset(test_x, test_y)
test_loader = DataLoader(test_dataset, batch_size=512)

In [19]:
class QuadratureDist(pyro.distributions.Distribution):
    def __init__(self, likelihood, function_dist):
        self.likelihood = likelihood
        self.function_dist = function_dist

    def log_prob(self, target):
        return self.likelihood.expected_log_prob(target, self.function_dist)

    def sample(self, sample_shape=torch.Size()):
        pass


class AbstractPyroHiddenGPLayer(PyroVariationalGP):
    def __init__(self, variational_strategy, input_dims, output_dims, num_samples, num_data, name_prefix=""):
        super().__init__(variational_strategy, None, num_data, name_prefix)
        self.input_dims = input_dims
        self.output_dims = output_dims
        self.num_samples = num_samples

    def model(self, input, output, *params, **kwargs):
        """
        Hidden GP layers do not implement models, only guides.
        """
        raise NotImplementedError

    def __call__(self, inputs):
        """
        input: num_samples x num_data x d
        intermediate: num_samples x output_dims x num_data
        outputs: num_samples x num_data x output_dims
        """
        inputs = inputs.contiguous()
        if inputs.dim() == 2:
            # Assume new input entirely
            inputs = inputs.unsqueeze(0)
            inputs = inputs.expand(self.output_dims, inputs.size(-2), self.input_dims)
        elif inputs.dim() == 3:
            # Assume batch dim is samples, not output_dim
            inputs = inputs.unsqueeze(0)
            inputs = inputs.expand(self.output_dims, inputs.size(1), inputs.size(-2), self.input_dims)

        if inputs.dim() == 4:
            num_samples = inputs.size(-3)
            inputs = inputs.view(self.output_dims, inputs.size(-2) * inputs.size(-3), self.input_dims)
            reshape_output = True
        else:
            reshape_output = False
            num_samples = self.num_samples
        
        qf = super().__call__(inputs)
        
        if reshape_output:
            samples = qf.rsample()
            samples = samples.view(self.output_dims, num_samples, -1).permute(1, 2, 0)
        else:
            samples = qf.rsample(torch.Size([num_samples]))
            samples = samples.transpose(-2, -1)
        
        return samples.contiguous()
        

class PyroDeepGP(AbstractPyroHiddenGPLayer):
    def __init__(self, variational_strategy, likelihood, input_dims, output_dims, num_samples, num_data, hidden_gp_net, name_prefix=""):
        super().__init__(variational_strategy, input_dims, output_dims, num_samples, num_data, name_prefix)
        
        self.likelihood = likelihood
        self.hidden_gp_net = hidden_gp_net
    
    def guide(self, input, output, *params, **kwargs):
        for hidden_layer in self.hidden_gp_net:
            pyro.module(hidden_layer.name_prefix + ".gp_prior", hidden_layer)
#             hidden_layer.sample_inducing_values(hidden_layer.variational_distribution)
            input = hidden_layer(input)  # Propagate input

        # Guide for output layer
        super().guide(input, output, *params, **kwargs)
    
    def model(self, input, output, *params, **kwargs):
        inputs = self.hidden_gp_net(input)
        # inputs is now num_samples x num_data x num_last_hidden
        
        # q(f) = num_samples x num_data
        
        if inputs.dim() == 2:
            # Assume new input entirely
            inputs = inputs.unsqueeze(0)
            inputs = inputs.expand(self.output_dims, inputs.size(-2), self.input_dims)
        elif inputs.dim() == 3:
            # Assume batch dim is samples, not output_dim
            inputs = inputs.unsqueeze(0)
            inputs = inputs.expand(self.output_dims, inputs.size(1), inputs.size(-2), self.input_dims)

        if inputs.dim() == 4:
            num_samples = inputs.size(-3)
            inputs = inputs.view(self.output_dims, inputs.size(-2) * inputs.size(-3), self.input_dims)
            reshape_output = True
        else:
            reshape_output = False
            num_samples = self.num_samples

        pyro.module(self.name_prefix + ".gp_prior", self)

        inducing_points = self.variational_strategy.inducing_points
        num_induc = inducing_points.size(-2)
        full_inputs = torch.cat([inducing_points, inputs], dim=-2)
        full_output = self.forward(full_inputs)
        full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix

        # Mean terms
        induc_mean = full_mean[..., :num_induc]
        test_mean = full_mean[..., num_induc:]

        # Covariance terms
        induc_induc_covar = full_covar[..., :num_induc, :num_induc].add_jitter()
        induc_induc_covar = gpytorch.lazy.CholLazyTensor(induc_induc_covar.cholesky())
        induc_data_covar = full_covar[..., :num_induc, num_induc:].evaluate()
        data_data_covar = full_covar[..., num_induc:, num_induc:]

        # Prior distribution + samples
        prior_distribution = full_output.__class__(induc_mean, induc_induc_covar)
        inducing_values_samples = self.sample_inducing_values(prior_distribution)

        means = induc_data_covar.transpose(-2, -1).matmul(induc_induc_covar.inv_matmul(inducing_values_samples.unsqueeze(-1))).squeeze(-1)
        
        means_ = means.reshape(num_samples, -1)
        vars_ = torch.sqrt((
                data_data_covar.diag() - induc_induc_covar.inv_quad(induc_data_covar, reduce_inv_quad=False)
            ).clamp_min(0))
        vars_ = vars_.reshape(num_samples, -1)
        
        
        f_samples = pyro.distributions.Normal(
            means_,
            vars_,
        )
        
        # f_samples is Normal(num_samples x num_data)
        
        with pyro.plate(self.name_prefix + ".samples_plate", f_samples.batch_shape[-2], dim=-2):
            with pyro.plate(self.name_prefix + ".data_plate", f_samples.batch_shape[-1], dim=-1):
                with pyro.poutine.scale(scale=float(self.num_data / input.size(-2))):
                    out_dist = QuadratureDist(self.likelihood, f_samples)
                    return pyro.sample(self.name_prefix + ".output_value", out_dist, obs=output)

        
    def __call__(self, inputs):
        inputs = self.hidden_gp_net(inputs)
        
        if inputs.dim() == 2:
            # Assume new input entirely
            inputs = inputs.unsqueeze(0)
            inputs = inputs.expand(self.output_dims, inputs.size(-2), self.input_dims)
        elif inputs.dim() == 3:
            # Assume batch dim is samples, not output_dim
            inputs = inputs.unsqueeze(0)
            inputs = inputs.expand(self.output_dims, inputs.size(1), inputs.size(-2), self.input_dims)

        if inputs.dim() == 4:
            num_samples = inputs.size(-3)
            inputs = inputs.view(self.output_dims, inputs.size(-2) * inputs.size(-3), self.input_dims)
            reshape_output = True
        else:
            reshape_output = False
            num_samples = self.num_samples
            
        output = PyroVariationalGP.__call__(self, inputs)
        
        mean_ = output.mean
        var_ = output.variance
        
        mean_ = mean_.reshape(num_samples, -1)
        var_ = var_.reshape(num_samples, -1)
        
        return output.__class__(
            mean_,
            gpytorch.lazy.DiagLazyTensor(var_),
        )
        

In [20]:
class SimpleHiddenLayer(AbstractPyroHiddenGPLayer):
    def __init__(self, input_dims, output_dims, num_samples, num_data, num_inducing=512, name_prefix=""):
        inducing_points = torch.randn(output_dims, num_inducing, input_dims)

        variational_distribution = CholeskyVariationalDistribution(
            num_inducing_points=num_inducing,
            batch_size=output_dims
        )

        variational_strategy = VariationalStrategy(
            self,
            inducing_points,
            variational_distribution,
            learn_inducing_locations=True
        )
        
        super().__init__(variational_strategy, input_dims, output_dims, num_samples, num_data, name_prefix)

        self.mean_module = ConstantMean(batch_size=output_dims)
        self.covar_module = ScaleKernel(RBFKernel(batch_size=output_dims,
                                                  ard_num_dims=input_dims), batch_size=output_dims,
                                        ard_num_dims=None)
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

class SimpleDeepGP(PyroDeepGP):
    def __init__(self, input_dims, output_dims, num_samples, num_data, hidden_gp_net, num_inducing=512, name_prefix=""):
        inducing_points = torch.randn(output_dims, num_inducing, input_dims)

        variational_distribution = CholeskyVariationalDistribution(
            num_inducing_points=num_inducing,
            batch_size=output_dims
        )

        variational_strategy = VariationalStrategy(
            self,
            inducing_points,
            variational_distribution,
            learn_inducing_locations=True
        )
        
        likelihood = gpytorch.likelihoods.GaussianLikelihood()
        
        super().__init__(
            variational_strategy,
            likelihood,
            input_dims,
            output_dims,
            num_samples,
            num_data,
            hidden_gp_net,
            name_prefix,
        )

        self.mean_module = ConstantMean(batch_size=output_dims)
        self.covar_module = ScaleKernel(RBFKernel(batch_size=output_dims,
                                                  ard_num_dims=input_dims), batch_size=output_dims,
                                    ard_num_dims=None)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

In [21]:
train_x.shape

torch.Size([13279, 18])

In [22]:
hidden_gp1 = SimpleHiddenLayer(
    input_dims=train_x.size(-1), 
    output_dims=10,
    num_samples=5,
    num_data=train_x.size(-2),
    name_prefix="hidden_gp1"
).cuda()

hidden_gp2 = SimpleHiddenLayer(
    input_dims=10,
    output_dims=10,
    num_samples=5,
    num_data=train_x.size(-2),
    name_prefix="hidden_gp2"
).cuda()

hidden_gp_net = torch.nn.Sequential(
    hidden_gp1,
    hidden_gp2,
)

deep_gp = SimpleDeepGP(10, 1, 5, train_x.size(-2), hidden_gp_net, name_prefix="output_gp").cuda()

512
512
512


In [23]:
# deep_gp(train_x[:1024, :])

In [24]:
# deep_gp.model(train_x[:1024, :], train_y[:1024])

In [25]:
# deep_gp.guide(train_x[:1024, :], train_y[:1024])

In [26]:
num_epochs = 30

from pyro import optim
# optimizer = optim.Adam({"lr": lr})
scheduler = pyro.optim.MultiStepLR({
    'optimizer': torch.optim.Adam,
    'optim_args': {'lr': 0.01},
    'milestones': [0.75 * num_epochs],
    'gamma': 0.5}
)
elbo = pyro.infer.Trace_ELBO(num_particles=1, vectorize_particles=True)
svi = pyro.infer.SVI(deep_gp.model, deep_gp.guide, scheduler, elbo)

In [27]:
num_iter_per_epoch = len(train_loader)
maxiter = num_iter_per_epoch * num_epochs

total_iter = 0

for i in range(num_epochs):
    for minibatch_i, (x_batch, y_batch) in enumerate(train_loader):
        total_iter += 1
        
        deep_gp.zero_grad()
        loss = torch.tensor(svi.step(x_batch, y_batch))
        print(
            f"Iter: {total_iter}/{maxiter}\t"
            f"loss: {loss.item():.3f}"
        )

INFO:root:Guessed max_plate_nesting = 2


Iter: 1/780	loss: 121024.531
Iter: 2/780	loss: 124045.125
Iter: 3/780	loss: 127748.148
Iter: 4/780	loss: 128445.391
Iter: 5/780	loss: 130713.203
Iter: 6/780	loss: 128700.000
Iter: 7/780	loss: 131327.391
Iter: 8/780	loss: 125699.977
Iter: 9/780	loss: 126386.445
Iter: 10/780	loss: 132934.094
Iter: 11/780	loss: 133077.172
Iter: 12/780	loss: 131425.547
Iter: 13/780	loss: 140960.859
Iter: 14/780	loss: 127055.781
Iter: 15/780	loss: 133786.188
Iter: 16/780	loss: 135585.562
Iter: 17/780	loss: 128565.656
Iter: 18/780	loss: 128022.211
Iter: 19/780	loss: 133863.188
Iter: 20/780	loss: 131440.781
Iter: 21/780	loss: 128582.406
Iter: 22/780	loss: 131624.719
Iter: 23/780	loss: 127233.883
Iter: 24/780	loss: 134449.062
Iter: 25/780	loss: 125803.367
Iter: 26/780	loss: 133936.703
Iter: 27/780	loss: 121528.586
Iter: 28/780	loss: 124806.398
Iter: 29/780	loss: 128201.875
Iter: 30/780	loss: 128603.617
Iter: 31/780	loss: 130918.320
Iter: 32/780	loss: 128439.867
Iter: 33/780	loss: 131332.578
Iter: 34/780	loss: 

Iter: 269/780	loss: 126699.953
Iter: 270/780	loss: 132634.438
Iter: 271/780	loss: 133576.547
Iter: 272/780	loss: 131214.422
Iter: 273/780	loss: 140694.203
Iter: 274/780	loss: 127249.188
Iter: 275/780	loss: 134529.250
Iter: 276/780	loss: 135162.328
Iter: 277/780	loss: 128867.969
Iter: 278/780	loss: 128598.203
Iter: 279/780	loss: 134341.266
Iter: 280/780	loss: 131320.734
Iter: 281/780	loss: 128159.828
Iter: 282/780	loss: 131204.547
Iter: 283/780	loss: 127319.906
Iter: 284/780	loss: 134853.578
Iter: 285/780	loss: 125131.406
Iter: 286/780	loss: 133238.672
Iter: 287/780	loss: 120747.500
Iter: 288/780	loss: 124683.984
Iter: 289/780	loss: 128303.672
Iter: 290/780	loss: 128924.477
Iter: 291/780	loss: 130751.094
Iter: 292/780	loss: 128381.703
Iter: 293/780	loss: 130955.008
Iter: 294/780	loss: 125804.492
Iter: 295/780	loss: 126421.883
Iter: 296/780	loss: 132835.641
Iter: 297/780	loss: 133674.891
Iter: 298/780	loss: 131071.578
Iter: 299/780	loss: 140531.469
Iter: 300/780	loss: 127142.578
Iter: 30

Iter: 534/780	loss: 127175.344
Iter: 535/780	loss: 134441.297
Iter: 536/780	loss: 135248.250
Iter: 537/780	loss: 128473.047
Iter: 538/780	loss: 128750.312
Iter: 539/780	loss: 134571.078
Iter: 540/780	loss: 130849.344
Iter: 541/780	loss: 128564.812
Iter: 542/780	loss: 131494.812
Iter: 543/780	loss: 127536.312
Iter: 544/780	loss: 134853.484
Iter: 545/780	loss: 125547.438
Iter: 546/780	loss: 133832.375
Iter: 547/780	loss: 121418.086
Iter: 548/780	loss: 124354.883
Iter: 549/780	loss: 128310.430
Iter: 550/780	loss: 128668.695
Iter: 551/780	loss: 130742.094
Iter: 552/780	loss: 128499.766
Iter: 553/780	loss: 131237.094
Iter: 554/780	loss: 126411.945
Iter: 555/780	loss: 126403.688
Iter: 556/780	loss: 132737.078
Iter: 557/780	loss: 133179.062
Iter: 558/780	loss: 131155.141
Iter: 559/780	loss: 140747.719
Iter: 560/780	loss: 126862.438
Iter: 561/780	loss: 134323.797
Iter: 562/780	loss: 135377.953
Iter: 563/780	loss: 128302.047
Iter: 564/780	loss: 128689.758
Iter: 565/780	loss: 133983.531
Iter: 56

In [None]:
# torch.matmul(A, B)
# A = b1 x b2 ... bk x n x m
# B = 1 x 1 x ... x bk x m x k