Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Fantasy Strategy for Variational GPs #1874

Merged
merged 18 commits into from
Jun 3, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
476 changes: 476 additions & 0 deletions examples/08_Advanced_Usage/SVGP_Model_Updating.ipynb

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions examples/08_Advanced_Usage/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ See the `1D derivatives GP example`_ or the `2D derivatives GP example`_ for exa
Simple_Batch_Mode_GP_Regression.ipynb:


Variational Fantasization
----------------------------------
We also include an example of how to perform fantasy modelling (e.g. efficient, closed form updates) for variational
Gaussian process models, enabling their usage for lookahead optimization.

.. _Variational fantasization:
SVGP_Model_Updating.ipynb

Converting Models to TorchScript
----------------------------------

Expand Down
27 changes: 27 additions & 0 deletions gpytorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,33 @@ def pyro_model(self, input, beta=1.0, name_prefix=""):
"""
return super().pyro_model(input, beta=beta, name_prefix=name_prefix)

def get_fantasy_model(self, inputs, targets, **kwargs):
Copy link
Contributor

@samuelstanton samuelstanton Mar 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see why it might be convenient to always have get_fantasy_model return an ExactGP, regardless of the original model class, but it might be worth considering naming this something else, reserving get_fantasy_model for the version of OVC that returns a variational GP (in other words make a package-level decision to require that get_fantasy_model always returns an instance of the original class).

Copy link
Collaborator Author

@wjmaddox wjmaddox Mar 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a thought I originally had but it requires the unstable direct updates to m, S in order to return its own model class itself rather than the exactGP. Although potentially lower overhead in the future to implement new fantasization strategies.

r"""
Returns a new GP model that incorporates the specified inputs and targets as new training data using
online variational conditioning (OVC).

This function first casts the inducing points and variational parameters into pseudo-points before
returning an equivalent ExactGP model with a specialized likelihood.

.. note::
If `targets` is a batch (e.g. `b x m`), then the GP returned from this method will be a batch mode GP.
If `inputs` is of the same (or lesser) dimension as `targets`, then it is assumed that the fantasy points
are the same for each target batch.

:param torch.Tensor inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
observations.
:param torch.Tensor targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
:return: An `ExactGP` model with `n + m` training examples, where the `m` fantasy examples have been added
and all test-time caches have been updated.
:rtype: ~gpytorch.models.ExactGP

Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,"
Maddox, Stanton, Wilson, NeurIPS, '21
https://papers.nips.cc/paper/2021/hash/325eaeac5bef34937cfdc1bd73034d17-Abstract.html

