In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import os

from functools import partial


GRAPH_DATA_PATH = "../data/raw/graphcast/"
GRAPH_LT_DATA_PATH = "../data/processed/graphcast_lt/"
GRAPH_FINAL_DATA_PATH = "../data/processed/graphcast/"

In [2]:
file_list = sorted([f for f in os.listdir(GRAPH_DATA_PATH) if f.endswith(".nc")])
file_list_dates = [file.split("_")[3] for file in file_list]
file_list_dates = pd.to_datetime(file_list_dates, format="%Y%m%d")

In [None]:
# Get all valid times
ds = xr.open_mfdataset(
    f"{GRAPH_DATA_PATH}/*.nc", concat_dim="init_time", combine="nested"
)
time_list = ds.time.values

In [None]:
for time in time_list:
    time = pd.to_datetime(time)
    # Get the models for the three previous initialisations before the time
    previous_dates = file_list_dates[file_list_dates < time][-3:]
    if len(previous_dates) == 0:
        print(f"No data for {time}")
        continue
    previous_dates = previous_dates.strftime("%Y%m%d")
    file_list_select = ["GRAP_v100_GFS_" + date + "_00.nc" for date in previous_dates]
    ncs = [GRAPH_DATA_PATH + "/" + nc for nc in file_list_select]
    ds = xr.open_mfdataset(ncs, concat_dim="init_time", combine="nested")
    ds = ds.assign_coords(init_time=pd.to_datetime(previous_dates, format="%Y%m%d"))
    ds = ds.sel(time=time)
    lead_time = ds.time - ds.init_time
    ds.coords["init_time"] = lead_time.values
    ds = ds.rename({"init_time": "lead_time"})
    ds.to_netcdf(
        GRAPH_LT_DATA_PATH + f"/GRAP_v100_GFS_{time.strftime('%Y%m%d_%H')}_00.nc"
    )
    print(f"Saved data for {time}")

In [None]:
file_list = sorted([f for f in os.listdir(GRAPH_LT_DATA_PATH) if f.endswith(".nc")])
ncs = file_list
ncs = [GRAPH_LT_DATA_PATH + "/" + nc for nc in ncs]

In [None]:
broadcast_array = xr.DataArray(
    [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
    dims=["lead_time"],
    coords={"lead_time": pd.timedelta_range(start="6h", end="2d", freq="6h").values},
)


def broadcast(x, broadcast_array):
    x = x.broadcast_like(broadcast_array)
    return x


partial_func = partial(broadcast, broadcast_array=broadcast_array)

ds = xr.open_mfdataset(
    ncs, combine="nested", concat_dim="time", preprocess=partial_func
)
ds = ds.dropna("lead_time", how="all")

da = ds

da.to_netcdf(f"{GRAPH_FINAL_DATA_PATH}graphcast.nc")