Skip to content

Commit

Permalink
Merge pull request #415 from MrBago/fixup_predict
Browse files Browse the repository at this point in the history
RF - move around some of the predict stuff
  • Loading branch information
arokem committed Sep 23, 2014
2 parents 94ec5bf + 667cfda commit 9881004
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 159 deletions.
6 changes: 5 additions & 1 deletion dipy/data/__init__.py
Expand Up @@ -26,7 +26,7 @@ def loads_compat(bytes):
import gzip
import numpy as np
from dipy.core.gradients import GradientTable, gradient_table
from dipy.core.sphere import Sphere
from dipy.core.sphere import Sphere, HemiSphere
from dipy.sims.voxel import SticksAndBall
import numpy as np
from dipy.data.fetcher import (fetch_scil_b0,
Expand Down Expand Up @@ -169,6 +169,10 @@ def get_sphere(name='symmetric362'):
faces=as_native_array(res['faces']))


default_sphere = HemiSphere.from_sphere(get_sphere('symmetric724'))
small_sphere = HemiSphere.from_sphere(get_sphere('symmetric362'))


def get_data(name='small_64D'):
""" provides filenames of some test datasets or other useful parametrisations
Expand Down
120 changes: 44 additions & 76 deletions dipy/reconst/csdeconv.py
Expand Up @@ -110,6 +110,7 @@ def __init__(self, gtab, response, reg_sphere=None, sh_order=8, lambda_=1,
self.response = response

self.S_r = estimate_response(gtab, self.response[0], self.response[1])
self.response_scaling = response[1]

r_sh = np.linalg.lstsq(self.B_dwi, self.S_r[self._where_dwi])[0]
r_rh = sh_to_rh(r_sh, m, n)
Expand All @@ -134,12 +135,50 @@ def fit(self, data):
return SphHarmFit(self, shm_coeff, None)


def predict(self, sh_coeff, S0=1):
"""
Predict a signal from sh coefficients for this model
def predict(self, sh_coeff, gtab=None, S0=1):
"""Compute a signal prediction given spherical harmonic coefficients
and (optionally) a response function for the provided GradientTable
class instance.
Parameters
----------
sh_coeff : ndarray
The spherical harmonic representation of the FOD from which to make
the signal prediction.
gtab : GradientTable
The gradients for which the signal will be predicted. Use the
model's gradient table by default.
S0 : ndarray or float
The non diffusion-weighted signal value.
Returns
-------
pred_sig : ndarray
The predicted signal.
"""
return csd_predict(sh_coeff, self.gtab, response=self.response, S0=S0,
R=self.R)
if gtab is None or gtab is self.gtab:
SH_basis = self.B_dwi
gtab = self.gtab
else:
x, y, z = gtab.gradients[~gtab.b0s_mask].T
r, theta, phi = cart2sphere(x, y, z)
SH_basis, m, n = real_sym_sh_basis(self.sh_order, theta, phi)

# Because R is diagonal, the matrix multiply is written as a multiply
predict_matrix = SH_basis * self.R.diagonal()
S0 = np.asarray(S0)[..., None]
scaling = S0 / self.response_scaling

# This is the key operation: convolve and multiply by S0:
pre_pred_sig = scaling * np.dot(predict_matrix, sh_coeff)

# Now put everything in its right place:
pred_sig = np.zeros(pre_pred_sig.shape[:-1] + (gtab.bvals.shape[0],))
pred_sig[..., ~gtab.b0s_mask] = pre_pred_sig
pred_sig[..., gtab.b0s_mask] = S0

return pred_sig


class ConstrainedSDTModel(OdfModel, Cache):
Expand Down Expand Up @@ -638,74 +677,3 @@ def auto_response(gtab, data, roi_center=None, roi_radius=10, fa_thr=0.7):
ratio = evals[1]/evals[0]
return response, ratio


def csd_predict(sh_coeff, gtab, response=None, S0=1, R=None):
"""
Compute a signal prediction given spherical harmonic coefficients and
(optionally) a response function for the provided GradientTable class
instance
Parameters
----------
sh_coeff : ndarray
Spherical harmonic coefficients
gtab : GradientTable class instance
response : tuple
A tuple with two elements. The first is the eigen-values as an (3,)
ndarray and the second is the signal value for the response
function without diffusion weighting.
Default: (np.array([0.0015, 0.0003, 0.0003]), 1)
S0 : ndarray or float
The non diffusion-weighted signal value.
R : ndarray
Optionally, provide an R matrix. If not provided, calculated from the
gtab, response function, etc.
Returns
-------
pred_sig : ndarray
The signal predicted from the provided SH coefficients for a
measurement with the provided GradientTable. The last dimension of the
resulting array is the same as the number of bvals/bvecs in the
GradientTable. The first dimensions have shape: `sh_coeff.shape[:-1]`.
"""
n_coeff = sh_coeff.shape[-1]
sh_order = order_from_ncoef(n_coeff)
x, y, z = gtab.gradients[~gtab.b0s_mask].T
r, theta, phi = cart2sphere(x, y, z)
SH_basis, m, n = real_sym_sh_basis(sh_order, theta, phi)
if R is None:
# for the gradient sphere
B_dwi = real_sph_harm(m, n, theta[:, None], phi[:, None])

