Skip to content

Commit

Permalink
Merge pull request #809 from crcrpar/naive-hyperband
Browse files Browse the repository at this point in the history
Add `HyperbandPruner`
  • Loading branch information
sile committed Jan 15, 2020
2 parents 4089a71 + 7bc4874 commit 0db5508
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/source/reference/pruners.rst
Expand Up @@ -21,3 +21,7 @@ Pruners
.. autoclass:: SuccessiveHalvingPruner
:members:
:exclude-members: prune

.. autoclass:: HyperbandPruner
:members:
:exclude-members: prune
1 change: 1 addition & 0 deletions optuna/pruners/__init__.py
@@ -1,4 +1,5 @@
from optuna.pruners.base import BasePruner # NOQA
from optuna.pruners.hyperband import HyperbandPruner # NOQA
from optuna.pruners.median import MedianPruner # NOQA
from optuna.pruners.nop import NopPruner # NOQA
from optuna.pruners.percentile import PercentilePruner # NOQA
Expand Down
155 changes: 155 additions & 0 deletions optuna/pruners/hyperband.py
@@ -0,0 +1,155 @@
import optuna
from optuna import logging
from optuna.pruners.base import BasePruner
from optuna.pruners.successive_halving import SuccessiveHalvingPruner
from optuna import type_checking

if type_checking.TYPE_CHECKING:
from typing import List # NOQA

from optuna import structs # NOQA

_logger = logging.get_logger(__name__)


class HyperbandPruner(BasePruner):
"""Pruner using Hyperband.
As SuccessiveHalving (SHA) requires the number of configurations
:math:`n` as its hyperparameter. For a given finite budget :math:`B`,
all the configurations have the resources of :math:`B \\over n` on average.
As you can see, there will be a trade-off of :math:`B` and :math:`B \\over n`.
`Hyperband <http://www.jmlr.org/papers/volume18/16-558/16-558.pdf>`_ attacks this trade-off
by trying different :math:`n` values for a fixed budget.
Note that this implementation does not take as inputs the maximum amount of resource to
a single SHA noted as :math:`R` in the paper.
Args:
min_resource:
A parameter for specifying the minimum resource allocated to a trial noted as :math:`r`
in the paper.
See the details for :class:`~optuna.pruners.SuccessiveHalvingPruner`.
reduction_factor:
A parameter for specifying reduction factor of promotable trials noted as
:math:`\\eta` in the paper. See the details for
:class:`~optuna.pruners.SuccessiveHalvingPruner`.
n_brackets:
The number of :class:`~optuna.pruners.SuccessiveHalvingPruner`\\s (brackets).
min_early_stopping_rate_low:
A parameter for specifying the minimum early-stopping rate.
This parameter is related to a parameter that is referred to as :math:`r` and used in
`Asynchronous SuccessiveHalving paper <http://arxiv.org/abs/1810.05934>`_.
The minimum early stopping rate for :math:`i` th bracket is :math:`i + s`.
"""

def __init__(
self,
min_resource=1,
reduction_factor=3,
n_brackets=4,
min_early_stopping_rate_low=0
):
# type: (int, int, int, int) -> None

self._pruners = [] # type: List[SuccessiveHalvingPruner]
self._reduction_factor = reduction_factor
self._resource_budget = 0
self._n_brackets = n_brackets
self._bracket_resource_budgets = [] # type: List[int]

_logger.debug('Hyperband has {} brackets'.format(self._n_brackets))

for i in range(n_brackets):
bracket_resource_budget = self._calc_bracket_resource_budget(i, n_brackets)
self._resource_budget += bracket_resource_budget
self._bracket_resource_budgets.append(bracket_resource_budget)

# N.B. (crcrpar): `min_early_stopping_rate` has the information of `bracket_index`.
min_early_stopping_rate = min_early_stopping_rate_low + i

_logger.debug(
'{}th bracket has minimum early stopping rate of {}'.format(
i, min_early_stopping_rate))

pruner = SuccessiveHalvingPruner(
min_resource=min_resource,
reduction_factor=reduction_factor,
min_early_stopping_rate=min_early_stopping_rate,
)
self._pruners.append(pruner)

def prune(self, study, trial):
# type: (optuna.study.Study, structs.FrozenTrial) -> bool

i = self._get_bracket_id(study, trial)
_logger.debug('{}th bracket is selected'.format(i))
bracket_study = self._create_bracket_study(study, i)
return self._pruners[i].prune(bracket_study, trial)

# TODO(crcrpar): Improve resource computation/allocation algorithm.
def _calc_bracket_resource_budget(self, pruner_index, n_brackets):
# type: (int, int) -> int

n = self._reduction_factor ** (n_brackets - 1)
return n + (n / 2) * (n_brackets - 1 - pruner_index)

def _get_bracket_id(self, study, trial):
# type: (optuna.study.Study, structs.FrozenTrial) -> int
"""Computes the index of bracket for a trial of ``trial_number``.
The index of a bracket is noted as :math:`s` in
`Hyperband paper <http://www.jmlr.org/papers/volume18/16-558/16-558.pdf>`_.
"""

