Skip to content

Commit

Permalink
Update methods of likelihood, small pyro fixes
Browse files Browse the repository at this point in the history
- Forward now defines conditional distribution
- Marginal defines marginal
- __call__ does either marginal or conditional
- variational_log_prob -> expected_log_prob
- Default implementaiotns
- Use proper plating for pyro
- Fix pyro_sample_y for GaussianLikelihood
  • Loading branch information
gpleiss committed Mar 26, 2019
1 parent 1fd1916 commit cbfd111
Show file tree
Hide file tree
Showing 14 changed files with 254 additions and 198 deletions.
10 changes: 9 additions & 1 deletion gpytorch/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,13 @@
from .multivariate_normal import MultivariateNormal
from .multitask_multivariate_normal import MultitaskMultivariateNormal

# Get the set of distributions from either PyTorch or Pyro
try:
# If pyro is installed, use that set of base distributions
import pyro.distributions as base_distributions
except ImportError:
# Otherwise, use PyTorch
import torch.distributions as base_distributions

__all__ = ["Distribution", "MultivariateNormal", "MultitaskMultivariateNormal"]

__all__ = ["Distribution", "MultivariateNormal", "MultitaskMultivariateNormal", "base_distributions"]
55 changes: 25 additions & 30 deletions gpytorch/likelihoods/bernoulli_likelihood.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
#!/usr/bin/env python3

import torch
from torch.distributions import Bernoulli
from ..distributions import MultivariateNormal
from ..utils.quadrature import GaussHermiteQuadrature1D
from ..functions import log_normal_cdf, normal_cdf
from ..distributions import base_distributions
from ..functions import log_normal_cdf
from .likelihood import Likelihood
import warnings


