Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move default logic of get_trial_id_from_study_id_trial_number() method to BaseStorage. #3910

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion optuna/storages/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,14 @@ def get_trial_id_from_study_id_trial_number(self, study_id: int, trial_number: i
:exc:`KeyError`:
If no trial with the matching ``study_id`` and ``trial_number`` exists.
"""
raise NotImplementedError
trials = self.get_all_trials(study_id, deepcopy=False)
if len(trials) <= trial_number:
raise KeyError(
"No trial with trial number {} exists in study with study_id {}.".format(
trial_number, study_id
)
)
return trials[trial_number]._trial_id

def get_trial_number_from_id(self, trial_id: int) -> int:
"""Read the trial number of a trial.
Expand Down
17 changes: 0 additions & 17 deletions optuna/study/_tell.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,6 @@ def _tell_with_warning(
trial_id = study._storage.get_trial_id_from_study_id_trial_number(
study._study_id, trial_number
)
except NotImplementedError as e:
warnings.warn(
"Study.tell may be slow because the trial was represented by its number but "
f"the storage {study._storage.__class__.__name__} does not implement the "
"method required to map numbers back. Please provide the trial object "
"to avoid performance degradation."
)

knshnb marked this conversation as resolved.
Show resolved Hide resolved
trials = study.get_trials(deepcopy=False)

if len(trials) <= trial_number:
raise ValueError(
f"Cannot tell for trial with number {trial_number} since it has not been "
"created."
) from e

trial_id = trials[trial_number]._trial_id
except KeyError as e:
raise ValueError(
f"Cannot tell for trial with number {trial_number} since it has not been "
Expand Down
22 changes: 6 additions & 16 deletions tests/study_tests/test_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,22 +1485,12 @@ def test_tell_duplicate_tell() -> None:
def test_tell_storage_not_implemented_trial_number() -> None:
with StorageSupplier("inmemory") as storage:

with patch.object(
storage,
"get_trial_id_from_study_id_trial_number",
side_effect=NotImplementedError,
):
study = create_study(storage=storage)

study.tell(study.ask(), 1.0)

# Storage missing implementation for method required to map trial numbers back to
# trial IDs.
with pytest.warns(UserWarning):
study.tell(study.ask().number, 1.0)

with pytest.raises(ValueError):
study.tell(study.ask().number + 1, 1.0)
study = create_study(storage=storage)
study.tell(study.ask(), 1.0)
study.tell(study.ask().number, 1.0)

with pytest.raises(ValueError):
study.tell(study.ask().number + 1, 1.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this, InMemoryStorage.get_trial_id_from_study_id_trial_number is tested instead of BaseStorage.get_trial_id_from_study_id_trial_number.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot think of a way to easily test InMemoryStorage.get_trial_id_from_study_id_trial_number (except for mocking all the abstract methods of InMemoryStorage), but do you come up with a good way...?

Copy link
Member Author

@c-bata c-bata Sep 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! It is a difficult problem how to test concrete methods in an abstract base class. I thought it would be better to just remove this test case since this test case should not be written in test_study.py.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that get_trial_id_from_study_id_trial_number() is enough tested at:

@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
def test_get_trial_id_from_study_id_trial_number(storage_mode: str) -> None:
with StorageSupplier(storage_mode) as storage:
with pytest.raises(KeyError): # Matching study does not exist.
storage.get_trial_id_from_study_id_trial_number(study_id=0, trial_number=0)
study_id = storage.create_new_study()
with pytest.raises(KeyError): # Matching trial does not exist.
storage.get_trial_id_from_study_id_trial_number(study_id, trial_number=0)
trial_id = storage.create_new_trial(study_id)
assert trial_id == storage.get_trial_id_from_study_id_trial_number(
study_id, trial_number=0
)
# Trial IDs are globally unique within a storage but numbers are only unique within a
# study. Create a second study within the same storage.
study_id = storage.create_new_study()
trial_id = storage.create_new_trial(study_id)
assert trial_id == storage.get_trial_id_from_study_id_trial_number(
study_id, trial_number=0
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this test should not be located in test_study.py.
Still, I feel the default logic of BaseStorage.get_trial_id_from_study_id_trial_number should be tested somewhere (not tested in test_get_trial_id_from_study_id_trial_number because all the concrete class override the method). We should probably discuss how to test default implementations in BaseStorage as another issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, let me merge this PR first. If you have any opinions, please leave a comment!



@pytest.mark.parametrize("storage_mode", STORAGE_MODES)
Expand Down