In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import regionmask
import dask

In [None]:
ntime = 366  # number of time steps to load/process
chunk_size = 48  # x/y chunk size

## Generate example ar6 regions

In [None]:
ds = xr.open_zarr(
    'az://training/ERA5_daily/2000/', storage_options={'account_name': 'cmip6downscaling'}
).head(time=ntime)
mask = regionmask.defined_regions.ar6.land.mask(ds)

In [None]:
# split dataset into regions

pieces = {}

for key, group in ds['tasmax'].groupby(mask):
    pieces[key] = group.unstack('stacked_lat_lon')

# Save pieces

# for key, group in pieces.items():
#     group.to_dataset(name='tasmax').chunk({'time': -1, 'lat': 48, 'lon': 48}).to_zarr(f'az://scratch/regions/{key}.zarr', mode='w')

In [None]:
# Next:
# - given a dictionary of xarray datasets from the ar6 regions, merge into a single dataset
# - things to consider:
#   - memory use
#   - overlapping bounds
#   - wrapped coordinates (for example, region `1` will not plot due to unsorted coordinates)

## Generate a template for merged output

In [None]:
lon = np.arange(0, 360, 0.25)
lat = np.flip(np.arange(-90, 90.25, 0.25))
time = pd.date_range("2000-01-01", periods=ntime)

template = xr.Dataset(
    {
        "tasmax": (
            ('time', 'lat', 'lon'),
            np.full((len(time), len(lat), len(lon)), fill_value=np.nan, dtype=np.single),
        )
    },
    coords={
        "lat": lat,
        "lon": lon,
        "time": time,
    },
)
template.chunk({'lon': chunk_size, 'lat': chunk_size, 'time': -1})

In [None]:
# Create a mask with ar6 regions, replacing nan with 46 for 8-bit representation
mask = regionmask.defined_regions.ar6.land.mask(template).fillna(46).astype(np.byte)
mask = mask.chunk({'lon': chunk_size, 'lat': chunk_size})

## Merge using manual blocks

In [None]:
path = 'az://scratch/merged_regions_slow.zarr'
# Even though the template is chunked, specifying encoding here seems necessary to get the expected chunking in the final product
template.to_zarr(path, compute=False, mode="w", encoding={"tasmax": {"chunks": [-1, 48, 48]}})

In [None]:
@dask.delayed()
def merge_block_to_zarr(mask, path, *, xslice, yslice):
    """
    Find ar6 regions in each block, merge and reindex, write to zarr
    """
    components = pd.unique(mask.values.ravel())
    components = components[components <= 45]
    if components.size > 0:
        merged = (
            xr.merge(
                (
                    xr.open_zarr(
                        f'az://scratch/regions/{ind}.0.zarr'
                    )  # Add .0 because subsets were created with mask as float dtypes
                    .where(mask.isin(ind), drop=True)
                    .sortby(["lon", "lat"])
                    for ind in components
                )
            )
            .reindex_like(mask)
            .sortby("lat", ascending=False)
        ).compute()
        return merged.to_zarr(
            path,
            region={'lat': yslice, 'lon': xslice, 'time': slice(0, merged.sizes['time'])},
            mode="r+",
        )

In [None]:
%%time
# Iterate over chuncks and merge pieces within each chunk
n_chunk_per_block = 3
block_size = chunk_size * n_chunk_per_block
# Iterate over chuncks and merge pieces within each chunk
total = []
for ilon in range(0, mask.sizes['lon'], block_size):
    if ilon <= mask.sizes['lon'] - block_size:
        xslice = slice(ilon, ilon + block_size)
    else:
        xslice = slice(ilon, mask.sizes['lon'])
    for ilat in range(0, mask.sizes['lat'], block_size):
        if ilat <= mask.sizes['lat'] - block_size:
            yslice = slice(ilat, ilat + block_size)
        else:
            yslice = slice(ilat, mask.sizes['lat'])
        total.append(
            merge_block_to_zarr(
                mask.isel(lon=xslice, lat=yslice), path, xslice=xslice, yslice=yslice
            )
        )

In [None]:
%%time
result = dask.compute(*total, scheduler='single-threaded')

In [None]:
# Plot example of the result
data = xr.open_zarr(path)
data['tasmax'].isel(time=0).plot()
data