In [None]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.stats import norm

In [None]:
mu, sigma = 0, 1
obs = 0.0
y = np.linspace(-3, 3, 500)

f_y = norm.cdf(y, loc=mu, scale=sigma)
h_y = np.where(y >= obs, 1, 0)

fig = make_subplots(
    rows=3,
    cols=1,
    subplot_titles=(
        "a) Graphical illustration of twCRPS",
        "b) Weighting function",
        "c) Chaining function",
    ),
)

# Top panel traces
fig.add_trace(
    go.Scatter(
        x=y,
        y=h_y,
        mode="lines",
        name="Observed CDF",
        line=dict(color="#E69F00", dash="dash"),
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=y,
        y=f_y,
        mode="lines",
        name="Forecast CDF",
        line=dict(color="#56B4E9"),
    ),
    row=1,
    col=1,
)
mask = y >= 0.5
x_shade = np.concatenate([y[mask], y[mask][::-1]])
y_shade = np.concatenate([f_y[mask], h_y[mask][::-1]])

fig.add_trace(
    go.Scatter(
        x=x_shade,
        y=y_shade,
        fill="toself",
        fillcolor="rgba(0, 158, 115, 0.5)",
        line=dict(color="rgba(255,255,255,0)"),
        showlegend=False,
    ),
    row=1,
    col=1,
)

# Middle panel trace
fig.add_trace(
    go.Scatter(
        x=y,
        y=np.where(y >= 0.5, 1, 0),
        mode="lines",
        name="Weighting Function",
        line=dict(color="black"),
        showlegend=False,
    ),
    row=2,
    col=1,
)
# Bottom panel trace
fig.add_trace(
    go.Scatter(
        x=y,
        y=np.maximum(y, 0.5),
        mode="lines",
        name="Chaining Function",
        line=dict(color="black"),
        showlegend=False,
    ),
    row=3,
    col=1,
)

fig.update_layout(
    legend=dict(x=0, y=1, xanchor="left", yanchor="top"),
    template="plotly_white",
    height=600,
    width=500,
    margin=dict(l=20, r=20, t=20, b=20),
)
fig.update_xaxes(tickvals=[0.5], ticktext=["t"], row=1, col=1)
fig.update_xaxes(tickvals=[0.5], ticktext=["t"], row=2, col=1)
fig.update_xaxes(tickvals=[0.5], ticktext=["t"], row=3, col=1)

fig.update_yaxes(title_text="Probability of<br>non-exceedance", row=1, col=1)
fig.update_yaxes(title_text="w(z)", row=2, col=1)
fig.update_yaxes(title_text="v(z)", row=3, col=1)
fig.update_xaxes(title_text="Threshold (z)", row=1, col=1)
fig.update_xaxes(title_text="Threshold (z)", row=2, col=1)
fig.update_xaxes(title_text="Threshold (z)", row=3, col=1)
fig.show()

In [None]:
fig.write_image("../paper_figs/twCRPS_illustration.svg", format="pdf")