Skip to content

Commit

Permalink
Merge pull request #388 from joshua-gould/delayed
Browse files Browse the repository at this point in the history
Write metadata using delayed that depends on writing array(s)
  • Loading branch information
joshmoore committed Jul 11, 2024
2 parents b2dda99 + 6461a20 commit 26d6413
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
27 changes: 21 additions & 6 deletions ome_zarr/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import dask
import dask.array as da
import numpy as np
import zarr
from dask.graph_manipulation import bind

from .axes import Axes
from .format import CurrentFormat, Format
Expand Down Expand Up @@ -232,7 +234,6 @@ def write_multiscale(
msg = """The 'chunks' argument is deprecated and will be removed in version 0.5.
Please use the 'storage_options' argument instead."""
warnings.warn(msg, DeprecationWarning)

datasets: List[dict] = []
for path, data in enumerate(pyramid):
options = _resolve_storage_options(storage_options, path)
Expand Down Expand Up @@ -280,7 +281,15 @@ def write_multiscale(
for dataset, transform in zip(datasets, coordinate_transformations):
dataset["coordinateTransformations"] = transform

write_multiscales_metadata(group, datasets, fmt, axes, name, **metadata)
if len(dask_delayed) > 0 and not compute:
write_multiscales_metadata_delayed = dask.delayed(write_multiscales_metadata)
return dask_delayed + [
bind(write_multiscales_metadata_delayed, dask_delayed)(
group, datasets, fmt, axes, name, **metadata
)
]
else:
write_multiscales_metadata(group, datasets, fmt, axes, name, **metadata)

return dask_delayed

Expand Down Expand Up @@ -629,10 +638,16 @@ def _write_dask_image(
if coordinate_transformations is not None:
for dataset, transform in zip(datasets, coordinate_transformations):
dataset["coordinateTransformations"] = transform

write_multiscales_metadata(group, datasets, fmt, axes, name, **metadata)

return delayed
if not compute:
write_multiscales_metadata_delayed = dask.delayed(write_multiscales_metadata)
return delayed + [
bind(write_multiscales_metadata_delayed, delayed)(
group, datasets, fmt, axes, name, **metadata
)
]
else:
write_multiscales_metadata(group, datasets, fmt, axes, name, **metadata)
return delayed


def write_label_metadata(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_write_image_dask(self, read_from_zarr, compute):

assert not compute == len(dask_delayed_jobs)

if not len(dask_delayed_jobs):
if not compute:
# can be configured to use a Local or Slurm cluster
# before persisting the jobs
dask_delayed_jobs = persist(*dask_delayed_jobs)
Expand Down

0 comments on commit 26d6413

Please sign in to comment.