Skip to content
This repository has been archived by the owner on Jul 21, 2021. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar committed Jan 15, 2020
1 parent 68da9a3 commit af0b4c4
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
7 changes: 5 additions & 2 deletions optuna/pruners/hyperband.py
Expand Up @@ -122,7 +122,10 @@ def _create_bracket_study(self, study, bracket_index):
# But for safety, prohibit the other attributes explicitly.
class _BracketStudy(Study):

_VALID_ATTRS = ('get_trials', 'direction', '_storage')
_VALID_ATTRS = (
'get_trials', 'direction', '_storage', '_study_id',
'pruner', 'study_name', '_bracket_id'
)

def __init__(self, study, bracket_id):
# type: (Study, int) -> None
Expand Down Expand Up @@ -150,6 +153,6 @@ def __getattribute__(self, attr_name): # type: ignore
if attr_name not in _BracketStudy._VALID_ATTRS:
raise NotImplementedError
else:
return getattr(self, attr_name)
return object.__getattribute__(self, attr_name)

return _BracketStudy(study, bracket_index)
24 changes: 23 additions & 1 deletion tests/pruners_tests/test_hyperband.py
Expand Up @@ -51,5 +51,27 @@ def test_bracket_study():
study = optuna.study.create_study(pruner=pruner)
bracket_study = pruner._create_bracket_study(study, 0)

with pytest.raises(Exception):
with pytest.raises(NotImplementedError):
bracket_study.optimize(lambda *args: 1.0)

for attr in ('set_user_attr', 'set_system_attr'):
with pytest.raises(NotImplementedError):
getattr(bracket_study, attr)('abc', 100)

for attr in ('user_attrs', 'system_attrs'):
with pytest.raises(NotImplementedError):
getattr(bracket_study, attr)

with pytest.raises(Exception):
bracket_study.trials_dataframe()

bracket_study.get_trials()
bracket_study.direction
bracket_study._storage
bracket_study._study_id
bracket_study.pruner
bracket_study.study_name
# As `_BracketStudy` is defined inside `HyperbandPruner`,
# we cannot do `assert isinstance(bracket_study, _BracketStudy)`.
# This is why the below line is ignored by mypy checks.
bracket_study._bracket_id # type: ignore

0 comments on commit af0b4c4

Please sign in to comment.