Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timeline plot with matplotlib as backend #4538

Merged
Merged
1 change: 1 addition & 0 deletions docs/source/reference/visualization/matplotlib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ optuna.visualization.matplotlib
optuna.visualization.matplotlib.plot_param_importances
optuna.visualization.matplotlib.plot_pareto_front
optuna.visualization.matplotlib.plot_slice
optuna.visualization.matplotlib.plot_timeline
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Notes] The reference page was rendered as expected:
image

optuna.visualization.matplotlib.is_available
2 changes: 2 additions & 0 deletions optuna/visualization/matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from optuna.visualization.matplotlib._param_importances import plot_param_importances
from optuna.visualization.matplotlib._pareto_front import plot_pareto_front
from optuna.visualization.matplotlib._slice import plot_slice
from optuna.visualization.matplotlib._timeline import plot_timeline
from optuna.visualization.matplotlib._utils import is_available


Expand All @@ -19,4 +20,5 @@
"plot_param_importances",
"plot_pareto_front",
"plot_slice",
"plot_timeline",
]
6 changes: 6 additions & 0 deletions optuna/visualization/matplotlib/_matplotlib_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
# TODO(ytknzw): Add specific imports.
import matplotlib
from matplotlib import __version__ as matplotlib_version
from matplotlib import dates
from matplotlib import patches
from matplotlib import pyplot as plt
from matplotlib import ticker
from matplotlib.axes._axes import Axes
from matplotlib.collections import LineCollection
from matplotlib.collections import PathCollection
Expand All @@ -27,6 +30,9 @@

__all__ = [
"_imports",
"dates",
"ticker",
"patches",
eukaryo marked this conversation as resolved.
Show resolved Hide resolved
"matplotlib",
"matplotlib_version",
"plt",
Expand Down
119 changes: 119 additions & 0 deletions optuna/visualization/matplotlib/_timeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Any

from optuna._experimental import experimental_func
from optuna.study import Study
from optuna.trial import TrialState
from optuna.visualization._timeline import _get_timeline_info
from optuna.visualization._timeline import _TimelineInfo
from optuna.visualization.matplotlib._matplotlib_imports import _imports


if _imports.is_successful():
from optuna.visualization.matplotlib._matplotlib_imports import Axes
from optuna.visualization.matplotlib._matplotlib_imports import dates
from optuna.visualization.matplotlib._matplotlib_imports import patches
from optuna.visualization.matplotlib._matplotlib_imports import plt
from optuna.visualization.matplotlib._matplotlib_imports import ticker


class _DateFormatter_Millisecond(ticker.Formatter):
eukaryo marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self) -> None:
pass

def __call__(self, x: float, pos: Any = 0) -> str:
return dates.num2date(x).strftime("%H:%M:%S.%f")[:-3]
eukaryo marked this conversation as resolved.
Show resolved Hide resolved


@experimental_func("3.2.0")
def plot_timeline(study: Study) -> "Axes":
"""Plot the timeline of a study.

.. seealso::
Please refer to :func:`optuna.visualization.plot_timeline` for an example.

Example:

The following code snippet shows how to plot the timeline of a study.

.. plot::

import time

import optuna


def objective(trial):
x = trial.suggest_float("x", 0, 1)
time.sleep(x * 0.1)
if x > 0.8:
raise ValueError()
if x > 0.4:
raise optuna.TrialPruned()
return x ** 2


study = optuna.create_study(direction="minimize")
study.optimize(
objective, n_trials=50, n_jobs=2, catch=(ValueError,)
)

optuna.visualization.matplotlib.plot_timeline(study)

Args:
study:
A :class:`~optuna.study.Study` object whose trials are plotted with
their lifetime.

Returns:
A :class:`matplotlib.axes.Axes` object.
"""
_imports.check()
info = _get_timeline_info(study)
return _get_timeline_plot(info)


def _get_timeline_plot(info: _TimelineInfo) -> "Axes":
_cm = {
TrialState.COMPLETE: "tab:blue",
TrialState.FAIL: "tab:red",
TrialState.PRUNED: "tab:orange",
TrialState.RUNNING: "tab:green",
TrialState.WAITING: "tab:gray",
}

# Set up the graph style.
plt.style.use("ggplot") # Use ggplot style sheet for similar outputs to plotly.
fig, ax = plt.subplots()
ax.set_title("Timeline Plot")
ax.set_xlabel("Datetime")
ax.set_ylabel("Trial")

if len(info.bars) == 0:
return ax

ax.barh(
y=[b.number for b in info.bars],
width=[b.complete - b.start for b in info.bars],
left=[b.start for b in info.bars],
color=[_cm[b.state] for b in info.bars],
)

# There are 5 types of TrialState in total.
# However, the legend depicts only types present in the arguments.
legend_handles = []
for state, color in _cm.items():
if len([b for b in info.bars if b.state == state]) > 0:
legend_handles.append(patches.Patch(color=color, label=state.name))
ax.legend(handles=legend_handles, loc="upper left", bbox_to_anchor=(1.05, 1.0))
fig.tight_layout()

assert len(info.bars) > 0
start_time = min([b.start for b in info.bars])
complete_time = max([b.complete for b in info.bars])
margin = (complete_time - start_time) * 0.05

ax.set_xlim(right=complete_time + margin, left=start_time - margin)
ax.yaxis.set_major_locator(ticker.MaxNLocator(integer=True))
ax.xaxis.set_major_formatter(_DateFormatter_Millisecond())
plt.gcf().autofmt_xdate()
return ax
23 changes: 23 additions & 0 deletions tests/visualization_tests/matplotlib_tests/test_timeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

from io import BytesIO

import pytest

from optuna.trial import TrialState
from optuna.visualization.matplotlib._timeline import plot_timeline
from tests.visualization_tests.test_timeline import _create_study


@pytest.mark.parametrize(
"trial_states_list",
[
[],
[TrialState.COMPLETE, TrialState.PRUNED, TrialState.FAIL],
[TrialState.FAIL, TrialState.PRUNED, TrialState.COMPLETE],
],
)
def test_get_timeline_plot(trial_states_list: list[TrialState]) -> None:
study = _create_study(trial_states_list)
fig = plot_timeline(study)
fig.get_figure().savefig(BytesIO())