Skip to content

Commit

Permalink
Merge pull request #4986 from xadrianzetx/param-importances-metric-names
Browse files Browse the repository at this point in the history
Allow user-defined objective names in hyperparameter importance plots
  • Loading branch information
Alnusjaponica committed Oct 12, 2023
2 parents 11b9589 + 0c58dbe commit 3895ff8
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 43 deletions.
68 changes: 46 additions & 22 deletions optuna/visualization/_param_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,48 @@ def _get_importances_info(
)


def _get_importances_infos(
study: Study,
evaluator: BaseImportanceEvaluator | None,
params: list[str] | None,
target: Callable[[FrozenTrial], float] | None,
target_name: str,
) -> tuple[_ImportancesInfo, ...]:
metric_names = study.metric_names
if target or not study._is_multi_objective():
target_name = metric_names[0] if metric_names is not None and not target else target_name
importances_infos: tuple[_ImportancesInfo, ...] = (
_get_importances_info(
study,
evaluator,
params,
target=target,
target_name=target_name,
),
)

else:
n_objectives = len(study.directions)
target_names = (
metric_names
if metric_names is not None
else (f"{target_name} {objective_id}" for objective_id in range(n_objectives))
)

importances_infos = tuple(
_get_importances_info(
study,
evaluator,
params,
target=lambda t: t.values[objective_id],
target_name=target_name,
)
for objective_id, target_name in enumerate(target_names)
)

return importances_infos


def plot_param_importances(
study: Study,
evaluator: BaseImportanceEvaluator | None = None,
Expand Down Expand Up @@ -137,34 +179,16 @@ def objective(trial):
importance of the first objective, use ``target=lambda t: t.values[0]`` for the
target parameter.
target_name:
Target's name to display on the legend.
Target's name to display on the legend. Names set via
:meth:`~optuna.study.Study.set_metric_names` will be used if ``target`` is :obj:`None`,
overriding this argument.
Returns:
A :class:`plotly.graph_objs.Figure` object.
"""

_imports.check()

if target or not study._is_multi_objective():
importances_infos: tuple[_ImportancesInfo, ...] = (
_get_importances_info(
study, evaluator, params, target=target, target_name=target_name
),
)

else:
n_objectives = len(study.directions)
importances_infos = tuple(
_get_importances_info(
study,
evaluator,
params,
target=lambda t: t.values[objective_id],
target_name=f"{target_name} {objective_id}",
)
for objective_id in range(n_objectives)
)

importances_infos = _get_importances_infos(study, evaluator, params, target, target_name)
return _get_importances_plot(importances_infos, study)


Expand Down
26 changes: 5 additions & 21 deletions optuna/visualization/matplotlib/_param_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from optuna.logging import get_logger
from optuna.study import Study
from optuna.trial import FrozenTrial
from optuna.visualization._param_importances import _get_importances_info
from optuna.visualization._param_importances import _get_importances_infos
from optuna.visualization._param_importances import _ImportancesInfo
from optuna.visualization.matplotlib._matplotlib_imports import _imports

Expand Down Expand Up @@ -86,32 +86,16 @@ def objective(trial):
importance of the first objective, use ``target=lambda t: t.values[0]`` for the
target parameter.
target_name:
Target's name to display on the axis label.
Target's name to display on the axis label. Names set via
:meth:`~optuna.study.Study.set_metric_names` will be used if ``target`` is :obj:`None`,
overriding this argument.
Returns:
A :class:`matplotlib.axes.Axes` object.
"""

_imports.check()

if target or not study._is_multi_objective():
importances_infos: tuple[_ImportancesInfo, ...] = (
_get_importances_info(study, evaluator, params, target, target_name),
)

else:
n_objectives = len(study.directions)
importances_infos = tuple(
_get_importances_info(
study,
evaluator,
params,
target=lambda t: t.values[objective_id],
target_name=f"{target_name} {objective_id}",
)
for objective_id in range(n_objectives)
)

importances_infos = _get_importances_infos(study, evaluator, params, target, target_name)
return _get_importances_plot(importances_infos)


Expand Down
37 changes: 37 additions & 0 deletions tests/visualization_tests/test_param_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from optuna.trial import Trial
from optuna.visualization import plot_param_importances as plotly_plot_param_importances
from optuna.visualization._param_importances import _get_importances_info
from optuna.visualization._param_importances import _get_importances_infos
from optuna.visualization._param_importances import _ImportancesInfo
from optuna.visualization._plotly_imports import go
from optuna.visualization.matplotlib import plot_param_importances as plt_plot_param_importances
Expand Down Expand Up @@ -139,6 +140,42 @@ def test_get_param_importances_info_empty(
)


@pytest.mark.parametrize(
"specific_create_study,objective_names",
[(create_study, ["Foo"]), (_create_multiobjective_study, ["Foo", "Bar"])],
)
def test_get_param_importances_infos_custom_objective_names(
specific_create_study: Callable[[], Study], objective_names: list[str]
) -> None:
study = specific_create_study()
study.set_metric_names(objective_names)

infos = _get_importances_infos(
study, evaluator=None, params=["param_a"], target=None, target_name="Objective Value"
)
assert len(infos) == len(study.directions)
assert all(info.target_name == expected for info, expected in zip(infos, objective_names))


@pytest.mark.parametrize(
"specific_create_study,objective_names",
[
(create_study, ["Objective Value"]),
(_create_multiobjective_study, ["Objective Value 0", "Objective Value 1"]),
],
)
def test_get_param_importances_infos_default_objective_names(
specific_create_study: Callable[[], Study], objective_names: list[str]
) -> None:
study = specific_create_study()

infos = _get_importances_infos(
study, evaluator=None, params=["param_a"], target=None, target_name="Objective Value"
)
assert len(infos) == len(study.directions)
assert all(info.target_name == expected for info, expected in zip(infos, objective_names))


def test_switch_label_when_param_insignificant() -> None:
def _objective(trial: Trial) -> int:
x = trial.suggest_int("x", 0, 2)
Expand Down

0 comments on commit 3895ff8

Please sign in to comment.