# Registration Sandbox Notebook
This notebook is for registration sandbox to play with different parametrisation of the registration process and display the results.
 

In [None]:
import os
import sys
import matplotlib.pyplot as plt
import SimpleITK as sitk
import glob
project_root = os.path.abspath('..')
sys.path.insert(1, project_root)
from src.groupwise import groupwise_registration

In [None]:
# Config 
version = 'interactive'
rigid_param = {'learn_rate': 1,
                   'min_step': 0.01, #early stop - step of SGD
                   'max_iter': 50, # max iter of SGD
                   'pyramid_lvl': 1} # Resolution
affine_param = {'learn_rate': .05,
                    'min_step': 0.001,
                    'max_iter': 50,
                    'pyramid_lvl': 1} # Resolution
nonlin_param = {'cpn': 5, # Mesh size for B-spline
                    'learn_rate': 1,
                    'min_step': 0.1,
                    'max_iter': 10,
                    'pyramid_lvl': [4, 2, 1]} # Resolution
iter_types = ['Rigid'] * 3 + ['Affine'] * 2 + ['NonLinear'] * 1
save_folder = os.path.join(project_root, f'data/output/registration/{version}')
resume = False  # False - to perform registration from scratch, True - to perform the registration from previous results

In [None]:
# Read images, which are provided with segmentations
input_images = []
for file_name in sorted(glob.glob(os.path.join(project_root, 'data/input/segmented/*img.nii.gz'))):
    input_images.append(sitk.ReadImage(file_name, sitk.sitkFloat32))
n_images = len(input_images)

In [None]:
if not resume:
    # Perform registration from scratch
    avg_image = None
    init_transf = None
else:
    # Perform registration from previsous result
    avg_image = [sitk.ReadImage(os.path.join(project_root, f'data/output/registration/{version}/average.nii.gz'))]
    init_transf = []
    for i in range(n_images):
        init_transf.append(sitk.ReadTransform(os.path.join(project_root, f'data/output/registration/{version}/transformation_{i}.tfm')))

In [None]:
# Execute groupwise transformation
trans, averages = groupwise_registration(input_images,
                                         iter_types,
                                         rigid_param,
                                         affine_param,
                                         nonlin_param,
                                         init_transformations=init_transf,
                                         average_images=avg_image)


In [None]:
os.makedirs(save_folder)
# Save the average images to disk
for i, img in enumerate(averages):
    sitk.WriteImage(img, os.path.join(save_folder, f'average_{i}.nii.gz'))

# Save the final average image as a main reference
sitk.WriteImage(averages[-1], os.path.join(save_folder, 'average.nii.gz'))
    
# Save the transformations to disk
for i, transform in enumerate(trans):
    sitk.WriteTransform(transform, os.path.join(save_folder, f'transformation_{i}.tfm'))

# Analyse the results

In [None]:
slices = [99, 53, 50]
ncols = len(input_images)
nrows = len(slices)
f, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*3, nrows*3), squeeze=False)
for col in range(ncols):
    image_to_show = sitk.GetArrayViewFromImage(input_images[col])
    ax[0, col].imshow(image_to_show[slices[0]], cmap='gray')
    ax[1, col].imshow(image_to_show[:, slices[1]], cmap='gray')
    ax[2, col].imshow(image_to_show[:, :, slices[2]], cmap='gray')
    for row in range(nrows):
        ax[row, col].axis('off')
f.suptitle('Input Images, 3 Projections View')
plt.show()

In [None]:
# Show all intermediate averages

slices = [99, 53, 50]
ncols = len(averages)
nrows = len(slices)
vmax = None
vmin = None

f, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5*ncols, 5*nrows), squeeze=False)
for col in range(ncols):
    image_to_show = sitk.GetArrayViewFromImage(averages[col])
    ax[0, col].set_title(f'Iteration: {col}')
    im = []
    im.append(ax[0, col].imshow(image_to_show[slices[0]], cmap='gray', vmax=vmax, vmin=vmin))
    im.append(ax[1, col].imshow(image_to_show[:, slices[1]], cmap='gray', vmax=vmax, vmin=vmin))
    im.append(ax[2, col].imshow(image_to_show[:, :, slices[2]], cmap='gray', vmax=vmax, vmin=vmin))
    
    for row in range(nrows):
        ax[row, col].axis('off')
        cax = ax[row, col].inset_axes([1.04, 0.2, 0.02, 0.6])
        f.colorbar(im[row], cax=cax)
f.suptitle('Average Templates, 3 Projections View')
plt.show()

In [None]:
### import numpy as np
ncols = len(averages) - 1
slices = [99, 53, 50]
nrows = len(slices)
f, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5*ncols, 5*nrows))
for col in range(ncols):
    image_0 = sitk.GetArrayViewFromImage(averages[col])
    image_1 = sitk.GetArrayViewFromImage(averages[col + 1])
    d = abs(image_0 - image_1)
    
    im = []
    im.append(ax[0, col].imshow(d[slices[0]], cmap='gray'))
    im.append(ax[1, col].imshow(d[:, slices[1]], cmap='gray'))
    im.append(ax[2, col].imshow(d[:,:,slices[2]], cmap='gray'))
    
    for row in range(nrows):
        ax[row, col].axis('off')
        cax = ax[row, col].inset_axes([1.04, 0.2, 0.02, 0.6])
        f.colorbar(im[row], cax=cax)
f.suptitle('Differences Between Templates Across Iterations, 3 Projections View')
plt.show()