Skip to content

Commit

Permalink
Add device property to Kernels, add unit tests
Browse files Browse the repository at this point in the history
This is kinda useful at times.
  • Loading branch information
Balandat committed Dec 23, 2022
1 parent 071314a commit 0c35f58
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
9 changes: 9 additions & 0 deletions gpytorch/kernels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,15 @@ def batch_shape(self) -> torch.Size:
def batch_shape(self, val: torch.Size):
self._batch_shape = val

@property
def device(self) -> Optional[torch.device]:
if self.has_lengthscale:
return self.lengthscale.device
else:
for param in self.parameters():
return param.device
return None

@property
def dtype(self) -> torch.dtype:
if self.has_lengthscale:
Expand Down
11 changes: 11 additions & 0 deletions gpytorch/test/base_kernel_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,14 @@ def test_kernel_getitem_broadcast(self):
def test_kernel_pickle_unpickle(self):
kernel = self.create_kernel_no_ard(batch_shape=torch.Size([]))
pickle.loads(pickle.dumps(kernel)) # Should be able to pickle and unpickle a kernel

def test_kernel_dtype_device(self):
kernel = self.create_kernel_no_ard(batch_shape=torch.Size([]))
self.assertEqual(kernel.dtype, torch.get_default_dtype())
self.assertEqual(kernel.device, torch.device("cpu"))
new_type = next(iter({torch.float32, torch.float64} - {torch.get_default_dtype()}))
kernel.to(dtype=new_type)
self.assertEqual(kernel.dtype, new_type)
if torch.cuda.is_available():
kernel.to(device="cuda")
self.assertNotEqual(kernel.device, torch.device("cpu"))

0 comments on commit 0c35f58

Please sign in to comment.