# Test Elastix on experiment 1

#### Author: 
Bruno De Santi, PhD
#### Affiliation:
Multi-modality Medical Imaging Lab (M3I Lab), University of Twente, Enschede, The Netherlands
#### Date:
20/09/2023
#### Paper/Project Title:
Automated three-dimensional image registration for longitudinal photoacoustic imaging (De Santi et al. 2023, JBO)
#### GitHub:
https://github.com/brunodesanti/muvinn-reg
#### License:
[Specify the license, e.g., MIT, GPL, etc.]

### Import libraries

In [None]:
abs_path = r'C:\Users\DeSantiB\surfdrive\PAMMOTH Study\PAMMOTH image registration\MUVINN-reg'
import sys
sys.path.append(abs_path)

from utils import processing as proc
from utils import visualizing as vis
from utils import evaluating as eva

import matplotlib.pyplot as plt
import numpy as np
import os 
import torch
import time
import scipy.ndimage as scnd
import SimpleITK as sitk
import monai

plot_flag = True # if one wants to plot figures
save_flag = True # if one wants to save figures

### Apply Elastix on experiment 1 
#### Adapted from tutorials: https://simpleelastix.readthedocs.io/

In [None]:
# Data path
data_path = r'C:\Users\DeSantiB\surfdrive\PAMMOTH Study\PAMMOTH image registration\MUVINN-reg\notebooks\data'

# Path to directory where landmarks are stored
landmarks_path = r'C:\Users\DeSantiB\surfdrive\PAMMOTH Study\PAMMOTH image registration\MUVINN-reg\notebooks\mevislab\landmarks'

# Load and pad fixed image
fixed_acq = '01_01'
fixed_path = data_path + os.path.sep + fixed_acq + '.npy'
fixed_data = np.load(fixed_path, allow_pickle=True).item()
fixed_image = proc.pad_rec(fixed_data["rec"], ((3, 3), (3, 3),(3, 3)))

# Pad depth map and cup mask
depth_map = proc.pad_rec(fixed_data["depth_map"], ((3, 3), (3, 3),(3, 3)))
cup_mask = proc.pad_rec(fixed_data["cup_mask"], ((3, 3), (3, 3),(3, 3)))

# Processing fixed image
frangi_options_pp = dict()
frangi_options_pp['sigmas'] = (1.5, 3, 5, 9, 12)
frangi_options_pp['alpha'] = 0.5
frangi_options_pp['beta'] = 0.5
frangi_options_pp['gamma'] = 1
frangi_options_pp['bw_flag'] = True  

aim_options_pp = dict()
aim_options_pp['half_size_win'] = 5
aim_options_pp['min_sd'] = 0.1
aim_options_pp['weights'] = (0, 1)

_, _ , fixed_image_pp = proc.processing_vis(fixed_image, frangi_options_pp, aim_options_pp, gpu = 'cuda') 

# Adaptive thresholding for vascular segmentation of fixed image
ti = 0.008 # threshold at cup surface
tf = 0.003 # threshold at maximum depth
tau = 100 # decay rate
fixed_mask = proc.segment_vessels(fixed_image_pp, depth_map, ti = ti, tf = tf, tau = tau)

# Plot and save fixed image MIPs
if plot_flag:
    plt.figure()
    fig = vis.plot_mips(fixed_image_pp)
    fig.tight_layout()
    if save_flag:
        fig.savefig(r'pp_image_{}.svg'.format(fixed_acq),dpi = 600)
        fig.savefig(r'pp_image_{}.png'.format(fixed_acq))

