Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update `TPESampler` to support `HyperbandPruner` #828

Merged
merged 9 commits into from Jan 24, 2020
@@ -24,6 +24,21 @@ class HyperbandPruner(BasePruner):
Note that this implementation does not take as inputs the maximum amount of resource to
a single SHA noted as :math:`R` in the paper.
.. note::
If you use ``HyperbandPruner`` with :class:`~optuna.samplers.TPESampler`,
it's recommended to consider to set larger ``n_trials`` or ``timeout`` to make full use of
the characteristics of :class:`~optuna.samplers.TPESampler`
because :class:`~optuna.samplers.TPESampler` uses some (by default, :math:`10`)
:class:`~optuna.trial.Trial`\\ s for its startup.
As Hyperband runs multiple :class:`~optuna.pruners.SuccessiveHalvingPruner` and collect
trials based on the current :class:`~optuna.trial.Trial`\\ 's bracket ID, each bracket
needs to observe more than :math:`10` :class:`~optuna.trial.Trial`\\ s
for :class:`~optuna.samplers.TPESampler` to adapt its search space.
Thus, for example, if ``HyperbandPruner`` has :math:`4` pruners in it,
at least :math:`4 \\times 10` pruners are consumed for startup.
Args:
min_resource:
A parameter for specifying the minimum resource allocated to a trial noted as :math:`r`
@@ -34,7 +49,7 @@ class HyperbandPruner(BasePruner):
:math:`\\eta` in the paper. See the details for
:class:`~optuna.pruners.SuccessiveHalvingPruner`.
n_brackets:
The number of :class:`~optuna.pruners.SuccessiveHalvingPruner`\\s (brackets).
The number of :class:`~optuna.pruners.SuccessiveHalvingPruner`\\ s (brackets).
min_early_stopping_rate_low:
A parameter for specifying the minimum early-stopping rate.
This parameter is related to a parameter that is referred to as :math:`r` and used in
@@ -4,6 +4,7 @@
import scipy.special

from optuna import distributions
from optuna.pruners import HyperbandPruner
from optuna.samplers import base
from optuna.samplers import random
from optuna.samplers.tpe.parzen_estimator import _ParzenEstimator
@@ -124,7 +125,7 @@ def sample_relative(self, study, trial, search_space):
def sample_independent(self, study, trial, param_name, param_distribution):
# type: (Study, FrozenTrial, str, BaseDistribution) -> Any

values, scores = _get_observation_pairs(study, param_name)
values, scores = _get_observation_pairs(study, param_name, trial)

n = len(values)

@@ -510,8 +511,8 @@ def objective(trial):
}


def _get_observation_pairs(study, param_name):
# type: (Study, str) -> Tuple[List[float], List[Tuple[float, float]]]
def _get_observation_pairs(study, param_name, trial):
# type: (Study, str, FrozenTrial) -> Tuple[List[float], List[Tuple[float, float]]]
"""Get observation pairs from the study.
This function collects observation pairs from the complete or pruned trials of the study.
@@ -534,6 +535,11 @@ def _get_observation_pairs(study, param_name):
if study.direction == StudyDirection.MAXIMIZE:
sign = -1

if isinstance(study.pruner, HyperbandPruner):
# Create `_BracketStudy` to use trials that have the same bracket id.
pruner = study.pruner # type: HyperbandPruner
study = pruner._create_bracket_study(study, pruner._get_bracket_id(study, trial))

values = []
scores = []
for trial in study.get_trials(deepcopy=False):
@@ -191,16 +191,6 @@ def __init__(
self.sampler = sampler or samplers.TPESampler()
self.pruner = pruner or pruners.MedianPruner()

if (
isinstance(self.sampler, samplers.TPESampler) and
isinstance(self.pruner, pruners.HyperbandPruner)):
msg = (
"The algorithm of TPESampler and HyperbandPruner might behave in a different way "
"from the paper of Hyperband because the sampler uses all the trials including "
"ones of brackets other than that of currently running trial")
warnings.warn(msg, UserWarning)
_logger.warning(msg)

self._optimize_lock = threading.Lock()

def __getstate__(self):
@@ -15,13 +15,6 @@
EXPECTED_N_TRIALS_PER_BRACKET = 10


def test_warn_on_TPESampler():
# type: () -> None

with pytest.warns(UserWarning):
optuna.study.create_study(pruner=optuna.pruners.HyperbandPruner())


def test_hyperband_pruner_intermediate_values():
# type: () -> None

@@ -1,3 +1,5 @@
import pytest

import optuna
from optuna.exceptions import TrialPruned
from optuna.samplers import tpe
@@ -7,11 +9,13 @@
from optuna.trial import Trial # NOQA


def test_hyperopt_parameters():
# type: () -> None
@pytest.mark.parametrize('use_hyperband', [False, True])
def test_hyperopt_parameters(use_hyperband):
# type: (bool) -> None

sampler = TPESampler(**TPESampler.hyperopt_parameters())
study = optuna.create_study(sampler=sampler)
study = optuna.create_study(
sampler=sampler, pruner=optuna.pruners.HyperbandPruner() if use_hyperband else None)
study.optimize(lambda t: t.suggest_uniform('x', 10, 20), n_trials=50)


@@ -39,29 +43,30 @@ def objective(trial):
# direction=minimize.
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=5, catch=(RuntimeError,))
study._storage.create_new_trial(study._study_id) # Create a running trial.
trial_number = study._storage.create_new_trial(study._study_id) # Create a running trial.
trial = study._storage.get_trial(trial_number)

assert tpe.sampler._get_observation_pairs(study, 'x') == (
assert tpe.sampler._get_observation_pairs(study, 'x', trial) == (
[5.0, 5.0, 5.0, 5.0],
[
(-float('inf'), 5.0), # COMPLETE
(-7, 2), # PRUNED (with intermediate values)
(-3, float('inf')), # PRUNED (with a NaN intermediate value; it's treated as infinity)
(float('inf'), 0.0) # PRUNED (without intermediate values)
])
assert tpe.sampler._get_observation_pairs(study, 'y') == ([], [])
assert tpe.sampler._get_observation_pairs(study, 'y', trial) == ([], [])

# direction=maximize.
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=4)
study._storage.create_new_trial(study._study_id) # Create a running trial.

assert tpe.sampler._get_observation_pairs(study, 'x') == (
assert tpe.sampler._get_observation_pairs(study, 'x', trial) == (
[5.0, 5.0, 5.0, 5.0],
[
(-float('inf'), -5.0), # COMPLETE
(-7, -2), # PRUNED (with intermediate values)
(-3, float('inf')), # PRUNED (with a NaN intermediate value; it's treated as infinity)
(float('inf'), 0.0) # PRUNED (without intermediate values)
])
assert tpe.sampler._get_observation_pairs(study, 'y') == ([], [])
assert tpe.sampler._get_observation_pairs(study, 'y', trial) == ([], [])
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.