# Image Registration Notebook
For 3D CT and CBCT images

Sample Dataset available on: https://drive.google.com/file/d/1tSSTTLbH8j_svnmeALkq4FzwX1gDl6cT/view?usp=drive_link

Credits :https://grand-challenge.org/forums/forum/learn2reg-registration-challenge-449/topic/l2r23-data-release-1486/

In [1]:
!pip install SimpleITK

import SimpleITK as sitk
import os
import matplotlib.pyplot as plt
import numpy as np

Collecting SimpleITK
  Downloading SimpleITK-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.6/52.6 MB[0m [31m18.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: SimpleITK
Successfully installed SimpleITK-2.3.0


In [None]:
from google.colab import drive
drive.mount('/gdrive')

**Run the following definitions cell. no need to edit it, during the assignment**

In [None]:
# @title
#Run the following cell, no need to edit it, during the assignment
#author Luis de la O

import SimpleITK as sitk
import numpy as np
import os
import math
from skimage.measure import regionprops,label
from skimage import filters


def BlurringFilter(image,kernel_val=3,filter_val=1.0):
    kernel_size = (kernel_val, kernel_val, kernel_val)
    mean_filter = sitk.SmoothingRecursiveGaussianImageFilter()
    mean_filter.SetSigma(filter_val)
    smoothed_image = mean_filter.Execute(image)

    return smoothed_image

def normalizeSimpleITK(image,tresh_min,tresh_max):
    #image = sitk.Threshold(image, lower=tresh_min, upper=tresh_max)
    image_np = sitk.GetArrayFromImage(image)
    image_normalized = (image_np - (-1024)) / (600 - (-1024))
    image_normalized[image_normalized<=0]=0
    image_normalized[image_normalized>1]=1

    return image_normalized


def getCenter(image):
    threshold_value = filters.threshold_otsu(image)
    labeled_foreground = (image > threshold_value).astype(int)
    properties = regionprops(labeled_foreground, image)
    center_of_mass = properties[0].centroid
    weighted_center_of_mass = properties[0].weighted_centroid
    return center_of_mass,weighted_center_of_mass


# Functions 2D Translation2D Euler2D Similarity2D ScaleTransform2D
# Functions 3D Translation3D Euler3D VersorRigid3D Similarity3D Scale3D
def transformation_fun_select(arg,image):
    if arg == "Euler2D":
        return sitk.Euler2DTransform()
    if arg == "Similarity2D":
        return sitk.Similarity2DTransform()
    if arg == "ScaleTransform2D":
        return sitk.ScaleTransform(2)
    if arg == "Translation3D":
        return sitk.TranslationTransform(3)
    if arg == "Euler3D":
        return sitk.Euler3DTransform()
    if arg == "Similarity3D":
        return sitk.Similarity3DTransform()
    if arg == "VersorRigid3D":
        return sitk.VersorRigid3DTransform()
    if arg == "Scale3D":
        return sitk.ScaleVersor3DTransform()


def tailor_registration(fixed_array,moving_array,transf_spec,center_spec,metric_spec,gradient_spec,shift_sepc,offset="Diff",iterations_spec=300,lr=1,minStep=.00001,gradientT=1e-7,convWinSize=10,convMinVal=1e-7,plot_bool=False):
    fixed_image = sitk.GetImageFromArray(fixed_array)
    moving_image = sitk.GetImageFromArray(moving_array)
    fixed_image = sitk.Cast(fixed_image, sitk.sitkFloat32)
    moving_image = sitk.Cast(moving_image, sitk.sitkFloat32)

    transf_fun = transformation_fun_select(transf_spec,moving_array)
    if center_spec == "Geometry":
        center_fun = sitk.CenteredTransformInitializerFilter.GEOMETRY
    elif center_spec == "Moments":
        center_fun = sitk.CenteredTransformInitializerFilter.MOMENTS
    else:
        center_fun = sitk.CenteredTransformInitializerFilter.GetName()
        print(center_fun)
    #INITIALIZE
    registration_method = sitk.ImageRegistrationMethod()
    if transf_spec != "Translation2D":
        initial_transform = sitk.CenteredTransformInitializer(fixed_image, moving_image, transf_fun, center_fun)
        registration_method.SetInitialTransform(initial_transform, inPlace=False)
    else:
        #For 1D translation w_center of fixed array worked as offset
        if offset == "Diff":
            centroid_fix, w_center_fix = getCenter(fixed_array)
            centr_mov, w_center_mov = getCenter(moving_array)
            centroid_difference = (w_center_fix[0] - w_center_mov[0], w_center_fix[1] - w_center_mov[1])
        elif offset == "Fix":
            centroid_fix, w_center_fix = getCenter(fixed_array)
            centroid_difference = w_center_fix
        elif offset == "Mov":
            centroid_fix, w_center_fix = getCenter(moving_array)
            centroid_difference = w_center_fix
        else:
            print("No Offset Selected")

        translation_transform = sitk.TranslationTransform(2,centroid_difference)
        rigid_transform = sitk.Euler2DTransform()
        rigid_transform.SetTranslation(translation_transform.GetOffset())
        registration_method.SetInitialTransform(rigid_transform, inPlace=False)

    #METRICS
    if metric_spec == "Correlation":
        registration_method.SetMetricAsCorrelation()
    elif metric_spec == "MatesMutualInformation":
        registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=100)
    elif metric_spec == "MeanSquares":
        registration_method.SetMetricAsMeanSquares()
    else:
        print("Error: No metric selected")
        exit(1)
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.2)

    #OPTIMIZER
    if gradient_spec == "RegularStepGradientDescent":
        registration_method.SetOptimizerAsRegularStepGradientDescent(learningRate=lr, numberOfIterations=iterations_spec,minStep=minStep, gradientMagnitudeTolerance=gradientT)
    elif gradient_spec == "GradientDescent":
        registration_method.SetOptimizerAsGradientDescent(learningRate=lr, numberOfIterations=iterations_spec,convergenceMinimumValue=convMinVal,convergenceWindowSize=convWinSize)
    else:
        print("Error: No metric selected")
        exit(1)

    if shift_sepc == "PhysicalShift":
        registration_method.SetOptimizerScalesFromPhysicalShift()
    elif shift_sepc == "IndexShift":
        registration_method.SetOptimizerScalesFromIndexShift()
    else:
        print("Error: No metric selected")
        exit(1)

    #Registration
    final_transform = registration_method.Execute(fixed_image, moving_image)
    evaluationMetric = registration_method.GetMetricValue()
    initial_metric_value = registration_method.MetricEvaluate(fixed_image, moving_image)
    print(f"Initial metric value: {initial_metric_value}")
    print(f"Final metric value: {evaluationMetric}")
    print(f"Optimizer's stopping condition: {registration_method.GetOptimizerStopConditionDescription()}")
    moved_image = sitk.Resample(moving_image, fixed_image, final_transform, sitk.sitkLinear, 0.0, moving_image.GetPixelID())
    moved_array = sitk.GetArrayFromImage(moved_image)

    return moved_array,evaluationMetric,final_transform

Check Sizes
Check Orientation
Check Pixel Range Values

In [None]:
root = "/gdrive/My Drive/ImageRegistration/RegistrationPart2/SampleDataset/"
name = "ThoraxCBCT_0000"
sliceNum=90

ct_path = root+name+"_0000.nii.gz"
fixed_sitk = sitk.ReadImage(ct_path, sitk.sitkFloat32)
fixed_array = normalizeSimpleITK(fixed_sitk,-1024,600)
plt.subplot(1,3,1),plt.imshow(fixed_array[sliceNum,:,:],cmap="gray"),plt.axis('off')
plt.title("CT - Fixed Image")

cbct_path= root+name+"_0001.nii.gz"
moving_sitk = sitk.ReadImage(cbct_path, sitk.sitkFloat32)
moving_array = normalizeSimpleITK(moving_sitk,-1024,600)
plt.subplot(1,3,2),plt.imshow(moving_array[sliceNum,:,:],cmap="gray"),plt.axis('off')
plt.title("CBCT - Moving Image")

plt.subplot(1,3,3),plt.imshow(fixed_array[sliceNum,:,:]-moving_array[sliceNum,:,:],cmap="gray"),plt.axis('off')
plt.title("Difference")
plt.tight_layout()
plt.show()

print("Normalized CT Values: ",fixed_array.max(),fixed_array.min())
print("Normalized CBCT Values: ",moving_array.max(),moving_array.min())

In [None]:
#For a faster registration, we can blurr the image if necessary
fixedBlurred_image = BlurringFilter(fixed_sitk,kernel_val=15,filter_val=2)
movingBlurred_image = BlurringFilter(moving_sitk,kernel_val=15,filter_val=2)
#To Normalized filtered image
fixed_array_blurred = normalizeSimpleITK(fixedBlurred_image,-1024,600)
moving_array_blurred = normalizeSimpleITK(movingBlurred_image,-1024,600)

fixed_array_unfiltered = normalizeSimpleITK(fixed_sitk,-1024,600)
moving_array_unfiltered = normalizeSimpleITK(moving_sitk,-1024,600)

#Visualize both of your images, notice the difference between them
print("Image dimensions: Fixed Image: ",fixed_array_unfiltered.shape," Moving Image:", fixed_array_unfiltered.shape)
plt.subplot(231),plt.imshow(fixed_array_unfiltered[140,:,:]),plt.axis('off'),plt.title("Fixed Unfiltered")
plt.subplot(232),plt.imshow(fixed_array_blurred[140,:,:]),plt.axis('off'),plt.title("Fixed Filtered")
plt.subplot(233),plt.imshow(moving_array_unfiltered[140,:,:]),plt.axis('off'),plt.title("Moving Unfiltered")
plt.subplot(234),plt.imshow(moving_array_blurred[140,:,:]),plt.axis('off'),plt.title("Moving Filtered")
plt.subplot(235),plt.imshow(fixed_array_unfiltered[140,:,:]-moving_array_unfiltered[140,:,:]),plt.axis('off'),plt.title("Unfiltered Difference")
plt.subplot(236),plt.imshow(fixed_array_blurred[140,:,:]-moving_array_blurred[140,:,:]),plt.axis('off'),plt.title("Filtered Difference")
plt.tight_layout()
plt.show()

#We can also do a simple substraction to check how many pixels are out of place:
print("% out of place, Filtered: ",100*np.sum(abs(moving_array_blurred-fixed_array_blurred))/np.sum(moving_array_blurred+fixed_array_blurred))
print("% out of place, Unfiltered: ",100*np.sum(abs(moving_array_unfiltered-fixed_array_unfiltered))/np.sum(moving_array_unfiltered+fixed_array_unfiltered))

### Possible Transform values are:  
-- **3D Transforms:** Similarity3D" "VersorRigid3D" "Scale3D" "Euler3D" "Translation3D".
-- **Center Spec:** Geometry . Moments.  
-- **Metrics:** Correlation . MatesMutualInformation. MeanSquares.  
-- **Optimizer:** GradientDescent . RegularStepGradientDescent.  
-- **Shift:** IndexShift . PhysicalShift

### Other configurations editable with current defaults:  
**For TranslationTransform:** offset="Diff" or "Fix" or "Mov".
"Diff" is difference between centroids of moving and target. "Fixed" is centroid of Fixed. "Mov" is centroid of Moving.  

**For Optimizers** : .   
iterations_spec=300  . Maximum number of iteration for all Optimizers.  
lr=1  Learning Rate for all Optimizers.  
minStep=.00001  , Minimum Step for Optimizer: Regular Step Gradient Descent.  
gradientT=1e-7  , GradientMagnitudeTolerance for Optimizer :  Regular Step Gradient Descent.  
convWinSize=10  , Convergence Minimum Value for Optimizer: Gradient Descent.  
convMinVal=1e-7 , Convergence Window Size for Optimizer: : Gradient Descent.  

In [None]:
transf_spec = "Scale3D" #"Similarity3D" "VersorRigid3D" "Scale3D" "Euler3D" "Translation3D"
center_spec = "Geometry"  # Geometry or Moments
metric_spec = "Correlation"  # "Correlation" or "MatesMutualInformation" or "MeanSquares"
optimizer = "RegularStepGradientDescent"  # GradientDescent" or "RegularStepGradientDescent"
shift_sepc = "PhysicalShift"  # "PhysicalShift" or "IndexShift"

fixed_array=fixed_array_blurred
moving_array = moving_array_blurred

#Both Optimizers
iterations_spec=300
lr=1
#ResgularStep Gradient
minStep=0.0001
gradientT=1e-8
#Gradient Descent
convMinVal =1
convWinSize = int(1e-18)

#Offset
offset = "Diff" #Fix, Mov, Diff


moved_array, evaluationMetric, final_transform = tailor_registration(fixed_array, moving_array, transf_spec,
                                                                         center_spec, metric_spec, optimizer,
                                                                         shift_sepc, offset, iterations_spec=iterations_spec,
                                                                         lr=lr,minStep=minStep, gradientT=gradientT,
                                                                         convMinVal=convMinVal,convWinSize = convWinSize)

In [None]:
#Visually see the registration:
print("Image dimensions: Image: ",fixed_array.shape," Image 2:", moved_array.shape)
plt.subplot(221),plt.imshow(fixed_array[150,:,:]),plt.title("Target"),plt.axis('off')
plt.subplot(222),plt.imshow(moving_array[150,:,:]),plt.title("Moving"),plt.axis('off')
plt.subplot(223),plt.imshow(fixed_array[150,:,:]-moving_array[150,:,:]),plt.title("Before Registration"),plt.axis('off')
plt.subplot(224),plt.imshow(fixed_array[150,:,:]-moved_array[150,:,:]),plt.title("After Registration"),plt.axis('off')
plt.show()

NOT BLURRED.

Initial metric value: -0.45744104476700714.

Final metric value: -0.7248766683264387.

BLURRED.

Initial metric value: -0.5067674241674248.

Final metric value: -0.7981090539856124.
