-
Notifications
You must be signed in to change notification settings - Fork 545
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update methods of likelihood, small pyro fixes
- Forward now defines conditional distribution - Marginal defines marginal - __call__ does either marginal or conditional - variational_log_prob -> expected_log_prob - Default implementaiotns - Use proper plating for pyro - Fix pyro_sample_y for GaussianLikelihood
- Loading branch information
Showing
14 changed files
with
254 additions
and
198 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,60 +1,140 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import torch | ||
from ..module import Module | ||
from ..distributions import MultivariateNormal | ||
from ..utils.quadrature import GaussHermiteQuadrature1D | ||
from .. import settings | ||
import warnings | ||
|
||
|
||
class Likelihood(Module): | ||
class _Likelihood(Module): | ||
""" | ||
A Likelihood in GPyTorch specifies the mapping from latent function values | ||
f to observed labels y. | ||
For example, in the case of regression this might be a Gaussian | ||
distribution, as y(x) is equal to f(x) plus Gaussian noise: | ||
y(x) = f(x) + \epsilon, \epsilon ~ N(0,\sigma^{2}_{n} I) | ||
.. math:: | ||
y(x) = f(x) + \epsilon, \epsilon ~ N(0,\sigma^{2}_{n} I) | ||
In the case of classification, this might be a Bernoulli distribution, | ||
where the probability that y=1 is given by the latent function | ||
passed through some sigmoid or probit function: | ||
y(x) = 1 w/ probability \sigma(f(x)), -1 w/ probability 1-\sigma(f(x)) | ||
y(x) = 1 w/ probability \sigma(f(x)), -1 w/ probability 1-\sigma(f(x)) | ||
In either case, to implement a (non-Gaussian) likelihood function, GPyTorch | ||
requires that two methods be implemented: | ||
requires a :attr:`forward` method that computes the conditional distribution | ||
:math:`p(y \mid f)`. | ||
1. A forward method that computes predictions p(y*|x*) given a distribution | ||
over the latent function p(f*|x*). Typically, this solves or | ||
approximates the integral: | ||
Calling this object does one of two things: | ||
p(y*|x*) = \int p(y*|f*)p(f*|x*) df* | ||
2. A variational_log_probability method that computes the log probability | ||
\log p(y|f) from a set of samples of f. This is only used for variational | ||
inference. | ||
- If likelihood is called with a :class:`torch.Tensor` object, then it is | ||
assumed that the input is samples from :math:`f(x)`. This | ||
returns the *conditional* distribution `p(y|f(x))`. | ||
- If likelihood is called with a :class:`gpytorch.distribution.MultivariateNormal` object, | ||
then it is assumed that the input is the distribution :math:`f(x)`. | ||
This returns the *marginal* distribution `p(y|x)`. | ||
""" | ||
def __init__(self): | ||
super().__init__() | ||
self.quadrature = GaussHermiteQuadrature1D() | ||
|
||
def forward(self, *inputs, **kwargs): | ||
def forward(self, function_samples, *params, **kwargs): | ||
""" | ||
Computes a predictive distribution p(y*|x*) given either a posterior | ||
distribution p(f|D,x) or a prior distribution p(f|x) as input. | ||
Computes the conditional distribution p(y|f) that defines the likelihood. | ||
With both exact inference and variational inference, the form of | ||
p(f|D,x) or p(f|x) should usually be Gaussian. As a result, input | ||
should usually be a MultivariateNormal specified by the mean and | ||
(co)variance of p(f|...). | ||
Args: | ||
:attr:`function_samples` | ||
Samples from the function `f` | ||
:attr:`kwargs` | ||
Returns: | ||
Distribution object (with same shape as :attr:`function_samples`) | ||
""" | ||
raise NotImplementedError | ||
|
||
def variational_log_probability(self, f, y): | ||
def expected_log_prob(self, observations, function_dist, *params, **kwargs): | ||
""" | ||
Compute the log likelihood p(y|f) given y and averaged over a set of | ||
latent function samples. | ||
Computes the expected log likelihood (used for variational inference): | ||
For the purposes of our variational inference implementation, y is an | ||
n-by-1 label vector, and f is an n-by-s matrix of s samples from the | ||
variational posterior, q(f|D). | ||
.. math:: | ||
\mathbb{E}_{f(x)} \left[ \log p \left( y \mid f(x) \right) \right] | ||
Args: | ||
:attr:`function_dist` (:class:`gpytorch.distributions.MultivariateNormal`) | ||
Distribution for :math:`f(x)`. | ||
:attr:`observations` (:class:`torch.Tensor`) | ||
Values of :math:`y`. | ||
:attr:`kwargs` | ||
Returns | ||
`torch.Tensor` (log probability) | ||
""" | ||
raise NotImplementedError | ||
log_prob_lambda = lambda function_samples: self.forward(function_samples).log_prob(observations) | ||
log_prob = self.quadrature(log_prob_lambda, function_dist) | ||
return log_prob.sum(tuple(range(len(function_dist.event_shape)))) | ||
|
||
def pyro_sample_y(self, variational_dist_f, y_obs, sample_shape, name_prefix=""): | ||
raise NotImplementedError | ||
def marginal(self, function_dist, *params, **kwargs): | ||
""" | ||
Computes a predictive distribution :math:`p(y*|x*)` given either a posterior | ||
distribution :math:`p(f|D,x)` or a prior distribution :math:`p(f|x)` as input. | ||
With both exact inference and variational inference, the form of | ||
:math:`p(f|D,x)` or :math:`p(f|x)` should usually be Gaussian. As a result, input | ||
should usually be a MultivariateNormal specified by the mean and | ||
(co)variance of :math:`p(f|...)`. | ||
Args: | ||
:attr:`function_dist` (:class:`gpytorch.distributions.MultivariateNormal`) | ||
Distribution for :math:`f(x)`. | ||
:attr:`kwargs` | ||
Returns | ||
Distribution object (the marginal distribution, or samples from it) | ||
""" | ||
sample_shape = torch.Size((settings.num_likelihood_samples,)) | ||
function_samples = function_dist.rsample(sample_shape) | ||
return self.forward(function_samples) | ||
|
||
def variational_log_probability(self, function_dist, observations): | ||
warnings.warn( | ||
"Likelihood.variational_log_probability is deprecated. Use Likelihood.expected_log_prob instead.", | ||
DeprecationWarning | ||
) | ||
return self.expected_log_prob(observations, function_dist) | ||
|
||
def __call__(self, input, *params, **kwargs): | ||
# Conditional | ||
if torch.is_tensor(input): | ||
return super().__call__(input, *params, **kwargs) | ||
# Marginal | ||
elif isinstance(input, MultivariateNormal): | ||
return self.marginal(input, *params, **kwargs) | ||
# Error | ||
else: | ||
raise RuntimeError( | ||
"Likelihoods expects a MultivariateNormal input to make marginal predictions, or a " | ||
"torch.Tensor for conditional predictions. Got a {}".format(input.__class__.__name__) | ||
) | ||
|
||
|
||
try: | ||
import pyro | ||
|
||
class Likelihood(_Likelihood): | ||
def pyro_sample_outputs(self, observations, function_dist, *params, **kwargs): | ||
name_prefix = kwargs.pop("name_prefix", "") | ||
with pyro.plate(name_prefix + ".output_values_plate", function_dist.batch_shape[-1], dim=-1): | ||
with pyro.poutine.block(): | ||
function_samples = pyro.sample(name_prefix + ".function_values", function_dist) | ||
output_dist = self(function_samples, *params, **kwargs) | ||
samples = pyro.sample(name_prefix + ".output_values", output_dist, obs=observations) | ||
return samples | ||
|
||
except ImportError: | ||
class Likelihood(_Likelihood): | ||
def pyro_sample_outputs(self, *args, **kwargs): | ||
raise ImportError("Failed to import pyro. Is it installed correctly?") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.