## Import Libraries

In [None]:
import os
import pickle
import glob
import random
import nibabel as nib
import numpy as np
import tensorflow as tf
import keras
from matplotlib import pyplot as plt
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, UpSampling2D, Conv2DTranspose
from tensorflow.keras.layers import Activation, BatchNormalization, PReLU, Dropout
from tensorflow.keras.optimizers import Adamax
from tensorflow.keras.metrics import Accuracy
from tensorflow.keras.layers import concatenate
from tensorflow.keras.callbacks import ModelCheckpoint
from scipy.ndimage import binary_dilation
from sklearn.model_selection import KFold

## Functions

### Preprocessing data

In [None]:
def prepare_data(images, image_type, view):
    """
    Prepare data.

    Args:
        images (np.ndarray): The peaks image extracted from DWI images or binary masks resulted from the segmentation tracts.
        image_type (int): if it's DWI or mask data.
        view (int): choose between coronal, axial or saggital.
    Returns:
        The images sliced.
    """
    images_sliced = []
    for f in images:
        if image_type == 'peaks':
            a = nib.load(f).get_fdata()[:144, 15:159, :144, :9]
            a = np.nan_to_num(a)
        else:
            a = nib.load(f).get_fdata()[:144, 15:159, :144]
            a = np.nan_to_num(a)
            a = np.expand_dims(a, -1)

        for i in range(144):
            # axial view
            if view == 'axial':
                slices = a[:, :, i, :]
            # coronal view
            elif view == 'coronal':
                slices = a[:, i, :, :]
            # sagittal view
            elif view == 'sagittal':
                slices = a[i, :, :, :]
            else:
                raise ValueError(f"Unknown view type: {view}")
            images_sliced.append(slices)
    return images_sliced

def generate_dataset(data, peaks_sliced, mask_sliced, batch_size):
    """
    Generate dataset.

    Args:
        data (int): prepare data for train or test.
        peaks_sliced (np.ndarray): Sliced DWI images.
        mask_sliced (np.ndarray): Sliced binary masks images.
        batch_size (int): size of batch.
    Returns:
        Dataset.
    """
    dataset = tf.data.Dataset.from_tensor_slices((peaks_sliced, mask_sliced)) if data == 'train' else tf.data.Dataset.from_tensor_slices((peaks_sliced))
    dataset = dataset.batch(batch_size)
    return dataset

### Model UNET

In [None]:
def unet_model(input_shape=(144, 144, 9)):
    # Set image data format to channels last
    keras.backend.set_image_data_format("channels_last")

    # Define the input layer
    inputs = Input(shape=input_shape)

    # Contracting path
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
    conv1 = BatchNormalization()(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
    conv3 = BatchNormalization()(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
    conv4 = BatchNormalization()(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    pool4 = Dropout(0.4)(pool4)

    # Bottom layer
    conv5 = Conv2D(1024, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
    conv5 = BatchNormalization()(conv5)
    conv5 = Conv2D(1024, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
    conv5 = BatchNormalization()(conv5)

    # Expansive path
    up6 = Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(conv5)
    up6 = concatenate([up6, conv4], axis=3)
    conv6 = Conv2D(512, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv6)

    up7 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv6)
    up7 = concatenate([up7, conv3], axis=3)
    conv7 = Conv2D(256, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv7)

    up8 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv7)
    up8 = concatenate([up8, conv2], axis=3)
    conv8 = Conv2D(128, (3, 3), activation='relu', padding='same')(up8)
    conv8 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv8)

    up9 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8)
    up9 = concatenate([up9, conv1], axis=3)
    conv9 = Conv2D(64, (3, 3), activation='relu', padding='same')(up9)
    conv9 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv9)

    conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)

    model = Model(inputs=[inputs], outputs=[conv10])
    return model

unet = unet_model()
unet.summary()

### Training

In [None]:
def dice_coef(y_true, y_pred, smooth=1):
    """
    Dice coeficient calc.

    Args:
        y_true (np.ndarray): Ground truth images.
        y_pred (np.ndarray): Prediction images.
    Returns:
        Dice score.
    """
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true, axis=[1, 2, 3]) + tf.reduce_sum(y_pred, axis=[1, 2, 3])
    dice = tf.reduce_mean((2. * intersection + smooth) / (union + smooth))
    return dice

def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)

