Skip to content

Commit

Permalink
Merge pull request #5424 from nabenabe0928/enhance/speed-up-wfg
Browse files Browse the repository at this point in the history
Speed up `WFG` by NumPy vectorization
  • Loading branch information
not522 committed May 8, 2024
2 parents bca90dd + 8e76a47 commit 8762a2f
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 96 deletions.
2 changes: 0 additions & 2 deletions optuna/_hypervolume/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from optuna._hypervolume.base import BaseHypervolume
from optuna._hypervolume.hssp import _solve_hssp
from optuna._hypervolume.utils import _compute_2d
from optuna._hypervolume.utils import _compute_2points_volume
from optuna._hypervolume.wfg import WFG


__all__ = [
"BaseHypervolume",
"_compute_2d",
"_compute_2points_volume",
"_solve_hssp",
"WFG",
]
13 changes: 0 additions & 13 deletions optuna/_hypervolume/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
import numpy as np


def _compute_2points_volume(point1: np.ndarray, point2: np.ndarray) -> float:
"""Compute the hypervolume of the hypercube, whose diagonal endpoints are given 2 points.
Args:
point1:
The first endpoint of the hypercube's diagonal.
point2:
The second endpoint of the hypercube's diagonal.
"""

return float(np.abs(np.prod(point1 - point2)))


def _compute_2d(solution_set: np.ndarray, reference_point: np.ndarray) -> float:
"""Compute the hypervolume for the two-dimensional space.
Expand Down
94 changes: 25 additions & 69 deletions optuna/_hypervolume/wfg.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Optional
from __future__ import annotations

import numpy as np

from optuna._hypervolume import _compute_2d
from optuna._hypervolume import _compute_2points_volume
from optuna._hypervolume import BaseHypervolume
from optuna.study._multi_objective import _is_pareto_front


class WFG(BaseHypervolume):
Expand All @@ -17,78 +17,34 @@ class WFG(BaseHypervolume):
"""

def __init__(self) -> None:
self._reference_point: Optional[np.ndarray] = None
self._reference_point: np.ndarray | None = None

def _compute(self, solution_set: np.ndarray, reference_point: np.ndarray) -> float:
self._reference_point = reference_point.astype(np.float64)
return self._compute_rec(solution_set.astype(np.float64))

def _compute_rec(self, solution_set: np.ndarray) -> float:
assert self._reference_point is not None
n_points = solution_set.shape[0]

if self._reference_point.shape[0] == 2:
return _compute_2d(solution_set, self._reference_point)

if n_points == 1:
return _compute_2points_volume(solution_set[0], self._reference_point)
elif n_points == 2:
volume = 0.0
volume += _compute_2points_volume(solution_set[0], self._reference_point)
volume += _compute_2points_volume(solution_set[1], self._reference_point)
intersection = self._reference_point - np.maximum(solution_set[0], solution_set[1])
volume -= np.prod(intersection)

return volume

solution_set = solution_set[solution_set[:, 0].argsort()]
return self._compute_hv(solution_set[solution_set[:, 0].argsort()].astype(np.float64))

# n_points >= 3
volume = 0.0
for i in range(n_points):
volume += self._compute_exclusive_hv(solution_set[i], solution_set[i + 1 :])
return volume

def _compute_exclusive_hv(self, point: np.ndarray, solution_set: np.ndarray) -> float:
def _compute_hv(self, sorted_sols: np.ndarray) -> float:
assert self._reference_point is not None
volume = _compute_2points_volume(point, self._reference_point)
limited_solution_set = self._limit(point, solution_set)
n_points_of_s = limited_solution_set.shape[0]
if n_points_of_s == 1:
volume -= _compute_2points_volume(limited_solution_set[0], self._reference_point)
elif n_points_of_s > 1:
volume -= self._compute_rec(limited_solution_set)
return volume

@staticmethod
def _limit(point: np.ndarray, solution_set: np.ndarray) -> np.ndarray:
"""Limit the points in the solution set for the given point.
Let `S := solution set`, `p := point` and `d := dim(p)`.
The returned solution set `S'` is
`S' = Pareto({s' | for all i in [d], exists s in S, s'_i = max(s_i, p_i)})`,
where `Pareto(T) = the points in T which are Pareto optimal`.
"""
n_points_of_s = solution_set.shape[0]

limited_solution_set = np.maximum(solution_set, point)

# Return almost Pareto optimal points for computational efficiency.
# If the points in the solution set are completely sorted along all coordinates,
# the following procedures return the complete Pareto optimal points.
# For the computational efficiency, we do not completely sort the points,
# but just sort the points according to its 0-th dimension.
if n_points_of_s <= 1:
return limited_solution_set
else:
# Assume limited_solution_set is sorted by its 0th dimension.
# Therefore, we can simply scan the limited solution set from left to right.
returned_limited_solution_set = [limited_solution_set[0]]
left = 0
right = 1
while right < n_points_of_s:
if (limited_solution_set[left] > limited_solution_set[right]).any():
left = right
returned_limited_solution_set.append(limited_solution_set[left])
right += 1
return np.asarray(returned_limited_solution_set)
inclusive_hvs = np.prod(self._reference_point - sorted_sols, axis=-1)
if inclusive_hvs.shape[0] == 1:
return float(inclusive_hvs[0])
elif inclusive_hvs.shape[0] == 2:
# S(A v B) = S(A) + S(B) - S(A ^ B).
intersec = np.prod(self._reference_point - np.maximum(sorted_sols[0], sorted_sols[1]))
return np.sum(inclusive_hvs) - intersec

limited_sols_array = np.maximum(sorted_sols[:, np.newaxis], sorted_sols)
return sum(
self._compute_exclusive_hv(limited_sols_array[i, i + 1 :], inclusive_hv)
for i, inclusive_hv in enumerate(inclusive_hvs)
)

def _compute_exclusive_hv(self, limited_sols: np.ndarray, inclusive_hv: float) -> float:
if limited_sols.shape[0] == 0:
return inclusive_hv

on_front = _is_pareto_front(limited_sols, assume_unique_lexsorted=False)
return inclusive_hv - self._compute_hv(limited_sols[on_front])
1 change: 0 additions & 1 deletion optuna/study/_multi_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def _calculate_nondomination_rank(
# It ensures that trials[j] will not dominate trials[i] for i < j.
# np.unique does lexsort.
unique_lexsorted_loss_values, order_inv = np.unique(loss_values, return_inverse=True, axis=0)

n_unique = unique_lexsorted_loss_values.shape[0]
# Clip n_below.
n_below = min(n_below or len(unique_lexsorted_loss_values), len(unique_lexsorted_loss_values))
Expand Down
11 changes: 0 additions & 11 deletions tests/hypervolume_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,6 @@
import optuna


def test_compute_2points_volume() -> None:
p1 = np.ones(10)
p2 = np.zeros(10)
assert 1 == optuna._hypervolume._compute_2points_volume(p1, p2)
assert 1 == optuna._hypervolume._compute_2points_volume(p2, p1)

p1 = np.ones(10) * 2
p2 = np.ones(10)
assert 1 == optuna._hypervolume._compute_2points_volume(p1, p2)


def test_compute_2d() -> None:
for n in range(2, 30):
r = n * np.ones(2)
Expand Down

0 comments on commit 8762a2f

Please sign in to comment.