Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hotfix] Fix the bug of matplotlib's plot_rank function #5133

Merged
merged 5 commits into from
Dec 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 18 additions & 6 deletions optuna/visualization/_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@
from optuna.visualization._utils import _check_plot_args
from optuna.visualization._utils import _is_log_scale
from optuna.visualization._utils import _is_numerical
from optuna.visualization.matplotlib._matplotlib_imports import _imports as matplotlib_imports


if _imports.is_successful():
plotly_is_available = _imports.is_successful()
if plotly_is_available:
from optuna.visualization._plotly_imports import go
from optuna.visualization._plotly_imports import make_subplots
from optuna.visualization._plotly_imports import plotly
from optuna.visualization._plotly_imports import Scatter
if matplotlib_imports.is_successful():
# TODO(c-bata): Refactor to remove matplotlib and plotly dependencies in `_get_rank_info()`.
# See https://github.com/optuna/optuna/pull/5133#discussion_r1414761672 for the discussion.
from optuna.visualization.matplotlib._matplotlib_imports import plt as matplotlib_plt

_logger = get_logger(__name__)

Expand Down Expand Up @@ -212,7 +218,7 @@
if constraints is not None and any([x > 0.0 for x in constraints]):
infeasible_trial_ids.append(i)

colors[infeasible_trial_ids] = plotly.colors.hex_to_rgb("#cccccc")
colors[infeasible_trial_ids] = (204, 204, 204) # equal to "#CCCCCC"

filtered_ids = [
i
Expand Down Expand Up @@ -422,7 +428,13 @@

def _convert_color_idxs_to_scaled_rgb_colors(color_idxs: np.ndarray) -> np.ndarray:
colormap = "RdYlBu_r"
# sample_colorscale requires plotly >= 5.0.0.
labeled_colors = plotly.colors.sample_colorscale(colormap, color_idxs)
scaled_rgb_colors = np.array([plotly.colors.unlabel_rgb(cl) for cl in labeled_colors])
return scaled_rgb_colors
if plotly_is_available:
# sample_colorscale requires plotly >= 5.0.0.
labeled_colors = plotly.colors.sample_colorscale(colormap, color_idxs)
scaled_rgb_colors = np.array([plotly.colors.unlabel_rgb(cl) for cl in labeled_colors])
return scaled_rgb_colors
else:
cmap = matplotlib_plt.get_cmap(colormap)
colors = cmap(color_idxs)[:, :3] # Drop alpha values.
rgb_colors = np.asarray(colors * 255, dtype=int)
return rgb_colors

Check warning on line 440 in optuna/visualization/_rank.py

View check run for this annotation

Codecov / codecov/patch

optuna/visualization/_rank.py#L437-L440

Added lines #L437 - L440 were not covered by tests