In [1]:
import os
import numpy as np
from nibabel.testing import data_path
import nibabel as nib
from pathlib import Path

path = str(Path().resolve())
path = path + "\\ADNI_PROCESSED"

def apply_mask(img_n_mmni, img_mask):
    """
        Taking a n_mmni and apply the correspondant mask
        param:
            img_n_mmi   : image n_mmi
            img_mask    : mask
    """
    mmni_m = img_n_mmni.get_fdata()
    mask_m = img_mask.get_fdata().astype(bool)
    mask_bg = np.logical_not(mask_m)
    mmni_m[mask_bg] = 0
    return mmni_m

def process_irm_data():
    """
        Create a new directory and process all images from tha ADNI1 directory
    """
    path = str(Path().resolve())
    path_res = path + "\\ADNI_PROCESSED"
    Path(path_res).mkdir(parents=True, exist_ok=True) # Create a directory for data processed
    path = path + "\\ADNI1"
    for filename in os.listdir(path):
        if filename.startswith("n_mmni"):
            n_mmni_filename = os.path.join(path, filename)
            mask_filename = os.path.join(path, "mask_" + filename)
            img_n_mmni = nib.load(n_mmni_filename)
            img_mask = nib.load(mask_filename)
            n_mmni_mask = apply_mask(img_n_mmni, img_mask)
            img = nib.Nifti1Image(n_mmni_mask, np.eye(4))
            nib.save(img, os.path.join(path_res, filename))

process_irm_data()

In [2]:
def load_processed_data(path):
    """
        load all n_mmni found in the path
    """
    if not os.path.isdir(path):
        print("Can't found directory: " + path)
    else:
        list_x = []
        for filename in os.listdir(path):
            n_mmni_filename = os.path.join(path, filename)
            img_n_mmni = nib.load(n_mmni_filename)
            mmni_matrix = img_n_mmni.get_fdata()
            list_x.append((filename, mmni_matrix))
        return list_x

# Not tested yet; crashed last time
# path = str(Path().resolve())
# path_to_data_proc = path + "\\ADNI_PROCESSED"
# X = load_processed_data(path_to_data_proc)

In [47]:
def cut_2D_i(img_n_mmni, axe, idx):
    """
        Function that returns a 2D cut from the "img" in the index "idx", along the axe given in parameter
    """
    axe_dim = {"x": img_n_mmni.shape[0], "y": img_n_mmni.shape[1], "z":img_n_mmni.shape[2]}
    if axe_dim[axe] <= idx or idx < 0:
        print("Invalid value for index must be between 0 and " , axe_dim[axe])
        return
    if axe == "x":
        cropped_img = img_n_mmni.slicer[idx:idx+1, ...]
    elif axe == "y":
        cropped_img = img_n_mmni.slicer[:, idx:idx+1,:]
    elif axe == "z":
        cropped_img = img_n_mmni.slicer[..., idx:idx+1]
    else:
        print("Choose a valid value for axe: x, y or z")
    return cropped_img

def patch_3D(img_n_mmni, axe, idx_start, idx_end):
    """
        Function that returns a 3D patch from the "img" along the axe given in parameter, from the idx_start to idx_end
    """
    axe_dim = {"x": img_n_mmni.shape[0], "y": img_n_mmni.shape[1], "z":img_n_mmni.shape[2]}
    if axe_dim[axe] <= idx_start or idx_start < 0 or axe_dim[axe] <= idx_end or idx_end < 0 or idx_start >= idx_end:
        print("Invalid value for index must, values must be between 0 and " , axe_dim[axe], "and idx_start must be greater than idx_end")
        return
    if axe == "x":
        cropped_img = img_n_mmni.slicer[idx_start:idx_end, ...]
    elif axe == "y":
        cropped_img = img_n_mmni.slicer[:, idx_start:idx_end,:]
    elif axe == "z":
        cropped_img = img_n_mmni.slicer[..., idx_start:idx_end]
    else:
        print("Choose a valid value for axe: x, y or z")
    return cropped_img

# To test thid function
n_mmni_filename = os.path.join(path, "n_mmni_fADNI_002_S_0295_1.5T_t1w.nii.gz")
img_n_mmni = nib.load(n_mmni_filename)
crop_img = cut_2D_i(img_n_mmni, "z", 90)
crop_img.shape
nib.save(crop_img, 'test_image.nii')

In [50]:
def load_X_data(path):
    if not os.path.isdir(path):
        print("Can't found directory: " + path)
    else:
        list_x = []
        for filename in os.listdir(path):
            n_mmni_filename = os.path.join(path, filename)
            img_n_mmni = nib.load(n_mmni_filename)
            # Customize your choice: taking a 2D cuts or 3D patches
            cropped_img = cut_2D_i(img_n_mmni, "z", 90)
            cropped_n_mmni_matrix = cropped_img.get_fdata()
            list_x.append((filename, cropped_n_mmni_matrix))
        return list_x

X_data = load_X_data(path)

# U-Net Neural Network

Creation of two U-Net models:
* One for 2D inputs, in case we slice the input into 2D images.
* The other for 3D inputs, in case we use 3D blocs of the image.

In [None]:
from tensorflow.keras.layers import Conv2D, MaxPooling2D, UpSampling2D, Conv3D, MaxPooling3D, UpSampling3D , Input
from tensorflow.keras.models import Model

def create_Unet_model2D(inputs, input_size, depth=5):
    x = Input(shape=input_size)
    num_filters = 64
    for i in range(depth):
        x = Conv2D(filters=num_filters, kernel_size=(3,3), activation='relu')(x)
        x = Conv2D(filters=num_filters, kernel_size=(3,3), activation='relu')(x)
        if i != depth - 1:
            x = MaxPooling2D(pool_size=(2,2), strides=2)(x)
        num_filters *= 2
    

    for i in range(depth):
        if i != depth - 1:
            x = UpSampling2D()(x)
            x = Conv2D(filters=num_filters, kernel_size=(2,2))(x)
        num_filters /= 2
        x = Conv2D(filters=num_filters, kernel_size=(3,3), activation='relu')(x)
        x = Conv2D(filters=num_filters, kernel_size=(3,3), activation='relu')(x)

    outputs = Conv2D(filters=2, kernel_size=(1,1))
        
    return Model(inputs, outputs)

def create_Unet_model3D(inputs, input_size, depth=5):
    x = Input(shape=input_size)
    num_filters = 64
    for i in range(depth):
        x = Conv3D(filters=num_filters, kernel_size=(3,3), activation='relu')(x)
        x = Conv3D(filters=num_filters, kernel_size=(3,3), activation='relu')(x)
        if i != depth - 1:
            x = MaxPooling3D(pool_size=(2,2), strides=2)(x)
        num_filters *= 2
    

    for i in range(depth):
        if i != depth - 1:
            x = UpSampling3D()(x)
            x = Conv3D(filters=num_filters, kernel_size=(2,2))(x)
        num_filters /= 2
        x = Conv3D(filters=num_filters, kernel_size=(3,3), activation='relu')(x)
        x = Conv3D(filters=num_filters, kernel_size=(3,3), activation='relu')(x)

    outputs = Conv3D(filters=2, kernel_size=(1,1))
        
    return Model(inputs, outputs)