Skip to content

Commit

Permalink
Merge pull request #741 from adrtsc/main
Browse files Browse the repository at this point in the history
Addition of chi2_shift as a Method to calculate Shifts to the "Calculate Registration (image-based)" Task
  • Loading branch information
jluethi committed Jul 4, 2024
2 parents f306ab5 + 3511871 commit 674dcac
Show file tree
Hide file tree
Showing 8 changed files with 419 additions and 20 deletions.
18 changes: 18 additions & 0 deletions fractal_tasks_core/__FRACTAL_MANIFEST__.json
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,16 @@
"type": "string",
"description": "Wavelength that will be used for image-based registration; e.g. `A01_C01` for Yokogawa, `C01` for MD."
},
"method": {
"default": "phase_cross_correlation",
"allOf": [
{
"$ref": "#/definitions/RegistrationMethod"
}
],
"title": "Method",
"description": "Method to use for image registration. The available methods are `phase_cross_correlation` (scikit-image package, works for 2D & 3D) and \"chi2_shift\" (image_registration package, only works for 2D images)."
},
"roi_table": {
"title": "Roi Table",
"default": "FOV_ROI_table",
Expand Down Expand Up @@ -1063,6 +1073,14 @@
"required": [
"reference_zarr_url"
]
},
"RegistrationMethod": {
"title": "RegistrationMethod",
"description": "An enumeration.",
"enum": [
"phase_cross_correlation",
"chi2_shift"
]
}
}
},
Expand Down
50 changes: 50 additions & 0 deletions fractal_tasks_core/tasks/_registration_utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import copy

import anndata as ad
import dask.array as da
import numpy as np
import pandas as pd
from image_registration import chi2_shift

from fractal_tasks_core.ngff.zarr_utils import load_NgffWellMeta
from fractal_tasks_core.tasks._zarr_utils import _split_well_path_image_path
Expand Down Expand Up @@ -235,3 +237,51 @@ def apply_registration_to_single_ROI_table(
+ float(min_df.loc[roi, "translation_x"])
)
return roi_table


def chi2_shift_out(img_ref, img_cycle_x) -> list[np.ndarray]:
"""
Helper function to get the output of chi2_shift into the same format as
phase_cross_correlation. Calculates the shift between two images using
the chi2_shift method.
Args:
img_ref (np.ndarray): First image.
img_cycle_x (np.ndarray): Second image.
Returns:
List containing numpy array of shift in y and x direction.
"""
x, y, a, b = chi2_shift(np.squeeze(img_ref), np.squeeze(img_cycle_x))

"""
Running into issues when using direct float output for fractal.
When rounding to integer and using integer dtype, it typically works
but for some reasons fails when run over a whole 384 well plate (but
the well where it fails works fine when run alone). For now, rounding
to integer, but still using float64 dtype (like the scikit-image
phase cross correlation function) seems to be the safest option.
"""
shifts = np.array([-np.round(y), -np.round(x)], dtype="float64")
# return as a list to adhere to the phase_cross_correlation output format
return [shifts]


def is_3D(dask_array: da.array) -> bool:
"""
Check if a dask array is 3D.
Treats singelton Z dimensions as 2D images.
(1, 2000, 2000) => False
(10, 2000, 2000) => True
Args:
dask_array: Input array to be checked
Returns:
bool on whether the array is 3D
"""
if len(dask_array.shape) == 3 and dask_array.shape[0] > 1:
return True
else:
return False
46 changes: 31 additions & 15 deletions fractal_tasks_core/tasks/calculate_registration_image_based.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Calculates translation for image-based registration
"""
import logging
from enum import Enum

import anndata as ad
import dask.array as da
Expand All @@ -36,14 +37,27 @@
from fractal_tasks_core.tasks._registration_utils import (
calculate_physical_shifts,
)
from fractal_tasks_core.tasks._registration_utils import chi2_shift_out
from fractal_tasks_core.tasks._registration_utils import (
get_ROI_table_with_translation,
)
from fractal_tasks_core.tasks._registration_utils import is_3D
from fractal_tasks_core.tasks.io_models import InitArgsRegistration

logger = logging.getLogger(__name__)


class RegistrationMethod(Enum):
PHASE_CROSS_CORRELATION = "phase_cross_correlation"
CHI2_SHIFT = "chi2_shift"

def register(self, img_ref, img_acq_x):
if self == RegistrationMethod.PHASE_CROSS_CORRELATION:
return phase_cross_correlation(img_ref, img_acq_x)
elif self == RegistrationMethod.CHI2_SHIFT:
return chi2_shift_out(img_ref, img_acq_x)


@validate_arguments
def calculate_registration_image_based(
*,
Expand All @@ -52,6 +66,7 @@ def calculate_registration_image_based(
init_args: InitArgsRegistration,
# Core parameters
wavelength_id: str,
method: RegistrationMethod = "phase_cross_correlation",
roi_table: str = "FOV_ROI_table",
level: int = 2,
) -> None:
Expand All @@ -73,6 +88,10 @@ def calculate_registration_image_based(
(standard argument for Fractal tasks, managed by Fractal server).
wavelength_id: Wavelength that will be used for image-based
registration; e.g. `A01_C01` for Yokogawa, `C01` for MD.
method: Method to use for image registration. The available methods
are `phase_cross_correlation` (scikit-image package, works for 2D
& 3D) and "chi2_shift" (image_registration package, only works for
2D images).
roi_table: Name of the ROI table over which the task loops to
calculate the registration. Examples: `FOV_ROI_table` => loop over
the field of views, `well_ROI_table` => process the whole well as
Expand Down Expand Up @@ -115,6 +134,16 @@ def calculate_registration_image_based(
channel_index_align
]

# Check if data is 3D (as not all registration methods work in 3D)
# TODO: Abstract this check into a higher-level Zarr loading class
if is_3D(data_reference_zyx):
if method == "chi2_shift":
raise ValueError(
"The `chi2_shift` registration method has not been "
"implemented for 3D images and the input image had a shape of "
f"{data_reference_zyx.shape}."
)

# Read ROIs
ROI_table_ref = ad.read_zarr(
f"{init_args.reference_zarr_url}/tables/{roi_table}"
Expand Down Expand Up @@ -211,30 +240,17 @@ def calculate_registration_image_based(
##############
# Calculate the transformation
##############
# Basic version (no padding, no internal binning)
if img_ref.shape != img_acq_x.shape:
raise NotImplementedError(
"This registration is not implemented for ROIs with "
"different shapes between acquisitions."
)
shifts = phase_cross_correlation(
np.squeeze(img_ref), np.squeeze(img_acq_x)
)[0]

# Registration based on scmultiplex, image-based
# shifts, _, _ = calculate_shift(np.squeeze(img_ref),
# np.squeeze(img_acq_x), bin=binning, binarize=False)

# TODO: Make this work on label images
# (=> different loading) etc.
shifts = method.register(np.squeeze(img_ref), np.squeeze(img_acq_x))[0]

##############
# Storing the calculated transformation ###
# Store the calculated transformation ###
##############
# Store the shift in ROI table
# TODO: Store in OME-NGFF transformations: Check SpatialData approach,
# per ROI storage?

# Adapt ROIs for the given ROI table:
ROI_name = ROI_table_ref.obs.index[i_ROI]
new_shifts[ROI_name] = calculate_physical_shifts(
Expand Down
1 change: 1 addition & 0 deletions fractal_tasks_core/tasks/find_registration_consensus.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def find_registration_consensus(
new_roi_table,
shifted_rois[acq_zarr_url],
table_attrs=roi_tables_attrs[acq_zarr_url],
overwrite=True,
)

# TODO: Optionally apply registration to other tables as well?
Expand Down
121 changes: 118 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ napari-skimage-regionprops = { version = "^0.8.1", optional = true }
napari-tools-menu = { version = "^0.1.19", optional = true }
cellpose = { version = "~2.2", optional = true }
torch = { version = "<=2.0.0", optional = true }
image_registration = { version = ">=0.2.9", optional = true }

[tool.poetry.extras]
fractal-tasks = ["Pillow", "imageio-ffmpeg", "scikit-image", "llvmlite", "napari-segment-blobs-and-things-with-membranes", "napari-workflows", "napari-skimage-regionprops", "napari-tools-menu", "cellpose", "torch"]
fractal-tasks = ["Pillow", "imageio-ffmpeg", "scikit-image", "llvmlite", "napari-segment-blobs-and-things-with-membranes", "napari-workflows", "napari-skimage-regionprops", "napari-tools-menu", "cellpose", "torch", "image_registration"]

[tool.poetry.group.dev]
optional = true
Expand Down
Loading

0 comments on commit 674dcac

Please sign in to comment.