Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow user-defined objective names in hyperparameter importance plots #4986

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions optuna/visualization/_param_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,48 @@ def _get_importances_info(
)


def _get_importances_infos(
study: Study,
evaluator: BaseImportanceEvaluator | None,
params: list[str] | None,
target: Callable[[FrozenTrial], float] | None,
target_name: str,
) -> tuple[_ImportancesInfo, ...]:
metric_names = study.metric_names
if target or not study._is_multi_objective():
target_name = metric_names[0] if metric_names is not None and not target else target_name
importances_infos: tuple[_ImportancesInfo, ...] = (
_get_importances_info(
study,
evaluator,
params,
target=target,
target_name=target_name,
),
)

else:
n_objectives = len(study.directions)
target_names = (
metric_names
if metric_names is not None
else (f"{target_name} {objective_id}" for objective_id in range(n_objectives))
)

importances_infos = tuple(
_get_importances_info(
study,
evaluator,
params,
target=lambda t: t.values[objective_id],
target_name=target_name,
)
for objective_id, target_name in enumerate(target_names)
)

return importances_infos


def plot_param_importances(
study: Study,
evaluator: BaseImportanceEvaluator | None = None,
Expand Down Expand Up @@ -137,34 +179,16 @@ def objective(trial):
importance of the first objective, use ``target=lambda t: t.values[0]`` for the
target parameter.
target_name:
Target's name to display on the legend.
Target's name to display on the legend. Names set via
:meth:`~optuna.study.Study.set_metric_names` will be used if ``target`` is :obj:`None`,
overriding this argument.

Returns:
A :class:`plotly.graph_objs.Figure` object.
"""

_imports.check()

if target or not study._is_multi_objective():
importances_infos: tuple[_ImportancesInfo, ...] = (
_get_importances_info(
study, evaluator, params, target=target, target_name=target_name
),
)

else:
n_objectives = len(study.directions)
importances_infos = tuple(
_get_importances_info(
study,
evaluator,
params,
target=lambda t: t.values[objective_id],
target_name=f"{target_name} {objective_id}",
)
for objective_id in range(n_objectives)
)

importances_infos = _get_importances_infos(study, evaluator, params, target, target_name)
return _get_importances_plot(importances_infos, study)


Expand Down
26 changes: 5 additions & 21 deletions optuna/visualization/matplotlib/_param_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from optuna.logging import get_logger
from optuna.study import Study
from optuna.trial import FrozenTrial
from optuna.visualization._param_importances import _get_importances_info
from optuna.visualization._param_importances import _get_importances_infos
from optuna.visualization._param_importances import _ImportancesInfo
from optuna.visualization.matplotlib._matplotlib_imports import _imports

Expand Down Expand Up @@ -86,32 +86,16 @@ def objective(trial):
importance of the first objective, use ``target=lambda t: t.values[0]`` for the
target parameter.
target_name:
Target's name to display on the axis label.
Target's name to display on the axis label. Names set via
:meth:`~optuna.study.Study.set_metric_names` will be used if ``target`` is :obj:`None`,
overriding this argument.

Returns:
A :class:`matplotlib.axes.Axes` object.
"""

_imports.check()

if target or not study._is_multi_objective():
importances_infos: tuple[_ImportancesInfo, ...] = (
_get_importances_info(study, evaluator, params, target, target_name),
)

else:
n_objectives = len(study.directions)
importances_infos = tuple(
_get_importances_info(
study,
evaluator,
params,
target=lambda t: t.values[objective_id],
target_name=f"{target_name} {objective_id}",
)
for objective_id in range(n_objectives)
)

importances_infos = _get_importances_infos(study, evaluator, params, target, target_name)
return _get_importances_plot(importances_infos)


Expand Down
19 changes: 19 additions & 0 deletions tests/visualization_tests/test_param_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from optuna.trial import Trial
from optuna.visualization import plot_param_importances as plotly_plot_param_importances
from optuna.visualization._param_importances import _get_importances_info
from optuna.visualization._param_importances import _get_importances_infos
from optuna.visualization._param_importances import _ImportancesInfo
from optuna.visualization._plotly_imports import go
from optuna.visualization.matplotlib import plot_param_importances as plt_plot_param_importances
Expand Down Expand Up @@ -139,6 +140,24 @@ def test_get_param_importances_info_empty(
)


@pytest.mark.parametrize(
"specific_create_study,objective_names",
[(create_study, ["Foo"]), (_create_multiobjective_study, ["Foo", "Bar"])],
)
def test_get_param_importances_infos_custom_objective_names(
specific_create_study: Callable[[], Study], objective_names: list[str]
) -> None:
study = specific_create_study()
study.set_metric_names(objective_names)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about testing the case of not setting the metric names?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added 👍

n_objectives = len(study.directions)

infos = _get_importances_infos(
study, evaluator=None, params=["param_a"], target=None, target_name="Objective Value"
)
assert len(infos) == n_objectives
assert all(info.target_name == expected for info, expected in zip(infos, objective_names))


def test_switch_label_when_param_insignificant() -> None:
def _objective(trial: Trial) -> int:
x = trial.suggest_int("x", 0, 2)
Expand Down
Loading