Skip to content

Commit

Permalink
Params encapsulate modules
Browse files Browse the repository at this point in the history
  • Loading branch information
gpleiss committed Aug 16, 2017
1 parent 25c51af commit b61a525
Show file tree
Hide file tree
Showing 16 changed files with 193 additions and 129 deletions.
2 changes: 1 addition & 1 deletion gpytorch/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def run_(self, train_x, train_y, inducing_points=None, max_inference_steps=20, *
likelihood = self.gp_model.likelihood
if isinstance(likelihood, GaussianLikelihood):
output = self.gp_model.forward(*train_x, **kwargs)
if len(output) == 2 and isinstance(output[0], GaussianRandomVariable):
if isinstance(output, GaussianRandomVariable):
if isinstance(self.gp_model, _ExactGPPosterior):
raise RuntimeError('Updating existing GP posteriors is not yet supported.')
else:
Expand Down
17 changes: 9 additions & 8 deletions gpytorch/inference/posterior_models/exact_gp_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,31 +46,32 @@ def forward(self, *inputs, **params):
full_inputs = [torch.cat([train_x_var, input]) for train_x_var, input in zip(train_x_vars, inputs)]
else:
full_inputs = inputs
gaussian_rv_output, log_noise = self.prior_model.forward(*full_inputs, **params)
gaussian_rv_output = self.prior_model.forward(*full_inputs, **params)
full_mean, full_covar = gaussian_rv_output.representation()

# If there's data, use it
if n and not self.training:
# Get mean/covar components
train_mean = full_mean[:n]
test_mean = full_mean[n:]
train_train_covar = gpytorch.add_diag(full_covar[:n, :n], log_noise.exp())
train_train_covar = gpytorch.add_diag(full_covar[:n, :n], self.likelihood.log_noise.exp())
train_test_covar = full_covar[:n, n:]
test_train_covar = full_covar[n:, :n]
test_test_covar = full_covar[n:, n:]

if hasattr(self, 'alpha') and self.alpha is not None:
GRV, alpha = gpytorch._exact_predict(test_mean, test_test_covar, self.train_y, train_mean,
train_train_covar, train_test_covar, test_train_covar, self.alpha)
output, alpha = gpytorch._exact_predict(test_mean, test_test_covar, self.train_y, train_mean,
train_train_covar, train_test_covar, test_train_covar,
self.alpha)
else:
GRV, alpha = gpytorch._exact_predict(test_mean, test_test_covar, self.train_y, train_mean,
train_train_covar, train_test_covar, test_train_covar)
output, alpha = gpytorch._exact_predict(test_mean, test_test_covar, self.train_y, train_mean,
train_train_covar, train_test_covar, test_train_covar)

self.alpha = alpha
else:
GRV = GaussianRandomVariable(full_mean, full_covar)
output = GaussianRandomVariable(full_mean, full_covar)

return GRV, log_noise
return output

def marginal_log_likelihood(self, output, train_y):
mean, covar = output.representation()
Expand Down
14 changes: 11 additions & 3 deletions gpytorch/kernels/index_kernel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import torch
from torch import nn
from .kernel import Kernel


class IndexKernel(Kernel):
def forward(self, i1, i2, index_covar_factor, index_log_var):
index_covar_matrix = index_covar_factor.mm(index_covar_factor.t()) + index_log_var.exp().diag()
output_covar = index_covar_matrix.index_select(0, i1.view(-1)).index_select(1, i2.view(-1))
def __init__(self, n_tasks, rank=1, covar_factor_bounds=(-100, 100), log_var_bounds=(-100, 100)):
super(IndexKernel, self).__init__()
self.register_parameter('covar_factor', nn.Parameter(torch.zeros(n_tasks, rank)),
bounds=covar_factor_bounds)
self.register_parameter('log_var', nn.Parameter(torch.zeros(n_tasks)), bounds=log_var_bounds)

def forward(self, i1, i2):
covar_matrix = self.covar_factor.mm(self.covar_factor.t()) + self.log_var.exp().diag()
output_covar = covar_matrix.index_select(0, i1.view(-1)).index_select(1, i2.view(-1))
return output_covar
10 changes: 8 additions & 2 deletions gpytorch/kernels/rbf_kernel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch import nn
from torch.autograd import Function, Variable
from .kernel import Kernel

Expand Down Expand Up @@ -43,7 +44,12 @@ def backward(self, grad_output):


class RBFKernel(Kernel):
def forward(self, x1, x2, log_lengthscale):
def __init__(self, log_lengthscale_bounds=(-10000, 10000)):
super(RBFKernel, self).__init__()
self.register_parameter('log_lengthscale', nn.Parameter(torch.zeros(1, 1)),
bounds=log_lengthscale_bounds)

def forward(self, x1, x2):
n, _ = x1.size()
m, _ = x2.size()
return RBFFunction(x1, x2)(log_lengthscale.expand(n, m))
return RBFFunction(x1, x2)(self.log_lengthscale.expand(n, m))
21 changes: 16 additions & 5 deletions gpytorch/kernels/spectral_mixture_kernel.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
import torch
import math
import torch
from torch import nn
from .kernel import Kernel


class SpectralMixtureKernel(Kernel):
def forward(self, x1, x2, log_mixture_weights, log_mixture_means, log_mixture_scales):
def __init__(self, n_mixtures, log_mixture_weight_bounds=(-100, 100),
log_mixture_mean_bounds=(-100, 100), log_mixture_scale_bounds=(-100, 100)):
super(SpectralMixtureKernel, self).__init__()
self.register_parameter('log_mixture_weights', nn.Parameter(torch.zeros(n_mixtures)),
bounds=log_mixture_weight_bounds)
self.register_parameter('log_mixture_means', nn.Parameter(torch.zeros(n_mixtures)),
bounds=log_mixture_mean_bounds)
self.register_parameter('log_mixture_scales', nn.Parameter(torch.zeros(n_mixtures)),
bounds=log_mixture_scale_bounds)

def forward(self, x1, x2):
n, d = x1.size()
m, _ = x2.size()

Expand All @@ -15,9 +26,9 @@ def forward(self, x1, x2, log_mixture_weights, log_mixture_means, log_mixture_sc
'use a product of SM kernels, one for each dimension.'
]))

