# Fig 5 - post processing

In [135]:
from scores.continuous import isotonic_fit
import plotly.graph_objects as go
from scores.categorical import firm
import xarray as xr
import scipy
import numpy as np
from plotly.subplots import make_subplots
from scores.stats import statistical_tests

In [136]:
fcst_train = xr.open_mfdataset(
    ["data/fcst_2020_2021.nc", "data/fcst_2021_2022.nc", "data/fcst_2022_2023.nc"]
)
fcst_train = fcst_train["__xarray_dataarray_variable__"].compute()

obs_train = xr.open_mfdataset(["data/obs_2020_2021.nc", "data/obs_2021_2022.nc", "data/obs_2022_2023.nc"])
obs_train = obs_train["__xarray_dataarray_variable__"].compute()

fcst_test = xr.open_mfdataset(["data/fcst_2023_2024.nc"])
fcst_test = fcst_test["__xarray_dataarray_variable__"].compute()

obs_test = xr.open_mfdataset(["data/obs_2023_2024.nc"])
obs_test = obs_test["__xarray_dataarray_variable__"].compute()

In [137]:
def select_district_data(district, lead_day):
    fcst_train1 = fcst_train.sel(district=district, lead_day=lead_day)
    obs_train1 = obs_train.sel(district=district)

    fcst_test1 = fcst_test.sel(district=district, lead_day=lead_day)
    obs_test1 = obs_test.sel(district=district)

    return fcst_train1, obs_train1, fcst_test1, obs_test1

In [138]:
def create_recalibrated_fcst(fcst_train1, obs_train1, fcst_test1):
    """
    - isotonic regression on training data
    - apply model to test data
    - create linear model to extrapolated beyond min/max training data
    - only allow lower recalibrated values
    """
    iso_dict = isotonic_fit(
        fcst_train1, obs_train1, functional="quantile", quantile_level=0.5
    )
    recalibratedfcst = fcst_test1.copy()
    recalibratedfcst.values = iso_dict["regression_func"](fcst_test1)

    # Create linear model to extrapolate to fill in missing data
    iso = iso_dict["regression_func"](np.arange(0, 20, 0.01))
    slope, intercept, _, _, _ = scipy.stats.linregress(
        np.arange(0, 20, 0.01)[~np.isnan(iso)], iso[~np.isnan(iso)]
    )

    recalibratedfcstlinear = fcst_test1 * slope + intercept
    recalibratedfcstlinear = recalibratedfcstlinear.clip(min=iso[~np.isnan(iso)].max())
    recalibratedfcst = recalibratedfcst.fillna(recalibratedfcstlinear)

    return recalibratedfcst

In [139]:
recal_list = []
for lead_day in np.arange(0, 7):
    recal_fcst_list_lead_day = []
    for district in fcst_train.district:
        fcst_train1, obs_train1, fcst_test1, obs_test1 = select_district_data(
            district, lead_day=lead_day
        )
        recalibratedfcst = create_recalibrated_fcst(fcst_train1, obs_train1, fcst_test1)
        recal_fcst_list_lead_day.append(recalibratedfcst)
    recal_list.append(xr.concat(recal_fcst_list_lead_day, dim="district"))
recalibratedfcsts = xr.concat(recal_list, dim="lead_day")

In [140]:
risk_parameter = 0.5
categorical_thresholds = [1, 3]
threshold_weights = [2, 1]
mean_firm_recal = firm(
    recalibratedfcsts,
    obs_test,
    risk_parameter,
    categorical_thresholds,
    threshold_weights,
    preserve_dims="lead_day",
)
mean_firm_recal

In [141]:
mean_firm_uncal = firm(
    fcst_test,
    obs_test,
    risk_parameter,
    categorical_thresholds,
    threshold_weights,
    preserve_dims="lead_day",
)
mean_firm_uncal

In [142]:
mean_firm_ref = firm(
    fcst_test * 0,
    obs_test,
    risk_parameter,
    categorical_thresholds,
    threshold_weights,
    preserve_dims="lead_day",
)
mean_firm_ref

# Statistical tests

In [143]:
firm_uncal = firm(
    fcst_test,
    obs_test,
    risk_parameter,
    categorical_thresholds,
    threshold_weights,
    preserve_dims=["lead_day", "valid_utc_date"],
)
firm_recal = firm(
    recalibratedfcsts,
    obs_test,
    risk_parameter,
    categorical_thresholds,
    threshold_weights,
    preserve_dims=["lead_day", "valid_utc_date"],
)

In [144]:
diff = firm_uncal - firm_recal

In [145]:
diff = firm_uncal.firm_score - firm_recal.firm_score
diff = diff.assign_coords(lead_day=diff.lead_day + 2)
diff = diff.assign_coords(h=("lead_day", [1, 2, 3, 4, 5, 6, 7]))
dm_result = statistical_tests.diebold_mariano(diff, "lead_day", "h")
dm_result

# Create subplots

In [146]:
fig = make_subplots(
    rows=2, cols=1, subplot_titles=("<b>(a)</b>", "<b>(b)</b>"), vertical_spacing=0.13
)
fig.update_annotations(font_size=12, xshift=-170, xanchor="left")

# Subplot 1
fig.add_trace(
    go.Scatter(
        x=mean_firm_uncal.lead_day,
        y=mean_firm_uncal.firm_score,
        name="Raw district forecasts",
        line=dict(color="#E69F00"),
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=mean_firm_recal.lead_day,
        y=mean_firm_recal.firm_score,
        name="Calibrated district forecasts",
        line=dict(color="#CC79A7", dash="dot"),
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=mean_firm_ref.lead_day,
        y=mean_firm_ref.firm_score,
        name="No warning reference",
        mode="lines",
        line=dict(color="black", dash="dash"),
    ),
    row=1,
    col=1,
)

# Subplot 2

fig.add_trace(
    go.Scatter(
        x=mean_firm_ref.lead_day,
        y=dm_result["mean"],
        line=dict(color="black"),
        error_y=dict(
            thickness=1,
            type="data",
            symmetric=False,
            array=dm_result["ci_upper"] - dm_result["mean"],
            arrayminus=dm_result["mean"] - dm_result["ci_lower"],
        ),
        showlegend=False,
    ),
    row=2,
    col=1,
)
fig.add_hline(y=0, row=2, col=1)


fig.update_layout(
    width=400,
    height=600,
    margin=go.layout.Margin(
        l=0,  # left margin
        r=10,  # right margin
        b=50,  # bottom margin
        t=20,  # top margin
    ),
    legend=dict(yanchor="bottom", y=0.57, xanchor="right", x=0.99),
)
fig.update_yaxes(title_text="Mean FIRM score", row=1, col=1)
fig.update_yaxes(title_text="Mean FIRM difference", row=2, col=1)
fig.update_xaxes(title_text="Lead Day", row=1, col=1, dtick=1)
fig.update_xaxes(title_text="Lead Day", row=2, col=1, dtick=1)
fig.show()

In [147]:
fig.write_image("results/figures/fig_5_firm_recal.pdf")
fig.write_image("results/figures/fig_5_firm_recal.svg")