Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 9 additions & 22 deletions botorch/acquisition/multi_objective/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ def __init__(
self.constraints = constraints
self.eta = eta
self.register_buffer("ref_point", ref_point)
self.partitioning = partitioning
cell_bounds = self.partitioning.get_hypercell_bounds(ref_point=self.ref_point)
cell_bounds = partitioning.get_hypercell_bounds(ref_point=self.ref_point)
self.register_buffer("cell_lower_bounds", cell_bounds[0])
self.register_buffer("cell_upper_bounds", cell_bounds[1])
self.q = -1
Expand Down Expand Up @@ -231,35 +230,23 @@ def _compute_qehvi(self, samples: Tensor) -> Tensor:
# memory usage.
q_choose_i = self.q_subset_indices[f"q_choose_{i}"]
# this tensor is mc_samples x batch_shape x i x q_choose_i x m
obj_subsets = torch.stack(
[obj.index_select(dim=-2, index=q_choose_i[:, k]) for k in range(i)],
dim=-3,
obj_subsets = obj.index_select(dim=-2, index=q_choose_i.view(-1))
obj_subsets = obj_subsets.view(
obj.shape[:-2] + q_choose_i.shape + obj.shape[-1:]
)
# since all hyperrectangles share one vertex, the opposite vertex of the
# overlap is given by the component-wise minimum.
# take the minimum in each subset
overlap_vertices = obj_subsets.min(dim=-3).values
expanded_shape = (
batch_shape
+ self.cell_upper_bounds.shape[-2:-1]
+ overlap_vertices.shape[-2:]
)
overlap_vertices = obj_subsets.min(dim=-2).values
# add batch-dim to compute area for each segment (pseudo-pareto-vertex)
# this tensor is mc_samples x batch_shape x num_cells x q_choose_i x m
overlap_vertices = overlap_vertices.unsqueeze(-3).expand(
*batch_shape,
self.cell_lower_bounds.shape[-2],
*overlap_vertices.shape[-2:],
)
overlap_vertices = torch.min(
overlap_vertices,
self.cell_upper_bounds.view(view_shape).expand(expanded_shape),
overlap_vertices.unsqueeze(-3), self.cell_upper_bounds.view(view_shape)
)
# substract cell lower bounds, clamp min at zero
lengths_i = overlap_vertices - self.cell_lower_bounds.view(
view_shape
).expand(expanded_shape)
lengths_i = lengths_i.clamp_min(0.0)
lengths_i = (
overlap_vertices - self.cell_lower_bounds.view(view_shape)
).clamp_min(0.0)
# take product over hyperrectangle side lengths to compute area
# sum over all subsets of size i
areas_i = lengths_i.prod(dim=-1).sum(dim=-1)
Expand Down
7 changes: 6 additions & 1 deletion botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,13 @@ def get_acquisition_function(
raise ValueError("`Y` must be specified in kwargs for qEHVI")
ref_point = kwargs["ref_point"]
num_outcomes = len(ref_point)
Y = kwargs.get("Y")
# get feasible points
if constraints is not None:
feas = torch.stack([c(Y) <= 0 for c in constraints], dim=-1).all(dim=-1)
Y = Y[feas]
partitioning = NondominatedPartitioning(
num_outcomes=num_outcomes, Y=kwargs.get("Y")[:, :num_outcomes]
num_outcomes=num_outcomes, Y=Y[:, :num_outcomes]
)
return moo_monte_carlo.qExpectedHypervolumeImprovement(
model=model,
Expand Down