In [32]:
import warnings
import subprocess
import time

import numpy as np
import xarray
import zarr
from zarr.errors import UnstableSpecificationWarning

In [26]:
warnings.filterwarnings(
    "ignore",
    message="Consolidated metadata is currently not part in the Zarr format*",
    category=UserWarning,
)
warnings.filterwarnings(
    "ignore",
    message="The data type*",
    category=UnstableSpecificationWarning,
)

In [52]:
# source = "../raster/stack.v3.zarr"
# source = "../raster/target.zarr"
# shard_size = 512
# chunk_size = 64

# source = "../raster/target.v2.zarr"
# target = "../raster/target.v3.zarr"
# shard_size = 256
# chunk_size = 32

# source = "../raster/target.v2.zarr"
# target = "../raster/target.v4.zarr"
# shard_size = 128
# chunk_size = 16


source = "../raster/target.v2.zarr"
target = "../raster/target.v5.zarr"
shard_size = 512
chunk_size = 16

store = xarray.open_groups(source)
groups = [k for k in store.keys() if k.count('/') == 2]

### Rework empty target.zarr chunking using zarr directly

In [54]:
for group_name in groups:
    print(group_name)
    array_name = group_name.split("/")[2]
    group = zarr.open_group(source, mode='r', path=group_name)
    array = group[array_name]
    shape = array.shape
    lon_ix = array.metadata.dimension_names.index('lon')
    lat_ix = array.metadata.dimension_names.index('lat')

    new_chunks = list(shape)
    new_chunks[lon_ix] = chunk_size
    new_chunks[lat_ix] = chunk_size

    new_shards = list(shape)
    new_shards[lon_ix] = shard_size
    new_shards[lat_ix] = shard_size

    target_group = zarr.open_group(target, mode='w', path=group_name)
    z = target_group.create_array(
        name=array_name,
        shape=array.shape,
        dtype=array.dtype,
        chunks=new_chunks,
        shards=new_shards,
        compressors=array.compressors,
        fill_value=array.fill_value,
        dimension_names=array.metadata.dimension_names,
        overwrite=True
    )
    for dim in array.metadata.dimension_names:
        dim_array = group[dim]
        target_group.create_array(
            name=dim,
            data=dim_array
        )

zarr.consolidate_metadata(target)

/aqueduct/depth_coastal
/aqueduct/depth_fluvial
/cdd_miranda/CDD_absolute_mean_change_from_15_to_20
/cdd_miranda/CDD_relative_mean_change_from_15_to_20
/dem/elevation
/dem/slope
/gem_earthquake/pga
/ghsl_buildings/built_surface
/ghsl_pop/population
/iris/ws
/isimip/drought_exposure
/isimip/drought_occurrence
/isimip/extreme_heat_exposure
/isimip/extreme_heat_occurrence
/jrc_flood/depth
/land_cover/lc
/nature/biodiversity_intactness
/nature/forest_landscape_integrity
/nature/organic_carbon
/storm/ws
/traveltime_to_healthcare/travel_time


<Group file://../raster/target.v5.zarr>

In [55]:
from tqdm.auto import tqdm
import itertools

In [None]:
for group_name in groups:
    ds = xarray.open_zarr(source, group=group_name)
    # read_shard_size = shard_size * 2
    # nlat = (ds.sizes['lat'] // read_shard_size) + 1
    # nlon = (ds.sizes['lon'] // read_shard_size) + 1

    array_name = group_name.split("/")[-1]
    da = ds[array_name]
    dims = da.dims
    chunk_sizes = []
    shard_sizes = []
    for d in dims:
        if d not in ("lat", "lon"):
            chunk_sizes.append("0")
            shard_sizes.append("0")
        else:
            chunk_sizes.append(str(chunk_size))
            shard_sizes.append(str(shard_size))

    print(" ".join([
        "zarrs_reencode",
        "-c", ",".join(chunk_sizes),
        "-s", ",".join(shard_sizes),
        "--concurrent-chunks", "32",
        "--direct-io",
        f"{source}{group_name}/{array_name}",
        f"{target}{group_name}/{array_name}",
    ]))

    # for i, j in tqdm(itertools.product(range(nlat), range(nlon)), desc=group_name, total=(nlat*nlon)):
    #     lat0 = i * read_shard_size
    #     if lat0 >= da.sizes['lat']:
    #         continue
    #     lat1 = min(lat0 + read_shard_size, da.sizes['lat'])

    #     lon0 = j * read_shard_size
    #     if lon0 >= da.sizes['lon']:
    #         continue
    #     lon1 = min(lon0 + read_shard_size, da.sizes['lon'])
    #     sub = da.isel(lat=slice(lat0, lat1), lon=slice(lon0, lon1))
    #     sub.to_zarr(target, group=group_name, mode="a", region="auto", compute=True)


