## Load and pre-process DWI data

In [None]:
import sys
import numpy as np
sys.path.append("./utils")
from data_handling import *

In [None]:
dwi_path = "./data/dwi"
mask_path = "./data/mask"

curr_path = os.getcwd()
dwi_file = get_file_path(curr_path, dwi_path, "*.nii*")
mask_file = get_file_path(curr_path, mask_path, "*.nii*")

In [None]:
dwi_data = nib.load(dwi_file)
dwi = dwi_data.get_data().astype("float32")
mask = nib.load(mask_file).get_data()

In [None]:
from dipy.io import read_bvals_bvecs

bval_file = get_file_path(curr_path, dwi_path, "*.bvals")
bvec_file = get_file_path(curr_path, dwi_path, "*.bvecs")

bvals, bvecs = read_bvals_bvecs(bval_file, bvec_file)

In [None]:
resampled_dwi = resample_dwi(mask_dwi(dwi, mask), bvals, bvecs, directions=None, sh_order=8, smooth=0.006)
resampled_dwi = 255 * mask_dwi(resampled_dwi, mask)

In [None]:
from train_utils import *

In [None]:
mask_path = "./data/WM_mask"
wm_mask_file = get_file_path(curr_path, mask_path, "*.nii*")
# wm_mask = nib.load(mask_file).get_data()
wm_mask = nib.load(wm_mask_file).get_data()[::2,::2,::2]

dwi_means = calc_mean_dwi(resampled_dwi, wm_mask)

## Load trained network

In [None]:
from keras.models import model_from_json

In [None]:
trained_model_path = "./trained_model"
json_file = get_file_path(curr_path, trained_model_path, "*.json*")
weights_file = get_file_path(curr_path, trained_model_path, "*.h*")

In [None]:
model_json = open(json_file, 'r')
loaded_model_json = model_json.read()
model_json.close()
loaded_model = model_from_json(loaded_model_json)
loaded_model.load_weights(weights_file)

# Run Tractography

In [None]:
from test_utils import *

#### Randomize seed points within WM mask

In [None]:
mask_path = "./data/WM_mask"
wm_mask_file = get_file_path(curr_path, mask_path, "*.nii*")
# wm_mask = nib.load(mask_file).get_data()
wm_mask = nib.load(wm_mask_file).get_data()

In [None]:
N_seeds = 400000 # set number of seed points
N_time_steps = loaded_model.input_shape[1]

Loc_seeds = init_seeds(wm_mask,N_seeds,N_time_steps)

# partition into batches of size 500
repetitions = int(N_seeds/500)
if np.mod(N_seeds,500) > 0 :
    repetitions += 1

Loc_seeds_list = []
for i in range(repetitions):
    Loc_seeds_list.append(Loc_seeds[i*500:i*500+500,:])
    
Loc_seeds = zero_pad_seeds(Loc_seeds_list[0],len(Loc_seeds_list[0]),N_time_steps)
DW_seeds = np.zeros((Loc_seeds.shape[0],Loc_seeds.shape[1],len(dwi_means)))
DW_seeds[:,0,:] = eval_volume_at_3d_coordinates(resampled_dwi, Loc_seeds[:,0,:]) - dwi_means

#### Set tracking parameters:

In [None]:
tractography_type = 'deterministic' # Either 'deterministic' or 'probabilistic'
step_size = 0.5
max_angle = 60 #in degrees
max_length = 200 # in mm
min_length = 20 # in mm
total_iters = N_time_steps
voxel_size = [2,2,2] #voxel dimensions in mm

# Max entropy threshold:
t_vec = np.arange(total_iters)
a, b, c = 3, 10, 4.5
entropy_th = a*np.exp(-t_vec/b) + c

# Auxiliary variables:
out_fibers = list( np.expand_dims(Loc_seeds[i,0,:], axis=0) for i in range(Loc_seeds.shape[0]) )
count_map = np.zeros_like(resampled_dwi[:,:,:,0])

####  Streamline Tractography

In [None]:
from itertools import compress
from scipy.stats import entropy

sphere724 = get_sphere('repulsion724')
angles724 = calc_angles_matrix(sphere724)
angles725 = np.hstack( (np.vstack((angles724,np.zeros(angles724.shape[1]))), np.zeros((angles724.shape[0]+1,1))) )
streamlines_list = []
odf_list = []
odf_array = np.zeros((3,total_iters,))
logits_list = []

dilated_wm_mask = mask_dilate(wm_mask)
ang_mat = calc_angles_matrix(sphere724)

for reps in range(repetitions):
    next_positions = Loc_seeds[:,0,:]
    EoF_mask = np.zeros(len(DW_seeds), dtype=bool)
    entropy_mask = np.zeros(len(DW_seeds), dtype=bool)
    angle_mask = np.zeros(len(DW_seeds), dtype=bool)
    inWM_mask = np.ones(len(DW_seeds), dtype=bool)
    print('Processing batch number ', reps+1, ' out of ', repetitions)
    
    for t in range(total_iters):
        print('tracking step number ', t)

        pdf_pred = loaded_model.predict_on_batch(DW_seeds)
        if tractography_type == 'deterministic':
            direction_idx_pred = argmax_from_pdf(pdf_pred[:,t,:])
        else:
            direction_idx_pred = sample_from_pdf(pdf_pred[:,t,:],1)[:,0]

        if t > 0:
            d_angles = np.array([angles725[direction_idx_pred[p],direction_idx_previous[p]] for p in range(len(direction_idx_pred))])
            angle_mask = np.logical_or(angle_mask, d_angles > max_angle)
        direction_idx_previous = direction_idx_pred

        curr_entropys = -np.sum( pdf_pred[:,t,:] * np.log(pdf_pred[:,t,:]+0.000000001) , axis=1)
        entropy_mask = np.logical_or(entropy_mask, curr_entropys > entropy_th[t])
        direction_vec_pred = idx2direction(direction_idx_pred, sphere724)

        EoF_mask = np.logical_or(EoF_mask, direction_idx_pred==sphere724.x.shape[0])
        next_positions = next_positions + step_size*direction_vec_pred*np.expand_dims(1*(~EoF_mask),axis=1)*np.expand_dims(1*(~entropy_mask),axis=1)*np.expand_dims(1*(~angle_mask),axis=1)*np.expand_dims(1*(inWM_mask),axis=1)
        inWM_mask = np.logical_and(inWM_mask, is_within_mask(2*next_positions, dilated_wm_mask).astype(bool))

        valids_mask = np.logical_and(np.logical_and(np.logical_and(~EoF_mask,~entropy_mask), ~angle_mask),inWM_mask)
        if sum(1*valids_mask) == 0:
            break
        count_map[(next_positions[valids_mask,0]).astype(int),(next_positions[valids_mask,1]).astype(int),(next_positions[valids_mask,2]).astype(int)] += 1
        for k in list(compress(range(len(valids_mask)), valids_mask)):
            out_fibers[k] = np.vstack((out_fibers[k],next_positions[k,:]))

        if t+1 < DW_seeds.shape[1]:
            DW_seeds[:,t+1,:] = eval_volume_at_3d_coordinates(resampled_dwi, next_positions) - dwi_means
    
    print('\n')
    lengths_vec = fiber_lengths(out_fibers, [2,2,2])
    filtered_out_fibers = [out_fibers[e] for e in range(len(out_fibers)) if np.logical_and(lengths_vec[e]>min_length, lengths_vec[e]<max_length)]
    streamlines_list.append(filtered_out_fibers)
    out_fibers = []

    if reps+1 < repetitions:
        Loc_seeds = zero_pad_seeds(Loc_seeds_list[reps+1],len(Loc_seeds_list[reps+1]),N_time_steps)
        out_fibers = list( np.expand_dims(Loc_seeds[i,0,:], axis=0) for i in range(Loc_seeds.shape[0]) )
        DW_seeds = np.zeros((Loc_seeds.shape[0],Loc_seeds.shape[1],len(dwi_means)))
        DW_seeds[:,0,:] = eval_volume_at_3d_coordinates(resampled_dwi, Loc_seeds[:,0,:]) - dwi_means

In [None]:
out_tractogram = output_tractogram(all_fibers)

#### Visualize tractogram

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline

fig = plt.figure(5)
ax = fig.add_subplot(111, projection='3d')
for streamline in out_tractogram:

    x = streamline[:,0]
    y = streamline[:,1]
    z = streamline[:,2]

    ax.plot(streamline[:,0], streamline[:,1], streamline[:,2])

ax.set_xlabel('X axis')
ax.set_ylabel('Y axis')
ax.set_zlabel('Z axis')
ax.view_init(elev=0., azim=0)