-
Notifications
You must be signed in to change notification settings - Fork 546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Fantasy Strategy for Variational GPs #1874
Changes from 16 commits
f7a9dda
176ab53
7d018ad
6c5b694
0e4ca31
dd78889
45d97dd
906b40b
65fc084
e26a029
7394641
cb3cce7
1787a73
46ceffb
1a47887
704da18
785688b
9bc2bc4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,47 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import functools | ||
from abc import ABC, abstractproperty | ||
from copy import deepcopy | ||
|
||
import torch | ||
|
||
from .. import settings | ||
from ..distributions import Delta, MultivariateNormal | ||
from ..likelihoods import GaussianLikelihood | ||
from ..models import ExactGP | ||
from ..module import Module | ||
from ..utils.broadcasting import _mul_broadcast_shape | ||
from ..utils.memoize import cached, clear_cache_hook | ||
from ..utils.memoize import add_to_cache, cached, clear_cache_hook | ||
|
||
|
||
class _BaseExactGP(ExactGP): | ||
def __init__(self, train_inputs, train_targets, likelihood, mean_module, covar_module): | ||
super().__init__(train_inputs, train_targets, likelihood) | ||
self.mean_module = mean_module | ||
self.covar_module = covar_module | ||
|
||
def forward(self, x): | ||
mean = self.mean_module(x) | ||
covar = self.covar_module(x) | ||
return MultivariateNormal(mean, covar) | ||
|
||
|
||
def _add_cache_hook(tsr, pred_strat): | ||
if tsr.grad_fn is not None: | ||
wrapper = functools.partial(clear_cache_hook, pred_strat) | ||
functools.update_wrapper(wrapper, clear_cache_hook) | ||
tsr.grad_fn.register_hook(wrapper) | ||
return tsr | ||
|
||
|
||
class _VariationalStrategy(Module, ABC): | ||
""" | ||
Abstract base class for all Variational Strategies. | ||
""" | ||
|
||
has_fantasy_strategy = False | ||
|
||
def __init__(self, model, inducing_points, variational_distribution, learn_inducing_locations=True): | ||
super().__init__() | ||
|
||
|
@@ -97,6 +123,144 @@ def kl_divergence(self): | |
kl_divergence = torch.distributions.kl.kl_divergence(self.variational_distribution, self.prior_distribution) | ||
return kl_divergence | ||
|
||
@cached(name="amortized_exact_gp") | ||
def amortized_exact_gp(self, mean_module=None, covar_module=None): | ||
mean_module = self.model.mean_module if mean_module is None else mean_module | ||
covar_module = self.model.covar_module if covar_module is None else covar_module | ||
|
||
with torch.no_grad(): | ||
# from here on down, we refer to the inducing points as pseudo_inputs | ||
pseudo_target_covar, pseudo_target_mean = self.pseudo_points | ||
pseudo_inputs = self.inducing_points.detach() | ||
if pseudo_inputs.ndim < pseudo_target_mean.ndim: | ||
pseudo_inputs = pseudo_inputs.expand(*pseudo_target_mean.shape[:-2], *pseudo_inputs.shape) | ||
# TODO: add flag for conditioning into SGPR after building fantasy strategy for SGPR | ||
new_covar_module = deepcopy(covar_module) | ||
|
||
# update inducing mean if necessary | ||
pseudo_target_mean = pseudo_target_mean.squeeze() + mean_module(pseudo_inputs) | ||
|
||
inducing_exact_model = _BaseExactGP( | ||
pseudo_inputs, | ||
pseudo_target_mean, | ||
mean_module=deepcopy(mean_module), | ||
covar_module=new_covar_module, | ||
likelihood=deepcopy(self.model.likelihood), | ||
) | ||
|
||
# now fantasize around this model | ||
# as this model is new, we need to compute a posterior to construct the prediction strategy | ||
# which uses the likelihood pseudo caches | ||
faked_points = torch.randn( | ||
*pseudo_target_mean.shape[:-2], | ||
1, | ||
pseudo_inputs.shape[-1], | ||
device=pseudo_inputs.device, | ||
dtype=pseudo_inputs.dtype, | ||
) | ||
inducing_exact_model.eval() | ||
_ = inducing_exact_model(faked_points) | ||
|
||
# then we overwrite the likelihood to take into account the multivariate normal term | ||
pred_strat = inducing_exact_model.prediction_strategy | ||
pred_strat._memoize_cache = {} | ||
with torch.no_grad(): | ||
updated_lik_train_train_covar = pred_strat.train_prior_dist.lazy_covariance_matrix + pseudo_target_covar | ||
pred_strat.lik_train_train_covar = updated_lik_train_train_covar | ||
|
||
# do the mean cache because the mean cache doesn't solve against lik_train_train_covar | ||
train_mean = inducing_exact_model.mean_module(*inducing_exact_model.train_inputs) | ||
train_labels_offset = (inducing_exact_model.prediction_strategy.train_labels - train_mean).unsqueeze(-1) | ||
mean_cache = updated_lik_train_train_covar.inv_matmul(train_labels_offset).squeeze(-1) | ||
mean_cache = _add_cache_hook(mean_cache, inducing_exact_model.prediction_strategy) | ||
add_to_cache(pred_strat, "mean_cache", mean_cache) | ||
# TODO: check to see if we need to do the covar_cache? | ||
|
||
inducing_exact_model.prediction_strategy = pred_strat | ||
return inducing_exact_model | ||
|
||
def pseudo_points(self): | ||
raise NotImplementedError("Each variational strategy must implement its own pseudo points method") | ||
|
||
def get_fantasy_model( | ||
self, | ||
inputs, | ||
targets, | ||
mean_module=None, | ||
covar_module=None, | ||
**kwargs, | ||
): | ||
r""" | ||
Performs the online variational conditioning (OVC) strategy of Maddox et al, '21 to return | ||
an exact GP model that incorporates the inputs and targets alongside the variational model's inducing | ||
points and targets. | ||
|
||
Currently, instead of directly updating the variational parameters (and inducing points), we instead | ||
return an ExactGP model rather than an updated variational GP model. This is done primarily for | ||
numerical stability. | ||
|
||
Unlike the ExactGP's call for get_fantasy_model, we enable options for mean_module and covar_module | ||
that allow specification of the mean / covariance. | ||
|
||
gpleiss marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making," | ||
Maddox, Stanton, Wilson, NeurIPS, '21 | ||
https://papers.nips.cc/paper/2021/hash/325eaeac5bef34937cfdc1bd73034d17-Abstract.html | ||
""" | ||
|
||
# currently, we only support fantasization for CholeskyVariationalDistribution and | ||
# whitened / unwhitened variational strategies | ||
if not self.has_fantasy_strategy: | ||
raise NotImplementedError( | ||
"No fantasy model support for ", | ||
self.__name__, | ||
". Only VariationalStrategy and UnwhitenedVariationalStrategy are currently supported.", | ||
) | ||
if not isinstance(self.model.likelihood, GaussianLikelihood): | ||
raise NotImplementedError( | ||
"No fantasy model support for ", | ||
self.model.likelihood, | ||
". Only GaussianLikelihoods are currently supported.", | ||
) | ||
# we assume that either the user has given the model a mean_module and a covar_module | ||
# or that it will be passed into the get_fantasy_model function. we check for these. | ||
if mean_module is None: | ||
mean_module = getattr(self.model, "mean_module", None) | ||
if mean_module is None: | ||
raise ModuleNotFoundError( | ||
"Either you must provide a mean_module as input to get_fantasy_model", | ||
"or it must be an attribute of the model.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this be a little more explicit? "or it must be an attribute of the model named |
||
) | ||
if covar_module is None: | ||
covar_module = getattr(self.model, "covar_module", None) | ||
if covar_module is None: | ||
# raise an error | ||
raise ModuleNotFoundError( | ||
"Either you must provide a covar_module as input to get_fantasy_model", | ||
"or it must be an attribute of the model.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same thing here "or it must be an attribute of the model named |
||
) | ||
|
||
# first we construct an exact model over the inducing points with the inducing covariance | ||
# matrix | ||
inducing_exact_model = self.amortized_exact_gp(mean_module=mean_module, covar_module=covar_module) | ||
|
||
# then we update this model by adding in the inputs and pseudo targets | ||
# finally we fantasize wrt targets | ||
fantasy_model = inducing_exact_model.get_fantasy_model(inputs, targets, **kwargs) | ||
wjmaddox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
fant_pred_strat = fantasy_model.prediction_strategy | ||
|
||
# first we update the lik_train_train_covar | ||
# do the mean cache again because the mean cache resets the likelihood forward | ||
train_mean = fantasy_model.mean_module(*fantasy_model.train_inputs) | ||
train_labels_offset = (fant_pred_strat.train_labels - train_mean).unsqueeze(-1) | ||
fantasy_lik_train_root_inv = fant_pred_strat.lik_train_train_covar.root_inv_decomposition() | ||
mean_cache = fantasy_lik_train_root_inv.matmul(train_labels_offset).squeeze(-1) | ||
mean_cache = _add_cache_hook(mean_cache, fant_pred_strat) | ||
add_to_cache(fant_pred_strat, "mean_cache", mean_cache) | ||
# TODO: should we update the covar_cache? | ||
|
||
fantasy_model.prediction_strategy = fant_pred_strat | ||
return fantasy_model | ||
|
||
def __call__(self, x, prior=False, **kwargs): | ||
# If we're in prior mode, then we're done! | ||
if prior: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see why it might be convenient to always have
get_fantasy_model
return anExactGP
, regardless of the original model class, but it might be worth considering naming this something else, reservingget_fantasy_model
for the version of OVC that returns a variational GP (in other words make a package-level decision to require thatget_fantasy_model
always returns an instance of the original class).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was a thought I originally had but it requires the unstable direct updates to
m
,S
in order to return its own model class itself rather than the exactGP. Although potentially lower overhead in the future to implement new fantasization strategies.