# Generate figures for SW WA wind biases (fig 5)
In this notebook we:
1. evaluate if meteorologists are improving on the automated wind speed forecasts over southwestern WA (fig 5), 
2. produce a station maps plot, and
3. calculate if the difference in errors is statistically significant.

In [1]:
import numpy as np
import pandas as pd
import xarray as xr

from scores.continuous import mse
from scores.processing import broadcast_and_match_nan
from scores.stats.statistical_tests import diebold_mariano

from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px

In [2]:
official = xr.open_dataarray(
    "../data/sw_wa_wind/Official_WindMag_00_20230601-20230830_sw_wa.nc"
)
# Note that the 12Z AutoFcst was the automated guidance available to meteorologists for the
# afternoon (00Z) official forecast issue.
autofcst = xr.open_dataarray(
    "../data/sw_wa_wind/AutoFcst_WindMag_12_20230601-20230830_sw_wa.nc"
)
obs = xr.open_dataarray("../data/sw_wa_wind/obs_WindMag_20230601-20230830sw_wa.nc")

# Match missing data between datasets
official, autofcst, obs = broadcast_and_match_nan(official, autofcst, obs)

### Calculate MSE and multiplicative bias

In [49]:
official_mse = mse(official, obs, preserve_dims="lead_day")
autofcst_mse = mse(autofcst, obs, preserve_dims="lead_day")

official_bias = official.mean(["valid_start", "station_number"]) / obs.mean(
    ["valid_start", "station_number"]
)
autofcst_bias = autofcst.mean(["valid_start", "station_number"]) / obs.mean(
    ["valid_start", "station_number"]
)

In [50]:
official_line_colour = "rgba(230,159,0,1)"
autofcst_line_colour = "rgba(86,180,233,1)"
figure = make_subplots(
    rows=2, cols=1, subplot_titles=("<b>(a)</b>", "<b>(b)</b>"), vertical_spacing=0.15
)
figure.update_annotations(font_size=12, xshift=-160, xanchor="left")

# Upper subfig - MSE
figure.add_trace(
    go.Scatter(
        x=official_mse.lead_day,
        y=official_mse.values,
        line=dict(color=official_line_colour),
        name="Official",
    ),
    row=1,
    col=1,
)
figure.add_trace(
    go.Scatter(
        x=autofcst_mse.lead_day,
        y=autofcst_mse.values,
        line=dict(color=autofcst_line_colour),
        name="Automated",
    ),
    row=1,
    col=1,
)

# Lower subfig - multiplicative bias
figure.add_trace(
    go.Scatter(
        x=[1, 7],
        y=[1, 1],
        line=dict(color="black", dash="dash"),
        mode="lines",
        name="Automated",
        showlegend=False,
    ),
    row=2,
    col=1,
)
figure.add_trace(
    go.Scatter(
        x=official_bias.lead_day,
        y=official_bias.values,
        line=dict(color=official_line_colour),
        name="Official",
        showlegend=False,
    ),
    row=2,
    col=1,
)
figure.add_trace(
    go.Scatter(
        x=autofcst_bias.lead_day,
        y=autofcst_bias.values,
        line=dict(color=autofcst_line_colour),
        name="Automated",
        showlegend=False,
    ),
    row=2,
    col=1,
)

figure.add_annotation(
    x=4, y=12, text="Lower scores better", showarrow=False, row=1, col=1
)
figure.add_annotation(
    x=4, y=1.05, text="Overforecast bias ↑", showarrow=False, row=2, col=1
)
figure.add_annotation(
    x=4, y=0.9, text="Underforecast bias ↓", showarrow=False, row=2, col=1
)

figure.update_layout(
    legend=dict(x=0.01, y=0.99),
    height=600,
    width=400,
    margin=go.layout.Margin(
        l=20,  # left margin
        r=20,  # right margin
        b=20,  # bottom margin
        t=20,  # top margin
    ),
)
figure.update_xaxes(
    title_text="Lead day", row=1, col=1, tickmode="linear", tick0=0, dtick=1
)
figure.update_xaxes(
    title_text="Lead day", row=2, col=1, tickmode="linear", tick0=0, dtick=1
)
figure.update_yaxes(title_text="MSE (kt<sup>2</sup>)", row=1, col=1)
figure.update_yaxes(title_text="Multiplicative bias", row=2, col=1)

In [51]:
figure.write_image("../figures/results/sw_wa_wind_bias.svg")

### Generate map of stations

In [4]:
df = pd.read_csv("../data/aws_metadata/station_data.csv")
df = df[df["station_number"].isin(official.station_number.values)]


fig = px.scatter_geo(
    df, lat="LATITUDE", lon="LONGITUDE", color_discrete_sequence=["red"]
)

fig.update_geos(
    resolution=50,
    lonaxis_range=[110, 155],
    lataxis_range=[-45, -10],
    showcoastlines=True,
    showland=True,
    showocean=True,
    oceancolor="rgb(144, 195, 245)",
    showcountries=True,
    showframe=True,
    lonaxis=dict(showgrid=True, gridcolor="gray", gridwidth=0.5, dtick=5),
    lataxis=dict(showgrid=True, gridcolor="gray", gridwidth=0.5, dtick=5),
)

fig.update_traces(marker={"size": 4})
fig.update_layout(
    height=350,
    width=400,
    margin=go.layout.Margin(
        l=0,  # left margin
        r=0,  # right margin
        b=0,  # bottom margin
        t=0,  # top margin
    ),
)
fig.show()

In [5]:
fig.write_image("../figures/station_maps/sw_wa_stations.svg")

### Check statistical significance

In [52]:
official_mse = mse(official, obs, preserve_dims=["lead_day", "valid_start"])
autofcst_mse = mse(autofcst, obs, preserve_dims=["lead_day", "valid_start"])

In [53]:
# Difference between Official and the hindcast
diff = official_mse - autofcst_mse
diff = diff.assign_coords(h=("lead_day", [2, 3, 4, 5, 6, 7, 8]))
dm_result = diebold_mariano(diff, "lead_day", "h")
dm_result