This notebook contains the code of the most similar atlas (SIM) approach 

In [1]:
# Import libraries
from pathlib import Path
import pandas as pd
import SimpleITK as sitk
import matplotlib.pyplot as plt
import pickle
import numpy as np 
import os
from tqdm import tqdm

from metadata import ImageDataset, patient

In [2]:
notebook_path = Path().cwd()
repo_path = notebook_path.parent
print(f'The current directory is: {notebook_path}')

The current directory is: d:\VS_Projects\MISA_FINAL_PROJECT\notebooks


In [4]:
#instantiate dataset, specify set name
img_training_data = ImageDataset('Training')
im_val_data = ImageDataset('Validation')

#check length
# print(f'len: {img_training_data.len}')
# #get all ids (useful to iterate over all adataset)
# print(f'all ids: {img_training_data.IDs}')
# #get list of all images paths for training
# print(f'images paths: {img_training_data.im_paths()}')
# #same for the labels
# print(f'images paths: {img_training_data.labels_paths()}')


#select one id of patient to check
id = img_training_data.IDs[0]
id_val = im_val_data.IDs[0]
print(f'selected id: {id}')
#istantiate patient object using the id and the previous ImageDataset object
pat = patient(id, img_training_data)
pat_val = patient(id_val, im_val_data)

# pat.show('labels')
# pat.show('im')

# Get the image and the labels
im = pat.im()
labels = pat.labels()

im_val = pat_val.im()
labels_val = pat_val.labels()

selected id: 01


In [5]:
def image_registration(fixedImage, movingImage):
    """Give two images and the registration of both, and its transformation map will be given

    Args:
        fixed_path (str): fixed (template) image
        moving_path (str): moving image (image that will be transformed)

    Returns:
        sitk image, transformix map: transformed image and the transformation map
    """

    #Start registration settings
    elastixImageFilter = sitk.ElastixImageFilter() #Image filter object
    #Defining settings
    elastixImageFilter.SetFixedImage(fixedImage)
    elastixImageFilter.SetMovingImage(movingImage)

    #Run registration
    elastixImageFilter.Execute()

    #Get result image
    resultImage = elastixImageFilter.GetResultImage()

    #Transformation map
    transformParameterMap = elastixImageFilter.GetTransformParameterMap()
    transformParameterMap[0]["ResampleInterpolator"] = ["FinalNearestNeighborInterpolator"]
    transformParameterMap[1]["ResampleInterpolator"] = ["FinalNearestNeighborInterpolator"]    
    transformParameterMap[2]["ResampleInterpolator"] = ["FinalNearestNeighborInterpolator"]
    return resultImage, transformParameterMap

### Registration and label propagation

In [6]:
if not os.path.exists(repo_path / "data" / "registered_labels_all"):
    os.makedirs(repo_path / "data" / "registered_labels_all")

In [8]:
# for each training image, register it to the validation image
# add tqdm to see progress

path_to_save_labels = repo_path / "data" / "registered_labels_all"

list_labels_all = []
for i in tqdm(range(len(img_training_data.IDs))):
    id = img_training_data.IDs[i]
    pat = patient(id, img_training_data)
    moving_image = pat.im(preprocess=True)
    labels = pat.labels()

    list_labels = []
    for v in range(len(im_val_data.IDs)):

        id_val = im_val_data.IDs[v]

        print(f'Registering training image {id} to validation image {id_val}...')

        pat_val = patient(id_val, im_val_data)
        fixed_image = pat_val.im(preprocess=True)
        labels_val = pat_val.labels()

        resultImage, transformParameterMap = image_registration(fixed_image, moving_image) #Resgiter images using predefined transform map

        transformixImageFilter = sitk.TransformixImageFilter() #Create object transform matrix
        transformixImageFilter.SetTransformParameterMap(transformParameterMap) #Set with found transform
        transformixImageFilter.SetMovingImage(labels) #set labels as moving
        transformixImageFilter.Execute() #Tranform labels

        labels_registered = transformixImageFilter.GetResultImage() #Propagated labels

        with open(path_to_save_labels / f'trans_labels_{id}_to_{id_val}.p', 'wb') as handle:   #Save propagated labels as pickle file
            pickle.dump(labels_registered, handle, pickle.HIGHEST_PROTOCOL)
            
    #     list_labels.append(labels_registered) #Append all propagated labels on a list
    
    # list_labels_all.append(list_labels) #Append all propagated labels on a list

  0%|          | 0/10 [00:00<?, ?it/s]

