# Generate temperature figure (fig 6)
In this notebook we
- Produce the Black summer MSE/TWMSE plot (fig 6)
- Produce a station maps plot
- Calculate confidence intervals
- Evaluate the NT temperature forecasts. We don't show this figure in the paper, but we discuss the results.

**Note** At the time of writing this paper, the Jive implementation of threshold weighted MSE hadn't been migrated to scores. We have produced a cut down version (`.threshold_weighted_score.py`) to use to generate the figure here.

In [None]:
import numpy as np
import pandas as pd
import xarray as xr

from scores.continuous import mse
from scores.processing import broadcast_and_match_nan
from scores.stats.statistical_tests import diebold_mariano

from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px

from threshold_weighted_score import threshold_weighted_squared_error

## First - analyse performance for Black Summer

### Get data for Black summer

In [None]:
# Black summer data

official_bs = xr.open_dataarray(
    "data/temperature/Official_MaxT_00_20191201-20200229.nc"
)
# Note that the 18Z PtOCF was the automated guidance available to meteorologists for the
# afternoon (00Z) official forecast issue during the 2019-2020 summer
autofcst_bs = xr.open_dataarray(
    "data/temperature/PtOCF_MaxT_12_20191201-20200229.nc"
)
obs_bs = xr.open_dataarray("data/temperature/obs_MaxT_20191201-20200229.nc")

# Match missing data between datasets
official_bs, autofcst_bs, obs_bs = broadcast_and_match_nan(
    official_bs, autofcst_bs, obs_bs
)

### Calculate scores

In [None]:
# threshold weighted scores for Black Summer
# load thresholds
thresholds = pd.read_json("data/climate/max_t_0.97.json", typ="series").to_xarray()
thresholds = thresholds.rename({"index": "station_number"})

# get common stations between thresholds and the forecasts and obs
common_stations = set(thresholds.station_number.values).intersection(
    obs_bs.station_number.values
)
obs_bs = obs_bs.sel(station_number=list(common_stations))
official_bs = official_bs.sel(station_number=list(common_stations))
autofcst_bs = autofcst_bs.sel(station_number=list(common_stations))
thresholds = thresholds.sel(station_number=list(common_stations))
inf_threshold = thresholds * np.inf

# Calculate threshold weighted MSE
official_bs_tw = threshold_weighted_squared_error(
    official_bs,
    obs_bs,
    interval_where_one=(thresholds, inf_threshold),
    dims=["lead_day"],
)
autofcst_bs_tw = threshold_weighted_squared_error(
    autofcst_bs,
    obs_bs,
    interval_where_one=(thresholds, inf_threshold),
    dims=["lead_day"],
)

# Calculate normal MSE
autofcst_bs_mse = mse(autofcst_bs, obs_bs, preserve_dims="lead_day")
official_bs_mse = mse(official_bs, obs_bs, preserve_dims="lead_day")

### Produce plot

In [None]:
official_line_colour = "rgba(230,159,0,1)"
autofcst_line_colour = "rgba(86,180,233,1)"
figure = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=(
        "<b>(a)</b>",
        "<b>(b)</b>",
    ),
)
figure.update_annotations(font_size=12, xshift=-160, xanchor="left")

# Left - black summer MSE
figure.add_trace(
    go.Scatter(
        x=official_bs_mse.lead_day,
        y=official_bs_mse.values,
        line=dict(color=official_line_colour),
        name="Official",
        showlegend=True,
    ),
    row=1,
    col=1,
)
figure.add_trace(
    go.Scatter(
        x=autofcst_bs_mse.lead_day,
        y=autofcst_bs_mse.values,
        line=dict(color=autofcst_line_colour),
        name="Automated",
        showlegend=True,
    ),
    row=1,
    col=1,
)

# right - black summer TW MSE
figure.add_trace(
    go.Scatter(
        x=official_bs_tw.lead_day,
        y=official_bs_tw.values,
        line=dict(color=official_line_colour),
        name="Official",
        showlegend=False,
    ),
    row=1,
    col=2,
)
figure.add_trace(
    go.Scatter(
        x=autofcst_bs_tw.lead_day,
        y=autofcst_bs_tw.values,
        line=dict(color=autofcst_line_colour),
        name="Automated",
        showlegend=False,
    ),
    row=1,
    col=2,
)

