In [None]:
import math
from sitkibex.registration import RegistrationCallbackManager
import SimpleITK as sitk
#from SimpleITK.utilities import fft_based_translation_initialization
from pathlib import Path
import numpy as np
from pytools import HedwigZarrImages
import logging
sitk.ImageViewer()
sitk.ImageViewer.SetGlobalDefaultFileExtension(".nrrd")

In [None]:
file_dir = Path("/Users/blowekamp/scratch/hedwig/TestData/Nanostringfiles/ROI Alignment Images for Brad/")

In [None]:

def fft_based_translation_initialization(
    fixed: sitk.Image,
    moving: sitk.Image,
    *,
    required_fraction_of_overlapping_pixels: float = 0.5,
    initial_transform: sitk.Transform = None,
    maked_pixel_value: float = 0.0,
) -> sitk.TranslationTransform:
    """Perform fast Fourier transform based normalized correlation to find the translation which maximizes correlation
    between the images.

    If the moving image grid is not congruent with fixed image ( same origin, spacing and direction ), then it will be
    resampled onto the grid defined by the fixed image.

    Efficiency can be improved by reducing the resolution of the image or using a projection filter to reduce the
    dimensionality of the inputs.

    :param fixed: A SimpleITK image object.
    :param moving: Another SimpleITK Image object, which will be resampled onto the grid of the fixed image if it is not
        congruent.
    :param required_fraction_of_overlapping_pixels: The required fraction of overlapping pixels between the fixed and
        moving image.
    :param initial_transform: An initial transformation to be applied to the moving image by resampling before the
        FFT registration. The returned transform will be of the initial_transform type with the translation updated.
    :param maked_pixel_value: The value of input pixels to be ignored by correlation. If None, then the FFT
        correlation will be used, otherwise the MaskedFFTNormalizedCorrelation will be used.
    :return: A TranslationTransform mapping physical points from the fixed to the moving image.
    """

    if (
        initial_transform is not None or
        moving.GetSpacing() != fixed.GetSpacing()
        or moving.GetDirection() != fixed.GetDirection()
        or moving.GetOrigin() != fixed.GetOrigin()
    ):
        resampler = sitk.ResampleImageFilter()
        resampler.SetReferenceImage(fixed)
        if initial_transform is not None:
            resampler.SetTransform(initial_transform)
            
        moving = resampler.Execute(moving)

    sigma = fixed.GetSpacing()[0]
    pixel_type = sitk.sitkFloat32

    fft_fixed = sitk.Cast(sitk.SmoothingRecursiveGaussian(fixed, sigma), pixel_type)
    fft_moving = sitk.Cast(sitk.SmoothingRecursiveGaussian(moving, sigma), pixel_type)

    if maked_pixel_value is None:
        out = sitk.FFTNormalizedCorrelation(fft_fixed,
                                            fft_moving,
                                            requiredFractionOfOverlappingPixels=required_fraction_of_overlapping_pixels)
    else:
        out = sitk.MaskedFFTNormalizedCorrelation(fft_fixed,
                                        fft_moving,
                                        sitk.Cast(fft_fixed != maked_pixel_value, pixel_type),
                                        sitk.Cast(fft_moving != maked_pixel_value, pixel_type),
                                        requiredFractionOfOverlappingPixels=required_fraction_of_overlapping_pixels)
    
    out = sitk.SmoothingRecursiveGaussian(out)
    cc = sitk.ConnectedComponent(sitk.RegionalMaxima(out, fullyConnected=True))
    stats = sitk.LabelStatisticsImageFilter()
    stats.Execute(out, cc)
    labels = sorted(stats.GetLabels(), key=lambda l: stats.GetMean(l))

    peak_bb = stats.GetBoundingBox(labels[-1])
    # Add 0.5 for center of voxel on continuous index
    peak_idx = [
        (min_idx + max_idx) / 2.0 + 0.5
        for min_idx, max_idx in zip(peak_bb[0::2], peak_bb[1::2])
    ]
    peak_pt = out.TransformContinuousIndexToPhysicalPoint(peak_idx)
    peak_value = stats.GetMean(labels[-1])

    center_pt = out.TransformContinuousIndexToPhysicalPoint(
        [p / 2.0 for p in out.GetSize()]
    )
    translation = [c - p for c, p in zip(center_pt, peak_pt)]
    
    if initial_transform is not None:
        print("Using initial transform")
        offset = initial_transform.TransformVector(translation, point=[0,0])
        
        tx_out = sitk.Transform(initial_transform).Downcast()
        tx_out.SetTranslation( [a+b for (a, b) in zip(initial_transform.GetTranslation(), offset)])
        return tx_out
    
    

    return sitk.TranslationTransform(out.GetDimension(), translation)


In [None]:

_logger = logging.getLogger(__name__)
def register_2d_affine(
    fixed_image,
    moving_image,
    sigma_base=1.0,
    initial_transform=None,
    fixed_image_mask=None,
    moving_image_mask=None,
    number_of_samples_per_parameter=5000,
    *,
    do_affine=True,
    verbose=True,
):
    
    use_gradient_filter = False
    rigid_scale_factors = [8, 4, 2]

    _logger.info("Initializing projected registration...")
    _logger.info("Sigma Base: {0}".format(sigma_base))
    _logger.info("Initial Translation: {0}".format(initial_transform))

    # Initialize the center of transform and align the center of the volumes.
    # - Use Euler transform
    #

    if initial_transform is None:
        initial_rigid = sitk.CenteredTransformInitializer(
            fixed_image, moving_image, sitk.Euler2DTransform(), sitk.CenteredTransformInitializerFilter.GEOMETRY
        )
    else:
        initial_rigid = initial_transform

    #
    # Setup Multi-scale 2D Multi-scale registration
    #
    R = sitk.ImageRegistrationMethod()

    R.SetMetricAsCorrelation()
    #R.SetMetricAsMattesMutualInformation(numberOfHistogramBins=100)

    # Due to the sparse sampling the gradient on the whole image
    # is not needed, and it is more time efficient not to compute.
    if use_gradient_filter:
        R.MetricUseFixedImageGradientFilterOn()
        R.MetricUseMovingImageGradientFilterOn()
    else:
        R.MetricUseMovingImageGradientFilterOff()
        R.MetricUseFixedImageGradientFilterOff()

    R.SetOptimizerAsGradientDescent(
        learningRate=1.0,
        numberOfIterations=500,
        convergenceMinimumValue=1e-6,
        convergenceWindowSize=10,
        maximumStepSizeInPhysicalUnits=2.0,
    )

    R.SetOptimizerScalesFromIndexShift()

    # We don't need more samples for larger image, so base the number of samples on the number of parameters
    sampling_percentage = (
        len(initial_rigid.GetParameters()) * number_of_samples_per_parameter / fixed_image.GetNumberOfPixels()
    )
    R.SetMetricSamplingPercentagePerLevel(
        [min(0.10, sampling_percentage) for f in rigid_scale_factors], seed=1
    )
    R.SetMetricSamplingStrategy(R.REGULAR)
    R.SetShrinkFactorsPerLevel([1 for f in rigid_scale_factors])
    R.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
    R.SetSmoothingSigmasPerLevel([4.0 * sigma_base * f * fixed_image.GetSpacing()[0] for f in rigid_scale_factors])

    R.SetInitialTransform(initial_rigid)
    #if initial_transform is not None:
    #    R.SetMovingInitialTransform(initial_transform)
    R.SetInterpolator(sitk.sitkLinear)

    if fixed_image_mask:
        _logger.info("Setting fixed mask")
        R.SetMetricFixedMask(fixed_image_mask)

    if moving_image_mask:
        _logger.info("Setting moving mask")
        R.SetMetricMovingMask(moving_image_mask)

    R_callbacks = RegistrationCallbackManager(R)
    R_callbacks.add_command_callbacks(print_position=True, verbose=verbose)

    # R.DebugOn()

    rigid_result_2d = R.Execute(fixed_image, moving_image)


    # Promote Euler to 2D Affine
    affine = sitk.AffineTransform(2)
    affine.SetMatrix(rigid_result_2d.GetMatrix())
    affine.SetTranslation(rigid_result_2d.GetTranslation())
    affine.SetCenter(rigid_result_2d.GetCenter())

    #
    # Setup Multi-Scale 2D Affine registration
    #
    R2 = sitk.ImageRegistrationMethod()
    R2.SetMetricAsCorrelation()
    if use_gradient_filter:
        R2.MetricUseFixedImageGradientFilterOn()
        R2.MetricUseMovingImageGradientFilterOn()
    else:    
        R2.MetricUseMovingImageGradientFilterOff()
        R2.MetricUseFixedImageGradientFilterOff()
    R2.SetOptimizerAsGradientDescentLineSearch(
        learningRate=1.0,
        numberOfIterations=100,
        convergenceMinimumValue=1e-6,
        convergenceWindowSize=10,
        lineSearchLowerLimit=0,
        lineSearchUpperLimit=2.0,
        lineSearchMaximumIterations=5,
        maximumStepSizeInPhysicalUnits=2,
    )

    R2.SetOptimizerScalesFromIndexShift()

    scale_factors = [4, 2, 1]
    sampling_percentage = len(affine.GetParameters()) * number_of_samples_per_parameter / fixed_image.GetNumberOfPixels()
    R2.SetMetricSamplingPercentagePerLevel(
        [min(0.10, sampling_percentage * f) for f in scale_factors], seed=1
    )
    R2.SetMetricSamplingStrategy(R.RANDOM)
    R2.SetShrinkFactorsPerLevel([f for f in scale_factors])
    R2.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
    R2.SetSmoothingSigmasPerLevel([sigma_base * f * fixed_image.GetSpacing()[0] for f in scale_factors])

    R2.SetInitialTransform(affine)
    R2.SetInterpolator(sitk.sitkLinear)

    R2_callbacks = RegistrationCallbackManager(R2)
    R2_callbacks.add_command_callbacks(print_position=True, verbose=verbose)

    if do_affine:
        affine_result = R2.Execute(fixed_image, moving_image)
    else:
        affine_result = affine

    # Do explicit casting
    affine_result = sitk.AffineTransform(affine_result)


    return affine_result
logging.basicConfig(level=logging.DEBUG)
_logger.setLevel(logging.DEBUG)

In [None]:
def equalize_intensity(img:sitk.Image, *, number_of_histogram_levels=256, number_of_match_points=12, threshold_at_mean_intensity=True):
    histogram_equalization = sitk.HistogramMatchingImageFilter()
    histogram_equalization.SetNumberOfHistogramLevels(number_of_histogram_levels)
    histogram_equalization.SetNumberOfMatchPoints(number_of_match_points)
    histogram_equalization.SetThresholdAtMeanIntensity(threshold_at_mean_intensity)
    
    ramp_np = np.linspace(0, number_of_histogram_levels-1, number_of_histogram_levels, dtype=int)[..., np.newaxis]
    ramp = sitk.GetImageFromArray(ramp_np)
    ramp = sitk.Cast(ramp, img.GetPixelID())
    
    return histogram_equalization.Execute(img, ramp)

def prep_image_for_registration(img: sitk.Image) -> sitk.Image:
    """Prepare an image for registration by casting to float, converting to scalar.
    """
    if img.GetNumberOfComponentsPerPixel() > 1:
        img = sitk.Cast(img, sitk.sitkVectorFloat32)
        img = sitk.VectorMagnitude(img)
    else:
        img = sitk.Cast(img, sitk.sitkFloat32)
    return img


In [None]:
S_number=4
S_fixed_number=1#S_number
moving_index=1
fixed_image_path = file_dir / f"IA_P2_S{S_fixed_number}.ome.zarr"
moving_image_path = file_dir / f"IA_P2_S{S_number}.zarr"
output_transform_path = f"IA_P2_S{S_number}_{moving_index}_to_roi.txt"

resolution = 16384
print(f"original fixed size: {HedwigZarrImages(fixed_image_path)[0].shape}")
print(f"original moving size: {HedwigZarrImages(moving_image_path)[moving_index].shape}")
fixed = HedwigZarrImages(fixed_image_path)[0].extract_2d(resolution,resolution)
roi = HedwigZarrImages(moving_image_path)[moving_index].extract_2d(resolution, resolution)


preped_fixed = prep_image_for_registration(fixed)
preped_roi = prep_image_for_registration(roi)

initial_tx = sitk.CenteredTransformInitializer(preped_fixed, preped_roi, sitk.Similarity2DTransform(), sitk.CenteredTransformInitializerFilter.GEOMETRY)
initial_tx = initial_tx.Downcast()
initial_tx.SetAngle(np.pi/2.0)

bin_factor = math.floor(resolution/512)
fft_tx = fft_based_translation_initialization(sitk.BinShrink(preped_fixed, [bin_factor,bin_factor]),
                                        sitk.BinShrink(preped_roi, [bin_factor,bin_factor]),
                                        initial_transform=initial_tx,
                                        required_fraction_of_overlapping_pixels=0.2)
print(fft_tx)

    
result_tx = register_2d_affine(preped_fixed,
                               preped_roi,
                               initial_transform=fft_tx,
                               fixed_image_mask=(preped_fixed>0.0),
                               moving_image_mask=(preped_roi>0.0),
                               number_of_samples_per_parameter=10000,
                               do_affine=True,)
print(result_tx)

sitk.WriteTransform(result_tx, file_dir/output_transform_path)

In [None]:
reg_img = sitk.Resample(sitk.VectorMagnitude(roi), referenceImage=fixed, transform=result_tx)
sitk.Show(sitk.Compose(reg_img, sitk.VectorMagnitude(fixed)))