Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HyperbandPruner #809

Merged
merged 25 commits into from Jan 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/reference/pruners.rst
Expand Up @@ -21,3 +21,7 @@ Pruners
.. autoclass:: SuccessiveHalvingPruner
:members:
:exclude-members: prune

.. autoclass:: HyperbandPruner
:members:
:exclude-members: prune
1 change: 1 addition & 0 deletions optuna/pruners/__init__.py
@@ -1,4 +1,5 @@
from optuna.pruners.base import BasePruner # NOQA
from optuna.pruners.hyperband import HyperbandPruner # NOQA
from optuna.pruners.median import MedianPruner # NOQA
from optuna.pruners.nop import NopPruner # NOQA
from optuna.pruners.percentile import PercentilePruner # NOQA
Expand Down
155 changes: 155 additions & 0 deletions optuna/pruners/hyperband.py
@@ -0,0 +1,155 @@
import optuna
from optuna import logging
from optuna.pruners.base import BasePruner
from optuna.pruners.successive_halving import SuccessiveHalvingPruner
from optuna import type_checking

if type_checking.TYPE_CHECKING:
from typing import List # NOQA

from optuna import structs # NOQA

_logger = logging.get_logger(__name__)


class HyperbandPruner(BasePruner):
"""Pruner using Hyperband.

As SuccessiveHalving (SHA) requires the number of configurations
:math:`n` as its hyperparameter. For a given finite budget :math:`B`,
all the configurations have the resources of :math:`B \\over n` on average.
As you can see, there will be a trade-off of :math:`B` and :math:`B \\over n`.
`Hyperband <http://www.jmlr.org/papers/volume18/16-558/16-558.pdf>`_ attacks this trade-off
by trying different :math:`n` values for a fixed budget.
Note that this implementation does not take as inputs the maximum amount of resource to
a single SHA noted as :math:`R` in the paper.

Args:
min_resource:
A parameter for specifying the minimum resource allocated to a trial noted as :math:`r`
in the paper.
See the details for :class:`~optuna.pruners.SuccessiveHalvingPruner`.
reduction_factor:
A parameter for specifying reduction factor of promotable trials noted as
:math:`\\eta` in the paper. See the details for
:class:`~optuna.pruners.SuccessiveHalvingPruner`.
n_brackets:
The number of :class:`~optuna.pruners.SuccessiveHalvingPruner`\\s (brackets).
min_early_stopping_rate_low:
A parameter for specifying the minimum early-stopping rate.
This parameter is related to a parameter that is referred to as :math:`r` and used in
`Asynchronous SuccessiveHalving paper <http://arxiv.org/abs/1810.05934>`_.
The minimum early stopping rate for :math:`i` th bracket is :math:`i + s`.
"""

def __init__(
self,
min_resource=1,
reduction_factor=3,
crcrpar marked this conversation as resolved.
Show resolved Hide resolved
n_brackets=4,
min_early_stopping_rate_low=0
):
# type: (int, int, int, int) -> None
crcrpar marked this conversation as resolved.
Show resolved Hide resolved

self._pruners = [] # type: List[SuccessiveHalvingPruner]
self._reduction_factor = reduction_factor
self._resource_budget = 0
self._n_brackets = n_brackets
self._bracket_resource_budgets = [] # type: List[int]

_logger.debug('Hyperband has {} brackets'.format(self._n_brackets))

for i in range(n_brackets):
bracket_resource_budget = self._calc_bracket_resource_budget(i, n_brackets)
self._resource_budget += bracket_resource_budget
self._bracket_resource_budgets.append(bracket_resource_budget)

# N.B. (crcrpar): `min_early_stopping_rate` has the information of `bracket_index`.
min_early_stopping_rate = min_early_stopping_rate_low + i

_logger.debug(
'{}th bracket has minimum early stopping rate of {}'.format(
i, min_early_stopping_rate))

pruner = SuccessiveHalvingPruner(
min_resource=min_resource,
reduction_factor=reduction_factor,
min_early_stopping_rate=min_early_stopping_rate,
)
self._pruners.append(pruner)

def prune(self, study, trial):
# type: (optuna.study.Study, structs.FrozenTrial) -> bool

i = self._get_bracket_id(study, trial)
_logger.debug('{}th bracket is selected'.format(i))
bracket_study = self._create_bracket_study(study, i)
return self._pruners[i].prune(bracket_study, trial)

# TODO(crcrpar): Improve resource computation/allocation algorithm.
def _calc_bracket_resource_budget(self, pruner_index, n_brackets):
# type: (int, int) -> int
crcrpar marked this conversation as resolved.
Show resolved Hide resolved

n = self._reduction_factor ** (n_brackets - 1)
return n + (n / 2) * (n_brackets - 1 - pruner_index)

def _get_bracket_id(self, study, trial):
# type: (optuna.study.Study, structs.FrozenTrial) -> int
"""Computes the index of bracket for a trial of ``trial_number``.

The index of a bracket is noted as :math:`s` in
`Hyperband paper <http://www.jmlr.org/papers/volume18/16-558/16-558.pdf>`_.
"""

