## Load and pre-process DWI data

In [1]:
import sys
import numpy as np
sys.path.append("./utils")
from data_handling import *

In [2]:
dwi_path = "./data/dwi"
mask_path = "./data/mask"

curr_path = os.getcwd()
dwi_file = get_file_path(curr_path, dwi_path, "*.nii*")
mask_file = get_file_path(curr_path, mask_path, "*.nii*")

In [3]:
dwi_data = nib.load(dwi_file)
dwi = dwi_data.get_data().astype("float32")
mask = nib.load(mask_file).get_data()

In [4]:
from dipy.io import read_bvals_bvecs

bval_file = get_file_path(curr_path, dwi_path, "*.bvals")
bvec_file = get_file_path(curr_path, dwi_path, "*.bvecs")

bvals, bvecs = read_bvals_bvecs(bval_file, bvec_file)

In [5]:
resampled_dwi = resample_dwi(mask_dwi(dwi, mask), bvals, bvecs, directions=None, sh_order=8, smooth=0.006)
resampled_dwi = 255 * mask_dwi(resampled_dwi, mask)

Nb. erroneous voxels: 211474


## Load and pre-process tractography data (labels)

In [7]:
tractogram_path = "./data/tractography"
tractogram_file = get_file_path(curr_path, tractogram_path, "*.trk")

tractogram_data = streamlines.load(tractogram_file)
tractogram = tractogram_data.streamlines



In [8]:
# tractogram = align_streamlines_to_grid(tractogram_data, dwi_data)

## Prepare data for training

In [9]:
from train_utils import *

In [10]:
# Calculate the mean DWI value in each volume, to be used later for normalization

mask_path = "./data/WM_mask"
wm_mask_file = get_file_path(curr_path, mask_path, "*.nii*")
# wm_mask = nib.load(mask_file).get_data()
wm_mask = nib.load(wm_mask_file).get_data()[::2,::2,::2]

dwi_means = calc_mean_dwi(resampled_dwi, wm_mask)

In [11]:
vector_labels = get_geometrical_labels(tractogram)

In [12]:
# Apply train-validation split
from sklearn.model_selection import train_test_split

valid_set_ratio = 0.1
X_train, X_valid, y_train, y_valid = train_test_split(tractogram, vector_labels, test_size=valid_set_ratio, random_state=101)

## Set DeepTract Network Architecture

In [13]:
from Network import get_DeepTract_network

Using TensorFlow backend.


In [14]:
# Num of gradient direction in the DWI data
grad_directions = resampled_dwi.shape[3]
# Num of steps in each batch
max_streamline_length = np.max(get_streamlines_lengths(tractogram))
N_time_steps = int(max_streamline_length)
# Num of neurons at each GRU layer
num_neurons = [1000,1000,1000,1000,1000]
# Num. of output features (set to Num. possible directions)
num_outputs = 725
# Whether to use dropout in the net
use_dropout = True
dropout_prob = 0.3

In [15]:
model = get_DeepTract_network(N_time_steps, grad_directions, num_neurons, num_outputs, use_dropout, dropout_prob)

In [16]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
masking_1 (Masking)          (None, 339, 100)          0         
_________________________________________________________________
gru_1 (GRU)                  (None, 339, 1000)         3303000   
_________________________________________________________________
dropout_1 (Dropout)          (None, 339, 1000)         0         
_________________________________________________________________
gru_2 (GRU)                  (None, 339, 1000)         6003000   
_________________________________________________________________
dropout_2 (Dropout)          (None, 339, 1000)         0         
_________________________________________________________________
gru_3 (GRU)                  (None, 339, 1000)         6003000   
_________________________________________________________________
dropout_3 (Dropout)          (None, 339, 1000)         0         
__________

## Set training parameters

In [17]:
from keras.losses import categorical_crossentropy
from keras.optimizers import Adam
from keras.metrics import top_k_categorical_accuracy, categorical_accuracy
from keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

from os.path import join

In [18]:
learning_rate = 1e-5
num_epochs = 1
BATCH_SIZE = 16 # this will be half of the actual batch size used in training, 
                # as each batch is augmented inside the trainin_generator

In [19]:
optimizer = Adam(lr=learning_rate)
loss = categorical_crossentropy
model.compile(loss=loss, optimizer=optimizer, metrics=[categorical_accuracy])

In [20]:
output_model_path = "./trained_model"
if not os.path.exists(output_model_path):
    os.makedirs(output_model_path)

In [21]:
weights_file = join(output_model_path, 'DeepTract_weights.hdf5')
callbacks = [EarlyStopping(monitor='val_categorical_accuracy',
                           patience=5,
                           verbose=1,
                           min_delta=1e-5,
                           mode='max'),
             ReduceLROnPlateau(monitor='val_categorical_accuracy',
                               factor=0.5,
                               patience=3,
                               verbose=1,
                               min_delta=1e-5,
                               mode='max'),
             ModelCheckpoint(monitor='val_categorical_accuracy',
                             filepath=weights_file,
                             save_best_only=True,
                             mode='max')]

## Train Network

In [None]:
train_history = \
model.fit_generator(generator=train_generator(resampled_dwi, X_train, y_train, N_time_steps, num_outputs, BATCH_SIZE, dwi_means),
                    steps_per_epoch=np.ceil(float(len(X_train)) / float(BATCH_SIZE)),
                    epochs=num_epochs,
                    verbose=1,
                    callbacks=callbacks,
                    validation_data=valid_generator(resampled_dwi, X_valid, y_valid, N_time_steps, num_outputs, BATCH_SIZE, dwi_means),
                    validation_steps=np.ceil(float(len(X_valid)) / float(BATCH_SIZE)))