Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions botorch/models/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from botorch.models.kernels.categorical import CategoricalKernel
from botorch.models.kernels.downsampling import DownsamplingKernel
from botorch.models.kernels.exponential_decay import ExponentialDecayKernel
from botorch.models.kernels.linear_truncated_fidelity import (
Expand All @@ -12,6 +13,7 @@


__all__ = [
"CategoricalKernel",
"DownsamplingKernel",
"ExponentialDecayKernel",
"LinearTruncatedFidelityKernel",
Expand Down
40 changes: 40 additions & 0 deletions botorch/models/kernels/categorical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from gpytorch.kernels.kernel import Kernel
from torch import Tensor


class CategoricalKernel(Kernel):
r"""A Kernel for categorical features.

Computes `exp(-(dist(x1, x2) / lengthscale)**2)`, where
`dist(x1, x2)` is zero if `x1 == x2` and one if `x1 != x2`.

Note: This kernel is NOT differentiable w.r.t. the inputs.
"""

has_lengthscale = True

def forward(
self,
x1: Tensor,
x2: Tensor,
diag: bool = False,
last_dim_is_batch: bool = False,
**kwargs
) -> Tensor:
delta = x1.unsqueeze(-2) != x2.unsqueeze(-3)
dists = (delta / self.lengthscale.unsqueeze(-2)).pow(2)
if last_dim_is_batch:
dists = dists.transpose(-3, -1)
else:
dists = dists.mean(-1)
res = torch.exp(-dists)
if diag:
res = torch.diagonal(res, dim1=-1, dim2=-2)
return res
13 changes: 12 additions & 1 deletion botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,24 @@ def optimize_acqf(

options = options or {}

# Handle the trivial case when all features are fixed
if fixed_features is not None and len(fixed_features) == bounds.shape[-1]:
X = torch.tensor(
[fixed_features[i] for i in range(bounds.shape[-1])],
device=bounds.device,
dtype=bounds.dtype,
)
X = X.expand(q, *X.shape)
with torch.no_grad():
acq_value = acq_function(X)
return X, acq_value

if batch_initial_conditions is None:
ic_gen = (
gen_one_shot_kg_initial_conditions
if isinstance(acq_function, qKnowledgeGradient)
else gen_batch_initial_conditions
)
# TODO: Generating initial candidates should use parameter constraints.
batch_initial_conditions = ic_gen(
acq_function=acq_function,
bounds=bounds,
Expand Down
3 changes: 3 additions & 0 deletions sphinx/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ Model Components

Kernels
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.models.kernels.categorical
.. autoclass:: CategoricalKernel

.. automodule:: botorch.models.kernels.downsampling
.. autoclass:: DownsamplingKernel

Expand Down
148 changes: 148 additions & 0 deletions test/models/kernels/test_categorical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from botorch.models.kernels.categorical import CategoricalKernel
from botorch.utils.testing import BotorchTestCase
from gpytorch.test.base_kernel_test_case import BaseKernelTestCase


class TestCategoricalKernel(BotorchTestCase, BaseKernelTestCase):
def create_kernel_no_ard(self, **kwargs):
return CategoricalKernel(**kwargs)

def create_data_no_batch(self):
return torch.randint(3, size=(5, 10)).to(dtype=torch.float)

def create_data_single_batch(self):
return torch.randint(3, size=(2, 5, 3)).to(dtype=torch.float)

def create_data_double_batch(self):
return torch.randint(3, size=(3, 2, 5, 3)).to(dtype=torch.float)

def test_initialize_lengthscale(self):
kernel = CategoricalKernel()
kernel.initialize(lengthscale=1)
actual_value = torch.tensor(1.0).view_as(kernel.lengthscale)
self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5)

def test_initialize_lengthscale_batch(self):
kernel = CategoricalKernel(batch_shape=torch.Size([2]))
ls_init = torch.tensor([1.0, 2.0])
kernel.initialize(lengthscale=ls_init)
actual_value = ls_init.view_as(kernel.lengthscale)
self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5)

def test_forward(self):
x1 = torch.tensor([[4, 2], [3, 1], [8, 5], [7, 6]], dtype=torch.float)
x2 = torch.tensor([[4, 2], [3, 0], [4, 4]], dtype=torch.float)
lengthscale = 2
kernel = CategoricalKernel().initialize(lengthscale=lengthscale)
kernel.eval()
sq_sc_dists = (x1.unsqueeze(-2) != x2.unsqueeze(-3)) ** 2 / lengthscale ** 2
actual = torch.exp(-sq_sc_dists.mean(-1))
res = kernel(x1, x2).evaluate()
self.assertTrue(torch.allclose(res, actual))

def test_active_dims(self):
x1 = torch.tensor([[4, 2], [3, 1], [8, 5], [7, 6]], dtype=torch.float)
x2 = torch.tensor([[4, 2], [3, 0], [4, 4]], dtype=torch.float)
lengthscale = 2
kernel = CategoricalKernel(active_dims=[0]).initialize(lengthscale=lengthscale)
kernel.eval()
dists = x1[:, :1].unsqueeze(-2) != x2[:, :1].unsqueeze(-3)
sq_sc_dists = dists ** 2 / lengthscale ** 2
actual = torch.exp(-sq_sc_dists.mean(-1))
res = kernel(x1, x2).evaluate()
self.assertTrue(torch.allclose(res, actual))

def test_ard(self):
x1 = torch.tensor([[4, 2], [3, 1], [8, 5]], dtype=torch.float)
x2 = torch.tensor([[4, 2], [3, 0], [4, 4]], dtype=torch.float)
lengthscales = torch.tensor([1, 2], dtype=torch.float).view(1, 1, 2)

kernel = CategoricalKernel(ard_num_dims=2)
kernel.initialize(lengthscale=lengthscales)
kernel.eval()

sq_sc_dists = (
x1.unsqueeze(-2) != x2.unsqueeze(-3)
) ** 2 / lengthscales.unsqueeze(-2) ** 2
actual = torch.exp(-sq_sc_dists.mean(-1))
res = kernel(x1, x2).evaluate()
self.assertTrue(torch.allclose(res, actual))

# diag
res = kernel(x1, x2).diag()
actual = torch.diagonal(actual, dim1=-1, dim2=-2)
self.assertTrue(torch.allclose(res, actual))

# batch_dims
actual = torch.exp(-sq_sc_dists).transpose(-1, -3)
res = kernel(x1, x2, last_dim_is_batch=True).evaluate()
self.assertTrue(torch.allclose(res, actual))

# batch_dims + diag
res = kernel(x1, x2, last_dim_is_batch=True).diag()
self.assertTrue(torch.allclose(res, torch.diagonal(actual, dim1=-1, dim2=-2)))

def test_ard_batch(self):
x1 = torch.tensor(
[
[[4, 2, 1], [3, 1, 5]],
[[3, 2, 3], [6, 1, 7]],
],
dtype=torch.float,
)
x2 = torch.tensor([[[4, 2, 1], [6, 0, 0]]], dtype=torch.float)
lengthscales = torch.tensor([[[1, 2, 1]]], dtype=torch.float)

kernel = CategoricalKernel(batch_shape=torch.Size([2]), ard_num_dims=3)
kernel.initialize(lengthscale=lengthscales)
kernel.eval()

sq_sc_dists = (
x1.unsqueeze(-2) != x2.unsqueeze(-3)
) ** 2 / lengthscales.unsqueeze(-2) ** 2
actual = torch.exp(-sq_sc_dists.mean(-1))
res = kernel(x1, x2).evaluate()
self.assertTrue(torch.allclose(res, actual))

def test_ard_separate_batch(self):
x1 = torch.tensor(
[
[[4, 2, 1], [3, 1, 5]],
[[3, 2, 3], [6, 1, 7]],
],
dtype=torch.float,
)
x2 = torch.tensor([[[4, 2, 1], [6, 0, 0]]], dtype=torch.float)
lengthscales = torch.tensor([[[1, 2, 1]], [[2, 1, 0.5]]], dtype=torch.float)

kernel = CategoricalKernel(batch_shape=torch.Size([2]), ard_num_dims=3)
kernel.initialize(lengthscale=lengthscales)
kernel.eval()

sq_sc_dists = (
x1.unsqueeze(-2) != x2.unsqueeze(-3)
) ** 2 / lengthscales.unsqueeze(-2) ** 2
actual = torch.exp(-sq_sc_dists.mean(-1))
res = kernel(x1, x2).evaluate()
self.assertTrue(torch.allclose(res, actual))

# diag
res = kernel(x1, x2).diag()
actual = torch.diagonal(actual, dim1=-1, dim2=-2)
self.assertTrue(torch.allclose(res, actual))

# batch_dims
actual = torch.exp(-sq_sc_dists).transpose(-1, -3)
res = kernel(x1, x2, last_dim_is_batch=True).evaluate()
self.assertTrue(torch.allclose(res, actual))

# batch_dims + diag
res = kernel(x1, x2, last_dim_is_batch=True).diag()
self.assertTrue(torch.allclose(res, torch.diagonal(actual, dim1=-1, dim2=-2)))
46 changes: 44 additions & 2 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ def test_optimize_acqf_joint(
raw_samples = 10
options = {}
mock_acq_function = MockAcquisitionFunction()
cnt = 1
cnt = 0
for dtype in (torch.float, torch.double):
mock_gen_batch_initial_conditions.return_value = torch.zeros(
num_restarts, q, 3, device=self.device, dtype=dtype
)
base_cand = torch.ones(1, q, 3, device=self.device, dtype=dtype)
base_cand = torch.arange(3, device=self.device, dtype=dtype).expand(1, q, 3)
mock_candidates = torch.cat(
[i * base_cand for i in range(num_restarts)], dim=0
)
Expand All @@ -82,7 +82,10 @@ def test_optimize_acqf_joint(
)
self.assertTrue(torch.equal(candidates, mock_candidates[0]))
self.assertTrue(torch.equal(acq_vals, mock_acq_values[0]))
cnt += 1
self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

# test generation with provided initial conditions
candidates, acq_vals = optimize_acqf(
acq_function=mock_acq_function,
bounds=bounds,
Expand All @@ -98,7 +101,46 @@ def test_optimize_acqf_joint(
self.assertTrue(torch.equal(candidates, mock_candidates))
self.assertTrue(torch.equal(acq_vals, mock_acq_values))
self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

# test fixed features
fixed_features = {0: 0.1}
mock_candidates[:, 0] = 0.1
mock_gen_candidates.return_value = (mock_candidates, mock_acq_values)
candidates, acq_vals = optimize_acqf(
acq_function=mock_acq_function,
bounds=bounds,
q=q,
num_restarts=num_restarts,
raw_samples=raw_samples,
options=options,
fixed_features=fixed_features,
)
self.assertEqual(
mock_gen_candidates.call_args[1]["fixed_features"], fixed_features
)
self.assertTrue(torch.equal(candidates, mock_candidates[0]))
cnt += 1
self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

# test trivial case when all features are fixed
candidates, acq_vals = optimize_acqf(
acq_function=mock_acq_function,
bounds=bounds,
q=q,
num_restarts=num_restarts,
raw_samples=raw_samples,
options=options,
fixed_features={0: 0.1, 1: 0.2, 2: 0.3},
)
self.assertTrue(
torch.equal(
candidates,
torch.tensor(
[0.1, 0.2, 0.3], device=self.device, dtype=dtype
).expand(3, 3),
)
)
self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt)

# test OneShotAcquisitionFunction
mock_acq_function = MockOneShotAcquisitionFunction()
Expand Down