n = hash('{}_{}'.format(study.study_name, trial.number)) % self._resource_budget
for i in range(self._n_brackets):
n -= self._bracket_resource_budgets[i]
if n < 0:
return i

assert False, 'This line should be unreachable.'

def _create_bracket_study(self, study, bracket_index):
# type: (optuna.study.Study, int) -> optuna.study.Study

# This class is assumed to be passed to
# `SuccessiveHalvingPruner.prune` in which `get_trials`,
# `direction`, and `storage` are used.
# But for safety, prohibit the other attributes explicitly.
class _BracketStudy(optuna.study.Study):

_VALID_ATTRS = (
'get_trials', 'direction', '_storage', '_study_id',
'pruner', 'study_name', '_bracket_id', 'sampler'
)

def __init__(self, study, bracket_id):
# type: (optuna.study.Study, int) -> None

super().__init__(
study_name=study.study_name,
storage=study._storage,
sampler=study.sampler,
pruner=study.pruner
)
self._bracket_id = bracket_id

def get_trials(self, deepcopy=True):
# type: (bool) -> List[structs.FrozenTrial]

trials = super().get_trials(deepcopy=deepcopy)
pruner = self.pruner
assert isinstance(pruner, HyperbandPruner)
return [
t for t in trials
if pruner._get_bracket_id(self, t) == self._bracket_id
]

def __getattribute__(self, attr_name): # type: ignore
if attr_name not in _BracketStudy._VALID_ATTRS:
raise AttributeError(
"_BracketStudy does not have attribute of '{}'".format(attr_name))
else:
return object.__getattribute__(self, attr_name)

return _BracketStudy(study, bracket_index)
10 changes: 10 additions & 0 deletions optuna/study.py
Expand Up @@ -191,6 +191,16 @@ def __init__(
self.sampler = sampler or samplers.TPESampler()
self.pruner = pruner or pruners.MedianPruner()

if (
isinstance(self.sampler, samplers.TPESampler) and
isinstance(self.pruner, pruners.HyperbandPruner)):
msg = (
"The algorithm of TPESampler and HyperbandPruner might behave in a different way "
"from the paper of Hyperband because the sampler uses all the trials including "
"ones of brackets other than that of currently running trial")
warnings.warn(msg, UserWarning)
crcrpar marked this conversation as resolved.
Show resolved Hide resolved
_logger.warning(msg)

self._optimize_lock = threading.Lock()

def __getstate__(self):
Expand Down
84 changes: 84 additions & 0 deletions tests/pruners_tests/test_hyperband.py
@@ -0,0 +1,84 @@
import pytest

import optuna
from optuna import type_checking

if type_checking.TYPE_CHECKING:
from optuna.trial import Trial # NOQA

MIN_RESOURCE = 1
REDUCTION_FACTOR = 2
N_BRACKETS = 4
EARLY_STOPPING_RATE_LOW = 0
EARLY_STOPPING_RATE_HIGH = 3
N_REPORTS = 10
EXPECTED_N_TRIALS_PER_BRACKET = 10


def test_warn_on_TPESampler():
# type: () -> None

with pytest.warns(UserWarning):
optuna.study.create_study(pruner=optuna.pruners.HyperbandPruner())


def test_hyperband_pruner_intermediate_values():
# type: () -> None

pruner = optuna.pruners.HyperbandPruner(
min_resource=MIN_RESOURCE,
reduction_factor=REDUCTION_FACTOR,
n_brackets=N_BRACKETS
)

study = optuna.study.create_study(sampler=optuna.samplers.RandomSampler(), pruner=pruner)

def objective(trial):
# type: (Trial) -> float

for i in range(N_REPORTS):
trial.report(i)

return 1.0

study.optimize(objective, n_trials=N_BRACKETS * EXPECTED_N_TRIALS_PER_BRACKET)

trials = study.trials
assert len(trials) == N_BRACKETS * EXPECTED_N_TRIALS_PER_BRACKET


def test_bracket_study():
# type: () -> None

pruner = optuna.pruners.HyperbandPruner(
min_resource=MIN_RESOURCE,
reduction_factor=REDUCTION_FACTOR,
n_brackets=N_BRACKETS
)
study = optuna.study.create_study(sampler=optuna.samplers.RandomSampler(), pruner=pruner)
bracket_study = pruner._create_bracket_study(study, 0)

with pytest.raises(AttributeError):
bracket_study.optimize(lambda *args: 1.0)

for attr in ('set_user_attr', 'set_system_attr'):
with pytest.raises(AttributeError):
getattr(bracket_study, attr)('abc', 100)

for attr in ('user_attrs', 'system_attrs'):
with pytest.raises(AttributeError):
getattr(bracket_study, attr)

with pytest.raises(AttributeError):
bracket_study.trials_dataframe()

bracket_study.get_trials()
bracket_study.direction
bracket_study._storage
bracket_study._study_id
bracket_study.pruner
bracket_study.study_name
# As `_BracketStudy` is defined inside `HyperbandPruner`,
# we cannot do `assert isinstance(bracket_study, _BracketStudy)`.
# This is why the below line is ignored by mypy checks.
bracket_study._bracket_id # type: ignore