diff --git a/botorch/acquisition/multi_objective/analytic.py b/botorch/acquisition/multi_objective/analytic.py index 4a85a86f32..e55d439e70 100644 --- a/botorch/acquisition/multi_objective/analytic.py +++ b/botorch/acquisition/multi_objective/analytic.py @@ -31,7 +31,9 @@ ) from botorch.exceptions.errors import UnsupportedError from botorch.models.model import Model -from botorch.utils.multi_objective.box_decomposition import NondominatedPartitioning +from botorch.utils.multi_objective.box_decompositions.non_dominated import ( + NondominatedPartitioning, +) from botorch.utils.transforms import t_batch_mode_transform from torch import Tensor from torch.distributions import Normal @@ -139,7 +141,7 @@ def __init__( super().__init__(model=model, objective=objective) self.register_buffer("ref_point", ref_point) self.partitioning = partitioning - cell_bounds = self.partitioning.get_hypercell_bounds(ref_point=self.ref_point) + cell_bounds = self.partitioning.get_hypercell_bounds() self.register_buffer("cell_lower_bounds", cell_bounds[0]) self.register_buffer("cell_upper_bounds", cell_bounds[1]) # create indexing tensor of shape `2^m x m` diff --git a/botorch/acquisition/multi_objective/monte_carlo.py b/botorch/acquisition/multi_objective/monte_carlo.py index 90c8abea77..0a9b8d5d65 100644 --- a/botorch/acquisition/multi_objective/monte_carlo.py +++ b/botorch/acquisition/multi_objective/monte_carlo.py @@ -31,7 +31,9 @@ from botorch.exceptions.errors import UnsupportedError from botorch.models.model import Model from botorch.sampling.samplers import MCSampler, SobolQMCNormalSampler -from botorch.utils.multi_objective.box_decomposition import NondominatedPartitioning +from botorch.utils.multi_objective.box_decompositions.non_dominated import ( + NondominatedPartitioning, +) from botorch.utils.objective import apply_constraints_nonnegative_soft from botorch.utils.torch import BufferDict from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform @@ -150,7 +152,7 @@ def __init__( self.constraints = constraints self.eta = eta self.register_buffer("ref_point", ref_point) - cell_bounds = partitioning.get_hypercell_bounds(ref_point=self.ref_point) + cell_bounds = partitioning.get_hypercell_bounds() self.register_buffer("cell_lower_bounds", cell_bounds[0]) self.register_buffer("cell_upper_bounds", cell_bounds[1]) self.q = -1 diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 36b6fbb8dc..86fede0e38 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -25,7 +25,9 @@ from botorch.exceptions.warnings import SamplingWarning from botorch.models.model import Model from botorch.sampling.samplers import IIDNormalSampler, MCSampler, SobolQMCNormalSampler -from botorch.utils.multi_objective.box_decomposition import NondominatedPartitioning +from botorch.utils.multi_objective.box_decompositions.non_dominated import ( + NondominatedPartitioning, +) from botorch.utils.transforms import squeeze_last_dim from torch import Tensor from torch.quasirandom import SobolEngine @@ -122,19 +124,23 @@ def get_acquisition_function( ) elif acquisition_function_name == "qEHVI": # pyre-fixme [16]: `Model` has no attribute `train_targets` - if "ref_point" not in kwargs: + try: + ref_point = kwargs["ref_point"] + except KeyError: raise ValueError("`ref_point` must be specified in kwargs for qEHVI") - if "Y" not in kwargs: + try: + Y = kwargs["Y"] + except KeyError: raise ValueError("`Y` must be specified in kwargs for qEHVI") - ref_point = kwargs["ref_point"] - Y = kwargs.get("Y") # get feasible points if constraints is not None: feas = torch.stack([c(Y) <= 0 for c in constraints], dim=-1).all(dim=-1) Y = Y[feas] obj = objective(Y) partitioning = NondominatedPartitioning( - num_outcomes=obj.shape[-1], Y=obj, alpha=kwargs.get("alpha", 0.0) + ref_point=torch.as_tensor(ref_point, dtype=Y.dtype, device=Y.device), + Y=obj, + alpha=kwargs.get("alpha", 0.0), ) return moo_monte_carlo.qExpectedHypervolumeImprovement( model=model, diff --git a/botorch/utils/multi_objective/__init__.py b/botorch/utils/multi_objective/__init__.py index 95c4ff0cfa..36fc11800d 100644 --- a/botorch/utils/multi_objective/__init__.py +++ b/botorch/utils/multi_objective/__init__.py @@ -4,7 +4,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from botorch.utils.multi_objective.box_decomposition import NondominatedPartitioning from botorch.utils.multi_objective.hypervolume import Hypervolume from botorch.utils.multi_objective.pareto import is_non_dominated from botorch.utils.multi_objective.scalarization import get_chebyshev_scalarization @@ -14,5 +13,4 @@ "get_chebyshev_scalarization", "is_non_dominated", "Hypervolume", - "NondominatedPartitioning", ] diff --git a/botorch/utils/multi_objective/box_decomposition.py b/botorch/utils/multi_objective/box_decomposition.py index a530f77891..e566f0c69e 100644 --- a/botorch/utils/multi_objective/box_decomposition.py +++ b/botorch/utils/multi_objective/box_decomposition.py @@ -4,457 +4,22 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -r"""Algorithms for partitioning the non-dominated space into rectangles. - -References - -.. [Couckuyt2012] - I. Couckuyt, D. Deschrijver and T. Dhaene, "Towards Efficient - Multiobjective Optimization: Multiobjective statistical criterions," - 2012 IEEE Congress on Evolutionary Computation, Brisbane, QLD, 2012, - pp. 1-8. - +r""" +DEPRECATED - Box decomposition algorithms. +Use the botorch.utils.multi_objective.box_decompositions instead. """ -from __future__ import annotations - -from typing import Optional - -import torch -from botorch.exceptions.errors import BotorchError, BotorchTensorDimensionError -from botorch.utils.multi_objective.pareto import is_non_dominated -from torch import Tensor -from torch.nn import Module - - -class NondominatedPartitioning(Module): - r"""A class for partitioning the non-dominated space into hyper-cells. - - Note: this assumes maximization. Internally, it multiplies by -1 and performs - the decomposition under minimization. TODO: use maximization internally as well. - - Note: it is only feasible to use this algorithm to compute an exact - decomposition of the non-dominated space for `m<5` objectives (alpha=0.0). - - The alpha parameter can be increased to obtain an approximate partitioning - faster. The `alpha` is a fraction of the total hypervolume encapsuling the - entire Pareto set. When a hypercell's volume divided by the total hypervolume - is less than `alpha`, we discard the hypercell. See Figure 2 in - [Couckuyt2012]_ for a visual representation. - - This PyTorch implementation of the binary partitioning algorithm ([Couckuyt2012]_) - is adapted from numpy/tensorflow implementation at: - https://github.com/GPflow/GPflowOpt/blob/master/gpflowopt/pareto.py. - - TODO: replace this with a more efficient decomposition. E.g. - https://link.springer.com/content/pdf/10.1007/s10898-019-00798-7.pdf - """ - - def __init__( - self, - num_outcomes: int, - Y: Optional[Tensor] = None, - alpha: float = 0.0, - eps: Optional[float] = None, - ) -> None: - """Initialize NondominatedPartitioning. - - Args: - num_outcomes: The number of outcomes - Y: A `(batch_shape) x n x m`-dim tensor - alpha: a thresold fraction of total volume used in an approximate - decomposition. - eps: a small value for numerical stability - """ - super().__init__() - self.alpha = alpha - self.num_outcomes = num_outcomes - self._eps = eps - if Y is not None: - self.update(Y=Y) - - @property - def eps(self) -> float: - if self._eps is not None: - return self._eps - try: - return 1e-6 if self._pareto_Y.dtype == torch.float else 1e-8 - except AttributeError: - return 1e-6 - - @property - def pareto_Y(self) -> Tensor: - r"""This returns the non-dominated set. - - Note: in the batch case, this Pareto set is padded by repeating a - Pareto point so that all batches have the same size Pareto set. - - Note: Internally, we store the negative Pareto set (minimization). - - Returns: - A `(batch_shape) x max_n_pareto x m`-dim tensor of outcomes. - """ - if not hasattr(self, "_pareto_Y"): - raise BotorchError("pareto_Y has not been initialized") - return -self._pareto_Y - - def _update_pareto_Y(self) -> bool: - r"""Update the non-dominated front.""" - # is_non_dominated assumes maximization - pareto_mask = is_non_dominated(-self.Y) - - if len(self.batch_shape) > 0: - # Note: in the batch case, the Pareto frontier is padded by repeating - # a Pareto point. This ensures that the padded box-decomposition has - # the same number of points, which enables fast batch operations. - max_n_pareto = pareto_mask.sum(dim=-1).max().item() - pareto_Y = torch.empty( - *self.batch_shape, - max_n_pareto, - self.Y.shape[-1], - dtype=self.Y.dtype, - device=self.Y.device, - ) - for i in range(self.Y.shape[0]): - pareto_i = self.Y[i, pareto_mask[i]] - n_pareto = pareto_i.shape[0] - pareto_Y[i, :n_pareto] = pareto_i - # pad pareto_Y, so that all batches have the same size Pareto set - pareto_Y[i, n_pareto:] = pareto_i[-1] - # sort by first objective - new_pareto_Y = pareto_Y.gather( - index=torch.argsort(pareto_Y[..., :1], dim=-2).expand(pareto_Y.shape), - dim=-2, - ) - else: - pareto_Y = self.Y[pareto_mask] - # sort by first objective - new_pareto_Y = pareto_Y[torch.argsort(pareto_Y[:, 0])] - - if not hasattr(self, "_pareto_Y") or not torch.equal( - new_pareto_Y, self._pareto_Y - ): - self.register_buffer("_pareto_Y", new_pareto_Y) - return True - return False - - def update(self, Y: Tensor) -> None: - r"""Update non-dominated front and decomposition. - - Args: - Y: A `(batch_shape) x n x m`-dim tensor of outcomes. - """ - self.batch_shape = Y.shape[:-2] - if len(self.batch_shape) > 1: - raise NotImplementedError( - f"{type(self).__name__} only supports a single " - f"batch dimension, but got {len(self.batch_shape)} " - "batch dimensions." - ) - # multiply by -1, since internally we minimize. - self.Y = -Y - is_new_pareto = self._update_pareto_Y() - # Update decomposition if the Pareto front changed - if is_new_pareto: - if self.num_outcomes > 2: - self.binary_partition_non_dominated_space() - else: - self.partition_non_dominated_space_2d() - - def binary_partition_non_dominated_space(self): - r"""Partition the non-dominated space into disjoint hypercells. - - This method works for an arbitrary number of outcomes, but is - less efficient than `partition_non_dominated_space_2d` for the - 2-outcome case. - """ - if len(self.batch_shape) > 0: - raise NotImplementedError( - f"{type(self).__name__} only supports a batched box " - f"decompositions in the 2-objective setting." - ) - # Extend Pareto front with the ideal and anti-ideal point - ideal_point = self._pareto_Y.min(dim=0, keepdim=True).values - 1 - anti_ideal_point = self._pareto_Y.max(dim=0, keepdim=True).values + 1 - - aug_pareto_Y = torch.cat([ideal_point, self._pareto_Y, anti_ideal_point], dim=0) - # The binary parititoning algorithm uses indices the augmented Pareto front. - aug_pareto_Y_idcs = self._get_augmented_pareto_front_indices() - - # Initialize one cell over entire pareto front - cell = torch.zeros(2, self.num_outcomes, dtype=torch.long, device=self.Y.device) - cell[1] = aug_pareto_Y_idcs.shape[0] - 1 - stack = [cell] - total_volume = (anti_ideal_point - ideal_point).prod() - - # hypercells contains the indices of the (augmented) Pareto front - # that specify that bounds of the each hypercell. - # It is a `2 x num_cells x num_outcomes`-dim tensor - self.register_buffer( - "hypercells", - torch.empty( - 2, 0, self.num_outcomes, dtype=torch.long, device=self.Y.device - ), - ) - outcome_idxr = torch.arange( - self.num_outcomes, dtype=torch.long, device=self.Y.device - ) - - # Use binary partitioning - while len(stack) > 0: - cell = stack.pop() - cell_bounds_pareto_idcs = aug_pareto_Y_idcs[cell, outcome_idxr] - cell_bounds_pareto_values = aug_pareto_Y[ - cell_bounds_pareto_idcs, outcome_idxr - ] - # Check cell bounds - # - if cell upper bound is better than Pareto front on all outcomes: - # - accept the cell - # - elif cell lower bound is better than Pareto front on all outcomes: - # - this means the cell overlaps the Pareto front. Divide the cell along - # - its longest edge. - if ( - ((cell_bounds_pareto_values[1] - self.eps) < self._pareto_Y) - .any(dim=1) - .all() - ): - # Cell is entirely non-dominated - self.hypercells = torch.cat( - [self.hypercells, cell_bounds_pareto_idcs.unsqueeze(1)], dim=1 - ) - elif ( - ((cell_bounds_pareto_values[0] + self.eps) < self._pareto_Y) - .any(dim=1) - .all() - ): - # The cell overlaps the pareto front - # compute the distance (in integer indices) - idx_dist = cell[1] - cell[0] - any_not_adjacent = (idx_dist > 1).any() - cell_volume = ( - (cell_bounds_pareto_values[1] - cell_bounds_pareto_values[0]) - .prod(dim=-1) - .item() - ) - - # Only divide a cell when it is not composed of adjacent indices - # and the fraction of total volume is above the approximation - # threshold fraction - if ( - any_not_adjacent - and ((cell_volume / total_volume) > self.alpha).all() - ): - # Divide the test cell over its largest dimension - # largest (by index length) - length, longest_dim = torch.max(idx_dist, dim=0) - length = length.item() - longest_dim = longest_dim.item() - - new_length1 = int(round(length / 2.0)) - new_length2 = length - new_length1 - - # Store divided cells - # cell 1: subtract new_length1 from the upper bound of the cell - # cell 2: add new_length2 to the lower bound of the cell - for bound_idx, length_delta in ( - (1, -new_length1), - (0, new_length2), - ): - new_cell = cell.clone() - new_cell[bound_idx, longest_dim] += length_delta - stack.append(new_cell) - - def partition_non_dominated_space_2d(self) -> None: - r"""Partition the non-dominated space into disjoint hypercells. - - This direct method works for `m=2` outcomes. - """ - if self.num_outcomes != 2: - raise BotorchTensorDimensionError( - "partition_non_dominated_space_2d requires 2 outputs, " - f"but num_outcomes={self.num_outcomes}" - ) - pf_ext_idx = self._get_augmented_pareto_front_indices() - n_pf_plus_1 = self._pareto_Y.shape[-2] + 1 - view_shape = torch.Size([1] * len(self.batch_shape) + [n_pf_plus_1]) - expand_shape = self.batch_shape + torch.Size([n_pf_plus_1]) - range_pf_plus1 = torch.arange( - n_pf_plus_1, dtype=torch.long, device=self._pareto_Y.device - ) - range_pf_plus1_expanded = range_pf_plus1.view(view_shape).expand(expand_shape) - - lower = torch.stack( - [range_pf_plus1_expanded, torch.zeros_like(range_pf_plus1_expanded)], dim=-1 - ) - upper = torch.stack( - [1 + range_pf_plus1_expanded, pf_ext_idx[..., -range_pf_plus1 - 1, -1]], - dim=-1, - ) - # 2 x batch_shape x n_cells x 2 - self.register_buffer("hypercells", torch.stack([lower, upper], dim=0)) - - def _get_augmented_pareto_front_indices(self) -> Tensor: - r"""Get indices of augmented Pareto front.""" - pf_idx = torch.argsort(self._pareto_Y, dim=-2) - return torch.cat( - [ - torch.zeros( - *self.batch_shape, - 1, - self.num_outcomes, - dtype=torch.long, - device=self.Y.device, - ), - # Add 1 because index zero is used for the ideal point - pf_idx + 1, - torch.full( - torch.Size( - [ - *self.batch_shape, - 1, - self.num_outcomes, - ] - ), - self._pareto_Y.shape[-2] + 1, - dtype=torch.long, - device=self.Y.device, - ), - ], - dim=-2, - ) - - def _expand_ref_point(self, ref_point: Tensor) -> Tensor: - r"""Expand reference point to the proper batch_shape.""" - if ref_point.shape[:-1] != self.batch_shape: - if ref_point.ndim > 1: - raise BotorchTensorDimensionError( - "Expected ref_point to be a `batch_shape x m` or `m`-dim tensor, " - f"but got {ref_point.shape}." - ) - ref_point = ref_point.view( - *(1 for _ in self.batch_shape), ref_point.shape[-1] - ).expand(self.batch_shape + ref_point.shape[-1:]) - return ref_point - - def get_hypercell_bounds(self, ref_point: Tensor) -> Tensor: - r"""Get the bounds of each hypercell in the decomposition. - - Args: - ref_point: A `(batch_shape) x m`-dim tensor containing the reference point. - - Returns: - A `2 x (batch_shape) x num_cells x num_outcomes`-dim tensor containing the - lower and upper vertices bounding each hypercell. - """ - ref_point = self._expand_ref_point(ref_point=ref_point) - aug_pareto_Y = torch.cat( - [ - # -inf is the lower bound of the non-dominated space - torch.full( - torch.Size( - [ - *self.batch_shape, - 1, - self.num_outcomes, - ] - ), - float("-inf"), - dtype=self._pareto_Y.dtype, - device=self._pareto_Y.device, - ), - self._pareto_Y, - # note: internally, this class minimizes, so use negative here - -(ref_point.unsqueeze(-2)), - ], - dim=-2, - ) - minimization_cell_bounds = self._get_hypercell_bounds(aug_pareto_Y=aug_pareto_Y) - # swap upper and lower bounds and multiply by -1 - return -minimization_cell_bounds.flip(0) - - def _get_hypercell_bounds(self, aug_pareto_Y: Tensor) -> Tensor: - r"""Get the bounds of each hypercell in the decomposition. - - Args: - aug_pareto_Y: A `n_pareto + 2 x m`-dim tensor containing - the augmented Pareto front. - - Returns: - A `2 x (batch_shape) x num_cells x num_outcomes`-dim tensor containing the - lower and upper vertices bounding each hypercell. - """ - num_cells = self.hypercells.shape[-2] - cells_times_outcomes = num_cells * self.num_outcomes - outcome_idxr = ( - torch.arange(self.num_outcomes, dtype=torch.long, device=self.Y.device) - .repeat(num_cells) - .view( - *(1 for _ in self.hypercells.shape[:-2]), - cells_times_outcomes, - ) - .expand(*self.hypercells.shape[:-2], cells_times_outcomes) - ) - - # this tensor is 2 x (num_cells * num_outcomes) x 2 - # the batch dim corresponds to lower/upper bound - cell_bounds_idxr = torch.stack( - [ - self.hypercells.view(*self.hypercells.shape[:-2], -1), - outcome_idxr, - ], - dim=-1, - ).view(2, -1, 2) - if len(self.batch_shape) > 0: - # TODO: support multiple batch dimensions here - batch_idxr = ( - torch.arange( - self.batch_shape[0], dtype=torch.long, device=self.Y.device - ) - .unsqueeze(1) - .expand(-1, cells_times_outcomes) - .reshape(1, -1, 1) - .expand(2, -1, 1) - ) - cell_bounds_idxr = torch.cat([batch_idxr, cell_bounds_idxr], dim=-1) - - cell_bounds_values = aug_pareto_Y[ - cell_bounds_idxr.chunk(cell_bounds_idxr.shape[-1], dim=-1) - ] - view_shape = (2, *self.batch_shape, num_cells, self.num_outcomes) - return cell_bounds_values.view(view_shape) - - def compute_hypervolume(self, ref_point: Tensor) -> Tensor: - r"""Compute the hypervolume for the given reference point. - - Note: This assumes minimization. - - This method computes the hypervolume of the non-dominated space - and computes the difference between the hypervolume between the - ideal point and hypervolume of the non-dominated space. +import warnings - Note there are much more efficient alternatives for computing - hypervolume when m > 2 (which do not require partitioning the - non-dominated space). Given such a partitioning, this method - is quite fast. +from botorch.utils.multi_objective.box_decompositions.non_dominated import ( # noqa F401 + NondominatedPartitioning, +) - Args: - ref_point: A `(batch_shape) x m`-dim tensor containing the reference point. - Returns: - `(batch_shape)`-dim tensor containing the dominated hypervolume. - """ - ref_point = self._expand_ref_point(ref_point=ref_point) - # internally we minimize - ref_point = -ref_point.unsqueeze(-2) - if (self._pareto_Y >= ref_point).any(): - raise ValueError( - "The reference point must be greater than all pareto_Y values." - ) - ideal_point = self._pareto_Y.min(dim=-2, keepdim=True).values - aug_pareto_Y = torch.cat([ideal_point, self._pareto_Y, ref_point], dim=-2) - cell_bounds_values = self._get_hypercell_bounds(aug_pareto_Y=aug_pareto_Y) - total_volume = (ref_point - ideal_point).squeeze(-2).prod(dim=-1) - non_dom_volume = ( - (cell_bounds_values[1] - cell_bounds_values[0]).prod(dim=-1).sum(dim=-1) - ) - return total_volume - non_dom_volume +warnings.warn( + "The botorch.utils.multi_objective.box_decomposition module has " + "been renamed to botorch.utils.multi_objective.box_decompositions. " + "botorch.utils.multi_objective.box_decomposition will be removed in " + "the next release.", + DeprecationWarning, +) diff --git a/botorch/utils/multi_objective/box_decompositions/__init__.py b/botorch/utils/multi_objective/box_decompositions/__init__.py new file mode 100644 index 0000000000..c22400fde6 --- /dev/null +++ b/botorch/utils/multi_objective/box_decompositions/__init__.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from botorch.utils.multi_objective.box_decompositions.non_dominated import ( + NondominatedPartitioning, +) + +__all__ = [ + "NondominatedPartitioning", +] diff --git a/botorch/utils/multi_objective/box_decompositions/box_decomposition.py b/botorch/utils/multi_objective/box_decompositions/box_decomposition.py new file mode 100644 index 0000000000..3c44f7047d --- /dev/null +++ b/botorch/utils/multi_objective/box_decompositions/box_decomposition.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r"""Box decomposition algorithms.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Optional + +import torch +from botorch.exceptions.errors import BotorchError, BotorchTensorDimensionError +from botorch.utils.multi_objective.pareto import is_non_dominated +from torch import Tensor +from torch.nn import Module + + +class BoxDecomposition(Module, ABC): + r"""An abstract class for box decompositions. + + Note: Internally, we store the negative reference point (minimization). + """ + + def __init__( + self, ref_point: Tensor, sort: bool, Y: Optional[Tensor] = None + ) -> None: + """Initialize BoxDecomposition. + + Args: + ref_point: A `m`-dim tensor containing the reference point. + sort: A boolean indicating whether to sort the Pareto frontier. + Y: A `(batch_shape) x n x m`-dim tensor of outcomes. + """ + super().__init__() + self.register_buffer("_neg_ref_point", -ref_point) + self.register_buffer("sort", torch.tensor(sort, dtype=torch.bool)) + self.num_outcomes = ref_point.shape[-1] + if Y is not None: + self.update(Y=Y) + + @property + def pareto_Y(self) -> Tensor: + r"""This returns the non-dominated set. + + Returns: + A `n_pareto x m`-dim tensor of outcomes. + """ + try: + return -self._neg_pareto_Y + except AttributeError: + raise BotorchError("pareto_Y has not been initialized") + + @property + def ref_point(self) -> Tensor: + r"""Get the reference point. + + Returns: + A `m`-dim tensor of outcomes. + """ + return -self._neg_ref_point + + @property + def Y(self) -> Tensor: + r"""Get the raw outcomes. + + Returns: + A `n x m`-dim tensor of outcomes. + """ + return -self._neg_Y + + def _update_pareto_Y(self) -> bool: + r"""Update the non-dominated front. + + Returns: + A boolean indicating whether the Pareto frontier has changed. + """ + # is_non_dominated assumes maximization + if self._neg_Y.shape[-2] == 0: + pareto_Y = self._neg_Y + else: + pareto_mask = is_non_dominated(self.Y) + if len(self.batch_shape) > 0: + # Note: in the batch case, the Pareto frontier is padded by repeating + # a Pareto point. This ensures that the padded box-decomposition has + # the same number of points, which enables fast batch operations. + max_n_pareto = pareto_mask.sum(dim=-1).max().item() + pareto_Y = torch.empty( + *self.batch_shape, + max_n_pareto, + self._neg_Y.shape[-1], + dtype=self._neg_Y.dtype, + device=self._neg_Y.device, + ) + for i in range(self._neg_Y.shape[0]): + pareto_i = self._neg_Y[i, pareto_mask[i]] + n_pareto = pareto_i.shape[0] + pareto_Y[i, :n_pareto] = pareto_i + # pad pareto_Y, so that all batches have the same size Pareto set + pareto_Y[i, n_pareto:] = pareto_i[-1] + if self.sort: + # sort by first objective + pareto_Y = pareto_Y.gather( + index=torch.argsort(pareto_Y[..., :1], dim=-2).expand( + pareto_Y.shape + ), + dim=-2, + ) + else: + pareto_Y = self._neg_Y[pareto_mask] + if self.sort: + # sort by first objective + pareto_Y = pareto_Y[torch.argsort(pareto_Y[:, 0])] + + if not hasattr(self, "_neg_pareto_Y") or not torch.equal( + pareto_Y, self._neg_pareto_Y + ): + self.register_buffer("_neg_pareto_Y", pareto_Y) + return True + return False + + def partition_space(self) -> None: + r"""Compute box decomposition.""" + try: + self.partition_space_2d() + except BotorchTensorDimensionError: + self._partition_space() + + @abstractmethod + def partition_space_2d(self) -> None: + r"""Compute box decomposition for 2 objectives.""" + pass # pragma: no cover + + @abstractmethod + def get_hypercell_bounds(self) -> Tensor: + r"""Get the bounds of each hypercell in the decomposition. + + Returns: + A `2 x num_cells x num_outcomes`-dim tensor containing the + lower and upper vertices bounding each hypercell. + """ + pass # pragma: no cover + + def update(self, Y: Tensor) -> None: + r"""Update non-dominated front and decomposition. + + Args: + Y: A `(batch_shape) x n x m`-dim tensor of outcomes. + """ + self.batch_shape = Y.shape[:-2] + if len(self.batch_shape) > 1: + raise NotImplementedError( + f"{type(self).__name__} only supports a single " + f"batch dimension, but got {len(self.batch_shape)} " + "batch dimensions." + ) + elif len(self.batch_shape) > 0 and self.num_outcomes > 2: + raise NotImplementedError( + f"{type(self).__name__} only supports a batched box " + f"decompositions in the 2-objective setting." + ) + # multiply by -1, since internally we minimize. + self._neg_Y = -Y + is_new_pareto = self._update_pareto_Y() + # Update decomposition if the Pareto front changed + if is_new_pareto: + self.partition_space() diff --git a/botorch/utils/multi_objective/box_decompositions/non_dominated.py b/botorch/utils/multi_objective/box_decompositions/non_dominated.py new file mode 100644 index 0000000000..1900336025 --- /dev/null +++ b/botorch/utils/multi_objective/box_decompositions/non_dominated.py @@ -0,0 +1,367 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r"""Algorithms for partitioning the non-dominated space into rectangles. + +References + +.. [Couckuyt2012] + I. Couckuyt, D. Deschrijver and T. Dhaene, "Towards Efficient + Multiobjective Optimization: Multiobjective statistical criterions," + 2012 IEEE Congress on Evolutionary Computation, Brisbane, QLD, 2012, + pp. 1-8. + +""" + +from __future__ import annotations + +from typing import Optional + +import torch +from botorch.exceptions.errors import BotorchTensorDimensionError +from botorch.utils.multi_objective.box_decompositions.box_decomposition import ( + BoxDecomposition, +) +from botorch.utils.multi_objective.box_decompositions.utils import _expand_ref_point +from torch import Tensor + + +class NondominatedPartitioning(BoxDecomposition): + r"""A class for partitioning the non-dominated space into hyper-cells. + + Note: this assumes maximization. Internally, it multiplies by -1 and performs + the decomposition under minimization. TODO: use maximization internally as well. + + Note: it is only feasible to use this algorithm to compute an exact + decomposition of the non-dominated space for `m<5` objectives (alpha=0.0). + + The alpha parameter can be increased to obtain an approximate partitioning + faster. The `alpha` is a fraction of the total hypervolume encapsuling the + entire Pareto set. When a hypercell's volume divided by the total hypervolume + is less than `alpha`, we discard the hypercell. See Figure 2 in + [Couckuyt2012]_ for a visual representation. + + This PyTorch implementation of the binary partitioning algorithm ([Couckuyt2012]_) + is adapted from numpy/tensorflow implementation at: + https://github.com/GPflow/GPflowOpt/blob/master/gpflowopt/pareto.py. + + TODO: replace this with a more efficient decomposition. E.g. + https://link.springer.com/content/pdf/10.1007/s10898-019-00798-7.pdf + """ + + def __init__( + self, + ref_point: Tensor, + Y: Optional[Tensor] = None, + alpha: float = 0.0, + eps: Optional[float] = None, + ) -> None: + """Initialize NondominatedPartitioning. + + Args: + ref_point: A `m`-dim tensor containing the reference point. + Y: A `(batch_shape) x n x m`-dim tensor. + alpha: A thresold fraction of total volume used in an approximate + decomposition. + eps: A small value for numerical stability. + """ + self._eps = eps + self.alpha = alpha + super().__init__(ref_point=ref_point, sort=True, Y=Y) + + @property + def eps(self) -> float: + if self._eps is not None: + return self._eps + try: + return 1e-6 if self._neg_pareto_Y.dtype == torch.float else 1e-8 + except AttributeError: + return 1e-6 + + def _partition_space(self) -> None: + r"""Partition the non-dominated space into disjoint hypercells. + + This method works for an arbitrary number of outcomes, but is + less efficient than `partition_non_dominated_space_2d` for the + 2-outcome case. + """ + # Extend Pareto front with the ideal and anti-ideal point + ideal_point = self._neg_pareto_Y.min(dim=0, keepdim=True).values - 1 + anti_ideal_point = self._neg_pareto_Y.max(dim=0, keepdim=True).values + 1 + + aug_pareto_Y = torch.cat( + [ideal_point, self._neg_pareto_Y, anti_ideal_point], dim=0 + ) + # The binary parititoning algorithm uses indices the augmented Pareto front. + aug_pareto_Y_idcs = self._get_augmented_pareto_front_indices() + + # Initialize one cell over entire pareto front + cell = torch.zeros( + 2, self.num_outcomes, dtype=torch.long, device=self._neg_Y.device + ) + cell[1] = aug_pareto_Y_idcs.shape[0] - 1 + stack = [cell] + total_volume = (anti_ideal_point - ideal_point).prod() + + # hypercells contains the indices of the (augmented) Pareto front + # that specify that bounds of the each hypercell. + # It is a `2 x num_cells x num_outcomes`-dim tensor + self.register_buffer( + "hypercells", + torch.empty( + 2, 0, self.num_outcomes, dtype=torch.long, device=self._neg_Y.device + ), + ) + outcome_idxr = torch.arange( + self.num_outcomes, dtype=torch.long, device=self._neg_Y.device + ) + + # Use binary partitioning + while len(stack) > 0: + cell = stack.pop() + cell_bounds_pareto_idcs = aug_pareto_Y_idcs[cell, outcome_idxr] + cell_bounds_pareto_values = aug_pareto_Y[ + cell_bounds_pareto_idcs, outcome_idxr + ] + # Check cell bounds + # - if cell upper bound is better than Pareto front on all outcomes: + # - accept the cell + # - elif cell lower bound is better than Pareto front on all outcomes: + # - this means the cell overlaps the Pareto front. Divide the cell along + # - its longest edge. + if ( + ((cell_bounds_pareto_values[1] - self.eps) < self._neg_pareto_Y) + .any(dim=1) + .all() + ): + # Cell is entirely non-dominated + self.hypercells = torch.cat( + [self.hypercells, cell_bounds_pareto_idcs.unsqueeze(1)], dim=1 + ) + elif ( + ((cell_bounds_pareto_values[0] + self.eps) < self._neg_pareto_Y) + .any(dim=1) + .all() + ): + # The cell overlaps the pareto front + # compute the distance (in integer indices) + idx_dist = cell[1] - cell[0] + any_not_adjacent = (idx_dist > 1).any() + cell_volume = ( + (cell_bounds_pareto_values[1] - cell_bounds_pareto_values[0]) + .prod(dim=-1) + .item() + ) + + # Only divide a cell when it is not composed of adjacent indices + # and the fraction of total volume is above the approximation + # threshold fraction + if ( + any_not_adjacent + and ((cell_volume / total_volume) > self.alpha).all() + ): + # Divide the test cell over its largest dimension + # largest (by index length) + length, longest_dim = torch.max(idx_dist, dim=0) + length = length.item() + longest_dim = longest_dim.item() + + new_length1 = int(round(length / 2.0)) + new_length2 = length - new_length1 + + # Store divided cells + # cell 1: subtract new_length1 from the upper bound of the cell + # cell 2: add new_length2 to the lower bound of the cell + for bound_idx, length_delta in ( + (1, -new_length1), + (0, new_length2), + ): + new_cell = cell.clone() + new_cell[bound_idx, longest_dim] += length_delta + stack.append(new_cell) + + def partition_space_2d(self) -> None: + r"""Partition the non-dominated space into disjoint hypercells. + + This direct method works for `m=2` outcomes. + """ + if self.num_outcomes != 2: + raise BotorchTensorDimensionError( + "partition_non_dominated_space_2d requires 2 outputs, " + f"but num_outcomes={self.num_outcomes}" + ) + pf_ext_idx = self._get_augmented_pareto_front_indices() + n_pf_plus_1 = self._neg_pareto_Y.shape[-2] + 1 + view_shape = torch.Size([1] * len(self.batch_shape) + [n_pf_plus_1]) + expand_shape = self.batch_shape + torch.Size([n_pf_plus_1]) + range_pf_plus1 = torch.arange( + n_pf_plus_1, dtype=torch.long, device=self._neg_pareto_Y.device + ) + range_pf_plus1_expanded = range_pf_plus1.view(view_shape).expand(expand_shape) + + lower = torch.stack( + [range_pf_plus1_expanded, torch.zeros_like(range_pf_plus1_expanded)], dim=-1 + ) + upper = torch.stack( + [1 + range_pf_plus1_expanded, pf_ext_idx[..., -range_pf_plus1 - 1, -1]], + dim=-1, + ) + # 2 x batch_shape x n_cells x 2 + self.register_buffer("hypercells", torch.stack([lower, upper], dim=0)) + + def _get_augmented_pareto_front_indices(self) -> Tensor: + r"""Get indices of augmented Pareto front.""" + pf_idx = torch.argsort(self._neg_pareto_Y, dim=-2) + return torch.cat( + [ + torch.zeros( + *self.batch_shape, + 1, + self.num_outcomes, + dtype=torch.long, + device=self._neg_Y.device, + ), + # Add 1 because index zero is used for the ideal point + pf_idx + 1, + torch.full( + torch.Size( + [ + *self.batch_shape, + 1, + self.num_outcomes, + ] + ), + self._neg_pareto_Y.shape[-2] + 1, + dtype=torch.long, + device=self._neg_Y.device, + ), + ], + dim=-2, + ) + + def get_hypercell_bounds(self) -> Tensor: + r"""Get the bounds of each hypercell in the decomposition. + + Args: + ref_point: A `(batch_shape) x m`-dim tensor containing the reference point. + + Returns: + A `2 x num_cells x num_outcomes`-dim tensor containing the + lower and upper vertices bounding each hypercell. + """ + ref_point = _expand_ref_point( + ref_point=self.ref_point, batch_shape=self.batch_shape + ) + aug_pareto_Y = torch.cat( + [ + # -inf is the lower bound of the non-dominated space + torch.full( + torch.Size( + [ + *self.batch_shape, + 1, + self.num_outcomes, + ] + ), + float("-inf"), + dtype=self._neg_pareto_Y.dtype, + device=self._neg_pareto_Y.device, + ), + self._neg_pareto_Y, + # note: internally, this class minimizes, so use negative here + -(ref_point.unsqueeze(-2)), + ], + dim=-2, + ) + minimization_cell_bounds = self._get_hypercell_bounds(aug_pareto_Y=aug_pareto_Y) + # swap upper and lower bounds and multiply by -1 + return -minimization_cell_bounds.flip(0) + + def _get_hypercell_bounds(self, aug_pareto_Y: Tensor) -> Tensor: + r"""Get the bounds of each hypercell in the decomposition. + + Args: + aug_pareto_Y: A `n_pareto + 2 x m`-dim tensor containing + the augmented Pareto front. + + Returns: + A `2 x (batch_shape) x num_cells x num_outcomes`-dim tensor containing the + lower and upper vertices bounding each hypercell. + """ + num_cells = self.hypercells.shape[-2] + cells_times_outcomes = num_cells * self.num_outcomes + outcome_idxr = ( + torch.arange(self.num_outcomes, dtype=torch.long, device=self._neg_Y.device) + .repeat(num_cells) + .view( + *(1 for _ in self.hypercells.shape[:-2]), + cells_times_outcomes, + ) + .expand(*self.hypercells.shape[:-2], cells_times_outcomes) + ) + + # this tensor is 2 x (num_cells * num_outcomes) x 2 + # the batch dim corresponds to lower/upper bound + cell_bounds_idxr = torch.stack( + [ + self.hypercells.view(*self.hypercells.shape[:-2], -1), + outcome_idxr, + ], + dim=-1, + ).view(2, -1, 2) + if len(self.batch_shape) > 0: + # TODO: support multiple batch dimensions here + batch_idxr = ( + torch.arange( + self.batch_shape[0], dtype=torch.long, device=self._neg_Y.device + ) + .unsqueeze(1) + .expand(-1, cells_times_outcomes) + .reshape(1, -1, 1) + .expand(2, -1, 1) + ) + cell_bounds_idxr = torch.cat([batch_idxr, cell_bounds_idxr], dim=-1) + + cell_bounds_values = aug_pareto_Y[ + cell_bounds_idxr.chunk(cell_bounds_idxr.shape[-1], dim=-1) + ] + view_shape = (2, *self.batch_shape, num_cells, self.num_outcomes) + return cell_bounds_values.view(view_shape) + + def compute_hypervolume(self) -> Tensor: + r"""Compute the hypervolume for the given reference point. + + Note: This assumes minimization. + + This method computes the hypervolume of the non-dominated space + and computes the difference between the hypervolume between the + ideal point and hypervolume of the non-dominated space. + + Note there are much more efficient alternatives for computing + hypervolume when m > 2 (which do not require partitioning the + non-dominated space). Given such a partitioning, this method + is quite fast. + + Returns: + `(batch_shape)`-dim tensor containing the dominated hypervolume. + """ + ref_point = _expand_ref_point( + ref_point=self.ref_point, batch_shape=self.batch_shape + ) + # internally we minimize + ref_point = -ref_point.unsqueeze(-2) + if (self._neg_pareto_Y >= ref_point).any(): + raise ValueError( + "The reference point must be greater than all pareto_Y values." + ) + ideal_point = self._neg_pareto_Y.min(dim=-2, keepdim=True).values + aug_pareto_Y = torch.cat([ideal_point, self._neg_pareto_Y, ref_point], dim=-2) + cell_bounds_values = self._get_hypercell_bounds(aug_pareto_Y=aug_pareto_Y) + total_volume = (ref_point - ideal_point).squeeze(-2).prod(dim=-1) + non_dom_volume = ( + (cell_bounds_values[1] - cell_bounds_values[0]).prod(dim=-1).sum(dim=-1) + ) + return total_volume - non_dom_volume diff --git a/botorch/utils/multi_objective/box_decompositions/utils.py b/botorch/utils/multi_objective/box_decompositions/utils.py new file mode 100644 index 0000000000..f9e1ae458a --- /dev/null +++ b/botorch/utils/multi_objective/box_decompositions/utils.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from botorch.exceptions.errors import BotorchTensorDimensionError +from torch import Size, Tensor + + +def _expand_ref_point(ref_point: Tensor, batch_shape: Size) -> Tensor: + r"""Expand reference point to the proper batch_shape. + + Args: + ref_point: A `(batch_shape) x m`-dim tensor containing the reference + point. + batch_shape: The batch shape. + + Returns: + A `batch_shape x m`-dim tensor containing the expanded reference point + """ + if ref_point.shape[:-1] != batch_shape: + if ref_point.ndim > 1: + raise BotorchTensorDimensionError( + "Expected ref_point to be a `batch_shape x m` or `m`-dim tensor, " + f"but got {ref_point.shape}." + ) + ref_point = ref_point.view( + *(1 for _ in batch_shape), ref_point.shape[-1] + ).expand(batch_shape + ref_point.shape[-1:]) + return ref_point diff --git a/botorch/utils/multi_objective/pareto.py b/botorch/utils/multi_objective/pareto.py index 734a06fe8b..c10fa9a210 100644 --- a/botorch/utils/multi_objective/pareto.py +++ b/botorch/utils/multi_objective/pareto.py @@ -6,23 +6,33 @@ from __future__ import annotations +import torch from torch import Tensor -def is_non_dominated(Y: Tensor) -> Tensor: +def is_non_dominated(Y: Tensor, deduplicate: bool = True) -> Tensor: r"""Computes the non-dominated front. Note: this assumes maximization. Args: - Y: a `(batch_shape) x n x m`-dim tensor of outcomes. + Y: A `(batch_shape) x n x m`-dim tensor of outcomes. + deduplicate: A boolean indicating whether to only return + unique points on the pareto frontier. Returns: A `(batch_shape) x n`-dim boolean tensor indicating whether each point is non-dominated. """ - expanded_shape = Y.shape[:-2] + Y.shape[-2:-1] + Y.shape[-2:] - Y1 = Y.unsqueeze(-3).expand(expanded_shape) - Y2 = Y.unsqueeze(-2).expand(expanded_shape) + Y1 = Y.unsqueeze(-3) + Y2 = Y.unsqueeze(-2) dominates = (Y1 >= Y2).all(dim=-1) & (Y1 > Y2).any(dim=-1) - return ~(dominates.any(dim=-1)) + nd_mask = ~(dominates.any(dim=-1)) + if deduplicate: + # remove duplicates + # find index of first occurrence of each unique element + indices = (Y1 == Y2).all(dim=-1).long().argmax(dim=-1) + keep = torch.zeros_like(nd_mask) + keep.scatter_(dim=-1, index=indices, value=1.0) + return nd_mask & keep + return nd_mask diff --git a/sphinx/source/utils.rst b/sphinx/source/utils.rst index 626fc8f2aa..27c5129a60 100644 --- a/sphinx/source/utils.rst +++ b/sphinx/source/utils.rst @@ -55,8 +55,18 @@ Feasible Volume Multi-Objective Utilities ------------------------------------------- -Box Decompositions +Abstract Box Decompositions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.utils.multi_objective.box_decompositions.box_decomposition + :members: + +Box Decomposition Utilities +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.utils.multi_objective.box_decompositions.utils + :members: + +Box Decompositions [DEPRECATED - use botorch..utils.multi_objective.box_decompositions] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.utils.multi_objective.box_decomposition :members: @@ -65,6 +75,11 @@ Hypervolume .. automodule:: botorch.utils.multi_objective.hypervolume :members: +Non-dominated Partitionings +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.utils.multi_objective.box_decompositions.non_dominated + :members: + Pareto ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.utils.multi_objective.pareto diff --git a/test/acquisition/multi_objective/test_analytic.py b/test/acquisition/multi_objective/test_analytic.py index e85401e103..1c123a2c8f 100644 --- a/test/acquisition/multi_objective/test_analytic.py +++ b/test/acquisition/multi_objective/test_analytic.py @@ -1,6 +1,8 @@ #!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import torch @@ -14,7 +16,9 @@ IdentityMCMultiOutputObjective, ) from botorch.exceptions.errors import BotorchError, UnsupportedError -from botorch.utils.multi_objective.box_decomposition import NondominatedPartitioning +from botorch.utils.multi_objective.box_decompositions.non_dominated import ( + NondominatedPartitioning, +) from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior @@ -66,7 +70,9 @@ def test_expected_hypervolume_improvement(self): pareto_Y = torch.tensor( [[4.0, 5.0], [5.0, 5.0], [8.5, 3.5], [8.5, 3.0], [9.0, 1.0]], **tkwargs ) - partitioning = NondominatedPartitioning(num_outcomes=2) + partitioning = NondominatedPartitioning( + ref_point=torch.tensor(ref_point, **tkwargs) + ) # the event shape is `b x q x m` = 1 x 1 x 1 mean = torch.zeros(1, 1, 2, **tkwargs) variance = torch.zeros(1, 1, 2, **tkwargs) diff --git a/test/acquisition/multi_objective/test_monte_carlo.py b/test/acquisition/multi_objective/test_monte_carlo.py index 2a16013f99..824c069b1f 100644 --- a/test/acquisition/multi_objective/test_monte_carlo.py +++ b/test/acquisition/multi_objective/test_monte_carlo.py @@ -20,7 +20,9 @@ from botorch.exceptions.errors import BotorchError, UnsupportedError from botorch.exceptions.warnings import BotorchWarning from botorch.sampling.samplers import IIDNormalSampler, SobolQMCNormalSampler -from botorch.utils.multi_objective.box_decomposition import NondominatedPartitioning +from botorch.utils.multi_objective.box_decompositions.non_dominated import ( + NondominatedPartitioning, +) from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior @@ -72,12 +74,13 @@ class TestQExpectedHypervolumeImprovement(BotorchTestCase): def test_q_expected_hypervolume_improvement(self): tkwargs = {"device": self.device} for dtype in (torch.float, torch.double): - ref_point = [0.0, 0.0] tkwargs["dtype"] = dtype + ref_point = [0.0, 0.0] + t_ref_point = torch.tensor(ref_point, **tkwargs) pareto_Y = torch.tensor( [[4.0, 5.0], [5.0, 5.0], [8.5, 3.5], [8.5, 3.0], [9.0, 1.0]], **tkwargs ) - partitioning = NondominatedPartitioning(num_outcomes=2) + partitioning = NondominatedPartitioning(ref_point=t_ref_point) # the event shape is `b x q x m` = 1 x 1 x 2 samples = torch.zeros(1, 1, 2, **tkwargs) mm = MockModel(MockPosterior(samples=samples)) @@ -337,10 +340,12 @@ def test_q_expected_hypervolume_improvement(self): [[4.0, 2.0, 3.0], [3.0, 5.0, 1.0], [2.0, 4.0, 2.0], [1.0, 3.0, 4.0]], **tkwargs, ) - partitioning = NondominatedPartitioning(num_outcomes=3, Y=pareto_Y) + ref_point = [-1.0] * 3 + t_ref_point = torch.tensor(ref_point, **tkwargs) + partitioning = NondominatedPartitioning(ref_point=t_ref_point, Y=pareto_Y) samples = torch.tensor([[1.0, 2.0, 6.0]], **tkwargs).unsqueeze(0) mm = MockModel(MockPosterior(samples=samples)) - ref_point = [-1.0] * 3 + acqf = qExpectedHypervolumeImprovement( model=mm, ref_point=ref_point, @@ -353,6 +358,8 @@ def test_q_expected_hypervolume_improvement(self): # change reference point ref_point = [0.0] * 3 + t_ref_point = torch.tensor(ref_point, **tkwargs) + partitioning = NondominatedPartitioning(ref_point=t_ref_point, Y=pareto_Y) acqf = qExpectedHypervolumeImprovement( model=mm, ref_point=ref_point, @@ -364,6 +371,8 @@ def test_q_expected_hypervolume_improvement(self): # test m = 3, no contribution ref_point = [1.0] * 3 + t_ref_point = torch.tensor(ref_point, **tkwargs) + partitioning = NondominatedPartitioning(ref_point=t_ref_point, Y=pareto_Y) acqf = qExpectedHypervolumeImprovement( model=mm, ref_point=ref_point, @@ -382,7 +391,8 @@ def test_q_expected_hypervolume_improvement(self): ).unsqueeze(0) mm = MockModel(MockPosterior(samples=samples)) ref_point = [-1.0] * 3 - partitioning = NondominatedPartitioning(num_outcomes=3, Y=pareto_Y) + t_ref_point = torch.tensor(ref_point, **tkwargs) + partitioning = NondominatedPartitioning(ref_point=t_ref_point, Y=pareto_Y) acqf = qExpectedHypervolumeImprovement( model=mm, ref_point=ref_point, @@ -397,10 +407,11 @@ def test_constrained_q_expected_hypervolume_improvement(self): for dtype in (torch.float, torch.double): tkwargs = {"device": self.device, "dtype": dtype} ref_point = [0.0, 0.0] + t_ref_point = torch.tensor(ref_point, **tkwargs) pareto_Y = torch.tensor( [[4.0, 5.0], [5.0, 5.0], [8.5, 3.5], [8.5, 3.0], [9.0, 1.0]], **tkwargs ) - partitioning = NondominatedPartitioning(num_outcomes=2) + partitioning = NondominatedPartitioning(ref_point=t_ref_point) partitioning.update(Y=pareto_Y) # test q=1 diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index 07a8f037f5..f1c505bb91 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -28,7 +28,9 @@ from botorch.exceptions.errors import UnsupportedError from botorch.exceptions.warnings import SamplingWarning from botorch.sampling.samplers import IIDNormalSampler, SobolQMCNormalSampler -from botorch.utils.multi_objective.box_decomposition import NondominatedPartitioning +from botorch.utils.multi_objective.box_decompositions.non_dominated import ( + NondominatedPartitioning, +) from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior from torch import Tensor @@ -390,7 +392,7 @@ def test_GetQEHVI(self, mock_acqf): ) _, kwargs = mock_acqf.call_args partitioning = kwargs["partitioning"] - self.assertEqual(partitioning._pareto_Y.shape[0], 0) + self.assertEqual(partitioning.pareto_Y.shape[0], 0) def test_GetUnknownAcquisitionFunction(self): with self.assertRaises(NotImplementedError): diff --git a/test/utils/multi_objective/box_decompositions/__init__.py b/test/utils/multi_objective/box_decompositions/__init__.py new file mode 100644 index 0000000000..734a1eb4e2 --- /dev/null +++ b/test/utils/multi_objective/box_decompositions/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/test/utils/multi_objective/box_decompositions/test_box_decomposition.py b/test/utils/multi_objective/box_decompositions/test_box_decomposition.py new file mode 100644 index 0000000000..bb48079886 --- /dev/null +++ b/test/utils/multi_objective/box_decompositions/test_box_decomposition.py @@ -0,0 +1,126 @@ +#! /usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from itertools import product +from unittest import mock + +import torch +from botorch.exceptions.errors import BotorchError +from botorch.utils.multi_objective.box_decompositions.box_decomposition import ( + BoxDecomposition, +) +from botorch.utils.testing import BotorchTestCase + + +class DummyBoxDecomposition(BoxDecomposition): + def partition_space_2d(self): + pass + + def _partition_space(self): + pass + + def compute_hypervolume(self): + pass + + def get_hypercell_bounds(self): + pass + + +class TestBoxDecomposition(BotorchTestCase): + def test_box_decomposition(self): + with self.assertRaises(TypeError): + BoxDecomposition() + ref_point_raw = torch.zeros(3, device=self.device) + Y_raw = torch.tensor( + [ + [1.0, 2.0, 0.0], + [1.0, 1.0, 0.0], + [2.0, 0.5, 0.0], + ], + device=self.device, + ) + pareto_Y_raw = torch.tensor( + [ + [1.0, 2.0, 0.0], + [2.0, 0.5, 0.0], + ], + device=self.device, + ) + for dtype, m, sort in product( + (torch.float, torch.double), (2, 3), (True, False) + ): + with mock.patch.object( + DummyBoxDecomposition, + "partition_space_2d" if m == 2 else "partition_space", + ) as mock_partition_space: + + ref_point = ref_point_raw[:m].to(dtype=dtype) + Y = Y_raw[:, :m].to(dtype=dtype) + pareto_Y = pareto_Y_raw[:, :m].to(dtype=dtype) + bd = DummyBoxDecomposition(ref_point=ref_point, sort=sort) + + # test pareto_Y before it is initialized + with self.assertRaises(BotorchError): + bd.pareto_Y + bd = DummyBoxDecomposition(ref_point=ref_point, sort=sort, Y=Y) + + mock_partition_space.assert_called_once() + # test attributes + expected_pareto_Y = ( + pareto_Y[torch.argsort(-pareto_Y[:, 0])] if sort else pareto_Y + ) + self.assertTrue(torch.equal(bd.pareto_Y, expected_pareto_Y)) + self.assertTrue(torch.equal(bd.Y, Y)) + self.assertTrue(torch.equal(bd._neg_Y, -Y)) + self.assertTrue(torch.equal(bd._neg_pareto_Y, -expected_pareto_Y)) + self.assertTrue(torch.equal(bd.ref_point, ref_point)) + self.assertTrue(torch.equal(bd._neg_ref_point, -ref_point)) + self.assertEqual(bd.num_outcomes, m) + + # test empty Y + bd = DummyBoxDecomposition(ref_point=ref_point, sort=sort, Y=Y[:0]) + self.assertTrue(torch.equal(bd.pareto_Y, expected_pareto_Y[:0])) + + # test batch mode + if m == 2: + batch_Y = torch.stack([Y, Y + 1], dim=0) + bd = DummyBoxDecomposition( + ref_point=ref_point, sort=sort, Y=batch_Y + ) + batch_expected_pareto_Y = torch.stack( + [expected_pareto_Y, expected_pareto_Y + 1], dim=0 + ) + self.assertTrue(torch.equal(bd.pareto_Y, batch_expected_pareto_Y)) + self.assertTrue(torch.equal(bd.Y, batch_Y)) + self.assertTrue(torch.equal(bd.ref_point, ref_point)) + # test batch ref point + batch_ref_point = torch.stack([ref_point, ref_point + 1], dim=0) + bd = DummyBoxDecomposition( + ref_point=batch_ref_point, sort=sort, Y=batch_Y + ) + self.assertTrue(torch.equal(bd.ref_point, batch_ref_point)) + # test multiple batch dims + with self.assertRaises(NotImplementedError): + DummyBoxDecomposition( + ref_point=ref_point, + sort=sort, + Y=batch_Y.unsqueeze(0), + ) + # test empty Y + bd = DummyBoxDecomposition( + ref_point=ref_point, sort=sort, Y=batch_Y[:, :0] + ) + self.assertTrue( + torch.equal(bd.pareto_Y, batch_expected_pareto_Y[:, :0]) + ) + + else: + with self.assertRaises(NotImplementedError): + DummyBoxDecomposition( + ref_point=ref_point, sort=sort, Y=Y.unsqueeze(0) + ) diff --git a/test/utils/multi_objective/box_decompositions/test_non_dominated.py b/test/utils/multi_objective/box_decompositions/test_non_dominated.py new file mode 100644 index 0000000000..d1d4cd4211 --- /dev/null +++ b/test/utils/multi_objective/box_decompositions/test_non_dominated.py @@ -0,0 +1,200 @@ +#! /usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import torch +from botorch.exceptions.errors import BotorchError, BotorchTensorDimensionError +from botorch.utils.multi_objective.box_decompositions.non_dominated import ( + NondominatedPartitioning, +) +from botorch.utils.testing import BotorchTestCase + + +class TestNonDominatedPartitioning(BotorchTestCase): + def test_non_dominated_partitioning(self): + tkwargs = {"device": self.device} + for dtype in (torch.float, torch.double): + tkwargs["dtype"] = dtype + ref_point = torch.zeros(2, **tkwargs) + partitioning = NondominatedPartitioning(ref_point=ref_point) + # assert error is raised if pareto_Y has not been computed + with self.assertRaises(BotorchError): + partitioning.pareto_Y + # test eps + # no pareto_Y + self.assertEqual(partitioning.eps, 1e-6) + partitioning = NondominatedPartitioning(ref_point=ref_point, eps=1.0) + # eps set + self.assertEqual(partitioning.eps, 1.0) + # set pareto_Y + partitioning = NondominatedPartitioning(ref_point=ref_point) + Y = torch.zeros(1, 2, **tkwargs) + partitioning.update(Y=Y) + self.assertEqual(partitioning.eps, 1e-6 if dtype == torch.float else 1e-8) + + # test _update_pareto_Y + partitioning._neg_Y = -Y + self.assertFalse(partitioning._update_pareto_Y()) + + # test m=2 + arange = torch.arange(3, 9, **tkwargs) + pareto_Y = torch.stack([arange, 11 - arange], dim=-1) + Y = torch.cat( + [ + pareto_Y, + torch.tensor( + [[8.0, 2.0], [7.0, 1.0]], **tkwargs + ), # add some non-pareto elements + ], + dim=0, + ) + partitioning = NondominatedPartitioning(ref_point=ref_point, Y=Y) + sorting = torch.argsort(pareto_Y[:, 0], descending=True) + self.assertTrue(torch.equal(pareto_Y[sorting], partitioning.pareto_Y)) + inf = float("inf") + expected_cell_bounds = torch.tensor( + [ + [ + [8.0, 0.0], + [7.0, 3.0], + [6.0, 4.0], + [5.0, 5.0], + [4.0, 6.0], + [3.0, 7.0], + [0.0, 8.0], + ], + [ + [inf, inf], + [8.0, inf], + [7.0, inf], + [6.0, inf], + [5.0, inf], + [4.0, inf], + [3.0, inf], + ], + ], + **tkwargs, + ) + cell_bounds = partitioning.get_hypercell_bounds() + self.assertTrue(torch.equal(cell_bounds, expected_cell_bounds)) + # test compute hypervolume + hv = partitioning.compute_hypervolume() + self.assertEqual(hv.item(), 49.0) + # test error when reference is not worse than all pareto_Y + partitioning = NondominatedPartitioning( + ref_point=pareto_Y.max(dim=0).values, Y=Y + ) + with self.assertRaises(ValueError): + partitioning.compute_hypervolume() + + # test batched, m=2 case + Y = torch.rand(3, 10, 2, **tkwargs) + partitioning = NondominatedPartitioning(ref_point=ref_point, Y=Y) + cell_bounds = partitioning.get_hypercell_bounds() + partitionings = [] + for i in range(Y.shape[0]): + partitioning_i = NondominatedPartitioning(ref_point=ref_point, Y=Y[i]) + partitionings.append(partitioning_i) + # check pareto_Y + pareto_set1 = {tuple(x) for x in partitioning_i.pareto_Y.tolist()} + pareto_set2 = {tuple(x) for x in partitioning.pareto_Y[i].tolist()} + self.assertEqual(pareto_set1, pareto_set2) + expected_cell_bounds_i = partitioning_i.get_hypercell_bounds() + # remove padding + no_padding_cell_bounds_i = cell_bounds[:, i][ + :, ((cell_bounds[1, i] - cell_bounds[0, i]) != 0).all(dim=-1) + ] + self.assertTrue( + torch.equal(expected_cell_bounds_i, no_padding_cell_bounds_i) + ) + + # test improper Y shape (too many batch dims) + with self.assertRaises(NotImplementedError): + NondominatedPartitioning(ref_point=ref_point, Y=Y.unsqueeze(0)) + + # test batched compute_hypervolume, m=2 + hvs = partitioning.compute_hypervolume() + hvs_non_batch = torch.stack( + [ + partitioning_i.compute_hypervolume() + for partitioning_i in partitionings + ], + dim=0, + ) + self.assertTrue(torch.allclose(hvs, hvs_non_batch)) + + # test batched m>2 + ref_point = torch.zeros(3, **tkwargs) + with self.assertRaises(NotImplementedError): + NondominatedPartitioning( + ref_point=ref_point, Y=torch.cat([Y, Y[..., :1]], dim=-1) + ) + + # test error with partition_space_2d for m=3 + partitioning = NondominatedPartitioning( + ref_point=ref_point, Y=torch.zeros(1, 3, **tkwargs) + ) + with self.assertRaises(BotorchTensorDimensionError): + partitioning.partition_space_2d() + # test m=3 + pareto_Y = torch.tensor( + [[1.0, 6.0, 8.0], [2.0, 4.0, 10.0], [3.0, 5.0, 7.0]], **tkwargs + ) + ref_point = torch.tensor([-1.0, -2.0, -3.0], **tkwargs) + partitioning = NondominatedPartitioning(ref_point=ref_point, Y=pareto_Y) + sorting = torch.argsort(pareto_Y[:, 0], descending=True) + self.assertTrue(torch.equal(pareto_Y[sorting], partitioning.pareto_Y)) + + expected_cell_bounds = torch.tensor( + [ + [ + [1.0, 4.0, 7.0], + [-1.0, -2.0, 10.0], + [-1.0, 4.0, 8.0], + [1.0, -2.0, 10.0], + [1.0, 4.0, 8.0], + [-1.0, 6.0, -3.0], + [1.0, 5.0, -3.0], + [-1.0, 5.0, 8.0], + [2.0, -2.0, 7.0], + [2.0, 4.0, 7.0], + [3.0, -2.0, -3.0], + [2.0, -2.0, 8.0], + [2.0, 5.0, -3.0], + ], + [ + [2.0, 5.0, 8.0], + [1.0, 4.0, inf], + [1.0, 5.0, inf], + [2.0, 4.0, inf], + [2.0, 5.0, inf], + [1.0, inf, 8.0], + [2.0, inf, 8.0], + [2.0, inf, inf], + [3.0, 4.0, 8.0], + [3.0, 5.0, 8.0], + [inf, 5.0, 8.0], + [inf, 5.0, inf], + [inf, inf, inf], + ], + ], + **tkwargs, + ) + cell_bounds = partitioning.get_hypercell_bounds() + # cell bounds can have different order + num_matches = ( + (cell_bounds.unsqueeze(0) == expected_cell_bounds.unsqueeze(1)) + .all(dim=-1) + .any(dim=0) + .sum() + ) + self.assertTrue(num_matches, 9) + # test compute hypervolume + hv = partitioning.compute_hypervolume() + self.assertEqual(hv.item(), 358.0) + + # TODO: test approximate decomposition diff --git a/test/utils/multi_objective/box_decompositions/test_utils.py b/test/utils/multi_objective/box_decompositions/test_utils.py new file mode 100644 index 0000000000..b22d23b0a0 --- /dev/null +++ b/test/utils/multi_objective/box_decompositions/test_utils.py @@ -0,0 +1,37 @@ +#! /usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import torch +from botorch.exceptions.errors import BotorchTensorDimensionError +from botorch.utils.multi_objective.box_decompositions.utils import _expand_ref_point +from botorch.utils.testing import BotorchTestCase + + +class TestExpandRefPoint(BotorchTestCase): + def test_expand_ref_point(self): + ref_point = torch.tensor([1.0, 2.0], device=self.device) + for dtype in (torch.float, torch.double): + ref_point = ref_point.to(dtype=dtype) + # test non-batch + self.assertTrue( + torch.equal( + _expand_ref_point(ref_point, batch_shape=torch.Size([])), + ref_point, + ) + ) + self.assertTrue( + torch.equal( + _expand_ref_point(ref_point, batch_shape=torch.Size([3])), + ref_point.unsqueeze(0).expand(3, -1), + ) + ) + # test ref point with wrong shape batch_shape + with self.assertRaises(BotorchTensorDimensionError): + _expand_ref_point(ref_point.unsqueeze(0), batch_shape=torch.Size([])) + with self.assertRaises(BotorchTensorDimensionError): + _expand_ref_point(ref_point.unsqueeze(0).expand(3, -1), torch.Size([2])) diff --git a/test/utils/multi_objective/test_box_decomposition.py b/test/utils/multi_objective/test_box_decomposition.py index 074012f58a..50cfb14bce 100644 --- a/test/utils/multi_objective/test_box_decomposition.py +++ b/test/utils/multi_objective/test_box_decomposition.py @@ -1,4 +1,4 @@ -#! /usr/bin/env python3 +#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the @@ -6,198 +6,24 @@ from __future__ import annotations -import torch -from botorch.exceptions.errors import BotorchError, BotorchTensorDimensionError -from botorch.utils.multi_objective.box_decomposition import NondominatedPartitioning -from botorch.utils.testing import BotorchTestCase - - -class TestNonDominatedPartitioning(BotorchTestCase): - def test_non_dominated_partitioning(self): - tkwargs = {"device": self.device} - for dtype in (torch.float, torch.double): - tkwargs["dtype"] = dtype - partitioning = NondominatedPartitioning(num_outcomes=2) - # assert error is raised if pareto_Y has not been computed - with self.assertRaises(BotorchError): - partitioning.pareto_Y - # test eps - # no pareto_Y - self.assertEqual(partitioning.eps, 1e-6) - partitioning = NondominatedPartitioning(num_outcomes=2, eps=1.0) - # eps set - self.assertEqual(partitioning.eps, 1.0) - # set pareto_Y - partitioning = NondominatedPartitioning(num_outcomes=2) - Y = torch.zeros(1, 2, **tkwargs) - partitioning.update(Y=Y) - self.assertEqual(partitioning.eps, 1e-6 if dtype == torch.float else 1e-8) - - # test _update_pareto_Y - partitioning.Y = -Y - self.assertFalse(partitioning._update_pareto_Y()) - - # test m=2 - arange = torch.arange(3, 9, **tkwargs) - pareto_Y = torch.stack([arange, 11 - arange], dim=-1) - Y = torch.cat( - [ - pareto_Y, - torch.tensor( - [[8.0, 2.0], [7.0, 1.0]], **tkwargs - ), # add some non-pareto elements - ], - dim=0, - ) - partitioning = NondominatedPartitioning(num_outcomes=2, Y=Y) - sorting = torch.argsort(pareto_Y[:, 0], descending=True) - self.assertTrue(torch.equal(pareto_Y[sorting], partitioning.pareto_Y)) - ref_point = torch.zeros(2, **tkwargs) - inf = float("inf") - expected_cell_bounds = torch.tensor( - [ - [ - [8.0, 0.0], - [7.0, 3.0], - [6.0, 4.0], - [5.0, 5.0], - [4.0, 6.0], - [3.0, 7.0], - [0.0, 8.0], - ], - [ - [inf, inf], - [8.0, inf], - [7.0, inf], - [6.0, inf], - [5.0, inf], - [4.0, inf], - [3.0, inf], - ], - ], - **tkwargs, - ) - cell_bounds = partitioning.get_hypercell_bounds(ref_point) - self.assertTrue(torch.equal(cell_bounds, expected_cell_bounds)) - # test compute hypervolume - hv = partitioning.compute_hypervolume(ref_point) - self.assertEqual(hv.item(), 49.0) - # test error when reference is not worse than all pareto_Y - with self.assertRaises(ValueError): - partitioning.compute_hypervolume(pareto_Y.max(dim=0).values) - - # test batched, m=2 case - Y = torch.rand(3, 10, 2, **tkwargs) - partitioning = NondominatedPartitioning(num_outcomes=2, Y=Y) - cell_bounds = partitioning.get_hypercell_bounds(ref_point) - partitionings = [] - for i in range(Y.shape[0]): - partitioning_i = NondominatedPartitioning(num_outcomes=2, Y=Y[i]) - partitionings.append(partitioning_i) - # check pareto_Y - pareto_set1 = {tuple(x) for x in partitioning_i.pareto_Y.tolist()} - pareto_set2 = {tuple(x) for x in partitioning.pareto_Y[i].tolist()} - self.assertEqual(pareto_set1, pareto_set2) - expected_cell_bounds_i = partitioning_i.get_hypercell_bounds(ref_point) - # remove padding - no_padding_cell_bounds_i = cell_bounds[:, i][ - :, ((cell_bounds[1, i] - cell_bounds[0, i]) != 0).all(dim=-1) - ] - self.assertTrue( - torch.equal(expected_cell_bounds_i, no_padding_cell_bounds_i) - ) +import warnings - # test batch ref point - cell_bounds2 = partitioning.get_hypercell_bounds( - ref_point.unsqueeze(0).expand(3, 2) - ) - self.assertTrue(torch.equal(cell_bounds, cell_bounds2)) - - # test improper batch shape - with self.assertRaises(BotorchTensorDimensionError): - partitioning.get_hypercell_bounds(ref_point.unsqueeze(0).expand(4, 2)) +from botorch import settings +from botorch.utils.testing import BotorchTestCase - # test improper Y shape (too many batch dims) - with self.assertRaises(NotImplementedError): - NondominatedPartitioning(num_outcomes=2, Y=Y.unsqueeze(0)) - # test batched compute_hypervolume, m=2 - hvs = partitioning.compute_hypervolume(ref_point) - hvs_non_batch = torch.stack( - [ - partitioning_i.compute_hypervolume(ref_point) - for partitioning_i in partitionings - ], - dim=0, +class TestBoxDecompositionDeprecation(BotorchTestCase): + def test_deprecation(self): + with warnings.catch_warnings(record=True) as ws, settings.debug(True): + from botorch.utils.multi_objective.box_decomposition import ( # noqa: F401 + NondominatedPartitioning, ) - self.assertTrue(torch.allclose(hvs, hvs_non_batch)) - # test batched m>2 - with self.assertRaises(NotImplementedError): - NondominatedPartitioning( - num_outcomes=3, Y=torch.cat([Y, Y[..., :1]], dim=-1) + self.assertTrue(any(issubclass(w.category, DeprecationWarning) for w in ws)) + self.assertTrue( + any( + "The botorch.utils.multi_objective.box_decomposition module has " + in str(w.message) + for w in ws ) - - # test error with partition_non_dominated_space_2d for m=3 - partitioning = NondominatedPartitioning( - num_outcomes=3, Y=torch.zeros(1, 3, **tkwargs) ) - with self.assertRaises(BotorchTensorDimensionError): - partitioning.partition_non_dominated_space_2d() - # test m=3 - pareto_Y = torch.tensor( - [[1.0, 6.0, 8.0], [2.0, 4.0, 10.0], [3.0, 5.0, 7.0]], **tkwargs - ) - partitioning = NondominatedPartitioning(num_outcomes=3, Y=pareto_Y) - sorting = torch.argsort(pareto_Y[:, 0], descending=True) - self.assertTrue(torch.equal(pareto_Y[sorting], partitioning.pareto_Y)) - ref_point = torch.tensor([-1.0, -2.0, -3.0], **tkwargs) - expected_cell_bounds = torch.tensor( - [ - [ - [1.0, 4.0, 7.0], - [-1.0, -2.0, 10.0], - [-1.0, 4.0, 8.0], - [1.0, -2.0, 10.0], - [1.0, 4.0, 8.0], - [-1.0, 6.0, -3.0], - [1.0, 5.0, -3.0], - [-1.0, 5.0, 8.0], - [2.0, -2.0, 7.0], - [2.0, 4.0, 7.0], - [3.0, -2.0, -3.0], - [2.0, -2.0, 8.0], - [2.0, 5.0, -3.0], - ], - [ - [2.0, 5.0, 8.0], - [1.0, 4.0, inf], - [1.0, 5.0, inf], - [2.0, 4.0, inf], - [2.0, 5.0, inf], - [1.0, inf, 8.0], - [2.0, inf, 8.0], - [2.0, inf, inf], - [3.0, 4.0, 8.0], - [3.0, 5.0, 8.0], - [inf, 5.0, 8.0], - [inf, 5.0, inf], - [inf, inf, inf], - ], - ], - **tkwargs, - ) - cell_bounds = partitioning.get_hypercell_bounds(ref_point) - # cell bounds can have different order - num_matches = ( - (cell_bounds.unsqueeze(0) == expected_cell_bounds.unsqueeze(1)) - .all(dim=-1) - .any(dim=0) - .sum() - ) - self.assertTrue(num_matches, 9) - # test compute hypervolume - hv = partitioning.compute_hypervolume(ref_point) - self.assertEqual(hv.item(), 358.0) - - # TODO: test approximate decomposition diff --git a/test/utils/multi_objective/test_pareto.py b/test/utils/multi_objective/test_pareto.py index 3c64a33711..4d0ad178f3 100644 --- a/test/utils/multi_objective/test_pareto.py +++ b/test/utils/multi_objective/test_pareto.py @@ -22,6 +22,7 @@ def test_is_non_dominated(self) -> None: [4.0, 5.0], [5.0, 5.0], [8.5, 3.5], + [8.5, 3.5], [8.5, 3.0], [9.0, 1.0], ] @@ -36,6 +37,7 @@ def test_is_non_dominated(self) -> None: [2.0, 4.0, 1.0], [3.0, 5.0, 1.0], [2.0, 4.0, 2.0], + [2.0, 4.0, 2.0], [1.0, 3.0, 4.0], [1.0, 2.0, 4.0], [1.0, 2.0, 6.0], @@ -67,15 +69,39 @@ def test_is_non_dominated(self) -> None: # test 2d nondom_Y = Y[is_non_dominated(Y)] self.assertTrue(torch.equal(expected_nondom_Y, nondom_Y)) + # test deduplicate=False + expected_nondom_Y_no_dedup = torch.cat( + [expected_nondom_Y, expected_nondom_Y[-1:]], dim=0 + ) + nondom_Y = Y[is_non_dominated(Y, deduplicate=False)] + self.assertTrue(torch.equal(expected_nondom_Y_no_dedup, nondom_Y)) + # test batch batch_Y = torch.stack([Y, Yb], dim=0) nondom_mask = is_non_dominated(batch_Y) self.assertTrue(torch.equal(batch_Y[0][nondom_mask[0]], expected_nondom_Y)) self.assertTrue(torch.equal(batch_Y[1][nondom_mask[1]], expected_nondom_Yb)) + # test deduplicate=False + expected_nondom_Yb_no_dedup = torch.cat( + [expected_nondom_Yb[:-1], expected_nondom_Yb[-2:]], dim=0 + ) + nondom_mask = is_non_dominated(batch_Y, deduplicate=False) + self.assertTrue( + torch.equal(batch_Y[0][nondom_mask[0]], expected_nondom_Y_no_dedup) + ) + self.assertTrue( + torch.equal(batch_Y[1][nondom_mask[1]], expected_nondom_Yb_no_dedup) + ) # test 3d nondom_Y3 = Y3[is_non_dominated(Y3)] self.assertTrue(torch.equal(expected_nondom_Y3, nondom_Y3)) + # test deduplicate=False + expected_nondom_Y3_no_dedup = torch.cat( + [expected_nondom_Y3[:3], expected_nondom_Y3[2:]], dim=0 + ) + nondom_Y3 = Y3[is_non_dominated(Y3, deduplicate=False)] + self.assertTrue(torch.equal(expected_nondom_Y3_no_dedup, nondom_Y3)) # test batch batch_Y3 = torch.stack([Y3, Y3b], dim=0) nondom_mask3 = is_non_dominated(batch_Y3) @@ -85,3 +111,14 @@ def test_is_non_dominated(self) -> None: self.assertTrue( torch.equal(batch_Y3[1][nondom_mask3[1]], expected_nondom_Y3b) ) + # test deduplicate=False + nondom_mask3 = is_non_dominated(batch_Y3, deduplicate=False) + self.assertTrue( + torch.equal(batch_Y3[0][nondom_mask3[0]], expected_nondom_Y3_no_dedup) + ) + expected_nondom_Y3b_no_dedup = torch.cat( + [expected_nondom_Y3b[:2], expected_nondom_Y3b[1:]], dim=0 + ) + self.assertTrue( + torch.equal(batch_Y3[1][nondom_mask3[1]], expected_nondom_Y3b_no_dedup) + )