class BernoulliLikelihood(Likelihood):
Expand All @@ -20,31 +19,27 @@ class BernoulliLikelihood(Likelihood):
p(Y=y|f)=\Phi(yf)
\end{equation*}
"""
def __init__(self):
super().__init__()
self.quadrature = GaussHermiteQuadrature1D()
def forward(self, function_samples, **kwargs):
output_probs = base_distributions.Normal(0, 1).cdf(function_samples)
return base_distributions.Bernoulli(probs=output_probs)

def forward(self, input):
if not isinstance(input, MultivariateNormal):
raise RuntimeError(
"BernoulliLikelihood expects a multi-variate normally distributed latent function to make predictions"
)

mean = input.mean
var = input.variance
def marginal(self, function_dist, **kwargs):
mean = function_dist.mean
var = function_dist.variance
link = mean.div(torch.sqrt(1 + var))
output_probs = normal_cdf(link)
return Bernoulli(probs=output_probs)

def variational_log_probability(self, latent_func, target):
likelihood_func = lambda locs: log_normal_cdf(locs.mul(target.unsqueeze(-1)))
res = self.quadrature(likelihood_func, latent_func)
return res.sum()

def pyro_sample_y(self, variational_dist_f, y_obs, sample_shape, name_prefix=""):
import pyro

f_samples = variational_dist_f(sample_shape)
y_prob_samples = torch.distributions.Normal(0, 1).cdf(f_samples)
y_dist = pyro.distributions.Bernoulli(y_prob_samples)
pyro.sample(name_prefix + "._training_labels", y_dist.independent(1), obs=y_obs)
output_probs = base_distributions.Normal(0, 1).cdf(link)
return base_distributions.Bernoulli(probs=output_probs)

def expected_log_prob(self, observations, function_dist, *params, **kwargs):
if torch.any(observations.eq(-1)):
warnings.warn(
"BernoulliLikelihood.expected_log_prob expects observations with labels in {0, 1}. "
"Observations with labels in {-1, 1} are deprecated.", DeprecationWarning
)
else:
observations = observations.mul(2).sub(1)
# Custom function here so we can use log_normal_cdf rather than Normal.cdf
# This is going to be less prone to overflow errors
log_prob_lambda = lambda function_samples: log_normal_cdf(function_samples.mul(observations))
log_prob = self.quadrature(log_prob_lambda, function_dist)
return log_prob.sum(tuple(range(len(function_dist.event_shape))))
47 changes: 14 additions & 33 deletions gpytorch/likelihoods/gaussian_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from torch.nn.functional import softplus

from .. import settings
from ..distributions import MultivariateNormal
from ..likelihoods import Likelihood
from ..distributions import base_distributions
from .noise_models import HomoskedasticNoise


Expand All @@ -17,22 +17,24 @@ def __init__(self, noise_covar):
super().__init__()
self.noise_covar = noise_covar

def forward(self, input, *params):
if not isinstance(input, MultivariateNormal):
raise ValueError("Gaussian likelihoods require a MultivariateNormal input")
mean, covar = input.mean, input.lazy_covariance_matrix
def _shaped_noise_covar(self, base_shape, *params):
if len(params) > 0:
# we can infer the shape from the params
shape = None
else:
# here shape[:-1] is the batch shape requested, and shape[-1] is `n`, the number of points
shape = mean.shape
noise_covar = self.noise_covar(*params, shape=shape)
full_covar = covar + noise_covar
return input.__class__(mean, full_covar)
shape = base_shape
return self.noise_covar(*params, shape=shape)

def variational_log_probability(self, input, target):
raise NotImplementedError
def forward(self, function_samples, *params, **kwargs):
return base_distributions.Normal(
function_samples, self._shaped_noise_covar(function_samples.shape, *params).diag()
)

def marginal(self, function_dist, *params, **kwargs):
mean, covar = function_dist.mean, function_dist.lazy_covariance_matrix
full_covar = covar + self._shaped_noise_covar(mean.shape, *params)
return function_dist.__class__(mean, full_covar)


class GaussianLikelihood(_GaussianLikelihoodBase):
Expand Down Expand Up @@ -67,7 +69,7 @@ def raw_noise(self):
def raw_noise(self, value):
self.noise_covar.initialize(raw_noise=value)

def variational_log_probability(self, input, target):
def expected_log_prob(self, target, input, *params, **kwargs):
mean, variance = input.mean, input.variance
noise = self.noise_covar.noise

Expand All @@ -82,24 +84,3 @@ def variational_log_probability(self, input, target):
res = -0.5 * ((target - mean) ** 2 + variance) / noise
res += -0.5 * noise.log() - 0.5 * math.log(2 * math.pi)
return res.sum(-1)

def pyro_sample_y(self, variational_dist_f, y_obs, sample_shape, name_prefix=""):
import pyro

noise = self.noise
var_f = variational_dist_f.lazy_covariance_matrix.diag()
y_mean = variational_dist_f.mean
if y_mean.dim() == 1:
noise = noise.squeeze(0)

y_dist = pyro.distributions.Independent(
pyro.distributions.Normal(y_mean, (var_f + noise.expand_as(var_f)).sqrt()),
reinterpreted_batch_ndims=y_mean.dim(),
)

# See if we're using a sampled GP distribution
# Samples will occur in the first batch dimension
sample_shape = y_dist.shape()[:-y_obs.dim()]
y_obs = y_obs.expand(y_dist.shape())
with pyro.poutine.scale(scale=float(1. / sample_shape.numel())):
pyro.sample(name_prefix + "._training_labels", y_dist, obs=y_obs)
136 changes: 108 additions & 28 deletions gpytorch/likelihoods/likelihood.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,140 @@
#!/usr/bin/env python3

import torch
from ..module import Module
from ..distributions import MultivariateNormal
from ..utils.quadrature import GaussHermiteQuadrature1D
from .. import settings
import warnings


class Likelihood(Module):
class _Likelihood(Module):
"""
A Likelihood in GPyTorch specifies the mapping from latent function values
f to observed labels y.
For example, in the case of regression this might be a Gaussian
distribution, as y(x) is equal to f(x) plus Gaussian noise:
y(x) = f(x) + \epsilon, \epsilon ~ N(0,\sigma^{2}_{n} I)
.. math::
y(x) = f(x) + \epsilon, \epsilon ~ N(0,\sigma^{2}_{n} I)
In the case of classification, this might be a Bernoulli distribution,
where the probability that y=1 is given by the latent function
passed through some sigmoid or probit function:
y(x) = 1 w/ probability \sigma(f(x)), -1 w/ probability 1-\sigma(f(x))
y(x) = 1 w/ probability \sigma(f(x)), -1 w/ probability 1-\sigma(f(x))
In either case, to implement a (non-Gaussian) likelihood function, GPyTorch
requires that two methods be implemented:
requires a :attr:`forward` method that computes the conditional distribution
:math:`p(y \mid f)`.
1. A forward method that computes predictions p(y*|x*) given a distribution
over the latent function p(f*|x*). Typically, this solves or
approximates the integral:
Calling this object does one of two things:
p(y*|x*) = \int p(y*|f*)p(f*|x*) df*
2. A variational_log_probability method that computes the log probability
\log p(y|f) from a set of samples of f. This is only used for variational
inference.
- If likelihood is called with a :class:`torch.Tensor` object, then it is
assumed that the input is samples from :math:`f(x)`. This
returns the *conditional* distribution `p(y|f(x))`.
- If likelihood is called with a :class:`gpytorch.distribution.MultivariateNormal` object,
then it is assumed that the input is the distribution :math:`f(x)`.
This returns the *marginal* distribution `p(y|x)`.
"""
def __init__(self):
super().__init__()
self.quadrature = GaussHermiteQuadrature1D()

def forward(self, *inputs, **kwargs):
def forward(self, function_samples, *params, **kwargs):
"""
Computes a predictive distribution p(y*|x*) given either a posterior
distribution p(f|D,x) or a prior distribution p(f|x) as input.
Computes the conditional distribution p(y|f) that defines the likelihood.
With both exact inference and variational inference, the form of
p(f|D,x) or p(f|x) should usually be Gaussian. As a result, input
should usually be a MultivariateNormal specified by the mean and
(co)variance of p(f|...).
Args:
:attr:`function_samples`
Samples from the function `f`
:attr:`kwargs`
Returns:
Distribution object (with same shape as :attr:`function_samples`)
"""
raise NotImplementedError

def variational_log_probability(self, f, y):
def expected_log_prob(self, observations, function_dist, *params, **kwargs):
"""
Compute the log likelihood p(y|f) given y and averaged over a set of
latent function samples.
Computes the expected log likelihood (used for variational inference):
For the purposes of our variational inference implementation, y is an
n-by-1 label vector, and f is an n-by-s matrix of s samples from the
variational posterior, q(f|D).
.. math::
\mathbb{E}_{f(x)} \left[ \log p \left( y \mid f(x) \right) \right]
Args:
:attr:`function_dist` (:class:`gpytorch.distributions.MultivariateNormal`)
Distribution for :math:`f(x)`.
:attr:`observations` (:class:`torch.Tensor`)
Values of :math:`y`.
:attr:`kwargs`
Returns
`torch.Tensor` (log probability)
"""
raise NotImplementedError
log_prob_lambda = lambda function_samples: self.forward(function_samples).log_prob(observations)
log_prob = self.quadrature(log_prob_lambda, function_dist)
return log_prob.sum(tuple(range(len(function_dist.event_shape))))

def pyro_sample_y(self, variational_dist_f, y_obs, sample_shape, name_prefix=""):
raise NotImplementedError
def marginal(self, function_dist, *params, **kwargs):
"""
Computes a predictive distribution :math:`p(y*|x*)` given either a posterior
distribution :math:`p(f|D,x)` or a prior distribution :math:`p(f|x)` as input.
With both exact inference and variational inference, the form of
:math:`p(f|D,x)` or :math:`p(f|x)` should usually be Gaussian. As a result, input
should usually be a MultivariateNormal specified by the mean and
(co)variance of :math:`p(f|...)`.
Args:
:attr:`function_dist` (:class:`gpytorch.distributions.MultivariateNormal`)
Distribution for :math:`f(x)`.
:attr:`kwargs`
Returns
Distribution object (the marginal distribution, or samples from it)
"""
sample_shape = torch.Size((settings.num_likelihood_samples,))
function_samples = function_dist.rsample(sample_shape)
return self.forward(function_samples)

def variational_log_probability(self, function_dist, observations):
warnings.warn(
"Likelihood.variational_log_probability is deprecated. Use Likelihood.expected_log_prob instead.",
DeprecationWarning
)
return self.expected_log_prob(observations, function_dist)

def __call__(self, input, *params, **kwargs):
# Conditional
if torch.is_tensor(input):
return super().__call__(input, *params, **kwargs)
# Marginal
elif isinstance(input, MultivariateNormal):
return self.marginal(input, *params, **kwargs)
# Error
else:
raise RuntimeError(
"Likelihoods expects a MultivariateNormal input to make marginal predictions, or a "
"torch.Tensor for conditional predictions. Got a {}".format(input.__class__.__name__)
)


try:
import pyro

class Likelihood(_Likelihood):
def pyro_sample_outputs(self, observations, function_dist, *params, **kwargs):
name_prefix = kwargs.pop("name_prefix", "")
with pyro.plate(name_prefix + ".output_values_plate", function_dist.batch_shape[-1], dim=-1):
with pyro.poutine.block():
function_samples = pyro.sample(name_prefix + ".function_values", function_dist)
output_dist = self(function_samples, *params, **kwargs)
samples = pyro.sample(name_prefix + ".output_values", output_dist, obs=observations)
return samples

except ImportError:
class Likelihood(_Likelihood):
def pyro_sample_outputs(self, *args, **kwargs):
raise ImportError("Failed to import pyro. Is it installed correctly?")
16 changes: 14 additions & 2 deletions gpytorch/likelihoods/likelihood_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,17 @@ def __init__(self, *likelihoods):
super().__init__()
self.likelihoods = ModuleList(likelihoods)

def forward(self, *args):
return [likelihood.forward(*args_) for likelihood, args_ in zip(self.likelihoods, _get_tuple_args_(*args))]
def expected_log_prob(self, *args, **kwargs):
return [
likelihood.expected_log_prob(*args_, **kwargs)
for likelihood, args_ in zip(self.likelihoods, _get_tuple_args_(*args))
]

def pyro_sample_outputs(self, *args, **kwargs):
return [
likelihood.pyro_sample_outputs(*args_, **kwargs)
for likelihood, args_ in zip(self.likelihoods, _get_tuple_args_(*args))
]

def __call__(self, *args, **kwargs):
return [likelihood(*args_, **kwargs) for likelihood, args_ in zip(self.likelihoods, _get_tuple_args_(*args))]

0 comments on commit cbfd111

Please sign in to comment.