Skip to content

Commit

Permalink
Add Fantasy Strategy for Variational GPs (#1874)
Browse files Browse the repository at this point in the history
* drop in basic parts of strategies

* add fantasization notebook

* add basic unit tests

* remove extraneous prints

* fixup unit tests

* add fantasy tests in var. examples

* attempt to fix documentation

* update rst

* add mean/covar flags

* add svgp model updating to toctree

* update docs
  • Loading branch information
wjmaddox committed Jun 3, 2022
1 parent 1eb9dbd commit 52bf07a
Show file tree
Hide file tree
Showing 17 changed files with 1,005 additions and 4 deletions.
476 changes: 476 additions & 0 deletions examples/08_Advanced_Usage/SVGP_Model_Updating.ipynb

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions examples/08_Advanced_Usage/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ See the `1D derivatives GP example`_ or the `2D derivatives GP example`_ for exa
Simple_Batch_Mode_GP_Regression.ipynb:


Variational Fantasization
----------------------------------
We also include an example of how to perform fantasy modelling (e.g. efficient, closed form updates) for variational
Gaussian process models, enabling their usage for lookahead optimization.

.. _Variational fantasization:
SVGP_Model_Updating.ipynb

Converting Models to TorchScript
----------------------------------

Expand All @@ -73,3 +81,4 @@ how to convert both an exact GP and a variational GP to a ScriptModule that can

TorchScript_Exact_Models.ipynb
TorchScript_Variational_Models.ipynb
SVGP_Model_Updating.ipynb
27 changes: 27 additions & 0 deletions gpytorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,33 @@ def pyro_model(self, input, beta=1.0, name_prefix=""):
"""
return super().pyro_model(input, beta=beta, name_prefix=name_prefix)

def get_fantasy_model(self, inputs, targets, **kwargs):
r"""
Returns a new GP model that incorporates the specified inputs and targets as new training data using
online variational conditioning (OVC).
This function first casts the inducing points and variational parameters into pseudo-points before
returning an equivalent ExactGP model with a specialized likelihood.
.. note::
If `targets` is a batch (e.g. `b x m`), then the GP returned from this method will be a batch mode GP.
If `inputs` is of the same (or lesser) dimension as `targets`, then it is assumed that the fantasy points
are the same for each target batch.
:param torch.Tensor inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
observations.
:param torch.Tensor targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
:return: An `ExactGP` model with `n + m` training examples, where the `m` fantasy examples have been added
and all test-time caches have been updated.
:rtype: ~gpytorch.models.ExactGP
Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,"
Maddox, Stanton, Wilson, NeurIPS, '21
https://papers.nips.cc/paper/2021/hash/325eaeac5bef34937cfdc1bd73034d17-Abstract.html
"""
return self.variational_strategy.get_fantasy_model(inputs=inputs, targets=targets, **kwargs)

def __call__(self, inputs, prior=False, **kwargs):
if inputs.dim() == 1:
inputs = inputs.unsqueeze(-1)
Expand Down
121 changes: 121 additions & 0 deletions gpytorch/test/variational_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,28 @@ def _eval_iter(self, model, batch_shape=torch.Size([]), cuda=False):

return output

def _fantasy_iter(
self,
model,
likelihood,
batch_shape=torch.Size([]),
cuda=False,
num_fant=10,
covar_module=None,
mean_module=None,
):
model.likelihood = likelihood
val_x = torch.randn(*batch_shape, num_fant, 2).clamp(-2.5, 2.5)
val_y = torch.linspace(-1, 1, num_fant)
val_y = val_y.view(num_fant, *([1] * (len(self.event_shape) - 1)))
val_y = val_y.expand(*batch_shape, num_fant, *self.event_shape[1:])
if cuda:
model = model.cuda()
val_x = val_x.cuda()
val_y = val_y.cuda()
updated_model = model.get_fantasy_model(val_x, val_y, covar_module=covar_module, mean_module=mean_module)
return updated_model

@abstractproperty
def batch_shape(self):
raise NotImplementedError
Expand Down Expand Up @@ -272,3 +294,102 @@ def test_training_all_batch_zero_mean(self):
expected_batch_shape=(torch.Size([3, 4]) + self.batch_shape),
constant_mean=False,
)

def test_fantasy_call(
self,
data_batch_shape=None,
inducing_batch_shape=None,
model_batch_shape=None,
expected_batch_shape=None,
constant_mean=True,
):
# Batch shapes
model_batch_shape = model_batch_shape if model_batch_shape is not None else self.batch_shape
data_batch_shape = data_batch_shape if data_batch_shape is not None else self.batch_shape
inducing_batch_shape = inducing_batch_shape if inducing_batch_shape is not None else self.batch_shape
expected_batch_shape = expected_batch_shape if expected_batch_shape is not None else self.batch_shape

num_inducing = 16
num_fant = 10
# Make model and likelihood
model, likelihood = self._make_model_and_likelihood(
batch_shape=model_batch_shape,
inducing_batch_shape=inducing_batch_shape,
distribution_cls=self.distribution_cls,
strategy_cls=self.strategy_cls,
constant_mean=constant_mean,
num_inducing=num_inducing,
)

# we iterate through the covar and mean module possible settings
covar_mean_options = [
{"covar_module": None, "mean_module": None},
{"covar_module": gpytorch.kernels.MaternKernel(), "mean_module": gpytorch.means.ZeroMean()},
]
for cm_dict in covar_mean_options:
fant_model = self._fantasy_iter(
model, likelihood, data_batch_shape, self.cuda, num_fant=num_fant, **cm_dict
)
self.assertTrue(isinstance(fant_model, gpytorch.models.ExactGP))

# we check to ensure setting the covar_module and mean_modules are okay
if cm_dict["covar_module"] is None:
self.assertEqual(type(fant_model.covar_module), type(model.covar_module))
else:
self.assertNotEqual(type(fant_model.covar_module), type(model.covar_module))
if cm_dict["mean_module"] is None:
self.assertEqual(type(fant_model.mean_module), type(model.mean_module))
else:
self.assertNotEqual(type(fant_model.mean_module), type(model.mean_module))

# now we check to ensure the shapes of the fantasy strategy are correct
self.assertTrue(fant_model.prediction_strategy is not None)
for key in fant_model.prediction_strategy._memoize_cache.keys():
if key[0] == "mean_cache":
break
mean_cache = fant_model.prediction_strategy._memoize_cache[key]
self.assertEqual(mean_cache.shape, torch.Size([*expected_batch_shape, num_inducing + num_fant]))

# we remove the mean_module and covar_module and check for errors
del model.mean_module
with self.assertRaises(ModuleNotFoundError):
self._fantasy_iter(model, likelihood, data_batch_shape, self.cuda, num_fant=num_fant)

model.mean_module = gpytorch.means.ZeroMean()
del model.covar_module
with self.assertRaises(ModuleNotFoundError):
self._fantasy_iter(model, likelihood, data_batch_shape, self.cuda, num_fant=num_fant)

# finally we check to ensure failure for a non-gaussian likelihood
with self.assertRaises(NotImplementedError):
self._fantasy_iter(
model,
gpytorch.likelihoods.BernoulliLikelihood(),
data_batch_shape,
self.cuda,
num_fant=num_fant,
)

def test_fantasy_call_batch_inducing(self):
return self.test_fantasy_call(
model_batch_shape=(torch.Size([3]) + self.batch_shape),
data_batch_shape=self.batch_shape,
inducing_batch_shape=(torch.Size([3]) + self.batch_shape),
expected_batch_shape=(torch.Size([3]) + self.batch_shape),
)

def test_fantasy_call_batch_data(self):
return self.test_fantasy_call(
model_batch_shape=self.batch_shape,
inducing_batch_shape=self.batch_shape,
data_batch_shape=(torch.Size([3]) + self.batch_shape),
expected_batch_shape=(torch.Size([3]) + self.batch_shape),
)

def test_fantasy_call_batch_model(self):
return self.test_fantasy_call(
model_batch_shape=(torch.Size([3]) + self.batch_shape),
inducing_batch_shape=self.batch_shape,
data_batch_shape=self.batch_shape,
expected_batch_shape=(torch.Size([3]) + self.batch_shape),
)
180 changes: 179 additions & 1 deletion gpytorch/variational/_variational_strategy.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,47 @@
#!/usr/bin/env python3

import functools
from abc import ABC, abstractproperty
from copy import deepcopy

import torch

from .. import settings
from ..distributions import Delta, MultivariateNormal
from ..likelihoods import GaussianLikelihood
from ..models import ExactGP
from ..module import Module
from ..utils.broadcasting import _mul_broadcast_shape
from ..utils.memoize import cached, clear_cache_hook
from ..utils.memoize import add_to_cache, cached, clear_cache_hook


class _BaseExactGP(ExactGP):
def __init__(self, train_inputs, train_targets, likelihood, mean_module, covar_module):
super().__init__(train_inputs, train_targets, likelihood)
self.mean_module = mean_module
self.covar_module = covar_module

def forward(self, x):
mean = self.mean_module(x)
covar = self.covar_module(x)
return MultivariateNormal(mean, covar)


def _add_cache_hook(tsr, pred_strat):
if tsr.grad_fn is not None:
wrapper = functools.partial(clear_cache_hook, pred_strat)
functools.update_wrapper(wrapper, clear_cache_hook)
tsr.grad_fn.register_hook(wrapper)
return tsr


class _VariationalStrategy(Module, ABC):
"""
Abstract base class for all Variational Strategies.
"""

has_fantasy_strategy = False

def __init__(self, model, inducing_points, variational_distribution, learn_inducing_locations=True):
super().__init__()

Expand Down Expand Up @@ -97,6 +123,158 @@ def kl_divergence(self):
kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution)
return kl_divergence

@cached(name="amortized_exact_gp")
def amortized_exact_gp(self, mean_module=None, covar_module=None):
mean_module = self.model.mean_module if mean_module is None else mean_module
covar_module = self.model.covar_module if covar_module is None else covar_module

with torch.no_grad():
# from here on down, we refer to the inducing points as pseudo_inputs
pseudo_target_covar, pseudo_target_mean = self.pseudo_points
pseudo_inputs = self.inducing_points.detach()
if pseudo_inputs.ndim < pseudo_target_mean.ndim:
pseudo_inputs = pseudo_inputs.expand(*pseudo_target_mean.shape[:-2], *pseudo_inputs.shape)
# TODO: add flag for conditioning into SGPR after building fantasy strategy for SGPR
new_covar_module = deepcopy(covar_module)

# update inducing mean if necessary
pseudo_target_mean = pseudo_target_mean.squeeze() + mean_module(pseudo_inputs)

inducing_exact_model = _BaseExactGP(
pseudo_inputs,
pseudo_target_mean,
mean_module=deepcopy(mean_module),
covar_module=new_covar_module,
likelihood=deepcopy(self.model.likelihood),
)

# now fantasize around this model
# as this model is new, we need to compute a posterior to construct the prediction strategy
# which uses the likelihood pseudo caches
faked_points = torch.randn(
*pseudo_target_mean.shape[:-2],
1,
pseudo_inputs.shape[-1],
device=pseudo_inputs.device,
dtype=pseudo_inputs.dtype,
)
inducing_exact_model.eval()
_ = inducing_exact_model(faked_points)

# then we overwrite the likelihood to take into account the multivariate normal term
pred_strat = inducing_exact_model.prediction_strategy
pred_strat._memoize_cache = {}
with torch.no_grad():
updated_lik_train_train_covar = pred_strat.train_prior_dist.lazy_covariance_matrix + pseudo_target_covar
pred_strat.lik_train_train_covar = updated_lik_train_train_covar

# do the mean cache because the mean cache doesn't solve against lik_train_train_covar
train_mean = inducing_exact_model.mean_module(*inducing_exact_model.train_inputs)
train_labels_offset = (inducing_exact_model.prediction_strategy.train_labels - train_mean).unsqueeze(-1)
mean_cache = updated_lik_train_train_covar.inv_matmul(train_labels_offset).squeeze(-1)
mean_cache = _add_cache_hook(mean_cache, inducing_exact_model.prediction_strategy)
add_to_cache(pred_strat, "mean_cache", mean_cache)
# TODO: check to see if we need to do the covar_cache?

inducing_exact_model.prediction_strategy = pred_strat
return inducing_exact_model

def pseudo_points(self):
raise NotImplementedError("Each variational strategy must implement its own pseudo points method")

def get_fantasy_model(
self,
inputs,
targets,
mean_module=None,
covar_module=None,
**kwargs,
):
r"""
Performs the online variational conditioning (OVC) strategy of Maddox et al, '21 to return
an exact GP model that incorporates the inputs and targets alongside the variational model's inducing
points and targets.
Currently, instead of directly updating the variational parameters (and inducing points), we instead
return an ExactGP model rather than an updated variational GP model. This is done primarily for
numerical stability.
Unlike the ExactGP's call for get_fantasy_model, we enable options for mean_module and covar_module
that allow specification of the mean / covariance. We expect that either the mean and covariance
modules are attributes of the model itself called mean_module and covar_module respectively OR that you
pass them into this method explicitly.
:param torch.Tensor inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
observations.
:param torch.Tensor targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
:param torch.nn.Module mean_module: torch module describing the mean function of the GP model. Optional if
`mean_module` is already an attribute of the variational GP.
:param torch.nn.Module covar_module: torch module describing the covariance function of the GP model. Optional
if `covar_module` is already an attribute of the variational GP.
:return: An `ExactGP` model with `k + m` training examples, where the `m` fantasy examples have been added
and all test-time caches have been updated. We assume that there are `k` inducing points in this variational
GP. Note that we return an `ExactGP` rather than a variational GP.
:rtype: ~gpytorch.models.ExactGP
Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,"
Maddox, Stanton, Wilson, NeurIPS, '21
https://papers.nips.cc/paper/2021/hash/325eaeac5bef34937cfdc1bd73034d17-Abstract.html
"""

# currently, we only support fantasization for CholeskyVariationalDistribution and
# whitened / unwhitened variational strategies
if not self.has_fantasy_strategy:
raise NotImplementedError(
"No fantasy model support for ",
self.__name__,
". Only VariationalStrategy and UnwhitenedVariationalStrategy are currently supported.",
)
if not isinstance(self.model.likelihood, GaussianLikelihood):
raise NotImplementedError(
"No fantasy model support for ",
self.model.likelihood,
". Only GaussianLikelihoods are currently supported.",
)
# we assume that either the user has given the model a mean_module and a covar_module
# or that it will be passed into the get_fantasy_model function. we check for these.
if mean_module is None:
mean_module = getattr(self.model, "mean_module", None)
if mean_module is None:
raise ModuleNotFoundError(
"Either you must provide a mean_module as input to get_fantasy_model",
"or it must be an attribute of the model called mean_module.",
)
if covar_module is None:
covar_module = getattr(self.model, "covar_module", None)
if covar_module is None:
# raise an error
raise ModuleNotFoundError(
"Either you must provide a covar_module as input to get_fantasy_model",
"or it must be an attribute of the model called covar_module.",
)

# first we construct an exact model over the inducing points with the inducing covariance
# matrix
inducing_exact_model = self.amortized_exact_gp(mean_module=mean_module, covar_module=covar_module)

# then we update this model by adding in the inputs and pseudo targets
# finally we fantasize wrt targets
fantasy_model = inducing_exact_model.get_fantasy_model(inputs, targets, **kwargs)
fant_pred_strat = fantasy_model.prediction_strategy

# first we update the lik_train_train_covar
# do the mean cache again because the mean cache resets the likelihood forward
train_mean = fantasy_model.mean_module(*fantasy_model.train_inputs)
train_labels_offset = (fant_pred_strat.train_labels - train_mean).unsqueeze(-1)
fantasy_lik_train_root_inv = fant_pred_strat.lik_train_train_covar.root_inv_decomposition()
mean_cache = fantasy_lik_train_root_inv.matmul(train_labels_offset).squeeze(-1)
mean_cache = _add_cache_hook(mean_cache, fant_pred_strat)
add_to_cache(fant_pred_strat, "mean_cache", mean_cache)
# TODO: should we update the covar_cache?

fantasy_model.prediction_strategy = fant_pred_strat
return fantasy_model

def __call__(self, x, prior=False, **kwargs):
# If we're in prior mode, then we're done!
if prior:
Expand Down
Loading

0 comments on commit 52bf07a

Please sign in to comment.