# Fig 2. FIRM line plot

In [35]:
import plotly.graph_objects as go
from scores.categorical import firm
from scores.stats.statistical_tests import diebold_mariano
import xarray as xr
import numpy as np
from plotly.subplots import make_subplots

In [36]:
fcst = xr.open_mfdataset(
    [
        "data/fcst_2020_2021.nc",
        "data/fcst_2021_2022.nc",
        "data/fcst_2022_2023.nc",
    ]
)
fcst = fcst["__xarray_dataarray_variable__"].compute()

obs = xr.open_mfdataset(
    ["data/obs_2020_2021.nc", "data/obs_2021_2022.nc", "data/obs_2022_2023.nc"]
)
obs = obs["__xarray_dataarray_variable__"].compute()

In [37]:
risk_parameter = 0.5
categorical_thresholds = [1, 3]
threshold_weights = [2, 1]
firm_score = firm(
    fcst,
    obs,
    risk_parameter,
    categorical_thresholds,
    threshold_weights,
    threshold_assignment="upper",
    preserve_dims="lead_day",
)
firm_score

In [38]:
# Calculate benchmark score of never warning
firm_ref = firm(
    fcst * 0,
    obs,
    risk_parameter,
    categorical_thresholds,
    threshold_weights,
    threshold_assignment="upper",
    preserve_dims="lead_day",
)
firm_ref

In [39]:
# Confidence intervals
firm_score_date_preserved = firm(
    fcst,
    obs,
    risk_parameter,
    categorical_thresholds,
    threshold_weights,
    threshold_assignment="upper",
    preserve_dims=["lead_day", "valid_utc_date"],
)
ref_score_date_preserved = firm(
    fcst * 0,
    obs,
    risk_parameter,
    categorical_thresholds,
    threshold_weights,
    threshold_assignment="upper",
    preserve_dims=["lead_day", "valid_utc_date"],
)

diff = ref_score_date_preserved.firm_score - firm_score_date_preserved.firm_score

diff = diff.assign_coords(
    h=(
        "lead_day",
        [
            1,
            2,
            3,
            4,
            5,
            6,
            7,
        ],
    )
)
dm_result = diebold_mariano(diff, "lead_day", "h", confidence_level=0.95)

In [40]:
fig = make_subplots(
    rows=2,
    cols=1,
    subplot_titles=(
        "<b>(a)</b>",
        "<b>(b)</b>",
    ),
    vertical_spacing=0.1,
)
fig.update_annotations(font_size=12, xshift=-160, xanchor="left")


fig.add_trace(
    go.Scatter(
        x=firm_score.lead_day,
        y=firm_score.firm_score,
        name="Mean FIRM score",
        line=dict(color="#E69F00"),
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=firm_score.lead_day,
        y=firm_score.overforecast_penalty,
        name="Overforecast penalty",
        line=dict(color="#CC79A7"),
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=firm_score.lead_day,
        y=firm_score.underforecast_penalty,
        name="Underforecast penalty",
        line=dict(color="#56B4E9"),
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=firm_ref.lead_day,
        y=firm_ref.firm_score,
        name="No warning reference",
        mode="lines",
        line=dict(color="black", dash="dash"),
    ),
    row=1,
    col=1,
)

fig.add_trace(
    go.Scatter(
        x=firm_ref.lead_day,
        y=dm_result["mean"],
        line=dict(color="black"),
        error_y=dict(
            thickness=1,
            type="data",
            symmetric=False,
            array=dm_result["ci_upper"] - dm_result["mean"],
            arrayminus=dm_result["mean"] - dm_result["ci_lower"],
        ),
        showlegend=False,
    ),
    row=2,
    col=1,
)

fig.add_hline(y=0, row=2, col=1)
fig.update_layout(
    width=400,
    height=600,
    margin=dict(l=0, r=10, b=50, t=20),
    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
)
fig.update_yaxes(title_text="Mean FIRM score", range=[0, 0.05], row=1, col=1)
fig.update_yaxes(title_text="Difference in Mean FIRM score", row=2, col=1)
fig.update_xaxes(title_text="Lead day", tickmode="linear", tick0=0, dtick=1)

In [41]:
fig.write_image("results/figures/fig2.pdf")