diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..5ee325b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,5 @@ +.n5 filter=lfs diff=lfs merge=lfs -text +.tif filter=lfs diff=lfs merge=lfs -text +.n5/** filter=lfs diff=lfs merge=lfs -text +examples/data filter=lfs diff=lfs merge=lfs -text +examples/data/** filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index 0a19790..30dc145 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,14 @@ cython_debug/ # PyPI configuration file .pypirc + +# test results +examples/data/test/ + +# huge data +*.tif +*.n5/ +*.png +*.ome.zarr/ + +.DS_Store \ No newline at end of file diff --git a/README.md b/README.md index 2083036..9cfd649 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# matchmaker +# 💞 Matchmaker Tool for segmentation-based deformable registration and object matching @@ -7,10 +7,8 @@ Tool for segmentation-based deformable registration and object matching ### Input - **Fixed image**: 3D instance segmentation in `n5` + resolution - - **Moving image**: 3D instance segmentation `n5` + resolution -- **Registration parameters yaml**: parameters of the `matchmaker` pipeline Expected image shape: ZYX @@ -19,17 +17,46 @@ Expected image shape: ZYX - **QC plots** - **Files with all transforms** - **Table of correspondence between instances in moving and fixed instance segmentations** +- **Logging file: `registration.log`** +- **Optional: Mobie project saved at `{output_dir}/mobie_project/`** ### Registration steps -0. Create point cloud -1. PCA pre-alignment ---> **rigid transform matrix** -2. Manual input: should the image be flipped? -3. Rigid pre-alignment with Elastix OR with rigid CPD ---> **rigid transform matrix** -4. Coherent point drift -5. Matching points with mixed integer programming -6. Deformable registration with Elastix with distance between keypoints in loss and rigidity penalty ---> **B-spline coefficients** +**1. PCA pre-alignment**: alignment of fixed and moving image to the PCs \ + `prealignment.py --fixed_path ... --fixed_key ... --moving_path ... --moving_key ... --output_dir ... --mobie_export --dataset_name ...` \ + +Outputs: +- prealigned images: `{file_name}_prealigned.n5` +- transformation matrixes + - `{file_name}_fixed_T_prealignment.txt` + - `{file_name}_moving_T_prealignment.txt` (maybe final one is `moving_T_prealignment.txt`, couldn't figure this out) +- plots: + - slice per dimension before pre-alignment: `{file_name}_fixed.png`, `{file_name}_moving.png` + - slice per dimension after pre-alignment: `{file_name}_fixed_prealigned.png`, `{file_name}_moving_prealigned.png` + - overlay of slice per dimension after pre-alignment: `overlay_prealignment.png` + - intensity profiles per axis and volume: `fixed_intensity_profile_{axis}.png`, `moving_intensity_profile_{axis}.png` + +**2. Rigid pre-alignment with Elastix** \ +`apply_rigid_elastix.py --fixed_path ... --fixed_key ... --moving_path ... --moving_key ... --output_dir ... --mobie_export --dataset_name ...` \ + +Outputs: +- rigid alinged moving image: `{file_name}_rigid_aligned.n5` +- rigid transformation matrix (Elastix outputs): `result.0.mhd`, `result.0.raw`, `TransformParameters.0.txt` +- logging file: `elastix_log_rigid.log` +- plots: + - `intersample_segm_overlay_before_alignment.png` + - `intersample_segm_rigid_alignment_semantic.png` + + +**3. Coherent point drift** + +**4. Matching points with mixed integer programming** + +**5. Deformable registration with Elastix** \ +with distance between keypoints in loss and rigidity penalty ---> **B-spline coefficients** + + ## Apply registration to other images diff --git a/environment.yml b/environment.yml index b5abc49..cc8a81f 100644 --- a/environment.yml +++ b/environment.yml @@ -16,6 +16,7 @@ dependencies: - optuna - proxsuite - cvxpy + - transforms3d - ipykernel - seaborn @@ -29,5 +30,6 @@ dependencies: - pyscipopt - click - pyyaml + - itk diff --git a/examples/data/platy1_muscles_stardist_fixed.tif b/examples/data/platy1_muscles_stardist_fixed.tif new file mode 100644 index 0000000..02e6673 --- /dev/null +++ b/examples/data/platy1_muscles_stardist_fixed.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bf565c9e84d198c58dc92d432a7b4a4ea2b7519e62b4d8f81a8d1ee85ad83ee9 +size 13164762 diff --git a/examples/data/platy1_muscles_stardist_moving.tif b/examples/data/platy1_muscles_stardist_moving.tif new file mode 100644 index 0000000..a611bd2 --- /dev/null +++ b/examples/data/platy1_muscles_stardist_moving.tif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea6991f9a3d168a8dfd9ec0e8b83d38f73de778f8d1bc83a206d251f0ad54038 +size 60386230 diff --git a/examples/data/rotation_matrix.txt b/examples/data/rotation_matrix.txt new file mode 100644 index 0000000..207754e --- /dev/null +++ b/examples/data/rotation_matrix.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dcdd4056e70e24809604b56e3c8efb544a04faeb8b2432b6a8feeb71e484d6b9 +size 231 diff --git a/examples/data/transformation_matrix.txt b/examples/data/transformation_matrix.txt new file mode 100644 index 0000000..a4ae13a --- /dev/null +++ b/examples/data/transformation_matrix.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d5e9bd5d72c68d4b8fa4d68b37382c836e24bdf87e0893222cdb564c8bf9ab3e +size 406 diff --git a/examples/deform_test_data.py b/examples/deform_test_data.py new file mode 100644 index 0000000..262ea2c --- /dev/null +++ b/examples/deform_test_data.py @@ -0,0 +1,90 @@ +import numpy as np +import tifffile as tif +import transforms3d as tf3d + +from matchmaker.transform_utils import ( + downscale_seg, + get_transformation_matrix, + rotate_img, + crop_to_bbox, +) +from matchmaker.vis import plot_three_slices, plot_overlay +from matchmaker.n5_utils import write_volume + + +def remove_instances(seg, prob=0.05): + # Get all unique instance IDs, excluding background (assumed to be 0) + instance_ids = np.unique(seg) + instance_ids = instance_ids[instance_ids != 0] + + # Randomly select 10% of the instance IDs + num_to_remove = int(len(instance_ids) * prob) + print(f"Number of instances to remove: {num_to_remove}") + selected_ids = np.random.choice(instance_ids, size=num_to_remove, replace=False) + + # Create a mask for the selected IDs and set them to 0 + mask = np.isin(seg, selected_ids) + seg[mask] = 0 + + print(f"Number of instances left: {len(np.unique(seg))}") + return seg + + +def main(): + seg = tif.imread("data/platy1_muscles_stardist.tif") + + # downsample image + factor = 4 + seg = downscale_seg(seg, factor) + seg_fixed = crop_to_bbox(seg) + print("Cropped shape", seg_fixed.shape) + + # save downsampled, fixed image + attributes = {"resolution": [1, 1, 1]} + write_volume( + f="./data/platy1_muscles_stardist_fixed.n5", + arr=seg_fixed, + key="seg", + chunks=(128, 512, 512), + attrs=attributes, + ) + + # save also as tiff + tif.imwrite("./data/platy1_muscles_stardist_fixed.tif", seg_fixed) + + # rotate image + center = np.array(seg_fixed.shape) // 2 + rotation = tf3d.euler.euler2mat( + *[np.deg2rad(155), np.deg2rad(30), np.deg2rad(65)], axes="szyx" + ) + + T, new_shape = get_transformation_matrix( + seg_fixed, center, rotation, save_path="./data/transformation_matrix.txt" + ) + seg_moving = rotate_img(seg_fixed, T, output_shape=new_shape) + + # randomly remove instances + probability = 0.05 + seg_moving = remove_instances(seg_moving, prob=probability) + + # save moving image + attributes = {"resolution": [1, 1, 1]} + write_volume( + f="./data/platy1_muscles_stardist_moving.n5", + arr=seg_moving, + key="seg", + chunks=(128, 512, 512), + attrs=attributes, + ) + + # save also as tiff + tif.imwrite("./data/platy1_muscles_stardist_moving.tif", seg_moving) + + # visualize + plot_three_slices(seg_moving, save_path="./data/plots/seg_moving.png") + plot_three_slices(seg_fixed, save_path="./data/plots/seg_fixed.png") + plot_overlay(seg_fixed, seg_moving, save_path="./data/plots/seg_overlay.png") + + +if __name__ == "__main__": + main() diff --git a/examples/mobie_dataset_clean_up.py b/examples/mobie_dataset_clean_up.py new file mode 100644 index 0000000..75ba3b6 --- /dev/null +++ b/examples/mobie_dataset_clean_up.py @@ -0,0 +1,35 @@ +import json + +# clean up dataset.json +input_path = "./data/test/mobie_project/platy1_muscles_stardist/dataset.json" + +# Define the key you want to delete +key_to_delete = "platy1_muscles_stardist_moving_prealigned_rigid_aligned" + +# Load the JSON file +with open(input_path, "r") as file: + dataset_dict = json.load(file) + + +# Recursively delete all entries with the specified key +def delete_key_recursively(obj, key): + if isinstance(obj, dict): + # Remove the key if it exists in the current dictionary + obj.pop(key, None) + # Recurse into the dictionary values + for value in obj.values(): + delete_key_recursively(value, key) + elif isinstance(obj, list): + # Recurse into each item in the list + for item in obj: + delete_key_recursively(item, key) + + +# Call the recursive function on the loaded JSON data +delete_key_recursively(dataset_dict, key_to_delete) + +# Save the modified JSON back to a file +with open(input_path, "w") as file: + json.dump(dataset_dict, file, indent=4) + +print(f"Entries with key '{key_to_delete}' have been deleted.") diff --git a/examples/registration_test.py b/examples/registration_test.py new file mode 100644 index 0000000..fc31f18 --- /dev/null +++ b/examples/registration_test.py @@ -0,0 +1,59 @@ +import os +import z5py +import numpy as np + +from matchmaker.prealignment import prealign_sample +from matchmaker.transform_utils import rotate_img +from matchmaker.vis import plot_three_slices, plot_overlay + + +def main(): + output_dir = "./data/test" + if not os.path.exists(f"{output_dir}/plots"): + os.makedirs(f"{output_dir}/plots") + + ####################### + # prealign moving image + moving_input = "./data/platy1_muscles_stardist_moving.n5" + with z5py.File(moving_input, "r") as f: + seg_moving = f["seg"][:] + + plot_three_slices(seg_moving, save_path=f"{output_dir}/plots/moving.png") + + seg_moving_prealigned = prealign_sample(seg_moving, file_name="moving", save_path=output_dir) + + with z5py.File(f"{output_dir}/moving_prealigned.n5", "w") as f: + f.create_dataset("seg", data=seg_moving_prealigned, compression="gzip") + + ####################### + # prealign fixed image + fixed_input = "./data/platy1_muscles_stardist_fixed.n5" + with z5py.File(fixed_input, "r") as f: + seg_fixed = f["seg"][:] + + plot_three_slices(seg_fixed, save_path=f"{output_dir}/plots/fixed.png") + + seg_fixed_prealigned = prealign_sample(seg_fixed, file_name="fixed", save_path=output_dir) + + with z5py.File(f"{output_dir}/fixed_prealigned.n5", "w") as f: + f.create_dataset("seg", data=seg_fixed_prealigned, compression="gzip") + + ####################### + # test backtransform + T = np.loadtxt(f"{output_dir}/fixed_T_prealignment.txt") + seg_fixed_inv = rotate_img( + seg_fixed_prealigned, np.linalg.inv(T), output_shape=seg_fixed.shape + ) + + plot_three_slices(seg_fixed_inv, save_path=f"{output_dir}/plots/fixed_prealigned_inv.png") + + # plot overlay of prealigned volumes + plot_overlay( + seg_fixed_prealigned, + seg_moving_prealigned, + save_path=f"{output_dir}/plots/overlay_prealignment.png", + ) + + +if __name__ == "__main__": + main() diff --git a/examples/reverse_deformation.py b/examples/reverse_deformation.py new file mode 100644 index 0000000..38dc87b --- /dev/null +++ b/examples/reverse_deformation.py @@ -0,0 +1,31 @@ +import numpy as np +from matchmaker.vis import plot_three_slices, plot_overlay +from matchmaker.transform_utils import rotate_img +from matchmaker.n5_utils import read_volume + + +def main(): + ''' + Reverse the deformation of a sample by applying the inverse transformation matrix. + ''' + seg_moving = read_volume( + f="./data/platy1_muscles_stardist_moving.n5", + key="seg", + ) + + # compare with original image + seg_fixed = read_volume( + f="./data/platy1_muscles_stardist_fixed.n5", + key="seg", + ) + + T = np.loadtxt("./data/transformation_matrix.txt") + seg_moving_reverse = rotate_img(seg_moving, np.linalg.inv(T), output_shape=seg_fixed.shape) + + plot_three_slices(seg_moving_reverse) + plot_three_slices(seg_fixed) + plot_overlay(seg_fixed, seg_moving_reverse) + + +if __name__ == "__main__": + main() diff --git a/examples/view_data.py b/examples/view_data.py new file mode 100644 index 0000000..2a77a6b --- /dev/null +++ b/examples/view_data.py @@ -0,0 +1,11 @@ +from matchmaker.n5_utils import read_volume +import napari + +seg_fixed = read_volume("./data/test/platy1_muscles_stardist_fixed_prealigned.n5", key="seg") +seg_moving = read_volume("./data/test/platy1_muscles_stardist_moving_prealigned.n5", key="seg") +# seg_moving = seg_moving[:, :, ::-1] + +v = napari.Viewer() +v.add_labels(seg_fixed) +v.add_labels(seg_moving) +napari.run() \ No newline at end of file diff --git a/matchmaker/align_rigid_elastix.py b/matchmaker/align_rigid_elastix.py index 320bf46..69f03fb 100644 --- a/matchmaker/align_rigid_elastix.py +++ b/matchmaker/align_rigid_elastix.py @@ -1,134 +1,183 @@ -from pathlib import Path import numpy as np -import argparse -from aicsimageio import readers - -from platy_reg.n5_utils import read_volume, write_volume, get_attrs -from platy_reg.vis import plot_overlay - -from aicsimageio.aics_image import AICSImage +from matchmaker.n5_utils import read_volume, write_volume, get_attrs +from matchmaker.vis import plot_overlay import logging import sys - -from platy_reg.preprocessing import percentile_norm -from platy_reg.elastix_utils import * - +from matchmaker import elastix_utils +from matchmaker.mobie_export import export_to_mobie import itk -from skimage.filters import gaussian +import click +import os + -def elastix_segm_rigid_alignment(fixed_img_np, fixed_resolution, moving_img_np, moving_resolution, log_dir): +def elastix_segm_rigid_alignment( + fixed_img_np, fixed_resolution, moving_img_np, moving_resolution, output_dir +): """ Run rigid alignment of the ventral and dorsal datasets using elastix. """ - - logging.info(f"Do rigid transform of unnormalized images") + + logging.info("Do rigid transform of unnormalized images") logging.info(f"Fixed resolution {fixed_resolution}") logging.info(f"Moving resolution {moving_resolution}") - logging.info(f"Do rigid transform of unnormalized images") - + fixed_img_semantic_np = (fixed_img_np > 0).astype(np.float32) moving_img_semantic_np = (moving_img_np > 0).astype(np.float32) - - fixed_img = itk_scalar_img(fixed_img_semantic_np, fixed_resolution) - moving_img = itk_scalar_img(moving_img_semantic_np, moving_resolution) - - logging.info(f"Fixed image") + + fixed_img = elastix_utils.itk_scalar_img(fixed_img_semantic_np, fixed_resolution) + moving_img = elastix_utils.itk_scalar_img(moving_img_semantic_np, moving_resolution) + + logging.info("Fixed image") logging.info(f"{fixed_img}") - logging.info(f"Moving image") + logging.info("Moving image") logging.info(f"{moving_img}") - - # plot_overlay(fixed_img_np, moving_img_np, log_dir / "intersample_segm_overlay_before_alignment.png") - - plot_overlay(itk_to_np_order(itk.GetArrayFromImage(fixed_img)), itk_to_np_order(itk.GetArrayFromImage(moving_img)), log_dir / "intersample_segm_overlay_before_alignment.png") - - parameter_map_paths = ["pipeline_steps/inter_sample_registration/ParameterMap_segm_rigid_registration_corr.txt"] - logging.info(f"Run rigid registration") - result_image, result_transform_parameters = run_registration(fixed_img, moving_img, parameter_map_paths, str(log_dir), log_name="elastix_log_rigid.log", set_threads=True) - # serialize_parameter_object(result_transform_parameters, "ParameterMap_rigid_transform", log_dir) - - logging.info(f"Result image shape {result_image.shape}") - result_img_np = itk_to_np_order(itk.GetArrayFromImage(result_image)) - plot_overlay(itk_to_np_order(itk.GetArrayFromImage(fixed_img)), result_img_np, log_dir / f"intersample_segm_rigid_alignment_semantic.png") + parameter_map_paths = [ + "../ParameterMap_segm_rigid_registration_corr.txt" + ] + logging.info("Run rigid registration with elastix") + result_image, result_transform_parameters = elastix_utils.run_registration( + fixed_img, + moving_img, + parameter_map_paths, + output_dir, + log_name="elastix_log_rigid.log", + set_threads=True, + ) - logging.info(f"Apply transform to all channels") - result_img_np = apply_transform_chanwise(result_transform_parameters, moving_img_np, moving_resolution) + logging.info(f"Result image shape {result_image.shape}") + result_img_np = elastix_utils.itk_to_np_order(itk.GetArrayFromImage(result_image)) + plot_overlay( + elastix_utils.itk_to_np_order(itk.GetArrayFromImage(fixed_img)), + result_img_np, + f"{output_dir}/plots/overlay_after_rigid_alignment.png", + ) + # NOTE: difference between result_img_np before and after applying transform? + logging.info("Apply transform to all channels") + result_img_np = elastix_utils.apply_transform_chanwise( + result_transform_parameters, moving_img_np, moving_resolution + ) logging.info(f"Result image shape {result_img_np.shape}") - + return result_img_np -def main(): - parser = argparse.ArgumentParser( - description="""Align rigidly moving segmentation volume to fixed volume. - """ +def run_rigid_alignment( + fixed_path, + fixed_key, + moving_path, + moving_key, + output_dir, + mobie_export, + dataset_name, +): + """ + Perform rigid alignment of a moving image to a fixed image using Elastix. + + This function reads the fixed and moving images from the specified paths, + performs a rigid alignment using Elastix, and saves the aligned moving image + to the output directory. If the MoBIE export flag is set, the aligned image + is also exported to a MoBIE project. + + Args: + fixed_path (str): Path to the fixed image .n5 file. + fixed_key (str): Key to the fixed image data in the .n5 file. + moving_path (str): Path to the moving image .n5 file. + moving_key (str): Key to the moving image data in the .n5 file. + output_dir (str): Directory where the aligned image should be saved. + mobie_export (bool): Flag indicating whether to export to a MoBIE project. + dataset_name (str): Name of the dataset for the MoBIE export. + + Returns: + np.ndarray: The rigidly aligned moving image. + """ + + if not os.path.exists(f"{output_dir}/plots"): + os.makedirs(f"{output_dir}/plots") + + logging.info("Start rigid alignment") + + logging.info("Read image file") + fixed_img_np = read_volume(fixed_path, fixed_key) + + if fixed_path == moving_path: + logging.info("Same n5 for fixed and moving image, not doing registration") + moving_img_np = fixed_img_np + + else: + moving_img_np = read_volume(moving_path, moving_key) + + fixed_img_np = fixed_img_np.astype(np.float32) + moving_img_np = moving_img_np.astype(np.float32) + + logging.info("Compute rigid alignment ...") + + moving_img_np = elastix_segm_rigid_alignment( + fixed_img_np=fixed_img_np, + fixed_resolution=get_attrs(fixed_path, fixed_key)["resolution"], + moving_img_np=moving_img_np, + moving_resolution=get_attrs(moving_path, moving_key)["resolution"], + output_dir=output_dir, ) + moving_img_np = moving_img_np.astype(np.uint16) + + logging.info("Save rigid aligned moving image") + attributes = dict(get_attrs(moving_path, moving_key)) + file_name = os.path.splitext(os.path.basename(moving_path))[0] + file_name = file_name.removesuffix("_prealigned") + + write_volume( + f=f"{output_dir}/{file_name}_rigid_aligned.n5", + arr=moving_img_np, + key=moving_key, + attrs=attributes, + ) - parser.add_argument("fixed_n5", type=str, help="Path of the output n5") - parser.add_argument("fixed_key", type=str, help="Key of ventral dataset") - parser.add_argument("moving_n5", type=str, help="Path of the output n5") - parser.add_argument("moving_key", type=str, help="Key of ventral dataset") - parser.add_argument("output_n5", type=str, help="Path of the output n5") - parser.add_argument("output_key", type=str, help="Key to write aligned ventral dataset to in the output n5") - parser.add_argument("log_dir", type=str, help="Directory to store diagnostic plots etc") - parser.add_argument("log_path", type=str, help="Path to log file") - args = parser.parse_args() - - log_dir = Path(args.log_dir) - log_dir.mkdir(exist_ok=True) - - logging.basicConfig(level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(message)s", - handlers=[logging.FileHandler(args.log_path, mode="w"), - logging.StreamHandler(sys.stdout)], - datefmt='%Y-%m-%d %H:%M:%S') - - logging.info("Read image file") - fixed_img_np = read_volume(args.fixed_n5, args.fixed_key) - - # pad_z = 50 - # fixed_img_np = np.pad(fixed_img_np, pad_width=((pad_z, pad_z), (0, 0), (0, 0))) - - # fixed_mask = gaussian(fixed_img_np, sigma=[6, 12, 12]) - # fixed_mask_bin = fixed_mask > np.quantile(fixed_mask, 0.7) - # fixed_img_np = fixed_img_np * fixed_mask_bin - - - if args.fixed_n5 == args.moving_n5: - logging.info("Same n5 for fixed and moving image, not doing registration") - moving_img_np = fixed_img_np - - - else: - moving_img_np = read_volume(args.moving_n5, args.moving_key) - # moving_mask = gaussian(moving_img_np, sigma=[6, 12, 12]) - # moving_mask_bin = moving_mask > np.quantile(moving_mask, 0.7) - # moving_img_np = moving_img_np * moving_mask_bin - - fixed_img_np = fixed_img_np.astype(np.float32) - moving_img_np = moving_img_np.astype(np.float32) - - logging.info("Start registration") - - # Dirty hack for the sample size mismatch - # For no apparent reason the size of the light samples seems to be ~20% larger than the EM - moving_resolution = get_attrs(args.moving_n5, args.moving_key)["resolution"] - moving_resolution = [r * 0.9 for r in moving_resolution] - - moving_img_np = elastix_segm_rigid_alignment(fixed_img_np, - get_attrs(args.fixed_n5, args.fixed_key)["resolution"], - moving_img_np, - moving_resolution, - log_dir) - - logging.info("Write results") - attributes = dict(get_attrs(args.moving_n5, args.moving_key)) - attributes["resolution"] = get_attrs(args.fixed_n5, args.fixed_key)["resolution"] - - write_volume(args.output_n5, moving_img_np, args.output_key, chunks=(128, 512, 512), attrs=attributes) - - -if __name__=="__main__": + # export rigid alignment to mobie + if mobie_export: + logging.info("Export rigid aligned moving image to MoBIE") + export_to_mobie( + input_path=f"{output_dir}/{file_name}_rigid_aligned.n5", + input_key=moving_key, + output_dir=output_dir, + dataset_name=dataset_name, + segmentation_name=f"{file_name}_rigid_aligned", + menu_name="moving" + ) + + +@click.command() +@click.option("-fi", "--fixed_path", required=True, help="Fixed input .n5 file") +@click.option("-fk", "--fixed_key", required=True, help="Fixed input key") +@click.option("-mi", "--moving_path", required=True, help="Moving input .n5 file") +@click.option("-mk", "--moving_key", required=True, help="Moving input key") +@click.option("-o", "--output_dir", required=True, help="Output directory") +@click.option("-m", "--mobie_export", required=False, is_flag=True, help="MoBIE export") +def main(fixed_path, fixed_key, moving_path, moving_key, output_dir, mobie_export): + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[ + logging.FileHandler(f"{output_dir}/rigid_alignment.log", mode="w"), + logging.StreamHandler(sys.stdout), + ], + datefmt="%Y-%m-%d %H:%M:%S", + ) + + run_rigid_alignment( + fixed_path, + fixed_key, + moving_path, + moving_key, + output_dir, + mobie_export, + dataset_name="platy1_muscles_stardist", + ) + + +if __name__ == "__main__": main() - - + +# python align_rigid_elastix.py -fi ../examples/data/test/platy1_muscles_stardist_fixed_prealigned.n5 -fk seg +# -mi ../examples/data/test/platy1_muscles_stardist_moving_prealigned.n5 -mk seg -o ../examples/data/test -m diff --git a/matchmaker/apply.py b/matchmaker/apply.py index e69de29..f4a22ab 100644 --- a/matchmaker/apply.py +++ b/matchmaker/apply.py @@ -0,0 +1 @@ +# apply transformations \ No newline at end of file diff --git a/matchmaker/compute_registration.py b/matchmaker/compute_registration.py new file mode 100644 index 0000000..ae52bc8 --- /dev/null +++ b/matchmaker/compute_registration.py @@ -0,0 +1,92 @@ +import os +import sys +import click +import logging +from matchmaker.mobie_export import create_mobie_project +from matchmaker.prealignment import run_prealignment +from matchmaker.align_rigid_elastix import run_rigid_alignment + + +@click.command() +@click.option("-fi", "--fixed_path", required=True, help="Fixed input .n5 file") +@click.option("-fk", "--fixed_key", required=True, help="Fixed input key") +@click.option("-mi", "--moving_path", required=True, help="Moving input .n5 file") +@click.option("-mk", "--moving_key", required=True, help="Moving input key") +@click.option("-o", "--output_dir", required=True, help="Output directory") +@click.option("-m", "--mobie_export", required=False, is_flag=True, help="MoBIE export") +def main(fixed_path, fixed_key, moving_path, moving_key, output_dir, mobie_export): + """ + Main function to perform registration of moving image to fixed image. + + This function orchestrates the sequence of steps required to register a moving image to a fixed image. + It handles the creation of necessary directories, configures logging, and invokes prealignment and + rigid alignment functions. Optionally, it can create a MoBIE project for visualization. + + Args: + fixed_path (str): Path to the fixed input .n5 file. + fixed_key (str): Key to the fixed image data in the .n5 file. + moving_path (str): Path to the moving input .n5 file. + moving_key (str): Key to the moving image data in the .n5 file. + output_dir (str): Directory where the results should be saved. + mobie_export (bool): Flag indicating whether to export results to a MoBIE project. + """ + + if not os.path.exists(f"{output_dir}/plots"): + os.makedirs(f"{output_dir}/plots") + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[ + logging.FileHandler(f"{output_dir}/registration.log", mode="w"), + logging.StreamHandler(sys.stdout), + ], + datefmt="%Y-%m-%d %H:%M:%S", + ) + + dataset_name = "platy1_muscles_stardist" + if mobie_export: + create_mobie_project( + fixed_path, + fixed_key, + moving_path, + moving_key, + output_dir, + dataset_name + ) + + run_prealignment( + fixed_path, + fixed_key, + moving_path, + moving_key, + output_dir, + mobie_export, + dataset_name, + ) + + fixed_prealigned_path = f"{output_dir}/{os.path.splitext(os.path.basename(fixed_path))[0]}_prealigned.n5" + moving_prealigned_path = f"{output_dir}/{os.path.splitext(os.path.basename(moving_path))[0]}_prealigned.n5" + + run_rigid_alignment( + fixed_prealigned_path, + fixed_key, + moving_prealigned_path, + moving_key, + output_dir, + mobie_export, + dataset_name, + ) + + # fixed_rigid_aligned_path = f"{output_dir}/{os.path.splitext(os.path.basename(fixed_path))[0]}_rigid_aligned.n5" + # moving_rigid_aligned_path = f"{output_dir}/{os.path.splitext(os.path.basename(moving_path))[0]}_rigid_aligned.n5" + + # run_cpd(...) + + +if __name__ == "__main__": + main() + + +# python compute_registration.py -fi ../examples/data/platy1_muscles_stardist_fixed.n5 -fk seg +# -mi ../examples/data/platy1_muscles_stardist_moving.n5 -mk seg -o ../examples/data/test -m \ No newline at end of file diff --git a/matchmaker/data.py b/matchmaker/data.py index 8ce95a1..a9739f7 100644 --- a/matchmaker/data.py +++ b/matchmaker/data.py @@ -4,11 +4,12 @@ def create_point_cloud(segm): - props = regionprops_table(segm, properties=("label", 'centroid')) + props = regionprops_table(segm, properties=("label", "centroid")) props = pd.DataFrame(props) segm_labels = props["label"].to_numpy() centroid_columns = [col for col in props.columns if col.startswith("centroid")] pos = props[centroid_columns].to_numpy().astype(np.float32) + # TODO: center, maybe w/o labels return pos, segm_labels @@ -22,7 +23,7 @@ def write_pcd(pcd_df, pcd_path): pcd_df.to_csv(pcd_path, index=False) -def pcd_np_to_df(pos:np.array, segm_labels=None, reg_gt=None, matching=None): +def pcd_np_to_df(pos: np.array, segm_labels=None, reg_gt=None, matching=None): # print("pos len", len(pos)) pcd_df = pd.DataFrame() N = pos.shape[0] @@ -32,11 +33,15 @@ def pcd_np_to_df(pos:np.array, segm_labels=None, reg_gt=None, matching=None): pcd_df[f"coord_{col}"] = pos[:, col] if segm_labels is not None: - assert len(segm_labels) == N, print(f"Labels of shape {segm_labels.shape} don't correspond to coordinates of shape {pos.shape}") + assert len(segm_labels) == N, print( + f"Labels of shape {segm_labels.shape} don't correspond to coordinates of shape {pos.shape}" + ) pcd_df["segm_labels"] = segm_labels if reg_gt is not None: - assert len(reg_gt) == N, print(f"Labels of shape {reg_gt.shape} don't correspond to coordinates of shape {pos.shape}") + assert len(reg_gt) == N, print( + f"Labels of shape {reg_gt.shape} don't correspond to coordinates of shape {pos.shape}" + ) pcd_df["reg_gt"] = reg_gt pcd_df.set_index("order_idx", drop=False) @@ -67,12 +72,12 @@ def index_pairs_to_matching(): def write_index_pairs(pairs, pairs_path): - with open(pairs_path,"w+") as f: + with open(pairs_path, "w+") as f: for pair in pairs: - f.write(str(pair[0]) + "," + str(pair[1]) + '\n') + f.write(str(pair[0]) + "," + str(pair[1]) + "\n") def read_index_pairs(pairs_path): - with open(pairs_path,"r") as f: - pairs = [(int(l.split(",")[0]), int(l.split(",")[1])) for l in f.readlines()] + with open(pairs_path, "r") as f: + pairs = [(int(loc.split(",")[0]), int(loc.split(",")[1])) for loc in f.readlines()] return pairs diff --git a/matchmaker/elastix_utils.py b/matchmaker/elastix_utils.py index af791dd..55c2b4c 100644 --- a/matchmaker/elastix_utils.py +++ b/matchmaker/elastix_utils.py @@ -1,19 +1,7 @@ from pathlib import Path import numpy as np -import argparse -from aicsimageio import readers - -from platy_reg.n5_utils import read_volume, write_volume, get_attrs -from platy_reg.vis import plot_overlay - -from aicsimageio.aics_image import AICSImage import logging -import sys - -import re - import itk -from platy_reg.preprocessing import percentile_norm def initial_alignment(ventral_img_np, dorsal_img_np): @@ -33,12 +21,12 @@ def itk_scalar_img(img: np.array, resolution, ch=0): ch: _description_. Defaults to 0. Returns: - Scalar ITK Image object for one of the channels. + Scalar ITK Image object for one of the channels. """ logging.info(f"Converting image of type {img.dtype} to ITK object") if img.ndim == 4: img = np.moveaxis(img[ch, :, :, :], 0, -1) - + else: img = np.moveaxis(img[:, :, :], 0, -1) @@ -57,7 +45,7 @@ def itk_to_np_order(img: np.array): img = np.swapaxes(img, 0, 3) img = np.swapaxes(img, 1, 2) return np.flip(img) - + def np_to_itk_order(img: np.array): # In numpy: CZYX @@ -68,26 +56,31 @@ def np_to_itk_order(img: np.array): img = np.swapaxes(img, 0, 3) img = np.swapaxes(img, 1, 2) return img - + def create_parameter_object(parameter_map_paths): parameter_object = itk.ParameterObject.New() for parameter_map_path in parameter_map_paths: parameter_object.AddParameterFile(parameter_map_path) print(parameter_object) - + return parameter_object -def run_registration(fixed_img, moving_img, parameter_map_paths, log_dir, log_name="elastix.log", set_threads=False): +def run_registration( + fixed_img, + moving_img, + parameter_map_paths, + log_dir, + log_name="elastix.log", + set_threads=False, +): logging.info("Start creating parameter object") parameter_object = create_parameter_object(parameter_map_paths) # Load Elastix Image Filter Object elastix_object = itk.ElastixRegistrationMethod.New(fixed_img, moving_img) logging.info(elastix_object) logging.info("Created registration object") - # elastix_object.SetFixedImage(fixed_image) - # elastix_object.SetMovingImage(moving_image) elastix_object.SetParameterObject(parameter_object) if set_threads: elastix_object.SetNumberOfThreads(32) @@ -96,17 +89,26 @@ def run_registration(fixed_img, moving_img, parameter_map_paths, log_dir, log_na elastix_object.SetOutputDirectory(log_dir) elastix_object.SetLogFileName(log_name) logging.info("Set parameter map") - + logging.info("Start registration") elastix_object.UpdateLargestPossibleRegion() - + result_image = elastix_object.GetOutput() result_transform_parameters = elastix_object.GetTransformParameterObject() - + return result_image, result_transform_parameters -def run_pointset_registration(fixed_img, moving_img, parameter_map_paths, fixed_pointset, moving_pointset, log_dir, log_name="elastix.log", set_threads=False): +def run_pointset_registration( + fixed_img, + moving_img, + parameter_map_paths, + fixed_pointset, + moving_pointset, + log_dir, + log_name="elastix.log", + set_threads=False, +): logging.info("Start creating parameter object") parameter_object = create_parameter_object(parameter_map_paths) # Load Elastix Image Filter Object @@ -120,17 +122,17 @@ def run_pointset_registration(fixed_img, moving_img, parameter_map_paths, fixed_ elastix_object.SetLogToFile(True) elastix_object.SetOutputDirectory(log_dir) elastix_object.SetLogFileName(log_name) - + logging.info("Created registration object") logging.info(elastix_object) logging.info("Set parameter map") - + logging.info("Start registration") elastix_object.UpdateLargestPossibleRegion() - + result_image = elastix_object.GetOutput() result_transform_parameters = elastix_object.GetTransformParameterObject() - + return result_image, result_transform_parameters @@ -138,9 +140,11 @@ def serialize_parameter_object(parameter_object, prefix, write_dir): write_dir = Path(write_dir) for index in range(parameter_object.GetNumberOfParameterMaps()): parameter_map = parameter_object.GetParameterMap(index) - parameter_object.WriteParameterFile(parameter_map, write_dir / f"{prefix}.{index}.txt") - - + parameter_object.WriteParameterFile( + parameter_map, write_dir / f"{prefix}_{index}.txt" + ) + + def deserialize_parameter_object(prefix, cur_dir=Path("./")): parameter_files = sorted(list(cur_dir.glob(f"{prefix}*.txt"))) parameter_files = [str(fname) for fname in parameter_files] @@ -154,7 +158,7 @@ def create_transformix_object(transform_parameter_object): transformix_filter = itk.TransformixFilter[ImageType].New() transformix_filter.SetTransformParameterObject(transform_parameter_object) return transformix_filter - + def apply_transform(transformix_filter, moving_img): transformix_filter.SetMovingImage(moving_img) @@ -164,10 +168,11 @@ def apply_transform(transformix_filter, moving_img): output_img_np = itk_to_np_order(output_img_np) return output_img_np + def apply_transform_chanwise(transform_parameter_object, moving_img_np, resolution): - transformix_filter = create_transformix_object(transform_parameter_object) + transformix_filter = create_transformix_object(transform_parameter_object) result_img = [] - + if moving_img_np.ndim == 4: for chan in range(moving_img_np.shape[0]): moving_img = itk_scalar_img(moving_img_np, resolution, chan) @@ -177,5 +182,5 @@ def apply_transform_chanwise(transform_parameter_object, moving_img_np, resoluti else: moving_img = itk_scalar_img(moving_img_np, resolution, 0) result_img = apply_transform(transformix_filter, moving_img) - + return result_img diff --git a/matchmaker/mobie_export.py b/matchmaker/mobie_export.py new file mode 100644 index 0000000..f8ba736 --- /dev/null +++ b/matchmaker/mobie_export.py @@ -0,0 +1,138 @@ +import os +import json +import logging +import mobie +from matchmaker.n5_utils import get_attrs + + +def update_default_view(dataset_json_path, new_segmentation_name): + """ + Update the name and sources for the segmentation in the 'default' view + in a MoBIE dataset.json file. + + Args: + dataset_json_path (str): Path to the dataset.json file. + new_segmentation_name (str): New segmentation name to set in the default view. + """ + if not os.path.exists(dataset_json_path): + raise FileNotFoundError(f"Could not find: {dataset_json_path}") + + with open(dataset_json_path, "r") as f: + data = json.load(f) + + views = data.get("views", {}) + default_view = views.get("default", {}) + + source_displays = default_view.get("sourceDisplays", []) + for display in source_displays: + if "segmentationDisplay" in display: + display["segmentationDisplay"]["name"] = new_segmentation_name + display["segmentationDisplay"]["sources"] = [new_segmentation_name] + + # Save the updated dataset.json + with open(dataset_json_path, "w") as f: + json.dump(data, f, indent=2) + + logging.info(f"Updated default view to use segmentation: '{new_segmentation_name}'") + + +def create_mobie_project( + fixed_input_path, + fixed_key, + moving_input_path, + moving_key, + output_dir, + dataset_name +): + """ + Create initial MoBIE project with the fixed and moving images. + + Args: + fixed_input_path (str): Path to the fixed image n5 file. + fixed_key (str): Key to the fixed image data in the n5 file. + moving_input_path (str): Path to the moving image n5 file. + moving_key (str): Key to the moving image data in the n5 file. + output_dir (str): Directory where the MoBIE project should be saved. + """ + + # export fixed image to MoBIE + fixed_file_name = os.path.splitext(os.path.basename(fixed_input_path))[0] + export_to_mobie( + fixed_input_path, + fixed_key, + output_dir, + dataset_name, + segmentation_name=f"{fixed_file_name}_original", + menu_name="fixed", + ) + + # export moving image to MoBIE + moving_file_name = os.path.splitext(os.path.basename(moving_input_path))[0] + export_to_mobie( + moving_input_path, + moving_key, + output_dir, + dataset_name, + segmentation_name=f"{moving_file_name}_original", + menu_name="moving", + ) + logging.info("Created initial MoBIE project with fixed and moving images.") + + +def export_to_mobie(input_path, input_key, output_dir, dataset_name, segmentation_name, menu_name): + """ + Export segmentation from n5 file to MoBIE project. + + Args: + input_path (str): Path to the n5 file containing the segmentation. + input_key (str): Key to the segmentation data in the n5 file. + output_dir (str): Directory where the MoBIE project should be saved. + dataset_name (str): Name of the MoBIE dataset. + segmentation_name (str): Name of the segmentation in the MoBIE project. + menu_name (str): Name of the menu in the MoBIE project. + """ + if not os.path.exists(f"{output_dir}/mobie_project"): + os.makedirs(f"{output_dir}/mobie_project") + + # Set parameters for MOBIE + mobie_folder = f"{output_dir}/mobie_project" + resolution = get_attrs(input_path, input_key)["resolution"] + chunks = (64, 64, 64) + scale_factors = 4 * [[2, 2, 2]] + + mobie.add_segmentation( + input_path=input_path, + input_key=input_key, + root=mobie_folder, + dataset_name=dataset_name, + segmentation_name=segmentation_name, + resolution=resolution, + scale_factors=scale_factors, + chunks=chunks, + menu_name=menu_name, + file_format="ome.zarr", + is_default_dataset=True + ) + logging.info(f"Added segmentation: {segmentation_name}") + + +def main(): + input_path = "../examples/CLI_test/platy1_muscles_stardist_fixed_prealigned.n5" + input_key = "seg" + output_dir = "../examples/data/test" + + file_name = os.path.splitext(os.path.basename(input_path))[0] + + export_to_mobie( + input_path, + input_key, + output_dir, + dataset_name="platy1_muscles_stardist", + segmentation_name=f"{file_name}_prealigned", + menu_name="fixed", + ) + print(f"MoBIE project created at {output_dir}/mobie_project") + + +if __name__ == "__main__": + main() diff --git a/matchmaker/n5_utils.py b/matchmaker/n5_utils.py new file mode 100644 index 0000000..aea3fe4 --- /dev/null +++ b/matchmaker/n5_utils.py @@ -0,0 +1,66 @@ +import z5py +from pathlib import PurePath +import numpy as np + + +def print_key_tree(f: z5py.File): + print(f"Key structure of z5 file {f.filename}") + f.visititems(lambda name, obj: print(name)) + + +def read_volume( + f: z5py.File, key: str, roi: np.lib.index_tricks.IndexExpression = np.s_[:] +): + if isinstance(f, (str, PurePath)): + f = z5py.File(f, "r") + + try: + ds = f[key] + except KeyError: + print(f"No key {key} in file {f.filename}") + print_key_tree(f) + return None + + ds.n_threads = 8 + vol = ds[roi] + + return vol + + +def get_attrs(f: z5py.File, key: str): + if isinstance(f, (str, PurePath)): + f = z5py.File(f, "a") + + try: + ds = f[key] + except KeyError: + print(f"No key {key} in file {f.filename}") + print_key_tree(f) + return None + + return ds.attrs + + +def write_volume(f, arr: np.array, key, chunks=(1, 512, 512), attrs=None): + shape = arr.shape + compression = "gzip" + dtype = arr.dtype + + if isinstance(f, (str, PurePath)): + f = z5py.File(f, "a") + + if key not in f.keys(): + ds = f.create_dataset( + key, shape=shape, compression=compression, chunks=chunks, dtype=dtype + ) + else: + ds = f[key] + + ds.n_threads = 8 + ds[:] = arr + + print(f"Dataset {key} written to {f.filename}") + + if attrs is not None: + for key in attrs.keys(): + ds.attrs[key] = attrs[key] diff --git a/matchmaker/prealignment.py b/matchmaker/prealignment.py old mode 100644 new mode 100755 index 40b8804..81ddef9 --- a/matchmaker/prealignment.py +++ b/matchmaker/prealignment.py @@ -1,164 +1,382 @@ -from pathlib import Path import numpy as np import matplotlib.pyplot as plt -from platy_reg.vis import plot_overlay, plot_three_slices import logging -from platy_reg.preprocessing import convert_to_point_cloud -from scipy.ndimage import rotate -from skimage.filters import gaussian +import os +import click +import sys +from matchmaker.data import create_point_cloud +from matchmaker.mobie_export import export_to_mobie, update_default_view +from matchmaker.transform_utils import get_transformation_matrix, rotate_img +from matchmaker.n5_utils import read_volume, get_attrs, write_volume +from matchmaker.vis import plot_three_slices, plot_overlay -def get_SVD_transform(img, plot_path=None, percentile_trsh=90): + +def get_SVD_transform(img, save_path=None): """Convert image to point cloud by thresholding, then run SVD on resulting point cloud. Args: img: _description_ plot_path: _description_. Defaults to None. - percentile_trsh: _description_. Defaults to 90. Returns: Variance matrix and principal axes matrix. """ - trsh = np.percentile(img, percentile_trsh) - logging.info(f"Threshold used for converting to point cloud: {trsh}") - X = convert_to_point_cloud(img, trsh) - gc = X.mean(axis=0) + + pos, _ = create_point_cloud(img) + gc = pos.mean(axis=0) gc = np.array(img.shape) // 2 - X_c = X - gc - logging.info(f"Point cloud shape {X.shape}") + pos_c = pos - gc + logging.info(f"Point cloud shape {pos.shape}") logging.info(f"Point cloud center {gc}") - - if X.shape[1] == 3: - random_subset = np.random.choice(X.shape[0], 10000, replace=False) - X_subset = X_c[random_subset, 1:] - else: - X_subset = X_c[::100] - logging.info("Subset point cloud") - logging.info(f"Subset point cloud shape {X_subset.shape}") - - logging.info("Run SVD") - U, S, Vt = np.linalg.svd(X_subset) - # logging.info("U") - # logging.info(U) + + logging.info("Run SVD ...") + U, S, Vt = np.linalg.svd(pos_c, full_matrices=False) + logging.info("U") + logging.info(str(U)) logging.info("S") logging.info(str(S)) logging.info("Vt") logging.info(str(Vt)) - + logging.info("Rotate point cloud") - vr = X_subset @ Vt.T - + vr = pos_c @ Vt.T + plt.figure(figsize=(10, 5)) - plt.subplot(1,2,1) - plt.title('original vertices') - plt.scatter(X_subset[:, 0], X_subset[:, 1], alpha=0.2) - plt.subplot(1,2,2) - plt.title('rotated vertices') + plt.subplot(1, 2, 1) + plt.title("Original Vertices") + plt.scatter(pos_c[:, 0], pos_c[:, 1], alpha=0.2) + plt.subplot(1, 2, 2) + plt.title("Rotated Vertices") plt.scatter(vr[:, 0], vr[:, 1], alpha=0.1) - if plot_path: - plt.savefig(plot_path, dpi=300) - + if save_path: + plt.savefig(save_path, dpi=300) + return gc, Vt - - -def orient_head(img, plot_path=None): - """Euristic to orient all samples "head up": calculate sum intensity profile along the Y axis, if max is closer to 0 then do nothing, else rotate 180 degrees. + + +def orient_axis(img, axis, save_path=None): + """ + Plot the sum intensity along the given axis and return True if the maximum is closer to the + upper boundary than the lower boundary, False otherwise. Args: - img: DAPI volume - + img: 3D image + axis: Axis to sum along + save_path: Path to save the plot to. If None, show the plot instead. + Returns: - True if rotation is needed. + True if the maximum is closer to the upper boundary than the lower boundary, False otherwise. """ assert img.ndim == 3, f"Input image should have 3 dimensions, has {img.ndim}" - int_profile = np.sum(img, axis=(0, 2)) - + if axis == 0: + int_profile = np.sum(img, axis=(1, 2)) # profile along z + if axis == 1: + int_profile = np.sum(img, axis=(0, 2)) # profile along y + if axis == 2: + int_profile = np.sum(img, axis=(0, 1)) # profile along x + plt.figure() plt.plot(int_profile) - plt.xlabel("Coordinate") - plt.ylabel("Sum intensity along Y axis") - plt.savefig(plot_path, dpi=300) - + plt.xlabel(f"Axis {axis} Coordinate") + plt.ylabel(f"Sum intensity along axis = {axis}") + if save_path is not None: + plt.savefig(save_path, dpi=300) + else: + plt.show() + max_pos = int_profile.argmax() - logging.info(f"Max position is {max_pos}, dimension shape is {img.shape[1]}") - if max_pos < img.shape[1] // 2: - logging.info("Correct head orientation") + logging.info(f"Max position is {max_pos}, dimension shape is {img.shape[axis]}") + if max_pos < img.shape[axis] // 2: return False else: - logging.info("Rotate 180 degree to align head position") return True - -def orient_sample(img, dapi_chan, plot_path, dorsal=False): - """Preliminary orientation of the samples with body axis along Y, head closer to 0. + +def prealign_sample(img, file_name, save_path): + """ + Pre-align a sample segmentation with its principal components. Args: - img: input volume - dapi_chan: channel to use for registration - dorsal: if True, rotate around Y to align with ventral in Z direction + img: Segmentation volume + file_name: Name of the sample + save_path: Folder to save results Returns: - oriented image + Pre-aligned segmentation volume. """ - - # Rotate around Y if the volume is dorsal - plot_path = Path(plot_path) - - if dorsal: - logging.info(f"Rotated dorsal sample to align Z") - img = img[:, ::-1, :, ::-1] - - - # Rotate in XY plane, because samples can be oriented randomly, not only along X or Y - max_proj = np.max(img, axis=1)[dapi_chan, ...] - plt.figure() - plt.imshow(max_proj, cmap="Reds") - plt.savefig(plot_path / "max_proj_input.png") - - logging.info(f"Smooth image with sigma={2}") - img_smoothed = gaussian(img[dapi_chan, ::10, ::10, ::10], sigma=3) - gc, Vt = get_SVD_transform(img_smoothed, plot_path / "max_proj_point_cloud_random_angle.png", percentile_trsh=90) - rot_angle = 90 - np.degrees(np.arctan2(Vt[0, 1], Vt[0, 0])) - logging.info(f"Rotation angle to correct for random angle is {rot_angle}") - img = rotate(img, rot_angle, axes=[-2, -1], order=3, mode="constant") - logging.info(f"Rotated the input volume around Z by {rot_angle} degrees") - - plt.figure() - plt.imshow(np.max(img, axis=1)[dapi_chan, ...], alpha=0.5, cmap="Blues") - plt.savefig(plot_path / "max_proj_rotated_random_angle.png", dpi=300) - - - # Check again if the sample is along X or along Y - # Make max projection and determine the direction of principal axes - max_proj = np.max(img, axis=1)[dapi_chan, ...] - img_smoothed = gaussian(img[dapi_chan, ::10, ::10, ::10], sigma=2) - gc, Vt = get_SVD_transform(img_smoothed, plot_path / "max_proj_point_cloud.png") - - rot_angle = np.arccos(Vt[0, 0]) - logging.info(f"Rotation angle is {rot_angle}") - if np.abs(rot_angle - np.pi / 2) < np.pi / 4: - logging.info("Rotate 90 degrees") - plt.figure() - plt.imshow(max_proj, cmap="Reds") - img_rot = np.rot90(img, axes=(2, 3)) - - elif np.abs(rot_angle) < np.pi / 4: - logging.info("Rotation is already correct") - img_rot = img - else: - logging.info(f"90 degree rotation can't be determined with first principal axis angle of {rot_angle * 180/ np.pi}") - img_rot = img - - # Check if head is oriented correctly, else rotate 180 degrees - - if orient_head(img_rot[dapi_chan, ...], plot_path / "sum_intensity_profile.png"): - img_reg = np.rot90(img_rot, k=2, axes=(2, 3)) + gc, Vt = get_SVD_transform(img) + T, new_shape = get_transformation_matrix(img, gc, Vt, save_path=f"{save_path}/{file_name}_T_prealignment.txt") + img_rotated = rotate_img(img, T, output_shape=new_shape) + + if np.linalg.det(Vt.T) < 0: + logging.warning("V includes a reflection (mirroring)") + R_3x3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]]) + logging.info("Mirror back...") + img_center = 0.5 * (np.array(img_rotated.shape)-1) + offset = img_center - R_3x3 @ img_center + R = np.eye(4) + R[:3, :3] = R_3x3 + R[:3, 3] = offset + img_rotated = rotate_img(img_rotated, R, output_shape=img_rotated.shape) + + # update transformation matrix + T = T @ R + np.savetxt(f"{save_path}/{file_name}_T_prealignment.txt", T) + + return img_rotated, T + + +def run_prealignment( + fixed_path, + fixed_key, + moving_path, + moving_key, + output_dir, + mobie_export, + dataset_name +): + """ + Run prealignment of a fixed and moving 3D image volume. + + This function reads two volumetric datasets (a fixed and a moving image), + performs prealignment to roughly register them into a common space, and + saves diagnostic plots, transformation matrices, and prealigned volumes. + It also checks axis orientation consistency between the two images and + applies corrective rotations if necessary. Optionally, the results can be + exported into a MoBIE project for interactive visualization. + + Steps performed: + 1. Load fixed and moving volumes. + 2. Plot reference slices and overlays before alignment. + 3. Apply prealignment to both volumes. + 4. Check and correct axis orientations if required. + 5. Save prealigned volumes and transformation matrices. + 6. Generate plots before and after prealignment. + 7. Optionally export results to a MoBIE project. + + Parameters + ---------- + fixed_path : str + Path to the fixed volume file (e.g. N5, OME-Zarr). + fixed_key : str + Dataset key inside the fixed volume file. + moving_path : str + Path to the moving volume file (e.g. N5, OME-Zarr). + moving_key : str + Dataset key inside the moving volume file. + output_dir : str + Directory where outputs (plots, volumes, transformations) will be saved. + mobie_export : bool + If True, export prealigned images to a MoBIE project. + dataset_name : str + Name of the MoBIE dataset (used only if `mobie_export=True`). + + Outputs + ------- + - Plots of slices and overlays before and after prealignment, saved in + ``{output_dir}/plots/``. + - Prealigned fixed and moving volumes saved as N5 containers in + ``{output_dir}/``. + - Transformation matrix for the moving image saved as a text file. + - (Optional) Exported MoBIE project with updated views. + + Notes + ----- + - The function assumes the input volumes are large 3D datasets. + - Axis orientation is checked via intensity profile analysis; axes may be + flipped by 180° if misaligned. + - The MoBIE export modifies the `dataset.json` to set the prealigned fixed + volume as the default view. + """ + if not os.path.exists(f"{output_dir}/plots"): + os.makedirs(f"{output_dir}/plots") + + logging.info("Start prealignment") + logging.info("Start prealignment of fixed image ...") + fixed_img = read_volume(fixed_path, fixed_key) + fixed_file_name = os.path.splitext(os.path.basename(fixed_path))[0] + + plot_three_slices( + fixed_img, + save_path=f"{output_dir}/plots/{fixed_file_name}_fixed.png" + ) + + fixed_prealigned, _ = prealign_sample( + fixed_img, + fixed_file_name, + output_dir, + ) + + logging.info("Start prealignment of moving image ...") + moving_img = read_volume(moving_path, moving_key) + moving_file_name = os.path.splitext(os.path.basename(moving_path))[0] + + plot_three_slices( + moving_img, + save_path=f"{output_dir}/plots/{moving_file_name}_moving.png" + ) + + plot_overlay( + fixed_img, + moving_img, + save_path=f"{output_dir}/plots/overlay_before.png", + ) + + moving_prealigned, T_moving = prealign_sample( + moving_img, + moving_file_name, + output_dir, + ) + + # check orientation (if moving fits to fixed) + logging.info("Check axis orientation ...") + R_3x3 = np.eye(3) + change_orientation = False + for axis in range(3): + rotate_axis_fixed = orient_axis( + fixed_prealigned, + axis=axis, + save_path=f"{output_dir}/plots/fixed_prealigned_intensity_profile_{axis}.png", + ) + rotate_axis_moving = orient_axis( + moving_prealigned, + axis=axis, + save_path=f"{output_dir}/plots/moving_prealigned_intensity_profile_{axis}.png", + ) + + if rotate_axis_fixed or rotate_axis_moving: + print(f"Rotate axis {axis} 180 degrees to align...") + change_orientation = True + R_3x3[axis, axis] = -1 + else: + print(f"Correct orientation in axis {axis}.") + + if not change_orientation: + logging.info("Correct orientation.") else: - img_reg = img_rot - - logging.info(f"Final image shape is {img_reg.shape}") - plt.figure() - plt.imshow(np.max(img_reg, axis=1)[dapi_chan, ...], alpha=0.5, cmap="Blues") - plt.savefig(plot_path / "max_proj_rotated_final.png", dpi=300) - - return img_reg \ No newline at end of file + logging.info("Rotate moving image to align orientation...") + logging.info(str(R_3x3)) + img_center = 0.5 * (np.array(moving_prealigned.shape)-1) + offset = img_center - R_3x3 @ img_center + R = np.eye(4) + R[:3, :3] = R_3x3 + R[:3, 3] = offset + moving_prealigned = rotate_img(moving_prealigned, R, output_shape=moving_prealigned.shape) + + # update transformation matrix + T_moving = T_moving @ R + np.savetxt(f"{output_dir}/{moving_file_name}_T_prealignment.txt", T_moving) + + logging.info("Prealignment done.") + + logging.info("Save prealigned fixed image ...") + attributes = dict(get_attrs(fixed_path, fixed_key)) + write_volume( + f=f"{output_dir}/{fixed_file_name}_prealigned.n5", + arr=fixed_prealigned, + key=fixed_key, + attrs=attributes + ) + + logging.info("Save prealigned moving image ...") + attributes = dict(get_attrs(moving_path, moving_key)) + write_volume( + f=f"{output_dir}/{moving_file_name}_prealigned.n5", + arr=moving_prealigned, + key=moving_key, + attrs=attributes + ) + + plot_three_slices( + fixed_prealigned, + save_path=f"{output_dir}/plots/{fixed_file_name}_prealigned.png" + ) + + plot_three_slices( + moving_prealigned, + save_path=f"{output_dir}/plots/{moving_file_name}_prealigned.png" + ) + + plot_overlay( + fixed_prealigned, + moving_prealigned, + save_path=f"{output_dir}/plots/overlay_after_prealignment.png", + ) + + if mobie_export: + logging.info("Export prealigned images to MoBIE ...") + export_to_mobie( + input_path=f"{output_dir}/{fixed_file_name}_prealigned.n5", + input_key=fixed_key, + output_dir=output_dir, + dataset_name=dataset_name, + segmentation_name=f"{fixed_file_name}_prealigned", + menu_name="fixed" + ) + logging.info("Change default dataset to prealigned fixed image.") + update_default_view( + dataset_json_path=f"{output_dir}/mobie_project/{dataset_name}/dataset.json", + new_segmentation_name=f"{fixed_file_name}_prealigned" + ) + + export_to_mobie( + input_path=f"{output_dir}/{moving_file_name}_prealigned.n5", + input_key=moving_key, + output_dir=output_dir, + dataset_name=dataset_name, + segmentation_name=f"{moving_file_name}_prealigned", + menu_name="moving" + ) + logging.info("MoBIE export done.") + + +@click.command() +@click.option("-fi", "--fixed_path", required=True, help="Fixed input .n5 file") +@click.option("-fk", "--fixed_key", required=True, help="Fixed input key") +@click.option("-mi", "--moving_path", required=True, help="Moving input .n5 file") +@click.option("-mk", "--moving_key", required=True, help="Moving input key") +@click.option("-o", "--output_dir", required=True, help="Output directory") +@click.option("-m", "--mobie_export", required=False, is_flag=True, help="MoBIE export") +def main(fixed_path, fixed_key, moving_path, moving_key, output_dir, mobie_export): + """ + Perform prealignment of moving image to fixed image. + + This function orchestrates the sequence of steps required to prealign a moving image to a fixed image. + It handles the creation of necessary directories, configures logging, and invokes prealignment functions. + Optionally, it can create a MoBIE project for visualization. + + Args: + fixed_path (str): Path to the fixed input .n5 file. + fixed_key (str): Key to the fixed image data in the .n5 file. + moving_path (str): Path to the moving input .n5 file. + moving_key (str): Key to the moving image data in the .n5 file. + output_dir (str): Directory where the results should be saved. + mobie_export (bool): Flag indicating whether to export results to a MoBIE project. + + Returns: + None + """ + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[ + logging.FileHandler(f"{output_dir}/prealignment.log", mode="w"), + logging.StreamHandler(sys.stdout), + ], + datefmt="%Y-%m-%d %H:%M:%S", + ) + + run_prealignment( + fixed_path, + fixed_key, + moving_path, + moving_key, + output_dir, + mobie_export, + dataset_name="platy1_muscles_stardist", + ) + + +if __name__ == "__main__": + main() diff --git a/matchmaker/preprocessing.py b/matchmaker/preprocessing.py new file mode 100644 index 0000000..5f51629 --- /dev/null +++ b/matchmaker/preprocessing.py @@ -0,0 +1,31 @@ +import numpy as np + + +def zero_mean_unit_variance(img): + return (img - np.mean(img)) / np.std(img) + + +def percentile_norm(img, pmin, pmax, eps=1e-10, channelwise=False): + """ + Percentile normalization. Image is supposed to be C(Z)YX + + Args: + img: _description_ + pmin: in percents + pmax: in percents + eps: _description_. Defaults to 1e-10. + + Returns: + _description_ + """ + if channelwise: + norm_img = np.zeros_like(img) + for chan in range(norm_img.shape[0]): + pmin_val = np.percentile(img[chan, ...], pmin) + pmax_val = np.percentile(img[chan, ...], pmax) + norm_img[chan, ...] = (img[chan, ...] - pmin_val) / (pmax_val - pmin_val + eps) + return norm_img + else: + pmin = np.percentile(img, pmin) + pmax = np.percentile(img, pmax) + return (img - pmin) / (pmax - pmin + eps) diff --git a/matchmaker/transform_utils.py b/matchmaker/transform_utils.py new file mode 100644 index 0000000..f05878f --- /dev/null +++ b/matchmaker/transform_utils.py @@ -0,0 +1,147 @@ +import numpy as np +from scipy.ndimage import affine_transform +import transforms3d as tf3d +from elf.wrapper.resized_volume import ResizedVolume + + +def downscale_seg(seg, factor): + new_shape = np.array(seg.shape) // factor + downsampled_seg = ResizedVolume(seg, shape=new_shape)[:] + print("Downsampled shape", downsampled_seg.shape) + + return downsampled_seg + + +def pad_img(img): + shape = img.shape + diagonal = int(np.ceil(np.linalg.norm(shape))) + print("Diagonal", diagonal) + + # Calculate how much padding is needed on each axis + pad_widths = [] + for dim in shape: + total_pad = diagonal - dim + before = total_pad // 2 + after = total_pad - before + pad_widths.append((before, after)) + + # Apply zero padding + padded = np.pad(img, pad_width=pad_widths, mode='constant', constant_values=0) + print("New shape after padding:", padded.shape) + + return padded + + +def crop_to_bbox(img): + """Crops a 3D volume to the minimal bounding box around all nonzero voxels.""" + # Find where the volume is nonzero (i.e., contains instances) + nonzero = np.argwhere(img) + + # Compute bounding box from min to max index along each axis + z_min, y_min, x_min = nonzero.min(axis=0) + z_max, y_max, x_max = nonzero.max(axis=0) + 1 # +1 to include the max index + + # Crop the volume + cropped = img[z_min:z_max, y_min:y_max, x_min:x_max] + return cropped + + +def get_rotated_shape(img, rotation_matrix): + # Step 1: Define the 8 corners of the original volume + dz, dy, dx = img.shape + + corners = np.array([ + [0, 0, 0], + [0, 0, dx], + [0, dy, 0], + [0, dy, dx], + [dz, 0, 0], + [dz, 0, dx], + [dz, dy, 0], + [dz, dy, dx] + ]) + + # Step 2: Apply the rotation matrix + rotated_corners = corners @ rotation_matrix[0:3, 0:3] # Apply rotation + + # Step 3: Find min and max of the rotated corners + min_coords = rotated_corners.min(axis=0) + max_coords = rotated_corners.max(axis=0) + + # Step 4: Calculate the new shape + new_shape = (np.ceil(max_coords - min_coords).astype(int)) + + return new_shape + + +def get_translation_matrix(translation): + M = np.identity(4) + M[0:3, 3] = translation + return M + + +def get_rotation(angles): + R = tf3d.euler.euler2mat(*angles, axes='szyx') + return R + + +def get_rotation_matrix(R): + M = np.identity(4) + M[0:3, 0:3] = R + return M + + +def get_transformation_matrix(img, gc, Vt, save_path=None): + # 1. center image on origin + center_to_origin = get_translation_matrix(gc) + # 2. rotate image + rot = get_rotation_matrix(Vt.T) + # 3. get new shape + new_shape = get_rotated_shape(img, rot) + # 4. center image on new shape + new_shape_center = np.array(new_shape) // 2 + center_to_new_shape = get_translation_matrix(-new_shape_center) + # 5. combine all transforms: get transformation matrix + T = center_to_origin @ rot @ center_to_new_shape + + if save_path is not None: + np.savetxt(save_path, T) + + return T, new_shape + + +def rotate_img(img, rotation_matrix, output_shape=None, offset=None): + """ + Rotate an image using a given rotation matrix. + + Parameters + ---------- + img : array + The 3D image to be rotated. + rotation_matrix : array + A 3x3 or 4x4 rotation matrix. + output_shape : tuple, optional + The desired output shape of the rotated image. If not given, the output shape + will be determined from the rotation matrix. + offset : tuple, optional + The offset to apply to the rotated image to ensure it fits within the new + bounding box. If not given, the offset will be determined from the rotation + matrix. + + Returns + ------- + rotated_img : array + The rotated image with the desired output shape (if given). + """ + rotated_img = affine_transform( + img, + matrix=rotation_matrix, + output_shape=output_shape, # new shape after rotation + offset=offset, # offset to ensure the image fits within the new bounding box + order=0, # interpolation (use 0 for discrete/label data) + mode='constant', # fill mode + cval=0.0 # fill value (if constant mode) + ) + + print(f"Shape after rotation: {rotated_img.shape}") + return rotated_img diff --git a/matchmaker/vis.py b/matchmaker/vis.py index af78530..d5b6a3f 100644 --- a/matchmaker/vis.py +++ b/matchmaker/vis.py @@ -1,10 +1,20 @@ import matplotlib.pyplot as plt import numpy as np -from platy_reg.preprocessing import percentile_norm +from matchmaker.preprocessing import percentile_norm -def plot_three_slices(img, save_path, x_pos=None, y_pos=None, z_pos=None, cmap="Greys_r", max_pos=False, alpha=False): - """Plot slices of a 3D image along each axis. +def plot_three_slices( + img, + save_path=None, + x_pos=None, + y_pos=None, + z_pos=None, + cmap="Greys_r", + max_pos=False, + alpha=False, +): + """ + Plot slices of a 3D image along each axis. Args: img: _description_ @@ -22,34 +32,34 @@ def plot_three_slices(img, save_path, x_pos=None, y_pos=None, z_pos=None, cmap=" y_pos = int(img.shape[1] // 2) if z_pos is None: z_pos = int(img.shape[0] // 2) - + if max_pos: z_pos, y_pos, x_pos = np.unravel_index(np.argmax(img), img.shape) - + if alpha: alpha = (img > 0).astype(np.float32) else: alpha = np.ones_like(img) plt.figure(figsize=(15, 5)) - plt.subplot(1,3,1) - plt.title(f'z slice at {z_pos}') + plt.subplot(1, 3, 1) + plt.title(f"z slice at {z_pos}") plt.imshow(img[z_pos, :, :], cmap=cmap, alpha=alpha[z_pos, :, :]) - plt.subplot(1,3,2) - plt.title(f'y slice at {y_pos}') + plt.subplot(1, 3, 2) + plt.title(f"y slice at {y_pos}") plt.imshow(img[:, y_pos, :], cmap=cmap, alpha=alpha[:, y_pos, :]) - plt.subplot(1,3,3) - plt.title(f'x slice at {x_pos}') + plt.subplot(1, 3, 3) + plt.title(f"x slice at {x_pos}") plt.imshow(img[:, :, x_pos], cmap=cmap, alpha=alpha[:, :, x_pos]) if save_path is None: plt.show() else: plt.savefig(save_path, dpi=300) plt.close() - - - + + def plot_overlay(img1, img2, save_path=None, x_pos=None, y_pos=None, z_pos=None): - """Plot slices of two 3D images along each axis. + """ + Plot slices of two 3D images along each axis. Args: img1: _description_target_shape = (25, 22, 29) @@ -61,29 +71,29 @@ def plot_overlay(img1, img2, save_path=None, x_pos=None, y_pos=None, z_pos=None) """ assert img1.ndim == 3 assert img2.ndim == 3 - + if x_pos is None: x_pos = min(int(img1.shape[2] // 2), int(img2.shape[2] // 2)) if y_pos is None: y_pos = min(int(img1.shape[1] // 2), int(img2.shape[1] // 2)) if z_pos is None: z_pos = min(int(img1.shape[0] // 2), int(img2.shape[0] // 2)) - + plt.figure(figsize=(30, 10), dpi=300) - plt.subplot(1,3,1) - plt.title(f'z slice at {z_pos}') + plt.subplot(1, 3, 1) + plt.title(f"z slice at {z_pos}") img1_alpha = (percentile_norm(img1, 0, 100) > 0) * 0.5 img2_alpha = (percentile_norm(img2, 0, 100) > 0) * 0.5 plt.imshow(img1[z_pos, :, :], cmap="Reds", alpha=img1_alpha[z_pos, :, :]) plt.imshow(img2[z_pos, :, :], cmap="Blues", alpha=img2_alpha[z_pos, :, :]) - - plt.subplot(1,3,2) - plt.title(f'y slice at {y_pos}') + + plt.subplot(1, 3, 2) + plt.title(f"y slice at {y_pos}") plt.imshow(img1[:, y_pos, :], cmap="Reds", alpha=img1_alpha[:, y_pos, :]) plt.imshow(img2[:, y_pos, :], cmap="Blues", alpha=img2_alpha[:, y_pos, :]) - - plt.subplot(1,3,3) - plt.title(f'x slice at {x_pos}') + + plt.subplot(1, 3, 3) + plt.title(f"x slice at {x_pos}") plt.imshow(img1[:, :, x_pos], cmap="Reds", alpha=img1_alpha[:, :, x_pos]) plt.imshow(img2[:, :, x_pos], cmap="Blues", alpha=img2_alpha[:, :, x_pos]) @@ -91,4 +101,4 @@ def plot_overlay(img1, img2, save_path=None, x_pos=None, y_pos=None, z_pos=None) plt.show() else: plt.savefig(save_path, dpi=300) - plt.close() \ No newline at end of file + plt.close()