zarrs_reencode -c 16,16,0,0,0,0,0 -s 512,512,0,0,0,0,0 --concurrent-chunks 32 ../raster/target.v2.zarr/aqueduct/depth_coastal/depth_coastal ../raster/target.v5.zarr/aqueduct/depth_coastal/depth_coastal
zarrs_reencode -c 16,16,0,0,0,0,0 -s 512,512,0,0,0,0,0 --concurrent-chunks 32 ../raster/target.v2.zarr/aqueduct/depth_fluvial/depth_fluvial ../raster/target.v5.zarr/aqueduct/depth_fluvial/depth_fluvial
zarrs_reencode -c 16,16,0 -s 512,512,0 --concurrent-chunks 32 ../raster/target.v2.zarr/cdd_miranda/CDD_absolute_mean_change_from_15_to_20/CDD_absolute_mean_change_from_15_to_20 ../raster/target.v5.zarr/cdd_miranda/CDD_absolute_mean_change_from_15_to_20/CDD_absolute_mean_change_from_15_to_20
zarrs_reencode -c 16,16,0 -s 512,512,0 --concurrent-chunks 32 ../raster/target.v2.zarr/cdd_miranda/CDD_relative_mean_change_from_15_to_20/CDD_relative_mean_change_from_15_to_20 ../raster/target.v5.zarr/cdd_miranda/CDD_relative_mean_change_from_15_to_20/CDD_relative_mean_change_from_15_to_20
zarrs_reenco

In [None]:
stores = [
    # "../raster/stack.zarr",
    # "../raster/stack.v3.zarr",
    "../raster/target.zarr",
    "../raster/target.v2.zarr",
    "../raster/target.v3.zarr",
]
# to ensure we're not hitting stale consolidated metadata
for store in stores:
    zarr.consolidate_metadata(store)

In [None]:
def check(stores, group, subgroup, array):
    for store in stores:
        start = time.time()
        a = zarr.open_group(store, mode='r', path=f"{group}/{subgroup}")[array]
        # a = zarr.open_group(store, mode='r')[group][subgroup][array]
        chunks = a.chunks
        shape = a.shape
        data = a[a.shape[0] // 2, a.shape[1] // 2]
        ndata = data.size
        end = time.time()

        print(
            group,
            subgroup,
            store,
            shape,
            chunks,
            # data.shape,
            ndata)
        print("   ", end - start)

# check(stores, 'storm', 'ws', 'ws')
# check(stores, 'iris', 'ws', 'ws')
check(stores, 'aqueduct', 'depth_coastal', 'depth_coastal')
# check(stores, 'aqueduct', 'depth_fluvial', 'depth_fluvial')
# check(stores, 'traveltime_to_healthcare', 'travel_time', 'travel_time')

In [24]:
groups = list(zarr.open_group(store).groups())
for key, group in groups:
    for subgroup in group.group_keys():
        print(f"/{key}/{subgroup}")

/aqueduct/depth_coastal
/aqueduct/depth_fluvial
/cdd_miranda/CDD_absolute_mean_change_from_15_to_20
/cdd_miranda/CDD_relative_mean_change_from_15_to_20
/dem/elevation
/dem/slope
/gem_earthquake/pga
/ghsl_buildings/built_surface
/ghsl_pop/population
/iris/ws
/isimip/drought_exposure
/isimip/drought_occurrence
/isimip/extreme_heat_exposure
/isimip/extreme_heat_occurrence
/jrc_flood/depth
/land_cover/lc
/nature/biodiversity_intactness
/nature/forest_landscape_integrity
/nature/organic_carbon
/storm/ws
/traveltime_to_healthcare/travel_time
