Skip to content
Permalink
Browse files

Merge pull request #1849 from scott-trinkle/add_iter_to_CSD

Adds control for number of iterations in CSD recon
  • Loading branch information...
skoudoro committed Jul 17, 2019
2 parents b838f96 + 97edac6 commit 5196d0bbb109fd8aaa3e97462d004ebb4f9440a9
Showing with 22 additions and 2 deletions.
  1. +5 −2 dipy/reconst/csdeconv.py
  2. +17 −0 dipy/reconst/tests/test_csdeconv.py
@@ -61,7 +61,7 @@ def on_sphere(self, sphere):
class ConstrainedSphericalDeconvModel(SphHarmModel):

def __init__(self, gtab, response, reg_sphere=None, sh_order=8, lambda_=1,
tau=0.1):
tau=0.1, convergence=50):
r""" Constrained Spherical Deconvolution (CSD) [1]_.
Spherical deconvolution computes a fiber orientation distribution
@@ -99,6 +99,8 @@ def __init__(self, gtab, response, reg_sphere=None, sh_order=8, lambda_=1,
zero. However, to improve the stability of the algorithm, tau is
set to tau*100 % of the mean fODF amplitude (here, 10% by default)
(see [1]_). Default: 0.1
convergence : int
Maximum number of iterations to allow the deconvolution to converge.
References
----------
@@ -174,14 +176,15 @@ def __init__(self, gtab, response, reg_sphere=None, sh_order=8, lambda_=1,
self.B_reg *= lambda_
self.sh_order = sh_order
self.tau = tau
self.convergence = convergence
self._X = X = self.R.diagonal() * self.B_dwi
self._P = np.dot(X.T, X)

@multi_voxel_fit
def fit(self, data):
dwi_data = data[self._where_dwi]
shm_coeff, _ = csdeconv(dwi_data, self._X, self.B_reg, self.tau,
P=self._P)
convergence=self.convergence, P=self._P)
return SphHarmFit(self, shm_coeff, None)

def predict(self, sh_coeff, gtab=None, S0=1.):
@@ -577,6 +577,23 @@ def test_csd_superres():
assert_(all(cos_sim > .99))


def test_csd_convergence():
""" Check existence of `convergence` keyword in CSD model """
_, fbvals, fbvecs = get_fnames('small_64D')
bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs)
gtab = gradient_table(bvals, bvecs)

evals = np.array([[1.5, .3, .3]]) * [[1.], [1.]] / 1000.
S, sticks = multi_tensor(gtab, evals, snr=None, fractions=[55., 45.])

model_w_conv = ConstrainedSphericalDeconvModel(gtab, (evals[0], 3.),
sh_order=8, convergence=50)
model_wo_conv = ConstrainedSphericalDeconvModel(gtab, (evals[0], 3.),
sh_order=8)

assert_equal(model_w_conv.fit(S).shm_coeff, model_wo_conv.fit(S).shm_coeff)


if __name__ == '__main__':
# run_module_suite()
test_csdeconv()

0 comments on commit 5196d0b

Please sign in to comment.
You can’t perform that action at this time.