Skip to content

Commit

Permalink
ENH add plotly backend to plot_murphy_diagram (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
lorentzenchr committed Feb 27, 2024
1 parent 39bfe14 commit c03d022
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 81 deletions.
58 changes: 58 additions & 0 deletions src/model_diagnostics/_utils/plot_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import sys

import matplotlib as mpl


def get_plotly_color(i):
try:
sys.modules["plotly"]
# Sometimes, those turn out to be the same as matplotlib default.
# colors = plotly.colors.DEFAULT_PLOTLY_COLORS
# Those are the plotly color default color palette in hex.
import plotly.express as px

colors = px.colors.qualitative.Plotly
return colors[i % len(colors)]
except KeyError:
return False


def get_xlabel(ax):
if isinstance(ax, mpl.axes.Axes):
return ax.get_xlabel()
else:
# ax = plotly figure
return ax.layout.xaxis.title.text


def get_ylabel(ax):
if isinstance(ax, mpl.axes.Axes):
return ax.get_ylabel()
else:
# ax = plotly figure
return ax.layout.yaxis.title.text


def get_title(ax):
if isinstance(ax, mpl.axes.Axes):
return ax.get_title()
else:
# ax = plotly figure
return ax.layout.title.text


def get_legend_list(ax):
if isinstance(ax, mpl.axes.Axes):
return [t.get_text() for t in ax.get_legend().get_texts()]
else:
# ax = plotly figure
return [d.name for d in ax.data if d.showlegend is None or d.showlegend]


def is_plotly_figure(x):
"""Return True if the x is a plotly figure."""
try:
plotly = sys.modules["plotly"]
except KeyError:
return False
return isinstance(x, plotly.graph_objects.Figure)
41 changes: 9 additions & 32 deletions src/model_diagnostics/calibration/plots.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import sys
import warnings
from functools import partial
from typing import Optional
Expand All @@ -22,33 +21,11 @@
length_of_second_dimension,
)
from model_diagnostics._utils.isotonic import IsotonicRegression
from model_diagnostics._utils.plot_helper import get_plotly_color, is_plotly_figure

from .identification import compute_bias


def _is_plotly_figure(x):
"""Return True if the x is a plotly figure."""
try:
plotly = sys.modules["plotly"]
except KeyError:
return False
return isinstance(x, plotly.graph_objects.Figure)


def _get_plotly_color(i):
try:
sys.modules["plotly"]
# Sometimes, those turn out to be the same as matplotlib default.
# colors = plotly.colors.DEFAULT_PLOTLY_COLORS
# Those are the plotly color default color palette in hex.
import plotly.express as px

colors = px.colors.qualitative.Plotly
return colors[i % len(colors)]
except KeyError:
return False


