Skip to content

Commit

Permalink
move mtgl kronecker into mtgl
Browse files Browse the repository at this point in the history
  • Loading branch information
wjmaddox committed Feb 10, 2021
1 parent 50d3358 commit 40dbece
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 190 deletions.
244 changes: 55 additions & 189 deletions gpytorch/likelihoods/multitask_gaussian_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,13 @@
from ..constraints import GreaterThan
from ..distributions import base_distributions
from ..lazy import (
BlockDiagLazyTensor,
ConstantDiagLazyTensor,
DiagLazyTensor,
KroneckerProductDiagLazyTensor,
KroneckerProductLazyTensor,
MatmulLazyTensor,
RootLazyTensor,
lazify,
)
from ..likelihoods import Likelihood, _GaussianLikelihoodBase
from .noise_models import MultitaskHomoskedasticNoise


class _MultitaskGaussianLikelihoodBase(_GaussianLikelihoodBase):
Expand Down Expand Up @@ -66,141 +62,74 @@ def _eval_corr_matrix(self):
C = Cfac / Cfac.pow(2).sum(dim=-1, keepdim=True).sqrt()
return C @ C.transpose(-1, -2)

def _shaped_noise_covar(self, base_shape, *params):
if len(base_shape) >= 2:
*batch_shape, n, _ = base_shape
else:
*batch_shape, n = base_shape

# compute the noise covariance
if len(params) > 0:
shape = None
else:
shape = base_shape if len(base_shape) == 1 else base_shape[:-1]
noise_covar = self.noise_covar(*params, shape=shape)

if self.rank > 0:
# if rank > 0, compute the task correlation matrix
# TODO: This is inefficient, change repeat so it can repeat LazyTensors w/ multiple batch dimensions
task_corr = self._eval_corr_matrix()
exp_shape = torch.Size([*batch_shape, n]) + task_corr.shape[-2:]
task_corr_exp = lazify(task_corr.unsqueeze(-3).expand(exp_shape))
noise_sem = noise_covar.sqrt()
task_covar_blocks = MatmulLazyTensor(MatmulLazyTensor(noise_sem, task_corr_exp), noise_sem)
else:
# otherwise tasks are uncorrelated
if isinstance(noise_covar, DiagLazyTensor):
flattened_diag = noise_covar._diag.view(*noise_covar._diag.shape[:-2], -1)
return DiagLazyTensor(flattened_diag)
task_covar_blocks = noise_covar
if len(batch_shape) == 1:
# TODO: Properly support general batch shapes in BlockDiagLazyTensor (no shape arithmetic)
tcb_eval = task_covar_blocks.evaluate()
task_covar = BlockDiagLazyTensor(lazify(tcb_eval), block_dim=-3)
else:
task_covar = BlockDiagLazyTensor(task_covar_blocks)

