# Image Registration Pipeline Using SimpleITK and Elastix

## Overview

This Jupyter notebook implements an image registration pipeline using SimpleITK and Elastix. The pipeline includes both rigid and deformable registration steps. The goal is to align pre-RT MRI images and masks to the reference frame of the mid-RT image. In other words the pre-RT data is the moving image and the mid-RT data is the fixed image.

Rigid registration takes ~0.5-1.5 mins per case. Deformable registration takes ~8-12 mins per case.

We utilize a parameter map from the [Elastix Model Zoo based on Parameter Map 23](https://github.com/SuperElastix/ElastixModelZoo/tree/master/models/Par0023). 

To avoid issues, this code should ideally only be run from a fresh conda enviornment only with [SimpleElastix installed](https://pypi.org/project/SimpleITK-SimpleElastix/). We have provided a YML file (enviornment.yml) that you can use for replicating the enviornment.

We use an example patient from the training set of HNTS-MRG 2024, [avaliable on Zenodo](https://zenodo.org/records/11199559).

## Library Imports

In [1]:
import os 
import SimpleITK as sitk #SimpleITK-SimpleElastix version
import logging
import time
import traceback

In [2]:
# sanity check to make sure you have Elastix installed properly
try:
    elastixImageFilter = sitk.ElastixImageFilter()
    print("SimpleElastix is correctly installed and integrated with SimpleITK!")
except AttributeError as e:
    print(f"Error: {e}\nIt seems SimpleElastix is not correctly integrated with SimpleITK.")

SimpleElastix is correctly installed and integrated with SimpleITK!


## Main Code

In [3]:
def setup_logging(log_file):
    """Setup logging configuration."""
    log_dir = os.path.dirname(log_file)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    logging.basicConfig(filename=log_file, level=logging.DEBUG, 
                        format='%(asctime)s - %(levelname)s - %(message)s')
    logger = logging.getLogger()
    return logger

def read_images(ID_path, ID):
    """Read fixed, moving, and mask images from the specified paths."""
    preRT_path = os.path.join(ID_path, "preRT")
    midRT_path = os.path.join(ID_path, "midRT")

    preRT_image_path = os.path.join(preRT_path, f"{ID}_preRT_T2.nii.gz")
    preRT_mask_path = os.path.join(preRT_path, f"{ID}_preRT_mask.nii.gz")
    midRT_image_path = os.path.join(midRT_path, f"{ID}_midRT_T2.nii.gz")

    fixed_image = sitk.ReadImage(midRT_image_path)
    moving_image = sitk.ReadImage(preRT_image_path)
    mask_sitk = sitk.ReadImage(preRT_mask_path)

    return fixed_image, moving_image, mask_sitk

def perform_rigid_registration(fixed_image, moving_image):
    """Perform rigid registration between fixed and moving images."""
    initial_transform = sitk.CenteredTransformInitializer(fixed_image, 
                                                          moving_image, 
                                                          sitk.Euler3DTransform(), 
                                                          sitk.CenteredTransformInitializerFilter.GEOMETRY)

    registration_method = sitk.ImageRegistrationMethod()
    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100, 
                                                      convergenceMinimumValue=1e-6, convergenceWindowSize=10)
    registration_method.SetOptimizerScalesFromPhysicalShift()
    registration_method.SetInitialTransform(initial_transform)
    registration_method.SetInterpolator(sitk.sitkLinear)

    final_transform = registration_method.Execute(fixed_image, moving_image)
    moving_image_resampled = sitk.Resample(moving_image, fixed_image, final_transform, 
                                           sitk.sitkLinear, 0.0, moving_image.GetPixelID())
    return final_transform, moving_image_resampled

def apply_rigid_transformation(mask_sitk, fixed_image, final_transform):
    """Apply rigid transformation to the mask image."""
    transformed_mask_sitk_rigid = sitk.Resample(mask_sitk, fixed_image, final_transform, 
                                                sitk.sitkNearestNeighbor, 0.0, mask_sitk.GetPixelID())
    return transformed_mask_sitk_rigid

def perform_deformable_registration(fixed_image, moving_image_resampled, elastix_file, top_folder):
    """Perform deformable registration using Elastix."""
    elastixImageFilter = sitk.ElastixImageFilter()
    p = sitk.ReadParameterFile(elastix_file)
    elastixImageFilter.SetParameterMap(p)
    elastixImageFilter.SetOutputDirectory(top_folder)
    elastixImageFilter.LogToFileOn()
    elastixImageFilter.SetFixedImage(fixed_image)
    elastixImageFilter.SetMovingImage(moving_image_resampled)
    elastixImageFilter.Execute()

    resultImage = elastixImageFilter.GetResultImage()
    transformParameterMap = elastixImageFilter.GetTransformParameterMap()

    return resultImage, transformParameterMap

def apply_deformable_transformation(mask_sitk_rigid, fixed_image, transformParameterMap):
    """Apply deformable transformation to the mask image."""
    transformixImageFilter = sitk.TransformixImageFilter()
    transformixImageFilter.SetMovingImage(mask_sitk_rigid)
    transformixImageFilter.SetTransformParameterMap(transformParameterMap)
    transformixImageFilter.ComputeDeformationFieldOn()
    transformixImageFilter.Execute()
    
    deformationField = transformixImageFilter.GetDeformationField()
    dvf = sitk.Cast(deformationField, sitk.sitkVectorFloat64)
    displacement_transform = sitk.DisplacementFieldTransform(dvf)
    displacement_field = displacement_transform.GetDisplacementField() # this is what is needed to save into the file

    transformed_mask_sitk_deform = sitk.Resample(mask_sitk_rigid, fixed_image, displacement_transform, 
                                                 sitk.sitkNearestNeighbor, 0.0, mask_sitk_rigid.GetPixelID())

    return transformed_mask_sitk_deform, displacement_field

def process_ID(ID_path, top_folder, elastix_file, apply_deformable=True, save_transform_files=False):
    """Process a single ID folder for image registration and transformation."""
    ID = os.path.split(ID_path)[-1]
    start_time = time.time()
    try:
        midRT_path = os.path.join(ID_path, "midRT")
        logging.info(f'Starting processing {ID_path}')
        fixed_image, moving_image, mask_sitk = read_images(ID_path, ID)
        
        final_transform, moving_image_resampled = perform_rigid_registration(fixed_image, moving_image)
        transformed_mask_sitk_rigid = apply_rigid_transformation(mask_sitk, fixed_image, final_transform)

        if save_transform_files:  # Save rigid transform if desired
            rigid_transform_file = os.path.join(midRT_path, "rigid_transform.tfm")
            sitk.WriteTransform(final_transform, rigid_transform_file)

        if apply_deformable:
            resultImage, transformParameterMap = perform_deformable_registration(fixed_image, moving_image_resampled, elastix_file, top_folder)
            
            transformed_mask_sitk_deform, displacement_field = apply_deformable_transformation(transformed_mask_sitk_rigid, fixed_image, transformParameterMap)
            
            if save_transform_files:  # Save deformable transform if desired
                deformation_field_file = os.path.join(midRT_path, "deformable_transform.mha")
                sitk.WriteImage(displacement_field, deformation_field_file)
            
            sitk.WriteImage(resultImage, os.path.join(midRT_path, f"{ID}_preRT_T2_registered.nii.gz"))
            sitk.WriteImage(transformed_mask_sitk_deform, os.path.join(midRT_path, f"{ID}_preRT_mask_registered.nii.gz"))
        else:
            sitk.WriteImage(moving_image_resampled, os.path.join(midRT_path, f"{ID}_preRT_T2_registered.nii.gz"))
            sitk.WriteImage(transformed_mask_sitk_rigid, os.path.join(midRT_path, f"{ID}_preRT_mask_registered.nii.gz"))
        
        elapsed_time = time.time() - start_time
        logging.info(f'Finished processing {ID_path} in {elapsed_time:.2f} seconds')
        print(f'Finished processing {ID_path} in {elapsed_time:.2f} seconds')

    except Exception as e:
        elapsed_time = time.time() - start_time
        err_msg = f'Error processing {ID_path} in {elapsed_time:.2f} seconds: {str(e)}'
        logging.error(err_msg)
        logging.error("Traceback:", exc_info=True)
        print(err_msg)

def main():
    """Main function to execute the image registration process."""
    top_folder = "Example_data"
    elastix_file = "Elastix_parameterset_23.txt"
    save_transform_files = False # turn on if you want to get the intermediatry files (deformable one is big)
    apply_deformable = True

    logger = setup_logging(os.path.join(top_folder, 'image_registration_log.txt'))
    
    ID_paths = [os.path.join(top_folder, folder) for folder in os.listdir(top_folder) if os.path.isdir(os.path.join(top_folder, folder)) and folder.isnumeric()]
    
    try:
        for ID_path in ID_paths:
            process_ID(ID_path, top_folder, elastix_file, apply_deformable, save_transform_files)
    finally:
        handlers = logger.handlers[:]
        for handler in handlers:
            handler.close()
            logger.removeHandler(handler)

if __name__ == "__main__":
    main()

Installing all components.
InstallingComponents was successful.

ELASTIX version: 5.000
Command line options from ElastixBase:
-fMask    unspecified, so no fixed mask used
-mMask    unspecified, so no moving mask used
-out      Example_data/
-threads  unspecified, so all available threads are used
Command line options from TransformBase:
-t0       unspecified, so no initial transform used
  The default value "3" is used instead.
  The default value "false" is used instead.

Reading images...
Reading images took 0 ms.

Initialization of all components (before registration) took: 2 ms.
Preparation of the image pyramids took: 1371 ms.

Resolution: 0
  The default value "true" is used instead.
  The default value "true" is used instead.
  The default value "true" is used instead.
  The default value "true" is used instead.
Setting the fixed masks took: 0 ms.
Setting the moving masks took: 0 ms.
  The default value "false" is used instead.
  The default value "1" is used instead.
  The defa