# Test DEMONS algorithm on experiment 1

#### Author: 
Bruno De Santi, PhD
#### Affiliation:
Multi-modality Medical Imaging Lab (M3I Lab), University of Twente, Enschede, The Netherlands
#### Date:
20/09/2023
#### Paper/Project Title:
Automated three-dimensional image registration for longitudinal photoacoustic imaging (De Santi et al. 2023, JBO)
#### GitHub:
https://github.com/brunodesanti/muvinn-reg
#### License:
[Specify the license, e.g., MIT, GPL, etc.]

### Import libraries

In [None]:
abs_path = r'C:\Users\DeSantiB\muvinn-reg'
import sys
sys.path.append(abs_path)

from utils import processing as proc
from utils import visualizing as vis
from utils import evaluating as eva

import matplotlib.pyplot as plt
import numpy as np
import os 
import torch
import time
import scipy.ndimage as scnd
import SimpleITK as sitk
import monai

plot_flag = True # if one wants to plot figures
save_flag = True # if one wants to save figures

### Demons functions
#### Adapted from tutorial: https://insightsoftwareconsortium.github.io/SimpleITK-Notebooks/Python_html/66_Registration_Demons.html

In [None]:
def smooth_and_resample(image, shrink_factors, smoothing_sigmas):
    """
    Args:
        image: The image we want to resample.
        shrink_factor(s): Number(s) greater than one, such that the new image's size is original_size/shrink_factor.
        smoothing_sigma(s): Sigma(s) for Gaussian smoothing, this is in physical units, not pixels.
    Return:
        Image which is a result of smoothing the input and then resampling it using the given sigma(s) and shrink factor(s).
    """
    if np.isscalar(shrink_factors):
        shrink_factors = [shrink_factors]*image.GetDimension()
    if np.isscalar(smoothing_sigmas):
        smoothing_sigmas = [smoothing_sigmas]*image.GetDimension()

    smoothed_image = sitk.SmoothingRecursiveGaussian(image, smoothing_sigmas)
    
    original_spacing = image.GetSpacing()
    original_size = image.GetSize()
    new_size = [int(sz/float(sf) + 0.5) for sf,sz in zip(shrink_factors,original_size)]
    new_spacing = [((original_sz-1)*original_spc)/(new_sz-1) 
                   for original_sz, original_spc, new_sz in zip(original_size, original_spacing, new_size)]
    return sitk.Resample(smoothed_image, new_size, sitk.Transform(), 
                         sitk.sitkLinear, image.GetOrigin(),
                         new_spacing, image.GetDirection(), 0.0, 
                         image.GetPixelID())


    
