Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Fantasy Strategy for Variational GPs #1874

Merged
merged 18 commits into from
Jun 3, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions gpytorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,30 @@ def pyro_model(self, input, beta=1.0, name_prefix=""):
return super().pyro_model(input, beta=beta, name_prefix=name_prefix)

def get_fantasy_model(self, inputs, targets, **kwargs):
Copy link
Contributor

@samuelstanton samuelstanton Mar 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see why it might be convenient to always have get_fantasy_model return an ExactGP, regardless of the original model class, but it might be worth considering naming this something else, reserving get_fantasy_model for the version of OVC that returns a variational GP (in other words make a package-level decision to require that get_fantasy_model always returns an instance of the original class).

Copy link
Collaborator Author

@wjmaddox wjmaddox Mar 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a thought I originally had but it requires the unstable direct updates to m, S in order to return its own model class itself rather than the exactGP. Although potentially lower overhead in the future to implement new fantasization strategies.

r"""
Returns a new GP model that incorporates the specified inputs and targets as new training data using
online variational conditioning (OVC).

This function first casts the inducing points and variational parameters into pseudo-points before
returning an equivalent ExactGP model with a specialized likelihood.

.. note::
If `targets` is a batch (e.g. `b x m`), then the GP returned from this method will be a batch mode GP.
If `inputs` is of the same (or lesser) dimension as `targets`, then it is assumed that the fantasy points
are the same for each target batch.

:param torch.Tensor inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
observations.
:param torch.Tensor targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
:return: An `ExactGP` model with `n + m` training examples, where the `m` fantasy examples have been added
and all test-time caches have been updated.
:rtype: ~gpytorch.models.ExactGP

Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,"
Maddox, Stanton, Wilson, NeurIPS, '21
https://papers.nips.cc/paper/2021/hash/325eaeac5bef34937cfdc1bd73034d17-Abstract.html

"""
return self.variational_strategy.get_fantasy_model(inputs=inputs, targets=targets, **kwargs)

def __call__(self, inputs, prior=False, **kwargs):
Expand Down
9 changes: 5 additions & 4 deletions gpytorch/variational/_variational_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
from ..utils.broadcasting import _mul_broadcast_shape
from ..utils.memoize import add_to_cache, cached, clear_cache_hook

# from gpytorch.variational.cholesky_variational_distribution import CholeskyVariationalDistribution
# from gpytorch.variational.variational_strategy import VariationalStrategy


class _BaseExactGP(ExactGP):
def __init__(self, train_inputs, train_targets, likelihood, mean_module, covar_module):
Expand Down Expand Up @@ -188,7 +185,11 @@ def get_fantasy_model(
targets,
**kwargs,
):
"""
r"""
Performs the online variational conditioning (OVC) strategy of Maddox et al, '21 to return
an exact GP model that incorporates the inputs and targets alongside the variational model's inducing
points and targets.

Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,"
Maddox, Stanton, Wilson, NeurIPS, '21
https://papers.nips.cc/paper/2021/hash/325eaeac5bef34937cfdc1bd73034d17-Abstract.html
Expand Down