Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions botorch/utils/multi_objective/box_decompositions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
# LICENSE file in the root directory of this source tree.


from botorch.utils.multi_objective.box_decompositions.box_decomposition_list import ( # noqa E501
BoxDecompositionList,
)
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
FastNondominatedPartitioning,
NondominatedPartitioning,
Expand All @@ -16,6 +19,7 @@

__all__ = [
"compute_non_dominated_hypercell_bounds_2d",
"BoxDecompositionList",
"FastNondominatedPartitioning",
"NondominatedPartitioning",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""Box decomposition container."""

from __future__ import annotations

from typing import List, Union

import torch
from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.utils.multi_objective.box_decompositions.box_decomposition import (
BoxDecomposition,
)
from torch import Tensor
from torch.nn import ModuleList, Module


class BoxDecompositionList(Module):
r"""A list of box decompositions."""

def __init__(self, *box_decompositions: BoxDecomposition) -> None:
r"""Initialize the box decomposition list.

Args:
*box_decompositions: An variable number of box decompositions

Example:
>>> bd1 = FastNondominatedPartitioning(ref_point, Y=Y1)
>>> bd2 = FastNondominatedPartitioning(ref_point, Y=Y2)
>>> bd = BoxDecompositionList(bd1, bd2)
"""
super().__init__()
self.box_decompositions = ModuleList(box_decompositions)

@property
def pareto_Y(self) -> List[Tensor]:
r"""This returns the non-dominated set.

Note: Internally, we store the negative pareto set (minimization).

Returns:
A list where the ith element is the `n_pareto_i x m`-dim tensor
of pareto optimal outcomes for each box_decomposition `i`.
"""
return [p.pareto_Y for p in self.box_decompositions]

@property
def ref_point(self) -> Tensor:
r"""Get the reference point.

Note: Internally, we store the negative reference point (minimization).

Returns:
A `n_box_decompositions x m`-dim tensor of outcomes.
"""
return torch.stack([p.ref_point for p in self.box_decompositions], dim=0)

def get_hypercell_bounds(self) -> Tensor:
r"""Get the bounds of each hypercell in the decomposition.

Returns:
A `2 x n_box_decompositions x num_cells x num_outcomes`-dim tensor
containing the lower and upper vertices bounding each hypercell.
"""
bounds_list = []
max_num_cells = 0
for p in self.box_decompositions:
bounds = p.get_hypercell_bounds()
max_num_cells = max(max_num_cells, bounds.shape[-2])
bounds_list.append(bounds)
# pad the decomposition with empty cells so that all
# decompositions have the same number of cells
for i, bounds in enumerate(bounds_list):
num_missing = max_num_cells - bounds.shape[-2]
if num_missing > 0:
padding = torch.zeros(
2,
num_missing,
bounds.shape[-1],
dtype=bounds.dtype,
device=bounds.device,
)
bounds_list[i] = torch.cat(
[
bounds,
padding,
],
dim=-2,
)

return torch.stack(bounds_list, dim=-3)

def update(self, Y: Union[List[Tensor], Tensor]) -> None:
r"""Update the partitioning.

Args:
Y: A `n_box_decompositions x n x num_outcomes`-dim tensor or a list
where the ith element contains the new points for
box_decomposition `i`.
"""
if (
torch.is_tensor(Y)
and Y.ndim != 3
and Y.shape[0] != len(self.box_decompositions)
) or (isinstance(Y, List) and len(Y) != len(self.box_decompositions)):
raise BotorchTensorDimensionError(
"BoxDecompositionList.update requires either a batched tensor Y, "
"with one batch per box decomposition or a list of tensors with "
"one element per box decomposition."
)
for i, p in enumerate(self.box_decompositions):
p.update(Y[i])

def compute_hypervolume(self) -> Tensor:
r"""Compute hypervolume that is dominated by the Pareto Froniter.

Returns:
A `(batch_shape)`-dim tensor containing the hypervolume dominated by
each Pareto frontier.
"""
return torch.stack(
[p.compute_hypervolume() for p in self.box_decompositions], dim=0
)
5 changes: 5 additions & 0 deletions sphinx/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ Abstract Box Decompositions
.. automodule:: botorch.utils.multi_objective.box_decompositions.box_decomposition
:members:

Box Decomposition List
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.utils.multi_objective.box_decompositions.box_decomposition_list
:members:

Box Decomposition Utilities
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.utils.multi_objective.box_decompositions.utils
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#! /usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from itertools import product

import torch
from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.utils.multi_objective.box_decompositions.box_decomposition_list import (
BoxDecompositionList,
)
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
FastNondominatedPartitioning,
)
from botorch.utils.testing import BotorchTestCase


class TestBoxDecompositionList(BotorchTestCase):
def test_box_decomposition_list(self):
ref_point_raw = torch.zeros(3, device=self.device)
pareto_Y_raw = torch.tensor(
[
[1.0, 2.0, 1.0],
[2.0, 0.5, 1.0],
],
device=self.device,
)
for m, dtype in product((2, 3), (torch.float, torch.double)):
ref_point = ref_point_raw[:m].to(dtype=dtype)
pareto_Y = pareto_Y_raw[:, :m].to(dtype=dtype)
pareto_Y_list = [pareto_Y[:0, :m], pareto_Y[:, :m]]
bds = [
FastNondominatedPartitioning(ref_point=ref_point, Y=Y)
for Y in pareto_Y_list
]
bd = BoxDecompositionList(*bds)
# test pareto Y
bd_pareto_Y_list = bd.pareto_Y
pareto_Y1 = pareto_Y_list[1]
expected_pareto_Y1 = (
pareto_Y1[torch.argsort(-pareto_Y1[:, 0])] if m == 2 else pareto_Y1
)
self.assertTrue(torch.equal(bd_pareto_Y_list[0], pareto_Y_list[0]))
self.assertTrue(torch.equal(bd_pareto_Y_list[1], expected_pareto_Y1))
# test ref_point
self.assertTrue(
torch.equal(bd.ref_point, ref_point.unsqueeze(0).expand(2, -1))
)
# test get_hypercell_bounds
cell_bounds = bd.get_hypercell_bounds()
expected_cell_bounds1 = bds[1].get_hypercell_bounds()
self.assertTrue(torch.equal(cell_bounds[:, 1], expected_cell_bounds1))
# the first pareto set in the list is empty so the cell bounds
# should contain one cell that spans the entire area (bounded by the
# ref_point) and then empty cells, bounded from above and below by the
# ref point.
expected_cell_bounds0 = torch.zeros_like(expected_cell_bounds1)
# set the upper bound for the first cell to be inf
expected_cell_bounds0[1, 0, :] = float("inf")
self.assertTrue(torch.equal(cell_bounds[:, 0], expected_cell_bounds0))
# test compute_hypervolume
expected_hv = torch.stack([b.compute_hypervolume() for b in bds], dim=0)
hv = bd.compute_hypervolume()
self.assertTrue(torch.equal(expected_hv, hv))

# test update with batched tensor
new_Y = torch.empty(2, 1, m, dtype=dtype, device=self.device)
new_Y[0] = 1
new_Y[1] = 3
bd.update(new_Y)
bd_pareto_Y_list = bd.pareto_Y
self.assertTrue(torch.equal(bd_pareto_Y_list[0], new_Y[0]))
self.assertTrue(torch.equal(bd_pareto_Y_list[1], new_Y[1]))

# test update with list
bd = BoxDecompositionList(*bds)
bd.update([new_Y[0], new_Y[1]])
bd_pareto_Y_list = bd.pareto_Y
self.assertTrue(torch.equal(bd_pareto_Y_list[0], new_Y[0]))
self.assertTrue(torch.equal(bd_pareto_Y_list[1], new_Y[1]))

# test update with wrong shape
bd = BoxDecompositionList(*bds)
with self.assertRaises(BotorchTensorDimensionError):
bd.update(new_Y.unsqueeze(0))