# Cycle through moving images
moving_acqs = ('02_01','03_01','04_01','05_01','06_01','07_01')
for moving_acq in moving_acqs:
    
    # Load and pad moving image
    moving_path = data_path + os.path.sep + moving_acq + '.npy'
    moving_data = np.load(moving_path, allow_pickle = True).item() #item() to return each item in tuples
    moving_image = proc.pad_rec(moving_data["rec"], ((3, 3), (3, 3),(3, 3)))
    
    # Process moving image
    _, _ , moving_image_pp = proc.processing_vis(moving_image, frangi_options_pp, aim_options_pp, gpu = 'cuda') 
    
    # Adaptive thresholding for vascular segmentation of moving image
    moving_mask = proc.segment_vessels(moving_image_pp, depth_map, ti = ti, tf = tf, tau = tau)
    
    # Plot and save moving image MIPs
    if plot_flag:
        plt.figure()
        fig = vis.plot_mips(moving_image_pp)
        fig.tight_layout()
        if save_flag:
            fig.savefig(r'pp_image_{}.svg'.format(moving_acq), dpi=600)
            fig.savefig(r'pp_image_{}.png'.format(moving_acq))
            
        # Plot and save RGB overlay before co-registration
        plt.figure()
        fig = vis.plot_aligned_mips(moving_image_pp, fixed_image_pp, alpha = 0.5)
        fig.tight_layout()
        if save_flag:
            fig.savefig(r'overlay_pp_{}_{}.svg'.format(fixed_acq, moving_acq), dpi=600)
            fig.savefig(r'overlay_pp_{}_{}.png'.format(fixed_acq, moving_acq))


    # Prepare data for Elastix, converting to SITK
    moving_image_sitk = sitk.GetImageFromArray(moving_image)
    fixed_image_sitk = sitk.GetImageFromArray(fixed_image)
    moving_image_pp_sitk = sitk.GetImageFromArray(1000*moving_image_pp)
    fixed_image_pp_sitk = sitk.GetImageFromArray(1000*fixed_image_pp)
    fixed_mask_sitk = sitk.GetImageFromArray(fixed_mask.astype(np.uint8))
    moving_mask_sitk = sitk.GetImageFromArray(moving_mask.astype(np.uint8))
    
    # Parameter settings
    elastixImageFilter = sitk.ElastixImageFilter()
    parameterMapVector = sitk.VectorOfParameterMap()

    # Rigid parameters  
    parameterMap = sitk.ReadParameterFile(r"parameters/rigid.txt")
    parameterMapVector.append(parameterMap)

    # Nonrigid parameters
    parameterMap = sitk.ReadParameterFile(r"parameters/bspline.txt")
    parameterMapVector.append(parameterMap)

    elastixImageFilter.SetParameterMap(parameterMapVector)

    elastixImageFilter.SetFixedImage(fixed_image_pp_sitk)
    elastixImageFilter.SetMovingImage(moving_image_pp_sitk)
    
    # Include only regions inside the cup mask
    elastixImageFilter.SetFixedMask(sitk.GetImageFromArray(cup_mask.astype(np.uint8)))
    elastixImageFilter.SetMovingMask(sitk.GetImageFromArray(cup_mask.astype(np.uint8)))

    start_time = time.time()
    # Run registration
    elastixImageFilter.Execute()
    execution_time = time.time() - start_time

    # Transformed processed moving image
    transformed_image_pp_sitk = elastixImageFilter.GetResultImage()
    transformed_image_pp = sitk.GetArrayFromImage(transformed_image_pp_sitk)/1000

    transform_parameters = elastixImageFilter.GetTransformParameterMap()
    transformixImageFilter = sitk.TransformixImageFilter()
    transformixImageFilter.SetTransformParameterMap(transform_parameters)

    # Transform original moving image
    transformixImageFilter.SetMovingImage(moving_image_sitk)
    transformixImageFilter.Execute()
    transformed_image_sitk = transformixImageFilter.GetResultImage()
    transformed_image = sitk.GetArrayFromImage(transformed_image_sitk)

    # Use NN interpolation for transforming mask
    transform_parameters_mask = []
    for parameter in transform_parameters:
        parameter["ResampleInterpolator"] = ["FinalNearestNeighborInterpolator"]
        transform_parameters_mask.append(parameter)

    transformixImageFilter.SetTransformParameterMap(transform_parameters_mask)
    transformixImageFilter.SetMovingImage(moving_mask_sitk)

    transformixImageFilter.ComputeDeformationFieldOn()
    transformixImageFilter.Execute()
    
    # Transform moving mask
    transformed_mask_sitk = transformixImageFilter.GetResultImage()
    transformed_mask = sitk.GetArrayFromImage(transformed_mask_sitk).astype(bool)
    
    # Extract deformation field
    deformation_field = sitk.GetArrayFromImage(transformixImageFilter.GetDeformationField())

    # Plot and save MIPs after co-registration
    if plot_flag:
        plt.figure()
        fig = vis.plot_aligned_mips(transformed_image_pp, fixed_image_pp)
        fig.tight_layout()
        if save_flag:
            fig.savefig(r'elastix_overlay_{}_{}.svg'.format(fixed_acq, moving_acq), dpi = 600)
            fig.savefig(r'elastix_overlay_{}_{}.png'.format(fixed_acq, moving_acq))
        
    # Quantitative evaluation
    
    # Before co-registration
    images = dict()
    images['fixed'] = fixed_image
    images['moving'] = moving_image

    masks = dict()
    masks['fixed'] = fixed_mask
    masks['moving'] = moving_mask
                                    
    fixed_landmarks, moving_landmarks = eva.load_landmarks(landmarks_path, fixed_acq, moving_acq)
    landmarks = dict()
    landmarks['reg'] = fixed_landmarks
    landmarks['gt'] = moving_landmarks

    metrics_before = eva.similarity(images, masks, landmarks)
    
    # Save metrics before co-registration
    np.save(r'metrics_before_{}_{}.npy'.format(fixed_acq, moving_acq), metrics_before) 
    
    # After co-registration
    images = dict()
    images['fixed'] = fixed_image
    images['moving'] = transformed_image

    masks = dict()
    masks['fixed'] = fixed_mask
    masks['moving'] = transformed_mask
    
    # Transform landmarks
    reg_landmarks = fixed_landmarks.copy()
    for i, fixed_landmark in enumerate(fixed_landmarks):
        reg_landmark = fixed_landmark + deformation_field[int(fixed_landmark[1]),int(fixed_landmark[0]),int(fixed_landmark[2])]
        reg_landmarks[i,:] = np.array(reg_landmark)

    landmarks = dict()
    landmarks['reg'] = reg_landmarks
    landmarks['gt'] = moving_landmarks
    
    # Export transformed landmarks as txt
    str_landmarks = str()
    idx_landmark = 1
    str_landmarks +=  '['
    for landmark in reg_landmarks:
        str_landmark = '(' + str(landmark[2]) + ' ' + str(landmark[1]) + ' ' + str(landmark[0]) + ')' + '  #{},'. format(idx_landmark)
        str_landmarks +=  str_landmark
        idx_landmark += 1
    str_landmarks +=  ']'
    with open('reg_points_{}_{}.txt'.format(fixed_acq, moving_acq), 'w') as f:
        f.write(str_landmarks)

    metrics_after = eva.similarity(images, masks, landmarks)
    
    # Save metrics after co-registration
    np.save(r'metrics_after_{}_{}.npy'.format(fixed_acq, moving_acq), metrics_after) 
    
    # Save execution time
    np.save(r'execution_time_{}_{}.npy'.format(fixed_acq, moving_acq), execution_time)