In [None]:
import xarray as xr
from scores.stats import statistical_tests
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from scores.processing import aggregate

RESULTS_PATH = "../results/dsc_results/"

weights = xr.open_dataarray("../data/station_weights/weights_099.nc")
weights = weights.fillna(0)

In [None]:
HRRR1_COLOUR = "#CC79A7"
HRRR7_9_COLOUR = "#009E73"
HRRR21_27_COLOUR = "#E69F00"
GRAPH1_COLOUR = "#56B4E9"
GRAPH3_COLOUR = "#0072B2"

# Recal twCRPS

In [None]:
twcrps_clim99 = xr.open_dataarray("../results/clim/clim_twcrps_099.nc")
twcrps_clim_mean99 = aggregate(
    twcrps_clim99, reduce_dims=["station", "time"], weights=weights
)

hrrr_27 = xr.open_dataarray(f"{RESULTS_PATH}hrrr21_27_potential_twcrps.nc")
hrrr_9 = xr.open_dataarray(f"{RESULTS_PATH}hrrr7_9_potential_twcrps.nc")
hrrr_1 = xr.open_dataarray(f"{RESULTS_PATH}hrrr1_potential_twcrps.nc")
graph1 = xr.open_dataarray(f"{RESULTS_PATH}graphcast_potential_twcrps.nc")
graph3 = xr.open_dataarray(f"{RESULTS_PATH}graphcast3_potential_twcrps.nc")

hrrr_27_mean = twcrps_clim_mean99 - aggregate(
    hrrr_27, reduce_dims=["station", "time"], weights=weights
)
hrrr_9_mean = twcrps_clim_mean99 - aggregate(
    hrrr_9, reduce_dims=["station", "time"], weights=weights
)
hrrr_1_mean = twcrps_clim_mean99 - aggregate(
    hrrr_1, reduce_dims=["station", "time"], weights=weights
)
graph1_mean = twcrps_clim_mean99 - aggregate(
    graph1, reduce_dims=["station", "time"], weights=weights
)
graph3_mean = twcrps_clim_mean99 - aggregate(
    graph3, reduce_dims=["station", "time"], weights=weights
)

hrrr_27_station_mean = aggregate(hrrr_27, reduce_dims=["station"], weights=weights)
hrrr_9_station_mean = aggregate(hrrr_9, reduce_dims=["station"], weights=weights)
hrrr_1_station_mean = aggregate(hrrr_1, reduce_dims=["station"], weights=weights)
graph1_station_mean = aggregate(graph1, reduce_dims=["station"], weights=weights)
graph3_station_mean = aggregate(graph3, reduce_dims=["station"], weights=weights)

In [None]:
# Construct confidence intervals for GraphCast 1x1 vs HRRR 7x9
diff_graph1_hrrr9 = graph1_station_mean - hrrr_9_station_mean
diff_graph1_hrrr9 = diff_graph1_hrrr9.assign_coords(
    h=("lead_time", [1, 2, 3, 4, 5, 6, 7, 8])
)
dm_graph1_hrrr9 = statistical_tests.diebold_mariano(
    diff_graph1_hrrr9, "lead_time", "h", confidence_level=0.99
)
# Construct confidence intervals for GraphCast 3x3 vs HRRR 21x27
diff_graph3_hrrr27 = graph3_station_mean - hrrr_27_station_mean
diff_graph3_hrrr27 = diff_graph3_hrrr27.assign_coords(
    h=("lead_time", [1, 2, 3, 4, 5, 6, 7, 8])
)
dm_graph3_hrrr27 = statistical_tests.diebold_mariano(
    diff_graph3_hrrr27, "lead_time", "h", confidence_level=0.99
)
# Construct confidence intervals for GraphCast 1x1 vs HRRR 1x1
diff_graph1_hrrr1 = graph1_station_mean - hrrr_1_station_mean
diff_graph1_hrrr1 = diff_graph1_hrrr1.assign_coords(
    h=("lead_time", [1, 2, 3, 4, 5, 6, 7, 8])
)
dm_graph1_hrrr1 = statistical_tests.diebold_mariano(
    diff_graph1_hrrr1, "lead_time", "h", confidence_level=0.99
)