n = hash('{}_{}'.format(study.study_name, trial.number)) % self._resource_budget
for i in range(self._n_brackets):
n -= self._bracket_resource_budgets[i]
if n < 0:
return i

assert False, 'This line should be unreachable.'

def _create_bracket_study(self, study, bracket_index):
# type: (optuna.study.Study, int) -> optuna.study.Study

# This class is assumed to be passed to
# `SuccessiveHalvingPruner.prune` in which `get_trials`,
# `direction`, and `storage` are used.
# But for safety, prohibit the other attributes explicitly.
class _BracketStudy(optuna.study.Study):

_VALID_ATTRS = (
'get_trials', 'direction', '_storage', '_study_id',
'pruner', 'study_name', '_bracket_id', 'sampler'
)

def __init__(self, study, bracket_id):
# type: (optuna.study.Study, int) -> None

super().__init__(
study_name=study.study_name,
storage=study._storage,
sampler=study.sampler,
pruner=study.pruner
)
self._bracket_id = bracket_id

def get_trials(self, deepcopy=True):
# type: (bool) -> List[structs.FrozenTrial]

trials = super().get_trials(deepcopy=deepcopy)
pruner = self.pruner
assert isinstance(pruner, HyperbandPruner)
return [
t for t in trials
if pruner._get_bracket_id(self, t) == self._bracket_id
]

def __getattribute__(self, attr_name): # type: ignore
if attr_name not in _BracketStudy._VALID_ATTRS:
raise AttributeError(
"_BracketStudy does not have attribute of '{}'".format(attr_name))
else:
return object.__getattribute__(self, attr_name)

return _BracketStudy(study, bracket_index)
10 changes: 10 additions & 0 deletions optuna/study.py
Expand Up @@ -191,6 +191,16 @@ def __init__(
self.sampler = sampler or samplers.TPESampler()
self.pruner = pruner or pruners.MedianPruner()

if (
isinstance(self.sampler, samplers.TPESampler) and
isinstance(self.pruner, pruners.HyperbandPruner)):
msg = (
"The algorithm of TPESampler and HyperbandPruner might behave in a different way "
"from the paper of Hyperband because the sampler uses all the trials including "
"ones of brackets other than that of currently running trial")
warnings.warn(msg, UserWarning)
_logger.warning(msg)

self._optimize_lock = threading.Lock()

def __getstate__(self):
Expand Down
84 changes: 84 additions & 0 deletions tests/pruners_tests/test_hyperband.py
@@ -0,0 +1,84 @@
import pytest

import optuna
from optuna import type_checking

if type_checking.TYPE_CHECKING:
from optuna.trial import Trial # NOQA

MIN_RESOURCE = 1
REDUCTION_FACTOR = 2
N_BRACKETS = 4
EARLY_STOPPING_RATE_LOW = 0
EARLY_STOPPING_RATE_HIGH = 3
N_REPORTS = 10
EXPECTED_N_TRIALS_PER_BRACKET = 10


def test_warn_on_TPESampler():
# type: () -> None

with pytest.warns(UserWarning):
optuna.study.create_study(pruner=optuna.pruners.HyperbandPruner())


def test_hyperband_pruner_intermediate_values():
# type: () -> None

pruner = optuna.pruners.HyperbandPruner(
min_resource=MIN_RESOURCE,
reduction_factor=REDUCTION_FACTOR,
n_brackets=N_BRACKETS
)

study = optuna.study.create_study(sampler=optuna.samplers.RandomSampler(), pruner=pruner)

def objective(trial):
# type: (Trial) -> float

for i in range(N_REPORTS):
trial.report(i)

return 1.0

study.optimize(objective, n_trials=N_BRACKETS * EXPECTED_N_TRIALS_PER_BRACKET)

trials = study.trials
assert len(trials) == N_BRACKETS * EXPECTED_N_TRIALS_PER_BRACKET


def test_bracket_study():
# type: () -> None

pruner = optuna.pruners.HyperbandPruner(
min_resource=MIN_RESOURCE,
reduction_factor=REDUCTION_FACTOR,
n_brackets=N_BRACKETS
)
study = optuna.study.create_study(sampler=optuna.samplers.RandomSampler(), pruner=pruner)
bracket_study = pruner._create_bracket_study(study, 0)

with pytest.raises(AttributeError):
bracket_study.optimize(lambda *args: 1.0)

for attr in ('set_user_attr', 'set_system_attr'):
with pytest.raises(AttributeError):
getattr(bracket_study, attr)('abc', 100)

for attr in ('user_attrs', 'system_attrs'):
with pytest.raises(AttributeError):
getattr(bracket_study, attr)

with pytest.raises(AttributeError):
bracket_study.trials_dataframe()

bracket_study.get_trials()
bracket_study.direction
bracket_study._storage
bracket_study._study_id
bracket_study.pruner
bracket_study.study_name
# As `_BracketStudy` is defined inside `HyperbandPruner`,
# we cannot do `assert isinstance(bracket_study, _BracketStudy)`.
# This is why the below line is ignored by mypy checks.
bracket_study._bracket_id # type: ignore

0 comments on commit 0db5508

Please sign in to comment.