In [10]:
import logging
import time

import dask
from distributed import wait
import xarray_multiscale
from ome_zarr.io import parse_url
from ome_zarr.writer import write_multiscale
from ome_zarr.format import CurrentFormat
import zarr

from aind_data_transfer.transformations import ome_zarr
from aind_data_transfer.util import chunk_utils

import zarr_io

In [None]:
# Install hdf5 plugin locally, then should be no issues. 
# Push a PR to repo with this change. 
# -> ujson
# -> hdf5plugin
# -> kerchunk

# Add from utils import ensure_shape_5d, expand_chunks, guess_chunks into transformations/ome_zarr.py

In [14]:
# NOTE: Aware of dask.optimize thing, skipping
# NOTE: Aware of dask.config thing, skipping

import dask.array as da
import numpy as np

logging.basicConfig(format="%(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M")
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.INFO)

# Primary Input
image = zarr_io.open_zarr_gcs('sofima-test-bucket', 'output_level_0_debug.zarr')
class SyncAdapter:
  """Makes it possible to use a TensorStore objects as a numpy array."""
  
  def __init__(self, tstore):
    self.tstore = tstore

  def __getitem__(self, ind):
    return np.array(self.tstore[ind])

  def __getattr__(self, attr):
    return getattr(self.tstore, attr)

  @property
  def shape(self):
    return self.tstore.shape

  @property
  def ndim(self):
    return self.tstore.ndim

  @property
  def dtype(self):
    return "uint16"

# Output Paths
output_path = 'gs://sofima-test-bucket/'
output_name = 'fused.zarr'

# Other Input Parameters
scale_factor = 2
voxel_sizes = (0.176, 0.298, 0.298)

# Actual Processing
dask_image = dask.array.from_array(SyncAdapter(image))
# This is the optimized chunksize
chunks = chunk_utils.expand_chunks(chunks=dask_image.chunksize,
                                   data_shape=dask_image.shape,
                                   target_size=64, # Same as Cameron's code 
                                   itemsize=dask_image.itemsize) 
chunks = chunk_utils.ensure_shape_5d(chunks)

scale_axis = (1, 1, scale_factor, scale_factor, scale_factor)
n_lvls = 5
pyramid = xarray_multiscale.multiscale(
            dask_image,
            xarray_multiscale.reducers.windowed_mean,
            scale_axis,  # scale factors
            preserve_dtype=True,
            chunks="auto",  # can also try "preserve", which is the default
            )[:n_lvls]

pyramid_data = [arr.data for arr in pyramid]
print(f'{pyramid_data=}')

axes_5d = ome_zarr._get_axes_5d()
transforms, chunk_opts = ome_zarr._compute_scales(
        len(pyramid),
        (scale_factor,) * 3,
        voxel_sizes,
        chunks, # Can optimize, or simply use dask default chunking. 
        pyramid[0].shape  # origin optional-- especially for single fused image
    )

loader = CurrentFormat()
store = loader.init_store(output_path, mode='w')

root_group = zarr.group(store=store)
group = root_group.create_group(output_name, overwrite=True)

# Actual Jobs
LOGGER.info("Starting write...")
t0 = time.time()
jobs = write_multiscale(
    pyramid,
    group=group,
    fmt=CurrentFormat(),
    axes=axes_5d,
    coordinate_transformations=transforms,
    storage_options=chunk_opts,
    name=None,
    compute=False,
)
if jobs:
    LOGGER.info("Computing dask arrays...")
    arrs = dask.persist(*jobs)
    wait(arrs)
write_time = time.time() - t0

LOGGER.info(
      f"Finished writing tile {output_name}.\n"
      f"Took {write_time}s."
  )

pyramid_data=[dask.array<array, shape=(1, 1, 14172, 3468, 2304), dtype=uint16, chunksize=(1, 1, 406, 406, 406), chunktype=numpy.ndarray>, dask.array<rechunk-merge, shape=(1, 1, 7086, 1734, 1152), dtype=uint16, chunksize=(1, 1, 406, 406, 406), chunktype=numpy.ndarray>, dask.array<rechunk-merge, shape=(1, 1, 3543, 867, 576), dtype=uint16, chunksize=(1, 1, 306, 306, 306), chunktype=numpy.ndarray>, dask.array<rechunk-merge, shape=(1, 1, 1771, 433, 288), dtype=uint16, chunksize=(1, 1, 510, 433, 288), chunktype=numpy.ndarray>, dask.array<rechunk-merge, shape=(1, 1, 885, 216, 144), dtype=uint16, chunksize=(1, 1, 885, 216, 144), chunktype=numpy.ndarray>]


2023-06-27 21:33 Starting write...


In [None]:
# If necessary, can try out implementing custom reader (don't think it is necessary)

# Minimal Reader
# class ZarrReader(io_utils.DataReader):
#   def __init__(self):
#     pass

#   def as_dask_array(self, chunks: Any = None) -> Array:
#     return super().as_dask_array(chunks)
  
#   def get_shape(self):
#     pass

#   def get_chunks(self):
#     pass

#   def get_itemsize(self):
#     pass