Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Fantasy Strategy for Variational GPs #1874

Merged
merged 18 commits into from
Jun 3, 2022

Conversation

wjmaddox
Copy link
Collaborator

@wjmaddox wjmaddox commented Dec 27, 2021

Implements the online variational conditioning (OVC) procedure described in https://papers.nips.cc/paper/2021/hash/325eaeac5bef34937cfdc1bd73034d17-Abstract.html for Gaussian models. Broadly, this enables fantasy models and closed form updates to variational GPs (especially SVGPs) as more data comes in, for example, in the streaming setting or during lookaheads in Bayes opt.

Currently, this PR will only support the procedure for Gaussian responses (hooking up a Newton iteration strategy for any 1d likelihood is a good next PR).

ToDos:

  • test out current version of strategy
  • basic unit tests
  • tutorial notebook (edit: needs some word-smithing)
  • unit tests in examples
  • documentation doesn't fail

cc @samuelstanton

@wjmaddox wjmaddox changed the title [Draft] Add Fantasy Strategy for Variational GPs Add Fantasy Strategy for Variational GPs Mar 2, 2022
@wjmaddox wjmaddox requested a review from gpleiss March 3, 2022 16:55
@gpleiss
Copy link
Member

gpleiss commented Mar 4, 2022

I'll review this early next week!

Copy link
Member

@gpleiss gpleiss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love the idea of calling self.model.mean_module and self.model.covar_module in the fantasy model function - because it's not guaranteed that people will follow this naming convention. However, I'm not sure what the best solution is here either...

gpytorch/test/variational_test_case.py Outdated Show resolved Hide resolved
if inducing_points.ndim < inducing_mean.ndim:
inducing_points = inducing_points.expand(*inducing_mean.shape[:-2], *inducing_points.shape)
# TODO: add flag for conditioning into SGPR after building fantasy strategy for SGPR
new_covar_module = deepcopy(self.model.covar_module)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels a bit brittle: it's not guaranteed that people use the self.covar_module and self.mean_module convention in their model.

inducing_exact_model = _BaseExactGP(
inducing_points,
inducing_mean,
mean_module=deepcopy(self.model.mean_module),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this is playing the same role as this line

https://github.com/cornellius-gp/gpytorch/blob/master/gpytorch/models/exact_gp.py#L223

But here we want to copy some of the attributes of one class into a completely different class. Not sure there is a general way to do this without assuming the attribute names.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to add an informative error message if the attributes are missing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The solution we're going to try here is to require either mean / covar to be so named in the model or to require it in the kwargs, which gives the added benefit of some amount of fantasizing through updated hypers.

@@ -76,6 +76,30 @@ def pyro_model(self, input, beta=1.0, name_prefix=""):
return super().pyro_model(input, beta=beta, name_prefix=name_prefix)

def get_fantasy_model(self, inputs, targets, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see why it might be convenient to always have get_fantasy_model return an ExactGP, regardless of the original model class, but it might be worth considering naming this something else, reserving get_fantasy_model for the version of OVC that returns a variational GP (in other words make a package-level decision to require that get_fantasy_model always returns an instance of the original class).

Copy link
Collaborator Author

@wjmaddox wjmaddox Mar 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a thought I originally had but it requires the unstable direct updates to m, S in order to return its own model class itself rather than the exactGP. Although potentially lower overhead in the future to implement new fantasization strategies.

@samuelstanton
Copy link
Contributor

The details of the PR look correct to me, and I think the implementation is really sensible (variational GP --> equivalent ExactGP --> usual ExactGP conditioning machinery). I think it could made a lot more clear in the docs/variable names that 1) this transformation is happening and 2) the equivalent transform is something you can always compute from a variational GP (I could imagine it might have other uses besides efficient conditioning).

@wjmaddox wjmaddox requested a review from gpleiss April 14, 2022 16:59
@wjmaddox
Copy link
Collaborator Author

I'm still not sure what the docs issue is.

@Balandat
Copy link
Collaborator

It says

/home/docs/checkouts/readthedocs.org/user_builds/gpytorch/checkouts/1874/docs/source/examples/08_Advanced_Usage/SVGP_Model_Updating.ipynb: WARNING: document isn't included in any toctree

Copy link
Member

@gpleiss gpleiss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few tiny doc string/error message fixes, but then I think it's good to go!

gpytorch/variational/_variational_strategy.py Show resolved Hide resolved
if mean_module is None:
raise ModuleNotFoundError(
"Either you must provide a mean_module as input to get_fantasy_model",
"or it must be an attribute of the model.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be a little more explicit? "or it must be an attribute of the model named mean_module." or something like that.

# raise an error
raise ModuleNotFoundError(
"Either you must provide a covar_module as input to get_fantasy_model",
"or it must be an attribute of the model.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here "or it must be an attribute of the model named covar_module." or something like that.

@wjmaddox
Copy link
Collaborator Author

wjmaddox commented Jun 3, 2022

Sorry for taking a while to clean that up, but hopefully the wording is a lot more clear now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants