# Prior Generation Sandbox Notebook


In [None]:
import os
import sys
import matplotlib.pyplot as plt
import SimpleITK as sitk
import glob
project_root = os.path.abspath('..')
sys.path.insert(1, project_root)
from src.groupwise import resample_image

In [None]:
# Config
version = 'interactive'
load_dir = os.path.join(project_root, 'data/input/segmented')
path_to_template = os.path.join(project_root, f'data/output/registration/{version}/average.nii.gz')
transformations_dir = os.path.join(project_root, f'data/output/registration/{version}')
save_dir = os.path.join(project_root, f'data/output/priors/{version}')

In [None]:
# Load the input images and their segmentations
input_images = []
for file_name in sorted(glob.glob(os.path.join(load_dir, '*img.nii.gz'))):
    input_images.append(sitk.ReadImage(file_name, sitk.sitkFloat32))
    
    
segmentations = []
for file_name in sorted(glob.glob(os.path.join(load_dir, '*seg.nii.gz'))):
    segmentations.append(sitk.ReadImage(file_name, sitk.sitkFloat32))
    

In [None]:
# Load the template average acquired as a result of registration
# Load corresponding transfromations from input image space to template
avg_image = sitk.ReadImage(path_to_template)
transformations = []
for i in range(10):
    transformations.append(sitk.ReadTransform(os.path.join(transformations_dir, f'transformation_{i}.tfm')))

In [None]:
# Apply transforms to segmentations
transformed_segmentation = []
for i, current_segmentation in enumerate(segmentations):
    tmp = resample_image(avg_image, input_images[i], transformations[i])
    tmp_seg = resample_image(avg_image, segmentations[i], transformations[i])
    transformed_segmentation.append(tmp_seg)

In [None]:
# Function to parse the priors
def prior(segmentation, values, thr=0.5):
    csf = (values['csf'] - thr <= segmentation) * (segmentation < values['csf'] + thr) # 1
    gm = (values['gm'] - thr <= segmentation) * (segmentation < values['gm'] + thr)  # 2
    wm = (values['wm'] - thr <= segmentation) * (segmentation < values['wm'] + thr)  # 3
    return csf, gm, wm

In [None]:
# Aggregate the segmentations into priors maps for each tissue type
segmentation_values = {'csf': 1, 'gm': 2, 'wm': 3}
size = transformed_segmentation[0].GetSize()
num = len(transformed_segmentation)
priors = {'csf': sitk.Image(size, sitk.sitkUInt8), 
          'gm': sitk.Image(size, sitk.sitkUInt8), 
          'wm': sitk.Image(size, sitk.sitkUInt8)}
priors['csf'].CopyInformation(transformed_segmentation[0])
priors['gm'].CopyInformation(transformed_segmentation[0])
priors['wm'].CopyInformation(transformed_segmentation[0])
for i, current_segmentation in enumerate(transformed_segmentation):
    csf, gm, wm = prior(current_segmentation, segmentation_values)
    priors['csf'] += csf
    priors['gm'] += gm
    priors['wm'] += wm
priors['csf'] /= num
priors['gm'] /= num
priors['wm'] /= num

In [None]:
# Plot the priors
nrows = 1
ncols = 3
sl = 70
f, ax = plt.subplots(nrows=nrows, ncols=ncols)
for col, key in enumerate(priors.keys()):
    ax[col].imshow(sitk.GetArrayViewFromImage(priors[key])[:,sl], cmap='gray')
    ax[col].axis('off')
    ax[col].set_title(key.upper())
plt.show()

In [None]:
# Save the prior maps and template image
os.makedirs(save_dir)
for key, val in priors.items():
    sitk.WriteImage(val, os.path.join(save_dir, f'{key}.nii.gz'))

sitk.WriteImage(avg_image, os.path.join(save_dir, 'template.nii.gz'))