diff --git a/botorch/models/kernels/__init__.py b/botorch/models/kernels/__init__.py index 1196f74bc4..14b9662ced 100644 --- a/botorch/models/kernels/__init__.py +++ b/botorch/models/kernels/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from botorch.models.kernels.categorical import CategoricalKernel from botorch.models.kernels.downsampling import DownsamplingKernel from botorch.models.kernels.exponential_decay import ExponentialDecayKernel from botorch.models.kernels.linear_truncated_fidelity import ( @@ -12,6 +13,7 @@ __all__ = [ + "CategoricalKernel", "DownsamplingKernel", "ExponentialDecayKernel", "LinearTruncatedFidelityKernel", diff --git a/botorch/models/kernels/categorical.py b/botorch/models/kernels/categorical.py new file mode 100644 index 0000000000..d6920d4706 --- /dev/null +++ b/botorch/models/kernels/categorical.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from gpytorch.kernels.kernel import Kernel +from torch import Tensor + + +class CategoricalKernel(Kernel): + r"""A Kernel for categorical features. + + Computes `exp(-(dist(x1, x2) / lengthscale)**2)`, where + `dist(x1, x2)` is zero if `x1 == x2` and one if `x1 != x2`. + + Note: This kernel is NOT differentiable w.r.t. the inputs. + """ + + has_lengthscale = True + + def forward( + self, + x1: Tensor, + x2: Tensor, + diag: bool = False, + last_dim_is_batch: bool = False, + **kwargs + ) -> Tensor: + delta = x1.unsqueeze(-2) != x2.unsqueeze(-3) + dists = (delta / self.lengthscale.unsqueeze(-2)).pow(2) + if last_dim_is_batch: + dists = dists.transpose(-3, -1) + else: + dists = dists.mean(-1) + res = torch.exp(-dists) + if diag: + res = torch.diagonal(res, dim1=-1, dim2=-2) + return res diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index cb5e71b5e4..b8ae775b0a 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -140,13 +140,24 @@ def optimize_acqf( options = options or {} + # Handle the trivial case when all features are fixed + if fixed_features is not None and len(fixed_features) == bounds.shape[-1]: + X = torch.tensor( + [fixed_features[i] for i in range(bounds.shape[-1])], + device=bounds.device, + dtype=bounds.dtype, + ) + X = X.expand(q, *X.shape) + with torch.no_grad(): + acq_value = acq_function(X) + return X, acq_value + if batch_initial_conditions is None: ic_gen = ( gen_one_shot_kg_initial_conditions if isinstance(acq_function, qKnowledgeGradient) else gen_batch_initial_conditions ) - # TODO: Generating initial candidates should use parameter constraints. batch_initial_conditions = ic_gen( acq_function=acq_function, bounds=bounds, diff --git a/sphinx/source/models.rst b/sphinx/source/models.rst index e5aaf75d31..765ab551db 100644 --- a/sphinx/source/models.rst +++ b/sphinx/source/models.rst @@ -80,6 +80,9 @@ Model Components Kernels ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.models.kernels.categorical +.. autoclass:: CategoricalKernel + .. automodule:: botorch.models.kernels.downsampling .. autoclass:: DownsamplingKernel diff --git a/test/models/kernels/test_categorical.py b/test/models/kernels/test_categorical.py new file mode 100644 index 0000000000..d26f565f71 --- /dev/null +++ b/test/models/kernels/test_categorical.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from botorch.models.kernels.categorical import CategoricalKernel +from botorch.utils.testing import BotorchTestCase +from gpytorch.test.base_kernel_test_case import BaseKernelTestCase + + +class TestCategoricalKernel(BotorchTestCase, BaseKernelTestCase): + def create_kernel_no_ard(self, **kwargs): + return CategoricalKernel(**kwargs) + + def create_data_no_batch(self): + return torch.randint(3, size=(5, 10)).to(dtype=torch.float) + + def create_data_single_batch(self): + return torch.randint(3, size=(2, 5, 3)).to(dtype=torch.float) + + def create_data_double_batch(self): + return torch.randint(3, size=(3, 2, 5, 3)).to(dtype=torch.float) + + def test_initialize_lengthscale(self): + kernel = CategoricalKernel() + kernel.initialize(lengthscale=1) + actual_value = torch.tensor(1.0).view_as(kernel.lengthscale) + self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5) + + def test_initialize_lengthscale_batch(self): + kernel = CategoricalKernel(batch_shape=torch.Size([2])) + ls_init = torch.tensor([1.0, 2.0]) + kernel.initialize(lengthscale=ls_init) + actual_value = ls_init.view_as(kernel.lengthscale) + self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5) + + def test_forward(self): + x1 = torch.tensor([[4, 2], [3, 1], [8, 5], [7, 6]], dtype=torch.float) + x2 = torch.tensor([[4, 2], [3, 0], [4, 4]], dtype=torch.float) + lengthscale = 2 + kernel = CategoricalKernel().initialize(lengthscale=lengthscale) + kernel.eval() + sq_sc_dists = (x1.unsqueeze(-2) != x2.unsqueeze(-3)) ** 2 / lengthscale ** 2 + actual = torch.exp(-sq_sc_dists.mean(-1)) + res = kernel(x1, x2).evaluate() + self.assertTrue(torch.allclose(res, actual)) + + def test_active_dims(self): + x1 = torch.tensor([[4, 2], [3, 1], [8, 5], [7, 6]], dtype=torch.float) + x2 = torch.tensor([[4, 2], [3, 0], [4, 4]], dtype=torch.float) + lengthscale = 2 + kernel = CategoricalKernel(active_dims=[0]).initialize(lengthscale=lengthscale) + kernel.eval() + dists = x1[:, :1].unsqueeze(-2) != x2[:, :1].unsqueeze(-3) + sq_sc_dists = dists ** 2 / lengthscale ** 2 + actual = torch.exp(-sq_sc_dists.mean(-1)) + res = kernel(x1, x2).evaluate() + self.assertTrue(torch.allclose(res, actual)) + + def test_ard(self): + x1 = torch.tensor([[4, 2], [3, 1], [8, 5]], dtype=torch.float) + x2 = torch.tensor([[4, 2], [3, 0], [4, 4]], dtype=torch.float) + lengthscales = torch.tensor([1, 2], dtype=torch.float).view(1, 1, 2) + + kernel = CategoricalKernel(ard_num_dims=2) + kernel.initialize(lengthscale=lengthscales) + kernel.eval() + + sq_sc_dists = ( + x1.unsqueeze(-2) != x2.unsqueeze(-3) + ) ** 2 / lengthscales.unsqueeze(-2) ** 2 + actual = torch.exp(-sq_sc_dists.mean(-1)) + res = kernel(x1, x2).evaluate() + self.assertTrue(torch.allclose(res, actual)) + + # diag + res = kernel(x1, x2).diag() + actual = torch.diagonal(actual, dim1=-1, dim2=-2) + self.assertTrue(torch.allclose(res, actual)) + + # batch_dims + actual = torch.exp(-sq_sc_dists).transpose(-1, -3) + res = kernel(x1, x2, last_dim_is_batch=True).evaluate() + self.assertTrue(torch.allclose(res, actual)) + + # batch_dims + diag + res = kernel(x1, x2, last_dim_is_batch=True).diag() + self.assertTrue(torch.allclose(res, torch.diagonal(actual, dim1=-1, dim2=-2))) + + def test_ard_batch(self): + x1 = torch.tensor( + [ + [[4, 2, 1], [3, 1, 5]], + [[3, 2, 3], [6, 1, 7]], + ], + dtype=torch.float, + ) + x2 = torch.tensor([[[4, 2, 1], [6, 0, 0]]], dtype=torch.float) + lengthscales = torch.tensor([[[1, 2, 1]]], dtype=torch.float) + + kernel = CategoricalKernel(batch_shape=torch.Size([2]), ard_num_dims=3) + kernel.initialize(lengthscale=lengthscales) + kernel.eval() + + sq_sc_dists = ( + x1.unsqueeze(-2) != x2.unsqueeze(-3) + ) ** 2 / lengthscales.unsqueeze(-2) ** 2 + actual = torch.exp(-sq_sc_dists.mean(-1)) + res = kernel(x1, x2).evaluate() + self.assertTrue(torch.allclose(res, actual)) + + def test_ard_separate_batch(self): + x1 = torch.tensor( + [ + [[4, 2, 1], [3, 1, 5]], + [[3, 2, 3], [6, 1, 7]], + ], + dtype=torch.float, + ) + x2 = torch.tensor([[[4, 2, 1], [6, 0, 0]]], dtype=torch.float) + lengthscales = torch.tensor([[[1, 2, 1]], [[2, 1, 0.5]]], dtype=torch.float) + + kernel = CategoricalKernel(batch_shape=torch.Size([2]), ard_num_dims=3) + kernel.initialize(lengthscale=lengthscales) + kernel.eval() + + sq_sc_dists = ( + x1.unsqueeze(-2) != x2.unsqueeze(-3) + ) ** 2 / lengthscales.unsqueeze(-2) ** 2 + actual = torch.exp(-sq_sc_dists.mean(-1)) + res = kernel(x1, x2).evaluate() + self.assertTrue(torch.allclose(res, actual)) + + # diag + res = kernel(x1, x2).diag() + actual = torch.diagonal(actual, dim1=-1, dim2=-2) + self.assertTrue(torch.allclose(res, actual)) + + # batch_dims + actual = torch.exp(-sq_sc_dists).transpose(-1, -3) + res = kernel(x1, x2, last_dim_is_batch=True).evaluate() + self.assertTrue(torch.allclose(res, actual)) + + # batch_dims + diag + res = kernel(x1, x2, last_dim_is_batch=True).diag() + self.assertTrue(torch.allclose(res, torch.diagonal(actual, dim1=-1, dim2=-2))) diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index dbacaa24b5..c3680104df 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -53,12 +53,12 @@ def test_optimize_acqf_joint( raw_samples = 10 options = {} mock_acq_function = MockAcquisitionFunction() - cnt = 1 + cnt = 0 for dtype in (torch.float, torch.double): mock_gen_batch_initial_conditions.return_value = torch.zeros( num_restarts, q, 3, device=self.device, dtype=dtype ) - base_cand = torch.ones(1, q, 3, device=self.device, dtype=dtype) + base_cand = torch.arange(3, device=self.device, dtype=dtype).expand(1, q, 3) mock_candidates = torch.cat( [i * base_cand for i in range(num_restarts)], dim=0 ) @@ -82,7 +82,10 @@ def test_optimize_acqf_joint( ) self.assertTrue(torch.equal(candidates, mock_candidates[0])) self.assertTrue(torch.equal(acq_vals, mock_acq_values[0])) + cnt += 1 + self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) + # test generation with provided initial conditions candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, @@ -98,7 +101,46 @@ def test_optimize_acqf_joint( self.assertTrue(torch.equal(candidates, mock_candidates)) self.assertTrue(torch.equal(acq_vals, mock_acq_values)) self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) + + # test fixed features + fixed_features = {0: 0.1} + mock_candidates[:, 0] = 0.1 + mock_gen_candidates.return_value = (mock_candidates, mock_acq_values) + candidates, acq_vals = optimize_acqf( + acq_function=mock_acq_function, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + options=options, + fixed_features=fixed_features, + ) + self.assertEqual( + mock_gen_candidates.call_args[1]["fixed_features"], fixed_features + ) + self.assertTrue(torch.equal(candidates, mock_candidates[0])) cnt += 1 + self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) + + # test trivial case when all features are fixed + candidates, acq_vals = optimize_acqf( + acq_function=mock_acq_function, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + options=options, + fixed_features={0: 0.1, 1: 0.2, 2: 0.3}, + ) + self.assertTrue( + torch.equal( + candidates, + torch.tensor( + [0.1, 0.2, 0.3], device=self.device, dtype=dtype + ).expand(3, 3), + ) + ) + self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) # test OneShotAcquisitionFunction mock_acq_function = MockOneShotAcquisitionFunction()