"""
return self.variational_strategy.get_fantasy_model(inputs=inputs, targets=targets, **kwargs)

def __call__(self, inputs, prior=False, **kwargs):
if inputs.dim() == 1:
inputs = inputs.unsqueeze(-1)
Expand Down
72 changes: 72 additions & 0 deletions gpytorch/test/variational_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,19 @@ def _eval_iter(self, model, batch_shape=torch.Size([]), cuda=False):

return output

def _fantasy_iter(self, model, likelihood, batch_shape=torch.Size([]), cuda=False, num_fant=10):
model.likelihood = likelihood
val_x = torch.randn(*batch_shape, num_fant, 2).clamp(-2.5, 2.5)
val_y = torch.linspace(-1, 1, num_fant)
val_y = val_y.view(num_fant, *([1] * (len(self.event_shape) - 1)))
val_y = val_y.expand(*batch_shape, num_fant, *self.event_shape[1:])
if cuda:
model = model.cuda()
val_x = val_x.cuda()
val_y = val_y.cuda()
updated_model = model.get_fantasy_model(val_x, val_y)
return updated_model

@abstractproperty
def batch_shape(self):
raise NotImplementedError
Expand Down Expand Up @@ -272,3 +285,62 @@ def test_training_all_batch_zero_mean(self):
expected_batch_shape=(torch.Size([3, 4]) + self.batch_shape),
constant_mean=False,
)

def test_fantasy_call(
self,
data_batch_shape=None,
inducing_batch_shape=None,
model_batch_shape=None,
expected_batch_shape=None,
constant_mean=True,
):
# Batch shapes
model_batch_shape = model_batch_shape if model_batch_shape is not None else self.batch_shape
data_batch_shape = data_batch_shape if data_batch_shape is not None else self.batch_shape
inducing_batch_shape = inducing_batch_shape if inducing_batch_shape is not None else self.batch_shape
expected_batch_shape = expected_batch_shape if expected_batch_shape is not None else self.batch_shape

num_inducing = 16
num_fant = 10
# Make model and likelihood
model, likelihood = self._make_model_and_likelihood(
batch_shape=model_batch_shape,
inducing_batch_shape=inducing_batch_shape,
distribution_cls=self.distribution_cls,
strategy_cls=self.strategy_cls,
constant_mean=constant_mean,
num_inducing=num_inducing,
)

fant_model = self._fantasy_iter(model, likelihood, data_batch_shape, self.cuda, num_fant=num_fant)
self.assertTrue(isinstance(fant_model, gpytorch.models.ExactGP))
self.assertTrue(fant_model.prediction_strategy is not None)
for key in fant_model.prediction_strategy._memoize_cache.keys():
if key[0] == "mean_cache":
break
gpleiss marked this conversation as resolved.
Show resolved Hide resolved
mean_cache = fant_model.prediction_strategy._memoize_cache[key]
self.assertEqual(mean_cache.shape, torch.Size([*expected_batch_shape, num_inducing + num_fant]))

def test_fantasy_call_batch_inducing(self):
return self.test_fantasy_call(
model_batch_shape=(torch.Size([3]) + self.batch_shape),
data_batch_shape=self.batch_shape,
inducing_batch_shape=(torch.Size([3]) + self.batch_shape),
expected_batch_shape=(torch.Size([3]) + self.batch_shape),
)

def test_fantasy_call_batch_data(self):
return self.test_fantasy_call(
model_batch_shape=self.batch_shape,
inducing_batch_shape=self.batch_shape,
data_batch_shape=(torch.Size([3]) + self.batch_shape),
expected_batch_shape=(torch.Size([3]) + self.batch_shape),
)

def test_fantasy_call_batch_model(self):
return self.test_fantasy_call(
model_batch_shape=(torch.Size([3]) + self.batch_shape),
inducing_batch_shape=self.batch_shape,
data_batch_shape=self.batch_shape,
expected_batch_shape=(torch.Size([3]) + self.batch_shape),
)
134 changes: 133 additions & 1 deletion gpytorch/variational/_variational_strategy.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,46 @@
#!/usr/bin/env python3

import functools
from abc import ABC, abstractproperty
from copy import deepcopy

import torch

from .. import settings
from ..distributions import Delta, MultivariateNormal
from ..models import ExactGP
from ..module import Module
from ..utils.broadcasting import _mul_broadcast_shape
from ..utils.memoize import cached, clear_cache_hook
from ..utils.memoize import add_to_cache, cached, clear_cache_hook


class _BaseExactGP(ExactGP):
def __init__(self, train_inputs, train_targets, likelihood, mean_module, covar_module):
super().__init__(train_inputs, train_targets, likelihood)
self.mean_module = mean_module
self.covar_module = covar_module

def forward(self, x):
mean = self.mean_module(x)
covar = self.covar_module(x)
return MultivariateNormal(mean, covar)


def _add_cache_hook(tsr, pred_strat):
if tsr.grad_fn is not None:
wrapper = functools.partial(clear_cache_hook, pred_strat)
functools.update_wrapper(wrapper, clear_cache_hook)
tsr.grad_fn.register_hook(wrapper)
return tsr


class _VariationalStrategy(Module, ABC):
"""
Abstract base class for all Variational Strategies.
"""

has_fantasy_strategy = False

def __init__(self, model, inducing_points, variational_distribution, learn_inducing_locations=True):
super().__init__()

Expand Down Expand Up @@ -97,6 +122,113 @@ def kl_divergence(self):
kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
return kl_divergence

@cached(name="inducing_model")
def inducing_model(self):
with torch.no_grad():
inducing_noise_covar, inducing_mean = self.pseudo_points
wjmaddox marked this conversation as resolved.
Show resolved Hide resolved
inducing_points = self.inducing_points.detach()
wjmaddox marked this conversation as resolved.
Show resolved Hide resolved
if inducing_points.ndim < inducing_mean.ndim:
inducing_points = inducing_points.expand(*inducing_mean.shape[:-2], *inducing_points.shape)
# TODO: add flag for conditioning into SGPR after building fantasy strategy for SGPR
new_covar_module = deepcopy(self.model.covar_module)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels a bit brittle: it's not guaranteed that people use the self.covar_module and self.mean_module convention in their model.


# update inducing mean if necessary
inducing_mean = inducing_mean.squeeze() + self.model.mean_module(inducing_points)
wjmaddox marked this conversation as resolved.
Show resolved Hide resolved

inducing_exact_model = _BaseExactGP(
inducing_points,
inducing_mean,
mean_module=deepcopy(self.model.mean_module),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this is playing the same role as this line

https://github.com/cornellius-gp/gpytorch/blob/master/gpytorch/models/exact_gp.py#L223

But here we want to copy some of the attributes of one class into a completely different class. Not sure there is a general way to do this without assuming the attribute names.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to add an informative error message if the attributes are missing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The solution we're going to try here is to require either mean / covar to be so named in the model or to require it in the kwargs, which gives the added benefit of some amount of fantasizing through updated hypers.

covar_module=new_covar_module,
likelihood=deepcopy(self.model.likelihood),
)

# now fantasize around this model
# as this model is new, we need to compute a posterior to construct the prediction strategy
# which uses the likelihood pseudo caches
faked_points = torch.randn(
*inducing_mean.shape[:-2],
1,
inducing_points.shape[-1],
device=inducing_points.device,
dtype=inducing_points.dtype,
)
inducing_exact_model.eval()
_ = inducing_exact_model(faked_points)

# then we overwrite the likelihood to take into account the multivariate normal term
pred_strat = inducing_exact_model.prediction_strategy
pred_strat._memoize_cache = {}
with torch.no_grad():
updated_lik_train_train_covar = (
pred_strat.train_prior_dist.lazy_covariance_matrix + inducing_noise_covar
)
pred_strat.lik_train_train_covar = updated_lik_train_train_covar

# do the mean cache because the mean cache doesn't solve against lik_train_train_covar
train_mean = inducing_exact_model.mean_module(*inducing_exact_model.train_inputs)
train_labels_offset = (inducing_exact_model.prediction_strategy.train_labels - train_mean).unsqueeze(-1)
mean_cache = updated_lik_train_train_covar.inv_matmul(train_labels_offset).squeeze(-1)
mean_cache = _add_cache_hook(mean_cache, inducing_exact_model.prediction_strategy)
add_to_cache(pred_strat, "mean_cache", mean_cache)
# TODO: check to see if we need to do the covar_cache?

inducing_exact_model.prediction_strategy = pred_strat
return inducing_exact_model

def pseudo_points(self):
raise NotImplementedError("Each variational strategy must implement its own pseudo points method")

def get_fantasy_model(
self,
inputs,
targets,
**kwargs,
):
r"""
Performs the online variational conditioning (OVC) strategy of Maddox et al, '21 to return
an exact GP model that incorporates the inputs and targets alongside the variational model's inducing
points and targets.

Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,"
Maddox, Stanton, Wilson, NeurIPS, '21
https://papers.nips.cc/paper/2021/hash/325eaeac5bef34937cfdc1bd73034d17-Abstract.html
"""

# currently, we only support fantasization for CholeskyVariationalDistribution and
# whitened / unwhitened variational strategies
# from .variational_strategy import VariationalStrategy
# from .unwhitened_variational_strategy import UnwhitenedVariationalStrategy
if not self.has_fantasy_strategy:
raise NotImplementedError(
"No fantasy model support for ",
self.__name__,
". Only VariationalStrategy and UnwhitenedVariationalStrategy are currently supported.",
)
# first we construct an exact model over the inducing points with the inducing covariance
# matrix
inducing_exact_model = self.inducing_model()

# then we update this model by adding in the inputs and pseudo targets
# if inputs.shape[-2] == 1 or targets.shape[-1] != 1:
# targets = targets.unsqueeze(-1)
# put on a trailing bdim for bs of 1
# finally we fantasize wrt targets
fantasy_model = inducing_exact_model.get_fantasy_model(inputs, targets, **kwargs)
wjmaddox marked this conversation as resolved.
Show resolved Hide resolved
fant_pred_strat = fantasy_model.prediction_strategy

# first we update the lik_train_train_covar
# do the mean cache again because the mean cache resets the likelihood forward
train_mean = fantasy_model.mean_module(*fantasy_model.train_inputs)
train_labels_offset = (fant_pred_strat.train_labels - train_mean).unsqueeze(-1)
fantasy_lik_train_root_inv = fant_pred_strat.lik_train_train_covar.root_inv_decomposition()
mean_cache = fantasy_lik_train_root_inv.matmul(train_labels_offset).squeeze(-1)
mean_cache = _add_cache_hook(mean_cache, fant_pred_strat)
add_to_cache(fant_pred_strat, "mean_cache", mean_cache)

fantasy_model.prediction_strategy = fant_pred_strat
return fantasy_model

def __call__(self, x, prior=False, **kwargs):
# If we're in prior mode, then we're done!
if prior:
Expand Down
56 changes: 56 additions & 0 deletions gpytorch/variational/unwhitened_variational_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import torch

from gpytorch.variational.cholesky_variational_distribution import CholeskyVariationalDistribution

from .. import settings
from ..distributions import MultivariateNormal
from ..lazy import (
Expand All @@ -17,6 +19,7 @@
)
from ..utils.broadcasting import _mul_broadcast_shape
from ..utils.cholesky import psd_safe_cholesky
from ..utils.errors import NotPSDError
from ..utils.memoize import add_to_cache, cached
from ._variational_strategy import _VariationalStrategy

Expand Down Expand Up @@ -44,6 +47,7 @@ class UnwhitenedVariationalStrategy(_VariationalStrategy):
the inducing point locations :math:`\mathbf Z` should be learned (i.e. are they
parameters of the model).
"""
has_fantasy_strategy = True