In [None]:
x = ["6h", "12h", "18h", "24h", "30h", "36h", "42h", "48h"]
fig = make_subplots(
    rows=2,
    cols=2,
    subplot_titles=(
        "(a) DSC",
        "(b) Difference between GraphCast-GFS 1x1<br>and HRRR 1x1",
        "(c) Difference between GraphCast-GFS 1x1<br>and HRRR 7x9",
        "(d) Difference between GraphCast-GFS 3x3<br>and HRRR 21x27",
    ),
)

# Top left subplot
fig.add_trace(
    go.Scatter(
        x=x,
        y=hrrr_27_mean,
        mode="lines+markers",
        name="HRRR 21x27",
        line=dict(color=HRRR21_27_COLOUR),
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=x,
        y=hrrr_9_mean,
        mode="lines+markers",
        name="HRRR 7x9",
        line=dict(color=HRRR7_9_COLOUR),
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=x,
        y=hrrr_1_mean,
        mode="lines+markers",
        name="HRRR 1x1",
        line=dict(color=HRRR1_COLOUR),
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=x,
        y=graph3_mean,
        mode="lines+markers",
        line=dict(dash="dash", color=GRAPH3_COLOUR),
        name="GraphCast-GFS 3x3",
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=x,
        y=graph1_mean,
        mode="lines+markers",
        line=dict(dash="dash", color=GRAPH1_COLOUR),
        name="GraphCast-GFS 1x1",
    ),
    row=1,
    col=1,
)


# Top Right
fig.add_trace(
    go.Scatter(
        x=x,
        y=dm_graph1_hrrr1["mean"],
        line=dict(color="black"),
        error_y=dict(
            thickness=1,
            type="data",
            symmetric=False,
            array=dm_graph1_hrrr1["ci_upper"] - dm_graph1_hrrr1["mean"],
            arrayminus=dm_graph1_hrrr1["mean"] - dm_graph1_hrrr1["ci_lower"],
        ),
        showlegend=False,
    ),
    row=1,
    col=2,
)
fig.add_hline(y=0, row=1, col=2)

# Bottom left
fig.add_trace(
    go.Scatter(
        x=x,
        y=dm_graph1_hrrr9["mean"],
        line=dict(color="black"),
        error_y=dict(
            thickness=1,
            type="data",
            symmetric=False,
            array=dm_graph1_hrrr9["ci_upper"] - dm_graph1_hrrr9["mean"],
            arrayminus=dm_graph1_hrrr9["mean"] - dm_graph1_hrrr9["ci_lower"],
        ),
        showlegend=False,
    ),
    row=2,
    col=1,
)
fig.add_hline(y=0, row=2, col=1)

# Bottom right
fig.add_trace(
    go.Scatter(
        x=x,
        y=dm_graph3_hrrr27["mean"],
        line=dict(color="black"),
        error_y=dict(
            thickness=1,
            type="data",
            symmetric=False,
            array=dm_graph3_hrrr27["ci_upper"] - dm_graph3_hrrr27["mean"],
            arrayminus=dm_graph3_hrrr27["mean"] - dm_graph3_hrrr27["ci_lower"],
        ),
        showlegend=False,
    ),
    row=2,
    col=2,
)
fig.add_hline(y=0, row=2, col=2)

fig.update_layout(width=800, height=600, margin=dict(l=0, r=30, t=40, b=0))

fig.update_xaxes(title_text="Lead time", row=1, col=1)
fig.update_xaxes(title_text="Lead time", row=1, col=2)
fig.update_xaxes(title_text="Lead time", row=2, col=1)
fig.update_xaxes(title_text="Lead time", row=2, col=2)
fig.update_yaxes(title_text="DSC", row=1, col=1)
fig.update_yaxes(title_text="Difference in DSC", row=2, col=1)

fig.update_layout(
    legend=dict(
        orientation="h",
        x=0.5,
        y=-0.1,
        xanchor="center",
        yanchor="top",
        bordercolor="black",
        borderwidth=1,
    )
)
fig.add_annotation(
    text="Difference in DSC",
    xref="x2 domain",
    yref="y2 domain",
    x=-0.2,
    y=0.5,
    showarrow=False,
    textangle=-90,
    font=dict(size=14),
)
fig.add_annotation(
    text="Difference in DSC",
    xref="x4 domain",
    yref="y4 domain",
    x=-0.2,
    y=0.5,
    showarrow=False,
    textangle=-90,
    font=dict(size=14),
)
fig.show()

In [None]:
fig.write_image("../paper_figs/dsc.pdf")