In [None]:
from scores.probability import brier_score_for_ensemble
import xarray as xr
import numpy as np
import pandas as pd

HRRR_PATH27 = "../data/neighbourhood/hrrr_21_27/"
HRRR_PATH9 = "../data/neighbourhood/hrrr_7_9/"
HRRR_PATH1 = "../data/neighbourhood/hrrr_1_1/"
GRAPH_PATH = "../data/neighbourhood/graphcast_1/"
GRAPH_PATH3 = "../data/neighbourhood/graphcast_3/"
OBS_DATA_PATH = "../data/processed/obs/"

RESULTS_PATH = "../results/brier/"

In [None]:
obs = xr.open_dataset(OBS_DATA_PATH)
obs = obs.rename({"valid(UTC)": "time"})
obs = obs.precip
graphcast = xr.open_mfdataset(f"{GRAPH_PATH}*.nc")
graphcast = graphcast.apcp
graphcast = graphcast.compute() * 1000  # convert to mm
graphcast = graphcast.clip(min=0)

graphcast3 = xr.open_mfdataset(f"{GRAPH_PATH3}*.nc")
graphcast3 = graphcast3.apcp
graphcast3 = graphcast3.compute() * 1000  # convert to mm
graphcast3 = graphcast3.clip(min=0)

In [None]:
start_date = pd.to_datetime("2022-01-01")
end_date = pd.to_datetime("2024-09-02")
time_range = pd.date_range(start=start_date, end=end_date, freq="6h")

In [None]:
THRESHOLDS = np.arange(1, 101)
THRESHOLD_XR = xr.DataArray(
    THRESHOLDS, dims=["threshold"], coords={"threshold": THRESHOLDS}
)

In [None]:
graph1_results = []
graph3_results = []
hrrr1_results = []
hrrr9_results = []
hrrr27_results = []
for time in time_range:
    ob = obs.sel(time=time)
    year = time.year
    month = time.month
    day = time.day
    hour = time.hour
    if month < 10:
        month = f"0{month}"
    if day < 10:
        day = f"0{day}"
    if hour < 10:
        hour = f"0{hour}"
    try:
        hrrr1 = xr.open_dataset(
            f"{HRRR_PATH1}hrrr_{time.year}{month}{day}_{hour}_00.nc"
        )
        hrrr9 = xr.open_dataset(
            f"{HRRR_PATH9}hrrr_{time.year}{month}{day}_{hour}_00.nc"
        )
        hrrr27 = xr.open_dataset(
            f"{HRRR_PATH27}hrrr_{time.year}{month}{day}_{hour}_00.nc"
        )
    except:
        print(f"No data for HRRR {year}{month}{day}{hour}")
        continue
    hrrr1 = hrrr1.APCP_6hr_acc_fcst
    hrrr9 = hrrr9.APCP_6hr_acc_fcst
    hrrr27 = hrrr27.APCP_6hr_acc_fcst
    try:
        graph = graphcast.sel(time=time)
        graph3 = graphcast3.sel(time=time)
    except:
        print(f"No data for GRAPHCAST {year}{month}{day}{hour}")
        continue
    graph = graph.expand_dims("ens_mem")  # delete these two lines if neighborhood > 1
    hrrr1 = hrrr1.expand_dims("ens_mem")

    # Match NaNs
    hrrr1 = hrrr1.where(graph.count(["station", "ens_mem"]) > 0)
    hrrr1 = hrrr1.where(hrrr9.count(["station", "ens_mem"]) > 0)
    hrrr1 = hrrr1.where(hrrr27.count(["station", "ens_mem"]) > 0)
    hrrr1 = hrrr1.where(graph3.count(["station", "ens_mem"]) > 0)

    hrrr9 = hrrr9.where(graph.count(["station", "ens_mem"]) > 0)
    hrrr9 = hrrr9.where(hrrr1.count(["station", "ens_mem"]) > 0)
    hrrr9 = hrrr9.where(graph3.count(["station", "ens_mem"]) > 0)
    hrrr9 = hrrr9.where(hrrr27.count(["station", "ens_mem"]) > 0)

    hrrr27 = hrrr27.where(graph.count(["station", "ens_mem"]) > 0)
    hrrr27 = hrrr27.where(hrrr1.count(["station", "ens_mem"]) > 0)
    hrrr27 = hrrr27.where(graph3.count(["station", "ens_mem"]) > 0)
    hrrr27 = hrrr27.where(hrrr9.count(["station", "ens_mem"]) > 0)

    graph = graph.where(hrrr1.count(["station", "ens_mem"]) > 0)
    graph = graph.where(hrrr9.count(["station", "ens_mem"]) > 0)
    graph = graph.where(hrrr27.count(["station", "ens_mem"]) > 0)
    graph = graph.where(graph3.count(["station", "ens_mem"]) > 0)

    graph3 = graph3.where(hrrr1.count(["station", "ens_mem"]) > 0)
    graph3 = graph3.where(hrrr9.count(["station", "ens_mem"]) > 0)
    graph3 = graph3.where(hrrr27.count(["station", "ens_mem"]) > 0)
    graph3 = graph3.where(graph.count(["station", "ens_mem"]) > 0)

    result_hrrr1 = brier_score_for_ensemble(
        hrrr1, ob, "ens_mem", event_thresholds=THRESHOLDS, preserve_dims="all"
    )
    result_hrrr1 = result_hrrr1.expand_dims("time")
    result_hrrr1.to_netcdf(f"{RESULTS_PATH}hrrr1/{year}{month}{day}{hour}.nc")

    result_hrrr9 = brier_score_for_ensemble(
        hrrr9, ob, "ens_mem", event_thresholds=THRESHOLDS, preserve_dims="all"
    )
    result_hrrr9 = result_hrrr9.expand_dims("time")
    result_hrrr9.to_netcdf(f"{RESULTS_PATH}hrrr7_9/{year}{month}{day}{hour}.nc")

    result_hrrr27 = brier_score_for_ensemble(
        hrrr27, ob, "ens_mem", event_thresholds=THRESHOLDS, preserve_dims="all"
    )
    result_hrrr27 = result_hrrr27.expand_dims("time")
    result_hrrr27.to_netcdf(f"{RESULTS_PATH}hrrr21_27/{year}{month}{day}{hour}.nc")

    result_graph1 = brier_score_for_ensemble(
        graph, ob, "ens_mem", event_thresholds=THRESHOLDS, preserve_dims="all"
    )
    result_graph1 = result_graph1.expand_dims("time")
    result_graph1.to_netcdf(f"{RESULTS_PATH}graphcast1/{year}{month}{day}{hour}.nc")

    result_graph3 = brier_score_for_ensemble(
        graph3, ob, "ens_mem", event_thresholds=THRESHOLDS, preserve_dims="all"
    )
    result_graph3 = result_graph3.expand_dims("time")
    result_graph3.to_netcdf(f"{RESULTS_PATH}graphcast3/{year}{month}{day}{hour}.nc")

    print(f"Calculated Brier score for {year}{month}{day}{hour}")