Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove ordered_dict argument from IntersectionSearchSpace #4846

Merged
merged 1 commit into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optuna/importance/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _get_distributions(study: Study, params: Optional[List[str]]) -> Dict[str, B
_check_evaluate_args(completed_trials, params)

if params is None:
return intersection_search_space(study.get_trials(deepcopy=False), ordered_dict=True)
return intersection_search_space(study.get_trials(deepcopy=False))

# New temporary required to pass mypy. Seems like a bug.
params_not_none = params
Expand Down
2 changes: 1 addition & 1 deletion optuna/integration/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ def infer_relative_search_space(
raise RuntimeError("BoTorchSampler cannot handle multiple studies.")

search_space: Dict[str, BaseDistribution] = {}
for name, distribution in self._search_space.calculate(study, ordered_dict=True).items():
for name, distribution in self._search_space.calculate(study).items():
if distribution.single():
# built-in `candidates_func` cannot handle distributions that contain just a
# single value, so we skip them. Note that the parameter values for such
Expand Down
29 changes: 7 additions & 22 deletions optuna/search_space/intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,17 @@ def __init__(self, include_pruned: bool = False) -> None:

self._include_pruned = include_pruned

def calculate(self, study: Study, ordered_dict: bool = False) -> Dict[str, BaseDistribution]:
def calculate(self, study: Study) -> Dict[str, BaseDistribution]:
"""Returns the intersection search space of the :class:`~optuna.study.Study`.
Args:
study:
A study with completed trials. The same study must be passed for one instance
of this class through its lifetime.
ordered_dict:
A boolean flag determining the return type.
If :obj:`False`, the returned object will be a :obj:`dict`.
If :obj:`True`, the returned object will be a :obj:`dict` sorted by keys, i.e.
parameter names.
Returns:
A dictionary containing the parameter names and parameter's distributions.
A dictionary containing the parameter names and parameter's distributions sorted by
parameter names.
"""

if self._study_id is None:
Expand All @@ -108,16 +104,12 @@ def calculate(self, study: Study, ordered_dict: bool = False) -> Dict[str, BaseD
self._cached_trial_number,
)
search_space = self._search_space or {}

if ordered_dict:
search_space = dict(sorted(search_space.items(), key=lambda x: x[0]))

search_space = dict(sorted(search_space.items(), key=lambda x: x[0]))
return copy.deepcopy(search_space)


def intersection_search_space(
trials: list[optuna.trial.FrozenTrial],
ordered_dict: bool = False,
include_pruned: bool = False,
) -> Dict[str, BaseDistribution]:
"""Return the intersection search space of the given trials.
Expand All @@ -136,22 +128,15 @@ def intersection_search_space(
Args:
trials:
A list of trials.
ordered_dict:
A boolean flag determining the return type.
If :obj:`False`, the returned object will be a :obj:`dict`.
If :obj:`True`, the returned object will be a :obj:`dict` sorted by keys, i.e.
parameter names.
include_pruned:
Whether pruned trials should be included in the search space.
Returns:
A dictionary containing the parameter names and parameter's distributions.
A dictionary containing the parameter names and parameter's distributions sorted by
parameter names.
"""

search_space, _ = _calculate(trials, include_pruned)
search_space = search_space or {}

if ordered_dict:
search_space = dict(sorted(search_space.items(), key=lambda x: x[0]))

search_space = dict(sorted(search_space.items(), key=lambda x: x[0]))
return search_space
2 changes: 1 addition & 1 deletion optuna/terminator/improvement/_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def apply(
trials: List[optuna.trial.FrozenTrial],
study_direction: Optional[optuna.study.StudyDirection],
) -> List[optuna.trial.FrozenTrial]:
search_space = intersection_search_space(trials, ordered_dict=True)
search_space = intersection_search_space(trials)

additional_trials = []
for _ in range(self._n_additional_trials):
Expand Down
2 changes: 1 addition & 1 deletion optuna/terminator/improvement/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def evaluate(
trials: List[FrozenTrial],
study_direction: StudyDirection,
) -> float:
search_space = intersection_search_space(trials, ordered_dict=True)
search_space = intersection_search_space(trials)
self._validate_input(trials, search_space)

fit_trials = self.get_preprocessing().apply(trials, study_direction)
Expand Down
2 changes: 1 addition & 1 deletion optuna/terminator/improvement/gp/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _convert_trials_to_tensors(trials: list[FrozenTrial]) -> tuple[torch.Tensor,
- the state is COMPLETE for any trial;
- direction is MINIMIZE for any trial.
"""
search_space = intersection_search_space(trials, ordered_dict=True)
search_space = intersection_search_space(trials)
sorted_params = sorted(search_space.keys())

x = []
Expand Down
12 changes: 2 additions & 10 deletions tests/search_space_tests/test_intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,8 @@ def test_intersection_search_space() -> None:
study.get_trials(deepcopy=False)
)

# Returning sorted `dict`.
assert search_space.calculate(study, ordered_dict=True) == dict(
[
("x", IntDistribution(low=0, high=10)),
("y", FloatDistribution(low=-3, high=3)),
]
)
assert search_space.calculate(study, ordered_dict=True) == intersection_search_space(
study.get_trials(deepcopy=False), ordered_dict=True
)
# Returned dict is sorted by parameter names.
assert list(search_space.calculate(study).keys()) == ["x", "y"]

# Second trial (only 'y' parameter is suggested in this trial).
study.optimize(lambda t: t.suggest_float("y", -3, 3), n_trials=1)
Expand Down
Loading