Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Fantasy Strategy for Variational GPs #1874

Merged
merged 18 commits into from
Jun 3, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
476 changes: 476 additions & 0 deletions examples/08_Advanced_Usage/SVGP_Model_Updating.ipynb

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions examples/08_Advanced_Usage/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ See the `1D derivatives GP example`_ or the `2D derivatives GP example`_ for exa
Simple_Batch_Mode_GP_Regression.ipynb:


Variational Fantasization
----------------------------------
We also include an example of how to perform fantasy modelling (e.g. efficient, closed form updates) for variational
Gaussian process models, enabling their usage for lookahead optimization.

.. _Variational fantasization:
SVGP_Model_Updating.ipynb

Converting Models to TorchScript
----------------------------------

Expand All @@ -73,3 +81,4 @@ how to convert both an exact GP and a variational GP to a ScriptModule that can

TorchScript_Exact_Models.ipynb
TorchScript_Variational_Models.ipynb
SVGP_Model_Updating.ipynb
27 changes: 27 additions & 0 deletions gpytorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,33 @@ def pyro_model(self, input, beta=1.0, name_prefix=""):
"""
return super().pyro_model(input, beta=beta, name_prefix=name_prefix)

def get_fantasy_model(self, inputs, targets, **kwargs):
Copy link
Contributor

@samuelstanton samuelstanton Mar 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see why it might be convenient to always have get_fantasy_model return an ExactGP, regardless of the original model class, but it might be worth considering naming this something else, reserving get_fantasy_model for the version of OVC that returns a variational GP (in other words make a package-level decision to require that get_fantasy_model always returns an instance of the original class).

Copy link
Collaborator Author

@wjmaddox wjmaddox Mar 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a thought I originally had but it requires the unstable direct updates to m, S in order to return its own model class itself rather than the exactGP. Although potentially lower overhead in the future to implement new fantasization strategies.

r"""
Returns a new GP model that incorporates the specified inputs and targets as new training data using
online variational conditioning (OVC).

This function first casts the inducing points and variational parameters into pseudo-points before
returning an equivalent ExactGP model with a specialized likelihood.

.. note::
If `targets` is a batch (e.g. `b x m`), then the GP returned from this method will be a batch mode GP.
If `inputs` is of the same (or lesser) dimension as `targets`, then it is assumed that the fantasy points
are the same for each target batch.

:param torch.Tensor inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
observations.
:param torch.Tensor targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
:return: An `ExactGP` model with `n + m` training examples, where the `m` fantasy examples have been added
and all test-time caches have been updated.
:rtype: ~gpytorch.models.ExactGP

Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,"
Maddox, Stanton, Wilson, NeurIPS, '21
https://papers.nips.cc/paper/2021/hash/325eaeac5bef34937cfdc1bd73034d17-Abstract.html