mixture_weights = log_mixture_weights.exp()
mixture_means = log_mixture_means.exp()
mixture_scales = log_mixture_scales.mul(2).exp_()
mixture_weights = self.log_mixture_weights.exp()
mixture_means = self.log_mixture_means.exp()
mixture_scales = self.log_mixture_scales.mul(2).exp_()

sq_distance = torch.mm(x1, x2.transpose(0, 1)).mul_(2)

Expand Down
10 changes: 8 additions & 2 deletions gpytorch/likelihoods/gaussian_likelihood.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import torch
import gpytorch
from torch import nn
from gpytorch.random_variables import GaussianRandomVariable
from .likelihood import Likelihood


class GaussianLikelihood(Likelihood):
def forward(self, input, log_noise):
def __init__(self, log_noise_bounds=(-1000, 1000)):
super(GaussianLikelihood, self).__init__()
self.register_parameter('log_noise', nn.Parameter(torch.zeros(1)), bounds=log_noise_bounds)

def forward(self, input):
assert(isinstance(input, GaussianRandomVariable))
mean, covar = input.representation()
noise = gpytorch.add_diag(covar, log_noise.exp())
noise = gpytorch.add_diag(covar, self.log_noise.exp())
return GaussianRandomVariable(mean, noise)
10 changes: 8 additions & 2 deletions gpytorch/means/constant_mean.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import torch
from torch import nn
from .mean import Mean


class ConstantMean(Mean):
def forward(self, input, constant):
return constant.expand(input.size())
def __init__(self, constant_bounds=(-1e10, 1e10)):
super(ConstantMean, self).__init__()
self.register_parameter('constant', nn.Parameter(torch.zeros(1)), bounds=constant_bounds)

