In [None]:
import os
import numpy as np
import pandas as pd
import xarray as xr

def extend_data(ds, dim, n):
    ds_ext = ds.isel({dim: range(-n, 0)})
    ds_ext = ds_ext.reindex({dim: ds_ext[dim][::-1]})
    ds_ext[dim] = np.arange(1, n+1) * 0.055 + ds[dim][-1].values
    
    print(ds_ext[dim][2] - ds_ext[dim][1])
    
    ds = xr.concat([ds, ds_ext], dim=dim)
    
    return ds

In [None]:
# script parameters
datadir = "/p/scratch/deepacf/maelstrom/maelstrom_data/ap5/downscaling_benchmark_dataset/preprocessed_era5_crea6/t2m_mlevels/netcdf_data/all_files/train/"
outdir = "/p/scratch/deepacf/maelstrom/maelstrom_data/ap5/downscaling_destine_aq/"
nyx_tar = (144, 144)
fname = "downscaling_tier2_test.nc"

In [None]:
# create output-directory if necessary
os.makedirs(outdir, exist_ok=True)

# Read data
datafile = os.path.join(datadir, fname)
print(f"Read data file '{datafile}'...")
ds = xr.open_dataset(datafile)

In [None]:
# modify temporally time step
times_save = ds["time"]

# offset-trick to obtain 24 time steps starting from 1 to 24 UTC with groupby
ds["time"] = pd.to_datetime(ds["time"]) + pd.DateOffset(hours=-1)

# extend data if desired
if nyx_tar:
    nyx_in = (len(ds["rlat"]), len(ds["rlon"]))
    dnyx = np.array(nyx_tar) - np.array(nyx_in)
    
    ds = extend_data(ds, "rlat", dnyx[0])
    ds = extend_data(ds, "rlon", dnyx[1])   
    

In [None]:
# group data for convenient iterating over dataset
ds_grouped = ds.groupby("time.dayofyear")

In [None]:
from tqdm import tqdm

for dd, ds_day in ds_grouped:
    date_now = pd.to_datetime(ds_day["time"][0].values)
    
    if len(ds_day["time"]) != 24:
        print(f"Skipping {date_now.strftime('%Y-%m-%d')} due to incomplete data.")
        continue
    
    # undo offsetting
    ds_day["time"] = pd.to_datetime(ds_day["time"]) - pd.DateOffset(hours=-1)
    
    date_str = date_now.strftime("%Y%m%d")    
    fname_out = os.path.join(outdir, fname.replace(".nc", f"_{date_str}.nc"))
    print(f"Save data for {date_now.strftime('%Y-%m-%d')} to '{fname_out}'...")
    
    ds_day.to_netcdf(fname_out)
    

