Skip to content

Commit

Permalink
Remove ordered_dict argument from IntersectionSearchSpace (#4846)
Browse files Browse the repository at this point in the history
  • Loading branch information
not522 committed Aug 1, 2023
1 parent 2f2636f commit fa19042
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 37 deletions.
2 changes: 1 addition & 1 deletion optuna/importance/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _get_distributions(study: Study, params: Optional[List[str]]) -> Dict[str, B
_check_evaluate_args(completed_trials, params)

if params is None:
return intersection_search_space(study.get_trials(deepcopy=False), ordered_dict=True)
return intersection_search_space(study.get_trials(deepcopy=False))

# New temporary required to pass mypy. Seems like a bug.
params_not_none = params
Expand Down
2 changes: 1 addition & 1 deletion optuna/integration/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ def infer_relative_search_space(
raise RuntimeError("BoTorchSampler cannot handle multiple studies.")

search_space: Dict[str, BaseDistribution] = {}
for name, distribution in self._search_space.calculate(study, ordered_dict=True).items():
for name, distribution in self._search_space.calculate(study).items():
if distribution.single():
# built-in `candidates_func` cannot handle distributions that contain just a
# single value, so we skip them. Note that the parameter values for such
Expand Down
29 changes: 7 additions & 22 deletions optuna/search_space/intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,17 @@ def __init__(self, include_pruned: bool = False) -> None:

self._include_pruned = include_pruned

def calculate(self, study: Study, ordered_dict: bool = False) -> Dict[str, BaseDistribution]:
def calculate(self, study: Study) -> Dict[str, BaseDistribution]:
"""Returns the intersection search space of the :class:`~optuna.study.Study`.
Args:
study:
A study with completed trials. The same study must be passed for one instance
of this class through its lifetime.
ordered_dict:
A boolean flag determining the return type.
If :obj:`False`, the returned object will be a :obj:`dict`.
If :obj:`True`, the returned object will be a :obj:`dict` sorted by keys, i.e.
parameter names.
Returns:
A dictionary containing the parameter names and parameter's distributions.
A dictionary containing the parameter names and parameter's distributions sorted by
parameter names.
"""

if self._study_id is None:
Expand All @@ -108,16 +104,12 @@ def calculate(self, study: Study, ordered_dict: bool = False) -> Dict[str, BaseD
self._cached_trial_number,
)
search_space = self._search_space or {}

if ordered_dict:
search_space = dict(sorted(search_space.items(), key=lambda x: x[0]))

search_space = dict(sorted(search_space.items(), key=lambda x: x[0]))
return copy.deepcopy(search_space)


def intersection_search_space(
trials: list[optuna.trial.FrozenTrial],
ordered_dict: bool = False,
include_pruned: bool = False,
) -> Dict[str, BaseDistribution]:
"""Return the intersection search space of the given trials.
Expand All @@ -136,22 +128,15 @@ def intersection_search_space(
Args:
trials:
A list of trials.
ordered_dict:
A boolean flag determining the return type.
If :obj:`False`, the returned object will be a :obj:`dict`.
If :obj:`True`, the returned object will be a :obj:`dict` sorted by keys, i.e.
parameter names.
include_pruned:
Whether pruned trials should be included in the search space.
Returns:
A dictionary containing the parameter names and parameter's distributions.
A dictionary containing the parameter names and parameter's distributions sorted by
parameter names.
"""

search_space, _ = _calculate(trials, include_pruned)
search_space = search_space or {}

if ordered_dict:
search_space = dict(sorted(search_space.items(), key=lambda x: x[0]))

search_space = dict(sorted(search_space.items(), key=lambda x: x[0]))
return search_space
2 changes: 1 addition & 1 deletion optuna/terminator/improvement/_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def apply(
trials: List[optuna.trial.FrozenTrial],
study_direction: Optional[optuna.study.StudyDirection],
) -> List[optuna.trial.FrozenTrial]:
search_space = intersection_search_space(trials, ordered_dict=True)
search_space = intersection_search_space(trials)

additional_trials = []
for _ in range(self._n_additional_trials):
Expand Down
2 changes: 1 addition & 1 deletion optuna/terminator/improvement/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def evaluate(
trials: List[FrozenTrial],
study_direction: StudyDirection,
) -> float:
search_space = intersection_search_space(trials, ordered_dict=True)
search_space = intersection_search_space(trials)
self._validate_input(trials, search_space)

fit_trials = self.get_preprocessing().apply(trials, study_direction)
Expand Down
2 changes: 1 addition & 1 deletion optuna/terminator/improvement/gp/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _convert_trials_to_tensors(trials: list[FrozenTrial]) -> tuple[torch.Tensor,
- the state is COMPLETE for any trial;
- direction is MINIMIZE for any trial.
"""
search_space = intersection_search_space(trials, ordered_dict=True)
search_space = intersection_search_space(trials)
sorted_params = sorted(search_space.keys())

x = []
Expand Down
12 changes: 2 additions & 10 deletions tests/search_space_tests/test_intersection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,8 @@ def test_intersection_search_space() -> None:
study.get_trials(deepcopy=False)
)

# Returning sorted `dict`.
assert search_space.calculate(study, ordered_dict=True) == dict(
[
("x", IntDistribution(low=0, high=10)),
("y", FloatDistribution(low=-3, high=3)),
]
)
assert search_space.calculate(study, ordered_dict=True) == intersection_search_space(
study.get_trials(deepcopy=False), ordered_dict=True
)
# Returned dict is sorted by parameter names.
assert list(search_space.calculate(study).keys()) == ["x", "y"]

# Second trial (only 'y' parameter is suggested in this trial).
study.optimize(lambda t: t.suggest_float("y", -3, 3), n_trials=1)
Expand Down

0 comments on commit fa19042

Please sign in to comment.