Skip to content

Commit

Permalink
Get values after _get_observation_pairs (#4742)
Browse files Browse the repository at this point in the history
* Get values after _get_observation_pairs

* Fix tests for multiobjective optimization with constant_liar
  • Loading branch information
not522 committed Jun 26, 2023
1 parent a81f5f9 commit b00caef
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 101 deletions.
87 changes: 40 additions & 47 deletions optuna/samplers/_tpe/sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import math
from typing import Any
from typing import Callable
from typing import Container
from typing import Dict
from typing import List
from typing import Optional
Expand Down Expand Up @@ -398,32 +399,51 @@ def sample_independent(

return self._sample(study, trial, {param_name: param_distribution})[param_name]

def _get_internal_repr(
self, trials: list[FrozenTrial], search_space: dict[str, BaseDistribution]
) -> dict[str, np.ndarray]:
values: dict[str, list[float]] = {param_name: [] for param_name in search_space}
for trial in trials:
if all((param_name in trial.params) for param_name in search_space):
for param_name in search_space:
param = trial.params[param_name]
distribution = trial.distributions[param_name]
values[param_name].append(distribution.to_internal_repr(param))
return {k: np.asarray(v) for k, v in values.items()}

def _sample(
self, study: Study, trial: FrozenTrial, search_space: Dict[str, BaseDistribution]
) -> Dict[str, Any]:
param_names = list(search_space.keys())
values, scores, violations = _get_observation_pairs(
if self._constant_liar and not study._is_multi_objective():
states = [TrialState.COMPLETE, TrialState.PRUNED, TrialState.RUNNING]
else:
states = [TrialState.COMPLETE, TrialState.PRUNED]
use_cache = not self._constant_liar
trials = study._get_trials(deepcopy=False, states=states, use_cache=use_cache)

scores, violations = _get_observation_pairs(
study,
param_names,
self._constant_liar,
trials,
self._constraints_func is not None,
)

n = sum(s < float("inf") for s, v in scores) # Ignore running trials.

# We divide data into below and above.
indices_below, indices_above = _split_observation_pairs(scores, self._gamma(n), violations)
# `None` items are intentionally converted to `nan` and then filtered out.
# For `nan` conversion, the dtype must be float.
# `None` items appear when `group=True` or `constant_liar=True`.
config_values = {k: np.asarray(v, dtype=float) for k, v in values.items()}
param_mask = np.all(~np.isnan(list(config_values.values())), axis=0)
param_mask_below, param_mask_above = param_mask[indices_below], param_mask[indices_above]
below = {k: v[indices_below[param_mask_below]] for k, v in config_values.items()}
above = {k: v[indices_above[param_mask_above]] for k, v in config_values.items()}

below_trials = np.asarray(trials, dtype=object)[indices_below].tolist()
above_trials = np.asarray(trials, dtype=object)[indices_above].tolist()
below = self._get_internal_repr(below_trials, search_space)
above = self._get_internal_repr(above_trials, search_space)

# We then sample by maximizing log likelihood ratio.
if study._is_multi_objective():
param_mask_below = []
for trial in below_trials:
param_mask_below.append(
all((param_name in trial.params) for param_name in search_space)
)
weights_below = _calculate_weights_below_for_multi_objective(
scores, indices_below, violations
)[param_mask_below]
Expand Down Expand Up @@ -540,14 +560,9 @@ def _calculate_nondomination_rank(loss_vals: np.ndarray) -> np.ndarray:

def _get_observation_pairs(
study: Study,
param_names: List[str],
constant_liar: bool = False, # TODO(hvy): Remove default value and fix unit tests.
trials: list[FrozenTrial],
constraints_enabled: bool = False,
) -> Tuple[
Dict[str, List[Optional[float]]],
List[Tuple[float, List[float]]],
Optional[List[float]],
]:
) -> tuple[list[tuple[float, list[float]]], list[float] | None]:
"""Get observation pairs from the study.
This function collects observation pairs from the complete or pruned trials of the study.
Expand Down Expand Up @@ -577,24 +592,15 @@ def _get_observation_pairs(
else:
signs.append(-1)

states: Container[TrialState]
if constant_liar:
states = (TrialState.COMPLETE, TrialState.PRUNED, TrialState.RUNNING)
else:
states = (TrialState.COMPLETE, TrialState.PRUNED)

scores = []
values: Dict[str, List[Optional[float]]] = {param_name: [] for param_name in param_names}
violations: Optional[List[float]] = [] if constraints_enabled else None
for trial in study._get_trials(deepcopy=False, states=states, use_cache=not constant_liar):
for trial in trials:
# We extract score from the trial.
if trial.state is TrialState.COMPLETE:
if trial.values is None:
continue
assert trial.values is not None
score = (-float("inf"), [sign * v for sign, v in zip(signs, trial.values)])
elif trial.state is TrialState.PRUNED:
if study._is_multi_objective():
continue
assert not study._is_multi_objective()

if len(trial.intermediate_values) > 0:
step, intermediate_value = max(trial.intermediate_values.items())
Expand All @@ -605,25 +611,12 @@ def _get_observation_pairs(
else:
score = (1, [0.0])
elif trial.state is TrialState.RUNNING:
if study._is_multi_objective():
continue

assert constant_liar
assert not study._is_multi_objective()
score = (float("inf"), [signs[0] * float("inf")])
else:
assert False
scores.append(score)

# We extract param_value from the trial.
for param_name in param_names:
param_value: Optional[float]
if param_name in trial.params:
distribution = trial.distributions[param_name]
param_value = distribution.to_internal_repr(trial.params[param_name])
else:
param_value = None
values[param_name].append(param_value)

if constraints_enabled:
assert violations is not None
if trial.state != TrialState.RUNNING:
Expand All @@ -641,7 +634,7 @@ def _get_observation_pairs(
else:
violations.append(float("inf"))

return values, scores, violations
return scores, violations


def _split_observation_pairs(
Expand Down
30 changes: 6 additions & 24 deletions tests/samplers_tests/tpe_tests/test_multi_objective_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,23 +339,8 @@ def objective(trial: optuna.trial.Trial) -> Tuple[float, float]:
)
)

assert _tpe.sampler._get_observation_pairs(study, ["x"], constant_liar) == (
{"x": [int_value, int_value]},
[(-float("inf"), [objective_value, -objective_value]) for _ in range(2)],
None,
)
assert _tpe.sampler._get_observation_pairs(study, ["y"], constant_liar) == (
{"y": [0, 0]},
[(-float("inf"), [objective_value, -objective_value]) for _ in range(2)],
None,
)
assert _tpe.sampler._get_observation_pairs(study, ["x", "y"], constant_liar) == (
{"x": [int_value, int_value], "y": [0, 0]},
[(-float("inf"), [objective_value, -objective_value]) for _ in range(2)],
None,
)
assert _tpe.sampler._get_observation_pairs(study, ["z"], constant_liar) == (
{"z": [None, None]},
states = [optuna.trial.TrialState.COMPLETE, optuna.trial.TrialState.PRUNED]
assert _tpe.sampler._get_observation_pairs(study, study.get_trials(states=states)) == (
[(-float("inf"), [objective_value, -objective_value]) for _ in range(2)],
None,
)
Expand All @@ -373,13 +358,10 @@ def objective(trial: optuna.trial.Trial) -> Tuple[float, float]:
study.optimize(objective, n_trials=5)

violations = [max(0, constraint_value) for _ in range(5)]
assert _tpe.sampler._get_observation_pairs(study, ["x"], constraints_enabled=True) == (
{"x": [5.0, 5.0, 5.0, 5.0, 5.0]},
[(-float("inf"), [5.0, -5.0]) for _ in range(5)],
violations,
)
assert _tpe.sampler._get_observation_pairs(study, ["y"], constraints_enabled=True) == (
{"y": [None, None, None, None, None]},
states = (optuna.trial.TrialState.COMPLETE, optuna.trial.TrialState.PRUNED)
assert _tpe.sampler._get_observation_pairs(
study, study.get_trials(states=states), constraints_enabled=True
) == (
[(-float("inf"), [5.0, -5.0]) for _ in range(5)],
violations,
)
Expand Down
44 changes: 14 additions & 30 deletions tests/samplers_tests/tpe_tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,24 +764,10 @@ def objective(trial: Trial) -> float:
(-3, [float("inf")]), # PRUNED (with a NaN intermediate value; it's treated as infinity)
(1, [sign * 0.0]), # PRUNED (without intermediate values)
]
states = (optuna.trial.TrialState.COMPLETE, optuna.trial.TrialState.PRUNED)
assert _tpe.sampler._get_observation_pairs(
study, ["x"], constraints_enabled=constraints_enabled
study, study.get_trials(states=states), constraints_enabled=constraints_enabled
) == (
{"x": [5.0, 5.0, 5.0, 5.0]},
scores,
expected_violations,
)
assert _tpe.sampler._get_observation_pairs(
study, ["y"], constraints_enabled=constraints_enabled
) == (
{"y": [None, None, None, None]},
scores,
expected_violations,
)
assert _tpe.sampler._get_observation_pairs(
study, ["z"], constraints_enabled=constraints_enabled
) == (
{"z": [0, 0, 0, 0]}, # The internal representation of 'None' for z is 0
scores,
expected_violations,
)
Expand Down Expand Up @@ -823,10 +809,10 @@ def objective(trial: Trial) -> float:
study.optimize(objective, n_trials=5, catch=(RuntimeError,))

sign = 1 if direction == "minimize" else -1
states = (optuna.trial.TrialState.COMPLETE, optuna.trial.TrialState.PRUNED)
assert _tpe.sampler._get_observation_pairs(
study, ["x", "y"], constraints_enabled=constraints_enabled
study, study.get_trials(states=states), constraints_enabled=constraints_enabled
) == (
{"x": [5.0, 5.0, 5.0, 5.0], "y": [6.0, 6.0, 6.0, 6.0]},
[
(-float("inf"), [sign * 11.0]), # COMPLETE
(-7, [sign * 2]), # PRUNED (with intermediate values)
Expand Down Expand Up @@ -953,13 +939,6 @@ def test_split_order(direction: str, constant_liar: bool, constraints: bool) ->
)
)

values, scores, violations = _tpe.sampler._get_observation_pairs(
study,
["x"],
constant_liar,
constraints,
)

if constant_liar:
states = [
optuna.trial.TrialState.COMPLETE,
Expand All @@ -969,7 +948,12 @@ def test_split_order(direction: str, constant_liar: bool, constraints: bool) ->
else:
states = [optuna.trial.TrialState.COMPLETE, optuna.trial.TrialState.PRUNED]

assert len(values["x"]) == len(study.get_trials(states=states))
scores, violations = _tpe.sampler._get_observation_pairs(
study,
study.get_trials(states=states),
constraints,
)

assert len(scores) == len(study.get_trials(states=states))
if constraints:
assert violations is not None
Expand Down Expand Up @@ -1140,14 +1124,14 @@ def test_constant_liar_observation_pairs(direction: str) -> None:
# and `-float("inf")` during maximization.
expected_values = [(float("inf"), [float("inf") * (-1 if direction == "maximize" else 1)])]

assert _tpe.sampler._get_observation_pairs(study, ["x"], constant_liar=False) == (
{"x": []},
states = [optuna.trial.TrialState.COMPLETE, optuna.trial.TrialState.PRUNED]
assert _tpe.sampler._get_observation_pairs(study, study.get_trials(states=states)) == (
[],
None,
)

assert _tpe.sampler._get_observation_pairs(study, ["x"], constant_liar=True) == (
{"x": [2]},
states.append(optuna.trial.TrialState.RUNNING)
assert _tpe.sampler._get_observation_pairs(study, study.get_trials(states=states)) == (
expected_values,
None,
)
Expand Down

0 comments on commit b00caef

Please sign in to comment.