def multiscale_demons(registration_algorithm,
                      fixed_image, moving_image, initial_transform = None, 
                      shrink_factors=None, smoothing_sigmas=None):
    """
    Run the given registration algorithm in a multiscale fashion. The original scale should not be given as input as the
    original images are implicitly incorporated as the base of the pyramid.
    Args:
        registration_algorithm: Any registration algorithm that has an Execute(fixed_image, moving_image, displacement_field_image)
                                method.
        fixed_image: Resulting transformation maps points from this image's spatial domain to the moving image spatial domain.
        moving_image: Resulting transformation maps points from the fixed_image's spatial domain to this image's spatial domain.
        initial_transform: Any SimpleITK transform, used to initialize the displacement field.
        shrink_factors (list of lists or scalars): Shrink factors relative to the original image's size. When the list entry, 
                                                   shrink_factors[i], is a scalar the same factor is applied to all axes.
                                                   When the list entry is a list, shrink_factors[i][j] is applied to axis j.
                                                   This allows us to specify different shrink factors per axis. This is useful
                                                   in the context of microscopy images where it is not uncommon to have
                                                   unbalanced sampling such as a 512x512x8 image. In this case we would only want to 
                                                   sample in the x,y axes and leave the z axis as is: [[[8,8,1],[4,4,1],[2,2,1]].
        smoothing_sigmas (list of lists or scalars): Amount of smoothing which is done prior to resmapling the image using the given shrink factor. These
                          are in physical (image spacing) units.
    Returns: 
        SimpleITK.DisplacementFieldTransform
    """
    # Create image pyramid.
    fixed_images = [fixed_image]
    moving_images = [moving_image]
    if shrink_factors:
        for shrink_factor, smoothing_sigma in reversed(list(zip(shrink_factors, smoothing_sigmas))):
            fixed_images.append(smooth_and_resample(fixed_images[0], shrink_factor, smoothing_sigma))
            moving_images.append(smooth_and_resample(moving_images[0], shrink_factor, smoothing_sigma))
    
    # Create initial displacement field at lowest resolution. 
    # Currently, the pixel type is required to be sitkVectorFloat64 because of a constraint imposed by the Demons filters.
    if initial_transform:
        initial_displacement_field = sitk.TransformToDisplacementField(initial_transform, 
                                                                       sitk.sitkVectorFloat64,
                                                                       fixed_images[-1].GetSize(),
                                                                       fixed_images[-1].GetOrigin(),
                                                                       fixed_images[-1].GetSpacing(),
                                                                       fixed_images[-1].GetDirection())
    else:
        initial_displacement_field = sitk.Image(fixed_images[-1].GetWidth(), 
                                                fixed_images[-1].GetHeight(),
                                                fixed_images[-1].GetDepth(),
                                                sitk.sitkVectorFloat64)
        initial_displacement_field.CopyInformation(fixed_images[-1])
 
    # Run the registration.            
    initial_displacement_field = registration_algorithm.Execute(fixed_images[-1], 
                                                                moving_images[-1], 
                                                                initial_displacement_field)
    # Start at the top of the pyramid and work our way down.    
    for f_image, m_image in reversed(list(zip(fixed_images[0:-1], moving_images[0:-1]))):
            initial_displacement_field = sitk.Resample (initial_displacement_field, f_image)
            initial_displacement_field = registration_algorithm.Execute(f_image, m_image, initial_displacement_field)
    return sitk.DisplacementFieldTransform(initial_displacement_field)

### Apply Demons on experiment 1 
#### Adapted from tutorial: https://insightsoftwareconsortium.github.io/SimpleITK-Notebooks/Python_html/66_Registration_Demons.html

In [None]:
# Data path
data_path = r'path\to\data'

# Path to directory where landmarks are stored
landmarks_path = r'C:\Users\DeSantiB\muvinn-reg\notebooks\mevislab\landmarks'

# Load and pad fixed image
fixed_acq = '01_01'
fixed_path = data_path + os.path.sep + fixed_acq + '.npy'
fixed_data = np.load(fixed_path, allow_pickle=True).item()
fixed_image = proc.pad_rec(fixed_data["rec"], ((3, 3), (3, 3),(3, 3)))

# Pad depth map and cup mask
depth_map = proc.pad_rec(fixed_data["depth_map"], ((3, 3), (3, 3),(3, 3)))
cup_mask = proc.pad_rec(fixed_data["cup_mask"], ((3, 3), (3, 3),(3, 3)))

# Processing fixed image
frangi_options_pp = dict()
frangi_options_pp['sigmas'] = (1.5, 3, 5, 9, 12)
frangi_options_pp['alpha'] = 0.5
frangi_options_pp['beta'] = 0.5
frangi_options_pp['gamma'] = 1
frangi_options_pp['bw_flag'] = True  

aim_options_pp = dict()
aim_options_pp['half_size_win'] = 5
aim_options_pp['min_sd'] = 0.1
aim_options_pp['weights'] = (0, 1)

_, _ , fixed_image_pp = proc.processing_vis(fixed_image, frangi_options_pp, aim_options_pp, gpu = 'cuda') 

# Adaptive thresholding for vascular segmentation of fixed image
ti = 0.008 # threshold at cup surface
tf = 0.003 # threshold at maximum depth
tau = 100 # decay rate
fixed_mask = proc.segment_vessels(fixed_image_pp, depth_map, ti = ti, tf = tf, tau = tau)

