In [66]:
import xarray as xr
import os
import numpy as np
import pandas as pd
import dask

from dask.distributed import Client
client = Client()


NEIGHBORHOOD = 3
GRAPH_LT_DATA_PATH = "/g/data/wa46/user/nl5316/tw_spatial/graphcast_combined/"
OBS_DATA_PATH = "/g/data/wa46/user/nl5316/tw_spatial/obs/"
RESULT_PATH = f"/g/data/wa46/user/nl5316/tw_spatial/graphcast_neighborhood/{NEIGHBORHOOD}/"

Perhaps you already have a cluster running?
Hosting the HTTP server on port 40889 instead


In [6]:
ds = xr.open_dataset(
    "/g/data/wa46/user/nl5316/tw_spatial/graphcast_combined/graphcast.nc", 
)

In [7]:
obs_file_list = sorted([f for f in os.listdir(OBS_DATA_PATH) if f.endswith(".nc")])

In [18]:
def create_neighourhood_ensemble(fcst, neighborhood, remove_incomplete_ens=True):
    ensemble = fcst.rolling(
        dim=dict(longitude=neighborhood, latitude=neighborhood), center=True
    ).construct(longitude="i", latitude="j")
    ensemble = ensemble.stack(ens_mem=("i", "j"))
    if remove_incomplete_ens:
        ensemble = ensemble.where(~np.isnan(ensemble).any(dim="ens_mem"), np.nan)
    # Clean up coordinates for crps calculation
    ensemble = ensemble.drop_vars(["ens_mem", "i", "j"])
    ensemble = ensemble.assign_coords(ens_mem=np.arange(neighborhood**2))
    return ensemble

In [None]:
# HRRR domain
latN = 50.4
latS = 24.25
lonW = -123.8
lonE = -71.2

ds_sel = ds.isel(time=10)


concat_list = []
for t in range(len(ds.time)):
# for t in range(10):
    ds_sel = ds.isel(time=t)
    # Create neighborhood ensemble
    ens = create_neighourhood_ensemble(ds_sel, NEIGHBORHOOD)

    station_list = []
    point_list = []
    for i in range(len(obs_file_list)):
        obs_file_name = obs_file_list[i]
        obs = xr.open_dataarray(f"{OBS_DATA_PATH}{obs_file_name}")
        station = obs.attrs["station"]
        latitude = obs.attrs["lat"]
        longitude = obs.attrs["lon"]
        if longitude < lonW or longitude > lonE or latitude > latN or latitude < latS:
            continue
        ens_point = ens.sel(latitude=latitude, longitude=longitude, method="nearest")
        ens_point = ens_point.expand_dims("station")
        ens_point = ens_point.assign_coords({"station": [station]})
        ens_point = ens_point.drop(["latitude", "longitude"])
        point_list.append(ens_point)
        station_list.append(station)
    concat_list.append(xr.concat(point_list, dim="station"))
    print(f"{t/len(ds.time)} complete")
ensemble_point_fcst = xr.concat(concat_list, dim="time")

0.0 complete
0.00025687130747495504 complete
0.0005137426149499101 complete
0.0007706139224248652 complete
0.0010274852298998202 complete
0.0012843565373747753 complete
0.0015412278448497304 complete
0.0017980991523246853 complete
0.0020549704597996403 complete


In [67]:
ensemble_point_fcst