diff --git a/gpytorch/kernels/kernel.py b/gpytorch/kernels/kernel.py index d2dc355dc..d4fa224a9 100644 --- a/gpytorch/kernels/kernel.py +++ b/gpytorch/kernels/kernel.py @@ -261,6 +261,15 @@ def batch_shape(self) -> torch.Size: def batch_shape(self, val: torch.Size): self._batch_shape = val + @property + def device(self) -> Optional[torch.device]: + if self.has_lengthscale: + return self.lengthscale.device + else: + for param in self.parameters(): + return param.device + return None + @property def dtype(self) -> torch.dtype: if self.has_lengthscale: diff --git a/gpytorch/test/base_kernel_test_case.py b/gpytorch/test/base_kernel_test_case.py index 793be1f0a..543759a49 100644 --- a/gpytorch/test/base_kernel_test_case.py +++ b/gpytorch/test/base_kernel_test_case.py @@ -179,3 +179,14 @@ def test_kernel_getitem_broadcast(self): def test_kernel_pickle_unpickle(self): kernel = self.create_kernel_no_ard(batch_shape=torch.Size([])) pickle.loads(pickle.dumps(kernel)) # Should be able to pickle and unpickle a kernel + + def test_kernel_dtype_device(self): + kernel = self.create_kernel_no_ard(batch_shape=torch.Size([])) + self.assertEqual(kernel.dtype, torch.get_default_dtype()) + self.assertEqual(kernel.device, torch.device("cpu")) + new_type = next(iter({torch.float32, torch.float64} - {torch.get_default_dtype()})) + kernel.to(dtype=new_type) + self.assertEqual(kernel.dtype, new_type) + if torch.cuda.is_available(): + kernel.to(device="cuda") + self.assertNotEqual(kernel.device, torch.device("cpu"))