In [None]:
# =============================
# Standard library
# =============================
import os
import sys
from collections import defaultdict
from pathlib import Path

# =============================
# Third-party
# =============================
import cv2
import ants
import itk
import flow_vis
import neurite as ne
import numpy as np
import matplotlib.pyplot as plt
import pystrum.pynd.ndutils as nd
from tqdm import tqdm

from itk import (
    ParameterObject,
    elastix_registration_method,
    transformix_filter,
)

In [None]:
# =============================
# Project path setup
# =============================
PROJECT_SRC = Path.cwd().parent
sys.path.append(str(PROJECT_SRC))

from models.model_reg_gn import SpatialTransformer

In [None]:
# =============================
# Utility
# =============================
class suppress_stdout:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, "w")

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout


## Random inputs

In [None]:
np.random.seed(0)

H, W, D = 64, 64, 32

im1_in = np.random.rand(H, W, D).astype(np.float32)
im2_in = np.random.rand(H, W, D).astype(np.float32)

## LDDMM

In [None]:
# =============================
# LDDMM registration
# =============================
import torch_lddmm

with suppress_stdout():
    lddmm = torch_lddmm.LDDMM(
        template=im1_in,
        target=im2_in,
        outdir="../",
        do_affine=1,
        do_lddmm=1,
        a=1,
        p=2,
        niter=500,
        epsilon=0.3,
        epsilonL=1e-7,
        epsilonT=2e-5,
        sigma=4.0,
        sigmaR=10.0,
        nt=3,
        im_norm_ms=1,
        energy_fraction=0.02,
        update_epsilon=1,
        optimizer="gdr",
    )
    lddmm.run()

In [None]:
# =============================
# LDDMM outputs
# =============================
phi0, phi1, phi2 = lddmm.computeThisDisplacement()
deformed_template = lddmm.outputDeformedTemplate()

flo_lddmm = np.stack((phi0, phi1, phi2), axis=-1)  # (H, W, D, 3)
print("LDDMM flow shape:", flo_lddmm.shape)


## ANTS

In [None]:
# =============================
# ANTs SyN registration
# =============================
fixed = ants.from_numpy(im1_in)
moving = ants.from_numpy(im2_in)

reg = ants.registration(
    fixed=fixed,
    moving=moving,
    type_of_transform="SyN",
    syn_metric="mattes",
    reg_iterations=(250, 250, 250, 250, 250),
    flow_sigma=3
)

flo_syn = ants.image_read(reg["fwdtransforms"][0]).numpy()
print("SyN flow shape:", flo_syn.shape)


## Elastix

In [None]:
# =============================
# Elastix parameter file
# =============================
param_file = "multi_contrast_elastix_params.txt"

param_text = """
(FixedInternalImagePixelType "float")
(MovingInternalImagePixelType "float")
(FixedImageDimension 3)
(MovingImageDimension 3)
(UseDirectionCosines "true")

(Registration "MultiResolutionRegistration")
(Interpolator "BSplineInterpolator")
(ResampleInterpolator "FinalBSplineInterpolator")
(Resampler "DefaultResampler")

(Optimizer "AdaptiveStochasticGradientDescent")
(Transform "BSplineTransform")
(Metric "NormalizedMutualInformation")

(FinalGridSpacingInPhysicalUnits 32 16 8 4)
(NumberOfResolutions 4)
(MaximumNumberOfIterations 1000)

(ImageSampler "RandomCoordinate")
(NumberOfSpatialSamples 4096)

(DefaultPixelValue 0)
(WriteResultImage "true")
"""

with open(param_file, "w") as f:
    f.write(param_text)

print(f"Saved Elastix parameters to {param_file}")


In [None]:
# =============================
# Elastix registration
# =============================
fixed_image = itk.GetImageFromArray(im1_in.astype(np.float32))
moving_image = itk.GetImageFromArray(im2_in.astype(np.float32))

parameter_object = itk.ParameterObject.New()
parameter_object.AddParameterFile(param_file)

result_image, result_transform_parameters = elastix_registration_method(
    fixed_image,
    moving_image,
    parameter_object=parameter_object,
    log_to_console=True,
)

deformation_field_itk = itk.transformix_deformation_field(
    fixed_image,
    result_transform_parameters,
    log_to_console=True,
)

deformation_field = itk.GetArrayFromImage(deformation_field_itk)