In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import numpy as np
import json
import matplotlib.pyplot as plt

from dipy.reconst import mapmri
import dipy.reconst.dti as dti
from dipy.viz import window, actor
from dipy.data import get_data, get_sphere
from dipy.core.gradients import gradient_table

from diGP.preprocessing import get_HCP_loader
from diGP.preprocessing_pipelines import preprocess_SPARC
from diGP.dataManipulations import log_q_squared
from diGP.model import GaussianProcessModel, get_default_kernel

In [3]:
with open('../config.json', 'r') as json_file:
    conf = json.load(json_file)

Load the data.

In [4]:
dataset = 'SPARC'

In [5]:
if dataset == 'HCP':
    subject_path = conf['HCP']['data_paths']['mgh_1007']
    loader = get_HCP_loader(subject_path)
    small_data_path = '{}/mri/small_data.npy'.format(subject_path)

    loader.update_filename_data(small_data_path)

    data = loader.data
    gtab = loader.gtab
    voxel_size = loader.voxel_size
elif dataset == 'SPARC':
    subject_path = conf['SPARC']['data_paths']['gradient_60']

    gtab, data, voxel_size = preprocess_SPARC(subject_path, normalize=True)
    

In [6]:
btable = np.loadtxt(get_data('dsi4169btable'))
#btable = np.loadtxt(get_data('dsi515btable'))

gtab_dsi = gradient_table(btable[:, 0], btable[:, 1:],
                          big_delta=gtab.big_delta, small_delta=gtab.small_delta)

Fit a MAPL model to the data.

In [7]:
map_model_laplacian_aniso = mapmri.MapmriModel(gtab, radial_order=6,
                                               laplacian_regularization=True,
                                               laplacian_weighting='GCV')

mapfit_laplacian_aniso = map_model_laplacian_aniso.fit(data)

We want to use an FA image as background, this requires us to fit a DTI model.

In [8]:
tenmodel = dti.TensorModel(gtab)
tenfit = tenmodel.fit(data)

In [9]:
fitted = {'MAPL': mapfit_laplacian_aniso.predict(gtab)[:, :, 0],
          'DTI': tenfit.predict(gtab)[:, :, 0]}

Fit GP without mean and with DTI and MAPL as mean.

In [10]:
kern = get_default_kernel(n_max=6, spatial_dims=2)
gp_model = GaussianProcessModel(gtab, spatial_dims=2, kernel=kern, verbose=False)
gp_fit = gp_model.fit(np.squeeze(data), mean=None, voxel_size=voxel_size[0:2], retrain=True)

kern = get_default_kernel(n_max=2, spatial_dims=2)
gp_dti_model = GaussianProcessModel(gtab, spatial_dims=2, kernel=kern, verbose=False)
gp_dti_fit = gp_dti_model.fit(np.squeeze(data), mean=fitted['DTI'], voxel_size=voxel_size[0:2], retrain=True)

kern = get_default_kernel(n_max=2, spatial_dims=2)
gp_mapl_model = GaussianProcessModel(gtab, spatial_dims=2, kernel=kern, verbose=False)
gp_mapl_fit = gp_mapl_model.fit(np.squeeze(data), mean=fitted['MAPL'], voxel_size=voxel_size[0:2], retrain=True)



gp_model = GaussianProcessModel(gtab, spatial_dims=2, q_magnitude_transform=np.sqrt, verbose=False)
gp_fit = gp_model.fit(np.squeeze(data), mean=None, voxel_size=voxel_size[0:2], retrain=True)
gp_dti_fit = gp_model.fit(np.squeeze(data), mean=fitted['DTI'], voxel_size=voxel_size[0:2], retrain=True)
gp_mapl_fit = gp_model.fit(np.squeeze(data), mean=fitted['MAPL'], voxel_size=voxel_size[0:2], retrain=True)

In [11]:
pred = {'MAPL': mapfit_laplacian_aniso.predict(gtab_dsi)[:, :, 0],
        'DTI': tenfit.predict(gtab_dsi)[:, :, 0]}

### Compute the ODFs

Load an odf reconstruction sphere

In [12]:
sphere = get_sphere('symmetric724').subdivide(1)

The radial order $s$ can be increased to sharpen the results, but it might
also make the odfs noisier. Note that a "proper" ODF corresponds to $s=0$.

In [13]:
odf = {'MAPL': mapfit_laplacian_aniso.odf(sphere, s=0),
       'DTI': tenfit.odf(sphere)}

In [14]:
odf['GP'] = gp_fit.odf(sphere, gtab_dsi=gtab_dsi, mean=None)[:, :, None, :]
odf['DTI_GP'] = gp_dti_fit.odf(sphere, gtab_dsi=gtab_dsi, mean=pred['DTI'])[:, :, None, :]
odf['MAPL_GP'] = gp_mapl_fit.odf(sphere, gtab_dsi=gtab_dsi, mean=pred['MAPL'])[:, :, None, :]

## Display the ODFs

In [15]:
for name, _odf in odf.items():
    ren = window.Renderer()
    ren.background((1, 1, 1))

    odf_actor = actor.odf_slicer(_odf, sphere=sphere, scale=0.5, colormap='jet')
    background_actor = actor.slicer(tenfit.fa, opacity=1)

    odf_actor.display(z=0)
    odf_actor.RotateZ(90)

    background_actor.display(z=0)
    background_actor.RotateZ(90)
    background_actor.SetPosition(0, 0, -1)

    ren.add(background_actor)
    ren.add(odf_actor)

    window.record(ren, out_path='odfs_{}.png'.format(name), size=(1000, 1000))