Skip to content

Commit

Permalink
Merge pull request #4470 from eukaryo/plotly-timeline-plot
Browse files Browse the repository at this point in the history
Add timeline plot with plotly as backend
  • Loading branch information
HideakiImamura committed Mar 23, 2023
2 parents ec5687c + 60eb620 commit b410c00
Show file tree
Hide file tree
Showing 8 changed files with 393 additions and 138 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/visualization/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ The :mod:`~optuna.visualization` module provides utility functions for plotting
optuna.visualization.plot_param_importances
optuna.visualization.plot_pareto_front
optuna.visualization.plot_slice
optuna.visualization.plot_timeline
optuna.visualization.is_available

.. note::
Expand Down
2 changes: 2 additions & 0 deletions optuna/visualization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from optuna.visualization._param_importances import plot_param_importances
from optuna.visualization._pareto_front import plot_pareto_front
from optuna.visualization._slice import plot_slice
from optuna.visualization._timeline import plot_timeline
from optuna.visualization._utils import is_available


Expand All @@ -21,4 +22,5 @@
"plot_param_importances",
"plot_pareto_front",
"plot_slice",
"plot_timeline",
]
26 changes: 1 addition & 25 deletions optuna/visualization/_pareto_front.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import collections
import json
from typing import Any
from typing import Callable
from typing import Dict
Expand All @@ -18,6 +17,7 @@
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
from optuna.visualization._plotly_imports import _imports
from optuna.visualization._utils import _make_hovertext


