Skip to content

Commit

Permalink
Merge pull request #4570 from contramundum53/backport-4462
Browse files Browse the repository at this point in the history
Backport #4462 for v3.1.1
  • Loading branch information
HideakiImamura committed Apr 3, 2023
2 parents 7f4da50 + f3292ab commit 970479b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
5 changes: 2 additions & 3 deletions optuna/samplers/_tpe/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,9 @@ def _sample_relative(
indices_below, indices_above = _split_observation_pairs(scores, self._gamma(n), violations)
# `None` items are intentionally converted to `nan` and then filtered out.
# For `nan` conversion, the dtype must be float.
# `None` items appear only when `group=True`. We just use the first parameter because the
# masks are the same for all parameters in one group.
# `None` items appear when `group=True` or `constant_liar=True`.
config_values = {k: np.asarray(v, dtype=float) for k, v in values.items()}
param_mask = ~np.isnan(list(config_values.values())[0])
param_mask = np.all(~np.isnan(list(config_values.values())), axis=0)
param_mask_below, param_mask_above = param_mask[indices_below], param_mask[indices_above]
below = {k: v[indices_below[param_mask_below]] for k, v in config_values.items()}
above = {k: v[indices_above[param_mask_above]] for k, v in config_values.items()}
Expand Down
31 changes: 31 additions & 0 deletions tests/samplers_tests/tpe_tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,3 +1157,34 @@ def test_constant_liar_observation_pairs(direction: str) -> None:
def test_constant_liar_experimental_warning() -> None:
with pytest.warns(optuna.exceptions.ExperimentalWarning):
_ = TPESampler(constant_liar=True)


@pytest.mark.parametrize("multivariate", [True, False])
def test_constant_liar_with_running_trial(multivariate: bool) -> None:
with warnings.catch_warnings():
warnings.simplefilter("ignore", optuna.exceptions.ExperimentalWarning)
sampler = TPESampler(multivariate=multivariate, constant_liar=True, n_startup_trials=0)

study = optuna.create_study(sampler=sampler)

# Add a complete trial.
trial0 = study.ask()
trial0.suggest_int("x", 0, 10)
trial0.suggest_float("y", 0, 10)
trial0.suggest_categorical("z", [0, 1, 2])
study.tell(trial0, 0)

# Add running trials.
trial1 = study.ask()
trial1.suggest_int("x", 0, 10)
trial2 = study.ask()
trial2.suggest_float("y", 0, 10)
trial3 = study.ask()
trial3.suggest_categorical("z", [0, 1, 2])

# Test suggestion with running trials.
trial = study.ask()
trial.suggest_int("x", 0, 10)
trial.suggest_float("y", 0, 10)
trial.suggest_categorical("z", [0, 1, 2])
study.tell(trial, 0)

0 comments on commit 970479b

Please sign in to comment.