def plot_reliability_diagram(
y_obs: npt.ArrayLike,
y_pred: npt.ArrayLike,
Expand Down Expand Up @@ -151,7 +128,7 @@ def plot_reliability_diagram(
fig = ax = go.Figure()
elif isinstance(ax, mpl.axes.Axes):
plot_backend = "matplotlib"
elif _is_plotly_figure(ax):
elif is_plotly_figure(ax):
import plotly.graph_objects as go

plot_backend = "plotly"
Expand Down Expand Up @@ -267,7 +244,7 @@ def iso_statistic(y_obs, y_pred, weights=None, x_values=None):
ax.fill_between(iso.X_thresholds_, lower, upper, alpha=0.1)
else:
# plotly has not equivalent of fill_between and needs a bit more coding
color = _get_plotly_color(i)
color = get_plotly_color(i)
fig.add_scatter(
x=np.r_[iso.X_thresholds_, iso.X_thresholds_[::-1]],
y=np.r_[lower, upper[::-1]],
Expand Down Expand Up @@ -295,7 +272,7 @@ def iso_statistic(y_obs, y_pred, weights=None, x_values=None):
x=iso.X_thresholds_,
y=y_plot,
mode="lines",
line={"color": _get_plotly_color(i)},
line={"color": get_plotly_color(i)},
name=label,
)

Expand Down Expand Up @@ -441,7 +418,7 @@ def plot_bias(
fig = ax = go.Figure()
elif isinstance(ax, mpl.axes.Axes):
plot_backend = "matplotlib"
elif _is_plotly_figure(ax):
elif is_plotly_figure(ax):
import plotly.graph_objects as go

plot_backend = "plotly"
Expand Down Expand Up @@ -566,7 +543,7 @@ def plot_bias(
"width": 4,
"visible": True,
},
marker={"color": _get_plotly_color(i)},
marker={"color": get_plotly_color(i)},
mode="markers",
name=label,
)
Expand All @@ -585,7 +562,7 @@ def plot_bias(
# plotly has not equivalent of fill_between and needs a bit more
# coding
# FIXME: polars >= 0.20.0 use df_i[::-1, feature_name]
color = _get_plotly_color(i)
color = get_plotly_color(i)
fig.add_scatter(
x=pl.concat([df_i[feature_name], df_i[feature_name][::-1]]),
y=pl.concat([lower, upper[::-1]]),
Expand All @@ -611,7 +588,7 @@ def plot_bias(
y=df_i["bias_mean"],
marker_symbol="circle",
mode="lines+markers",
line={"color": _get_plotly_color(i)},
line={"color": get_plotly_color(i)},
name=label,
)

Expand Down Expand Up @@ -658,7 +635,7 @@ def plot_bias(
"width": 4,
"visible": True,
},
marker={"color": _get_plotly_color(i), "symbol": "diamond"},
marker={"color": get_plotly_color(i), "symbol": "diamond"},
mode="markers",
showlegend=False,
)
Expand Down
34 changes: 6 additions & 28 deletions src/model_diagnostics/calibration/tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from sklearn.model_selection import train_test_split

from model_diagnostics import polars_version
from model_diagnostics._utils.plot_helper import (
get_legend_list,
get_title,
get_xlabel,
get_ylabel,
)
from model_diagnostics._utils.test_helper import (
SkipContainer,
pa_array,
Expand All @@ -19,34 +25,6 @@
from model_diagnostics.calibration import plot_bias, plot_reliability_diagram


def get_xlabel(ax):
if isinstance(ax, mpl.axes.Axes):
return ax.get_xlabel()
else:
return ax.layout.xaxis.title.text


def get_ylabel(ax):
if isinstance(ax, mpl.axes.Axes):
return ax.get_ylabel()
else:
return ax.layout.yaxis.title.text


def get_title(ax):
if isinstance(ax, mpl.axes.Axes):
return ax.get_title()
else:
return ax.layout.title.text


def get_legend_list(ax):
if isinstance(ax, mpl.axes.Axes):
return [t.get_text() for t in ax.get_legend().get_texts()]
else:
return [d.name for d in ax.data if d.showlegend is None or d.showlegend]


@pytest.mark.parametrize(
("param", "value", "msg"),
[
Expand Down
61 changes: 51 additions & 10 deletions src/model_diagnostics/scoring/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_sorted_array_names,
length_of_second_dimension,
)
from model_diagnostics._utils.plot_helper import get_plotly_color, is_plotly_figure

from .scoring import ElementaryScore

Expand All @@ -25,6 +26,7 @@ def plot_murphy_diagram(
functional: str = "mean",
level: float = 0.5,
ax: Optional[mpl.axes.Axes] = None,
plot_backend: str = "matplotlib",
):
r"""Plot a Murphy diagram.
Expand Down Expand Up @@ -61,6 +63,11 @@ def plot_murphy_diagram(
`level=0.5` and `functional="quantile"` gives the median.
ax : matplotlib.axes.Axes
Axes object to draw the plot onto, otherwise uses the current Axes.
plot_backend: str
The plotting backend to use when `ax = None`. Options are:
- "matplotlib"
- "plotly"
Returns
-------
Expand All @@ -79,8 +86,30 @@ def plot_murphy_diagram(
Representations, and Forecast Rankings".
[arxiv:1503.08195](https://arxiv.org/abs/1503.08195).
"""
if plot_backend not in ("matplotlib", "plotly"):
msg = f"The plot_backend must be matplotlib or plotly, got {plot_backend}."
raise ValueError(msg)

if ax is None:
ax = plt.gca()
if plot_backend == "matplotlib":
ax = plt.gca()
else:
import plotly.graph_objects as go

fig = ax = go.Figure()
elif isinstance(ax, mpl.axes.Axes):
plot_backend = "matplotlib"
elif is_plotly_figure(ax):
import plotly.graph_objects as go

plot_backend = "plotly"
fig = ax
else:
msg = (
"The ax argument must be None, a matplotlib Axes or a plotly Figure, "
f"got {type(ax)}."
)
raise ValueError(msg)

if (n_cols := length_of_second_dimension(y_obs)) > 0:
if n_cols == 1:
Expand Down Expand Up @@ -120,19 +149,31 @@ def elementary_score(y_obs, y_pred, weights, eta):
for eta in etas
]
label = pred_names[i] if n_pred >= 2 else None
ax.plot(etas, y_plot, label=label)
if plot_backend == "matplotlib":
ax.plot(etas, y_plot, label=label)
else:
fig.add_scatter(
x=etas,
y=y_plot,
mode="lines",
line={"color": get_plotly_color(i)},
name=label,
)

xlabel = "eta"
ylabel = "score"
title = "Murphy Diagram"
ax.set(xlabel="eta", ylabel="score")
if n_pred <= 1 and len(pred_names[0]) > 0:
title = title + " " + pred_names[0]

if n_pred >= 2:
if plot_backend == "matplotlib":
if n_pred >= 2:
ax.legend()
ax.set_title(title)
ax.legend()
ax.set(xlabel=xlabel, ylabel=ylabel)
else:
y_pred_i = y_pred if n_pred == 0 else get_second_dimension(y_pred, i)
if len(pred_names[0]) > 0:
ax.set_title(title + " " + pred_names[0])
else:
ax.set_title(title)
if n_pred <= 1:
fig.update_layout(showlegend=False)
fig.update_layout(xaxis_title=xlabel, yaxis_title=ylabel, title=title)

return ax

0 comments on commit c03d022

Please sign in to comment.