In [None]:
def train_unet(peaks_images, mask_images, val_peaks_images, val_masks_images, save_dir, batch_size=56, epochs=130, learning_rate=0.002, shuffle=False, verbose=1):
        """
    Train a U-Net model.

    Args:
        peaks_images (np.ndarray): The peaks image extracted from DWI images.
        mask_images (np.ndarray): The binary masks resulted from the segmentation tracts.
        val_peaks_images (np.ndarray): The validation peaks image.
        val_masks_images (np.ndarray): The validation binary masks.
        save_dir (str): The directory where the best model weights will be saved.
        epochs (int): The number of epochs to train the model for each fold.
        learning_rate (float): The learning rate for the Adamax optimizer.
        shuffle (str): if set True it will shuffle the slices from the dataset during the training.
        verbose (int): Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch.

    Returns:
        The history object of the trained model.
        """
        # Prepare the training and validation datasets
        train_dataset = generate_dataset('train', peaks_images, mask_images, batch_size=batch_size)
        val_dataset = generate_dataset('train', val_peaks_images, val_masks_images, batch_size=batch_size)

        # Define the U-Net model
        unet = unet_model()

        # Define the checkpoint callback to save the best model weights
        checkpoint_callback = ModelCheckpoint(filepath=save_dir,
                                            save_best_only=False,
                                            save_weights_only=False,
                                            monitor='val_loss',
                                            mode='min',
                                            verbose=verbose)

        # Compile the model
        unet.compile(optimizer=Adamax(learning_rate=learning_rate),
                    loss='binary_crossentropy',
                    metrics=[dice_coef])

        # Train the model
        history = unet.fit(train_dataset,
                        epochs=epochs,
                        validation_data=val_dataset,
                        callbacks=[checkpoint_callback],
                        shuffle = shuffle,
                        verbose=verbose)
        return history

### Prediction in new data

#### Load models, predict and save

In [None]:
def predict_new_data(peaks_images, path_model, view, num_subjects):
    """
    Predict new data using a trained model.

    Args:
        peaks (list): List of paths to the peaks images.
        path_model (str): Path to the trained model.
        view (str): The view for which to perform the prediction. Must be one of 'coronal', 'axial', or 'sagital'.
        num_subjects (int): Number of subjects in the dataset.

    Returns:
        List of predicted data for each subject.
    """
    model = keras.models.load_model(path_model, custom_objects={'dice_coef': dice_coef})

    peaks_sliced = prepare_data(peaks_images, 'peaks', view)
    dataset = tf.data.Dataset.from_tensor_slices(peaks_sliced).batch(1)
    predictions = model.predict(dataset)

    if view == 'coronal':
        pred = np.transpose(predictions, (1, 0, 2, 3))
        axis = 1
    elif view == 'axial':
        pred = np.transpose(predictions, (1, 2, 0, 3))
        axis = 2
    else:  # 'sagital'
        pred = np.transpose(predictions, (0, 1, 2, 3))
        axis = 0

    pred = np.split(pred, num_subjects, axis=axis)
    return pred

In [None]:
def resize_data(data):
  """
      Resize the prediction data (144x144x144)
      to the original size(145x175x145).
  """
    data_resized = []
    for n in data:
        mat = np.zeros((144, 15, 144), dtype=n.dtype)
        data_res = np.concatenate([mat, n, mat], axis=1)
        data_res = np.insert(data_res, 144, 0, axis=0)
        data_res = np.insert(data_res, 144, 0, axis=2)
        data_resized.append(data_res)
    return data_resized

In [None]:
def save_predictions(coronal_predictions, axial_predictions, sagittal_predictions, flip=True, mean=False, reference_files=None, save_path='.', format='nii', threshold_factor=0.5):
    """
    Predicts masks for new data using a trained model and saves the results.

    Args:
        coronal_predictions (np.ndarray): Input images results from coronal model.
        axial_predictions (np.ndarray): Input images results from axial model.
        sagital_predictions (np.ndarray): Input images results from sagital model.
        flip (bool, optional): Whether to flip the predictions along the second axis. Defaults to True.
        mean (bool, optional): Whether to average the predictions along the fifth axis. Defaults to False.
        ids (list, optional): List of ids for each input image. Defaults to None.
        reference_files (list, optional): List of reference files for each input image. Defaults to None.
        save_path (str, optional): Path to save the predicted masks. Defaults to '.'.
        format (str, optional): Format to save the predicted masks ('nii' or 'npy'). Defaults to 'nii'.
        threshold_factor (float, optional): Factor to multiply the maximum prediction value to determine the threshold for creating the binary mask. Defaults to 0.5.

    Returns:
        list: List of predicted masks.
    """

    if flip:
        coronal_predictions = np.flip(coronal_predictions, axis=1)
        axial_predictions = np.flip(axial_predictions, axis=1)
        sagittal_predictions = np.flip(sagittal_predictions, axis=1)

    result = np.concatenate((coronal_predictions, axial_predictions, sagittal_predictions), axis=4)

    if mean:
        result = np.mean(result, axis=4)
        result = resize_data(result)

    # Create masks from predictions
    ids = ['pred_01', 'pred_02', 'pred_03', 'pred_04', 'pred_05', 'pred_06', 'pred_07', 'pred_08', 'pred_09', 'pred_10', 'pred_11', 'pred_12', 'pred_13', 'pred_14', 'pred_15', 'pred_16', 'pred_17', 'pred_18', 'pred_19', 'pred_20','pred_21', 'pred_22', 'pred_23', 'pred_24', 'pred_25', 'pred_26', 'pred_27', 'pred_28', 'pred_29', 'pred_30','pred_31', 'pred_32', 'pred_33', 'pred_34', 'pred_35', 'pred_36', 'pred_37', 'pred_38', 'pred_39', 'pred_40','pred_41', 'pred_42', 'pred_43', 'pred_44', 'pred_45', 'pred_46', 'pred_47', 'pred_48', 'pred_49', 'pred_50','pred_51', 'pred_52', 'pred_53', 'pred_54', 'pred_55', 'pred_56', 'pred_57', 'pred_58', 'pred_59', 'pred_60','pred_61', 'pred_62', 'pred_63', 'pred_64', 'pred_65', 'pred_66', 'pred_67', 'pred_68', 'pred_69', 'pred_70','pred_71', 'pred_72', 'pred_73', 'pred_74', 'pred_75', 'pred_76', 'pred_77', 'pred_78', 'pred_79', 'pred_80','pred_81', 'pred_82', 'pred_83', 'pred_84', 'pred_85', 'pred_86', 'pred_87', 'pred_88', 'pred_89', 'pred_90','pred_91', 'pred_92', 'pred_93', 'pred_94', 'pred_95', 'pred_96', 'pred_97', 'pred_98', 'pred_99', 'pred_100']

    masks_pred = []
    for im in result:
        threshold = np.max(im) * threshold_factor
        pred = (im > threshold).astype(np.uint8)
        pred = binary_dilation(pred, iterations=1)
        masks_pred.append(pred)

    if format == 'nii':
        # save as nifti
        for (i, j, k) in zip(masks_pred, reference_files, ids):
            img = nib.load(j)
            nii = nib.Nifti1Image(i, img.affine, img.header)
            nib.save(nii, os.path.join(save_path, f'{k}.nii'))

    elif format == 'npy':
        # save as numpy
        for (m, n) in zip(masks_pred, ids):
            np.save(os.path.join(save_path, f'{n}.npy'), m)

    return masks_pred

#### Metrics

In [None]:
def dice_score(y_true, y_pred, mean=True):
   Y_true = []
    for i in y_true:
        true = nib.load(i).get_fdata()
        true = binary_dilation(true, iterations=1)
        true = tf.cast(true, dtype=tf.float64)
        Y_true.append(true)

    Y_pred = []
    for j in y_pred:
        pred = nib.load(j).get_fdata()
        pred = binary_dilation(pred, iterations=1)
        if mean == True:
           pred = np.mean(pred,axis=2)
        else:
          pred = pred
        pred = tf.cast(pred, dtype=tf.float64)
        Y_pred.append(pred)

    dice = []
    for (m,n) in zip(Y_true,Y_pred):
      dc = dice_coef(m,n)
      dc = np.ravel(dc)
      dice.append(dc)

    mean_dice = np.mean(dice)
    std_dice = np.std(dice)

    print(f'The mean dice score is: {mean_dice:.4f} and the standard deviation is: {std_dice:.4f}')
    return mean_dice, std_dice

## Training

### Set up

In [None]:
##Train data ------------------------------------------------------------------
peaks_path = 'path/to/data'  #set the path to the train peaks images
peaks_images  = sorted(glob.glob(peaks_path_hcp + '*'))
mask_path =  'path/to/data'  #set the path to the train binary mask images
mask_images  = sorted(glob.glob(mask_path_hcp + '*'))

val_peaks_path  = 'path/to/data' #set the path to the validation peaks images
val_peaks_images  = sorted(glob.glob(val_peaks_path_hcp + '*'))
val_mask_path =  'path/to/data' #set the path to the validation binary mask images
val_mask_images  = sorted(glob.glob(val_mask_path_hcp + '*'))

#Save weights and predictions path ---------------------------------------------
save_model_path_coronal  = 'save/directory/coronal/model' #set the path to save your coronal model weights
save_model_path_axial  = 'save/directory/axial/model' #set the path to save your axial model weights
save_model_path_sagittal  = 'save/directory/sagittal/model' #set the path to save your sagital model weights

### Workflow