figure.update_layout(
    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
    height=300,
    width=800,
    margin=go.layout.Margin(
        l=20,  # left margin
        r=20,  # right margin
        b=20,  # bottom margin
        t=20,  # top margin
    ),
)
figure.update_xaxes(
    title_text="Lead day", row=1, col=1, tickmode="linear", tick0=0, dtick=1
)
figure.update_xaxes(
    title_text="Lead day", row=1, col=2, tickmode="linear", tick0=0, dtick=1
)
figure.update_yaxes(title_text="MSE (°C<sup>2</sup>)", row=1, col=1)
figure.update_yaxes(title_text="Threshold weighted MSE (°C<sup>2</sup>)", row=1, col=2)

figure.add_annotation(
    x=4, y=2.5, text="Lower scores better", showarrow=False, row=1, col=1
)
figure.add_annotation(
    x=4, y=0.4, text="Lower scores better", showarrow=False, row=1, col=2
)

In [None]:
figure.write_image("results/figures/temperature.pdf")

### Generate map of stations

In [None]:
df = pd.read_csv("data/aws_metadata/station_data.csv")
df = df[df["station_number"].isin(official_bs.station_number.values)]


fig = px.scatter_geo(
    df, lat="LATITUDE", lon="LONGITUDE", color_discrete_sequence=["red"]
)

fig.update_geos(
    resolution=50,
    lonaxis_range=[110, 155],
    lataxis_range=[-45, -10],
    showcoastlines=True,
    showland=True,
    showocean=True,
    oceancolor="rgb(144, 195, 245)",
    showcountries=True,
    showframe=True,
    lonaxis=dict(showgrid=True, gridcolor="gray", gridwidth=0.5, dtick=5),
    lataxis=dict(showgrid=True, gridcolor="gray", gridwidth=0.5, dtick=5),
)

fig.update_traces(marker={"size": 4})
fig.update_layout(
    title="d)",
    height=350,
    width=400,
    margin=go.layout.Margin(
        l=0,  # left margin
        r=0,  # right margin
        b=0,  # bottom margin
        t=40,  # top margin
    ),
)
fig.show()

In [None]:
fig.write_image("results/station_maps/d_aus_temperature.pdf")

### Confidence intervals

In [None]:
# Black Summer
autofcst_bs_mse = mse(autofcst_bs, obs_bs, preserve_dims=["lead_day", "valid_15z_date"])
official_bs_mse = mse(official_bs, obs_bs, preserve_dims=["lead_day", "valid_15z_date"])
# Difference between Official and the hindcast
diff = autofcst_bs_mse - official_bs_mse
diff = diff.assign_coords(h=("lead_day", [2, 3, 4, 5, 6, 7, 8]))
dm_result = diebold_mariano(diff, "lead_day", "h")
dm_result

In [None]:
# Black Summer TW
official_bs_tw = threshold_weighted_squared_error(
    official_bs,
    obs_bs,
    interval_where_one=(thresholds, inf_threshold),
    dims=["lead_day", "valid_15z_date"],
)
autofcst_bs_tw = threshold_weighted_squared_error(
    autofcst_bs,
    obs_bs,
    interval_where_one=(thresholds, inf_threshold),
    dims=["lead_day", "valid_15z_date"],
)
# Difference between Official and the hindcast
diff = autofcst_bs_tw - official_bs_tw
diff = diff.assign_coords(h=("lead_day", [2, 3, 4, 5, 6, 7, 8]))
dm_result = diebold_mariano(diff, "lead_day", "h")
dm_result

# NT example

In [None]:
# NT 2016 data

official_nt_2016 = xr.open_dataarray(
    "data/temperature/Official_MaxT_00_20160501-20160930_NT.nc"
)
# Note that the 18Z PtOCF was the automated guidance available to meteorologists for the
# afternoon (00Z) official forecast issue during 2016
autofcst_nt_2016 = xr.open_dataarray(
    "data/temperature/PtOCF_MaxT_12_20160501-20160930_NT.nc"
)
obs_nt_2016 = xr.open_dataarray("data/temperature/obs_MaxT_20160501-20160930_NT.nc")

# Match missing data between datasets
official_nt_2016, autofcst_nt_2016, obs_nt_2016 = broadcast_and_match_nan(
    official_nt_2016, autofcst_nt_2016, obs_nt_2016
)

