-
Notifications
You must be signed in to change notification settings - Fork 545
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This is a proof of concept of how heteroskedastic likelihoods may work.
- Loading branch information
Showing
10 changed files
with
106 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,18 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
from __future__ import unicode_literals | ||
from __future__ import absolute_import, division, print_function, unicode_literals | ||
|
||
from .bernoulli_likelihood import BernoulliLikelihood | ||
from .gaussian_likelihood import GaussianLikelihood, HeteroskedasticGaussianLikelihood, HomoskedasticGaussianLikelihood | ||
from .likelihood import Likelihood | ||
from .gaussian_likelihood import GaussianLikelihood | ||
from .multitask_gaussian_likelihood import MultitaskGaussianLikelihood | ||
from .bernoulli_likelihood import BernoulliLikelihood | ||
from .softmax_likelihood import SoftmaxLikelihood | ||
|
||
|
||
__all__ = [ | ||
"Likelihood", | ||
"BernoulliLikelihood", | ||
"GaussianLikelihood", | ||
"HeteroskedasticGaussianLikelihood", | ||
"HomoskedasticGaussianLikelihood", | ||
"Likelihood", | ||
"MultitaskGaussianLikelihood", | ||
"BernoulliLikelihood", | ||
"SoftmaxLikelihood", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from __future__ import absolute_import, division, print_function, unicode_literals | ||
|
||
import torch | ||
from torch.nn import Parameter | ||
|
||
from ..lazy import DiagLazyTensor | ||
from ..module import Module | ||
|
||
|
||
class HomoskedasticNoise(Module): | ||
def __init__(self, log_noise_prior=None, batch_size=1): | ||
super(HomoskedasticNoise, self).__init__() | ||
self.register_parameter( | ||
name="log_noise", parameter=Parameter(torch.zeros(batch_size, 1)), prior=log_noise_prior | ||
) | ||
|
||
def forward(self, params): | ||
noise = self.log_noise.exp() | ||
if isinstance(params, list): | ||
variance_shape = params[0].shape[:-2] + params[0].shape[-1:] | ||
else: | ||
variance_shape = params.shape[:-2] + params.shape[-1:] | ||
if len(variance_shape) == 1: | ||
noise = noise.squeeze(0) | ||
variances = noise * torch.ones(*variance_shape, dtype=noise.dtype, device=noise.device) | ||
return DiagLazyTensor(variances) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,16 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
from __future__ import unicode_literals | ||
from __future__ import absolute_import, division, print_function, unicode_literals | ||
|
||
from .gp import GP | ||
from .additive_grid_inducing_variational_gp import AdditiveGridInducingVariationalGP | ||
from .exact_gp import ExactGP | ||
from .variational_gp import VariationalGP | ||
from .gp import GP | ||
from .grid_inducing_variational_gp import GridInducingVariationalGP | ||
from .additive_grid_inducing_variational_gp import AdditiveGridInducingVariationalGP | ||
from .variational_gp import VariationalGP | ||
|
||
|
||
__all__ = ["GP", "ExactGP", "VariationalGP", "GridInducingVariationalGP", "AdditiveGridInducingVariationalGP"] | ||
__all__ = [ | ||
"AdditiveGridInducingVariationalGP", | ||
"ExactGP", | ||
"GP", | ||
"VariationalGP", | ||
"GridInducingVariationalGP", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters