Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rank plot matplotlib version #4541

Merged
merged 12 commits into from
May 10, 2023

Conversation

cross32768
Copy link
Contributor

@cross32768 cross32768 commented Mar 24, 2023

Motivation

This is the follow-up PR of #4427 .

Description of the changes

Add plot_rank function in optuna.visualization.matplotlib.
It is matplotlib version of plot_rank function in optuna.visualization which was added by #4427 .

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@github-actions github-actions bot added the optuna.visualization Related to the `optuna.visualization` submodule. This is automatically labeled by github-actions. label Mar 24, 2023
@github-actions
Copy link
Contributor

github-actions bot commented Apr 2, 2023

This pull request has not seen any recent activity.

@github-actions github-actions bot added stale Exempt from stale bot labeling. and removed stale Exempt from stale bot labeling. labels Apr 2, 2023
@github-actions
Copy link
Contributor

This pull request has not seen any recent activity.

@github-actions github-actions bot added the stale Exempt from stale bot labeling. label Apr 11, 2023
@cross32768
Copy link
Contributor Author

Below is the comparison of plotly version rank_plot and matplotlib version by this PR.

Simple Example

def objective(trial):
    x = trial.suggest_float("x", -1, 1)
    y = trial.suggest_float("y", -1, 1)
    return x + y

study = optuna.create_study()
study.optimize(objective, n_trials=100)

Plotly
スクリーンショット 2023-04-14 17 39 01

matplotlib
スクリーンショット 2023-04-14 17 39 10

@cross32768
Copy link
Contributor Author

When parameter is logscale or categorical

def objective(trial):
    x = trial.suggest_float("x", 1, 1000, log=True)
    y = trial.suggest_categorical("y", ["a", "b", "c"])
    
    if y == "a":
        return 2 * x
    else:
        return x

Plotly
スクリーンショット 2023-04-14 17 43 14

matplotlib
スクリーンショット 2023-04-14 17 43 24

@cross32768
Copy link
Contributor Author

cross32768 commented Apr 14, 2023

When plotting a study with more than 2 parameters

def objective(trial):
    x = trial.suggest_float("x", -1, 1)
    y = trial.suggest_float("y", -1, 1)
    z = trial.suggest_float("z", -1, 1)
    return x * y * z

Plotly
スクリーンショット 2023-04-14 17 46 52

matplotlib
スクリーンショット 2023-04-14 17 47 08

@cross32768 cross32768 removed the stale Exempt from stale bot labeling. label Apr 14, 2023
@cross32768 cross32768 marked this pull request as ready for review April 14, 2023 08:54
@cross32768
Copy link
Contributor Author

MEMO: change of plotly version in recent PR #4602 should be included if PR #4602 is merged

@toshihikoyanase toshihikoyanase added the feature Change that does not break compatibility, but affects the public interfaces. label Apr 20, 2023
@toshihikoyanase
Copy link
Member

@Alnusjaponica Could you join the review, please? If you have any questions, please let me know.

Copy link
Collaborator

@Alnusjaponica Alnusjaponica left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made some suggestions, some of which can possibly be addressed in different PRs. Feel free to ignore those minor changes if you don't agree with them.

) -> "Axes":
"""Plot parameter relations as scatter plots with colors indicating ranks of objective value.

Note that, if a parameter contains missing values, a trial with missing values is not plotted.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Note that, if a parameter contains missing values, a trial with missing values is not plotted.
Note that, trials with missing values will not be plotted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion. It looks good change, but I think to apply this change in the follow-up PR because this change should also apply original plotly-version.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"values" usually mean objective values, so maybe it's better to say "trials missing the specified parameters will not be plotted."

optuna/visualization/matplotlib/_rank.py Outdated Show resolved Hide resolved
target: Optional[Callable[[FrozenTrial], float]] = None,
target_name: str = "Objective Value",
) -> "Axes":
"""Plot parameter relations as scatter plots with colors indicating ranks of objective value.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Plot parameter relations as scatter plots with colors indicating ranks of objective value.
"""Plot parameter relations as scatter plots with colors indicating ranks of objective values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this change can also be resolved in the follow-up PR to re-write both of plotly version and matplotlib version.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More precisely, this should be "ranks of target value".

optuna/visualization/matplotlib/_rank.py Outdated Show resolved Hide resolved
optuna/visualization/matplotlib/_rank.py Outdated Show resolved Hide resolved
optuna/visualization/matplotlib/_rank.py Outdated Show resolved Hide resolved
optuna/visualization/matplotlib/_rank.py Outdated Show resolved Hide resolved
optuna/visualization/matplotlib/_rank.py Outdated Show resolved Hide resolved
cross32768 and others added 7 commits April 21, 2023 04:44
Co-authored-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Co-authored-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Co-authored-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Co-authored-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
Co-authored-by: Shinichi Hemmi <50256998+Alnusjaponica@users.noreply.github.com>
@cross32768
Copy link
Contributor Author

I added edgecolor to scatter plot as in #4602 .

@cross32768
Copy link
Contributor Author

@Alnusjaponica Thank you for your review. I resolved most of comments and update my code to follow #4602. Could you take another look?

Copy link
Collaborator

@Alnusjaponica Alnusjaponica left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for my delayed reply. LGTM!
@contramundum53 Could you take another look?

@Alnusjaponica Alnusjaponica removed their assignment Apr 27, 2023
@cross32768 cross32768 modified the milestone: v3.2.0 Apr 27, 2023
Comment on lines +104 to +113
if n_params == 0:
_, ax = plt.subplots()
ax.set_title(title)
return ax
if n_params == 1 or n_params == 2:
fig, axs = plt.subplots()
axs.set_title(title)
pc = _add_rank_subplot(axs, sub_plot_infos[0][0])
else:
fig, axs = plt.subplots(n_params, n_params)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Maybe we don't need these ifs and can directly have plt.subplots(len(sub_plot_infos), len(sub_plot_infos)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these ifs are required by the following reason.

  • if n_params == 0: is required because in this case, sub_plot_infos is [] and sub_plot_infos[0][0] raises error if we remove this if.
  • if n_params == 1 or n_params == 2: is required because plt.subplots() or plt.subplots(1, 1) returns Axes object. On the other hand, plt.subplots(n, n) returns array of Axes object if n >= 2. Thus, ax = axs[x_i, y_i] raises error if we remove this if.

@contramundum53
Copy link
Member

Otherwise, LGTM!

@contramundum53
Copy link
Member

I see, LGTM!

@contramundum53 contramundum53 added this to the v3.2.0 milestone May 9, 2023
@contramundum53 contramundum53 merged commit ff12502 into optuna:master May 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Change that does not break compatibility, but affects the public interfaces. optuna.visualization Related to the `optuna.visualization` submodule. This is automatically labeled by github-actions.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants