diff --git a/README.md b/README.md index 67c14ca..f87d759 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ Optional: If you want to hide all warnings during tests, append: This test generates deformed test data, runs the Snakemake workflow, and compares the final outputs against pre-computed reference results. -The reference data will be downloaded automatically from this repo's release if available. If the download fails (e.g. the repository is private), manually download the reference data from [here](https://github.com/kreshuklab/matchmaker/releases/tag/test_data-v0.2) and place it under: +The reference data will be downloaded automatically from this repo's release if available. If the download fails (e.g. the repository is private), manually download the reference data from [here](https://github.com/kreshuklab/matchmaker/releases/tag/test_data-v1.0) and place it under: ``` examples/data/test_data/ @@ -80,8 +80,8 @@ mm.n5-utils.read_volume(...) ## Registration ### Input -- **Fixed image**: 3D instance segmentation in `n5` + resolution -- **Moving image**: 3D instance segmentation `n5` + resolution +- **Fixed image**: 3D instance segmentation in `n5` + resolution +- **Moving image**: 3D instance segmentation `n5` + resolution Expected image shape: ZYX @@ -100,7 +100,7 @@ Expected image shape: ZYX **1. PCA pre-alignment**: alignment of fixed and moving image to the PCs \ `prealignment.py --fixed_path ... --fixed_key ... --moving_path ... --moving_key ... --output_dir ... --mobie_export --dataset_name ...` \ -Outputs: +Outputs: - prealigned images: `{file_name}_prealigned.n5` - transformation matrixes - `{file_name}_fixed_T_prealignment.txt` @@ -114,7 +114,7 @@ Outputs: **2. Rigid pre-alignment with Elastix** \ `apply_rigid_elastix.py --fixed_path ... --fixed_key ... --moving_path ... --moving_key ... --output_dir ... --mobie_export --dataset_name ...` \ -Outputs: +Outputs: - rigid alinged moving image: `{file_name}_rigid_aligned.n5` - rigid transformation matrix (Elastix outputs): `result.0.mhd`, `result.0.raw`, `TransformParameters.0.txt` - logging file: `elastix_log_rigid.log` @@ -143,4 +143,4 @@ with distance between keypoints in loss and rigidity penalty ---> **B-spline coe - **Moving image resampled to match fixed image** - only rigid - `n5` - **Moving image resampled to match fixed image** - deformable - `n5` -Expected image shape: (C)ZYX \ No newline at end of file +Expected image shape: (C)ZYX diff --git a/examples/deform_test_data.py b/examples/deform_test_data.py index 178958c..75f9fd8 100644 --- a/examples/deform_test_data.py +++ b/examples/deform_test_data.py @@ -8,7 +8,7 @@ from matchmaker.utils import (get_transformation_matrix, rotate_img, write_volume, plot_three_slices, plot_overlay, grid_sample3d, load_config, - crop_to_bbox) + crop_to_bbox, resample_volume) def remove_instances(seg, prob=0.05, seed=None): @@ -32,33 +32,38 @@ def remove_instances(seg, prob=0.05, seed=None): return seg -def save_volume(path, array, key="seg", chunks=(128, 512, 512), attributes={"resolution":[1,1,1]}, +def save_volume(path, array, key="seg", chunks=(128, 512, 512), resolution=[1,1,1], save_tif=True): assert path.endswith(".n5") - write_volume(f=path, arr=array, key=key, chunks=chunks, attrs=attributes,) + write_volume(f=path, arr=array, key=key, chunks=chunks, attrs={"resolution":resolution,},) if save_tif: tif.imwrite(path.replace(".n5", ".tif"), array) -def rigid_deform(fixed, angles): +def rigid_deform(fixed, angles, voxel_spacing=None): if not isinstance(angles, (list, tuple)): raise TypeError() assert len(angles) == 3 + iso_spacing = np.asarray([1, 1, 1], dtype=np.float32) center = np.array(fixed.shape) // 2 rotation = tf3d.euler.euler2mat( *[np.deg2rad(angles[0]), np.deg2rad(angles[1]), np.deg2rad(angles[2])], axes="szyx" ) - T, new_shape = get_transformation_matrix(fixed, center, rotation) + T, new_shape = get_transformation_matrix(fixed, center, rotation, iso_spacing) moving = rotate_img(fixed, T, output_shape=new_shape) + + if not np.array_equal(voxel_spacing, iso_spacing): + moving = resample_volume(moving, iso_spacing, voxel_spacing) + return moving -def elastic_deform(volume, alpha=(1.,1.,1.), sigma=None, spacing=16, mode="nearest", +def elastic_deform(volume, alpha=(1.,1.,1.), sigma=None, grid_spacing=16, mode="nearest", align_corners=False, seed=None,): """ Apply elastic deformation to a 3D volume. @@ -67,7 +72,7 @@ def elastic_deform(volume, alpha=(1.,1.,1.), sigma=None, spacing=16, mode="neare volume (np.ndarray): Input volume of shape (D, H, W). alpha (tuple[float, float, float]): Displacement amplitude scaling factors. sigma (float): Gaussian smoothing std. - spacing (int | tuple[int, int, int]): Control point spacing (in voxels). + grid_spacing (int | tuple[int, int, int]): Control point spacing (in voxels). mode (str): Interpolation mode, "nearest" or "trilinear". align_corners (bool): Grid sampling alignment flag. seed (int | None): Random seed. @@ -91,15 +96,15 @@ def elastic_deform(volume, alpha=(1.,1.,1.), sigma=None, spacing=16, mode="neare else: raise ValueError - if isinstance(spacing, int): - spacing = [spacing] * 3 - elif isinstance(spacing, (list, tuple)): - assert len(spacing) == 3 + if isinstance(grid_spacing, int): + grid_spacing = [grid_spacing] * 3 + elif isinstance(grid_spacing, (list, tuple)): + assert len(grid_spacing) == 3 else: raise ValueError if sigma is None: - sigma = (spacing[0] / 2, spacing[1] / 2, spacing[2] / 2) + sigma = (grid_spacing[0] / 2, grid_spacing[1] / 2, grid_spacing[2] / 2) elif isinstance(sigma, (int, float)): sigma = (sigma, sigma, sigma) elif isinstance(sigma, (list, tuple)): @@ -108,13 +113,13 @@ def elastic_deform(volume, alpha=(1.,1.,1.), sigma=None, spacing=16, mode="neare else: raise ValueError - shape = (int(np.ceil(D/spacing[0])), int(np.ceil(H/spacing[1])), int(np.ceil(W/spacing[2]))) + shape = (int(np.ceil(D/grid_spacing[0])), int(np.ceil(H/grid_spacing[1])), int(np.ceil(W/grid_spacing[2]))) disp = np.random.randn(*shape, 3).astype(np.float32) disp = gaussian(disp, sigma=(*sigma, 0), mode="constant", preserve_range=True,) disp *= alpha - if (spacing[0] > 1) or (spacing[1] > 1) or (spacing[2] > 1): + if (grid_spacing[0] > 1) or (grid_spacing[1] > 1) or (grid_spacing[2] > 1): zoom_factors = (D / shape[0], H / shape[1], W / shape[2], 1.,) disp = zoom(disp, zoom_factors, order=1) @@ -132,42 +137,50 @@ def normalize_axis(d): return sampled.astype(volume.dtype) -def deform_test_data(cfg_path="", config=None, enable_elastic=False, alpha=0.9, sigma=2, spacing=16, - rotate_angles_fixed=[20,345,30], rotate_angles_moving=[155,30,65], - remove_p=0.05, seed=42, visualize=True): +def deform_test_data(cfg_path="", config=None, enable_aniso=False, enable_elastic=False, + alpha=0.9, sigma=2, grid_spacing=16, rotate_angles_fixed=[20,345,30], + rotate_angles_moving=[155,30,65], remove_p=0.05, seed=42, visualize=True): if config is None: config = load_config(cfg_path) data_dir = Path(config["fixed_image"]["path"]).parent data_dir.mkdir(parents=True, exist_ok=True) + fixed_spacing = [config["fixed_image"]["z_res"], config["fixed_image"]["y_res"], config["fixed_image"]["x_res"]] + moving_spacing = [config["moving_image"]["z_res"], config["moving_image"]["y_res"], config["moving_image"]["x_res"]] + fixed_spacing = np.asarray(fixed_spacing, dtype=np.float32) + moving_spacing = np.asarray(moving_spacing, dtype=np.float32) + + if enable_aniso: + assert not np.all(moving_spacing == 1) + seg_fixed = tif.imread(config["fixed_image"]["source_path"]) + seg_moving = seg_fixed.copy() - seg_fixed = rigid_deform(seg_fixed, angles=rotate_angles_fixed) + seg_fixed = rigid_deform(seg_fixed, rotate_angles_fixed, fixed_spacing) seg_fixed = crop_to_bbox(seg_fixed) print("Fixed volume shape", seg_fixed.shape) - seg_moving = seg_fixed.copy() if enable_elastic: - seg_moving = elastic_deform(seg_moving, alpha=alpha, sigma=sigma, spacing=spacing, seed=seed,) + seg_moving = elastic_deform(seg_moving, alpha=alpha, sigma=sigma, grid_spacing=grid_spacing, seed=seed,) - seg_moving = rigid_deform(seg_moving, angles=rotate_angles_moving) + seg_moving = rigid_deform(seg_moving, rotate_angles_moving, moving_spacing) seg_moving = remove_instances(seg_moving, prob=remove_p, seed=seed) seg_moving = crop_to_bbox(seg_moving) print("Moving volume shape", seg_moving.shape) - save_volume(config["fixed_image"]["path"].replace(".tif", ".n5"), seg_fixed) - save_volume(config["moving_image"]["path"].replace(".tif", ".n5"), seg_moving) + save_volume(config["fixed_image"]["path"].replace(".tif", ".n5"), seg_fixed, resolution=fixed_spacing.tolist()) + save_volume(config["moving_image"]["path"].replace(".tif", ".n5"), seg_moving, resolution=moving_spacing.tolist()) if visualize: plot_dir = data_dir / "plots" plot_dir.mkdir(parents=True, exist_ok=True) - plot_three_slices(seg_fixed, save_path=plot_dir/"seg_fixed.png") + iso_name = "aniso_" if enable_aniso else "" + elastic_name = "elastic" if enable_elastic else "rigid" - moving_name = "seg_elastic" if enable_elastic else "seg_rigid" - plot_three_slices(seg_moving, save_path=plot_dir/f"{moving_name}.png") - overlay_name = "elastic_overlay" if enable_elastic else "rigid_overlay" - plot_overlay(seg_fixed, seg_moving, save_path=plot_dir/f"{overlay_name}.png") + plot_three_slices(seg_fixed, save_path=plot_dir/f"seg_{iso_name}fixed.png") + plot_three_slices(seg_moving, save_path=plot_dir/f"seg_{iso_name}{elastic_name}.png") + plot_overlay(seg_fixed, seg_moving, save_path=plot_dir/f"{iso_name}{elastic_name}_overlay.png") if __name__ == "__main__": diff --git a/examples/register_config_test_aniso_rigid.yaml b/examples/register_config_test_aniso_rigid.yaml new file mode 100644 index 0000000..65f70d3 --- /dev/null +++ b/examples/register_config_test_aniso_rigid.yaml @@ -0,0 +1,36 @@ +fixed_image: + source_path: "examples/data/platy1_muscles_stardist_fixed.tif" + path: "examples/data/deformed_data/platy1_muscles_stardist_aniso_fixed.tif" + output_name: "fixed_image" + x_res: 1 + y_res: 1 + z_res: 1 + +moving_image: + path: "examples/data/deformed_data/platy1_muscles_stardist_aniso_rigid.tif" + output_name: "moving_image" + x_res: 2 + y_res: 1 + z_res: 1 + ref_path: "examples/data/test_data/moving_aniso_rigid_aligned.tif" + ref_url: "https://github.com/kreshuklab/matchmaker/releases/download/test_data-v0.3/moving_aniso_rigid_aligned.tif" + +log_dir: "data/test_rigid_registration" +final_transform_path: "data/test_rigid_registration/final_transform.json" + +prealignment: + axis_orientation: "auto" + +coherent_point_drift: + w: 0.00001 + lmd: 0.1 + beta: 100 + maxiter: 10 # Low for debugging purpose, should be 100-150 + +ILP: + min_neighbours: 10 + max_dist: 30 + +mobie_export: True +semantic_seg: True +mobie_dataset_name: "platy1_muscles_stardist" diff --git a/examples/register_config_test_rigid.yaml b/examples/register_config_test_rigid.yaml index 4129ae2..dbf329a 100644 --- a/examples/register_config_test_rigid.yaml +++ b/examples/register_config_test_rigid.yaml @@ -7,17 +7,20 @@ fixed_image: z_res: 1 moving_image: - path: "examples/data/deformed_data/platy1_muscles_stardist_moving.tif" + path: "examples/data/deformed_data/platy1_muscles_stardist_rigid.tif" output_name: "moving_image" x_res: 1 y_res: 1 z_res: 1 - ref_path: "examples/data/test_data/moving_rigid_aligned.tif" - ref_url: "https://github.com/kreshuklab/matchmaker/releases/download/test_data-v0.2/moving_rigid_aligned.tif" + ref_path: "examples/data/test_data/moving_pointset_aligned.tif" + ref_url: "https://github.com/kreshuklab/matchmaker/releases/download/test_data-v1.0/moving_pointset_aligned.tif" log_dir: "data/test_rigid_registration" final_transform_path: "data/test_rigid_registration/final_transform.json" +prealignment: + axis_orientation: "auto" + coherent_point_drift: w: 0.00001 lmd: 0.1 diff --git a/examples/register_config_test_rigid_apply_transform.yaml b/examples/register_config_test_rigid_apply_transform.yaml new file mode 100644 index 0000000..9f2310e --- /dev/null +++ b/examples/register_config_test_rigid_apply_transform.yaml @@ -0,0 +1,27 @@ +fixed_image: + input_path: "examples/data/deformed_data/platy1_muscles_stardist_fixed_rotated.tif" + input_key: "input" + +moving_images: + - + input_path: "data/test_rigid_registration/moving_image.n5" + input_key: "input" + output_path: "data/test_apply_transform/moving_image.n5" + output_key: "pointset_alignment_transform" + x_res: 1 + y_res: 1 + z_res: 1 + interpolation_order: 0 + - + input_path: "examples/data/deformed_data/platy1_muscles_stardist_rigid.tif" + input_key: "input" + output_path: "data/test_apply_transform/moving_image.tif" + output_key: "pointset_alignment_transform" + x_res: 1 + y_res: 1 + z_res: 1 + interpolation_order: 0 + +log_dir: "data/test_apply_transform" +parameter_map_path: "data/test_rigid_registration/elastix_deformable_pointset_registration/TransformParameters.2.txt" +prealignment_transform_path: "data/test_rigid_registration/svd_prealignment/svd_prealignment_transform.json" diff --git a/matchmaker/apply_transform.py b/matchmaker/apply_transform.py new file mode 100644 index 0000000..0d7790c --- /dev/null +++ b/matchmaker/apply_transform.py @@ -0,0 +1,163 @@ +import itk +import json +import click +import logging +import numpy as np +import tifffile as tiff +from pathlib import Path + +from utils import (setup_logging, read_volume, write_volume, get_attrs, rotate_img, + read_transform_dict, apply_transform_chanwise, plot_three_slices, + plot_overlay,) + + +def load_data(path, key=None): + if path.endswith(".n5"): + assert key + data = read_volume(path, key) + elif path.endswith((".tif", ".tiff")): + data = tiff.imread(path) + if data.ndim == 2: + data = data[None, ...] + else: + raise NotImplementedError + + return data.astype(np.float32) + + +def save_data(data, output_path, output_key=None, **kwargs): + logging.info("Write results") + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + if output_path.endswith((".tif", ".tiff")): + tiff.imwrite(output_path, data) + elif output_path.endswith(".n5"): + assert output_key + write_volume(output_path, data, output_key, **kwargs) + else: + raise NotImplementedError + + +def apply_transform(moving_img, moving_resolution, parameter_object, interpolation_order, + T_fixed=None, output_shape=None): + if T_fixed is not None: + assert output_shape is not None + + logging.info(f"Set interpolation order {interpolation_order}") + parameter_object.SetParameter("FinalBSplineInterpolationOrder", str(interpolation_order)) + + warped = apply_transform_chanwise(parameter_object, moving_img, moving_resolution) + logging.info(f"transformed image shape {warped.shape}") + + warped = np.squeeze(warped) + + if T_fixed is not None: + logging.info("Rotate moving image using prealignment transform") + warp_prealigned = rotate_img(warped, T_fixed, output_shape=output_shape) + else: + warp_prealigned = None + + return warped, warp_prealigned + + +@click.command() +@click.option("-mp", "--moving_paths", required=True, multiple=True, help="Paths to moving inputs") +@click.option("-mk", "--moving_keys", required=True, multiple=True, help="Keys of moving inputs") +@click.option("-mr", "--moving_resolutions", required=True, multiple=True, help="Resolutions of moving inputs") +@click.option("-op", "--output_paths", required=True, multiple=True, help="Paths to save warped images",) +@click.option("-ok", "--output_keys", required=True, multiple=True, help="Keys of moving outputs",) +@click.option("-io", "--interpolation_orders", required=True, multiple=True, help="Orders of interpolation",) +@click.option("-ld", "--log_dir", required=True, help="Log directory") +@click.option("-pm", "--parameter_map_path", required=True, help="Path to the parameter map",) +@click.option("-pt", "--prealignment_transform_path", default=None, help="Prealignment transform path",) +@click.option("-fi", "--fixed_path", default=None, help="Fixed input .n5 file") +@click.option("-fk", "--fixed_key", default=None, help="Fixed input key") +@click.option("-vb", "--verbose", is_flag=True, default=False, help="Show verbose logs") +def apply_transforms(moving_paths, moving_keys, moving_resolutions, output_paths, output_keys, + interpolation_orders, log_dir, parameter_map_path, prealignment_transform_path, + fixed_path, fixed_key, verbose): + assert len(moving_paths) == len(moving_keys) == len(moving_resolutions) == len(output_paths) == len(output_keys) == len(interpolation_orders) + log_dir = Path(log_dir) + log_dir.mkdir(exist_ok=True) + + setup_logging(log_dir, "apply_transform.log") + + logging.info(f"Read parameter file {parameter_map_path}") + parameter_object = itk.ParameterObject.New() + parameter_object.ReadParameterFile(parameter_map_path) + if verbose: + logging.info(parameter_object) + + if prealignment_transform_path: + logging.info("Read prealignment transform") + prealignment_transform = read_transform_dict(prealignment_transform_path)["fixed_prealignment"] + T_fixed, output_shape = prealignment_transform["matrix"], prealignment_transform["output_shape"] + else: + logging.info("Process without prealignment transform") + T_fixed, output_shape = None, None + + if fixed_path: + logging.info("Read fixed image") + fixed_img = load_data(fixed_path, fixed_key) + if T_fixed is not None: + logging.info("Rotate fixed image using prealignment transform") + fixed_prealigned = rotate_img(fixed_img, T_fixed, output_shape=output_shape) + + for i, (moving_path, moving_key) in enumerate(zip(moving_paths, moving_keys)): + print("\n") + logging.info(f"[{i+1}/{len(moving_paths)}] Start processing moving image: {moving_path}") + moving_resolution = json.loads(moving_resolutions[i]) + output_path, output_key = output_paths[i], output_keys[i] + interpolation_order = interpolation_orders[i] + + moving_name = Path(moving_path).stem + logging.info("Read moving image") + moving_img = load_data(moving_path, moving_key) + if moving_path.endswith(".n5"): + if list(moving_resolution) != get_attrs(moving_path, moving_key)["resolution"]: + raise ValueError("Moving resolution from config is different from n5 file") + + if moving_img.ndim == 3: + moving_img = moving_img[None, ...] + chunks = (128, 512, 512) + else: + chunks = (1, 128, 512, 512) + + logging.info("Start transformation") + warped, warp_prealigned = apply_transform(moving_img, + moving_resolution, + parameter_object, + interpolation_order, + T_fixed=T_fixed, + output_shape=output_shape) + + resolution = [float(res) for res in parameter_object.GetParameter(0, "Spacing")] + + save_attrs = {} + if moving_path.endswith(".n5"): + attributes = dict(get_attrs(moving_path, moving_key)) + attributes["resolution"] = resolution + save_attrs["chunks"] = chunks + save_attrs["attrs"] = attributes + + logging.info("Plot warped image") + plot_three_slices(warped, save_path=log_dir / f"{moving_name}_warped.png") + if fixed_path: + logging.info("Plot overlay image") + plot_overlay(fixed_img, warped, log_dir / f"{moving_name}_warped_overlay.png",) + + if T_fixed is not None: + logging.info("Plot warped moving image after pre-alignment") + plot_three_slices(warp_prealigned, save_path=log_dir / f"{moving_name}_warp_prealigned.png") + + if fixed_path: + logging.info("Plot overlay image after pre-alignment") + plot_overlay(fixed_prealigned, warp_prealigned, log_dir / f"{moving_name}_warp_prealigned_overlay.png",) + + save_data(warp_prealigned, output_path, output_key=output_key, **save_attrs) + + else: + save_data(warped, output_path, output_key=output_key, **save_attrs) + + +if __name__ == "__main__": + apply_transforms() diff --git a/matchmaker/prealignment.py b/matchmaker/prealignment.py index 8479563..0d5f5c4 100755 --- a/matchmaker/prealignment.py +++ b/matchmaker/prealignment.py @@ -7,10 +7,11 @@ from matchmaker.data import create_point_cloud from matchmaker.utils import (get_transformation_matrix, rotate_img, read_volume, get_attrs, write_volume, - write_transform_dict, plot_three_slices, plot_overlay, setup_logging, get_axis_orient_matrix) + write_transform_dict, plot_three_slices, plot_overlay, setup_logging, + get_axis_orient_matrix, resample_volume, transform_axes_vis) -def get_SVD_transform(img, save_path=None): +def get_SVD_transform(img, spacing, save_path=None): """Convert image to point cloud by thresholding, then run SVD on resulting point cloud. Args: @@ -22,6 +23,9 @@ def get_SVD_transform(img, save_path=None): """ pos, _ = create_point_cloud(img) + + pos *= spacing + gc = pos.mean(axis=0) pos_c = pos - gc logging.info(f"Point cloud shape {pos.shape}") @@ -36,17 +40,19 @@ def get_SVD_transform(img, save_path=None): logging.info("Vt") logging.info(str(Vt)) - logging.info("Rotate point cloud") - vr = pos_c @ Vt.T + gc /= spacing - plt.figure(figsize=(10, 5)) - plt.subplot(1, 2, 1) - plt.title("Original Vertices") - plt.scatter(pos_c[:, 0], pos_c[:, 1], alpha=0.2) - plt.subplot(1, 2, 2) - plt.title("Rotated Vertices") - plt.scatter(vr[:, 0], vr[:, 1], alpha=0.1) if save_path: + logging.info("Rotate point cloud") + vr = pos_c @ Vt.T + + plt.figure(figsize=(10, 5)) + plt.subplot(1, 2, 1) + plt.title("Original Vertices") + plt.scatter(pos_c[:, 0], pos_c[:, 1], alpha=0.2) + plt.subplot(1, 2, 2) + plt.title("Rotated Vertices") + plt.scatter(vr[:, 0], vr[:, 1], alpha=0.1) plt.savefig(save_path, dpi=300) return gc, Vt @@ -184,8 +190,7 @@ def generate_rotation_overlays(fixed_prealigned, moving_prealigned, output_dir): ) - -def prealign_samples(fixed_img, moving_img): +def prealign_samples(fixed_img, moving_img, fixed_spacing, moving_spacing, new_spacing): """ Pre-align two segmentation volumes using PCA. @@ -196,13 +201,17 @@ def prealign_samples(fixed_img, moving_img): Returns: fixed_rot, moving_rot, T_fixed, T_moving """ - gc_fixed, Vt_fixed = get_SVD_transform(fixed_img) - gc_moving, Vt_moving = get_SVD_transform(moving_img) - - T_fixed, fixed_shape = get_transformation_matrix(fixed_img, gc_fixed, Vt_fixed, - img_ref=moving_img, Vt_ref=Vt_moving) - T_moving, moving_shape = get_transformation_matrix(moving_img, gc_moving, Vt_moving, - img_ref=fixed_img, Vt_ref=Vt_fixed) + gc_fixed, Vt_fixed = get_SVD_transform(fixed_img, fixed_spacing) + gc_moving, Vt_moving = get_SVD_transform(moving_img, moving_spacing) + + T_fixed, fixed_shape = get_transformation_matrix(fixed_img, gc_fixed, Vt_fixed, fixed_spacing, + img_ref=moving_img, Vt_ref=Vt_moving, + spacing_ref=moving_spacing, + spacing_out=new_spacing) + T_moving, moving_shape = get_transformation_matrix(moving_img, gc_moving, Vt_moving, moving_spacing, + img_ref=fixed_img, Vt_ref=Vt_fixed, + spacing_ref=fixed_spacing, + spacing_out=new_spacing) assert np.array_equal(fixed_shape, moving_shape) fixed_rot = rotate_img(fixed_img, T_fixed, output_shape=fixed_shape) @@ -236,6 +245,9 @@ def mirror_img(img, T): def run_prealignment( fixed_img, moving_img, + fixed_spacing, + moving_spacing, + new_spacing, output_dir, axis_orientation ): @@ -294,7 +306,7 @@ def run_prealignment( logging.info("Start prealignment of fixed and moving images ...") - prealigned_results = prealign_samples(fixed_img, moving_img) + prealigned_results = prealign_samples(fixed_img, moving_img, fixed_spacing, moving_spacing, new_spacing) fixed_prealigned, T_fixed, gc_fixed, Vt_fixed, fixed_shape = prealigned_results["fixed"] moving_prealigned, T_moving, gc_moving, Vt_moving, _ = prealigned_results["moving"] @@ -327,10 +339,10 @@ def run_prealignment( fixed_prealigned, moving_prealigned, save_path=f"{output_dir}/plots/overlay_after_prealignment_before_axis_orient.png", - gc1=T_fixed[:3,:3].T @ (gc_fixed - T_fixed[:3, 3]), - Vt1=Vt_fixed @ T_fixed[:3, :3], - gc2=T_moving[:3,:3].T @ (gc_moving - T_moving[:3,3]), - Vt2=Vt_moving @ T_moving[:3, :3], + gc1=(np.linalg.inv(T_fixed) @ np.append(gc_fixed, 1))[:3], + Vt1=transform_axes_vis(Vt_fixed, T_fixed), + gc2=(np.linalg.inv(T_moving) @ np.append(gc_moving, 1))[:3], + Vt2=transform_axes_vis(Vt_moving, T_moving), ) # check orientation (if moving fits to fixed) @@ -390,25 +402,25 @@ def run_prealignment( plot_three_slices( fixed_prealigned, save_path=f"{output_dir}/plots/fixed_prealigned.png", - gc = T_fixed[:3,:3].T @ (gc_fixed - T_fixed[:3, 3]), - Vt = Vt_fixed @ T_fixed[:3, :3], + gc = (np.linalg.inv(T_fixed) @ np.append(gc_fixed, 1))[:3], + Vt = transform_axes_vis(Vt_fixed, T_fixed), ) plot_three_slices( moving_prealigned, save_path=f"{output_dir}/plots/moving_prealigned.png", - gc = T_moving[:3,:3].T @ (gc_moving - T_moving[:3,3]), - Vt = Vt_moving @ T_moving[:3, :3], + gc = (np.linalg.inv(T_moving) @ np.append(gc_moving, 1))[:3], + Vt = transform_axes_vis(Vt_moving, T_moving), ) plot_overlay( fixed_prealigned, moving_prealigned, save_path=f"{output_dir}/plots/overlay_after_prealignment.png", - gc1=T_fixed[:3,:3].T @ (gc_fixed - T_fixed[:3, 3]), - Vt1=Vt_fixed @ T_fixed[:3, :3], - gc2=T_moving[:3,:3].T @ (gc_moving - T_moving[:3,3]), - Vt2=Vt_moving @ T_moving[:3, :3], + gc1=(np.linalg.inv(T_fixed) @ np.append(gc_fixed, 1))[:3], + Vt1=transform_axes_vis(Vt_fixed, T_fixed), + gc2=(np.linalg.inv(T_moving) @ np.append(gc_moving, 1))[:3], + Vt2=transform_axes_vis(Vt_moving, T_moving), ) return fixed_prealigned, moving_prealigned, prealignment_transform @@ -417,14 +429,17 @@ def run_prealignment( @click.command() @click.option("-fi", "--fixed_path", required=True, help="Fixed input .n5 file") @click.option("-fk", "--fixed_key", required=True, help="Fixed input key") +@click.option("-fs", "--fixed_spacing", nargs=3, required=True, help="Fixed input spacing") @click.option("-mi", "--moving_path", required=True, help="Moving input .n5 file") @click.option("-mk", "--moving_key", required=True, help="Moving input key") +@click.option("-ms", "--moving_spacing", nargs=3, required=True, help="Moving input spacing") @click.option("-o", "--output_dir", required=True, help="Output directory") @click.option("-ok", "--output_key", required=True, help="Output key (same in both n5)") @click.option("-trans", "--output_transform_path", required=True, help="Path to write the final transform") @click.option("-axis_orientation", "--axis_orientation", required=True, help="How to find the correct orientation along the principal axes") @click.option("-tif", "--save_tif", is_flag=True, help="Whether to save tif or not") -def main(fixed_path, fixed_key, moving_path, moving_key, output_dir, output_key, output_transform_path, axis_orientation, save_tif=False): +def main(fixed_path, fixed_key, fixed_spacing, moving_path, moving_key, moving_spacing, + output_dir, output_key, output_transform_path, axis_orientation, save_tif=False): """ Perform prealignment of moving image to fixed image. @@ -444,23 +459,33 @@ def main(fixed_path, fixed_key, moving_path, moving_key, output_dir, output_key, """ setup_logging(output_dir, "prealignment.log") + fixed_spacing = np.asarray(fixed_spacing, dtype=np.float32) + moving_spacing = np.asarray(moving_spacing, dtype=np.float32) + new_spacing = np.full_like(fixed_spacing, fixed_spacing.min()) + logging.info("Reading fixed image") fixed_img = read_volume(fixed_path, fixed_key) - logging.info(f"Fixed image shape: {fixed_img.shape}, dtype {fixed_img.dtype}") + logging.info(f"Fixed image shape: {fixed_img.shape}, dtype {fixed_img.dtype}, resolution {fixed_spacing}") logging.info("Reading moving image") moving_img = read_volume(moving_path, moving_key) - logging.info(f"Moving image shape: {moving_img.shape}, dtype {moving_img.dtype}") + logging.info(f"Moving image shape: {moving_img.shape}, dtype {moving_img.dtype}, resolution {moving_spacing}") fixed_prealigned, moving_prealigned, prealignment_transform = run_prealignment( fixed_img, moving_img, + fixed_spacing, + moving_spacing, + new_spacing, output_dir, axis_orientation ) logging.info("Save prealigned fixed image") fixed_attributes = dict(get_attrs(fixed_path, fixed_key)) + if not np.array_equal(fixed_spacing, new_spacing): + assert new_spacing is not None + fixed_attributes["resolution"] = new_spacing.tolist() write_volume( f=fixed_path, arr=fixed_prealigned, @@ -470,6 +495,8 @@ def main(fixed_path, fixed_key, moving_path, moving_key, output_dir, output_key, logging.info("Save prealigned moving image") moving_attributes = dict(get_attrs(moving_path, moving_key)) + if not np.array_equal(moving_spacing, new_spacing): + moving_attributes["resolution"] = new_spacing.tolist() write_volume( f=moving_path, arr=moving_prealigned, diff --git a/matchmaker/utils/transform_utils.py b/matchmaker/utils/transform_utils.py index 7ef4efc..4b2446e 100644 --- a/matchmaker/utils/transform_utils.py +++ b/matchmaker/utils/transform_utils.py @@ -1,9 +1,8 @@ import json import numpy as np import transforms3d as tf3d -from scipy.ndimage import affine_transform +from scipy.ndimage import affine_transform, zoom from elf.wrapper.resized_volume import ResizedVolume -import logging def write_transform_dict(transform_dict, json_path): @@ -80,7 +79,15 @@ def crop_to_bbox(img): return cropped -def get_rotated_shape(img, rotation_matrix): +def resample_volume(img, old_spacing, new_spacing, order=0): + scale = np.asarray(old_spacing) / np.asarray(new_spacing) + return zoom(img, scale, order=order) + + +def get_rotated_shape(img, rotation_matrix, spacing, spacing_out=None): + if spacing_out is None: + spacing_out = spacing + # Step 1: Define the 8 corners of the original volume dz, dy, dx = img.shape @@ -93,11 +100,15 @@ def get_rotated_shape(img, rotation_matrix): [dz, 0, dx], [dz, dy, 0], [dz, dy, dx] - ]) + ], dtype=np.float32) + + corners *= spacing # Step 2: Apply the rotation matrix rotated_corners = corners @ rotation_matrix[0:3, 0:3] # Apply rotation + rotated_corners /= spacing_out + # Step 3: Find min and max of the rotated corners min_coords = rotated_corners.min(axis=0) max_coords = rotated_corners.max(axis=0) @@ -125,17 +136,26 @@ def get_rotation_matrix(R): return M -def get_transformation_matrix(img, gc, Vt, img_ref=None, Vt_ref=None): +def get_transformation_matrix(img, gc, Vt, spacing, img_ref=None, Vt_ref=None, + spacing_ref=None, spacing_out=None): + if spacing_out is None: + spacing_out = spacing + spacing_out_ref = spacing_ref + else: + spacing_out_ref = spacing_out + # 1. center image on origin center_to_origin = get_translation_matrix(gc) # 2. rotate image rot = get_rotation_matrix(Vt.T) + rot_in = rot.copy() + # 3. get new shape - new_shape, min_coords, max_coords = get_rotated_shape(img, rot) + new_shape, min_coords, max_coords = get_rotated_shape(img, rot_in, spacing, spacing_out) if img_ref is not None: assert Vt_ref is not None rot_ref = get_rotation_matrix(Vt_ref.T) - _, min_coords_ref, max_coords_ref = get_rotated_shape(img_ref, rot_ref) + _, min_coords_ref, max_coords_ref = get_rotated_shape(img_ref, rot_ref, spacing_ref, spacing_out_ref) union_min = np.minimum(min_coords, min_coords_ref) union_max = np.maximum(max_coords, max_coords_ref) new_shape = np.ceil(union_max - union_min).astype(int) @@ -143,6 +163,12 @@ def get_transformation_matrix(img, gc, Vt, img_ref=None, Vt_ref=None): new_shape_center = np.array(new_shape) // 2 center_to_new_shape = get_translation_matrix(-new_shape_center) # 5. combine all transforms: get transformation matrix + # Because affine_transform performs inverse transform, here we need to be + # careful about the inverse matrix and the order, e.g. (AB)^-1 = B^-1A^-1 + S_in_inv = np.diag(1./spacing) + S_out = np.diag(spacing_out) + rot[:3, :3] = S_in_inv @ rot[:3, :3] @ S_out + T = center_to_origin @ rot @ center_to_new_shape return T, new_shape diff --git a/matchmaker/utils/vis.py b/matchmaker/utils/vis.py index 6d1710b..68a17b3 100644 --- a/matchmaker/utils/vis.py +++ b/matchmaker/utils/vis.py @@ -321,3 +321,15 @@ def plot_matching_qc( plt.savefig(fig_name, dpi=300) plt.close() + + +def transform_axes_vis(Vt, T): + """ + Transform PCA axes into the target space for visualization. + Extracts the pure rotation from the affine transform T (removing scaling) + and applies it to the PCA axes Vt. The resulting axes are normalized. + """ + U, _, Vt_svd = np.linalg.svd(T[:3, :3]) + R = U @ Vt_svd + Vt_vis = Vt @ R + return Vt_vis / np.linalg.norm(Vt_vis, axis=1, keepdims=True) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 922a473..2fe9fd0 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,5 +1,6 @@ import sys import subprocess +import yaml import shutil import numpy as np import tifffile as tiff @@ -11,7 +12,9 @@ compute_centroid_distances,) -def run_pipline(config_path, snakefile, cores=8, test_dir="tmp_pytest"): +def run_pipline(registration_config_path, registration_snakefile, transform_config_path, + transform_snakefile, cores=8, test_dir="tmp_pytest", enable_aniso=False, + enable_elastic=False): test_dir = Path(test_dir) final_transform_path = test_dir / "final_transform.json" @@ -19,54 +22,70 @@ def run_pipline(config_path, snakefile, cores=8, test_dir="tmp_pytest"): shutil.rmtree(test_dir) test_dir.mkdir(parents=True) - config = load_config(config_path) + registration_config = load_config(registration_config_path) + transform_config = load_config(transform_config_path) # Generate deformed data - deform_test_data(config=config) + deform_test_data(config=registration_config, enable_aniso=enable_aniso, enable_elastic=enable_elastic) - # Run snakemake + # Run registration snakemake workflow result = subprocess.run( [ "snakemake", - "--snakefile", snakefile, - "--configfile", config_path, + "--snakefile", registration_snakefile, + "--configfile", registration_config_path, "--config", f"log_dir={test_dir}", f"final_transform_path={final_transform_path}", "--cores", str(cores), ], check=True, ) - assert result.returncode == 0, result.stderr + + log_dir = transform_config["log_dir"] + for moving_img in transform_config["moving_images"]: + moving_img["input_path"] = moving_img["input_path"].replace("data/test_rigid_registration", str(test_dir)) + moving_img["output_path"] = moving_img["output_path"].replace(log_dir, str(test_dir)) + transform_config["log_dir"] = str(test_dir) + transform_config["final_transform_path"] = str(final_transform_path) + transform_config["parameter_map_path"] = transform_config["parameter_map_path"].replace("data/test_rigid_registration", str(test_dir)) + transform_config["prealignment_transform_path"] = transform_config["prealignment_transform_path"].replace("data/test_rigid_registration", str(test_dir)) + + tmp_config_path = test_dir / "apply_transform.yaml" + with open(tmp_config_path, "w") as f: + yaml.dump(transform_config, f) + + # Run apply_transform snakemake workflow + result = subprocess.run( + [ + "snakemake", + "--snakefile", transform_snakefile, + "--configfile", tmp_config_path, + "--cores", str(cores), + ], + check=True, + ) # Compare results - output_key = config["keys"]["rigid_alignment"] - moving_path = test_dir / f"{config['moving_image']['output_name']}.n5" + moving_path = test_dir / f"{registration_config['moving_image']['output_name']}.n5" + + result_img = read_volume(moving_path, "pointset_alignment") + warped_img = read_volume(moving_path, "pointset_alignment_transform") - test_img = read_volume(moving_path, output_key) + assert_arrays_equal(result_img, warped_img) - ref_path = config["moving_image"]["ref_path"] + ref_path = registration_config["moving_image"]["ref_path"] if not Path(ref_path).exists(): - ref_url = config["moving_image"]["ref_url"] + ref_url = registration_config["moving_image"]["ref_url"] download_file(ref_path, ref_url) ref_img = tiff.imread(ref_path) - if sys.platform.startswith("linux"): - assert_arrays_equal(test_img, ref_img) - else: - # NOTE: - # scipy.ndimage.affine_transform is not bitwise deterministic across platforms. - # In practice, small voxel-level differences may occur between operating systems. - # These differences are geometrically negligible (≤ 1 voxel shift), so we - # validate structural consistency instead of strict array equality. - no_new_id, _ = check_no_new_ids(test_img, ref_img) - assert no_new_id - - centroid_distances = compute_centroid_distances(test_img, ref_img, exclude_id=0) - assert np.max(centroid_distances) < 1 + assert_arrays_equal(result_img, ref_img) + print("✅ compare_results finished.") shutil.rmtree(test_dir) def test_workflow(): - run_pipline("examples/register_config_test_rigid.yaml", "workflows/registration.smk") + run_pipline("examples/register_config_test_rigid.yaml", "workflows/registration.smk", + "examples/register_config_test_rigid_apply_transform.yaml", "workflows/apply_transform.smk") diff --git a/workflows/apply_transform.smk b/workflows/apply_transform.smk new file mode 100644 index 0000000..b974847 --- /dev/null +++ b/workflows/apply_transform.smk @@ -0,0 +1,73 @@ +import json + +configfile: "examples/register_config_test_rigid_apply_transform.yaml" + +moving_images = config["moving_images"] +moving_paths = [item["input_path"] for item in moving_images] +moving_keys = [item["input_key"] for item in moving_images] +moving_resolutions = [[item["z_res"], item["y_res"], item["x_res"]] for item in moving_images] +moving_resolutions = [json.dumps(r, separators=(',', ':')) for r in moving_resolutions] +output_paths = [item["output_path"] for item in moving_images] +output_keys = [item["output_key"] for item in moving_images] +interpolation_orders = [item["interpolation_order"] for item in moving_images] + +log_dir = config["log_dir"] +parameter_map_path = config["parameter_map_path"] +prealignment_transform_path = config["prealignment_transform_path"] + +fixed_path = config["fixed_image"]["input_path"] +fixed_key = config["fixed_image"]["input_key"] + + +def expand_flag(flag, values): + return " ".join(f"{flag} {v}" for v in values) + + +def get_all_opts(d): + opts = [] + for k, v in d.items(): + if v: + opts.extend([f"--{k}", str(v)]) + return opts + + +rule all: + input: + output_paths, + + +rule apply_transform: + input: + parameter_map_path = parameter_map_path, + output: + [directory(f"{path}/{key}") if path.endswith(".n5") else path for path, key in zip(output_paths, output_keys)], + params: + opts = lambda w: get_all_opts({ + "prealignment_transform_path": prealignment_transform_path, + "fixed_path": fixed_path, + "fixed_key": fixed_key, + }), + + moving_paths = expand_flag("--moving_paths", moving_paths), + moving_keys = expand_flag("--moving_keys", moving_keys), + moving_resolutions = expand_flag("--moving_resolutions", moving_resolutions), + output_paths = expand_flag("--output_paths", output_paths), + output_keys = expand_flag("--output_keys", output_keys), + interpolation_orders = expand_flag("--interpolation_orders", interpolation_orders), + log_dir = log_dir, + + log: f"{log_dir}/matchmaker.log" + conda: "matchmaker_env" + shell: + """ + python matchmaker/apply_transform.py \ + {params.opts} \ + {params.moving_paths} \ + {params.moving_keys} \ + {params.moving_resolutions} \ + {params.output_paths} \ + {params.output_keys} \ + {params.interpolation_orders} \ + --log_dir {params.log_dir} \ + --parameter_map_path {input.parameter_map_path} + """ diff --git a/workflows/registration.smk b/workflows/registration.smk index 3387a32..36ad138 100644 --- a/workflows/registration.smk +++ b/workflows/registration.smk @@ -17,6 +17,8 @@ moving_input_key = config["moving_image"]["input_key"] if "input_key" in config[ moving_n5_path = f"{config['log_dir']}/{moving_name}.n5" log_dir = config["log_dir"] final_transform = config["final_transform_path"] +fixed_spacing = [config["fixed_image"]["z_res"], config["fixed_image"]["y_res"], config["fixed_image"]["x_res"]] +moving_spacing = [config["moving_image"]["z_res"], config["moving_image"]["y_res"], config["moving_image"]["x_res"]] print(fixed_name, fixed_n5_path, moving_name, moving_n5_path) @@ -116,7 +118,7 @@ rule SVD_prealignment: log: f"{log_dir}/matchmaker.log" conda: "matchmaker_env" shell: - f"python matchmaker/prealignment.py --fixed_path {{input.fixed_image_n5}} --fixed_key {{params.input_key}} --moving_path {{input.moving_image_n5}} --moving_key {{params.input_key}} --output_dir {log_dir}/{{params.output_key}} --output_key {{params.output_key}} --output_transform_path {{output.output_transform}} --axis_orientation {{params.axis_orientation}};" + f"python matchmaker/prealignment.py --fixed_path {{input.fixed_image_n5}} --fixed_key {{params.input_key}} --fixed_spacing {{fixed_spacing}} --moving_path {{input.moving_image_n5}} --moving_key {{params.input_key}} --moving_spacing {{moving_spacing}} --output_dir {log_dir}/{{params.output_key}} --output_key {{params.output_key}} --output_transform_path {{output.output_transform}} --axis_orientation {{params.axis_orientation}};" """ @@ -291,4 +293,4 @@ rule add_elastix_deformable_pointset_to_mobie: f"python matchmaker/mobie_export.py --input_path {{input.moving_image_n5}} --input_key {{params.input_key}} --input_type {moving_type} {'--semantic_seg' if semantic_seg else ''} --dataset_name {dataset_name} --output_dir {log_dir};" f"touch {{output.moving_check}};" f"rm -rf ./tmp_{dataset_name}_{moving_name}_{{params.input_key}};" - f"rm -rf ./tmp_{dataset_name}_{moving_name}_{{params.input_key}}_binary;" \ No newline at end of file + f"rm -rf ./tmp_{dataset_name}_{moving_name}_{{params.input_key}}_binary;"