In [2]:
import sys
import os
from pathlib import Path

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

GNDTRUTH = Path('../data/raw/COMMON_images_masks/')
RAWIMGS = Path('../data/raw/GROUP_images/')
DATA = (RAWIMGS.parent).parent
MASKS = DATA/'masks'

import SimpleITK as sitk

In [3]:
from ipywidgets import interact, IntSlider, fixed
import matplotlib.pyplot as plt
import numpy as np

def show_coronal_slice(arr, slc, mask=None):
    plt.figure(figsize=(5,5))
    plt.imshow(arr[:, slc, :], cmap='gray')
    
    if mask is not None:
        plt.imshow(
            np.ma.masked_where(mask[:, slc, :] == 0, mask[:, slc, :]),
            cmap='hsv',
            vmin=0,
            vmax=5,
            alpha=0.4
        )
    
    plt.axis('off')
    plt.title(f'Slice {slc}')
    plt.show()

def show_interactive(arr, fn, mask=None):
    return interact(
        fn,
        slc=IntSlider(min=0, max=arr.shape[1]-1, step=1, value=arr.shape[1]//2),
        arr=fixed(arr),
        mask=fixed(mask)
    )

def show_coronal_overlay(fixed_img, moving_img, slc):
    plt.figure(figsize=(5,5))
    
    plt.imshow(fixed_img[:, slc, :], cmap='Blues')
    plt.imshow(moving_img[:, slc, :], cmap='Reds', alpha=0.3)
    
    plt.axis('off')
    plt.title(f'Coronal slice {slc}')
    plt.show()

def show_interactive_overlay(fixed_img, moving_img):
    return interact(
        show_coronal_overlay,
        slc=IntSlider(
            min=0,
            max=fixed_img.shape[1]-1,
            step=1,
            value=fixed_img.shape[1]//2
        ),
        fixed_img=fixed(fixed_img),
        moving_img=fixed(moving_img)
    )

In [4]:
# downsample image function
def downsample_image(image, scale_factor=2):
    """Downsample a SimpleITK image."""
    if isinstance(scale_factor, int):
        scale_factor = [scale_factor] * image.GetDimension()
    
    original_spacing = image.GetSpacing()
    original_size = image.GetSize()
    
    new_spacing = [spacing * factor for spacing, factor in zip(original_spacing, scale_factor)]
    new_size = [int(size / factor) for size, factor in zip(original_size, scale_factor)]
    
    resampler = sitk.ResampleImageFilter()
    resampler.SetSize(new_size)
    resampler.SetOutputSpacing(new_spacing)
    resampler.SetOutputOrigin(image.GetOrigin())
    resampler.SetOutputDirection(image.GetDirection())
    resampler.SetTransform(sitk.Transform())
    resampler.SetInterpolator(sitk.sitkLinear)
    
    return resampler.Execute(image)

In [5]:
#downsample mask function
def downsample_mask(mask, scale_factor=2):
    """Downsample a SimpleITK mask image by a given scale factor."""
    if isinstance(scale_factor, int):
        scale_factor = [scale_factor] * mask.GetDimension()
    
    original_spacing = mask.GetSpacing()
    original_size = mask.GetSize()
    
    new_spacing = [spacing * factor for spacing, factor in zip(original_spacing, scale_factor)]
    new_size = [int(size / factor) for size, factor in zip(original_size, scale_factor)]
    
    resampler = sitk.ResampleImageFilter()
    resampler.SetSize(new_size)
    resampler.SetOutputSpacing(new_spacing)
    resampler.SetOutputOrigin(mask.GetOrigin())
    resampler.SetOutputDirection(mask.GetDirection())
    resampler.SetTransform(sitk.Transform())
    resampler.SetInterpolator(sitk.sitkNearestNeighbor)
    
    return resampler.Execute(mask)

In [6]:
DOWNSAMPLE_FACTOR = 1  # modify based on your memory capacity

print(f"\nDownsample factor: {DOWNSAMPLE_FACTOR}")
print("(Set to 1 for full resolution, 2 for half, 3 for 1/3, etc.)\n")

# load images
fix_im_full = sitk.ReadImage(GNDTRUTH/'common_40_image.nii.gz')
fix_msk_full = sitk.ReadImage(MASKS/'common_40_regmask_2.nii.gz')
mov_im_full = sitk.ReadImage(RAWIMGS/'g1_54_image.nii.gz')
mov_msk_full = sitk.ReadImage(MASKS/'g1_54_mask.nii.gz')

# downsample images and masks
print("Downsampling fixed image...")
fix_im = downsample_image(fix_im_full, DOWNSAMPLE_FACTOR)
fix_msk = downsample_mask(fix_msk_full, DOWNSAMPLE_FACTOR)

print("\nDownsampling moving image...")
mov_im = downsample_image(mov_im_full, DOWNSAMPLE_FACTOR)
mov_msk = downsample_mask(mov_msk_full, DOWNSAMPLE_FACTOR)

# Get arrays for visualization
fix_im_data = sitk.GetArrayFromImage(fix_im)
fix_msk_data = sitk.GetArrayFromImage(fix_msk)
mov_im_data = sitk.GetArrayFromImage(mov_im)
mov_msk_data = sitk.GetArrayFromImage(mov_msk)

print(f"\nMask labels in fixed: {np.unique(fix_msk_data)}")
print(f"Mask labels in moving: {np.unique(mov_msk_data)}")



Downsample factor: 1
(Set to 1 for full resolution, 2 for half, 3 for 1/3, etc.)

Downsampling fixed image...

Downsampling moving image...

Mask labels in fixed: [0 1 2]
Mask labels in moving: [0 1 2]


In [8]:
def est_lin_transf_with_mask(
    im_ref, im_mov, mask_ref, mask_mov=None,
    use_distance_map=True,
    sampling=0.3,
    bins=50
):
    """
    Two-stage registration using masks:
      1) Rigid (VersorRigid3DTransform)
      2) Affine
    Metric uses either:
      - Distance map of masks (recommended) OR
      - Mattes MI on intensity with fixed/moving masks
    """

    fixed = sitk.Cast(im_ref, sitk.sitkFloat32)
    moving = sitk.Cast(im_mov, sitk.sitkFloat32)

    fixed_mask = sitk.Cast(mask_ref > 0, sitk.sitkUInt8)
    moving_mask = None
    if mask_mov is not None:
        moving_mask = sitk.Cast(mask_mov > 0, sitk.sitkUInt8)

    def _unwrap_last(tfm):
        if isinstance(tfm, sitk.CompositeTransform):
            return tfm.GetNthTransform(tfm.GetNumberOfTransforms()-1)
        return tfm

    def _base_setup(reg, fixed_mask, moving_mask):
        # ---- Metric ----
        if use_distance_map:
            # signed distance map 
            f = sitk.SignedMaurerDistanceMap(fixed_mask, insideIsPositive=True, squaredDistance=False, useImageSpacing=True)
            if moving_mask is None:
                raise ValueError("use_distance_map=True requires mask_mov (moving mask).")
            m = sitk.SignedMaurerDistanceMap(moving_mask, insideIsPositive=True, squaredDistance=False, useImageSpacing=True)
            # MeanSquares
            reg.SetMetricAsMeanSquares()
            return f, m
        else:
            reg.SetMetricAsMattesMutualInformation(numberOfHistogramBins=bins)
            reg.SetMetricFixedMask(fixed_mask)
            if moving_mask is not None:
                reg.SetMetricMovingMask(moving_mask)
            return fixed, moving

    def _common(reg):
        reg.SetInterpolator(sitk.sitkLinear)

        # Sampling: REGULAR
        reg.SetMetricSamplingStrategy(reg.REGULAR)
        reg.SetMetricSamplingPercentage(sampling)

        # Multi-resolution
        reg.SetShrinkFactorsPerLevel([4, 2, 1])
        reg.SetSmoothingSigmasPerLevel([2, 1, 0])
        reg.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    # -------------------------
    # Stage 1: Rigid
    # -------------------------
    rigid = sitk.VersorRigid3DTransform()
    rigid_init = sitk.CenteredTransformInitializer(
        fixed, moving, rigid,
        sitk.CenteredTransformInitializerFilter.MOMENTS
    )

    reg1 = sitk.ImageRegistrationMethod()
    fixed_for_reg1, moving_for_reg1 = _base_setup(reg1, fixed_mask, moving_mask)
    _common(reg1)

    # RegularStepGradientDescent optimizer
    reg1.SetOptimizerAsRegularStepGradientDescent(
        learningRate=2.0,
        minStep=1e-3,
        numberOfIterations=200,
        gradientMagnitudeTolerance=1e-6
    )
    reg1.SetOptimizerScalesFromPhysicalShift()
    reg1.SetInitialTransform(rigid_init, inPlace=False)

    print("Executing RIGID registration...")
    rigid_tfm = reg1.Execute(fixed_for_reg1, moving_for_reg1)
    rigid_tfm = _unwrap_last(rigid_tfm)

    print("  Rigid stop:", reg1.GetOptimizerStopConditionDescription())
    print("  Rigid iters:", reg1.GetOptimizerIteration())
    print("  Rigid metric:", reg1.GetMetricValue())

    # -------------------------
    # Stage 2: Affine (init from rigid)
    # -------------------------
    affine = sitk.AffineTransform(3)
    #  initialize affine
    affine.SetMatrix(rigid_tfm.GetMatrix())
    affine.SetTranslation(rigid_tfm.GetTranslation())
    try:
        affine.SetCenter(rigid_tfm.GetCenter())
    except Exception:
        #  rigid transform without center, just skip
        pass

    reg2 = sitk.ImageRegistrationMethod()
    fixed_for_reg2, moving_for_reg2 = _base_setup(reg2, fixed_mask, moving_mask)
    _common(reg2)

    reg2.SetOptimizerAsRegularStepGradientDescent(
        learningRate=1.0,
        minStep=1e-4,
        numberOfIterations=300,
        gradientMagnitudeTolerance=1e-6
    )
    reg2.SetOptimizerScalesFromPhysicalShift()
    reg2.SetInitialTransform(affine, inPlace=False)

    print("Executing AFFINE registration...")
    affine_tfm = reg2.Execute(fixed_for_reg2, moving_for_reg2)
    affine_tfm = _unwrap_last(affine_tfm)

    print("  Affine stop:", reg2.GetOptimizerStopConditionDescription())
    print("  Affine iters:", reg2.GetOptimizerIteration())
    print("  Affine metric:", reg2.GetMetricValue())

    return affine_tfm

In [9]:
def apply_transform(im_mov, transform, im_ref, is_mask=False):
    """apply transform"""
    interpolator = sitk.sitkNearestNeighbor if is_mask else sitk.sitkLinear
    return sitk.Resample(
        im_mov,
        im_ref,
        transform,
        interpolator,
        0,
        im_mov.GetPixelID()
    )

In [10]:
def dice_coefficient(mask1, mask2):
    """compute Dice coefficient"""
    arr1 = sitk.GetArrayFromImage(mask1) > 0
    arr2 = sitk.GetArrayFromImage(mask2) > 0
    
    intersection = np.sum(arr1 & arr2)
    union = np.sum(arr1) + np.sum(arr2)
    
    if union == 0:
        return 0.0
    
    return 2.0 * intersection / union


def evaluate_registration(fix_msk, mov_msk, stage_name=""):
    """evaluate registration using Dice coefficient"""
    dice = dice_coefficient(fix_msk, mov_msk)
    
    fix_arr = sitk.GetArrayFromImage(fix_msk) > 0
    mov_arr = sitk.GetArrayFromImage(mov_msk) > 0
    
    overlap = np.sum(fix_arr & mov_arr)
    fix_vol = np.sum(fix_arr)
    mov_vol = np.sum(mov_arr)
    
    print(f"\n{stage_name}")
    print(f"  Dice Coefficient: {dice:.4f}")
    print(f"  Fixed volume: {fix_vol} voxels")
    print(f"  Moving volume: {mov_vol} voxels")
    print(f"  Overlap: {overlap} voxels ({overlap/fix_vol*100:.1f}% of fixed)")
    
    return dice

In [11]:
def visualize_registration(fixed_img, fixed_mask, moving_img, moving_mask, title=""):
    """visualize registration results with interactive coronal slice viewer"""
    
    def show_slice(fixed_img, fixed_mask, moving_img, moving_mask, slc):
        fig, axes = plt.subplots(2, 2, figsize=(12, 12))
        
        # Image overlay
        axes[0, 0].imshow(fixed_img[:, slc, :], cmap='gray')
        axes[0, 0].set_title('Fixed Image')
        axes[0, 0].axis('off')
        
        axes[0, 1].imshow(fixed_img[:, slc, :], cmap='Blues', alpha=0.7)
        axes[0, 1].imshow(moving_img[:, slc, :], cmap='Reds', alpha=0.3)
        axes[0, 1].set_title(f'{title} - Image Overlay')
        axes[0, 1].axis('off')
        
        # Mask overlay
        axes[1, 0].imshow(fixed_img[:, slc, :], cmap='gray')
        axes[1, 0].imshow(np.ma.masked_where(fixed_mask[:, slc, :] == 0, 
                                             fixed_mask[:, slc, :]),
                         cmap='Greens', alpha=0.5)
        axes[1, 0].set_title('Fixed Mask')
        axes[1, 0].axis('off')
        
        axes[1, 1].imshow(fixed_img[:, slc, :], cmap='gray')
        axes[1, 1].imshow(np.ma.masked_where(fixed_mask[:, slc, :] == 0, 
                                             fixed_mask[:, slc, :]),
                         cmap='Greens', alpha=0.4)
        axes[1, 1].imshow(np.ma.masked_where(moving_mask[:, slc, :] == 0, 
                                             moving_mask[:, slc, :]),
                         cmap='Reds', alpha=0.4)
        axes[1, 1].set_title(f'{title} - Masks (Green=Fixed, Red=Moving)')
        axes[1, 1].axis('off')
        
        plt.tight_layout()
        plt.show()
    
    return interact(
        show_slice,
        fixed_img=fixed(fixed_img),
        fixed_mask=fixed(fixed_mask),
        moving_img=fixed(moving_img),
        moving_mask=fixed(moving_mask),
        slc=IntSlider(min=0, max=fixed_img.shape[1]-1, step=1, value=fixed_img.shape[1]//2)
    )

In [12]:
# execute linear registration
lin_tfm = est_lin_transf_with_mask(fix_im, mov_im, fix_msk, mov_msk)

# apply transform
mov_im_linear = apply_transform(mov_im, lin_tfm, fix_im, is_mask=False)
mov_msk_linear = apply_transform(mov_msk, lin_tfm, fix_im, is_mask=True)

mov_im_linear_data = sitk.GetArrayFromImage(mov_im_linear)
mov_msk_linear_data = sitk.GetArrayFromImage(mov_msk_linear)

# evaluate
print("\n" + "-"*60)
print("LINEAR REGISTRATION QUALITY")
print("-"*60)

# Baseline 
print("\nCalculating baseline (center-aligned)...")

# create baseline
baseline_tfm = sitk.CenteredTransformInitializer(
    fix_im,
    mov_im,
    sitk.Euler3DTransform(),  
    sitk.CenteredTransformInitializerFilter.GEOMETRY
)

mov_msk_baseline = apply_transform(mov_msk, baseline_tfm, fix_im, is_mask=True)

dice_before = evaluate_registration(fix_msk, mov_msk_baseline, "Baseline (center-aligned):")
dice_linear = evaluate_registration(fix_msk, mov_msk_linear, "After linear registration:")

improvement = (dice_linear - dice_before) / dice_before * 100 if dice_before > 0 else float('inf')
print(f"\nImprovement: {improvement:.1f}%")

if dice_linear > 0.75:
    print("\n✓ Excellent! Linear registration achieved Dice > 0.75")
elif dice_linear > 0.65:
    print("\n✓ Good! Linear registration achieved decent alignment")
else:
    print("\n⚠️  Linear registration Dice is low. May need further investigation.")


Executing RIGID registration...
  Rigid stop: RegularStepGradientDescentOptimizerv4: Step too small after 52 iterations. Current step (0.000976562) is less than minimum step (0.001).
  Rigid iters: 53
  Rigid metric: 8.550283299810918
Executing AFFINE registration...
  Affine stop: RegularStepGradientDescentOptimizerv4: Step too small after 35 iterations. Current step (6.10352e-05) is less than minimum step (0.0001).
  Affine iters: 36
  Affine metric: 7.572338962442114

------------------------------------------------------------
LINEAR REGISTRATION QUALITY
------------------------------------------------------------

Calculating baseline (center-aligned)...

Baseline (center-aligned):
  Dice Coefficient: 0.0377
  Fixed volume: 1197012 voxels
  Moving volume: 1142011 voxels
  Overlap: 44053 voxels (3.7% of fixed)

After linear registration:
  Dice Coefficient: 0.7512
  Fixed volume: 1197012 voxels
  Moving volume: 1178245 voxels
  Overlap: 892183 voxels (74.5% of fixed)

Improvement: 

In [13]:
visualize_registration(fix_im_data, fix_msk_data, 
                      mov_im_linear_data, mov_msk_linear_data,
                      title="Linear Registration")

interactive(children=(IntSlider(value=256, description='slc', max=511), Output()), _dom_classes=('widget-inter…

<function __main__.visualize_registration.<locals>.show_slice(fixed_img, fixed_mask, moving_img, moving_mask, slc)>

In [14]:
output_dir = DATA / 'registered'
output_dir.mkdir(exist_ok=True)

sitk.WriteTransform(lin_tfm, str(output_dir / 'g1_54_to_common_40_linear.tfm'))
sitk.WriteImage(mov_im_linear, str(output_dir / 'g1_54_linear_registered_image.nii.gz'))
sitk.WriteImage(mov_msk_linear, str(output_dir / 'g1_54_linear_registered_mask.nii.gz'))

In [15]:
def est_nl_transf_bspline_after_linear(im_ref, im_mov_linear, mask_ref, grid_spacing=50,
                                      sampling=0.10, bins=50, max_iter=100):
    """
    B-spline non-linear registration (based on linear-resampled moving)
    """

    fixed  = sitk.Cast(im_ref, sitk.sitkFloat32)
    moving = sitk.Cast(im_mov_linear, sitk.sitkFloat32)

    fixed_mask = sitk.Cast(mask_ref > 0, sitk.sitkUInt8)

    # --- standard meshSize estimate ：physical_size / grid_spacing ---
    size = fixed.GetSize()
    sp   = fixed.GetSpacing()
    physical_size = [size[i] * sp[i] for i in range(3)]

    mesh_size = [max(1, int(round(physical_size[i] / grid_spacing))) for i in range(3)]
    # BSplineTransformInitializer
    bspline = sitk.BSplineTransformInitializer(fixed, transformDomainMeshSize=mesh_size, order=3)

    reg = sitk.ImageRegistrationMethod()
    reg.SetInitialTransform(bspline, inPlace=False)

    # --- metric：MI  ---
    reg.SetMetricAsMattesMutualInformation(numberOfHistogramBins=bins)
    reg.SetMetricSamplingStrategy(reg.REGULAR)        
    reg.SetMetricSamplingPercentage(sampling)
    reg.SetMetricFixedMask(fixed_mask)

    # --- optimizer：LBFGSB---
    reg.SetOptimizerAsLBFGSB(
        gradientConvergenceTolerance=1e-5,
        numberOfIterations=max_iter,
        maximumNumberOfCorrections=5,
        maximumNumberOfFunctionEvaluations=2000
    )
    reg.SetOptimizerScalesFromPhysicalShift()

    reg.SetInterpolator(sitk.sitkLinear)

    # --- multi-resolution ---
    reg.SetShrinkFactorsPerLevel([4, 2, 1])
    reg.SetSmoothingSigmasPerLevel([2, 1, 0])   
    reg.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    print(f"[BSpline] grid_spacing={grid_spacing}mm, physical_size={physical_size}, mesh_size={mesh_size}")
    print("Executing B-spline registration...")
    tfm = reg.Execute(fixed, moving)

    print("✓ B-spline done.")
    print("  Stop:", reg.GetOptimizerStopConditionDescription())
    print("  Iter:", reg.GetOptimizerIteration())
    print("  Metric:", reg.GetMetricValue())

    return tfm


In [16]:
GRID_SPACING = 50  

nl_tfm = est_nl_transf_bspline_after_linear(
    im_ref=fix_im,
    im_mov_linear=mov_im_linear,   
    mask_ref=fix_msk,
    grid_spacing=GRID_SPACING,
    sampling=0.10,                 
    max_iter=100
)

print("\nApplying NON-LINEAR (B-spline) transform...")

mov_im_nonlinear = apply_transform(
    mov_im_linear,   # After linear registration
    nl_tfm,
    fix_im,
    is_mask=False
)

mov_msk_nonlinear = apply_transform(
    mov_msk_linear,  
    nl_tfm,
    fix_im,
    is_mask=True
)

mov_im_nonlinear_data = sitk.GetArrayFromImage(mov_im_nonlinear)
mov_msk_nonlinear_data = sitk.GetArrayFromImage(mov_msk_nonlinear)

[BSpline] grid_spacing=50mm, physical_size=[340.0, 340.0, 228.7965087890625], mesh_size=[7, 7, 5]
Executing B-spline registration...
✓ B-spline done.
  Stop: LBFGSBOptimizerv4: User requested
  Iter: 100
  Metric: -0.29546738182783383

Applying NON-LINEAR (B-spline) transform...


In [18]:

dice_nonlinear = evaluate_registration(
    fix_msk,
    mov_msk_nonlinear,
    "After non-linear (B-spline) registration:"
)

improvement_nl = (dice_nonlinear - dice_linear) / dice_linear * 100
print(f"\nNon-linear improvement over linear: {improvement_nl:.2f}%")


After non-linear (B-spline) registration:
  Dice Coefficient: 0.8871
  Fixed volume: 1197012 voxels
  Moving volume: 1119999 voxels
  Overlap: 1027658 voxels (85.9% of fixed)

Non-linear improvement over linear: 18.08%


In [17]:
visualize_registration(fix_im_data, fix_msk_data,
                          mov_im_nonlinear_data, mov_msk_nonlinear_data,
                          title="B-spline Registration")



interactive(children=(IntSlider(value=256, description='slc', max=511), Output()), _dom_classes=('widget-inter…

<function __main__.visualize_registration.<locals>.show_slice(fixed_img, fixed_mask, moving_img, moving_mask, slc)>

In [19]:
# compare all stages
def compare_all_stages(fixed_img, fixed_mask, 
                        linear_img, linear_mask,
                        nonlinear_img, nonlinear_mask, slc):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        
    # Before (identity)
    axes[0].imshow(fixed_img[:, slc, :], cmap='gray')
    axes[0].set_title(f'Fixed Image\nDice: {dice_before:.3f}')
    axes[0].axis('off')
        
    # Linear
    axes[1].imshow(fixed_img[:, slc, :], cmap='gray')
    axes[1].imshow(np.ma.masked_where(fixed_mask[:, slc, :] == 0, 
                                        fixed_mask[:, slc, :]),
                    cmap='Greens', alpha=0.3)
    axes[1].imshow(np.ma.masked_where(linear_mask[:, slc, :] == 0, 
                                        linear_mask[:, slc, :]),
                    cmap='Reds', alpha=0.3)
    axes[1].set_title(f'After Linear\nDice: {dice_linear:.3f}')
    axes[1].axis('off')
        
    # B-spline
    axes[2].imshow(fixed_img[:, slc, :], cmap='gray')
    axes[2].imshow(np.ma.masked_where(fixed_mask[:, slc, :] == 0, 
                                        fixed_mask[:, slc, :]),
                    cmap='Greens', alpha=0.3)
    axes[2].imshow(np.ma.masked_where(nonlinear_mask[:, slc, :] == 0, 
                                        nonlinear_mask[:, slc, :]),
                    cmap='Reds', alpha=0.3)
    axes[2].set_title(f'After B-spline\nDice: {dice_nonlinear:.3f}')
    axes[2].axis('off')
        
    plt.suptitle('Registration Comparison (Green=Fixed, Red=Moving)', fontsize=14)
    plt.tight_layout()
    plt.show()


In [20]:
interact(
    compare_all_stages,
    fixed_img=fixed(fix_im_data),
    fixed_mask=fixed(fix_msk_data),
    linear_img=fixed(mov_im_linear_data),
    linear_mask=fixed(mov_msk_linear_data),
    nonlinear_img=fixed(mov_im_nonlinear_data),
    nonlinear_mask=fixed(mov_msk_nonlinear_data),
    slc=IntSlider(min=0, max=fix_im_data.shape[1]-1, step=1, value=fix_im_data.shape[1]//2)
)

interactive(children=(IntSlider(value=256, description='slc', max=511), Output()), _dom_classes=('widget-inter…

<function __main__.compare_all_stages(fixed_img, fixed_mask, linear_img, linear_mask, nonlinear_img, nonlinear_mask, slc)>

In [21]:
# save non-linear registration results
sitk.WriteTransform(nl_tfm, str(output_dir / 'g1_54_to_common_40_bspline.tfm'))
sitk.WriteImage(mov_im_nonlinear, str(output_dir / 'g1_54_bspline_registered_image.nii.gz'))
sitk.WriteImage(mov_msk_nonlinear, str(output_dir / 'g1_54_bspline_registered_mask.nii.gz'))