In [1]:
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.lazy import NonLazyTensor
import torch

In [2]:
N = 10000
d = 1
jitter_val = float(1e-1)
lengthscale = 1

rbf = RBFKernel()
rbf.initialize(lengthscale=lengthscale)

kernel = ScaleKernel(rbf).cuda().half()
X = torch.randn(N, d).cuda().half()

In [3]:
class HalfsiesNonLazyTensor(NonLazyTensor):
    def _get_indices(self, left_indices, right_indices, *batch_indices):
        return super()._get_indices(left_indices, right_indices, *batch_indices).float()
    
    def _getitem(self, *indices):
        return super()._getitem(*indices).float()
    
    def _matmul(self, rhs):
        return super()._matmul(rhs.half()).float()
    
    def _t_matmul(self, rhs):
        return super()._t_matmul(rhs.half()).float()
    
    def _quad_form_derivative(self, left_vecs, right_vecs):
        return (res.float() for res in super()._quad_form_derivative(left_vecs, right_vecs))
    
    def diag(self):
        return super().diag().float()
    
    @property
    def dtype(self):
        return torch.float32
    
    def evaluate(self):
        return self.tensor.float()

# can't use add_jitter because it makes a float :(
lt = HalfsiesNonLazyTensor(kernel(X).evaluate()).add_diag(torch.tensor(jitter_val, device=X.device))

In [5]:
rhs = torch.randn(N, device=X.device)

import gpytorch

with torch.no_grad(), gpytorch.settings.max_cg_iterations(1000), gpytorch.settings.cg_tolerance(0.0001):
    print('Running mixed precision...')
    solve_half = lt.inv_matmul(rhs)
    print('Running fp32...')
    solve_float = kernel.float()(X.float()).add_diag(torch.tensor(jitter_val, device=X.device)).inv_matmul(rhs)

Running mixed precision...
Final CG residual norm after 8.329125557793304e-05 iterations: 14
Running fp32...
Final CG residual norm after 1.0387577020765093e-07 iterations: 10
