Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timeline plot with plotly as backend #4470

Merged
merged 32 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
6a8ce6e
Create _timeline.py
eukaryo Feb 27, 2023
0c3497e
Create test_timeline.py
eukaryo Feb 27, 2023
3348017
Update _timeline.py
eukaryo Feb 27, 2023
728c952
Update _timeline.py
eukaryo Feb 27, 2023
e6e1704
Update test_timeline.py
eukaryo Feb 27, 2023
9bf2216
Update test_timeline.py
eukaryo Feb 27, 2023
0382ee7
Update _timeline.py
eukaryo Feb 27, 2023
a74c5b0
Merge branch 'optuna:master' into plotly-timeline-plot
eukaryo Mar 8, 2023
d4e0720
Update test_timeline.py
eukaryo Mar 8, 2023
beecdca
Update _timeline.py
eukaryo Mar 8, 2023
e29e78d
Update _pareto_front.py
eukaryo Mar 8, 2023
8d590c7
Update _utils.py
eukaryo Mar 8, 2023
6526688
Update test_pareto_front.py
eukaryo Mar 8, 2023
e1a07ab
Update test_utils.py
eukaryo Mar 8, 2023
d7e068b
Update index.rst
eukaryo Mar 15, 2023
57258e5
Update _pareto_front.py
eukaryo Mar 15, 2023
9be9b6f
Update _timeline.py
eukaryo Mar 15, 2023
b707aca
Update test_timeline.py
eukaryo Mar 15, 2023
01d2d5f
Update test_timeline.py
eukaryo Mar 15, 2023
a857cfd
Merge branch 'optuna:master' into plotly-timeline-plot
eukaryo Mar 15, 2023
21b370f
Update __init__.py
eukaryo Mar 15, 2023
71067e6
Update _timeline.py
eukaryo Mar 15, 2023
a602c02
Update test_timeline.py
eukaryo Mar 15, 2023
d40a9f7
Update _timeline.py
eukaryo Mar 15, 2023
4c9a9eb
Update _timeline.py
eukaryo Mar 16, 2023
b165a51
Update _timeline.py
eukaryo Mar 16, 2023
7b6ab50
Update test_timeline.py
eukaryo Mar 16, 2023
0bf8769
Merge branch 'optuna:master' into plotly-timeline-plot
eukaryo Mar 17, 2023
7754454
Update _timeline.py
eukaryo Mar 17, 2023
35df5f6
Update test_timeline.py
eukaryo Mar 17, 2023
19750bd
Merge branch 'optuna:master' into plotly-timeline-plot
eukaryo Mar 17, 2023
60eb620
Update tests/visualization_tests/test_timeline.py
eukaryo Mar 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 2 additions & 26 deletions optuna/visualization/_pareto_front.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import collections
import json
from typing import Any
from typing import Callable
from typing import Dict
Expand All @@ -18,6 +17,7 @@
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
from optuna.visualization._plotly_imports import _imports
from optuna.visualization._utils import _make_hovertext


