# Register dataset

- All patient images registered using series 'Pelvis_t2_spc_rst_tra_p2_iso' as fixed image. Patients who does not have this image are skipped. 

- Two methods used, one for MRI and another for PET. 

- Full body scans are skipped, only pelvis scans used

- Estimated time for registering dataset: 15 min

In [None]:
import os
import pydicom
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interactive, widgets, Layout, HBox, VBox, Label, interact, fixed
from IPython.display import display, clear_output
import itk
import SimpleITK as sitk
from concurrent.futures import ProcessPoolExecutor

### Method 1: Simple Center Transform

Use the CenteredTransformInitializer to align the centers of the two volumes and set the center of rotation to the center of the fixed image. Method used between MRI images, as they are very similar images although they vary in resolution

In [None]:
def apply_centered_transform(fixed_image, moving_image):
    initial_transform = sitk.CenteredTransformInitializer(fixed_image, 
                                                        moving_image, 
                                                        sitk.Euler3DTransform(), 
                                                        sitk.CenteredTransformInitializerFilter.GEOMETRY)

    reference_image = sitk.Image(fixed_image.GetSize()[0], fixed_image.GetSize()[1], fixed_image.GetSize()[2], fixed_image.GetPixelID())
    reference_image.SetOrigin(fixed_image.GetOrigin())
    reference_image.SetSpacing((fixed_image.GetSpacing()[0], fixed_image.GetSpacing()[1], fixed_image.GetSpacing()[2]))
    reference_image.SetDirection(fixed_image.GetDirection())
    return sitk.Resample(moving_image, reference_image, initial_transform, sitk.sitkHammingWindowedSinc, 0.0, moving_image.GetPixelID())



### Method 2: Example registration from ITK notebook examples

More sophisticated method used for PET Pelvis images due to its high differences with MRI images. This method does not work correctly with series like MRI ADC because they have a very small depth compared to the fixed image series (Pelvis_t2_spc_rst_tra_p2_iso)

In [None]:
def register_image(fixed_image, moving_image):

    initial_transform = sitk.CenteredTransformInitializer(fixed_image, 
                                                        moving_image, 
                                                        sitk.Euler3DTransform(), 
                                                        sitk.CenteredTransformInitializerFilter.GEOMETRY)
    registration_method = sitk.ImageRegistrationMethod()

    # Similarity metric settings.
    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)

    registration_method.SetInterpolator(sitk.sitkHammingWindowedSinc)

    # Optimizer settings.
    registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100, convergenceMinimumValue=1e-6, convergenceWindowSize=10)
    registration_method.SetOptimizerScalesFromPhysicalShift()

    # Setup for the multi-resolution framework.            
    registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,1,0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    # Don't optimize in-place, we would possibly like to run this cell multiple times.
    registration_method.SetInitialTransform(initial_transform, inPlace=False)

    final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                                sitk.Cast(moving_image, sitk.sitkFloat32))
    # Resample the moving image to the space of the fixed image
    reference_image = sitk.Image(fixed_image.GetSize()[0], fixed_image.GetSize()[1], fixed_image.GetSize()[2], fixed_image.GetPixelID())
    reference_image.SetOrigin(fixed_image.GetOrigin())
    reference_image.SetSpacing((fixed_image.GetSpacing()[0], fixed_image.GetSpacing()[1], fixed_image.GetSpacing()[2]))
    reference_image.SetDirection(fixed_image.GetDirection())
    return sitk.Resample(moving_image, reference_image, final_transform, sitk.sitkHammingWindowedSinc, 0.0, moving_image.GetPixelID())

### Register all images

In [None]:
fixed_image_descriptions = [
'Pelvis_t2_spc_rst_tra_p2_iso', # patients who does not contain this series will be skipped
]
def find_fixed_image(patient_images_path):
    for description in fixed_image_descriptions:
        try:
            image = sitk.ReadImage(os.path.join(patient_images_path, description+'.mha'), sitk.sitkFloat32)
            return image, description
        except Exception as e:
            continue
    return None, None

In [None]:
def process_image(fixed_image, moving_image, output_path, method):
    if method == 1:
        registered_image = apply_centered_transform(fixed_image, moving_image) # method 1
    elif method == 2:
        registered_image = register_image(fixed_image, moving_image) # method 2
    else:
        raise Exception('Method must be 1 or 2')
    sitk.WriteImage(registered_image, output_path)