Registering training image 01 to validation image 11...
Registering training image 01 to validation image 12...
Registering training image 01 to validation image 13...
Registering training image 01 to validation image 14...
Registering training image 01 to validation image 17...


 10%|█         | 1/10 [13:38<2:02:42, 818.04s/it]

Registering training image 03 to validation image 11...
Registering training image 03 to validation image 12...
Registering training image 03 to validation image 13...
Registering training image 03 to validation image 14...
Registering training image 03 to validation image 17...


 20%|██        | 2/10 [27:27<1:49:56, 824.51s/it]

Registering training image 04 to validation image 11...
Registering training image 04 to validation image 12...
Registering training image 04 to validation image 13...
Registering training image 04 to validation image 14...
Registering training image 04 to validation image 17...


 30%|███       | 3/10 [41:27<1:37:01, 831.58s/it]

Registering training image 05 to validation image 11...
Registering training image 05 to validation image 12...
Registering training image 05 to validation image 13...
Registering training image 05 to validation image 14...
Registering training image 05 to validation image 17...


 40%|████      | 4/10 [54:37<1:21:32, 815.35s/it]

Registering training image 06 to validation image 11...
Registering training image 06 to validation image 12...
Registering training image 06 to validation image 13...
Registering training image 06 to validation image 14...
Registering training image 06 to validation image 17...


 50%|█████     | 5/10 [1:08:58<1:09:18, 831.75s/it]

Registering training image 07 to validation image 11...
Registering training image 07 to validation image 12...
Registering training image 07 to validation image 13...
Registering training image 07 to validation image 14...
Registering training image 07 to validation image 17...


 60%|██████    | 6/10 [1:24:57<58:19, 874.95s/it]  

Registering training image 08 to validation image 11...
Registering training image 08 to validation image 12...
Registering training image 08 to validation image 13...
Registering training image 08 to validation image 14...
Registering training image 08 to validation image 17...


 70%|███████   | 7/10 [1:39:27<43:40, 873.34s/it]

Registering training image 09 to validation image 11...
Registering training image 09 to validation image 12...
Registering training image 09 to validation image 13...
Registering training image 09 to validation image 14...
Registering training image 09 to validation image 17...


 80%|████████  | 8/10 [1:53:24<28:43, 861.93s/it]

Registering training image 16 to validation image 11...
Registering training image 16 to validation image 12...
Registering training image 16 to validation image 13...


In [None]:
#to save all the propagated labels in one file
with open(f'all_trans_labels.p', 'wb') as handle: 
    pickle.dump(list_labels_all, handle, pickle.HIGHEST_PROTOCOL)

### Make probablistic Atlas from propagated labels 

In [None]:
with open(f'all_trans_labels.p', 'rb') as handle:   #Load all propagated labels
    list_labels= pickle.load(handle)

atlas = np.zeros(shape=(3,) + sitk.GetArrayFromImage(list_labels[0]).shape,dtype=np.float32) #Array to store

for k in range(1,4): #Go for each tissue    
    for label in list_labels: # Go through all images
        label = sitk.GetArrayFromImage(label) 
        atlas[k-1] = atlas[k-1] + (label==k)   #Accumulated atlas
    atlas[k-1] = atlas[k-1]/len(list_labels) #Finally get mean

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(20,20))
for k in range(1,4):
    axs[k-1].imshow(atlas[k-1][100], cmap='gray')
    axs[k-1].set_title(f'Probability of tissue {k}')