if response is None:
response = (np.array([0.0015, 0.0003, 0.0003]), 1)
else:
response = response

S_r = estimate_response(gtab, response[0], response[1])
r_sh = np.linalg.lstsq(B_dwi, S_r[~gtab.b0s_mask])[0]
r_rh = sh_to_rh(r_sh, m, n)
R = forward_sdeconv_mat(r_rh, n)

predict_matrix = np.dot(SH_basis, R)

if np.iterable(S0):
# If it's an array, we need to give it one more dimension:
S0 = S0[..., None]

# This is the key operation: convolve and multiply by S0:
pre_pred_sig = S0 * np.dot(predict_matrix, sh_coeff)

# Now put everything in its right place:
pred_sig = np.zeros(pre_pred_sig.shape[:-1] + (gtab.bvals.shape[0],))
pred_sig[..., ~gtab.b0s_mask] = pre_pred_sig
pred_sig[..., gtab.b0s_mask] = S0

return pred_sig


4 changes: 1 addition & 3 deletions dipy/reconst/peaks.py
Expand Up @@ -14,12 +14,10 @@

from .recspeed import local_maxima, remove_similar_vertices, search_descending
from dipy.core.sphere import HemiSphere, Sphere
from dipy.data import get_sphere
from dipy.data import default_sphere
from dipy.core.ndindex import ndindex
from dipy.reconst.shm import sh_to_sf_matrix

default_sphere = HemiSphere.from_sphere(get_sphere('symmetric724'))


