In [None]:
import numpy as np
import xarray as xr
import regionmask

In [None]:
ds = xr.open_zarr(
    'az://training/ERA5_daily/2000/', storage_options={'account_name': 'cmip6downscaling'}
).head(time=10)
ds

In [None]:
regionmask.defined_regions.ar6.land.plot()

In [None]:
mask = regionmask.defined_regions.ar6.land.mask(ds)

In [None]:
mask.plot()

In [None]:
# split dataset into regions

pieces = {}

for key, group in ds['tasmax'].groupby(mask):
    pieces[key] = group.unstack('stacked_lat_lon')

# plot sample region
pieces[0].isel(time=0).plot()

In [None]:
# Next:
# - given a dictionary of xarray datasets from the ar6 regions, merge into a single dataset
# - things to consider:
#   - memory use
#   - overlapping bounds
#   - wrapped coordinates (for example, region `1` will not plot due to unsorted coordinates)

In [None]:
# Create a new dataset with the expected shape
expected_ds = xr.full_like(ds[['tasmax']], np.nan)

In [None]:
# Sort the pieces by geographic coordinates and merge into a single dataset
combined_ds = xr.merge(
    [expected_ds, xr.merge((da.sortby(["lat", "lon"]) for da in pieces.values()))], join="left"
)

In [None]:
# Check that the original and reconstructed data are identical
xr.testing.assert_identical(ds['tasmax'].where(mask >= 0), combined_ds['tasmax'])