In [None]:
root_dir = '../../../data/ProstateData/BREST patients/'
mha_images_path = '/local_ssd/practical_wise24/prostate_cancer/mha_corrected_images'
# mha_images_path = './mha_corrected_images'
registered_dataset_path = '/local_ssd/practical_wise24/prostate_cancer/mha_registered_images'
# registered_dataset_path = './mha_registered_images'
all_descriptions = [
    'Pelvis_t2_spc_rst_tra_p2_iso',
    'Pelvis_t2_haste_fs_db_tra_p2_320',
    '*MRAC_PET_mlaa_siemens_4BP TK_AC Images',
    '*Pelvis_MRAC_PET_mlaa_siemens_Becken_1BP_15min_LM_AC Images',
    '*Pelvis_MRAC_PET_siemens_Becken_1BP_15min_LM_AC Images',
    '*MRAC_PET_siemens_4BP TK_AC Images',
    'Pelvis_ep2d_diff_tra_ADC',
    'Pelvis_ep2d_diff_tra',
    'Pelvis_t1_tse_cor_p2',
]

In [None]:
with ProcessPoolExecutor(max_workers=os.cpu_count()) as executor:
    for patient_folder in os.listdir(root_dir):
        files_path = os.path.join(mha_images_path, patient_folder)
        fixed_image, description_used = find_fixed_image(files_path)
        write_file_path = os.path.join(registered_dataset_path, patient_folder)
        if os.path.exists(write_file_path):
            continue
        else:
            os.makedirs(write_file_path)

        if fixed_image is not None:
            tasks = []

            for mha_file_name in os.listdir(files_path):
                description = mha_file_name.split('.')[0]
                if 'PET' in description:
                    if 'Pelvis' in description: # full body pets are skipped
                        moving_image = sitk.ReadImage(os.path.join(files_path, mha_file_name), sitk.sitkFloat32)
                        output_path = os.path.join(write_file_path, mha_file_name)
                        task = executor.submit(process_image, fixed_image, moving_image, output_path, 2)
                        tasks.append(task)
                else:
                    if mha_file_name.split('.')[0] != description_used:
                        moving_image = sitk.ReadImage(os.path.join(files_path, mha_file_name), sitk.sitkFloat32)
                        output_path = os.path.join(write_file_path, mha_file_name)
                        task = executor.submit(process_image, fixed_image, moving_image, output_path, 1)
                        tasks.append(task)

            # Wait for all tasks to complete
            for task in tasks:
                task.result()

            sitk.WriteImage(fixed_image, os.path.join(write_file_path, description_used + '.mha'))

### Visualize Results

In [None]:
def display_images(fixed_image_z, moving_image_z, fixed_npa, moving_npa):
    # Create a figure with two subplots and the specified size.
    plt.subplots(1,2,figsize=(10,8))
    
    # Draw the fixed image in the first subplot.
    plt.subplot(1,2,1)
    plt.imshow(fixed_npa[fixed_image_z,:,:],cmap=plt.cm.Greys_r);
    plt.title('fixed image')
    plt.axis('off')
    
    # Draw the moving image in the second subplot.
    plt.subplot(1,2,2)
    plt.imshow(moving_npa[moving_image_z,:,:],cmap=plt.cm.Greys_r);
    plt.title('moving image')
    plt.axis('off')
    
    plt.show()

In [None]:
fixed_image =  sitk.ReadImage('./mha_registered_images/BREST_001/Pelvis_t2_spc_rst_tra_p2_iso.mha', sitk.sitkFloat32)
moving_image = sitk.ReadImage('./mha_registered_images/BREST_001/Pelvis_ep2d_diff_tra_ADC.mha', sitk.sitkFloat32) 

In [None]:
interact(display_images, fixed_image_z=(0,fixed_image.GetSize()[2]-1), moving_image_z=(0,moving_image.GetSize()[2]-1), fixed_npa = fixed(sitk.GetArrayViewFromImage(fixed_image)), moving_npa=fixed(sitk.GetArrayViewFromImage(moving_image)));