In [1]:
import zarr
import numpy as np
import multiprocessing
import warnings
from typing import Tuple
import logging

In [2]:
warnings.filterwarnings(
    "ignore", "Found an empty list of filters in the array metadata document."
)

logger = logging.getLogger("rechunker")
logger.setLevel(logging.DEBUG)
logging.basicConfig()

In [3]:
def open_reference_filesystem(zarr_path, storage_options={}):
    storage_options = add_parquet_options_if_needed(zarr_path, storage_options)
    return zarr.open_group(f"reference:://{zarr_path}", storage_options=storage_options)


def add_parquet_options_if_needed(zarr_path, storage_options):
    if zarr_path.endswith(".parq"):
        storage_options = storage_options or dict()
        if "remote_protocol" not in storage_options:
            storage_options = {"remote_protocol": "file"}
        storage_options = storage_options | {"lazy": True}
    return storage_options


def gen_array(zarr_out, template, chunks):
    return zarr_out.require_array(
        name=template.name,
        chunk_key_encoding=zarr.core.chunk_key_encodings.DefaultChunkKeyEncoding(
            "v2", "/"
        ),
        shape=template.shape,
        chunks=chunks,
        dtype="float32",
        fill_value=np.nan,
    )


def create_outfile(filename, zarr_format):
    return zarr.open(
        filename,
        mode="w",
        zarr_format=zarr_format,
    )


def simple_remap(var_name, big_slice, source, destination, chunksize):
    print(big_slice)
    outvar = destination[var_name]
    for my_slice in iter_slices(big_slice, chunksize):
        outvar[*my_slice] = source[*my_slice]


def double_remap(
    var_name,
    slice_to_process,
    in_store,
    temp_store,
    out_store,
):
    in_var = in_store[var_name]
    temp_var = temp_store[var_name]
    out_var = out_store[var_name]

    temp_chunksize = shape_union(in_var.chunks, temp_var.chunks)
    out_chunksize = shape_union(temp_var.chunks, out_var.chunks)

    for my_slice in iter_slices(slice_to_process, temp_chunksize):
        temp_var[*my_slice] = in_var[*my_slice]

    for my_slice in iter_slices(slice_to_process, out_chunksize):
        out_var[*my_slice] = temp_var[*my_slice]

    temp_var[*slice_to_process] = np.nan


def iter_slices(to_go, chunks, current=list()):
    pos = len(current)
    my_chunk = chunks[pos]
    for i in range(to_go[0].start, to_go[0].stop, my_chunk):
        my_slice = [slice(i, min(i + my_chunk, to_go[0].stop))]
        if len(to_go) == 1:
            yield tuple(current + my_slice)
        else:
            yield from iter_slices(to_go[1:], chunks, current=current + my_slice)


def shape_union(shape_a: Tuple[int, ...], shape_b: Tuple[int, ...]) -> Tuple[int, ...]:
    if len(shape_a) != len(shape_b):
        raise ValueError(
            f"Cannot compute union of shapes with different rank: "
            f"{len(shape_a)} vs {len(shape_b)}"
        )
    return tuple(max(a_dim, b_dim) for a_dim, b_dim in zip(shape_a, shape_b))

In [4]:
var_name = "cc"

temp_file = "/fastdata/k20200/k202134/fasterpool/test2.zarr"
out_file = "/work/bm1235/k202134/IFS_2km.zarr"

In [5]:
in_store = open_reference_filesystem("/work/bm1235/k202134/2D_hourly_healpix2048.parq")

invar = in_store[var_name]
ncells = invar.chunks[-1]
invar.chunks

(1, 1, 50331648)

In [6]:
gen_array(create_outfile(temp_file, zarr_format=3), invar, (1, 1, 4**10))

<Array file:///fastdata/k20200/k202134/fasterpool/test2.zarr/cc shape=(10201, 25, 50331648) dtype=float32>

In [7]:
gen_array(create_outfile(out_file, zarr_format=2), invar, (24, 5, 4**8))

<Array file:///work/bm1235/k202134/IFS_2km.zarr/cc shape=(10201, 25, 50331648) dtype=float32>

In [8]:
%%time

temp_store = zarr.open(temp_file, mode="r+")
out_store = zarr.open(
    out_file,
    mode="r+",
)
with multiprocessing.Pool(64) as pool:
    pool.starmap(
        double_remap,
        [
            (var_name, my_slice, in_store, temp_store, out_store)
            for my_slice in iter_slices(
                (slice(0, 96), slice(0, 25), slice(0, ncells)), (24, 5, ncells)
            )
        ],
    )

CPU times: user 72 ms, sys: 183 ms, total: 255 ms
Wall time: 5min 38s