In [None]:
#Training coronal view
print('Training coronal view ...')
print('Preprocessing data...')
coronal_peaks = np.asarray(prepare_data(peaks_images, 'peaks', 'coronal'))
coronal_masks = np.asarray(prepare_data(mask_images, 'masks', 'coronal'))
val_coronal_peaks = np.asarray(prepare_data(val_peaks_images, 'peaks', 'coronal'))
val_coronal_masks = np.asarray(prepare_data(val_masks_images, 'masks', 'coronal'))
print('Train...')
coronal_histories = train_unet(coronal_peaks, coronal_masks,
                               val_coronal_peaks, val_coronal_masks,
                               save_model_path_coronal,
                               batch_size=56, epochs=130, learning_rate=0.002,
                               shuffle=False, verbose=1)

In [None]:
#Training axial view
print('Training axial view ...')
axial_peaks = np.asarray(prepare_data(peaks_images, 'peaks', 'axial'))
axial_masks = np.asarray(prepare_data(masks_images, 'masks', 'axial'))
val_axial_peaks = np.asarray(prepare_data(val_peaks_images, 'peaks', 'axial'))
val_axial_masks = np.asarray(prepare_data(val_masks_images, 'masks', 'axial'))
print('Train...')
axial_histories = train_unet(axial_peaks, axial_masks,
                             val_axial_peaks, val_axial_masks,
                             save_model_path_axial,
                             batch_size=56, epochs=130, learning_rate=0.002,
                             shuffle=False, verbose=1)

In [None]:
#Training sagittal view
print('Training sagittal view ...')
sagittal_peaks = np.asarray(prepare_data(peaks_images, 'peaks', 'sagittal'))
sagittal_masks = np.asarray(prepare_data(masks_images, 'masks', 'sagittal'))
val_sagittal_peaks = np.asarray(prepare_data(val_peaks_images, 'peaks', 'sagittal'))
val_sagittal_masks = np.asarray(prepare_data(val_masks_images, 'masks', 'sagittal'))
print('Train...')
sagittal_histories = train_unet(sagittal_peaks, sagittal_masks,
                                val_sagittal_peaks, val_sagittal_masks,
                                save_model_path_sagittal,
                                batch_size=56, epochs=130, learning_rate=0.002,
                                shuffle=False, verbose=1)

##Test and Dice Scores


###Set Up

In [None]:
##Test data ------------------------------------------------------------------
test_peaks_path = 'path/to/data'  #set the path to the test peaks images
test_peaks_images  = sorted(glob.glob(test_peaks_path + '*'))

reference_masks_path = 'path/to/data'  #set the path to reference mask images (this is to save the predictions in .nii extension, so you need a reference header)
reference_masks_images  = sorted(glob.glob(reference_masks_path + '*'))

#Local to weights and save predictions directory ---------------------------------------------
model_path_coronal  = 'directory/coronal/model' #set the path to coronal model weights
model_path_axial  = 'directory/axial/model' #set the path to axial model weights
model_path_sagittal  = 'directory/sagittal/model' #set the path to sagital model weights

save_results_path = 'save/directory/predictions/results' #set the path to save your results (predictions)

### Test in new data

In [None]:
#Testing

print('Predicting coronal view ...')
coronal_predictions = predict_new_data(test_peaks_images, model_path_coronal, 'coronal', 60)

print('Predicting axial view ...')
axial_predictions = predict_new_data(test_peaks_images, save_model_path_axial, 'axial', 60)

print('Predicting sagittal view ...')
sagittal_predictions = predict_new_data(test_peaks_images, save_model_path_sagittal, 'sagittal', 60)

print('Concatenating the results and saving ...')
save_predictions(coronal_predictions, axial_predictions, sagittal_predictions,
                 flip=False, mean=True, resize=True,
                 reference_files=reference_masks_images, save_path=save_results_path,
                 format='nii', threshold_factor=0.5) #the threshold can change depending the tracts that you are interested, so always check the results

### Dice scores from the predictions

In [None]:
#Dice Scores
y_pred_path = 'path/to/your/results/'
y_pred = sorted(glob.glob(y_pred_path + '*'))
#Ground truth
y_true_path = 'path/to/ground/truth/'
y_true = sorted(glob.glob(y_true_path + '*'))

dice = dice_score(y_true, y_pred, mean=False)

The mean dice score is: 0.7238 and the standard deviation is: 0.0541


In [None]:
#Visualize the results
ground_truth = nib.load(y_true[01]).get_fdata()
ground_truth_slice = ground_truth[75,:,:]

pred = nib.load(y_pred[01]).get_fdata()
pred_image_slice = pred[75,:,:]

fig = plt.figure()
ax1 = fig.add_subplot(2,2,1)
ax1.imshow(ground_truth_slic)
ax2 = fig.add_subplot(2,2,2)
ax2.imshow(pred_image_slice)