# Perform Retraining of DAGMNet

## Load Model

In [1]:
import scipy
import os
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import tensorflow
import random

In [2]:
DAGMNet_ch2_name =  os.path.join("D:\ADS_Algorithm\ADSv1.3\data\Trained_Nets", 'DAGMNet_CH2.h5')

In [3]:
DAGMNet_ch2 = tensorflow.keras.models.load_model(DAGMNet_ch2_name, compile=False)

In [4]:
DAGMNet_ch2.input

<KerasTensor: shape=(None, 96, 112, 48, 2) dtype=float32 (created by layer 'input_1')>

In [5]:
swi_sample_img = nib.load("D:\\data_processed_ETIS\\Resized_Images\\DAGMNet_Training\\SWI_Train\\2018-104_01-10113-D0MR_9_Ax_T2_GRE__SkullStripped_Training_Normalized.nii.gz")
swi_sample_data = swi_sample_img.get_fdata()
tof_sample_img = nib.load("D:\\data_processed_ETIS\\Resized_Images\\DAGMNet_Training\\TOF3D_Train\\2018-104_01-10113-D0MR_6_3D_TOF_LARGE__SkullStripped_Training_Normalized.nii.gz")
tof_sample_data = tof_sample_img.get_fdata()

In [6]:
swi_sample_data = swi_sample_data[:,:,:,np.newaxis]
tof_sample_data = tof_sample_data[:,:,:,np.newaxis]

In [7]:
dagmnet_input = np.expand_dims(np.concatenate((swi_sample_data,tof_sample_data),axis=3), axis=0)

In [None]:
dagmnet_output = DAGMNet_ch2.predict(dagmnet_input)

In [49]:
DAGMNet_ch2.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 96, 112, 48  0           []                               
                                , 2)]                                                             
                                                                                                  
 conv3d (Conv3D)                (None, 96, 112, 48,  1760        ['input_1[0][0]']                
                                 32)                                                              
                                                                                                  
 tf_op_layer_strided_slice (Ten  (None, 48, 56, 24,   0          ['input_1[0][0]']                
 sorFlowOpLayer)                2)                                                            

## Read-in Training Images

In [4]:
swi_dir = "D:\\data_processed_ETIS\\Resized_Images\\DAGMNet_Training\\SWI_Train"
tof_dir = "D:\\data_processed_ETIS\\Resized_Images\\DAGMNet_Training\\TOF3D_Train"

swi_files = [file for file in os.listdir(swi_dir) if file.endswith('.nii.gz')]
tof_files = [file for file in os.listdir(tof_dir) if file.endswith('.nii.gz')]

X_train = np.zeros((len(swi_files), 96, 112, 48, 2))
index = 0

for swi_file, tof_file in zip(swi_files,tof_files):
    swi_img = nib.load(os.path.join(swi_dir, swi_file))
    swi_data = swi_img.get_fdata()
    tof_img = nib.load(os.path.join(tof_dir, tof_file))
    tof_data = swi_img.get_fdata()

    swi_data = swi_data[:,:,:,np.newaxis]
    tof_data = tof_data[:,:,:,np.newaxis]
    input_array = np.concatenate((swi_data,tof_data),axis=3)
    X_train[index] = input_array
    index += 1

In [5]:
mask_dir = "D:\\data_processed_ETIS\\Resized_Images\\DAGMNet_Training\\MASK_Train"

mask_files = [file for file in os.listdir(mask_dir) if file.endswith('.nii.gz')]

y_train = np.zeros((len(mask_files), 96, 112, 48))
index = 0

for mask_file in mask_files:
    mask_img = nib.load(os.path.join(mask_dir, mask_file))
    mask_data = mask_img.get_fdata()
    y_train[index] = mask_data
    index += 1

## Separate Test Batch of Images

In [8]:
mask_train_dir = "D:\\data_processed_ETIS\\Resized_Images\\DAGMNet_Training\\MASK_Train"
mask_test_dir = "D:\\data_processed_ETIS\\Resized_Images\\DAGMNet_Training\\MASK_Test"
swi_train_dir = "D:\\data_processed_ETIS\\Resized_Images\\DAGMNet_Training\\SWI_Train"
swi_test_dir = "D:\\data_processed_ETIS\\Resized_Images\\DAGMNet_Training\\SWI_Test"
tof_train_dir = "D:\\data_processed_ETIS\\Resized_Images\\DAGMNet_Training\\TOF3D_Train"
tof_test_dir = "D:\\data_processed_ETIS\\Resized_Images\\DAGMNet_Training\\TOF3D_Test"

In [9]:
random.seed(777)

In [10]:
test_indexes = random.sample(range(len(os.listdir(mask_train_dir))), 100)

In [11]:
file_list = [os.listdir(mask_train_dir)[index] for index in test_indexes]
for file in file_list:
    os.rename(os.path.join(mask_train_dir, file), os.path.join(mask_test_dir, file))

In [12]:
file_list = [os.listdir(swi_train_dir)[index] for index in test_indexes]
for file in file_list:
    os.rename(os.path.join(swi_train_dir, file), os.path.join(swi_test_dir, file))

In [13]:
file_list = [os.listdir(tof_train_dir)[index] for index in test_indexes]
for file in file_list:
    os.rename(os.path.join(tof_train_dir, file), os.path.join(tof_test_dir, file))

## Define Custom Loss Functions

In [12]:
def l1_loss(y_pred):
    return tensorflow.reduce_sum(tensorflow.abs(y_pred))

In [13]:
def dice_loss(y_true, y_pred, smoothing_factor=1):
    # tensorflow.keras.losses.Dice()
    return 1 - ((2 * tensorflow.reduce_sum(y_true * y_pred) + smoothing_factor) / (tensorflow.reduce_sum(y_true) + tensorflow.reduce_sum(y_pred) + smoothing_factor))

In [14]:
def combined_loss(y_true, y_pred):
    cross_entropy_loss = tensorflow.keras.losses.BinaryCrossentropy(reduction="sum")
    weight_cross_entropy, weight_dice, weight_l1_regularization = (1.,1.,0.00001)
    
    return weight_cross_entropy * cross_entropy_loss(y_true, y_pred) + weight_dice * dice_loss(y_true, y_pred) + weight_l1_regularization * l1_loss(y_pred)

In [15]:
test_pred = np.random.choice(np.array([0.,1.]),(404,96,112,48), p=[0.999998,0.000002])

In [16]:
dice_loss(y_train, test_pred)

<tf.Tensor: shape=(), dtype=float64, numpy=0.9987063389391979>

## Restructure Network

In [30]:
last_layer_to_keep = DAGMNet_ch2.layers[-5].output

In [31]:
New_DAGMNet_ch2 = tensorflow.keras.Model(inputs=DAGMNet_ch2.input, outputs=last_layer_to_keep)

In [32]:
New_DAGMNet_ch2.compile(optimizer=tensorflow.keras.optimizers.Adam(learning_rate=0.0003), loss=combined_loss, metrics=["accuracy"])

In [33]:
New_DAGMNet_ch2.fit(X_train, y_train, batch_size=4, epochs=20, validation_split=0.2)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x25f2d95f760>

In [34]:
New_DAGMNet_ch2.save("D:\\ADS_Inference\\Retrained_ADS\\Retrained_DAGMNet_ch2_404.h5")

## Test Inference of Retrained Model

In [59]:
swi_test_img = nib.load("D:\\data_processed_ETIS\\Resized_Images\\DAGMNet_Training\\SWI_Test\\2018-104_02-10410-D0MR_17_T2_EG_TRA_fl2d1r_SkullStripped_Training_Normalized.nii.gz")
swi_test_data = swi_test_img.get_fdata()
tof_test_img = nib.load("D:\\data_processed_ETIS\\Resized_Images\\DAGMNet_Training\\TOF3D_Test\\2018-104_02-10410-D0MR_13_TOF_3D_RAPIDE_fl3d1r_t70_SkullStripped_Training_Normalized.nii.gz")
tof_test_data = tof_test_img.get_fdata()

In [60]:
swi_test_data = swi_test_data[:,:,:,np.newaxis]
tof_test_data = tof_test_data[:,:,:,np.newaxis]

In [61]:
new_dagmnet_input = np.expand_dims(np.concatenate((swi_test_data,tof_test_data),axis=3), axis=0)

In [62]:
new_dagmnet_output = New_DAGMNet_ch2.predict(new_dagmnet_input)



In [63]:
def save_array_to_nifti1(array, original_img, destination_path, output_name):
    # Transform the array to a nifti image which requires the affine of the original image.
    processed_img = nib.Nifti1Image(array, original_img.affine)
    nib.save(processed_img, os.path.join(destination_path, output_name))

In [64]:
def Stroke_closing(img):
    # used to close stroke prediction image
    new_img = np.zeros_like(img)
    new_img = scipy.ndimage.binary_closing(img, structure=np.ones((2,2,2)))
    return new_img

In [69]:
new_dagmnet_output_squeezed = np.squeeze(new_dagmnet_output)
predicted_mask_no_postprocessing = (new_dagmnet_output_squeezed>0.5).astype("float64")
predicted_mask = Stroke_closing(predicted_mask_no_postprocessing)
predicted_mask = scipy.ndimage.binary_fill_holes(predicted_mask)
predicted_mask = predicted_mask.astype("float64")

In [70]:
save_array_to_nifti1(predicted_mask_no_postprocessing, swi_test_img, "D:\\data_processed_ETIS\\Resized_Images\\DAGMNet_Training\\Predict_Test", "2018-104_02-10410-D0MR_17_T2_EG_TRA_fl2d1r_Predict_NoPP.nii.gz")