@cached(name="cholesky_factor", ignore_args=True)
def _cholesky_factor(self, induc_induc_covar):
Expand All @@ -58,6 +62,58 @@ def prior_distribution(self):
res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter())
return res

@property
@cached(name="pseudo_points_memo")
def pseudo_points(self):
# TODO: implement for other distributions
# retrieve the variational mean, m and covariance matrix, S.
if not isinstance(self._variational_distribution, CholeskyVariationalDistribution):
raise NotImplementedError(
"Only CholeskyVariationalDistribution has pseudo-point support currently, ",
"but your _variational_distribution is a ",
self._variational_distribution.__name__,
)

# retrieve the variational mean, m and covariance matrix, S.
var_cov_root = TriangularLazyTensor(self._variational_distribution.chol_variational_covar)
var_cov = CholLazyTensor(var_cov_root)
var_mean = self.variational_distribution.mean # .unsqueeze(-1)
if var_mean.shape[-1] != 1:
var_mean = var_mean.unsqueeze(-1)

# R = K - S
Kmm = self.model.covar_module(self.inducing_points)
res = Kmm - var_cov

cov_diff = res

# D_a = (S^{-1} - K^{-1})^{-1} = S + S R^{-1} S
# note that in the whitened case R = I - S, unwhitened R = K - S
# we compute (R R^{T})^{-1} R^T S for stability reasons as R is probably not PSD.
wjmaddox marked this conversation as resolved.
Show resolved Hide resolved
eval_lhs = var_cov.evaluate()
eval_rhs = cov_diff.transpose(-1, -2).matmul(eval_lhs)
inner_term = cov_diff.matmul(cov_diff.transpose(-1, -2))
# TODO: flag the jitter here
inner_solve = inner_term.add_jitter(1e-3).inv_matmul(eval_rhs, eval_lhs.transpose(-1, -2))
inducing_covar = var_cov + inner_solve

# mean term: D_a S^{-1} m
# unwhitened: (S - S R^{-1} S) S^{-1} m = (I - S R^{-1}) m
rhs = cov_diff.transpose(-1, -2).matmul(var_mean)
inner_rhs_mean_solve = inner_term.add_jitter(1e-3).inv_matmul(rhs)
new_mean = var_mean + var_cov.matmul(inner_rhs_mean_solve)

# ensure inducing covar is psd
try:
final_inducing_covar = CholLazyTensor(inducing_covar.add_jitter(1e-3).cholesky()).evaluate()
except NotPSDError:
from gpytorch.lazy import DiagLazyTensor

evals, evecs = inducing_covar.symeig(eigenvectors=True)
final_inducing_covar = evecs.matmul(DiagLazyTensor(evals + 1e-4)).matmul(evecs.transpose(-1, -2)).evaluate()

return final_inducing_covar, new_mean

def forward(self, x, inducing_points, inducing_values, variational_inducing_covar=None):
# If our points equal the inducing points, we're done
if torch.equal(x, inducing_points):
Expand Down
Loading