if _imports.is_successful():
Expand Down Expand Up @@ -365,15 +365,6 @@ def _get_non_pareto_front_trials(
return non_pareto_trials


def _make_json_compatible(value: Any) -> Any:
try:
json.dumps(value)
return value
except TypeError:
# The value can't be converted to JSON directly, so return a string representation.
return str(value)


def _make_scatter_object(
n_targets: int,
axis_order: Sequence[int],
Expand Down Expand Up @@ -416,21 +407,6 @@ def _make_scatter_object(
assert False, "Must not reach here"


def _make_hovertext(trial: FrozenTrial) -> str:
user_attrs = {key: _make_json_compatible(value) for key, value in trial.user_attrs.items()}
user_attrs_dict = {"user_attrs": user_attrs} if user_attrs else {}
text = json.dumps(
{
"number": trial.number,
"values": trial.values,
"params": trial.params,
**user_attrs_dict,
},
indent=2,
)
return text.replace("\n", "<br>")


def _make_marker(
trials: Sequence[FrozenTrial],
include_dominated_trials: bool,
Expand Down
143 changes: 143 additions & 0 deletions optuna/visualization/_timeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from __future__ import annotations

import datetime
from typing import NamedTuple

from optuna._experimental import experimental_func
from optuna.logging import get_logger
from optuna.study import Study
from optuna.trial import TrialState
from optuna.visualization._plotly_imports import _imports
from optuna.visualization._utils import _make_hovertext


if _imports.is_successful():
from optuna.visualization._plotly_imports import go

_logger = get_logger(__name__)


class _TimelineBarInfo(NamedTuple):
number: int
start: datetime.datetime
complete: datetime.datetime
state: TrialState
hovertext: str


class _TimelineInfo(NamedTuple):
bars: list[_TimelineBarInfo]


@experimental_func("3.2.0")
def plot_timeline(study: Study) -> "go.Figure":
"""Plot the timeline of a study.
Example:
The following code snippet shows how to plot the timeline of a study.
Timeline plot can visualize trials with overlapping execution time
(e.g., in distributed environments).
.. plotly::
import time
import optuna
def objective(trial):
x = trial.suggest_float("x", 0, 1)
time.sleep(x * 0.1)
if x > 0.8:
raise ValueError()
if x > 0.4:
raise optuna.TrialPruned()
return x ** 2
study = optuna.create_study(direction="minimize")
study.optimize(
objective, n_trials=50, n_jobs=2, catch=(ValueError,)
)
fig = optuna.visualization.plot_timeline(study)
fig.show()
Args:
study:
A :class:`~optuna.study.Study` object whose trials are plotted with
their lifetime.
Returns:
A :class:`plotly.graph_objs.Figure` object.
"""
_imports.check()
info = _get_timeline_info(study)
return _get_timeline_plot(info)


def _get_timeline_info(study: Study) -> _TimelineInfo:
bars = []
for t in study.get_trials():
date_complete = t.datetime_complete or datetime.datetime.now()
date_start = t.datetime_start or date_complete
if date_complete < date_start:
_logger.warning(
(
f"The start and end times for Trial {t.number} seem to be reversed. "
f"The start time is {date_start} and the end time is {date_complete}."
)
)
bars.append(
_TimelineBarInfo(
number=t.number,
start=date_start,
complete=date_complete,
state=t.state,
hovertext=_make_hovertext(t),
)
)

if len(bars) == 0:
_logger.warning("Your study does not have any trials.")

return _TimelineInfo(bars)


def _get_timeline_plot(info: _TimelineInfo) -> "go.Figure":
_cm = {
"COMPLETE": "blue",
"FAIL": "red",
"PRUNED": "orange",
"RUNNING": "green",
"WAITING": "gray",
}

fig = go.Figure()
for s in sorted(TrialState, key=lambda x: x.name):
bars = [b for b in info.bars if b.state == s]
if len(bars) == 0:
continue
fig.add_trace(
go.Bar(
name=s.name,
x=[(b.complete - b.start).total_seconds() * 1000 for b in bars],
y=[b.number for b in bars],
base=[b.start.isoformat() for b in bars],
text=[b.hovertext for b in bars],
hovertemplate="%{text}<extra>" + s.name + "</extra>",
orientation="h",
marker=dict(color=_cm[s.name]),
textposition="none", # Avoid drawing hovertext in a bar.
)
)
fig.update_xaxes(type="date")
fig.update_layout(
go.Layout(
title="Timeline Plot",
xaxis={"title": "Datetime"},
yaxis={"title": "Trial"},
)
)
return fig
25 changes: 25 additions & 0 deletions optuna/visualization/_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Any
from typing import Callable
from typing import cast
Expand Down Expand Up @@ -179,3 +180,27 @@ def _target(t: FrozenTrial) -> float:

def _is_reverse_scale(study: Study, target: Optional[Callable[[FrozenTrial], float]]) -> bool:
return target is not None or study.direction == StudyDirection.MINIMIZE


def _make_json_compatible(value: Any) -> Any:
try:
json.dumps(value)
return value
except TypeError:
# The value can't be converted to JSON directly, so return a string representation.
return str(value)


def _make_hovertext(trial: FrozenTrial) -> str:
user_attrs = {key: _make_json_compatible(value) for key, value in trial.user_attrs.items()}
user_attrs_dict = {"user_attrs": user_attrs} if user_attrs else {}
text = json.dumps(
{
"number": trial.number,
"values": trial.values,
"params": trial.params,
**user_attrs_dict,
},
indent=2,
)
return text.replace("\n", "<br>")
113 changes: 0 additions & 113 deletions tests/visualization_tests/test_pareto_front.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import datetime
from io import BytesIO
from textwrap import dedent
from typing import Any
from typing import Callable
from typing import List
from typing import Optional
from typing import Sequence
import warnings

import numpy as np
import pytest

import optuna
Expand All @@ -17,11 +14,9 @@
from optuna.distributions import FloatDistribution
from optuna.study.study import Study
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
from optuna.visualization import plot_pareto_front
import optuna.visualization._pareto_front
from optuna.visualization._pareto_front import _get_pareto_front_info
from optuna.visualization._pareto_front import _make_hovertext
from optuna.visualization._pareto_front import _ParetoFrontInfo
from optuna.visualization._plotly_imports import go
from optuna.visualization._utils import COLOR_SCALE
Expand Down Expand Up @@ -319,114 +314,6 @@ def test_get_pareto_front_plot(
plt.savefig(BytesIO())


def test_make_hovertext() -> None:
trial_no_user_attrs = FrozenTrial(
number=0,
trial_id=0,
state=TrialState.COMPLETE,
value=0.2,
datetime_start=datetime.datetime.now(),
datetime_complete=datetime.datetime.now(),
params={"x": 10},
distributions={"x": FloatDistribution(5, 12)},
user_attrs={},
system_attrs={},
intermediate_values={},
)
assert (
_make_hovertext(trial_no_user_attrs)
== dedent(
"""
{
"number": 0,
"values": [
0.2
],
"params": {
"x": 10
}
}
"""
)
.strip()
.replace("\n", "<br>")
)

trial_user_attrs_valid_json = FrozenTrial(
number=0,
trial_id=0,
state=TrialState.COMPLETE,
value=0.2,
datetime_start=datetime.datetime.now(),
datetime_complete=datetime.datetime.now(),
params={"x": 10},
distributions={"x": FloatDistribution(5, 12)},
user_attrs={"a": 42, "b": 3.14},
system_attrs={},
intermediate_values={},
)
assert (
_make_hovertext(trial_user_attrs_valid_json)
== dedent(
"""
{
"number": 0,
"values": [
0.2
],
"params": {
"x": 10
},
"user_attrs": {
"a": 42,
"b": 3.14
}
}
"""
)
.strip()
.replace("\n", "<br>")
)

trial_user_attrs_invalid_json = FrozenTrial(
number=0,
trial_id=0,
state=TrialState.COMPLETE,
value=0.2,
datetime_start=datetime.datetime.now(),
datetime_complete=datetime.datetime.now(),
params={"x": 10},
distributions={"x": FloatDistribution(5, 12)},
user_attrs={"a": 42, "b": 3.14, "c": np.zeros(1), "d": np.nan},
system_attrs={},
intermediate_values={},
)
assert (
_make_hovertext(trial_user_attrs_invalid_json)
== dedent(
"""
{
"number": 0,
"values": [
0.2
],
"params": {
"x": 10
},
"user_attrs": {
"a": 42,
"b": 3.14,
"c": "[0.]",
"d": NaN
}
}
"""
)
.strip()
.replace("\n", "<br>")
)


@pytest.mark.parametrize("direction", ["minimize", "maximize"])
def test_color_map(direction: str) -> None:
study = create_study(directions=[direction, direction])
Expand Down
Loading

0 comments on commit b410c00

Please sign in to comment.