return task_covar
def marginal(self, function_dist, *params, **kwargs):
r"""
Adds the task noises to the diagonal of the covariance matrix of the supplied
:obj:`gpytorch.distributions.MultivariateNormal` or :obj:`gpytorch.distributions.MultitaskMultivariateNormal`,
in case of `rank` == 0. Otherwise, adds a rank `rank` covariance matrix to it.
def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> base_distributions.Normal:
noise = self._shaped_noise_covar(function_samples.shape, *params, **kwargs).diag()
noise = noise.view(*noise.shape[:-1], *function_samples.shape[-2:])
return base_distributions.Independent(base_distributions.Normal(function_samples, noise.sqrt()), 1)
To accomplish this, we form a new :obj:`gpytorch.lazy.KroneckerProductLazyTensor` between :math:`I_{n}`,
an identity matrix with size equal to the data and a (not necessarily diagonal) matrix containing the task
noises :math:`D_{t}`.
We also incorporate a shared `noise` parameter from the base
:class:`gpytorch.likelihoods.GaussianLikelihood` that we extend.
class MultitaskGaussianLikelihood(_MultitaskGaussianLikelihoodBase):
r"""
A convenient extension of the :class:`gpytorch.likelihoods.GaussianLikelihood` to the multitask setting that allows
for a full cross-task covariance structure for the noise.
The likelihood is given by
.. math::
p(\mathbf y \mid \mathbf f) = \mathcal N \left( \mathbf f, \mathbf B \mathbf B^\top + \sigma^2 \mathbf I \right)
where :math:`\sigma^2` is a constant noise term, and the covariance matrix :math:`\mathbf B \mathbf B^\top`
captures inter-task noise.
The fitted covariance matrix has rank `rank`.
If a strictly diagonal task noise covariance matrix is desired, then rank=0 should be set.
This likelihood assumes homoskedastic noise.
.. note::
Like the Gaussian likelihood, this object can be used with exact inference.
:param int num_tasks: Number of tasks.
:param int rank: The rank of the task noise covariance matrix :math:`\mathbf B \mathbf B^\top`to fit.
If `rank` is set to 0, then a diagonal covariance matrix is fit.
:param task_correlation_prior: Prior to use over the task noise correlaton matrix :math:`\mathbf B`.
Only used when `rank` > 0.
:type task_correlation_prior: ~gpytorch.priors.Prior, optional
:param batch_shape: The batch shape of the learned noise parameter (default: []).
:type batch_shape: torch.Size, optional
:param noise_prior: Prior for noise parameter :math:`\sigma^2`.
:type noise_prior: ~gpytorch.priors.Prior, optional
:param noise_constraint: Constraint for noise parameter :math:`\sigma^2`.
:type noise_constraint: ~gpytorch.constraints.Interval, optional
:var torch.Tensor noise: :math:`\sigma^2` parameter (constant diagonal noise)
"""
The final covariance matrix after this method is then :math:`K + D_{t} \otimes I_{n} + \sigma^{2}I_{nt}`.
def __init__(
self,
num_tasks,
rank=0,
task_correlation_prior=None,
batch_shape=torch.Size(),
noise_prior=None,
noise_constraint=None,
):
"""
Args:
num_tasks (int): Number of tasks.
function_dist (:obj:`gpytorch.distributions.MultitaskMultivariateNormal`): Random variable whose covariance
matrix is a :obj:`gpytorch.lazy.LazyTensor` we intend to augment.
Returns:
:obj:`gpytorch.distributions.MultitaskMultivariateNormal`: A new random variable whose covariance
matrix is a :obj:`gpytorch.lazy.LazyTensor` with :math:`D_{t} \otimes I_{n}` and :math:`\sigma^{2}I_{nt}`
added.
"""
mean, covar = function_dist.mean, function_dist.lazy_covariance_matrix

rank (int):
The rank of the task noise covariance matrix to fit. If `rank` is set to 0,
then a diagonal covariance matrix is fit.
covar_kron_lt = self._shaped_noise_covar(mean.shape, add_noise=self.has_global_noise)
covar = covar + covar_kron_lt

task_correlation_prior (:obj:`gpytorch.priors.Prior`):
Prior to use over the task noise correlaton matrix.
Only used when `rank` > 0.
return function_dist.__class__(mean, covar)

"""
if noise_constraint is None:
noise_constraint = GreaterThan(1e-4)
def _shaped_noise_covar(self, shape, add_noise=True, *params, **kwargs):
if not self.has_task_noise:
noise = ConstantDiagLazyTensor(self.noise, diag_shape=shape[-2] * self.num_tasks)
return noise

noise_covar = MultitaskHomoskedasticNoise(
num_tasks=num_tasks, noise_prior=noise_prior, noise_constraint=noise_constraint, batch_shape=batch_shape
)
super().__init__(
num_tasks=num_tasks,
noise_covar=noise_covar,
rank=rank,
task_correlation_prior=task_correlation_prior,
batch_shape=batch_shape,
)
if self.rank == 0:
task_noises = self.raw_task_noises_constraint.transform(self.raw_task_noises)
task_var_lt = DiagLazyTensor(task_noises)
dtype, device = task_noises.dtype, task_noises.device
ckl_init = KroneckerProductDiagLazyTensor
else:
task_noise_covar_factor = self.task_noise_covar_factor
task_var_lt = RootLazyTensor(task_noise_covar_factor)
dtype, device = task_noise_covar_factor.dtype, task_noise_covar_factor.device
ckl_init = KroneckerProductLazyTensor

