diff --git a/CHANGELOG.md b/CHANGELOG.md index 5905abba0..c4810eeaa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,13 +1,14 @@ **Note**: Numbers like (\#123) point to closed Pull Requests on the fractal-tasks-core repository. -# Unreleased +# 0.13.0 * Tasks: * New task and helper functions: - * Introduce `import_ome_zarr` task (\#557). + * Introduce `import_ome_zarr` task (\#557, \#579). * Introduce `get_single_image_ROI` and `get_image_grid_ROIs` (\#557). * Introduce `detect_ome_ngff_type` (\#557). - * Make `maximum_intensity_projection` task not depend on ROI tables (\#557). + * Introduce `update_omero_channels` (\#579). + * Make `maximum_intensity_projection` independent from ROI tables (\#557). * Make Cellpose task work when `input_ROI_table` is empty (\#566). * Fix bug of missing attributes in ROI-table Zarr group (\#573). * Dependencies: diff --git a/fractal_tasks_core/__FRACTAL_MANIFEST__.json b/fractal_tasks_core/__FRACTAL_MANIFEST__.json index 7877cf068..6e14d1a40 100644 --- a/fractal_tasks_core/__FRACTAL_MANIFEST__.json +++ b/fractal_tasks_core/__FRACTAL_MANIFEST__.json @@ -1245,6 +1245,12 @@ "type": "integer", "description": "X shape of the ROI grid in `grid_ROI_table`." }, + "update_omero_metadata": { + "title": "Update Omero Metadata", + "default": true, + "type": "boolean", + "description": "Whether to update Omero-channels metadata, to make them Fractal-compatible." + }, "overwrite": { "title": "Overwrite", "default": false, diff --git a/fractal_tasks_core/lib_channels.py b/fractal_tasks_core/lib_channels.py index 811830f8c..4b79487a7 100644 --- a/fractal_tasks_core/lib_channels.py +++ b/fractal_tasks_core/lib_channels.py @@ -12,6 +12,8 @@ Helper functions to address channels via OME-NGFF/OMERO metadata. """ import logging +from copy import deepcopy +from typing import Any from typing import Optional from typing import Union @@ -339,3 +341,135 @@ def define_omero_channels( ] return new_channels_dictionaries + + +def _get_new_unique_value( + value: str, + existing_values: list[str], +) -> str: + """ + Produce a string value that is not present in a given list + + Append `_1`, `_2`, ... to a given string, if needed, until finding a value + which is not already present in `existing_values`. + + Args: + value: The first guess for the new value + existing_values: The list of existing values + + Returns: + A string value which is not present in `existing_values` + """ + counter = 1 + new_value = value + while new_value in existing_values: + new_value = f"{value}-{counter}" + counter += 1 + return new_value + + +def update_omero_channels( + old_channels: list[dict[str, Any]] +) -> list[dict[str, Any]]: + """ + Make an existing list of Omero channels Fractal-compatible + + The output channels all have keys `label`, `wavelength_id` and `color`; + the `wavelength_id` values are unique across the channel list. + + See https://ngff.openmicroscopy.org/0.4/index.html#omero-md for the + definition of NGFF Omero metadata. + + Args: + old_channels: Existing list of Omero-channel dictionaries + + Returns: + New list of Fractal-compatible Omero-channel dictionaries + """ + new_channels = deepcopy(old_channels) + existing_wavelength_ids: list[str] = [] + handled_channels = [] + + default_colors = ["00FFFF", "FF00FF", "FFFF00"] + + def _get_next_color() -> str: + try: + return default_colors.pop(0) + except IndexError: + return "808080" + + # Channels that contain the key "wavelength_id" + for ind, old_channel in enumerate(old_channels): + if "wavelength_id" in old_channel.keys(): + handled_channels.append(ind) + existing_wavelength_ids.append(old_channel["wavelength_id"]) + new_channel = old_channel.copy() + try: + label = old_channel["label"] + except KeyError: + label = str(ind + 1) + new_channel["label"] = label + if "color" not in old_channel: + new_channel["color"] = _get_next_color() + new_channels[ind] = new_channel + + # Channels that contain the key "label" but do not contain the key + # "wavelength_id" + for ind, old_channel in enumerate(old_channels): + if ind in handled_channels: + continue + if "label" not in old_channel.keys(): + continue + handled_channels.append(ind) + label = old_channel["label"] + wavelength_id = _get_new_unique_value( + label, + existing_wavelength_ids, + ) + existing_wavelength_ids.append(wavelength_id) + new_channel = old_channel.copy() + new_channel["wavelength_id"] = wavelength_id + if "color" not in old_channel: + new_channel["color"] = _get_next_color() + new_channels[ind] = new_channel + + # Channels that do not contain the key "label" nor the key "wavelength_id" + # NOTE: these channels must be treated last, as they have lower priority + # w.r.t. existing "wavelength_id" or "label" values + for ind, old_channel in enumerate(old_channels): + if ind in handled_channels: + continue + label = str(ind + 1) + wavelength_id = _get_new_unique_value( + label, + existing_wavelength_ids, + ) + existing_wavelength_ids.append(wavelength_id) + new_channel = old_channel.copy() + new_channel["label"] = label + new_channel["wavelength_id"] = wavelength_id + if "color" not in old_channel: + new_channel["color"] = _get_next_color() + new_channels[ind] = new_channel + + # Log old/new values of label, wavelength_id and color + for ind, old_channel in enumerate(old_channels): + label = old_channel.get("label") + color = old_channel.get("color") + wavelength_id = old_channel.get("wavelength_id") + old_attributes = ( + f"Old attributes: {label=}, {wavelength_id=}, {color=}" + ) + label = new_channels[ind]["label"] + wavelength_id = new_channels[ind]["wavelength_id"] + color = new_channels[ind]["color"] + new_attributes = ( + f"New attributes: {label=}, {wavelength_id=}, {color=}" + ) + logging.info( + "Omero channel update:\n" + f" {old_attributes}\n" + f" {new_attributes}" + ) + + return new_channels diff --git a/fractal_tasks_core/tasks/import_ome_zarr.py b/fractal_tasks_core/tasks/import_ome_zarr.py index 73ee05121..033ffffd3 100644 --- a/fractal_tasks_core/tasks/import_ome_zarr.py +++ b/fractal_tasks_core/tasks/import_ome_zarr.py @@ -21,6 +21,7 @@ import zarr from pydantic.decorator import validate_arguments +from fractal_tasks_core.lib_channels import update_omero_channels from fractal_tasks_core.lib_ngff import detect_ome_ngff_type from fractal_tasks_core.lib_ngff import NgffImageMeta from fractal_tasks_core.lib_regions_of_interest import get_image_grid_ROIs @@ -34,6 +35,8 @@ def _process_single_image( image_path: str, add_image_ROI_table: bool, add_grid_ROI_table: bool, + update_omero_metadata: bool, + *, grid_YX_shape: Optional[tuple[int, int]] = None, overwrite: bool = False, ) -> None: @@ -43,7 +46,8 @@ def _process_single_image( This task: 1. Validates OME-NGFF image metadata, via `NgffImageMeta`; - 2. Optionally generates and writes two ROI tables. + 2. Optionally generates and writes two ROI tables; + 3. Optionally update OME-NGFF omero metadata. Args: image_path: Absolute path to the image Zarr group. @@ -51,6 +55,8 @@ def _process_single_image( (argument propagated from `import_ome_zarr`). add_grid_ROI_table: Whether to add a `grid_ROI_table` table (argument propagated from `import_ome_zarr`). + update_omero_metadata: Whether to update Omero-channels metadata + (argument propagated from `import_ome_zarr`). grid_YX_shape: YX shape of the ROI grid (it must be not `None`, if `add_grid_ROI_table=True`. """ @@ -100,6 +106,51 @@ def _process_single_image( logger=logger, ) + # Update Omero-channels metadata + if update_omero_metadata: + # Extract number of channels from zarr array + try: + channel_axis_index = image_meta.axes_names.index("c") + except ValueError: + logger.error(f"Existing axes: {image_meta.axes_names}") + msg = ( + "OME-Zarrs with no channel axis are not currently " + "supported in fractal-tasks-core. Upcoming flexibility " + "improvements are tracked in https://github.com/" + "fractal-analytics-platform/fractal-tasks-core/issues/150." + ) + logger.error(msg) + raise NotImplementedError(msg) + logger.info(f"Existing axes: {image_meta.axes_names}") + logger.info(f"Channel-axis index: {channel_axis_index}") + num_channels_zarr = array.shape[channel_axis_index] + logger.info( + f"{num_channels_zarr} channel(s) found in Zarr array " + f"at {image_path}/{dataset_subpath}" + ) + # Update or create omero channels metadata + old_omero = image_group.attrs.get("omero", {}) + old_channels = old_omero.get("channels", []) + if len(old_channels) > 0: + logger.info( + f"{len(old_channels)} channel(s) found in NGFF omero metadata" + ) + if len(old_channels) != num_channels_zarr: + error_msg = ( + "Channels-number mismatch: Number of channels in the " + f"zarr array ({num_channels_zarr}) differs from number " + "of channels listed in NGFF omero metadata " + f"({len(old_channels)})." + ) + logging.error(error_msg) + raise ValueError(error_msg) + else: + old_channels = [{} for ind in range(num_channels_zarr)] + new_channels = update_omero_channels(old_channels) + new_omero = old_omero.copy() + new_omero["channels"] = new_channels + image_group.attrs.update(omero=new_omero) + @validate_arguments def import_ome_zarr( @@ -112,6 +163,7 @@ def import_ome_zarr( add_grid_ROI_table: bool = True, grid_y_shape: int = 2, grid_x_shape: int = 2, + update_omero_metadata: bool = True, overwrite: bool = False, ) -> dict[str, Any]: """ @@ -141,6 +193,8 @@ def import_ome_zarr( image, with the image split into a rectangular grid of ROIs. grid_y_shape: Y shape of the ROI grid in `grid_ROI_table`. grid_x_shape: X shape of the ROI grid in `grid_ROI_table`. + update_omero_metadata: Whether to update Omero-channels metadata, to + make them Fractal-compatible. overwrite: Whether new ROI tables (added when `add_image_ROI_table` and/or `add_grid_ROI_table` are `True`) can overwite existing ones. """ @@ -174,14 +228,15 @@ def import_ome_zarr( f"{zarr_path}/{well_path}/{image_path}", add_image_ROI_table, add_grid_ROI_table, - grid_YX_shape, + update_omero_metadata, + grid_YX_shape=grid_YX_shape, overwrite=overwrite, ) elif ngff_type == "well": zarrurls["well"].append(zarr_name) logger.warning( "Only OME-Zarr for plates are fully supported in Fractal; " - "e.g. the current one ({ngff_type=}) cannot be " + f"e.g. the current one ({ngff_type=}) cannot be " "processed via the `maximum_intensity_projection` task." ) for image in root_group.attrs["well"]["images"]: @@ -191,21 +246,23 @@ def import_ome_zarr( f"{zarr_path}/{image_path}", add_image_ROI_table, add_grid_ROI_table, - grid_YX_shape, + update_omero_metadata, + grid_YX_shape=grid_YX_shape, overwrite=overwrite, ) elif ngff_type == "image": zarrurls["image"].append(zarr_name) logger.warning( "Only OME-Zarr for plates are fully supported in Fractal; " - "e.g. the current one ({ngff_type=}) cannot be " + f"e.g. the current one ({ngff_type=}) cannot be " "processed via the `maximum_intensity_projection` task." ) _process_single_image( zarr_path, add_image_ROI_table, add_grid_ROI_table, - grid_YX_shape, + update_omero_metadata, + grid_YX_shape=grid_YX_shape, overwrite=overwrite, ) diff --git a/tests/_zenodo_ome_zarrs.py b/tests/_zenodo_ome_zarrs.py index 607f5f1c9..f0976879f 100644 --- a/tests/_zenodo_ome_zarrs.py +++ b/tests/_zenodo_ome_zarrs.py @@ -12,11 +12,13 @@ Zurich. """ import json +import logging import shutil from pathlib import Path from typing import Any import dask.array as da +import zarr from devtools import debug @@ -25,6 +27,7 @@ def prepare_3D_zarr( zenodo_zarr: list[str], zenodo_zarr_metadata: list[dict[str, Any]], remove_tables: bool = False, + remove_omero: bool = False, ): zenodo_zarr_3D, zenodo_zarr_2D = zenodo_zarr[:] metadata_3D, metadata_2D = zenodo_zarr_metadata[:] @@ -35,6 +38,16 @@ def prepare_3D_zarr( shutil.rmtree( str(Path(zarr_path) / Path(zenodo_zarr_3D).name / "B/03/0/tables") ) + logging.warning("Removing ROI tables attributes 3D Zenodo zarr") + if remove_omero: + image_group = zarr.open_group( + str(Path(zarr_path) / Path(zenodo_zarr_3D).name / "B/03/0"), + mode="r+", + ) + image_attrs = image_group.attrs.asdict() + image_attrs.pop("omero") + image_group.attrs.put(image_attrs) + logging.warning("Removing omero attributes from 3D Zenodo zarr") metadata = metadata_3D.copy() return metadata diff --git a/tests/tasks/test_import_ome_zarr.py b/tests/tasks/test_import_ome_zarr.py index a2589e2dc..06684e766 100644 --- a/tests/tasks/test_import_ome_zarr.py +++ b/tests/tasks/test_import_ome_zarr.py @@ -2,7 +2,9 @@ import zarr from devtools import debug +import fractal_tasks_core.tasks # noqa from .._zenodo_ome_zarrs import prepare_3D_zarr +from fractal_tasks_core.lib_input_models import Channel from fractal_tasks_core.tasks.copy_ome_zarr import copy_ome_zarr from fractal_tasks_core.tasks.import_ome_zarr import import_ome_zarr from fractal_tasks_core.tasks.maximum_intensity_projection import ( @@ -100,12 +102,19 @@ def test_import_ome_zarr_well(tmp_path, zenodo_zarr, zenodo_zarr_metadata): _check_ROI_tables(f"{root_path}/{zarr_name}/0") -def test_import_ome_zarr_image(tmp_path, zenodo_zarr, zenodo_zarr_metadata): +@pytest.mark.parametrize("reset_omero", [True, False]) +def test_import_ome_zarr_image( + tmp_path, zenodo_zarr, zenodo_zarr_metadata, reset_omero +): # Prepare an on-disk OME-Zarr at the plate level root_path = tmp_path prepare_3D_zarr( - root_path, zenodo_zarr, zenodo_zarr_metadata, remove_tables=True + root_path, + zenodo_zarr, + zenodo_zarr_metadata, + remove_tables=True, + remove_omero=reset_omero, ) zarr_name = "plate.zarr/B/03/0" @@ -129,15 +138,70 @@ def test_import_ome_zarr_image(tmp_path, zenodo_zarr, zenodo_zarr_metadata): # Check that table were created _check_ROI_tables(f"{root_path}/{zarr_name}") + # Check that omero attributes were filled correctly + g = zarr.open_group(str(root_path / zarr_name), mode="r") + debug(g.attrs["omero"]["channels"]) + if reset_omero: + EXPECTED_CHANNELS = [ + dict(label="1", wavelength_id="1", color="00FFFF") + ] + assert g.attrs["omero"]["channels"] == EXPECTED_CHANNELS + else: + EXPECTED_LABEL = "DAPI" + EXPECTED_WAVELENGTH_ID = "A01_C01" + assert g.attrs["omero"]["channels"][0]["label"] == EXPECTED_LABEL + assert ( + g.attrs["omero"]["channels"][0]["wavelength_id"] + == EXPECTED_WAVELENGTH_ID + ) + + +def test_import_ome_zarr_image_wrong_channels( + tmp_path, zenodo_zarr, zenodo_zarr_metadata +): + # Prepare an on-disk OME-Zarr at the plate level + root_path = tmp_path + prepare_3D_zarr( + root_path, + zenodo_zarr, + zenodo_zarr_metadata, + remove_tables=True, + remove_omero=True, + ) + zarr_name = "plate.zarr/B/03/0" + # Modify NGFF omero metadata, adding two channels (even if the Zarr array + # has only one) + g = zarr.open_group(str(root_path / zarr_name), mode="r+") + new_omero = dict( + channels=[ + dict(color="asd"), + dict(color="asd"), + ] + ) + g.attrs.update(omero=new_omero) + # Run import_ome_zarr and catch the error + with pytest.raises(ValueError) as e: + import_ome_zarr( + input_paths=[str(root_path)], + zarr_name=zarr_name, + output_path="null", + metadata={}, + ) + debug(e.value) + assert "Channels-number mismatch" in str(e.value) + @pytest.mark.skip -def test_import_ome_zarr_image_BIA(tmp_path): +def test_import_ome_zarr_image_BIA(tmp_path, monkeypatch): """ This test imports one of the BIA OME-Zarr listed in https://www.ebi.ac.uk/biostudies/bioimages/studies/S-BIAD843. It is currently marked as "skip", to avoid incurring into download-rate limits. + + Also note that any further processing of the imported Zarr this will fail + because we don't support time data, see fractal-tasks-core issue #169. """ from ftplib import FTP @@ -193,3 +257,46 @@ def test_import_ome_zarr_image_BIA(tmp_path): image_ROI_table[:, "len_x_micrometer"].X[0, 0], EXPECTED_X_LENGTH, ) + + g = zarr.open(f"{root_path}/{zarr_name}", mode="r") + omero_channels = g.attrs["omero"]["channels"] + debug(omero_channels) + assert len(omero_channels) == 1 + omero_channel = omero_channels[0] + assert omero_channel["label"] == "Channel 0" + assert omero_channel["wavelength_id"] == "Channel 0" + + # Part 2: run Cellpose on the imported OME-Zarr. + + from fractal_tasks_core.cellpose_segmentation import cellpose_segmentation + from .test_workflows_cellpose_segmentation import ( + patched_cellpose_core_use_gpu, + patched_segment_ROI, + ) + + monkeypatch.setattr( + "fractal_tasks_core.tasks.cellpose_segmentation.cellpose.core.use_gpu", + patched_cellpose_core_use_gpu, + ) + + monkeypatch.setattr( + "fractal_tasks_core.tasks.cellpose_segmentation.segment_ROI", + patched_segment_ROI, + ) + + # Per-FOV labeling + for component in metadata["image"]: + cellpose_segmentation( + input_paths=[str(root_path)], + output_path=str(root_path), + input_ROI_table="grid_ROI_table", + metadata=metadata, + component=component, + channel=Channel(wavelength_id="Channel 0"), + level=0, + relabeling=True, + diameter_level0=80.0, + augment=True, + net_avg=True, + min_size=30, + ) diff --git a/tests/test_unit_channels.py b/tests/test_unit_channels.py index 9ae34bdf0..f3c9141c2 100644 --- a/tests/test_unit_channels.py +++ b/tests/test_unit_channels.py @@ -11,6 +11,7 @@ from fractal_tasks_core.lib_channels import define_omero_channels from fractal_tasks_core.lib_channels import get_channel_from_list from fractal_tasks_core.lib_channels import OmeroChannel +from fractal_tasks_core.lib_channels import update_omero_channels def test_check_unique_wavelength_ids(): @@ -196,3 +197,51 @@ def test_color_validator(): for c in invalid_colors: with pytest.raises(ValueError): OmeroChannel(wavelength_id="A01_C01", color=c) + + +@pytest.mark.parametrize( + "old_channels", + [ + [{}, {}, {}], + [{}, {}, {}, {}], + [{}, {}, {}, {}, {}], + [{"label": "A"}, {"label": "B"}, {"label": "C"}], + [{"label": "A"}, {}, {"label": "C"}], + [{"label": "A"}, {"label": "A"}, {"label": "A"}], + [{"label": "1"}, {"label": "1"}, {"label": "1"}], + [{}, {"label": "1"}, {}, {"label": "1"}], + [ + {"wavelength_id": "1"}, + {"label": "1"}, + {}, + {"label": "1", "wavelength_id": "3"}, + ], + [{"color": "FFFFFF"}, {}, {}, {}, {}, {}], + ], +) +def test_update_omero_channels(old_channels): + + # Update partial metadata + print() + print(f"OLD: {old_channels}") + new_channels = update_omero_channels(old_channels) + print(f"NEW: {new_channels}") + + # Validate new channels as `OmeroChannel` objects, and check that they + # have unique `wavelength_id` values + check_unique_wavelength_ids( + [OmeroChannel(**channel) for channel in new_channels] + ) + + # Check that colors are as expected + old_colors = [channel.get("color") for channel in old_channels] + new_colors = [channel["color"] for channel in new_channels] + if set(old_colors) == {None}: + full_colors_list = ["00FFFF", "FF00FF", "FFFF00"] + ["808080"] * 20 + EXPECTED_COLORS = full_colors_list[: len(new_colors)] + debug(EXPECTED_COLORS) + debug(new_channels) + # Note: we compare sets, because the list order of `new_colors` depends + # on other factors (namely whether each channel has the `wavelength_id` + # and/or `label` attributes) + assert set(EXPECTED_COLORS) == set(new_colors)