# Test for affine alignment

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os

import matplotlib.pyplot as plt
import SimpleITK as sitk


from wsiitkreg.affine import affine_align_image
from wsiitkreg.util import ordered_images_to_batches, split_by_core_id_and_sort

In [None]:
# Choose and load dataset
dataset = '/home/maximilw/workbench/wsi/exports/frozen_storage/normalized_frozen_storage'

cores = split_by_core_id_and_sort([x for x in os.listdir(dataset) if x.endswith('.tif')])
for key in cores:
    cores[key] = [os.path.join(dataset, x) for x in cores[key]]
    
core_A = ordered_images_to_batches(cores['A'])
core_B = ordered_images_to_batches(cores['B'])
core_C = ordered_images_to_batches(cores['C'])
core_D = ordered_images_to_batches(cores['D'])

In [None]:
def align_core(core):
    
    aligned_core = []
    print('Core size: ', len(core))
    print()
    for i, batch in enumerate(core):
        print('Next batch index: ', i)
        transformed_image, transform, fixed_mask, moving_mask = affine_align_image(batch['source_image_path'], batch['target_image_path'])
        batch['sitk_transformed_image'] = transformed_image
        batch['transform'] = transform
        batch['fixed_mask'] = fixed_mask
        batch['moving_mask'] = moving_mask
        batch['source_image'] = sitk.GetArrayFromImage(sitk.ReadImage(batch['source_image_path'], sitk.sitkFloat32))
        batch['target_image'] = sitk.GetArrayFromImage(sitk.ReadImage(batch['target_image_path'], sitk.sitkFloat32))
        batch['transformed_image'] = sitk.GetArrayFromImage(batch['sitk_transformed_image'])
        aligned_core.append(batch)

    plot_alignment(core)

def plot_alignment(core):
    
    nrows = len(core)
    ncols = 3

    fix, axs = plt.subplots(nrows, ncols, figsize=(32,128))

    for i, _ in enumerate(core):
        axs[i, 0].imshow(core[i]['source_image'])
        axs[i, 0].set_title('Source Image')
        axs[i, 0].axis('off')
        axs[i, 1].imshow(core[i]['target_image'])
        axs[i, 1].set_title('Target Image')
        axs[i, 1].axis('off')
        axs[i, 2].imshow(core[i]['transformed_image'])
        axs[i, 2].set_title('Affine')
        axs[i, 2].axis('off')
    plt.show()

In [None]:
align_core(core_A)

In [None]:
plot_alignment(core_A)