self.register_parameter(name="raw_noise", parameter=torch.nn.Parameter(torch.zeros(*batch_shape, 1)))
self.register_constraint("raw_noise", noise_constraint)
eye_lt = ConstantDiagLazyTensor(torch.ones(*shape[:-2], 1, dtype=dtype, device=device), diag_shape=shape[-2])
task_var_lt = task_var_lt.expand(*shape[:-2], *task_var_lt.matrix_shape)

@property
def noise(self):
return self.raw_noise_constraint.transform(self.raw_noise)
# to add the latent noise we exploit the fact that
# I \kron D_T + \sigma^2 I_{NT} = I \kron (D_T + \sigma^2 I)
# which allows us to move the latent noise inside the task dependent noise
# thereby allowing exploitation of Kronecker structure in this likelihood.
if add_noise and self.has_global_noise:
noise = ConstantDiagLazyTensor(self.noise, diag_shape=task_var_lt.shape[-1])
task_var_lt = task_var_lt + noise

@noise.setter
def noise(self, value):
self._set_noise(value)
covar_kron_lt = ckl_init(eye_lt, task_var_lt)

def _set_noise(self, value):
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_noise)
self.initialize(raw_noise=self.raw_noise_constraint.inverse_transform(value))
return covar_kron_lt

def _shaped_noise_covar(self, base_shape, *params):
noise_covar = super()._shaped_noise_covar(base_shape, *params)
noise = self.noise
return noise_covar.add_diag(noise)
def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> base_distributions.Normal:
noise = self._shaped_noise_covar(function_samples.shape, *params, **kwargs).diag()
noise = noise.view(*noise.shape[:-1], *function_samples.shape[-2:])
return base_distributions.Independent(base_distributions.Normal(function_samples, noise.sqrt()), 1)


class MultitaskGaussianLikelihoodKronecker(_MultitaskGaussianLikelihoodBase):
class MultitaskGaussianLikelihood(_MultitaskGaussianLikelihoodBase):
"""
A convenient extension of the :class:`gpytorch.likelihoods.GaussianLikelihood` to the multitask setting that allows
for a full cross-task covariance structure for the noise. The fitted covariance matrix has rank `rank`.
Expand All @@ -209,8 +138,6 @@ class MultitaskGaussianLikelihoodKronecker(_MultitaskGaussianLikelihoodBase):
Like the Gaussian likelihood, this object can be used with exact inference.
Note: This Likelihood is scheduled to be deprecated and replaced by an improved version of
`MultitaskGaussianLikelihood`. Use this only for compatibility with batched Multitask models.
"""

