# SimpleITK

The following short demo for this session was prepared using examples provided at SimpleITK Notebooks:

http://insightsoftwareconsortium.github.io/SimpleITK-Notebooks/

Read more about SimpleITK (and ITK in general):

http://www.simpleitk.org/



## Setting the enviroment for simpleITK

At this session, we will use only a few 2D images to explain basic concepts in image registration, but if you would like to run the additional jupyter notebooks (for more advanced topics in image regitration or image segmentation), it is strongly advised to download all of the data before you run the notebook.

In case of any problems with setting enviroment, see more:
http://insightsoftwareconsortium.github.io/SimpleITK-Notebooks/Python_html/00_Setup.html

In [None]:
import importlib

# check that all packages are installed (see requirements.txt file)
required_packages = {'jupyter', 
                     'numpy',
                     'matplotlib',
                     'ipywidgets',
                     'scipy',
                     'pandas',
                     'SimpleITK'
                    }

for package in required_packages:
    try:
        p = importlib.import_module(package)
        print(package, 'is installed')
    except ImportError:
        print(package, 'is missing')

If you need to install SimpleITK

__pip install SimpleITK__

If any issues, see more:
https://simpleitk.readthedocs.io/en/master/Documentation/docs/source/installation.html

In [None]:
#pip install SimpleITK

# or

#conda install -c https://conda.anaconda.org/simpleitk SimpleITK

## Check SimpleITK version and download data:


In [None]:
import SimpleITK as sitk

print(sitk.Version())


In [None]:
import sys, os
download_script_location = os.path.abspath(os.path.join('..','Utilities'))
if not download_script_location in sys.path:
    sys.path.append(download_script_location)

# we will not use all data for this session
#from downloaddata import fetch_data_all
#fetch_data_all(os.path.join('..','Data'), os.path.join('..','Data','manifest.json'))

## Let check one of the exemplar images:

We read image named: "training_001_ct.mha", check its properties(size, etc), and then we try to visualise it.

For further read on image basics in simpleITK (ITK)

http://insightsoftwareconsortium.github.io/SimpleITK-Notebooks/Python_html/03_Image_Details.html

In [None]:
import matplotlib.pyplot as plt

from downloaddata import fetch_data as fdata

fixed_image =  sitk.ReadImage(fdata("training_001_ct.mha"), sitk.sitkFloat32)

print("Size of fixed image: ",   fixed_image.GetSize())
print("Origin of fixed image: ", fixed_image.GetOrigin())
print("Spacing of fixed image: ",fixed_image.GetSpacing())
print("Pixel type of fixed image: ",fixed_image.GetPixelIDTypeAsString())

# fixed_image is 3D, we need to select slice, we want visualise:
slice_number=14

plt.imshow(sitk.GetArrayViewFromImage(fixed_image)[slice_number,:,:])
plt.axis('off');




### Grey colormap?
Usually, medical data are shown using grey colormap:

In [None]:
plt.imshow(sitk.GetArrayViewFromImage(fixed_image)[14,:,:],cmap=plt.cm.Greys_r)
plt.axis('off');

