Skip to content

Commit

Permalink
New base class, clean up, fix MTGLikelihood
Browse files Browse the repository at this point in the history
  • Loading branch information
Balandat committed Oct 30, 2018
1 parent cda00da commit 529456c
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 60 deletions.
3 changes: 2 additions & 1 deletion gpytorch/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def exact_predictive_mean(full_covar, full_mean, train_inputs, train_labels, num
Args:
- full_covar ( (n+t) x (n+t) ) - the block prior covariance matrix of training and testing points
[ K_XX, K_XX*; K_X*X, K_X*X* ]
- train_inputs TODO
- full_mean (n + t) - the training and test prior means, stacked on top of each other
- train_inputs TODO
- train_labels (n) - the training labels minus the training prior mean
- noise (1) - the observed noise (from the likelihood)
- precomputed_cache - speeds up subsequent computations (default: None)
Expand All @@ -99,6 +99,7 @@ def exact_predictive_covar(full_covar, train_inputs, num_train, likelihood, prec
Args:
- full_covar ( (n+t) x (n+t) ) - the block prior covariance matrix of training and testing points
[ K_XX, K_XX*; K_X*X, K_X*X* ]
- train_inputs TODO
- num_train (int) - how many training points are there in the full covariance matrix
- noise (1) - the observed noise (from the likelihood)
- precomputed_cache - speeds up subsequent computations (default: None)
Expand Down
12 changes: 6 additions & 6 deletions gpytorch/lazy/diag_lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,6 @@ def evaluate(self):
else:
return super(DiagLazyTensor, self).evaluate()

def exp(self):
return DiagLazyTensor(self._diag.exp())

def sqrt(self):
return DiagLazyTensor(self._diag.sqrt())

def sum_batch(self, sum_batch_size=None):
if sum_batch_size is None:
diag = self._diag.view(-1, self._diag.size(-1))
Expand All @@ -78,6 +72,12 @@ def sum_batch(self, sum_batch_size=None):

return self.__class__(diag.sum(-2))

def exp(self):
return DiagLazyTensor(self._diag.exp())

def sqrt(self):
return DiagLazyTensor(self._diag.sqrt())

def zero_mean_mvn_samples(self, num_samples):
if self.ndimension() == 3:
base_samples = torch.randn(
Expand Down
11 changes: 11 additions & 0 deletions gpytorch/lazy/lazy_evaluated_kernel_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def _quad_form_derivative(self, left_vecs, right_vecs):
def _transpose_nonbatch(self):
return self.__class__(self.kernel, self.x2, self.x1, **self.params)

def _batch_get_indices(self, batch_indices, left_indices, right_indices):
from ..kernels import Kernel

x1 = self.x1[batch_indices, left_indices, :].unsqueeze(0)
x2 = self.x2[batch_indices, right_indices, :].unsqueeze(0)
res = super(Kernel, self.kernel).__call__(x1.transpose(-1, -2), x2.transpose(-1, -2))
if isinstance(res, LazyTensor):
res = res.evaluate()
res = res.view(-1)
return res

def _get_indices(self, left_indices, right_indices):
from ..kernels import Kernel

Expand Down
1 change: 1 addition & 0 deletions gpytorch/lazy/lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .. import beta_features, settings
from .lazy_tensor_representation_tree import LazyTensorRepresentationTree

from IPython.core.debugger import set_trace

class LazyTensor(object):
"""
Expand Down
12 changes: 9 additions & 3 deletions gpytorch/likelihoods/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
from __future__ import absolute_import, division, print_function, unicode_literals

from .bernoulli_likelihood import BernoulliLikelihood
from .gaussian_likelihood import GaussianLikelihood, HomoskedasticGaussianLikelihood
from .gaussian_likelihood import _GaussianLikelihoodBase, GaussianLikelihood
from .likelihood import Likelihood
from .multitask_gaussian_likelihood import MultitaskGaussianLikelihood
from .multitask_gaussian_likelihood import (
_MultitaskGaussianLikelihoodBase,
MultitaskGaussianLikelihood,
MultitaskGaussianLikelihood_Kronecker,
)
from .noise_models import HeteroskedasticNoise
from .softmax_likelihood import SoftmaxLikelihood


__all__ = [
"_GaussianLikelihoodBase",
"_MultitaskGaussianLikelihoodBase",
"BernoulliLikelihood",
"GaussianLikelihood",
"HeteroskedasticNoise",
"HomoskedasticGaussianLikelihood",
"Likelihood",
"MultitaskGaussianLikelihood",
"MultitaskGaussianLikelihood_Kronecker",
"SoftmaxLikelihood",
]
29 changes: 9 additions & 20 deletions gpytorch/likelihoods/gaussian_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,43 +10,32 @@
from .likelihood import Likelihood
from .noise_models import HomoskedasticNoise

DEPRECATION_WARNING = "'GaussianLikelihood' was renamed to 'HomoskedasticGaussianLikelihood'"


class GaussianLikelihood(Likelihood):
def __init__(self, *args, **kwargs):
if len(args) + len(kwargs) == 0 or "log_noise_prior" in kwargs or "batch_size" in kwargs:
warnings.warn(DEPRECATION_WARNING, DeprecationWarning)
logging.warning(DEPRECATION_WARNING)
self.__init__(log_noise_covar=HomoskedasticNoise(*args, **kwargs))
self._is_homoskedastic = True
else:
super(GaussianLikelihood, self).__init__()
self.log_noise_covar = args[0] if len(kwargs) == 0 else kwargs["log_noise_covar"]
class _GaussianLikelihoodBase(Likelihood):
def __init__(self, log_noise_covar):
super(_GaussianLikelihoodBase, self).__init__()
self.log_noise_covar = log_noise_covar

def forward(self, input, *params):
if not isinstance(input, MultivariateNormal):
raise ValueError("GaussianLikelihood requires a MultivariateNormal input")
raise ValueError("Gaussian Likelihoods require a MultivariateNormal input")
mean, covar = input.mean, input.lazy_covariance_matrix
log_noise_covar = self.log_noise_covar(*params)
if isinstance(log_noise_covar, DiagLazyTensor):
full_covar = AddedDiagLazyTensor(covar, log_noise_covar.exp())
else:
# TODO: Deal with non-diagonal noise covariance models
# TODO: Poperly deal with non-diagonal noise covariance models
full_covar = covar + log_noise_covar.exp()
return input.__class__(mean, full_covar)

def variational_log_probability(self, input, target):
if hasattr(self, "_is_homoskedastic"):
return HomoskedasticGaussianLikelihood.variational_log_probability(self, input, target)
else:
raise NotImplementedError
raise NotImplementedError


class HomoskedasticGaussianLikelihood(GaussianLikelihood):
class GaussianLikelihood(_GaussianLikelihoodBase):
def __init__(self, log_noise_prior=None, batch_size=1):
log_noise_covar = HomoskedasticNoise(log_noise_prior=log_noise_prior, batch_size=1)
super(HomoskedasticGaussianLikelihood, self).__init__(log_noise_covar=log_noise_covar)
super(GaussianLikelihood, self).__init__(log_noise_covar=log_noise_covar)

def variational_log_probability(self, input, target):
mean, variance = input.mean, input.variance
Expand Down
72 changes: 52 additions & 20 deletions gpytorch/likelihoods/multitask_gaussian_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
NonLazyTensor,
RootLazyTensor,
)
from ..likelihoods import GaussianLikelihood, Likelihood


DEPRECATION_WARNING = "'MultitaskGaussianLikelihood' was renamed to 'HomoskedasticMultitaskGaussianLikelihood'"
from ..likelihoods import _GaussianLikelihoodBase, Likelihood
from .noise_models import MultitaskHomoskedasticNoise


def _eval_covar_matrix(task_noise_covar_factor, log_noise):
Expand All @@ -31,7 +29,7 @@ def _eval_corr_matrix(task_noise_corr_factor):
return M * dsqrtinv.unsqueeze(-1).matmul(dsqrtinv.unsqueeze(0))


class MultitaskGaussianLikelihood(GaussianLikelihood):
class _MultitaskGaussianLikelihoodBase(_GaussianLikelihoodBase):
"""
A convenient extension of the :class:`gpytorch.likelihoods.GaussianLikelihood` to the multitask setting that allows
for a full cross-task covariance structure for the noise. The fitted covariance matrix has rank `rank`.
Expand All @@ -46,16 +44,19 @@ def __init__(self, num_tasks, log_noise_covar, rank=0, task_correlation_prior=No
Args:
num_tasks (int): Number of tasks.
log_noise_covar TODO
log_noise_covar (:obj:`gpytorch.module.Module`): A model for the log-noise covariance. This can be a
simple homoskedastic model, or a GP that is to be fitted on the observed measurement errors.
rank (int): The rank of the task noise covariance matrix to fit. If `rank` is set to 0,
then a diagonal covariance matrix is fit.
task_prior (:obj:`gpytorch.priors.Prior`): Prior to use over the task noise covariance matrix if
`rank` > 0, or a prior over the log of just the diagonal elements, if `rank` == 0.
task_correlation_prior (:obj:`gpytorch.priors.Prior`): Prior to use over the task noise correlation matrix.
Only used when `rank` > 0.
batch_size (int): Number of batches.
"""
super(MultitaskGaussianLikelihood, self).__init__(log_noise_covar=log_noise_covar)
super(_MultitaskGaussianLikelihoodBase, self).__init__(log_noise_covar=log_noise_covar)
if rank != 0:
self.register_parameter(
name="task_noise_corr_factor", parameter=torch.nn.Parameter(torch.randn(batch_size, num_tasks, rank))
Expand Down Expand Up @@ -96,10 +97,11 @@ def forward(self, input, *params):
added.
"""
mean, covar = input.mean, input.lazy_covariance_matrix
batch_shape, n = covar.shape[:-2], covar.shape[-1] // self.num_tasks

if hasattr(self, "task_noise_corr_factor"):
task_noise_corr_factor = self.task_noise_corr_factor
if covar.ndimension() == 2:
if len(batch_shape) > 0:
if settings.debug.on() and task_noise_corr_factor.size(0) > 1:
raise RuntimeError(
"With batch_size > 1, expected a batched MultitaskMultivariateNormal distribution."
Expand All @@ -109,20 +111,54 @@ def forward(self, input, *params):
task_corr = NonLazyTensor(_eval_corr_matrix(task_noise_corr_factor))
else:
task_corr = DiagLazyTensor(
torch.ones(covar.shape[:-2] + torch.Size([self.num_tasks]), dtype=covar.dtype, device=covar.device)
torch.ones(batch_shape + torch.Size([self.num_tasks]), dtype=covar.dtype, device=covar.device)
)

log_noise_covar = self.log_noise_covar(*params) # n x num_tasks
log_noise_covar = self.log_noise_covar(*params)
D_sem = log_noise_covar.exp().sqrt()
task_covar_blocks = MatmulLazyTensor(MatmulLazyTensor(D_sem, task_corr.repeat(mean.shape[-2], 1, 1)), D_sem)
task_covar_blocks = MatmulLazyTensor(MatmulLazyTensor(D_sem, task_corr.repeat(n, 1, 1)), D_sem)
task_covar = BlockDiagLazyTensor(task_covar_blocks)
return input.__class__(mean, covar + task_covar)

def variational_log_probability(self, input, target):
raise NotImplementedError
raise NotImplementedError("Variational inference with Multitask Gaussian likelihood is not yet supported")


class HomoskedasticMultitaskGaussianLikelihood(MultitaskGaussianLikelihood):
class MultitaskGaussianLikelihood(_MultitaskGaussianLikelihoodBase):
"""
A convenient extension of the :class:`gpytorch.likelihoods.GaussianLikelihood` to the multitask setting that allows
for a full cross-task covariance structure for the noise. The fitted covariance matrix has rank `rank`.
If a strictly diagonal task noise covariance matrix is desired, then rank=0 should be set. (This option still
allows for a different `log_noise` parameter for each task.)
Like the Gaussian likelihood, this object can be used with exact inference.
"""

def __init__(self, num_tasks, rank=0, task_correlation_prior=None, batch_size=1, log_noise_prior=None):
"""
Args:
num_tasks (int): Number of tasks.
rank (int): The rank of the task noise covariance matrix to fit. If `rank` is set to 0,
then a diagonal covariance matrix is fit.
task_correlation_prior (:obj:`gpytorch.priors.Prior`): Prior to use over the task noise correlaton matrix.
Only used when `rank` > 0.
"""
log_noise_covar = MultitaskHomoskedasticNoise(
num_tasks=num_tasks, log_noise_prior=log_noise_prior, batch_size=1
)
super(MultitaskGaussianLikelihood, self).__init__(
num_tasks=num_tasks,
log_noise_covar=log_noise_covar,
rank=rank,
task_correlation_prior=task_correlation_prior,
batch_size=batch_size,
)


class MultitaskGaussianLikelihood_Kronecker(_MultitaskGaussianLikelihoodBase):
"""
A convenient extension of the :class:`gpytorch.likelihoods.GaussianLikelihood` to the multitask setting that allows
for a full cross-task covariance structure for the noise. The fitted covariance matrix has rank `rank`.
Expand Down Expand Up @@ -167,7 +203,7 @@ def __init__(self, num_tasks, rank=0, task_prior=None, batch_size=1, log_noise_p
)
self.num_tasks = num_tasks

def forward(self, input):
def forward(self, input, *params):
"""
Adds the log task noises to the diagonal of the covariance matrix of the supplied
:obj:`gpytorch.distributions.MultivariateNormal` or
Expand Down Expand Up @@ -233,9 +269,5 @@ def forward(self, input):
raise RuntimeError("With batch_size > 1, expected a batched MultitaskMultivariateNormal distribution.")
noise = noise.squeeze(0)

# set_trace()
covar = add_diag(covar, noise)
return input.__class__(mean, covar)

def variational_log_probability(self, input, target):
raise NotImplementedError("Variational inference with Multitask Gaussian likelihood is not yet supported")
16 changes: 11 additions & 5 deletions gpytorch/likelihoods/noise_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@ def __init__(self, log_noise_prior=None, batch_size=1):
def forward(self, params):
log_noise = self.log_noise
p = params[0] if isinstance(params, list) else params
var_shape = p.shape[:-2] + p.shape[-1:]
if len(var_shape) == 1:
log_noise = log_noise.squeeze(0)
variances = log_noise * torch.ones(*var_shape, dtype=log_noise.dtype, device=log_noise.device)
return DiagLazyTensor(variances)
n = p.shape[-2] if len(p.shape) > 1 else p.shape[-1]
log_noise_diag = log_noise.repeat(n, 1)
return DiagLazyTensor(log_noise_diag)


class MultitaskHomoskedasticNoise(HomoskedasticNoise):
def __init__(self, num_tasks, log_noise_prior=None, batch_size=1):
super(HomoskedasticNoise, self).__init__()
self.register_parameter(
name="log_noise", parameter=Parameter(torch.zeros(batch_size, num_tasks)), prior=log_noise_prior
)


class HeteroskedasticNoise(Module):
Expand Down
4 changes: 2 additions & 2 deletions gpytorch/mlls/exact_marginal_log_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
from .marginal_log_likelihood import MarginalLogLikelihood
from ..likelihoods import GaussianLikelihood
from ..likelihoods import _GaussianLikelihoodBase
from ..distributions import MultivariateNormal


Expand All @@ -18,7 +18,7 @@ def __init__(self, likelihood, model):
- likelihood: (Likelihood) - the likelihood for the model
- model: (Module) - the exact GP model
"""
if not isinstance(likelihood, GaussianLikelihood):
if not isinstance(likelihood, _GaussianLikelihoodBase):
raise RuntimeError("Likelihood must be Gaussian for exact inference")
super(ExactMarginalLogLikelihood, self).__init__(likelihood, model)

Expand Down
6 changes: 3 additions & 3 deletions gpytorch/models/exact_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from ..functions import exact_predictive_mean, exact_predictive_covar
from ..distributions import MultivariateNormal, MultitaskMultivariateNormal
from ..likelihoods import GaussianLikelihood
from ..likelihoods import _GaussianLikelihoodBase
from .. import settings
from .gp import GP

Expand All @@ -18,8 +18,8 @@ def __init__(self, train_inputs, train_targets, likelihood):
train_inputs = (train_inputs,)
if train_inputs is not None and not all(torch.is_tensor(train_input) for train_input in train_inputs):
raise RuntimeError("Train inputs must be a tensor, or a list/tuple of tensors")
if not isinstance(likelihood, GaussianLikelihood):
raise RuntimeError("ExactGP can only handle GaussianLikelihood")
if not isinstance(likelihood, _GaussianLikelihoodBase):
raise RuntimeError("ExactGP can only handle Gaussian Likelihoods")

super(ExactGP, self).__init__()
if train_inputs is not None:
Expand Down

0 comments on commit 529456c

Please sign in to comment.