In [None]:
import xarray as xr
import os
import numpy as np

NEIGHBORHOOD = 3
OBS_DATA_PATH = "../data/processed/obs/"
RESULT_PATH = "../data/neighbourhood/graphcast_{NEIGHBORHOOD}/"

In [None]:
ds = xr.open_dataset("../data/processed/graphcast/graphcast.nc")

In [None]:
obs_file_list = sorted([f for f in os.listdir(OBS_DATA_PATH) if f.endswith(".nc")])
results_file_list = sorted([f for f in os.listdir(RESULT_PATH) if f.endswith(".nc")])

In [None]:
def create_point_neighborhood(
    fcst, longitude, latitude, neighborhood, remove_incomplete_ens=True
):
    latitude_rounded = np.round(latitude / 0.25) * 0.25
    longitude_rounded = np.round(longitude / 0.25) * 0.25
    neighbor = fcst.sel(
        longitude=slice(360 + longitude_rounded - 0.25, 360 + longitude_rounded + 0.25),
        latitude=slice(latitude_rounded + 0.25, latitude_rounded - 0.25),
    )
    neighbor = neighbor.stack(ens_mem=("longitude", "latitude"))

    if remove_incomplete_ens:
        neighbor = neighbor.where(~np.isnan(neighbor).any(dim="ens_mem"), np.nan)
    if neighbor.ens_mem.size < NEIGHBORHOOD**2:
        return "edge case"
    neighbor = neighbor.drop_vars(["ens_mem", "longitude", "latitude"])
    neighbor = neighbor.assign_coords(ens_mem=np.arange(neighborhood**2))

    return neighbor

In [None]:
# CONUS domain
latN = 50.4
latS = 24.25
lonW = -126
lonE = -66

obs_metadata = []
for obs_file_name in obs_file_list:
    obs = xr.open_dataarray(f"{OBS_DATA_PATH}{obs_file_name}")
    station = obs.attrs["station"]
    latitude = obs.attrs["lat"]
    longitude = obs.attrs["lon"]
    obs_meta_dict = {"station": station, "latitude": latitude, "longitude": longitude}
    obs_metadata.append(obs_meta_dict)

In [None]:
concat_list = []
for t in range(len(ds.time)):
    ds_sel = ds.isel(time=t)
    point_list = []
    for obs_meta in obs_metadata:
        station = obs_meta["station"]
        latitude = obs_meta["latitude"]
        longitude = obs_meta["longitude"]
        if longitude < lonW or longitude > lonE or latitude > latN or latitude < latS:
            continue
        ens_point = create_point_neighborhood(ds_sel, longitude, latitude, NEIGHBORHOOD)
        if isinstance(ens_point, str):
            continue
        ens_point = ens_point.expand_dims("station")
        ens_point = ens_point.assign_coords({"station": [station]})
        point_list.append(ens_point)

    concat_list.append(xr.concat(point_list, dim="station"))
    print(f"{t/len(ds.time)} complete")
ensemble_point_fcst = xr.concat(concat_list, dim="time")