if _imports.is_successful():
Expand Down Expand Up @@ -365,15 +365,6 @@ def _get_non_pareto_front_trials(
return non_pareto_trials


def _make_json_compatible(value: Any) -> Any:
try:
json.dumps(value)
return value
except TypeError:
# The value can't be converted to JSON directly, so return a string representation.
return str(value)


def _make_scatter_object(
n_targets: int,
axis_order: Sequence[int],
Expand Down Expand Up @@ -413,22 +404,7 @@ def _make_scatter_object(
showlegend=False,
)
else:
assert False, "Must not reach here"


def _make_hovertext(trial: FrozenTrial) -> str:
user_attrs = {key: _make_json_compatible(value) for key, value in trial.user_attrs.items()}
user_attrs_dict = {"user_attrs": user_attrs} if user_attrs else {}
text = json.dumps(
{
"number": trial.number,
"values": trial.values,
"params": trial.params,
**user_attrs_dict,
},
indent=2,
)
return text.replace("\n", "<br>")
raise AssertionError("Must not reach here")
toshihikoyanase marked this conversation as resolved.
Show resolved Hide resolved


def _make_marker(
Expand Down
131 changes: 131 additions & 0 deletions optuna/visualization/_timeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import datetime
from typing import List
from typing import NamedTuple

from optuna._experimental import experimental_func
from optuna.logging import get_logger
from optuna.study import Study
from optuna.trial import TrialState
from optuna.visualization._plotly_imports import _imports
from optuna.visualization._utils import _make_hovertext


if _imports.is_successful():
from optuna.visualization._plotly_imports import go

_logger = get_logger(__name__)


class _TimelineBarInfo(NamedTuple):
number: int
start: datetime.datetime
end: datetime.datetime
eukaryo marked this conversation as resolved.
Show resolved Hide resolved
state: TrialState
hovertext: str


class _TimelineInfo(NamedTuple):
bars: List[_TimelineBarInfo]


@experimental_func("3.2.0")
def plot_timeline(study: Study) -> "go.Figure":
"""Plot the timeline of a study.
toshihikoyanase marked this conversation as resolved.
Show resolved Hide resolved

Example:

The following code snippet shows how to plot the timeline of a study.

.. plotly::

import optuna


def objective(trial):
x = trial.suggest_float("x", 0, 1)
time.sleep(x * 0.1)
if x > 0.8:
raise ValueError()
if x > 0.4:
raise optuna.TrialPruned()
return x ** 2


study = optuna.create_study(direction="minimize")
study.optimize(
objective, n_trials=50, n_jobs=2, catch=(ValueError,)
)
eukaryo marked this conversation as resolved.
Show resolved Hide resolved

fig = optuna.visualization.plot_timeline(study)
fig.show()

Args:
study:
A :class:`~optuna.study.Study` object whose trials are plotted with
their lifetime.

Returns:
A :class:`plotly.graph_objs.Figure` object.
"""
_imports.check()
info = _get_timeline_info(study)
return _get_timeline_plot(info)


def _get_timeline_info(study: Study) -> _TimelineInfo:
bars = []
for t in study.get_trials():
date_end = t.datetime_complete or datetime.datetime.now()
date_start = t.datetime_start or date_end
bars.append(
_TimelineBarInfo(
number=t.number,
start=date_start,
end=date_end,
state=t.state,
hovertext=_make_hovertext(t),
)
)

if len(bars) == 0:
_logger.warning("Your study does not have any trials.")

return _TimelineInfo(bars)


def _get_timeline_plot(info: _TimelineInfo) -> "go.Figure":
_cm = {
"COMPLETE": "blue",
"FAIL": "red",
"PRUNED": "orange",
"RUNNING": "green",
"WAITING": "gray",
}

fig = go.Figure()
for s in sorted(TrialState, key=lambda x: x.name):
bars = [b for b in info.bars if b.state == s]
if len(bars) == 0:
continue
fig.add_trace(
go.Bar(
name=s.name,
x=[(b.end - b.start).total_seconds() * 1000 for b in bars],
y=[b.number for b in bars],
base=[b.start.isoformat() for b in bars],
text=[b.hovertext for b in bars],
hovertemplate="%{text}<extra>" + s.name + "</extra>",
orientation="h",
marker=dict(color=_cm[s.name]),
textposition="none", # avoid drawing hovertext in a bar
toshihikoyanase marked this conversation as resolved.
Show resolved Hide resolved
)
)
fig.update_xaxes(type="date")
fig.update_layout(
go.Layout(
title="Timeline Plot",
xaxis={"title": "Datetime"},
yaxis={"title": "Trial"},
)
)
return fig
25 changes: 25 additions & 0 deletions optuna/visualization/_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Any
from typing import Callable
from typing import cast
Expand Down Expand Up @@ -179,3 +180,27 @@ def _target(t: FrozenTrial) -> float:

def _is_reverse_scale(study: Study, target: Optional[Callable[[FrozenTrial], float]]) -> bool:
return target is not None or study.direction == StudyDirection.MINIMIZE


def _make_json_compatible(value: Any) -> Any:
try:
json.dumps(value)
return value
except TypeError:
# The value can't be converted to JSON directly, so return a string representation.
return str(value)


def _make_hovertext(trial: FrozenTrial) -> str:
toshihikoyanase marked this conversation as resolved.
Show resolved Hide resolved
user_attrs = {key: _make_json_compatible(value) for key, value in trial.user_attrs.items()}
user_attrs_dict = {"user_attrs": user_attrs} if user_attrs else {}
text = json.dumps(
{
"number": trial.number,
"values": trial.values,
"params": trial.params,
**user_attrs_dict,
},
indent=2,
)
return text.replace("\n", "<br>")
113 changes: 0 additions & 113 deletions tests/visualization_tests/test_pareto_front.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import datetime
from io import BytesIO
from textwrap import dedent
from typing import Any
from typing import Callable
from typing import List
from typing import Optional
from typing import Sequence
import warnings

import numpy as np
import pytest

import optuna
Expand All @@ -17,11 +14,9 @@
from optuna.distributions import FloatDistribution
from optuna.study.study import Study
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
from optuna.visualization import plot_pareto_front
import optuna.visualization._pareto_front
from optuna.visualization._pareto_front import _get_pareto_front_info
from optuna.visualization._pareto_front import _make_hovertext
from optuna.visualization._pareto_front import _ParetoFrontInfo
from optuna.visualization._plotly_imports import go
from optuna.visualization._utils import COLOR_SCALE
Expand Down Expand Up @@ -319,114 +314,6 @@ def test_get_pareto_front_plot(
plt.savefig(BytesIO())


def test_make_hovertext() -> None:
trial_no_user_attrs = FrozenTrial(
number=0,
trial_id=0,
state=TrialState.COMPLETE,
value=0.2,
datetime_start=datetime.datetime.now(),
datetime_complete=datetime.datetime.now(),
params={"x": 10},
distributions={"x": FloatDistribution(5, 12)},
user_attrs={},
system_attrs={},
intermediate_values={},
)
assert (
_make_hovertext(trial_no_user_attrs)
== dedent(
"""
{
"number": 0,
"values": [
0.2
],
"params": {
"x": 10
}
}
"""
)
.strip()
.replace("\n", "<br>")
)

trial_user_attrs_valid_json = FrozenTrial(
number=0,
trial_id=0,
state=TrialState.COMPLETE,
value=0.2,
datetime_start=datetime.datetime.now(),
datetime_complete=datetime.datetime.now(),
params={"x": 10},
distributions={"x": FloatDistribution(5, 12)},
user_attrs={"a": 42, "b": 3.14},
system_attrs={},
intermediate_values={},
)
assert (
_make_hovertext(trial_user_attrs_valid_json)
== dedent(
"""
{
"number": 0,
"values": [
0.2
],
"params": {
"x": 10
},
"user_attrs": {
"a": 42,
"b": 3.14
}
}
"""
)
.strip()
.replace("\n", "<br>")
)

trial_user_attrs_invalid_json = FrozenTrial(
number=0,
trial_id=0,
state=TrialState.COMPLETE,
value=0.2,
datetime_start=datetime.datetime.now(),
datetime_complete=datetime.datetime.now(),
params={"x": 10},
distributions={"x": FloatDistribution(5, 12)},
user_attrs={"a": 42, "b": 3.14, "c": np.zeros(1), "d": np.nan},
system_attrs={},
intermediate_values={},
)
assert (
_make_hovertext(trial_user_attrs_invalid_json)
== dedent(
"""
{
"number": 0,
"values": [
0.2
],
"params": {
"x": 10
},
"user_attrs": {
"a": 42,
"b": 3.14,
"c": "[0.]",
"d": NaN
}
}
"""
)
.strip()
.replace("\n", "<br>")
)


@pytest.mark.parametrize("direction", ["minimize", "maximize"])
def test_color_map(direction: str) -> None:
study = create_study(directions=[direction, direction])
Expand Down
Loading