In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt
import GPy
from diGP.preprocessing_pipelines import get_SPARC_train_and_test
from diGP.dataManipulations import (DataHandler, log_q_squared)
from diGP.model import Model
from diGP.evaluation import get_SPARC_metrics

%matplotlib inline

In [None]:
with open('../config.json', 'r') as json_file:
    conf = json.load(json_file)
data_paths = conf['SPARC']['data_paths']
q_test_path = conf['SPARC']['q_test_path']

Load data to use for prediction.

In [None]:
source = 'gradient_20'
gtab, data, voxelSize = get_SPARC_train_and_test(data_paths[source], data_paths['goldstandard'], q_test_path)

Fit various base models that could be used as the mean of the GP.

In [None]:
import dipy.reconst.dti as dti

tenmodel = dti.TensorModel(gtab['train'])
tenfit = tenmodel.fit(data['train'])

residuals = {'DTI': data['train'] - tenfit.predict(gtab['train'])}
pred = {'DTI': tenfit.predict(gtab['test'])}

from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel

response = ((3e-3, 2e-4, 0), 1)
csd_model = ConstrainedSphericalDeconvModel(gtab['train'], response, sh_order=4)
csd_fit = csd_model.fit(data['train'])
residuals['CSD'] = data['train'] - csd_fit.predict(gtab['train'])
pred['CSD'] = csd_fit.predict(gtab['test'])

In [None]:
residuals['MAP'] = data['train'] - np.load(os.path.join(data_paths[source], 'map_mri_train.npy'))
pred['MAP'] = np.load(os.path.join(data_paths[source], 'map_mri_test.npy'))

In [None]:
base_model = 'MAP'
plt.imshow(residuals[base_model][:,:,20], cmap='gray')

It is clear that there are spatial correlations in the residuals.

In [None]:
qMagnitudeTransform = np.sqrt
handler = DataHandler(gtab['train'], residuals[base_model], qMagnitudeTransform=qMagnitudeTransform,
                      voxelSize=voxelSize[0:2])
handlerPred = DataHandler(gtab['test'], data=None, spatial_shape=data['test'].shape[0:2],
                          qMagnitudeTransform=qMagnitudeTransform, voxelSize=voxelSize[0:2])

b = np.array([0, 1000, 2000, 3000, 4000, 5000])
q = np.sqrt(b/0.07*1e-3)/(2*np.pi)
print(q)


plt.hist(residuals['MAP'].flatten(), bins=100);

In [None]:
spatialLengthScale = 2
bValLengthScale = 1

kernel = (GPy.kern.RBF(input_dim=1, active_dims=[0],
                       variance=1,
                       lengthscale=spatialLengthScale) *
          GPy.kern.RBF(input_dim=1, active_dims=[1],
                       variance=1,
                       lengthscale=spatialLengthScale) *
          GPy.kern.RBF(input_dim=1, active_dims=[2],
                            variance=1,
                            lengthscale=bValLengthScale) *
          GPy.kern.LegendrePolynomial(
             input_dim=3,
             coefficients=np.array((1e-2, 1e-3, 1e-4)),
             orders=(0, 2, 4),
             active_dims=(3, 4, 5)))

kernel.parts[0].variance.fix(value=1)
kernel.parts[1].variance.fix(value=1)
kernel.parts[2].variance.fix(value=1)

grid_dims = [[0], [1], [2, 3, 4, 5]]

model = Model(handler, kernel, data_handler_pred=handlerPred, grid_dims=grid_dims, verbose=False)

spatialLengthScale = 2
bValLengthScale = 1

kernel = (GPy.kern.RBF(input_dim=1, active_dims=[0],
                       variance=1,
                       lengthscale=spatialLengthScale) *
          GPy.kern.RBF(input_dim=1, active_dims=[1],
                       variance=1,
                       lengthscale=spatialLengthScale) *
          GPy.kern.RBF(input_dim=1, active_dims=[2],
                            variance=1,
                            lengthscale=bValLengthScale) *
          GPy.kern.LegendrePolynomial(
             input_dim=3,
             coefficients=np.array((1, 0)),
             orders=(0, 2),
             active_dims=(3, 4, 5)))

kernel.parts[0].variance.fix(value=1)
kernel.parts[1].variance.fix(value=1)
kernel.parts[3].coefficients.fix(value=(1, 0))

grid_dims = [[0], [1], [2, 3, 4, 5]]

model = Model(handler, kernel, data_handler_pred=handlerPred, grid_dims=grid_dims, verbose=False)

In [None]:
np.random.seed(0)
model.train(restarts=True)

print(model.GP_model)
print("\nLegendre coefficients: \n{}".format(model.GP_model.mul.LegendrePolynomial.coefficients))

In [None]:
mu = model.predict(compute_var=False)
pred_residuals = model.data_handler_pred.untransform(mu)

In [None]:
pred["{} + GP".format(base_model)] = pred[base_model] + pred_residuals

In [None]:
for key, value in pred.items():
    print("\n{}:".format(key))
    get_SPARC_metrics(gtab['test'], data['test'], value, verbose=True)

In [None]:
import os
import pickle
with open(os.path.join(data_paths['gradient_20'], 'batch_run_prediction_results.p'), 'rb') as fp:
    pred = pickle.load(fp)