def peak_directions_nl(sphere_eval, relative_peak_threshold=.25,
min_separation_angle=25, sphere=default_sphere,
Expand Down
70 changes: 5 additions & 65 deletions dipy/reconst/shm.py
Expand Up @@ -517,44 +517,6 @@ def fit(self, data, mask=None):
return SphHarmFit(self, coef, mask)


def _shm_predict(fit, gtab, S0=1):
"""
Helper function for the implementation of model prediction from the
ConstrainedSphericalDeconvFit class. This is necessary, because in
multi-voxel data, the multi_vox_fit kicks in.
Parameters
----------
fit : A ConstrainedSphericalDeconvFit class instance
The prediction will be done using the parameters in this object.
gtab : A GradientTable class instance
The prediction will be done for the bval/bvec combinations in this
object.
S0 : float or ndarray (optional)
The mean non-diffusion weighted signal in the voxel or volume.
Returns
-------
pred_sig : ndarray
The predicted signal in the gtab for this fit.
"""
sphere = Sphere(xyz=gtab.bvecs[~gtab.b0s_mask])
prediction_matrix = fit.prediction_matrix(sphere, gtab)

if np.iterable(S0):
# If it's an array, we need to give it one more dimension:
S0 = S0[..., None]

# This is the key operation: convolve and multiply by S0:
pre_pred_sig = S0 * np.dot(prediction_matrix, fit._shm_coef)
# Now put everything in its right place:
pred_sig = np.zeros(pre_pred_sig.shape[:-1] + (gtab.bvals.shape[0],))
pred_sig[..., ~gtab.b0s_mask] = pre_pred_sig
pred_sig[..., gtab.b0s_mask] = S0

return pred_sig


class SphHarmFit(OdfFit):
"""Diffusion data fit to a spherical harmonic model"""

Expand Down Expand Up @@ -626,32 +588,7 @@ def shm_coeff(self):
return self._shm_coef


def prediction_matrix(self, sphere, gtab):
"""
A matrix used to predict the signal from an estimated ODF
Parameters
----------
sphere : a Sphere class instance
gtab : a GradientTable class instance
"""
prediction_matrix = self.model.cache_get("prediction_matrix", (sphere,
gtab))
if prediction_matrix is None:
pred_gtab = grad.gradient_table(
gtab.bvals[~gtab.b0s_mask],
gtab.bvecs[~gtab.b0s_mask])
x, y, z = pred_gtab.gradients.T
r, theta, phi = cart2sphere(x, y, z)
SH_basis, m, n = real_sym_sh_basis(self.model.sh_order, theta, phi)
# The prediction matrix needs to be normalized to the response S0:
prediction_matrix = (np.dot(SH_basis, self.model.R) /
self.model.response[1])

return prediction_matrix


def predict(self, gtab, S0=1.0):
def predict(self, gtab=None, S0=1.0):
"""
Predict the diffusion signal from the model coefficients.
Expand All @@ -665,7 +602,10 @@ def predict(self, gtab, S0=1.0):
all voxels
"""
return _shm_predict(self, gtab, S0)
if not hasattr(self.model, 'predict'):
msg = "This model does not have prediction implemented yet"
raise NotImplementedError(msg)
return self.model.predict(self.shm_coeff, gtab, S0)


class CsaOdfModel(SphHarmModel):
Expand Down
41 changes: 27 additions & 14 deletions dipy/reconst/tests/test_csdeconv.py
Expand Up @@ -4,20 +4,18 @@
from numpy.testing import (assert_, assert_equal, assert_almost_equal,
assert_array_almost_equal, run_module_suite,
assert_array_equal)
from dipy.data import get_sphere, get_data
from dipy.data import get_sphere, get_data, default_sphere, small_sphere
from dipy.sims.voxel import (multi_tensor,
single_tensor,
multi_tensor_odf,
all_tensor_evecs)
from dipy.core.gradients import gradient_table
from dipy.core.sphere import HemiSphere
from dipy.reconst.csdeconv import (ConstrainedSphericalDeconvModel,
ConstrainedSDTModel,
forward_sdeconv_mat,
odf_deconv,
odf_sh_to_sharp,
auto_response,
csd_predict)
auto_response)
from dipy.reconst.peaks import peak_directions, default_sphere
from dipy.core.sphere_stats import angular_similarity
from dipy.reconst.shm import (sf_to_sh, sh_to_sf, QballModel,
Expand Down Expand Up @@ -274,19 +272,34 @@ def test_csd_predict():
angles = [(0, 0), (60, 0)]
S, sticks = multi_tensor(gtab, mevals, S0, angles=angles,
fractions=[50, 50], snr=SNR)
sphere = get_sphere('symmetric362')
sphere = small_sphere
odf_gt = multi_tensor_odf(sphere.vertices, mevals, angles, [50, 50])
response = (np.array([0.0015, 0.0003, 0.0003]), S0)

csd = ConstrainedSphericalDeconvModel(gtab, response)
csd_fit = csd.fit(S)
prediction = csd_predict(csd_fit.shm_coeff, gtab, response=response, S0=S0)
npt.assert_equal(prediction.shape[0], S.shape[0])
model_prediction = csd.predict(csd_fit.shm_coeff)
assert_array_almost_equal(prediction, model_prediction)
# Roundtrip tests (quite inaccurate, because of regularization):
assert_array_almost_equal(csd_fit.predict(gtab, S0=S0),S,decimal=1)
assert_array_almost_equal(csd.predict(csd_fit.shm_coeff, S0=S0),S,decimal=1)

# Predicting from a fit should give the same result as predicting from a
# model, S0 is 1 by default
prediction1 = csd_fit.predict()
prediction2 = csd.predict(csd_fit.shm_coeff)
npt.assert_array_equal(prediction1, prediction2)
npt.assert_array_equal(prediction1[..., gtab.b0s_mask], 1.)

# Same with a different S0
prediction1 = csd_fit.predict(S0=123.)
prediction2 = csd.predict(csd_fit.shm_coeff, S0=123.)
npt.assert_array_equal(prediction1, prediction2)
npt.assert_array_equal(prediction1[..., gtab.b0s_mask], 123.)

# For "well behaved" coefficients, the model should be able to find the
# coefficients from the predicted signal.
coeff = np.random.random(csd_fit.shm_coeff.shape) - .5
coeff[..., 0] = 10.
S = csd.predict(coeff)
csd_fit = csd.fit(S)
npt.assert_array_almost_equal(coeff, csd_fit.shm_coeff)


def test_sphere_scaling_csdmodel():
"""Check that mirroring regulization sphere does not change the result of
Expand All @@ -305,8 +318,8 @@ def test_sphere_scaling_csdmodel():
S, sticks = multi_tensor(gtab, mevals, 100., angles=angles,
fractions=[50, 50], snr=None)

sphere = get_sphere('symmetric362')
hemi = HemiSphere.from_sphere(sphere)
hemi = small_sphere
sphere = hemi.mirror()

response = (np.array([0.0015, 0.0003, 0.0003]), 100)
model_full = ConstrainedSphericalDeconvModel(gtab, response,
Expand Down

0 comments on commit 9881004

Please sign in to comment.