Closed
Description
🐛 Bug
I want to sample from the prior distribution with precision torch.float62
However, during sampling with KISS-GP a dtype error is raised
if I manually design a kernel (very similar to the RBF Kernel)
that is included in the GridInterpolationKernel
.
Changing the test data size from 2500x2
to 100x2
, no error will occur.
x = torch.meshgrid(
torch.linspace(0, 10 - 1, 10) * 1.,
torch.linspace(0, 10 - 1, 10) * 1.,
indexing="xy",
)
x = torch.cat(
(
x[0].contiguous().view(x[0].numel(), 1),
x[1].contiguous().view(x[1].numel(), 1),
),
dim=1,
)
To reproduce
import torch
import gpytorch
torch.set_default_dtype(torch.float64)
def postprocess_rot(dist_mat):
return dist_mat.mul_(-1.0).exp_()
class TestKernel(gpytorch.kernels.Kernel):
is_stationary = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x1, x2, **params):
x1_ = x1.div_(torch.tensor([10., 1.]))
x2_ = x2.div_(torch.tensor([10., 1.]))
return self.covar_dist(
x1_, x2_, square_dist=False, dist_postprocess_func=postprocess_rot, **params
)
class ExactGP(gpytorch.models.ExactGP):
def __init__(self, **kwargs):
super().__init__(None, None, gpytorch.likelihoods.GaussianLikelihood())
self.mean_module = gpytorch.means.ZeroMean()
self.covar_module = gpytorch.kernels.GridInterpolationKernel(
TestKernel(ard_num_dims=2, **kwargs),
grid_size=100,
num_dims=2
)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
x = torch.meshgrid(
torch.linspace(0, 50 - 1, 50) * 1.,
torch.linspace(0, 50 - 1, 50) * 1.,
indexing="xy",
)
x = torch.cat(
(
x[0].contiguous().view(x[0].numel(), 1),
x[1].contiguous().view(x[1].numel(), 1),
),
dim=1,
)
model = ExactGP()
model.eval()
with torch.no_grad(), gpytorch.settings.fast_pred_var(), gpytorch.settings.max_root_decomposition_size(100):
with gpytorch.settings.fast_pred_samples():
samples = model(x).rsample(torch.Size([1]))
** Error message **
expected scalar type Double but found Float
System information
torch=1.13.0
gpytorch=1.9.0