In [None]:
from scores.processing import broadcast_and_match_nan
import xarray as xr
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
HRRR_PATH1 = f"../data/neighbourhood/hrrr_1_1/"
GRAPH_PATH = f"../data/neighbourhood/graphcast_1/"
OBS_DATA_PATH = "../data/processed/obs/"

In [None]:
obs = xr.open_dataset(OBS_DATA_PATH)
obs = obs.rename({"valid(UTC)": "time"})
obs = obs.precip
graphcast = xr.open_mfdataset(f"{GRAPH_PATH}*.nc")
graphcast = graphcast.apcp
graphcast = graphcast.compute() * 1000  # convert to mm
graphcast = graphcast.clip(min=0)

In [None]:
start_date = pd.to_datetime("2022-01-01")
end_date = pd.to_datetime("2024-09-01")
time_range = pd.date_range(start=start_date, end=end_date, freq="6h")

hrrr_results = []
for time in time_range:
    ob = obs.sel(time=time)
    year = time.year
    month = time.month
    day = time.day
    hour = time.hour
    if month < 10:
        month = f"0{month}"
    if day < 10:
        day = f"0{day}"
    if hour < 10:
        hour = f"0{hour}"
    try:
        hrrr = xr.open_dataset(
            f"{HRRR_PATH1}hrrr_{time.year}{month}{day}_{hour}_00.nc"
        ).sel(lead_time=slice(pd.Timedelta("6h"), pd.Timedelta("2D")))
    except:
        print(f"No data for {time}")
        continue
    hrrr = hrrr.APCP_6hr_acc_fcst
    hrrr = hrrr.expand_dims("time")
    hrrr_results.append(hrrr)

In [None]:
hrrr = xr.concat(hrrr_results, dim="time")

In [None]:
lead_time = pd.Timedelta("6h")
hrrr6 = hrrr.sel(lead_time=lead_time)
graphcast6 = graphcast.sel(lead_time=lead_time)
hrrr6, graphcast6, obs6 = broadcast_and_match_nan(hrrr6, graphcast6, obs)

In [None]:
quants = np.linspace(0.9, 1, 10000)
hrrr_quantiles6 = hrrr6.quantile(quants)
graphcast_quantiles6 = graphcast6.quantile(quants)
obs_quantiles6 = obs6.quantile(quants)

In [None]:
lead_time = pd.Timedelta("30h")
hrrr30 = hrrr.sel(lead_time=lead_time)
graphcast30 = graphcast.sel(lead_time=lead_time)
hrrr30, graphcast30, obs30 = broadcast_and_match_nan(hrrr30, graphcast30, obs)

In [None]:
quants = np.linspace(0.9, 1, 10000)
hrrr_quantiles30 = hrrr30.quantile(quants)
graphcast_quantiles30 = graphcast30.quantile(quants)
obs_quantiles30 = obs30.quantile(quants)

In [None]:
HRRR1_COLOUR = "#CC79A7"
HRRR7_9_COLOUR = "#009E73"
HRRR21_27_COLOUR = "#E69F00"
GRAPH1_COLOUR = "#56B4E9"
GRAPH3_COLOUR = "#0072B2"

fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=(
        "(a) 6 hour lead time",
        "(b) 30 hour lead time",
    ),
)

fig.add_trace(
    go.Scatter(
        x=obs_quantiles6,
        y=hrrr_quantiles6,
        mode="markers",
        name="HRRR 1X1",
        line=dict(color=HRRR1_COLOUR),
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=obs_quantiles6,
        y=graphcast_quantiles6,
        mode="markers",
        name="GRAPHCAST 1X1",
        line=dict(color=GRAPH1_COLOUR),
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Scatter(
        x=[0, 240], y=[0, 240], showlegend=False, mode="lines", line=dict(color="black")
    ),
    row=1,
    col=1,
)

fig.add_trace(
    go.Scatter(
        x=obs_quantiles30,
        y=hrrr_quantiles30,
        showlegend=False,
        mode="markers",
        name="HRRR 1X1",
        line=dict(color=HRRR1_COLOUR),
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        x=obs_quantiles30,
        y=graphcast_quantiles30,
        showlegend=False,
        mode="markers",
        name="GRAPHCAST 1X1",
        line=dict(color=GRAPH1_COLOUR),
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Scatter(
        x=[0, 240], y=[0, 240], showlegend=False, mode="lines", line=dict(color="black")
    ),
    row=1,
    col=2,
)

fig.update_layout(
    width=700,
    height=300,
    margin=dict(l=0, r=30, t=40, b=0),
    legend=dict(x=0.818, y=0.99, xanchor="right", yanchor="top"),
)
fig.update_xaxes(title_text="Observed (mm)", row=1, col=1)
fig.update_yaxes(title_text="Forecast (mm)", row=1, col=1)
fig.update_xaxes(title_text="Observed (mm)", row=1, col=2)
fig.write_image("../paper_figs/q-q_plot.pdf")
fig.show()