Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move default logic of get_trial_id_from_study_id_trial_number() method to BaseStorage. #3910

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion optuna/storages/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,14 @@ def get_trial_id_from_study_id_trial_number(self, study_id: int, trial_number: i
:exc:`KeyError`:
If no trial with the matching ``study_id`` and ``trial_number`` exists.
"""
raise NotImplementedError
trials = self.get_all_trials(study_id, deepcopy=False)
if len(trials) <= trial_number:
raise KeyError(
"No trial with trial number {} exists in study with study_id {}.".format(
trial_number, study_id
)
)
return trials[trial_number]._trial_id

def get_trial_number_from_id(self, trial_id: int) -> int:
"""Read the trial number of a trial.
Expand Down
17 changes: 0 additions & 17 deletions optuna/study/_tell.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,6 @@ def _get_frozen_trial(study: "optuna.Study", trial: Union[trial_module.Trial, in
trial_id = study._storage.get_trial_id_from_study_id_trial_number(
study._study_id, trial_number
)
except NotImplementedError as e:
warnings.warn(
"Study.tell may be slow because the trial was represented by its number but "
f"the storage {study._storage.__class__.__name__} does not implement the "
"method required to map numbers back. Please provide the trial object "
"to avoid performance degradation."
)

knshnb marked this conversation as resolved.
Show resolved Hide resolved
trials = study.get_trials(deepcopy=False)

if len(trials) <= trial_number:
raise ValueError(
f"Cannot tell for trial with number {trial_number} since it has not been "
"created."
) from e

trial_id = trials[trial_number]._trial_id
except KeyError as e:
raise ValueError(
f"Cannot tell for trial with number {trial_number} since it has not been "
Expand Down
22 changes: 0 additions & 22 deletions tests/study_tests/test_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,28 +1480,6 @@ def test_tell_duplicate_tell() -> None:
study.tell(trial, 1.0, skip_if_finished=False)


@pytest.mark.filterwarnings("ignore::UserWarning")
def test_tell_storage_not_implemented_trial_number() -> None:
with StorageSupplier("inmemory") as storage:

with patch.object(
storage,
"get_trial_id_from_study_id_trial_number",
side_effect=NotImplementedError,
):
study = create_study(storage=storage)

study.tell(study.ask(), 1.0)

# Storage missing implementation for method required to map trial numbers back to
# trial IDs.
with pytest.warns(UserWarning):
study.tell(study.ask().number, 1.0)

with pytest.raises(ValueError):
study.tell(study.ask().number + 1, 1.0)


@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_enqueued_trial_datetime_start(storage_mode: str) -> None:

Expand Down