In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import regionmask
from distributed import Client

In [None]:
client = Client()
client

In [None]:
ntime = 366  # number of time steps to load/process
chunk_size = 48  # x/y chunk size

## Generate example ar6 regions

In [None]:
ds = xr.open_zarr(
    'az://training/ERA5_daily/2000/', storage_options={'account_name': 'cmip6downscaling'}
).head(time=ntime)
mask = regionmask.defined_regions.ar6.land.mask(ds)

In [None]:
# split dataset into regions

pieces = {}

for key, group in ds['tasmax'].groupby(mask):
    pieces[key] = group.unstack('stacked_lat_lon')

# Save pieces

# for key, group in pieces.items():
#     group.to_dataset(name='tasmax').chunk({'time': -1, 'lat': 48, 'lon': 48}).to_zarr(f'az://scratch/regions/{key}.zarr', mode='w')

In [None]:
# Next:
# - given a dictionary of xarray datasets from the ar6 regions, merge into a single dataset
# - things to consider:
#   - memory use
#   - overlapping bounds
#   - wrapped coordinates (for example, region `1` will not plot due to unsorted coordinates)

## Generate a template for merged output

In [None]:
lon = np.arange(0, 360, 0.25)
lat = np.flip(np.arange(-90, 90.25, 0.25))
time = pd.date_range("2000-01-01", periods=ntime)

template = xr.Dataset(
    {"tasmax": (('time', 'lat', 'lon'), np.full((len(time), len(lat), len(lon)), np.nan))},
    coords={
        "lat": lat,
        "lon": lon,
        "time": time,
    },
)
template.chunk({'lon': chunk_size, 'lat': chunk_size, 'time': -1})

In [None]:
# Create a mask with ar6 regions
mask = regionmask.defined_regions.ar6.land.mask(template)

## Merge using xarray.DataArray.map_blocks

In [None]:
def merge_block(mask):
    """
    Find ar6 regions in each block, merge pieces, and reindex
    """
    components = pd.unique(mask.values.ravel())
    components = components[~np.isnan(components)]
    if components.size > 0:
        merged = (
            xr.merge(
                (
                    xr.open_zarr(f'az://scratch/regions/{ind}.zarr')
                    .where(mask.isin(ind))
                    .sortby(["lon", "lat"])
                    for ind in components
                )
            )
            .reindex_like(mask)
            .sortby("lat", ascending=False)
        )
    else:
        merged = mask.expand_dims(time=template.coords['time']).to_dataset(name='tasmax')
    return merged.load()

In [None]:
# Chunk the mask and template
path = 'az://scratch/merged_regions.zarr'
mask = mask.chunk({'lon': chunk_size, 'lat': chunk_size})
template = template.chunk({'lon': chunk_size, 'lat': chunk_size, 'time': -1})
# Generate merged dataset by calling map_blocks on the chunked mask
mask = mask.chunk({'lon': chunk_size, 'lat': chunk_size})
result = mask.map_blocks(merge_block, template=template)
result.to_zarr(path, mode='w')

In [None]:
# Plot example of the result
data = xr.open_zarr(path)
data['tasmax'].isel(time=0).plot()

## Merge using manual blocks

In [None]:
# Chunk the mask and template
mask = mask.chunk({'lon': chunk_size, 'lat': chunk_size})
template.chunk({'lon': chunk_size, 'lat': chunk_size, 'time': -1})
# Create a zarr group
path = "'az://scratch/merged_regions_slow.zarr'"
template.to_zarr(path, compute=False, mode="w")

In [None]:
def merge_block_to_zarr(mask, path, *, xslice, yslice):
    """
    Find ar6 regions in each block, merge and reindex, write to zarr
    """
    components = pd.unique(mask.values.ravel())
    components = components[~np.isnan(components)]
    if components.size > 0:
        merged = (
            xr.merge(
                (
                    xr.open_zarr(f'az://scratch/regions/{ind}.zarr')
                    .where(mask.isin(ind))
                    .sortby(["lon", "lat"])
                    for ind in components
                )
            )
            .reindex_like(mask)
            .sortby("lat", ascending=False)
            .compute()
        )
        merged.to_zarr(
            path, region={'lat': yslice, 'lon': xslice, 'time': slice(0, merged.sizes['time'])}
        )  # Note: compute=True does not work as expected in to_zarr() here, placed after xr.merge instead

In [None]:
# Iterate over chuncks and merge pieces within each chunk
for ilon in range(0, mask.sizes['lon'], chunk_size):
    if ilon <= mask.sizes['lon'] - chunk_size:
        xslice = slice(ilon, ilon + chunk_size)
    else:
        xslice = slice(ilon, mask.sizes['lon'])
    for ilat in range(0, mask.sizes['lat'], chunk_size):
        if ilat <= mask.sizes['lat'] - chunk_size:
            yslice = slice(ilat, ilat + chunk_size)
        else:
            yslice = slice(ilat, mask.sizes['lat'])
        merge_block_to_zarr(mask.isel(lon=xslice, lat=yslice), path, xslice=xslice, yslice=yslice)

In [None]:
# Plot example of the result
data = xr.open_zarr(path)
data['tasmax'].isel(time=0).plot()