Skip to content

[Bug] torch.float62 raises error in GridInterpolationKernel #2225

Closed
@anjawa

Description

@anjawa

🐛 Bug

I want to sample from the prior distribution with precision torch.float62
However, during sampling with KISS-GP a dtype error is raised
if I manually design a kernel (very similar to the RBF Kernel)
that is included in the GridInterpolationKernel.

Changing the test data size from 2500x2 to 100x2, no error will occur.

x = torch.meshgrid(
    torch.linspace(0, 10 - 1, 10) * 1.,
    torch.linspace(0, 10 - 1, 10) * 1.,
    indexing="xy",
)
x = torch.cat(
    (
        x[0].contiguous().view(x[0].numel(), 1),
        x[1].contiguous().view(x[1].numel(), 1),
    ),
    dim=1,
)

To reproduce

import torch
import gpytorch


torch.set_default_dtype(torch.float64)


def postprocess_rot(dist_mat):
    return dist_mat.mul_(-1.0).exp_()

class TestKernel(gpytorch.kernels.Kernel):

    is_stationary = True

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x1, x2, **params):
        x1_ = x1.div_(torch.tensor([10., 1.]))
        x2_ = x2.div_(torch.tensor([10., 1.]))
        return self.covar_dist(
            x1_, x2_, square_dist=False, dist_postprocess_func=postprocess_rot, **params
        )

class ExactGP(gpytorch.models.ExactGP):

    def __init__(self, **kwargs):
        super().__init__(None, None, gpytorch.likelihoods.GaussianLikelihood())
        self.mean_module = gpytorch.means.ZeroMean()
        self.covar_module = gpytorch.kernels.GridInterpolationKernel(
            TestKernel(ard_num_dims=2, **kwargs),
            grid_size=100,
            num_dims=2
            )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


x = torch.meshgrid(
    torch.linspace(0, 50 - 1, 50) * 1.,
    torch.linspace(0, 50 - 1, 50) * 1.,
    indexing="xy",
)
x = torch.cat(
    (
        x[0].contiguous().view(x[0].numel(), 1),
        x[1].contiguous().view(x[1].numel(), 1),
    ),
    dim=1,
)

model = ExactGP()
model.eval()

with torch.no_grad(), gpytorch.settings.fast_pred_var(), gpytorch.settings.max_root_decomposition_size(100):
    with gpytorch.settings.fast_pred_samples():
        samples = model(x).rsample(torch.Size([1]))

** Error message **

expected scalar type Double but found Float

System information

torch=1.13.0
gpytorch=1.9.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions