In [None]:
# !pip install itk-elastix
# !pip install itk
# !pip install itkwidgets

https://github.com/InsightSoftwareConsortium/ITKElastix

In [None]:
import numpy as np
import nibabel as nib
import itk
from itkwidgets import compare, checkerboard, view
import matplotlib.pyplot as plt
%matplotlib inline
import os
from tqdm import tqdm

## Registering training test

In [None]:
#Setting up the parameters

parameter_object_affine = itk.ParameterObject.New()
parameter_object_affine.AddParameterFile( 'affine.txt')

parameter_object_elastix = itk.ParameterObject.New()
parameter_object_elastix.AddParameterFile( 'elastic.txt')

In [None]:
fixed_image = itk.imread('training-set/training-images/1000.nii', itk.F)

In [None]:
imset = os.listdir('training-set/training-images')

transform_list_affine = []
transform_list_elastix = []


for image in tqdm(imset[1:]):
    moving_image = itk.imread('training-set/training-images/'+image, itk.F)

    result_image_affine, result_transform_parameters_affine = itk.elastix_registration_method(
    fixed_image, moving_image,
    parameter_object=parameter_object_affine,
    log_to_console=False)

    transform_list_affine.append(result_transform_parameters_affine)

    result_image, result_transform_parameters_elastix = itk.elastix_registration_method(
    fixed_image, result_image_affine,
    parameter_object=parameter_object_elastix,
    log_to_console=False)

    transform_list_elastix.append(result_transform_parameters_elastix)
    
    itk.imwrite(result_image,'training-set/registered1000/'+image)

   

## Creating intesity volume atlas

In [None]:
imset_reg = os.listdir('training-set/registered1000')

In [None]:
reg = itk.imread('training-set/training-images/1000.nii.gz', itk.F)
reg = itk.array_from_image(reg)

i = 1
for image in tqdm(imset_reg):
    im = itk.imread('training-set/registered1000/'+image, itk.F)
    im = itk.array_from_image(im)

    reg = reg + im
    i= i+1

reg = np.round(reg/i).astype(int)


In [None]:
#Some metadata change reqired because of itk wrappers
reg1 = itk.image_from_array(reg)
meta_data= dict(fixed_image)
for key, value in meta_data.items():
    reg1[key] = value
reg1 = reg1.astype(itk.F)

itk.imwrite(reg1,'generated_template.nii.gz')

## Applying registration parameters to labels

In [None]:
maskset = os.listdir('training-set/training-labels')
for i, image in tqdm(enumerate(maskset[1:])):
    moving_image_transformix = itk.imread('training-set/training-labels/'+image, itk.F)
    result_image_transformix_affine = itk.transformix_filter( moving_image_transformix, transform_list_affine[i])
    result_image_transformix = itk.transformix_filter( result_image_transformix_affine, transform_list_elastix[i])
    itk.imwrite(result_image_transformix, 'training-set/registered1000_labels/'+image)
    

## Creating label probability atlas

In [None]:
maskset_reg = os.listdir('training-set/registered1000_labels')

In [None]:
atlas_im = itk.imread('training-set/training-labels/1000_3C.nii.gz', itk.F)
atlas_im = itk.array_from_image(atlas_im)

s = atlas_im.shape
atlas = np.empty((4, s[0], s[1],s[2]))

for j in range(4):
    atlas[j] = (atlas_im==j).astype(int)
    
i = 1
for image in tqdm(maskset_reg):
    im = itk.imread('training-set/registered1000_labels/'+image, itk.F)
    im = itk.array_from_image(im)
    for j in range(4):
        atlas[j] = atlas[j] +  (im==j).astype(int) 
    i= i+1
atlas = atlas/i

itk.imwrite(itk.GetImageFromArray(atlas),'generated_atlas.nii.gz')

## Registering testing images to MNITemplateAtlas

In [None]:
test_folder = os.listdir('test-set/testing-images')

In [None]:
moving_image = itk.imread('MNITemplateAtlas/template.nii.gz', itk.F)
moving_image_transformix = itk.imread('MNITemplateAtlas/atlas.nii.gz', itk.F)
moving_image_transformix = itk.array_from_image(moving_image_transformix)
meta_data= dict(moving_image)



for image in tqdm(test_folder):

    fixed_image = itk.imread('test-set/testing-images/'+image, itk.F)

    result_image_affine, result_transform_parameters_affine = itk.elastix_registration_method(
    fixed_image, moving_image,
    parameter_object=parameter_object_affine,
    log_to_console=False)

    result_image, result_transform_parameters_elastix = itk.elastix_registration_method(
    fixed_image, result_image_affine,
    parameter_object=parameter_object_elastix,
    log_to_console=False)

    s = fixed_image.shape
    atlas_affine = np.empty((4, s[0], s[1],s[2]))
    atlas_reg = np.empty((4, s[0], s[1],s[2]))

    for label in range(4):
        im = itk.image_from_array( moving_image_transformix[label])
        for key, value in meta_data.items():
            im[key] = value
        im = im.astype(itk.F)

        result_image_transformix_affine = itk.transformix_filter( im, result_transform_parameters_affine)
        atlas_affine[label] = result_image_transformix_affine
        atlas_reg[label] = itk.transformix_filter( result_image_transformix_affine, result_transform_parameters_elastix)
        
    itk.imwrite(itk.GetImageFromArray(atlas_reg), 'test-set/mni/'+image)

## Registering testing images to generated atlas

In [None]:
moving_image = itk.imread('generated_template.nii.gz', itk.F)
moving_image_transformix = itk.imread('generated_atlas.nii.gz', itk.F)
moving_image_transformix = itk.array_from_image(moving_image_transformix)
meta_data= dict(moving_image)



for image in tqdm(test_folder):

    fixed_image = itk.imread('test-set/testing-images/'+image, itk.F)

    result_image_affine, result_transform_parameters_affine = itk.elastix_registration_method(
    fixed_image, moving_image,
    parameter_object=parameter_object_affine,
    log_to_console=False)

    result_image, result_transform_parameters_elastix = itk.elastix_registration_method(
    fixed_image, result_image_affine,
    parameter_object=parameter_object_elastix,
    log_to_console=False)

    s = fixed_image.shape
    atlas_affine = np.empty((4, s[0], s[1],s[2]))
    atlas_reg = np.empty((4, s[0], s[1],s[2]))

    for label in range(4):
        im = itk.image_from_array( moving_image_transformix[label])
        for key, value in meta_data.items():
            im[key] = value
        im = im.astype(itk.F)

        result_image_transformix_affine = itk.transformix_filter( im, result_transform_parameters_affine)
        atlas_affine[label] = result_image_transformix_affine
        atlas_reg[label] = itk.transformix_filter( result_image_transformix_affine, result_transform_parameters_elastix)
        
    itk.imwrite(itk.GetImageFromArray(atlas_reg), 'test-set/generated/'+image)