Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement `Hyperband` pruner #785

Closed
wants to merge 6 commits into from
Closed

Conversation

@crcrpar
Copy link
Collaborator

crcrpar commented Dec 11, 2019

UPDATE
The original version is separated into #805 and this PR.


This PR is based on #301.

Intuitively, Hyperband (HB) eliminates the dependency on parameters of SuccessiveHalving (SH) by internally executes multiple SHs with different configurations.

Design

Study with HyperbandPruner runs in the same way as with other pruners. HyperbandPruner class maintains some number of SuccessiveHalvingPruners (= brackets) and selects a pruner for each Trial. So, the algorithm would be different from the paper to some extent. There're two challenges, 1) different trials of the same Study have to be pruned by different brackets, and 2) When sampling for a new trial, Sampler can only use the trials of the same bracket.

Major Changes

  • Study collects the list of trials (= friend_trials in the code) and set it as a Trial's attribute
    • To filter trials with some metadata, study sets the information of pruner as user_attr to Trial and uses it as a filter.
  • Trial passes its friend_trial to study.sampler
  • Sampler's sampling methods accept the list of trials as their argument

An alternative design, a new class that manages multiple Studys is implemented in https://github.com/crcrpar/optuna/tree/dev/study-manager.

@c-bata c-bata mentioned this pull request Dec 11, 2019
2 of 4 tasks complete
@crcrpar crcrpar force-pushed the crcrpar:dev/hyperband branch from 1bbf155 to d97f324 Dec 11, 2019
@codecov-io

This comment has been minimized.

Copy link

codecov-io commented Dec 11, 2019

Codecov Report

Merging #785 into master will decrease coverage by 0.12%.
The diff coverage is 85.54%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #785      +/-   ##
==========================================
- Coverage   90.15%   90.02%   -0.13%     
==========================================
  Files         106      108       +2     
  Lines        8769     8906     +137     
==========================================
+ Hits         7906     8018     +112     
- Misses        863      888      +25
Impacted Files Coverage Δ
optuna/pruners/__init__.py 100% <100%> (ø) ⬆️
optuna/samplers/tpe/sampler.py 87.54% <100%> (ø) ⬆️
optuna/pruners/percentile.py 95.71% <100%> (+0.25%) ⬆️
tests/test_study.py 97.9% <100%> (ø) ⬆️
optuna/study.py 93.82% <100%> (+0.3%) ⬆️
optuna/integration/cma.py 94.03% <100%> (+0.08%) ⬆️
optuna/integration/skopt.py 88.42% <100%> (ø) ⬆️
optuna/pruners/successive_halving.py 95.23% <100%> (+0.41%) ⬆️
optuna/testing/integration.py 100% <100%> (ø) ⬆️
optuna/integration/lightgbm_tuner/optimize.py 76.03% <100%> (ø) ⬆️
... and 12 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 0389962...4195bea. Read the comment docs.


return ''

def should_filter_trials(self):

This comment has been minimized.

Copy link
@crcrpar

crcrpar Dec 11, 2019

Author Collaborator

This method is called when Study's construction and tells the study whether the study's sampler can use all the trials or not.