### Visualising 3D data:
Since theis data set is 3D (as many common medical modalities), we can visualise it using interactive scroll (this is why we import this:

__from ipywidgets import interact, fixed__

for __interact and fixed__, read more:

https://ipywidgets.readthedocs.io/en/latest/examples/Using%20Interact.html


In [None]:
def display_3d_image(slice_range, img_view, img_name='a nice image'):
    # Create a figure with two subplots and the specified size.
    plt.subplots(1,1,figsize=(5,5))
    
    # Draw the fixed image in the first subplot.
    plt.subplot(1,1,1)
    plt.imshow(img_view[slice_range,:,:],cmap=plt.cm.Greys_r)
    plt.title(img_name)
    plt.axis('off')
    
    plt.show()
    
from ipywidgets import interact, fixed

interact(display_3d_image, slice_range=(0,fixed_image.GetSize()[2]-1), 
         img_view=fixed(sitk.GetArrayViewFromImage(fixed_image)),img_name=fixed('head CT') );


## Adding transformation to image
In this step, we will create "an artificial" moving image by adding a trasformation to the fixed image 

In [None]:
dimension = 3        
offset =(10,5,1) # offset can be any vector-like data  
translation = sitk.TranslationTransform(dimension, offset)
print(translation)


reference_image = fixed_image
interpolator = sitk.sitkLinear
default_value = 100.0
moving_image = sitk.Resample(fixed_image, reference_image, translation, interpolator, default_value)

interact(display_3d_image, slice_range=(0,moving_image.GetSize()[2]-1), 
         img_view=fixed(sitk.GetArrayViewFromImage(moving_image)),img_name=fixed('shifted head CT - moving image') );

## a better visualisation?
To better see the artificial shift introduced to image, we can visualise image difference

In [None]:
def display_two_3Dimages_difference(slice_range, fixed, moving):
    img = fixed[:,:,slice_range] - moving[:,:,slice_range] 
    plt.imshow(sitk.GetArrayViewFromImage(img),cmap=plt.cm.Greys_r);
    plt.axis('off')
    plt.show()
    


interact(display_two_3Dimages_difference, slice_range=(0,fixed_image.GetSize()[2]-1), 
         fixed = fixed(fixed_image), moving=fixed(moving_image));


## Run registration!


In [None]:
from IPython.display import clear_output
# Callback invoked when the StartEvent happens, sets up our new data.
def start_plot():
    global metric_values, multires_iterations
    
    metric_values = []
    multires_iterations = []

# Callback invoked when the EndEvent happens, do cleanup of data and figure.
def end_plot():
    global metric_values, multires_iterations
    
    del metric_values
    del multires_iterations
    # Close figure, we don't want to get a duplicate of the plot latter on.
    plt.close()

# Callback invoked when the IterationEvent happens, update our data and display new figure.
def plot_values(registration_method):
    global metric_values, multires_iterations
    
    metric_values.append(registration_method.GetMetricValue())                                       
    # Clear the output area (wait=True, to reduce flickering), and plot current data
    clear_output(wait=True)
    # Plot the similarity metric values
    plt.plot(metric_values, 'r')
    plt.plot(multires_iterations, [metric_values[index] for index in multires_iterations], 'b*')
    plt.xlabel('Iteration Number',fontsize=12)
    plt.ylabel('Metric Value',fontsize=12)
    plt.show()
    
# Callback invoked when the sitkMultiResolutionIterationEvent happens, update the index into the 
# metric_values list. 
def update_multires_iterations():
    global metric_values, multires_iterations
    multires_iterations.append(len(metric_values))

    

            
initial_transform = sitk.TranslationTransform(dimension)
# try different transformation:
# #initial_transform = sitk.CenteredTransformInitializer(fixed_image, moving_image, sitk.Euler3DTransform(), sitk.CenteredTransformInitializerFilter.GEOMETRY)


registration_method = sitk.ImageRegistrationMethod()

# Similarity metric settings.
registration_method.SetMetricAsMeanSquares()


# try something else e.g. Correlation 
# registration_method.SetMetricAsCorrelation()
# or Mutual Information (this could be trick)
#registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
#registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
#registration_method.SetMetricSamplingPercentage(0.01)


registration_method.SetInterpolator(sitk.sitkLinear)
# try different interpolation:
# sitk.sitkNearestNeighbor 
# sitk.sitkBSpline

# Optimizer settings.
registration_method.SetOptimizerAsGradientDescent(learningRate=1, numberOfIterations=25, 
                                                  convergenceMinimumValue=1e-6, convergenceWindowSize=2)
registration_method.SetOptimizerScalesFromPhysicalShift()

# Setup for the multi-resolution framework.            
registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[4,2,0])
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

# Don't optimize in-place, we would possibly like to run this cell multiple times.
registration_method.SetInitialTransform(initial_transform, inPlace=False)

# Connect all of the observers so that we can perform plotting during registration.
registration_method.AddCommand(sitk.sitkStartEvent, start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, end_plot)
registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, update_multires_iterations) 
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: plot_values(registration_method))

final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                               sitk.Cast(moving_image, sitk.sitkFloat32))


moving_resampled = sitk.Resample(moving_image, fixed_image, final_transform, 
                                 sitk.sitkLinear, 0.0, moving_image.GetPixelID())

interact(display_two_3Dimages_difference, slice_range=(0,fixed_image.GetSize()[2]-1), 
         fixed = fixed(fixed_image), moving=fixed(moving_resampled));

