Skip to content

Commit

Permalink
Merge pull request #3910 from c-bata/refactor-get-trial-id-from-study…
Browse files Browse the repository at this point in the history
…-id-trial-number

Move default logic of `get_trial_id_from_study_id_trial_number()` method to BaseStorage.
  • Loading branch information
knshnb committed Sep 13, 2022
2 parents 4b060ae + 86f7081 commit de29496
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 40 deletions.
9 changes: 8 additions & 1 deletion optuna/storages/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,14 @@ def get_trial_id_from_study_id_trial_number(self, study_id: int, trial_number: i
:exc:`KeyError`:
If no trial with the matching ``study_id`` and ``trial_number`` exists.
"""
raise NotImplementedError
trials = self.get_all_trials(study_id, deepcopy=False)
if len(trials) <= trial_number:
raise KeyError(
"No trial with trial number {} exists in study with study_id {}.".format(
trial_number, study_id
)
)
return trials[trial_number]._trial_id

def get_trial_number_from_id(self, trial_id: int) -> int:
"""Read the trial number of a trial.
Expand Down
17 changes: 0 additions & 17 deletions optuna/study/_tell.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,6 @@ def _get_frozen_trial(study: "optuna.Study", trial: Union[trial_module.Trial, in
trial_id = study._storage.get_trial_id_from_study_id_trial_number(
study._study_id, trial_number
)
except NotImplementedError as e:
warnings.warn(
"Study.tell may be slow because the trial was represented by its number but "
f"the storage {study._storage.__class__.__name__} does not implement the "
"method required to map numbers back. Please provide the trial object "
"to avoid performance degradation."
)

trials = study.get_trials(deepcopy=False)

if len(trials) <= trial_number:
raise ValueError(
f"Cannot tell for trial with number {trial_number} since it has not been "
"created."
) from e

trial_id = trials[trial_number]._trial_id
except KeyError as e:
raise ValueError(
f"Cannot tell for trial with number {trial_number} since it has not been "
Expand Down
22 changes: 0 additions & 22 deletions tests/study_tests/test_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,28 +1480,6 @@ def test_tell_duplicate_tell() -> None:
study.tell(trial, 1.0, skip_if_finished=False)


@pytest.mark.filterwarnings("ignore::UserWarning")
def test_tell_storage_not_implemented_trial_number() -> None:
with StorageSupplier("inmemory") as storage:

with patch.object(
storage,
"get_trial_id_from_study_id_trial_number",
side_effect=NotImplementedError,
):
study = create_study(storage=storage)

study.tell(study.ask(), 1.0)

# Storage missing implementation for method required to map trial numbers back to
# trial IDs.
with pytest.warns(UserWarning):
study.tell(study.ask().number, 1.0)

with pytest.raises(ValueError):
study.tell(study.ask().number + 1, 1.0)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_enqueued_trial_datetime_start(storage_mode: str) -> None:

Expand Down

0 comments on commit de29496

Please sign in to comment.