In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn.functional as F
import numpy as np
import random
import sys
sys.path.insert(0,'..')

from dfibert.tracker.nn.rl import Agent
import dfibert.envs.RLTractEnvironment_fast as RLTe

from dfibert.tracker import save_streamlines

import matplotlib.pyplot as plt
%matplotlib notebook

#from train import load_model

# I. HCP Tracking
The environment is able to run tracking on a fixed set of datasets. At the moment, it is able to load HCP data as well as ISMRM data. The following cells shows the initalisation of our environment on HCP dataset `100307` while seed points are automatically determined at voxels with fa-value >= 0.2 via `seeds = None`.

In [None]:
env = RLTe.RLTractEnvironment(step_width=0.8, dataset = '100307',
                              device = 'cpu', seeds = None, tracking_in_RAS = False,
                              odf_state = False, odf_mode = "DTI")

In [None]:
streamlines = env.track()

We can also directly visualize our streamlines in this notebook by `ax.plot3d`. However, a single streamline is typically very hard to comprehend so this is merely one tool to qualitatively reason about major bugs in our tracking code.

In [None]:
%matplotlib notebook
streamline_index = 9
streamline_np = np.stack(streamlines[streamline_index])

fig = plt.figure()
ax = plt.axes(projection='3d')
#ax.plot3D(env.referenceStreamline_ijk.T[0], env.referenceStreamline_ijk.T[1], env.referenceStreamline_ijk.T[2], '-*')
ax.plot3D(streamline_np[:,0], streamline_np[:,1], streamline_np[:,2])
#plt.legend(['gt', 'agent'])
plt.legend('agent')

# II. Evaluation of Cortico Spinal Tract @ ISMRM benchmark data
We will now be using our environment along with our reward function to track streamlines on the ISMRM dataset. For this purpose, we first initialise our environment and set seed points to the cortico spinal tract. We precomputed seed points in IJK for our ISMRM dataset. These seeds will now be loaded into our environment.

In [3]:
seeds_CST = np.load('data/ismrm_seeds_CST.npy')
seeds_CST = torch.from_numpy(seeds_CST)

In [12]:
env = RLTe.RLTractEnvironment(dataset = 'ISMRM', step_width=0.8,
                            device = 'cuda:0', seeds = seeds_CST, action_space=20,
                              odf_mode = "DTI", 
                              fa_threshold=0.2, tracking_in_RAS=False)

Will be deprecated by NARLTractEnvironment as soon as Jos fixes all bugs in the reward function.
Loading dataset #  ISMRM




Interpolating ODF as state Value
Init tract masks for neuroanatomical reward
torch.Size([90, 108, 90, 25])


In [14]:
env.brainMask_interpolator(torch.from_numpy(np.array([0,0,0])).to(env.device))

tensor([0.], device='cuda:0')

Tracking itself can now be done by basically calling the `.track()` function that tracks our streamlines from each of the provided seed points in a forward and backward direciton.

In [20]:
streamlines = env.track()

100%|██████████| 5701/5701 [06:36<00:00, 14.39it/s] 


The streamlines are now stored as VTK file. The nice thing about this format is that we can directly import the streamlines into 3dSlicer via the slicer-dMRI extension.

In [21]:
streamlines_ras = [env.dataset.to_ras(torch.stack(sl).cpu().numpy()) for sl in streamlines]

In [22]:
save_streamlines(streamlines=streamlines_ras, path="ismrm_CST_ras_test10_noPeakFinding_20a.vtk")

We can also write our bundle masks into a file for ease of visualisation.

In [None]:
import nibabel as nib
for i in range(25):
    data = env.tractMasksAllBundles[i,:,:,:].squeeze().cpu().numpy()
    new_image = nib.Nifti1Image(data, np.eye(4))
    new_image.set_data_dtype(data.dtype)

    nib.save(new_image, 'mask_%s_%d.nii.gz' % (env.bundleNames[i],i))