# Plot and save fixed image MIPs
if plot_flag:
    plt.figure()
    fig = vis.plot_mips(fixed_image_pp)
    fig.tight_layout()
    if save_flag:
        fig.savefig(r'pp_image_{}.svg'.format(fixed_acq),dpi = 600)
        fig.savefig(r'pp_image_{}.png'.format(fixed_acq))

# Cycle through moving images
moving_acqs = ('02_01','03_01','04_01','05_01','06_01','07_01')
for moving_acq in moving_acqs:
    
    # Load and pad moving image
    moving_path = data_path + os.path.sep + moving_acq + '.npy'
    moving_data = np.load(moving_path, allow_pickle = True).item() #item() to return each item in tuples
    moving_image = proc.pad_rec(moving_data["rec"], ((3, 3), (3, 3),(3, 3)))

    # Process moving image
    _, _ , moving_image_pp = proc.processing_vis(moving_image, frangi_options_pp, aim_options_pp, gpu = 'cuda')
    
    # Adaptive thresholding for vascular segmentation of moving image
    moving_mask = proc.segment_vessels(moving_image_pp, depth_map, ti = ti, tf = tf, tau = tau)
    
    # Plot and save moving image MIPs
    if plot_flag:
        plt.figure()
        fig = vis.plot_mips(moving_image_pp)
        fig.tight_layout()
        if save_flag:
            fig.savefig(r'pp_image_{}.svg'.format(moving_acq), dpi=600)
            fig.savefig(r'pp_image_{}.png'.format(moving_acq))
            
        # Plot and save RGB overlay before co-registration
        plt.figure()
        fig = vis.plot_aligned_mips(moving_image_pp, fixed_image_pp, alpha = 0.5)
        fig.tight_layout()
        if save_flag:
            fig.savefig(r'overlay_pp_{}_{}.svg'.format(fixed_acq, moving_acq), dpi=600)
            fig.savefig(r'overlay_pp_{}_{}.png'.format(fixed_acq, moving_acq))


    # Prepare data for Demons, converting to SITK
    moving_image_sitk = sitk.GetImageFromArray(moving_image)
    fixed_image_sitk = sitk.GetImageFromArray(fixed_image)
    moving_image_pp_sitk = sitk.GetImageFromArray(1000*moving_image_pp)
    fixed_image_pp_sitk = sitk.GetImageFromArray(1000*fixed_image_pp)
    fixed_mask_sitk = sitk.GetImageFromArray(fixed_mask.astype(np.uint8))
    moving_mask_sitk = sitk.GetImageFromArray(moving_mask.astype(np.uint8))
    
     # Select a Demons filter and configure it.
    demons_filter =  sitk.FastSymmetricForcesDemonsRegistrationFilter()
    demons_filter.SetNumberOfIterations(500)
    
    # Regularization (update field - viscous, total field - elastic).
    demons_filter.SetSmoothDisplacementField(True)
    demons_filter.SetStandardDeviations(1.0)

    # Add our simple callback to the registration filter.
    demons_filter.AddCommand(sitk.sitkIterationEvent, lambda: iteration_callback(demons_filter))
    
    # Define a simple callback which allows us to monitor registration progress.
    def iteration_callback(filter):
        print('\r{0}: {1:.2f}'.format(filter.GetElapsedIterations(), filter.GetMetric()), end='')
                
    start_time = time.time()
    # Run registration
    tx = multiscale_demons(registration_algorithm=demons_filter, 
                           fixed_image = fixed_image_sitk, 
                           moving_image = moving_image_sitk,
                           shrink_factors = [6,4,2,1],
                           smoothing_sigmas = [12,8,4,2])
    execution_time = time.time() - start_time
    
    # Transform original moving image
    transformed_image_sitk = sitk.Resample(moving_image_sitk,
                                         fixed_image_sitk,
                                         tx, 
                                         sitk.sitkLinear,
                                         0.0, 
                                         moving_image_sitk.GetPixelID())
    # Transform processed moving image
    transformed_image_pp_sitk = sitk.Resample(moving_image_pp_sitk,
                                         fixed_image_pp_sitk,
                                         tx, 
                                         sitk.sitkLinear,
                                         0.0, 
                                         moving_image_sitk.GetPixelID())
    # Transform mask of moving image
    transformed_mask_sitk = sitk.Resample(moving_mask_sitk,
                                             fixed_image_sitk,
                                             tx, 
                                             sitk.sitkNearestNeighbor,
                                             0.0, 
                                             moving_mask_sitk.GetPixelID())

    # Back to numpy
    transformed_image = sitk.GetArrayFromImage(transformed_image_sitk)
    transformed_image_pp =  sitk.GetArrayFromImage(transformed_image_pp_sitk)/1000
    transformed_mask = sitk.GetArrayFromImage(transformed_mask_sitk).astype(bool)
    
    # Extract displacement field
    deformation_field = sitk.GetArrayFromImage(tx.GetDisplacementField())
    
    # Plot and save RGB overlays after co-registration
    if plot_flag:
        plt.figure()
        fig = vis.plot_aligned_mips(transformed_image_pp, fixed_image_pp)
        fig.tight_layout()
        if save_flag:
            fig.savefig(r'demons_overlay_{}_{}.svg'.format(fixed_acq, moving_acq), dpi = 600)
            fig.savefig(r'demons_overlay_{}_{}.png'.format(fixed_acq, moving_acq))
        
    # Quantitative evaluation
    
    # Before co-registration
    images = dict()
    images['fixed'] = fixed_image
    images['moving'] = moving_image

    masks = dict()
    masks['fixed'] = fixed_mask
    masks['moving'] = moving_mask

    fixed_landmarks, moving_landmarks = eva.load_landmarks(landmarks_path, fixed_acq, moving_acq)
    landmarks = dict()
    landmarks['reg'] = fixed_landmarks
    landmarks['gt'] = moving_landmarks

    metrics_before = eva.similarity(images, masks, landmarks)
    
    # Save metrics before co-registration
    np.save(r'metrics_before_{}_{}.npy'.format(fixed_acq, moving_acq), metrics_before) 

    # After co-registration
    images = dict()
    images['fixed'] = fixed_image
    images['moving'] = transformed_image

    masks = dict()
    masks['fixed'] = fixed_mask
    masks['moving'] = transformed_mask
    
    # Transform landmarks
    reg_landmarks = fixed_landmarks.copy()
    for i, fixed_landmark in enumerate(fixed_landmarks):
        reg_landmark = fixed_landmark + deformation_field[int(fixed_landmark[1]),int(fixed_landmark[0]),int(fixed_landmark[2])]
        reg_landmarks[i,:] = np.array([reg_landmark])

    landmarks = dict()
    landmarks['reg'] = reg_landmarks
    landmarks['gt'] = moving_landmarks
    
    # Export transformed landmarks as txt
    str_landmarks = str()
    idx_landmark = 1
    str_landmarks +=  '['
    for landmark in reg_landmarks:
        str_landmark = '(' + str(landmark[2]) + ' ' + str(landmark[1]) + ' ' + str(landmark[0]) + ')' + '  #{},'. format(idx_landmark)
        str_landmarks +=  str_landmark
        idx_landmark += 1
    str_landmarks +=  ']'
    with open('reg_points_{}_{}.txt'.format(fixed_acq, moving_acq), 'w') as f:
        f.write(str_landmarks)

    metrics_after = eva.similarity(images, masks, landmarks)
    
    # Save metrics after co-registration
    np.save(r'metrics_after_{}_{}.npy'.format(fixed_acq, moving_acq), metrics_after) 
    
    # Save execution time
    np.save(r'execution_time_{}_{}.npy'.format(fixed_acq, moving_acq), execution_time)