def __init__(
Expand Down Expand Up @@ -300,68 +227,7 @@ def _eval_covar_matrix(self):
D = noise * torch.eye(self.num_tasks, dtype=noise.dtype, device=noise.device)
return covar_factor.matmul(covar_factor.transpose(-1, -2)) + D

def marginal(self, function_dist, *params, **kwargs):
r"""
Adds the task noises to the diagonal of the covariance matrix of the supplied
:obj:`gpytorch.distributions.MultivariateNormal` or :obj:`gpytorch.distributions.MultitaskMultivariateNormal`,
in case of `rank` == 0. Otherwise, adds a rank `rank` covariance matrix to it.

To accomplish this, we form a new :obj:`gpytorch.lazy.KroneckerProductLazyTensor` between :math:`I_{n}`,
an identity matrix with size equal to the data and a (not necessarily diagonal) matrix containing the task
noises :math:`D_{t}`.
We also incorporate a shared `noise` parameter from the base
:class:`gpytorch.likelihoods.GaussianLikelihood` that we extend.
The final covariance matrix after this method is then :math:`K + D_{t} \otimes I_{n} + \sigma^{2}I_{nt}`.
Args:
function_dist (:obj:`gpytorch.distributions.MultitaskMultivariateNormal`): Random variable whose covariance
matrix is a :obj:`gpytorch.lazy.LazyTensor` we intend to augment.
Returns:
:obj:`gpytorch.distributions.MultitaskMultivariateNormal`: A new random variable whose covariance
matrix is a :obj:`gpytorch.lazy.LazyTensor` with :math:`D_{t} \otimes I_{n}` and :math:`\sigma^{2}I_{nt}`
added.
"""
mean, covar = function_dist.mean, function_dist.lazy_covariance_matrix

covar_kron_lt = self._shaped_noise_covar(mean.shape, add_noise=self.has_global_noise)
covar = covar + covar_kron_lt

return function_dist.__class__(mean, covar)

def _shaped_noise_covar(self, shape, add_noise=True, *params, **kwargs):
if not self.has_task_noise:
noise = ConstantDiagLazyTensor(self.noise, diag_shape=shape[-2] * self.num_tasks)
return noise

if self.rank == 0:
task_noises = self.raw_task_noises_constraint.transform(self.raw_task_noises)
task_var_lt = DiagLazyTensor(task_noises)
dtype, device = task_noises.dtype, task_noises.device
ckl_init = KroneckerProductDiagLazyTensor
else:
task_noise_covar_factor = self.task_noise_covar_factor
task_var_lt = RootLazyTensor(task_noise_covar_factor)
dtype, device = task_noise_covar_factor.dtype, task_noise_covar_factor.device
ckl_init = KroneckerProductLazyTensor

eye_lt = ConstantDiagLazyTensor(torch.ones(*shape[:-2], 1, dtype=dtype, device=device), diag_shape=shape[-2])
task_var_lt = task_var_lt.expand(*shape[:-2], *task_var_lt.matrix_shape)

# to add the latent noise we exploit the fact that
# I \kron D_T + \sigma^2 I_{NT} = I \kron (D_T + \sigma^2 I)
# which allows us to move the latent noise inside the task dependent noise
# thereby allowing exploitation of Kronecker structure in this likelihood.
if add_noise and self.has_global_noise:
noise = ConstantDiagLazyTensor(self.noise, diag_shape=task_var_lt.shape[-1])
task_var_lt = task_var_lt + noise

covar_kron_lt = ckl_init(eye_lt, task_var_lt)

return covar_kron_lt

def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> base_distributions.Normal:
noise = self._shaped_noise_covar(function_samples.shape, *params, **kwargs).diag()
noise = noise.view(*noise.shape[:-1], *function_samples.shape[-2:])
return base_distributions.Independent(base_distributions.Normal(function_samples, noise.sqrt()), 1)
# TODO: remove in a later commit
# MultitaskGaussianLikelihoodKronecker has replaced MultitaskGaussianLikelihood
MultitaskGaussianLikelihoodKronecker = MultitaskGaussianLikelihood
4 changes: 3 additions & 1 deletion gpytorch/likelihoods/noise_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def forward(self, *params: Any, shape: Optional[torch.Size] = None, **kwargs: An
noise_diag = noise.expand(*batch_shape, 1, num_tasks).contiguous()
if num_tasks == 1:
noise_diag = noise_diag.view(*batch_shape, 1)
return ConstantDiagLazyTensor(noise_diag, diag_shape=n * num_tasks)
if noise_diag.shape[-1] != 1:
noise_diag = noise_diag.unsqueeze(-1)
return ConstantDiagLazyTensor(noise_diag, diag_shape=n)


class HomoskedasticNoise(_HomoskedasticNoiseBase):
Expand Down

0 comments on commit 40dbece

Please sign in to comment.