def forward(self, input):
return self.constant.expand(input.size())
80 changes: 66 additions & 14 deletions gpytorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,75 @@ def register_parameter(self, name, param, bounds, prior=None):
raise AttributeError(
"cannot assign parameter before Module.__init__() call")
super(Module, self).register_parameter(name, param)
# Get bound
lower_bound_tensor = param.data.new().resize_as_(param.data)
upper_bound_tensor = param.data.new().resize_as_(param.data)
self._bounds[name] = (lower_bound_tensor, upper_bound_tensor)
kwargs = {}
kwargs[name] = bounds
self.set_bounds(**kwargs)

# Set bounds
lower_bound, upper_bound = bounds
if torch.is_tensor(lower_bound) and torch.is_tensor(upper_bound):
if lower_bound.size() != upper_bound.size() or \
lower_bound.size() != param.size():
raise AttributeError('Lower bound, upper bound, and param should have the same size')
self._bounds[name] = (lower_bound, upper_bound)
elif (isinstance(lower_bound, int) or isinstance(lower_bound, float)) and \
(isinstance(upper_bound, int) or isinstance(upper_bound, float)):
lower_bound_tensor = param.data.new().resize_as_(param.data).fill_(lower_bound)
upper_bound_tensor = param.data.new().resize_as_(param.data).fill_(upper_bound)
self._bounds[name] = (lower_bound_tensor, upper_bound_tensor)
else:
raise AttributeError('Unsupported argument types for parameter %s' % name)
def initialize(self, **kwargs):
"""
Set a value for a parameter
kwargs: (param_name, value) - parameter to initialize
Value can take the form of a tensor, a float, or an int
"""
for name, val in kwargs.items():
if name not in self._parameters:
raise AttributeError('Unknown parameter %s for %s' % (name, self.__class__.__name__))
if torch.is_tensor(val):
self.__getattr__(name).data.copy_(val)
elif isinstance(val, float) or isinstance(val, int):
self.__getattr__(name).data.fill_(val)
else:
raise AttributeError('Type %s not valid to initialize parameter %s' % (type(val), name))

# Ensure initializion is within bounds
param = self._parameters[name]
lower_bound, upper_bound = self._bounds[name]
lower_mask = param.data < lower_bound
if any(lower_mask.view(-1)):
raise AttributeError('Parameter %s exceeds lower bound' % name)
upper_mask = param.data > upper_bound
if any(upper_mask.view(-1)):
raise AttributeError('Parameter %s exceeds upper bound' % name)
return self

def set_bounds(self, **kwargs):
"""
Set bounds for a parameter
kwargs: (param_name, value) - parameter to initialize
Value can take the form of a tensor, a float, or an int
"""
for name, bounds in kwargs.items():
if name not in self._parameters:
raise AttributeError('Unknown parameter %s for %s' % (name, self.__class__.__name__))
param = self._parameters[name]
# Set bounds
lower_bound, upper_bound = bounds
if torch.is_tensor(lower_bound) and torch.is_tensor(upper_bound):
if lower_bound.size() != upper_bound.size() or \
lower_bound.size() != param.size():
raise AttributeError('Lower bound, upper bound, and param should have the same size')
self._bounds[name][0].copy_(lower_bound)
self._bounds[name][1].copy_(upper_bound)
elif (isinstance(lower_bound, int) or isinstance(lower_bound, float)) and \
(isinstance(upper_bound, int) or isinstance(upper_bound, float)):
self._bounds[name][0].fill_(lower_bound)
self._bounds[name][1].fill_(upper_bound)
else:
raise AttributeError('Unsupported argument types for parameter %s' % name)
return self

def bound_for(self, name):
"""
Get bounds for parameter
name (str): parameter name
"""
if '.' in name:
module, name = name.split('.', 1)
if module in self._modules:
Expand Down
3 changes: 1 addition & 2 deletions gpytorch/utils/lanczos_quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def lanczos(self, mv_closure, b):
beta[k] = beta_k
Q[:, k] = u


if math.fabs(beta[k]) < 1e-4 or math.fabs(alpha[k]) < 1e-4:
break

Expand All @@ -73,7 +72,7 @@ def lanczos(self, mv_closure, b):

Q = Q[:, :k]
T = torch.diag(alpha) + torch.diag(beta, 1) + torch.diag(beta, -1)

return Q, T