@@ -166,6 +166,7 @@ def __init__(

self.sampler = sampler or samplers.TPESampler()
self.pruner = pruner or pruners.MedianPruner()
self._should_filter_trials = self.pruner.should_filter_trials()

This comment has been minimized.

trial_pruner_metadata = self.pruner.__class__.__name__
pruner_auxiliary_data = self.pruner.get_trial_pruner_auxiliary_data(
self._study_id, trial.number)
if pruner_auxiliary_data:
trial_pruner_metadata += pruner_auxiliary_data
trial.set_user_attr('pruner_metadata', trial_pruner_metadata)
Comment on lines 529 to 572

This comment has been minimized.

Copy link
@crcrpar

crcrpar Dec 11, 2019

Author Collaborator

All Trial has an attribute of what pruner is used as a user_attr to allow for filtering by pruner information.

self._friend_trials = None # type: Optional[List[FrozenTrial]]

def _set_friend_trials(self, friend_trials):
# type: (List[FrozenTrial]) -> None

self._friend_trials = friend_trials

def _clear_friend_trials(self):
# type: () -> None

self._friend_trials = None

@property
def friend_trials(self):
# type: () -> Optional[List[FrozenTrial]]

return self._friend_trials
Comment on lines 141 to 167

This comment has been minimized.

Copy link
@crcrpar

crcrpar Dec 11, 2019

Author Collaborator

I don't think the naming of friend_trials is cool 😅

@crcrpar

This comment has been minimized.

Copy link
Collaborator Author

crcrpar commented Dec 11, 2019

As I have done in general and I want to get any feedback, so I mark this as ready for review.

@crcrpar crcrpar marked this pull request as ready for review Dec 11, 2019
Copy link
Collaborator Author

crcrpar left a comment

Do I have to avoid the deprecated study._study_id and use study.study_name instead as done in these comments?

@@ -30,3 +31,26 @@ def prune(self, study, trial):
"""

raise NotImplementedError

@abc.abstractmethod
def get_trial_pruner_auxiliary_data(self, study_id, trial_number):

This comment has been minimized.

Copy link
@crcrpar

crcrpar Dec 12, 2019

Author Collaborator

As study_id is deprecated, should this be study_name as follows?

Suggested change
def get_trial_pruner_auxiliary_data(self, study_id, trial_number):
def get_trial_pruner_auxiliary_data(self, study_name, trial_number):

@abc.abstractmethod
def get_trial_pruner_auxiliary_data(self, study_id, trial_number):
# type: (int, int) -> str

This comment has been minimized.

Copy link
@crcrpar

crcrpar Dec 12, 2019

Author Collaborator

Reflecting the above change.

Suggested change
# type: (int, int) -> str
# type: (str, int) -> str
budget += n / 2
return budget

def get_bracket_id(self, study_id, trial_number):

This comment has been minimized.

Copy link
@crcrpar

crcrpar Dec 12, 2019

Author Collaborator

As mentioned above, study_id is deprecated.

Suggested change
def get_bracket_id(self, study_id, trial_number):
def get_bracket_id(self, study_name, trial_number):
return budget

def get_bracket_id(self, study_id, trial_number):
# type: (int, int) -> int

This comment has been minimized.

Copy link
@crcrpar

crcrpar Dec 12, 2019

Author Collaborator

ditto

Suggested change
# type: (int, int) -> int
# type: (str, int) -> int
# type: (int, int) -> int
"""Computes the id of bracket for a trial of `trial_number`."""

n = hash('{}_{}'.format(study_id, trial_number)) % self._resource_badget

This comment has been minimized.

Copy link
@crcrpar

crcrpar Dec 12, 2019

Author Collaborator

ditto

Suggested change
n = hash('{}_{}'.format(study_id, trial_number)) % self._resource_badget
n = hash('{}_{}'.format(study_name, trial_number)) % self._resource_badget
@@ -151,3 +151,13 @@ def prune(self, study, trial):
if direction == structs.StudyDirection.MAXIMIZE:
return best_intermediate_result < p
return best_intermediate_result > p

def get_trial_pruner_auxiliary_data(self, study_id, trial_number):
# type: (int, int) -> str

This comment has been minimized.

Copy link
@crcrpar

crcrpar Dec 12, 2019

Author Collaborator
Suggested change
# type: (int, int) -> str
# type: (str, int) -> str
direction = study.direction
if not self._is_promotable(rung, value, all_trials, direction):
return True

rung += 1

def get_trial_pruner_auxiliary_data(self, study_id, trial_number):

This comment has been minimized.

Copy link
@crcrpar

crcrpar Dec 12, 2019

Author Collaborator
Suggested change
def get_trial_pruner_auxiliary_data(self, study_id, trial_number):
def get_trial_pruner_auxiliary_data(self, study_name, trial_number):
direction = study.direction
if not self._is_promotable(rung, value, all_trials, direction):
return True

rung += 1

def get_trial_pruner_auxiliary_data(self, study_id, trial_number):
# type: (int, int) -> str

This comment has been minimized.

Copy link
@crcrpar

crcrpar Dec 12, 2019

Author Collaborator
Suggested change
# type: (int, int) -> str
# type: (str, int) -> str
@@ -12,6 +12,16 @@ def prune(self, study, trial):

return self.is_pruning

def get_trial_pruner_auxiliary_data(self, study_id, trial_number):

This comment has been minimized.

Copy link
@crcrpar

crcrpar Dec 12, 2019

Author Collaborator
Suggested change
def get_trial_pruner_auxiliary_data(self, study_id, trial_number):
def get_trial_pruner_auxiliary_data(self, study_name, trial_number):
@@ -12,6 +12,16 @@ def prune(self, study, trial):

return self.is_pruning

def get_trial_pruner_auxiliary_data(self, study_id, trial_number):
# type: (int, int) -> str

This comment has been minimized.

Copy link
@crcrpar

crcrpar Dec 12, 2019

Author Collaborator
Suggested change
# type: (int, int) -> str
# type: (str, int) -> str
Copy link
Collaborator

c-bata left a comment

Good job @crcrpar!
My code review is still work in progress (Actually, I still don't understand the reason why this PR includes the change of sampler interface, and what friend_trials means.). For now, I put some minor comments.

search_space, # type: Dict[str, BaseDistribution]
trials=None # type: Optional[List[FrozenTrial]]
):
# type: (...) -> Dict[str, float]

if len(study.trials) > 1:
raise RuntimeError("`FirstTrialOnlyRandomSampler` only works on the first trial.")

return super(FirstTrialOnlyRandomSampler, self).sample_relative(study, trial, search_space)

This comment has been minimized.

Copy link
@c-bata

c-bata Dec 12, 2019

Collaborator

It looks trials argument should be propagated.

Suggested change
return super(FirstTrialOnlyRandomSampler, self).sample_relative(study, trial, search_space)
return super(FirstTrialOnlyRandomSampler, self).sample_relative(study, trial, search_space, trials=trials)

sample_independent method (L80) is also.

This comment has been minimized.

Copy link
@crcrpar

crcrpar Dec 18, 2019

Author Collaborator

Good catch, thank you!

optuna/pruners/hyperband.py Outdated Show resolved Hide resolved
`Hyperband paper <http://www.jmlr.org/papers/volume18/16-558/16-558.pdf>`_.
"""

n = hash('{}_{}'.format(study_id, trial_number)) % self._resource_badget

This comment has been minimized.

Copy link
@c-bata

c-bata Dec 12, 2019

Collaborator

I'm not confident but It looks trial_number % self._resource_budget is enough, right?

Suggested change
n = hash('{}_{}'.format(study_id, trial_number)) % self._resource_badget
n = trial_number % self._resource_budget
optuna/pruners/hyperband.py Outdated Show resolved Hide resolved
optuna/pruners/hyperband.py Outdated Show resolved Hide resolved
optuna/pruners/hyperband.py Outdated Show resolved Hide resolved
@crcrpar

This comment has been minimized.

Copy link
Collaborator Author

crcrpar commented Dec 13, 2019

@c-bata

Thank you for your review!
Being not complete is not a problem.
I really appreciate your response. 😀

why this PR includes the change of sampler interface, and what friend_trials means

I think your question is equivalent to why new argument trials: List[FrozenTrial].
A sampler, especially TPE sampler, must only reflect the trials that have the same SuccessiveHalvingPruner of HyperbandPruner as a trial that is currently being initialized.
To realize this, I added an attribute of pruner_metadata to Trial via set_uesr_atttr inside Study. Also, make Study collects appropriate trials and set them as friend_trials to the trial tentatively for the ease of trial selection.

I hope this helps you.

@c-bata

This comment has been minimized.

Copy link
Collaborator

c-bata commented Dec 13, 2019

Thank you! Probably, I understand.

  1. When using SuccessiveHalvingPruner, it has no problem to use the last intermediate score of pruned trials.
  2. But it seems that HyperbandPruner is not.
  3. So you labeled trials as pruned_metadata (this is a string representation of bracket_id in HyperbandPruner) by get_trial_pruner_auxiliary_data() method.
  4. friend_trials returns the trials which has the same pruned_metadata (bracket_id).
  5. So you passes friend_trials to samplers.
@c-bata c-bata mentioned this pull request Dec 13, 2019
@crcrpar crcrpar mentioned this pull request Dec 18, 2019
crcrpar added 2 commits Dec 18, 2019
update other samplers
update tests
@crcrpar

This comment has been minimized.

Copy link
Collaborator Author

crcrpar commented Dec 18, 2019

This PR consists of two major changes

  • new argument of trials to samplers sample methods
  • hyperband

Both changes are not trivial I think, thus I'd like to separate this into two PRs.

@sile

This comment has been minimized.

Copy link
Member

sile commented Dec 18, 2019

Both changes are not trivial I think, thus I'd like to separate this into two PRs.

Nice idea!

@crcrpar crcrpar force-pushed the crcrpar:dev/hyperband branch from d3f749a to 4195bea Dec 18, 2019
@crcrpar crcrpar mentioned this pull request Dec 18, 2019
@sile

This comment has been minimized.

Copy link
Member

sile commented Dec 27, 2019

I think that this PR was taken over by #809. Could we close this? > @crcrpar

@crcrpar

This comment has been minimized.

Copy link
Collaborator Author

crcrpar commented Dec 27, 2019

I think that this PR was taken over by #809. Could we close this? > @crcrpar

thank you for your reminding, of course.

@crcrpar crcrpar closed this Dec 27, 2019
@crcrpar crcrpar deleted the crcrpar:dev/hyperband branch Jan 22, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked issues

Successfully merging this pull request may close these issues.

None yet

4 participants
You can’t perform that action at this time.