"""
return self.variational_strategy.get_fantasy_model(inputs=inputs, targets=targets, **kwargs)

def __call__(self, inputs, prior=False, **kwargs):
if inputs.dim() == 1:
inputs = inputs.unsqueeze(-1)
Expand Down
121 changes: 121 additions & 0 deletions gpytorch/test/variational_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,28 @@ def _eval_iter(self, model, batch_shape=torch.Size([]), cuda=False):

return output

def _fantasy_iter(
self,
model,
likelihood,
batch_shape=torch.Size([]),
cuda=False,
num_fant=10,
covar_module=None,
mean_module=None,
):
model.likelihood = likelihood
val_x = torch.randn(*batch_shape, num_fant, 2).clamp(-2.5, 2.5)
val_y = torch.linspace(-1, 1, num_fant)
val_y = val_y.view(num_fant, *([1] * (len(self.event_shape) - 1)))
val_y = val_y.expand(*batch_shape, num_fant, *self.event_shape[1:])
if cuda:
model = model.cuda()
val_x = val_x.cuda()
val_y = val_y.cuda()
updated_model = model.get_fantasy_model(val_x, val_y, covar_module=covar_module, mean_module=mean_module)
return updated_model

@abstractproperty
def batch_shape(self):
raise NotImplementedError
Expand Down Expand Up @@ -272,3 +294,102 @@ def test_training_all_batch_zero_mean(self):
expected_batch_shape=(torch.Size([3, 4]) + self.batch_shape),
constant_mean=False,
)

def test_fantasy_call(
self,
data_batch_shape=None,
inducing_batch_shape=None,
model_batch_shape=None,
expected_batch_shape=None,
constant_mean=True,
):
# Batch shapes
model_batch_shape = model_batch_shape if model_batch_shape is not None else self.batch_shape
data_batch_shape = data_batch_shape if data_batch_shape is not None else self.batch_shape
inducing_batch_shape = inducing_batch_shape if inducing_batch_shape is not None else self.batch_shape
expected_batch_shape = expected_batch_shape if expected_batch_shape is not None else self.batch_shape

num_inducing = 16
num_fant = 10
# Make model and likelihood
model, likelihood = self._make_model_and_likelihood(
batch_shape=model_batch_shape,
inducing_batch_shape=inducing_batch_shape,
distribution_cls=self.distribution_cls,
strategy_cls=self.strategy_cls,
constant_mean=constant_mean,
num_inducing=num_inducing,
)

# we iterate through the covar and mean module possible settings
covar_mean_options = [
{"covar_module": None, "mean_module": None},
{"covar_module": gpytorch.kernels.MaternKernel(), "mean_module": gpytorch.means.ZeroMean()},
]
for cm_dict in covar_mean_options:
fant_model = self._fantasy_iter(
model, likelihood, data_batch_shape, self.cuda, num_fant=num_fant, **cm_dict
)
self.assertTrue(isinstance(fant_model, gpytorch.models.ExactGP))

# we check to ensure setting the covar_module and mean_modules are okay
if cm_dict["covar_module"] is None:
self.assertEqual(type(fant_model.covar_module), type(model.covar_module))
else:
self.assertNotEqual(type(fant_model.covar_module), type(model.covar_module))
if cm_dict["mean_module"] is None:
self.assertEqual(type(fant_model.mean_module), type(model.mean_module))
else:
self.assertNotEqual(type(fant_model.mean_module), type(model.mean_module))

# now we check to ensure the shapes of the fantasy strategy are correct
self.assertTrue(fant_model.prediction_strategy is not None)
for key in fant_model.prediction_strategy._memoize_cache.keys():
if key[0] == "mean_cache":
break
mean_cache = fant_model.prediction_strategy._memoize_cache[key]
self.assertEqual(mean_cache.shape, torch.Size([*expected_batch_shape, num_inducing + num_fant]))

# we remove the mean_module and covar_module and check for errors
del model.mean_module
with self.assertRaises(ModuleNotFoundError):
self._fantasy_iter(model, likelihood, data_batch_shape, self.cuda, num_fant=num_fant)

model.mean_module = gpytorch.means.ZeroMean()
del model.covar_module
with self.assertRaises(ModuleNotFoundError):
self._fantasy_iter(model, likelihood, data_batch_shape, self.cuda, num_fant=num_fant)

# finally we check to ensure failure for a non-gaussian likelihood
with self.assertRaises(NotImplementedError):
self._fantasy_iter(
model,
gpytorch.likelihoods.BernoulliLikelihood(),
data_batch_shape,
self.cuda,
num_fant=num_fant,
)

def test_fantasy_call_batch_inducing(self):
return self.test_fantasy_call(
model_batch_shape=(torch.Size([3]) + self.batch_shape),
data_batch_shape=self.batch_shape,
inducing_batch_shape=(torch.Size([3]) + self.batch_shape),
expected_batch_shape=(torch.Size([3]) + self.batch_shape),
)

def test_fantasy_call_batch_data(self):
return self.test_fantasy_call(
model_batch_shape=self.batch_shape,
inducing_batch_shape=self.batch_shape,
data_batch_shape=(torch.Size([3]) + self.batch_shape),
expected_batch_shape=(torch.Size([3]) + self.batch_shape),
)

def test_fantasy_call_batch_model(self):
return self.test_fantasy_call(
model_batch_shape=(torch.Size([3]) + self.batch_shape),
inducing_batch_shape=self.batch_shape,
data_batch_shape=self.batch_shape,
expected_batch_shape=(torch.Size([3]) + self.batch_shape),
)
166 changes: 165 additions & 1 deletion gpytorch/variational/_variational_strategy.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,47 @@
#!/usr/bin/env python3

import functools
from abc import ABC, abstractproperty
from copy import deepcopy

import torch

from .. import settings
from ..distributions import Delta, MultivariateNormal
from ..likelihoods import GaussianLikelihood
from ..models import ExactGP
from ..module import Module
from ..utils.broadcasting import _mul_broadcast_shape
from ..utils.memoize import cached, clear_cache_hook
from ..utils.memoize import add_to_cache, cached, clear_cache_hook


class _BaseExactGP(ExactGP):
def __init__(self, train_inputs, train_targets, likelihood, mean_module, covar_module):
super().__init__(train_inputs, train_targets, likelihood)
self.mean_module = mean_module
self.covar_module = covar_module

def forward(self, x):
mean = self.mean_module(x)
covar = self.covar_module(x)
return MultivariateNormal(mean, covar)


def _add_cache_hook(tsr, pred_strat):
if tsr.grad_fn is not None:
wrapper = functools.partial(clear_cache_hook, pred_strat)
functools.update_wrapper(wrapper, clear_cache_hook)
tsr.grad_fn.register_hook(wrapper)
return tsr


class _VariationalStrategy(Module, ABC):
"""
Abstract base class for all Variational Strategies.
"""

has_fantasy_strategy = False

def __init__(self, model, inducing_points, variational_distribution, learn_inducing_locations=True):
super().__init__()

Expand Down Expand Up @@ -97,6 +123,144 @@ def kl_divergence(self):
kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
return kl_divergence

@cached(name="amortized_exact_gp")
def amortized_exact_gp(self, mean_module=None, covar_module=None):
mean_module = self.model.mean_module if mean_module is None else mean_module
covar_module = self.model.covar_module if covar_module is None else covar_module

with torch.no_grad():
# from here on down, we refer to the inducing points as pseudo_inputs
pseudo_target_covar, pseudo_target_mean = self.pseudo_points
pseudo_inputs = self.inducing_points.detach()
if pseudo_inputs.ndim < pseudo_target_mean.ndim:
pseudo_inputs = pseudo_inputs.expand(*pseudo_target_mean.shape[:-2], *pseudo_inputs.shape)
# TODO: add flag for conditioning into SGPR after building fantasy strategy for SGPR
new_covar_module = deepcopy(covar_module)

# update inducing mean if necessary
pseudo_target_mean = pseudo_target_mean.squeeze() + mean_module(pseudo_inputs)

inducing_exact_model = _BaseExactGP(
pseudo_inputs,
pseudo_target_mean,
mean_module=deepcopy(mean_module),
covar_module=new_covar_module,
likelihood=deepcopy(self.model.likelihood),
)

# now fantasize around this model
# as this model is new, we need to compute a posterior to construct the prediction strategy
# which uses the likelihood pseudo caches
faked_points = torch.randn(
*pseudo_target_mean.shape[:-2],
1,
pseudo_inputs.shape[-1],
device=pseudo_inputs.device,
dtype=pseudo_inputs.dtype,
)
inducing_exact_model.eval()
_ = inducing_exact_model(faked_points)

# then we overwrite the likelihood to take into account the multivariate normal term
pred_strat = inducing_exact_model.prediction_strategy
pred_strat._memoize_cache = {}
with torch.no_grad():
updated_lik_train_train_covar = pred_strat.train_prior_dist.lazy_covariance_matrix + pseudo_target_covar
pred_strat.lik_train_train_covar = updated_lik_train_train_covar

# do the mean cache because the mean cache doesn't solve against lik_train_train_covar
train_mean = inducing_exact_model.mean_module(*inducing_exact_model.train_inputs)
train_labels_offset = (inducing_exact_model.prediction_strategy.train_labels - train_mean).unsqueeze(-1)
mean_cache = updated_lik_train_train_covar.inv_matmul(train_labels_offset).squeeze(-1)
mean_cache = _add_cache_hook(mean_cache, inducing_exact_model.prediction_strategy)
add_to_cache(pred_strat, "mean_cache", mean_cache)
# TODO: check to see if we need to do the covar_cache?

inducing_exact_model.prediction_strategy = pred_strat
return inducing_exact_model

def pseudo_points(self):
raise NotImplementedError("Each variational strategy must implement its own pseudo points method")

def get_fantasy_model(
self,
inputs,
targets,
mean_module=None,
covar_module=None,
**kwargs,
):
r"""
Performs the online variational conditioning (OVC) strategy of Maddox et al, '21 to return
an exact GP model that incorporates the inputs and targets alongside the variational model's inducing
points and targets.

