In [None]:
from pathlib import Path
import shutil
import xarray as xr
import numpy as np
import dask.array as da

xr.show_versions()

In [None]:
zarr_file_path = Path("../generated/file.zarr")

In [None]:

chunk_size = 5
shape = (50, 32, 1000)
ones_dataset = xr.Dataset({"data": xr.ones_like(xr.DataArray(np.empty(shape)))})
ones_dataset = ones_dataset.chunk({"dim_0": chunk_size})

chunk_indices = np.arange(len(ones_dataset.chunks["dim_0"]))
chunk_ids = np.repeat(np.arange(ones_dataset.sizes["dim_0"] // chunk_size), chunk_size)
chunk_ids_dask_array = da.from_array(chunk_ids, chunks=(chunk_size,))
# Append the chunk IDs Dask array as a new variable to the existing dataset
ones_dataset["chunk_id"] = (("dim_0",), chunk_ids_dask_array)

ones_dataset

In [None]:
ones_dataset.chunk_id.values

In [None]:
# # Try: pre-compute dask id
# ones_dataset["chunk_id"] = ones_dataset["chunk_id"].compute()
# ones_dataset["chunk_id"]

In [None]:
# Create a new dataset filled with zeros
zeros_dataset = xr.Dataset({"data": xr.zeros_like(xr.DataArray(np.empty(shape)))})
zeros_dataset

In [None]:
if zarr_file_path.exists():
    shutil.rmtree(zarr_file_path)

zeros_dataset.to_zarr(zarr_file_path, compute=False)
zarr_data = xr.open_zarr(zarr_file_path)

def process_chunk(chunk_dataset: xr.Dataset):
    chunk_id = int(chunk_dataset["chunk_id"][0])
    chunk_dataset_to_store = chunk_dataset.drop_vars("chunk_id")

    start_index = chunk_id * chunk_size
    end_index = chunk_id * chunk_size + chunk_size

    print(start_index, end_index) 
    
    # chunk_dataset_to_store.to_zarr(
    #     zarr_file_path, region={"dim_0": slice(start_index, end_index)}
    # )
    
    return chunk_dataset


# ones_dataset.map_blocks(process_chunk, template=ones_dataset).compute()
mapped = ones_dataset.map_blocks(process_chunk, template=ones_dataset)
mapped.data

In [None]:
mapped.to_zarr(zarr_file_path)

In [None]:
# Load data stored in zarr
zarr_data = xr.open_zarr(zarr_file_path)
# zarr_data = xr.open_zarr(zarr_file_path, chunks={"dim_0": chunk_size})

# Find differences
for var_name in zarr_data.variables:
    try:
        xr.testing.assert_equal(zarr_data[var_name], ones_dataset[var_name])
    except AssertionError:
        print(f"Differences in {var_name}:")
        expected = ones_dataset[var_name].sum().compute().item()
        actual = zarr_data[var_name].sum().compute().item()
        print(f"{expected=}")
        print(f"{actual=}")

In [None]:
zarr_data.data