# Segmentation Sandbox Notebook

In [None]:
import numpy as np
from tqdm import tqdm
from IPython.display import clear_output
from time import sleep
import os
import sys
import matplotlib.pyplot as plt
import SimpleITK as sitk
import glob
project_root = os.path.abspath('..')
sys.path.insert(1, project_root)
from src.gmm import GMM

In [None]:
# Congif
version = 'interactive'
load_dir = os.path.join(project_root, 'data/input/unsegmented')
prior_dir = os.path.join(project_root, f'data/output/segmentation_priors/{version}')

In [None]:
# Load images
unsegmented_arrays = []
priors = {'csf': [], 'gm': [], 'wm': []}
bg = []
for file_name in sorted(glob.glob(os.path.join(load_dir, '*.nii.gz'))):
    unsegmented_arrays.append(sitk.GetArrayFromImage(sitk.ReadImage(file_name, sitk.sitkFloat32)))
    
# Load the priors transformed into the space of unsegmented images (see priors_for_unsegmented.py)
num_segmentations = len(unsegmented_arrays)
for index in range(num_segmentations):
    for key in priors.keys():
        file_name = os.path.join(prior_dir, f'{index}/{key}.nii.gz')
        tmp_array = sitk.GetArrayFromImage(sitk.ReadImage(file_name, sitk.sitkFloat32))
        priors[key].append(tmp_array)
    bg.append(1 - (priors['csf'][-1] + priors['gm'][-1] + priors['wm'][-1]))
priors['bg'] = bg
del bg

In [None]:
def display3D(images, titles):
    #figure size
    nrows = 1
    ncols = len(images)
    
    f, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5*ncols, 5*nrows))
    for col in range(ncols):
        ax[col].imshow(images[col], cmap='gray')
        ax[col].axis('off')
        ax[col].set_title(titles[col].upper())
    plt.show()
    
def display_wrap(image, title, slices):
    titles = [f'{title}: Projection 1', f'{title}: Projection 2', f'{title}: Projection 3']
    display3D(images=[image[::-1, ::-1, slices[0]].T, image[slices[1], :, ::-1].T, image[::-1, slices[2], ::-1].T],titles=titles)

In [None]:
index = 0
image = unsegmented_arrays[index]
prior = {}
for key, value in priors.items():
    prior[key] = value[index]

slices = [99, 53, 50]

display_wrap(image, f'Unsegmented Case {index}', slices)
    
for key, value in prior.items():
    image_to_display = value
    display_wrap(value, key.upper(), slices)


In [None]:
# mask image and priors to eliminate background
masked_img = image * (1 - prior['bg']>1e-3)
masked_prior = {}
for key in ['csf', 'gm', 'wm']:
    masked_prior[key] = prior[key] * (1 - prior['bg']>1e-3)

In [None]:
# Mask priors
masked_priors = np.stack([masked_prior['csf'],masked_prior['gm'], masked_prior['wm']], -1)

In [None]:
slices = [99, 53, 50]

display_wrap(masked_img, f'Masked Case {index}', slices)
    
for key, value in masked_prior.items():
    display_wrap(value, key.upper(), slices)

In [None]:
gmm_parameters = {'n_components': 3, 'max_iter': 10, 'tol': 1e-3, 'prior': None, 'mrf': None, 'verbose': True}

In [None]:
# Simple GMM model

gmm_parameters['prior'] = False
gmm_parameters['mrf'] = False

model = GMM(n_components=gmm_parameters['n_components'], 
            max_iter=gmm_parameters['max_iter'], 
            tol=gmm_parameters['tol'],
            prior=gmm_parameters['prior'], 
            mrf=gmm_parameters['mrf'],
            verbose=gmm_parameters['verbose'])

scores, p, s = model.fit_predict(masked_img) 
print('Segmentation results')
print(f'Mean values per class: {model.means[0]}')
print(f'Variance values per class: {model.variances[0]}')
print(f'Model weights: {model.weights[0]}')

In [None]:
# GMM with priors
gmm_parameters['prior'] = True
gmm_parameters['mrf'] = False

model_prior = GMM(n_components=gmm_parameters['n_components'], 
                  max_iter=gmm_parameters['max_iter'], 
                  tol=gmm_parameters['tol'],
                  prior=gmm_parameters['prior'], 
                  mrf=gmm_parameters['mrf'],
                  verbose=gmm_parameters['verbose'])
scores_p, p_p, s_p = model_prior.fit_predict(masked_img, masked_priors) 

print('GMM with prior info')
print(f'Mean values per class: {model_prior.means[0]}')
print(f'Variance values per class: {model_prior.variances[0]}')
print(f'Model weights: {model_prior.weights[0]}')

In [None]:
# GMM with MRF
gmm_parameters['prior'] = False
gmm_parameters['mrf'] = 0.1

model_mrf = GMM(n_components=gmm_parameters['n_components'], 
                max_iter=gmm_parameters['max_iter'], 
                tol=gmm_parameters['tol'],
                prior=gmm_parameters['prior'], 
                mrf=gmm_parameters['mrf'],
                verbose=gmm_parameters['verbose'])

scores_mrf, p_mrf, s_mrf = model_mrf.fit_predict(masked_img) 
print('GMM with MRF')
print(f'Mean values per class: {model_mrf.means[0]}')
print(f'Variance values per class: {model_mrf.variances[0]}')
print(f'Model weights: {model_mrf.weights[0]}')

In [None]:
# GMM with MRF and prior
gmm_parameters['prior'] = True
gmm_parameters['mrf'] = 0.1

model_mrf_prior = GMM(n_components=gmm_parameters['n_components'], 
                  max_iter=gmm_parameters['max_iter'], 
                  tol=gmm_parameters['tol'],
                  prior=gmm_parameters['prior'], 
                  mrf=gmm_parameters['mrf'],
                  verbose=gmm_parameters['verbose'])
scores_mrf_prior, p_mrf_prior, s_mrf_prior = model_mrf_prior.fit_predict(masked_img, masked_priors) 
print('GMM with prior and MRF')
print(f'Mean values per class: {model_mrf_prior.means[0]}')
print(f'Variance values per class: {model_mrf_prior.variances[0]}')
print(f'Model weights: {model_mrf_prior.weights[0]}')

### Plot convergence of the models and segmentation results

In [None]:
plt.title('NLL convegence')
plt.plot(scores, label='GMM')
plt.plot(scores_p, label='GMM, prior')
plt.plot(scores_mrf, label='GMM, MRF')
plt.plot(scores_mrf_prior, label='GMM, MRF, prior')
plt.grid()
plt.legend()
plt.xlabel('Iteration')
plt.ylabel('NLL values')
plt.show()

In [None]:
display_wrap(masked_img, 'Image', [99,99,84])
display_wrap(s, 'GMM segmentation', [99,99,84])
display_wrap(s_p, 'GMM+prior segmentation', [99,99,84])
display_wrap(s_mrf, 'GMM+MRF segmentation', [99,99,84])
display_wrap(s_mrf_prior, 'GMM+MRF+prior segmentation', [99,99,84])