def _lanczos_step(self, u, v, mv_closure, Q):
Expand Down
19 changes: 8 additions & 11 deletions test/examples/kissgp_gp_regression_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import torch
import gpytorch
from torch import nn, optim
from torch import optim
from torch.autograd import Variable
from gpytorch.kernels import RBFKernel, GridInterpolationKernel
from gpytorch.means import ConstantMean
Expand All @@ -20,19 +20,16 @@
# All tests that pass with the exact kernel should pass with the interpolated kernel.
class KissGPModel(gpytorch.GPModel):
def __init__(self):
super(KissGPModel, self).__init__(GaussianLikelihood())
self.mean_module = ConstantMean()
covar_module = RBFKernel()
likelihood = GaussianLikelihood(log_noise_bounds=(-3, 3))
super(KissGPModel, self).__init__(likelihood)
self.mean_module = ConstantMean(constant_bounds=(-1, 1))
covar_module = RBFKernel(log_lengthscale_bounds=(-3, 3))
self.grid_covar_module = GridInterpolationKernel(covar_module, 50)
self.register_parameter('log_noise', nn.Parameter(torch.Tensor([-2])), bounds=(-5, 5)),
self.register_parameter('log_lengthscale', nn.Parameter(torch.Tensor([0])), bounds=(-3, 5)),

def forward(self, x):
mean_x = self.mean_module(x, constant=Variable(torch.Tensor([0])))
covar_x = self.grid_covar_module(x, log_lengthscale=self.log_lengthscale)

latent_pred = GaussianRandomVariable(mean_x, covar_x)
return latent_pred, self.log_noise
mean_x = self.mean_module(x)
covar_x = self.grid_covar_module(x)
return GaussianRandomVariable(mean_x, covar_x)


def test_kissgp_gp_mean_abs_error():
Expand Down
30 changes: 10 additions & 20 deletions test/examples/multitask_gp_regression_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import torch
import gpytorch
from torch import nn, optim
from torch import optim
from torch.autograd import Variable
from gpytorch.kernels import RBFKernel, IndexKernel
from gpytorch.means import ConstantMean
Expand All @@ -25,28 +25,18 @@

class MultitaskGPModel(gpytorch.GPModel):
def __init__(self):
super(MultitaskGPModel, self).__init__(GaussianLikelihood())
self.mean_module = ConstantMean()
self.covar_module = RBFKernel()
self.task_covar_module = IndexKernel()
self.register_parameter('constant_mean', nn.Parameter(torch.randn(1)), bounds=(-1, 1))
self.register_parameter('log_noise', nn.Parameter(torch.randn(1)), bounds=(-6, 6))
self.register_parameter('log_lengthscale', nn.Parameter(torch.randn(1)), bounds=(-6, 6))
self.register_parameter('task_matrix', nn.Parameter(torch.randn(2, 1)), bounds=(-6, 6))
self.register_parameter('task_log_vars', nn.Parameter(torch.randn(2)), bounds=(-6, 6))
likelihood = GaussianLikelihood(log_noise_bounds=(-6, 6))
super(MultitaskGPModel, self).__init__(likelihood)
self.mean_module = ConstantMean(constant_bounds=(-1, 1))
self.covar_module = RBFKernel(log_lengthscale_bounds=(-6, 6))
self.task_covar_module = IndexKernel(n_tasks=2, rank=1, covar_factor_bounds=(-6, 6), log_var_bounds=(-6, 6))

def forward(self, x, i):
mean_x = self.mean_module(x, constant=self.constant_mean)

covar_x = self.covar_module(x, log_lengthscale=self.log_lengthscale)
covar_i = self.task_covar_module(i,
index_covar_factor=self.task_matrix,
index_log_var=self.task_log_vars)

mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
covar_i = self.task_covar_module(i)
covar_xi = covar_x.mul(covar_i)

latent_pred = GaussianRandomVariable(mean_x, covar_xi)
return latent_pred, self.log_noise
return GaussianRandomVariable(mean_x, covar_xi)


def test_multitask_gp_mean_abs_error():
Expand Down

0 comments on commit b61a525

Please sign in to comment.