In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [4]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import GPy
from diGP.preprocessing import (readHCP,
                                averageb0Volumes,
                                createBrainMaskFromb0Data,
                                replaceNegativeData,
                                normalize_data)
from diGP.dataManipulations import (DataHandler,
                                    log_q_squared,
                                    generateCoordinates)
from diGP.generateSyntheticData import combineCoordinatesAndqVecs

%matplotlib inline

ImportError: cannot import name 'normalize_data'

In [None]:
dataPath = 'C:\\Users\\sesjojen\\Documents\\Data\\HumanConnectomeProject\\mgh_1007\\diff\\preproc'
print(dataPath)

In [None]:
gtab, data, voxelSize = readHCP(dataPath)
print(gtab.info)

In [None]:
data = replaceNegativeData(data, gtab)

b0 = averageb0Volumes(data, gtab)
mask = createBrainMaskFromb0Data(b0)
data = normalize_d
#maskIdx = np.nonzero(mask)

def normalize_signal(data, maskIdx, b0):
    S = data[maskIdx[0], maskIdx[1], maskIdx[2], :]
    S = S/b0[maskIdx[0], maskIdx[1], maskIdx[2], np.newaxis]
    maxSignal = 1.5
    percentExceedingMaxSignal = 100*np.sum(S>maxSignal)/np.prod(S.shape)
    print('Replacing the top {} % values with {}.'.format(percentExceedingMaxSignal, maxSignal))
    S[S > maxSignal] = maxSignal
    return S

S = normalize_signal(data, maskIdx, b0)

Just to detect problems further down the road, we will for now reduce the spatial extent used.

In [None]:
#handler = DataHandler(gtab, data[60:70, 60:70, 25:55:2, :]/b0[60:70, 60:70, 25:55:2, None] + 1e-6,
#                      voxelSize=(voxelSize[0], voxelSize[1], 2*voxelSize[2]), qMagnitudeTransform=log_q_squared)
handler = DataHandler(gtab, data[:, :, ::2, :], voxelSize=(voxelSize[0], voxelSize[1], 2*voxelSize[2]), qMagnitudeTransform=log_q_squared)

In [None]:
spatialLengthScale = 5
bValLengthScale = 3

kernel = (GPy.kern.RBF(input_dim=1, active_dims=[0],
                       variance=1,
                       lengthscale=spatialLengthScale) *
          GPy.kern.RBF(input_dim=1, active_dims=[1],
                       variance=1,
                       lengthscale=spatialLengthScale) *
          GPy.kern.RBF(input_dim=1, active_dims=[2],
                       variance=1,
                       lengthscale=spatialLengthScale) *
          GPy.kern.Matern52(input_dim=1, active_dims=[3],
                            variance=1,
                            lengthscale=bValLengthScale) *
          GPy.kern.LegendrePolynomial(
             input_dim=3,
             coefficients=np.array((2, 0.5, 0.05)),
             orders=(0, 2, 4),
             active_dims=(4, 5, 6)))

kernel.parts[0].variance.fix(value=1)
kernel.parts[1].variance.fix(value=1)
kernel.parts[2].variance.fix(value=1)
kernel.parts[3].variance.fix(value=1)

In [None]:
handler.y.shape

In [None]:
grid_dims = [[0], [1], [2], [3, 4, 5, 6]]

model = GPy.models.GPRegressionGrid(handler.X, handler.y, kernel, grid_dims=grid_dims)        

In [None]:
model.optimize(messages=True)

In [None]:
print(model)
print("Legendre coefficients: {}".format(model.mul.LegendrePolynomial.coefficients))

As a test, let's try doing some simple spatial interpolation. Specifically, let X be every second slice and Xnew those in between. More advanced would be a checker board pattern.

In [None]:
handlerPred = DataHandler(gtab, data[:, :, 1::2, :],
                          voxelSize=(voxelSize[0], voxelSize[1], 2*voxelSize[2]), qMagnitudeTransform=log_q_squared)

In [None]:
import inspect
inspect.getfile(model.predict_noiseless)

In [None]:
mu = model.predict_noiseless(handlerPred.X, compute_var=False)

Visualize the results: slice below, interpolated, slice above
Also show predictive variance? Have to think about current implementation of prediction, which will probably be too memory intensive.

In [None]:
plt.hist(handlerPred.y-mu, bins=500);

In [None]:
yTrue = np.reshape(handler.y, handler.originalShape)
mu = mu.reshape(handlerPred.originalShape)
mu.shape

In [None]:
sns.set_style("dark")
f, axs = plt.subplots(5, 3)
f.set_figheight(16*3)
f.set_figwidth(16)
for i in range(5):
    axs[i, 0].imshow(yTrue[:, :, 5, i], vmin=0, vmax=1)
    axs[i, 1].imshow(mu[:, :, 5, i], vmin=0, vmax=1)
    axs[i, 2].imshow(yTrue[:, :, 6, i], vmin=0, vmax=1)