In [None]:
!pip install tensorflow-gpu==2.1.0
!pip install keras==2.3.1



In [None]:
%cd drive/MyDrive/DeepAnat/
!ls

/content/drive/MyDrive/DeepAnat
cnn_models.py  DeepAnat.ipynb  generator      __pycache__
cnn_utils.py   example_data    pretrain_data  SpectralNormalizationKeras.py


In [None]:
# %% load moduals
import os
import glob
import scipy.io as sio
import numpy as np
import nibabel as nib
import tensorflow as tf
from matplotlib import pyplot as plt

from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint
from keras.models import load_model

import aini_utils as utils
from cnn_models import unet_3d_model

# for compatibility
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
from tensorflow.compat.v1 import GPUOptions

gpu_options = GPUOptions(per_process_gpu_memory_fraction=0.9)
config = ConfigProto(gpu_options=gpu_options)
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

Using TensorFlow backend.


In [None]:
%tensorflow_version 2.x
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

Found GPU at: /device:GPU:0


In [None]:
# %% set up path and parameters
dpRoot = os.path.dirname(os.path.abspath('DeepAnat.ipynb'))
os.chdir(dpRoot)

block_size = 64
finetune = 1 # 1 for fine-tuning, 0 for training from scratch 

# %% subjects
if finetune
    subjects = sorted(glob.glob(os.path.join(dpRoot, 'example_data', 'mwu*')))
else
    subjects = sorted(glob.glob(os.path.join(dpRoot, 'pretrain_data', 'mwu*')))

In [None]:
# %% load data 
train_block_in = np.array([])
valid_block_in = np.array([])

sz_block = 64    
sz_pad = 1
flip = 1 # flip along x to augment training data
input_list = ['diff_meanb0', 'diff_meandwi', 'diff_dtiL1', 'diff_dtiL2', 'diff_dtiL3',
              'diff_dtiDwi1', 'diff_dtiDwi2', 'diff_dtiDwi3', 'diff_dtiDwi4', 'diff_dtiDwi5', 'diff_dtiDwi6']

for ii in np.arange(len(subjects)):
    sj = os.path.basename(subjects[ii])
    
    print(sj)
    dpSub = os.path.join(dpData, sj)
    
    fpT1w = os.path.join(dpSub, sj + '_t1w.nii.gz')
    t1w = nb.load(fpT1w).get_data()   
    t1w = np.expand_dims(t1w, -1)

    fpMask = os.path.join(dpSub, sj + '_mask.nii.gz')
    mask = nb.load(fpMask).get_data() 
    mask = np.expand_dims(mask, -1)
        
    input = 0.
    for jj in np.arange(0, len(input_list)):
        
        fpImage = os.path.join(dpSub, sj + '_' + input_list[jj] + '.nii.gz')
        image = nb.load(fpImage).get_data()   
        image = np.expand_dims(image, -1)      
        
        if jj == 0:
            inputs = image
        else:
            inputs = np.concatenate((inputs, image), axis=-1)

    norm_ch = [0, 1, 5, 6, 7, 8, 9, 10]
    t1w_norm, tmp = utils.normalize_image(t1w, t1w, mask)
    inputs_norm, tmp = utils.normalize_image(inputs, inputs, mask, norm_ch)
    
    t1w_norm = t1w_norm * mask # exclude non-brain content from loss calculation
    inputs_norm = inputs_norm * mask
    
    ind_block, ind_brain = utils.block_ind(mask, sz_block=sz_block, sz_pad=sz_pad)
    
    t1w_norm_block = utils.extract_block(t1w_norm, ind_block)
    inputs_norm_block = utils.extract_block(inputs_norm, ind_block)
    mask_block = utils.extract_block(mask, ind_block)
    
    t1w_norm_block = np.concatenate((t1w_norm_block, mask_block), axis=-1)
    
    if flip: # Flip x to augment data
        inputs_norm_block_flip = inputs_norm_block[:,::-1,:,:,:]
        mask_block_flip = mask_block[:,::-1,:,:,:]
        t1w_norm_block_flip = t1w_norm_block[:,::-1,:,:,:]
        
        inputs_norm_block = np.concatenate((inputs_norm_block, inputs_norm_block_flip), axis=0)
        mask_block = np.concatenate((mask_block, mask_block_flip), axis=0)
        t1w_norm_block = np.concatenate((t1w_norm_block, t1w_norm_block_flip), axis=0)
    
    if finetune: # select 20% blocks from each subject for the validation data during fine-tuning
        l = np.shape(mask_block)[0]
        val_num = int(np.round(l * 0.2))
        idx = np.arange(l)
        np.random.shuffle(idx)
        inputs_norm_block = inputs_norm_block[idx,:,:,:,:]
        mask_block = mask_block[idx,:,:,:,:]
        t1w_norm_block = t1w_norm_block[idx,:,:,:,:]
        
        if valid_block_in.size == 0: 
            valid_block_out = t1w_norm_block[0:val_num,:,:,:,:]
            valid_block_in = inputs_norm_block[0:val_num,:,:,:,:]
            valid_block_mask = mask_block[0:val_num,:,:,:,:]    
        else:
            valid_block_out = np.concatenate((valid_block_out, t1w_norm_block[0:val_num,:,:,:,:]), axis=0)
            valid_block_in = np.concatenate((valid_block_in, inputs_norm_block[0:val_num,:,:,:,:]), axis=0)
            valid_block_mask = np.concatenate((valid_block_mask, mask_block[0:val_num,:,:,:,:]), axis=0)


        if train_block_in.size == 0: 
            train_block_out = t1w_norm_block[val_num:l,:,:,:,:]
            train_block_in = inputs_norm_block[val_num:l,:,:,:,:]           
            train_block_mask = mask_block[val_num:l,:,:,:,:]     
        else:
            train_block_out = np.concatenate((train_block_out, t1w_norm_block[val_num:l,:,:,:,:]), axis=0)
            train_block_in = np.concatenate((train_block_in, inputs_norm_block[val_num:l,:,:,:,:]), axis=0)
            train_block_mask = np.concatenate((train_block_mask, mask_block[val_num:l,:,:,:,:]), axis=0)
    
    else: # select 20% subjects as validation data if training from scratch
        if np.mod(ii + 2, 5) == 0: # 1 out of 5 subjects for validation
            print('validation subject')
            
            if valid_block_in.size == 0: 
                valid_block_out = t1w_norm_block
                valid_block_in = inputs_norm_block      
                valid_block_mask = mask_block     
            else:
                valid_block_out = np.concatenate((valid_block_out, t1w_norm_block), axis=0)
                valid_block_in = np.concatenate((valid_block_in, inputs_norm_block), axis=0)
                valid_block_mask = np.concatenate((valid_block_mask, mask_block), axis=0)
        else:
            print('training subject')

            if train_block_in.size == 0: 
                train_block_out = t1w_norm_block
                train_block_in = inputs_norm_block           
                train_block_mask = mask_block     
            else:
                train_block_out = np.concatenate((train_block_out, t1w_norm_block), axis=0)
                train_block_in = np.concatenate((train_block_in, inputs_norm_block), axis=0)
                train_block_mask = np.concatenate((train_block_mask, mask_block), axis=0)

mwu119126
test subject
mwu120212
mwu120414
mwu121921
mwu125222
mwu126325
test subject
mwu126426
mwu127226
mwu128026
mwu128632
mwu129937
test subject
mwu130417
mwu130619
mwu130720
mwu132017
mwu133827
test subject
mwu135528
mwu135629
mwu135730
mwu137431
mwu138332
test subject
mwu144125
mwu144226
mwu146634
mwu146735
mwu147636
test subject
mwu148133
mwu148941
mwu149337
mwu151223
mwu151627
test subject
mwu151930
mwu152225
mwu153227
mwu158843


In [None]:


# %% set up models
# set up models
num_ch = train_block_in.shape[-1]
num_epochs = 30

if finetune: # load pre-trained model
    fnCp = 'unet_all2t1w_hcp'
    fpCp = os.path.join(dpRoot, fnCp, fnCp + '.h5') 
    model_unet = load_model(fpCp, custom_objects={'mean_absolute_error_weighted': utils.mean_absolute_error_weighted})
    print('model loaded:', fnCp)
    adam_opt_unet = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0)
    model_unet.compile(loss = utils.mean_absolute_error_weighted, optimizer = adam_opt_unet)
    model_unet = unet_3d_model(num_ch)
    model_unet.summary()

else: # train from scratch
    adam_opt_unet = Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0)
    model_unet.compile(loss = utils.mean_absolute_error_weighted, optimizer = adam_opt_unet)

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, None, None, N 0                                            
__________________________________________________________________________________________________
conv3d_1 (Conv3D)               (None, None, None, N 1344        input_1[0][0]                    
__________________________________________________________________________________________________
activation_1 (Activation)       (None, None, None, N 0           conv3d_1[0][0]                   
__________________________________________________________________________________________________
conv3d_2 (Conv3D)               (None, None, None, N 62256       activation_1[0][0]               
____________________________________________________________________________________________

In [None]:
# train
sz_batch = 1
fnCp = 'unet_MGHfinetuned'
dpCnn = os.path.join(dpRoot, fnCp) 
if not os.path.exists(dpCnn):
    os.mkdir(dpCnn)
    print('create directory')
        
fpCp = os.path.join(dpCnn, fnCp + '.h5')
cp = ModelCheckpoint(fpCp, monitor='val_loss', save_best_only = True)

history = model_unet.fit(x = [train_block_in, train_block_mask], 
                         y = train_block_out, 
                         validation_data = ([valid_block_in, valid_block_mask], valid_block_out),
                         batch_size = sz_batch, 
                         epochs = num_epochs,  
                         shuffle = True, 
                         callbacks = [cp], 
                         verbose = 1)

fpLoss = os.path.join(dpCnn, fnCp + '.mat') 
sio.savemat(fpLoss, {'loss_train':history.history['loss'], 'loss_val':history.history['val_loss']})  

Train on 714 samples, validate on 180 samples
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30

KeyboardInterrupt: ignored

In [None]:
# %% apply 
print('Applying...')

# %% load data
sz_block = 64
sz_pad = 5
sz_crop = 3
input_list = ['diff_meanb0', 'diff_meandwi', 'diff_dtiL1', 'diff_dtiL2', 'diff_dtiL3',
              'diff_dtiDwi1', 'diff_dtiDwi2', 'diff_dtiDwi3', 'diff_dtiDwi4', 'diff_dtiDwi5', 'diff_dtiDwi6']

mse = []

for ii in np.arange(len(subjects)):
    sj = os.path.basename(subjects[ii])
    
    print(sj)
    dpSub = os.path.join(dpData, 'evaluation_subjects', sj)
    
    fpT1w = os.path.join(dpSub, sj + '_t1w.nii.gz')
    t1w = nb.load(fpT1w).get_data()   
    t1w = np.expand_dims(t1w, -1)

    fpMask = os.path.join(dpSub, sj + '_mask.nii.gz')
    mask = nb.load(fpMask).get_data() 
    mask = np.expand_dims(mask, -1)
    
    input = 0.
    for jj in np.arange(0, len(input_list)):
        
        fpImage = os.path.join(dpSub, sj + '_' + input_list[jj] + '.nii.gz')
        image = nb.load(fpImage).get_data()   
        image = np.expand_dims(image, -1)      

        if jj == 0:
            inputs = image
        else:
            inputs = np.concatenate((inputs, image), axis=-1)

    norm_ch = [0, 1, 5, 6, 7, 8, 9, 10] # do not normalize DTI metrics
    t1w_norm, tmp = utils.normalize_image(t1w, t1w, mask)
    inputs_norm, tmp = utils.normalize_image(inputs, inputs, mask, norm_ch) 
    
    ind_block, ind_brain = utils.block_ind(mask, sz_block=sz_block, sz_pad=sz_pad)
    inputs_norm_block = utils.extract_block(inputs_norm, ind_block)
    mask_block = utils.extract_block(mask, ind_block)
    
    t1w_pred_block = np.zeros(mask_block.shape)
    
    for mm in np.arange(0, mask_block.shape[0]):
        tmp = model_unet.predict([inputs_norm_block[mm:mm+1, :, :, :, :], mask_block[mm:mm+1, :, :, :, :]]) 
        t1w_pred_block[mm:mm+1, :, :, :, :] = tmp[:, :, :, :, :-1]

    t1w_pred_vol, tmp = utils.block2brain(t1w_pred_block, ind_block, mask, sz_crop)

    fpPred = os.path.join(dpSub, sj + fnCp + '_predimg_norm.nii.gz')
    utils.save_nii(fpPred, t1w_pred_vol, fpMask)
    
    t1w_mse = (t1w_norm + 3) / 6
    pred_mse = (t1w_pred_vol + 3) / 6
    mse_subject = np.mean((t1w_mse[mask > 0.5] - pred_mse[mask > 0.5]) ** 2)
    print('mean squared error:', mse_subject)
    mse.append(mse_subject)
    
    # transform standardized intensities to normal range
    # can use the mean and std from one of training subjects
    img_mean = np.mean(t1w[mask > 0.5])
    img_std = np.std(t1w[mask > 0.5])
    
    t1w_pred_final = (t1w_pred_vol * img_std + img_mean) * mask
    fpPred = os.path.join(dpSub, sj + fnCp + '_predimg_final.nii.gz')
    utils.save_nii(fpPred, t1w_pred_final, fpMask)

fpMSE = os.path.join(dpCnn, fnCp + '_mse.mat') 
sio.savemat(fpMSE, {'mse':mse})
print('Applying finished')  

Applying...
mwu159744
evaluation subject
Applying finished
