Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache the latest result of HSSP for speedup of MOTPE #5454

Closed
wants to merge 11 commits into from
2 changes: 2 additions & 0 deletions optuna/_hypervolume/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from optuna._hypervolume.base import BaseHypervolume
from optuna._hypervolume.hssp import _solve_hssp
from optuna._hypervolume.hssp import _solve_hssp_with_cache
from optuna._hypervolume.utils import _compute_2d
from optuna._hypervolume.wfg import WFG

Expand All @@ -8,5 +9,6 @@
"BaseHypervolume",
"_compute_2d",
"_solve_hssp",
"_solve_hssp_with_cache",
"WFG",
]
35 changes: 35 additions & 0 deletions optuna/_hypervolume/hssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,38 @@ def _solve_hssp(
rank_i_unique_loss_vals, indices_of_unique_loss_vals, subset_size, reference_point
)
return rank_i_indices[subset_indices_of_unique_loss_vals]


def _solve_hssp_with_cache(
study: optuna.Study,
rank_i_loss_vals: np.ndarray,
rank_i_indices: np.ndarray,
subset_size: int,
reference_point: np.ndarray,
) -> np.ndarray:
hssp_cache = study._storage.get_study_system_attrs(study._study_id).get("hssp_cache", {})
cached_subset_size = hssp_cache.get("subset_size")
cached_ref_point = np.array(hssp_cache.get("reference_point", np.zeros_like(reference_point)))
cached_indices = np.array(hssp_cache.get("rank_i_indices", []))
cached_loss_vals = np.array(hssp_cache.get("rank_i_loss_vals", []))

if (
subset_size == cached_subset_size
and "selected_indices" in hssp_cache
and np.allclose(cached_ref_point, reference_point)
and np.array_equal(cached_indices, rank_i_indices)
and cached_loss_vals.shape == rank_i_loss_vals.shape
and np.allclose(cached_loss_vals, rank_i_loss_vals)
):
return np.asarray(hssp_cache["selected_indices"])

hssp_cache = {
"rank_i_loss_vals": rank_i_loss_vals.tolist(),
"rank_i_indices": rank_i_indices.tolist(),
"subset_size": subset_size,
"reference_point": reference_point.tolist(),
}
selected_indices = _solve_hssp(rank_i_loss_vals, rank_i_indices, subset_size, reference_point)
hssp_cache["selected_indices"] = selected_indices.tolist()
study._storage.set_study_system_attr(study._study_id, key="hssp_cache", value=hssp_cache)
return selected_indices
2 changes: 2 additions & 0 deletions optuna/_hypervolume/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import numpy as np


Expand Down
6 changes: 4 additions & 2 deletions optuna/samplers/_tpe/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from optuna._experimental import warn_experimental_argument
from optuna._hypervolume import WFG
from optuna._hypervolume.hssp import _solve_hssp
from optuna._hypervolume.hssp import _solve_hssp_with_cache
from optuna.distributions import BaseDistribution
from optuna.distributions import CategoricalChoiceType
from optuna.logging import get_logger
Expand Down Expand Up @@ -688,7 +688,9 @@ def _split_complete_trials_multi_objective(
worst_point = np.max(rank_i_lvals, axis=0)
reference_point = np.maximum(1.1 * worst_point, 0.9 * worst_point)
reference_point[reference_point == 0] = EPS
selected_indices = _solve_hssp(rank_i_lvals, rank_i_indices, subset_size, reference_point)
selected_indices = _solve_hssp_with_cache(
study, rank_i_lvals, rank_i_indices, subset_size, reference_point
)
indices_below[last_idx:] = selected_indices

below_trials = []
Expand Down
58 changes: 58 additions & 0 deletions tests/hypervolume_tests/test_hssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,61 @@ def test_solve_hssp_duplicated_infinite_loss() -> None:
rank_i_loss_vals=test_case, rank_i_indices=np.arange(4), subset_size=2, reference_point=r
)
assert (0 not in res) or (1 not in res)


def _solve_hssp_and_check_cache(
study: optuna.Study,
pareto_sols: np.ndarray,
pareto_indices: np.ndarray,
subset_size: int,
ref_point: np.ndarray,
) -> np.ndarray:
selected_indices = optuna._hypervolume._solve_hssp_with_cache(
study, pareto_sols, pareto_indices, subset_size=subset_size, reference_point=ref_point
)
hssp_cache = study._storage.get_study_system_attrs(study._study_id)["hssp_cache"]
assert np.allclose(hssp_cache["rank_i_loss_vals"], pareto_sols)
assert np.array_equal(hssp_cache["rank_i_indices"], pareto_indices)
assert hssp_cache["subset_size"] == subset_size
assert np.allclose(hssp_cache["reference_point"], ref_point)
return selected_indices


@pytest.mark.parametrize(
"storage",
(optuna.storages.RDBStorage("sqlite:///:memory:"), optuna.storages.InMemoryStorage()),
)
def test_solve_hssp_with_cache(storage: optuna.storages.BaseStorage) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first glance, it seems that this test includes too many preparations and logics. How about pre-determining the arguments of _solve_hssp_and_check_cache?

study = optuna.create_study(directions=["minimize"] * 2, storage=storage)
n_trials = 100
n_objectives = 3
rng = np.random.RandomState(42)
loss_vals = rng.random((n_trials, n_objectives))
indices = np.arange(n_trials)
on_front = optuna.study._multi_objective._is_pareto_front(
loss_vals, assume_unique_lexsorted=False
)
pareto_sols = loss_vals[on_front]
pareto_indices = indices[on_front]
ref_point = np.ones(n_objectives, dtype=float)
is_rank2 = optuna.study._multi_objective._is_pareto_front(
loss_vals[~on_front], assume_unique_lexsorted=False
)
rank_2_loss_vals = loss_vals[~on_front][is_rank2]
rank_2_indices = indices[~on_front][is_rank2]
subset_size = min(rank_2_indices.size, pareto_indices.size) // 2
selected_indices_list = []
for _ in range(2):
selected_indices_for_pareto = _solve_hssp_and_check_cache(
study, pareto_sols, pareto_indices, subset_size, ref_point
)
selected_indices_list.append(selected_indices_for_pareto.copy())

selected_indices_for_rank_2 = _solve_hssp_and_check_cache(
study, rank_2_loss_vals, rank_2_indices, subset_size, ref_point
)
selected_indices_list.append(selected_indices_for_rank_2.copy())
# The result should be identical.
assert np.all(np.sort(selected_indices_list[0]) == np.sort(selected_indices_list[1]))
# Cache should be deleted, meaning that the following condition must be satisfied.
assert np.all(np.sort(selected_indices_list[0]) != np.sort(selected_indices_list[-1]))
Loading