# Produce sample climatological forecast
nt_2016_clim_fcst = obs_nt_2016.mean("valid_15z_date") * (obs_nt_2016 * 0 + 1)

# NT 2020 data

official_nt_2020 = xr.open_dataarray(
    "data/temperature/Official_MaxT_00_20200501-20200930_NT.nc"
)
# Note that the 12Z AutoFcst was the automated guidance available to meteorologists for the
# afternoon (00Z) official forecast issue during the 2020 dry season
autofcst_nt_2020 = xr.open_dataarray(
    "data/temperature/AutoFcst_MaxT_12_20200501-20200930_NT.nc"
)
obs_nt_2020 = xr.open_dataarray("data/temperature/obs_MaxT_20200501-20200930_NT.nc")

# Match missing data between datasets
official_nt_2020, autofcst_nt_2020, obs_nt_2020 = broadcast_and_match_nan(
    official_nt_2020, autofcst_nt_2020, obs_nt_2020
)

# Produce sample climatological forecast
nt_2020_clim_fcst = obs_nt_2020.mean("valid_15z_date") * (obs_nt_2020 * 0 + 1)

In [None]:
# 2016 NT MSE
autofcst_2016_mse = mse(autofcst_nt_2016, obs_nt_2016, preserve_dims="lead_day")
official_2016_mse = mse(official_nt_2016, obs_nt_2016, preserve_dims="lead_day")
clim_2016_mse = mse(nt_2016_clim_fcst, obs_nt_2016, preserve_dims="lead_day")
autofcst_2016_ss = 1 - (autofcst_2016_mse / clim_2016_mse)
official_2016_ss = 1 - (official_2016_mse / clim_2016_mse)

# 2020 NT MSE
autofcst_2020_mse = mse(autofcst_nt_2020, obs_nt_2020, preserve_dims="lead_day")
official_2020_mse = mse(official_nt_2020, obs_nt_2020, preserve_dims="lead_day")
clim_2020_mse = mse(nt_2016_clim_fcst, obs_nt_2016, preserve_dims="lead_day")
autofcst_2020_ss = 1 - (autofcst_2020_mse / clim_2020_mse)
official_2020_ss = 1 - (official_2020_mse / clim_2020_mse)

In [None]:
official_line_colour = "rgba(230,159,0,1)"
autofcst_line_colour = "rgba(86,180,233,1)"
figure = make_subplots(rows=1, cols=2, subplot_titles=("<b>(a)</b>", "<b>(b)</b>"))
figure.update_annotations(font_size=12, xshift=-160, xanchor="left")

# Upper left - NT MSE 2016
figure.add_trace(
    go.Scatter(
        x=official_2016_ss.lead_day,
        y=official_2016_ss.values,
        line=dict(color=official_line_colour),
        name="Official",
    ),
    row=1,
    col=1,
)
figure.add_trace(
    go.Scatter(
        x=autofcst_2016_ss.lead_day,
        y=autofcst_2016_ss.values,
        line=dict(color=autofcst_line_colour),
        name="Automated",
    ),
    row=1,
    col=1,
)

# Upper right - NT MSE 2020
figure.add_trace(
    go.Scatter(
        x=official_2020_ss.lead_day,
        y=official_2020_ss.values,
        line=dict(color=official_line_colour),
        name="Official",
        showlegend=False,
    ),
    row=1,
    col=2,
)
figure.add_trace(
    go.Scatter(
        x=autofcst_2020_ss.lead_day,
        y=autofcst_2020_ss.values,
        line=dict(color=autofcst_line_colour),
        name="Automated",
        showlegend=False,
    ),
    row=1,
    col=2,
)


figure.update_layout(
    legend=dict(x=0.01, y=0.99),
    height=300,
    width=800,
    margin=go.layout.Margin(
        l=20,  # left margin
        r=20,  # right margin
        b=20,  # bottom margin
        t=20,  # top margin
    ),
)
figure.update_xaxes(
    title_text="Lead day", row=1, col=1, tickmode="linear", tick0=0, dtick=1
)
figure.update_xaxes(
    title_text="Lead day", row=1, col=2, tickmode="linear", tick0=0, dtick=1
)
figure.update_yaxes(title_text="MSE skill score", row=1, col=1)
figure.update_yaxes(title_text="MSE skill score", row=1, col=2)

figure.update_layout(legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99))