print(final_transform)

## Multimodal registration:
So far we considered registration between fixed image (head CT) and its translated copy to learn a general image registration framework. Now we can read another image, coming e.g. from different scanner (MRI), and try to align it to fixed image (reference image)

In [None]:
## read other image
new_moving_image = sitk.ReadImage(fdata("training_001_mr_T1.mha"), sitk.sitkFloat32) 

print("Size of moving image: ",   new_moving_image.GetSize())
print("Origin of moving image: ", new_moving_image.GetOrigin())
print("Spacing of moving image: ",new_moving_image.GetSpacing())
print("Pixel type of moving image: ",new_moving_image.GetPixelIDTypeAsString())

interact(display_3d_image, slice_range=(0,new_moving_image.GetSize()[2]-1), 
         img_view=fixed(sitk.GetArrayViewFromImage(new_moving_image)),img_name=fixed('head MRI') );



We can blend this two images to see how they overlap 

(Why not to see difference as in case of the previous example?)

In [None]:
## register new image to the other one
## view results
## conclude

def display_two_3Dimages_with_alpha(slice_range, alpha, fixed, moving):
    img = (1.0 - alpha)*fixed[:,:,slice_range] + alpha*moving[:,:,slice_range] 
    plt.imshow(sitk.GetArrayViewFromImage(img),cmap=plt.cm.Greys_r);
    plt.axis('off')
    plt.show()
    
    
multmodal_initial_transform = sitk.CenteredTransformInitializer(fixed_image, 
                                                      new_moving_image, 
                                                      sitk.Euler3DTransform(), 
                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY)

new_moving_resampled = sitk.Resample(new_moving_image, fixed_image, multmodal_initial_transform, sitk.sitkLinear, 0.0, new_moving_image.GetPixelID())

interact(display_two_3Dimages_with_alpha, slice_range=(0,fixed_image.GetSize()[2]-1), 
         alpha=(0.0,1.0,0.05), fixed = fixed(fixed_image), moving=fixed(new_moving_resampled));

## Setting multimodal image registration:

Again, we need to set all elements of image registration, you can copy your "framework" from previous exercise or use this one (http://insightsoftwareconsortium.github.io/SimpleITK-Notebooks/Python_html/60_Registration_Introduction.html).

Play with different setup and parameters.
(further read here: http://insightsoftwareconsortium.github.io/SimpleITK-Notebooks/Python_html/62_Registration_Tuning.html)

Any comments?

In [None]:
multimodal_registration_method = sitk.ImageRegistrationMethod()

# Similarity metric settings.
multimodal_registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
multimodal_registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
multimodal_registration_method.SetMetricSamplingPercentage(0.01)

multimodal_registration_method.SetInterpolator(sitk.sitkLinear)

# Optimizer settings.
multimodal_registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100, convergenceMinimumValue=1e-6, convergenceWindowSize=10)
multimodal_registration_method.SetOptimizerScalesFromPhysicalShift()

# Setup for the multi-resolution framework.            
multimodal_registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
multimodal_registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,1,0])
multimodal_registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

# Don't optimize in-place, we would possibly like to run this cell multiple times.
multimodal_registration_method.SetInitialTransform(multmodal_initial_transform, inPlace=False)

# Connect all of the observers so that we can perform plotting during registration.
multimodal_registration_method.AddCommand(sitk.sitkStartEvent, start_plot)
multimodal_registration_method.AddCommand(sitk.sitkEndEvent, end_plot)
multimodal_registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, update_multires_iterations) 
multimodal_registration_method.AddCommand(sitk.sitkIterationEvent, lambda: plot_values(multimodal_registration_method))

multimodal_final_transform = multimodal_registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                               sitk.Cast(new_moving_image, sitk.sitkFloat32))

In [None]:
final_moving_resampled = sitk.Resample(new_moving_image, fixed_image, multimodal_final_transform, sitk.sitkLinear, 0.0, new_moving_image.GetPixelID())

interact(display_two_3Dimages_with_alpha, slice_range=(0,fixed_image.GetSize()[2]-1), 
         alpha=(0.0,1.0,0.05), fixed = fixed(fixed_image), moving=fixed(final_moving_resampled));