Currently, instead of directly updating the variational parameters (and inducing points), we instead
return an ExactGP model rather than an updated variational GP model. This is done primarily for
numerical stability.

Unlike the ExactGP's call for get_fantasy_model, we enable options for mean_module and covar_module
that allow specification of the mean / covariance.

gpleiss marked this conversation as resolved.
Show resolved Hide resolved
Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,"
Maddox, Stanton, Wilson, NeurIPS, '21
https://papers.nips.cc/paper/2021/hash/325eaeac5bef34937cfdc1bd73034d17-Abstract.html
"""

# currently, we only support fantasization for CholeskyVariationalDistribution and
# whitened / unwhitened variational strategies
if not self.has_fantasy_strategy:
raise NotImplementedError(
"No fantasy model support for ",
self.__name__,
". Only VariationalStrategy and UnwhitenedVariationalStrategy are currently supported.",
)
if not isinstance(self.model.likelihood, GaussianLikelihood):
raise NotImplementedError(
"No fantasy model support for ",
self.model.likelihood,
". Only GaussianLikelihoods are currently supported.",
)
# we assume that either the user has given the model a mean_module and a covar_module
# or that it will be passed into the get_fantasy_model function. we check for these.
if mean_module is None:
mean_module = getattr(self.model, "mean_module", None)
if mean_module is None:
raise ModuleNotFoundError(
"Either you must provide a mean_module as input to get_fantasy_model",
"or it must be an attribute of the model.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be a little more explicit? "or it must be an attribute of the model named mean_module." or something like that.

)
if covar_module is None:
covar_module = getattr(self.model, "covar_module", None)
if covar_module is None:
# raise an error
raise ModuleNotFoundError(
"Either you must provide a covar_module as input to get_fantasy_model",
"or it must be an attribute of the model.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here "or it must be an attribute of the model named covar_module." or something like that.

)

# first we construct an exact model over the inducing points with the inducing covariance
# matrix
inducing_exact_model = self.amortized_exact_gp(mean_module=mean_module, covar_module=covar_module)

# then we update this model by adding in the inputs and pseudo targets
# finally we fantasize wrt targets
fantasy_model = inducing_exact_model.get_fantasy_model(inputs, targets, **kwargs)
wjmaddox marked this conversation as resolved.
Show resolved Hide resolved
fant_pred_strat = fantasy_model.prediction_strategy

# first we update the lik_train_train_covar
# do the mean cache again because the mean cache resets the likelihood forward
train_mean = fantasy_model.mean_module(*fantasy_model.train_inputs)
train_labels_offset = (fant_pred_strat.train_labels - train_mean).unsqueeze(-1)
fantasy_lik_train_root_inv = fant_pred_strat.lik_train_train_covar.root_inv_decomposition()
mean_cache = fantasy_lik_train_root_inv.matmul(train_labels_offset).squeeze(-1)
mean_cache = _add_cache_hook(mean_cache, fant_pred_strat)
add_to_cache(fant_pred_strat, "mean_cache", mean_cache)
# TODO: should we update the covar_cache?

fantasy_model.prediction_strategy = fant_pred_strat
return fantasy_model

def __call__(self, x, prior=False, **kwargs):
# If we're in prior mode, then we're done!
if prior:
Expand Down
Loading