Skip to content

Commit

Permalink
Merge pull request #1669 from arokem/flow_csd_sh_order
Browse files Browse the repository at this point in the history
Flow csd sh order
  • Loading branch information
skoudoro committed Nov 27, 2018
2 parents a7bb518 + 0c884a6 commit 72f9516
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 77 deletions.
12 changes: 5 additions & 7 deletions dipy/workflows/reconst.py
Expand Up @@ -520,14 +520,12 @@ def run(self, input_files, bvalues, bvectors, mask_files,
atol=bvecs_tol)
mask_vol = nib.load(maskfile).get_data().astype(np.bool)

sh_order = 8
if data.shape[-1] < 15:
n_params = ((sh_order + 1) * (sh_order + 2)) / 2
if data.shape[-1] < n_params:
raise ValueError(
'You need at least 15 unique DWI volumes to '
'compute fiber odfs. You currently have: {0}'
' DWI volumes.'.format(data.shape[-1]))
elif data.shape[-1] < 30:
sh_order = 6
'You need at least {0} unique DWI volumes to '
'compute fiber odfs. You currently have: {1}'
' DWI volumes.'.format(n_params, data.shape[-1]))

if frf is None:
logging.info('Computing response function')
Expand Down
155 changes: 85 additions & 70 deletions dipy/workflows/tests/test_reconst_csa_csd.py
Expand Up @@ -13,6 +13,7 @@

from dipy.data import get_data
from dipy.workflows.reconst import ReconstCSDFlow, ReconstCSAFlow
from dipy.reconst.shm import sph_harm_ind_list
logging.getLogger().setLevel(logging.INFO)


Expand All @@ -35,77 +36,91 @@ def reconst_flow_core(flow):
nib.save(mask_img, mask_path)

reconst_flow = flow()

reconst_flow.run(data_path, bval_path, bvec_path, mask_path,
out_dir=out_dir, extract_pam_values=True)

gfa_path = reconst_flow.last_generated_outputs['out_gfa']
gfa_data = nib.load(gfa_path).get_data()
assert_equal(gfa_data.shape, volume.shape[:-1])

peaks_dir_path = reconst_flow.last_generated_outputs['out_peaks_dir']
peaks_dir_data = nib.load(peaks_dir_path).get_data()
assert_equal(peaks_dir_data.shape[-1], 15)
assert_equal(peaks_dir_data.shape[:-1], volume.shape[:-1])

peaks_idx_path = \
reconst_flow.last_generated_outputs['out_peaks_indices']
peaks_idx_data = nib.load(peaks_idx_path).get_data()
assert_equal(peaks_idx_data.shape[-1], 5)
assert_equal(peaks_idx_data.shape[:-1], volume.shape[:-1])

peaks_vals_path = \
reconst_flow.last_generated_outputs['out_peaks_values']
peaks_vals_data = nib.load(peaks_vals_path).get_data()
assert_equal(peaks_vals_data.shape[-1], 5)
assert_equal(peaks_vals_data.shape[:-1], volume.shape[:-1])

shm_path = reconst_flow.last_generated_outputs['out_shm']
shm_data = nib.load(shm_path).get_data()
assert_equal(shm_data.shape[-1], 45)
assert_equal(shm_data.shape[:-1], volume.shape[:-1])

pam = load_peaks(reconst_flow.last_generated_outputs['out_pam'])
npt.assert_allclose(pam.peak_dirs.reshape(peaks_dir_data.shape),
peaks_dir_data)
npt.assert_allclose(pam.peak_values, peaks_vals_data)
npt.assert_allclose(pam.peak_indices, peaks_idx_data)
npt.assert_allclose(pam.shm_coeff, shm_data)
npt.assert_allclose(pam.gfa, gfa_data)

bvals, bvecs = read_bvals_bvecs(bval_path, bvec_path)
bvals[0] = 5.
bvecs = generate_bvecs(len(bvals))

tmp_bval_path = pjoin(out_dir, "tmp.bval")
tmp_bvec_path = pjoin(out_dir, "tmp.bvec")
np.savetxt(tmp_bval_path, bvals)
np.savetxt(tmp_bvec_path, bvecs.T)
reconst_flow._force_overwrite = True
with npt.assert_raises(BaseException):
npt.assert_warns(UserWarning, reconst_flow.run, data_path,
tmp_bval_path, tmp_bvec_path, mask_path,
out_dir=out_dir, extract_pam_values=True)

if flow.get_short_name() == 'csd':

reconst_flow = flow()
reconst_flow._force_overwrite = True
reconst_flow.run(data_path, bval_path, bvec_path, mask_path,
out_dir=out_dir, frf=[15, 5, 5])
reconst_flow = flow()
reconst_flow._force_overwrite = True
reconst_flow.run(data_path, bval_path, bvec_path, mask_path,
out_dir=out_dir, frf='15, 5, 5')
reconst_flow = flow()
for sh_order in [4, 6, 8]:
if flow.get_short_name() == 'csd':

reconst_flow.run(data_path, bval_path, bvec_path, mask_path,
sh_order=sh_order,
out_dir=out_dir, extract_pam_values=True)

elif flow.get_short_name() == 'csa':

reconst_flow.run(data_path, bval_path, bvec_path, mask_path,
sh_order=sh_order,
odf_to_sh_order=sh_order,
out_dir=out_dir, extract_pam_values=True)

gfa_path = reconst_flow.last_generated_outputs['out_gfa']
gfa_data = nib.load(gfa_path).get_data()
assert_equal(gfa_data.shape, volume.shape[:-1])

peaks_dir_path =\
reconst_flow.last_generated_outputs['out_peaks_dir']
peaks_dir_data = nib.load(peaks_dir_path).get_data()
assert_equal(peaks_dir_data.shape[-1], 15)
assert_equal(peaks_dir_data.shape[:-1], volume.shape[:-1])

peaks_idx_path = \
reconst_flow.last_generated_outputs['out_peaks_indices']
peaks_idx_data = nib.load(peaks_idx_path).get_data()
assert_equal(peaks_idx_data.shape[-1], 5)
assert_equal(peaks_idx_data.shape[:-1], volume.shape[:-1])

peaks_vals_path = \
reconst_flow.last_generated_outputs['out_peaks_values']
peaks_vals_data = nib.load(peaks_vals_path).get_data()
assert_equal(peaks_vals_data.shape[-1], 5)
assert_equal(peaks_vals_data.shape[:-1], volume.shape[:-1])

shm_path = reconst_flow.last_generated_outputs['out_shm']
shm_data = nib.load(shm_path).get_data()
# Test that the number of coefficients is what you would expect
# given the order of the sh basis:
assert_equal(shm_data.shape[-1],
sph_harm_ind_list(sh_order)[0].shape[0])
assert_equal(shm_data.shape[:-1], volume.shape[:-1])

pam = load_peaks(reconst_flow.last_generated_outputs['out_pam'])
npt.assert_allclose(pam.peak_dirs.reshape(peaks_dir_data.shape),
peaks_dir_data)
npt.assert_allclose(pam.peak_values, peaks_vals_data)
npt.assert_allclose(pam.peak_indices, peaks_idx_data)
npt.assert_allclose(pam.shm_coeff, shm_data)
npt.assert_allclose(pam.gfa, gfa_data)

bvals, bvecs = read_bvals_bvecs(bval_path, bvec_path)
bvals[0] = 5.
bvecs = generate_bvecs(len(bvals))

tmp_bval_path = pjoin(out_dir, "tmp.bval")
tmp_bvec_path = pjoin(out_dir, "tmp.bvec")
np.savetxt(tmp_bval_path, bvals)
np.savetxt(tmp_bvec_path, bvecs.T)
reconst_flow._force_overwrite = True
reconst_flow.run(data_path, bval_path, bvec_path, mask_path,
out_dir=out_dir, frf=None)
reconst_flow2 = flow()
reconst_flow2._force_overwrite = True
reconst_flow2.run(data_path, bval_path, bvec_path, mask_path,
out_dir=out_dir, frf=None,
roi_center=[10, 10, 10])
with npt.assert_raises(BaseException):
npt.assert_warns(UserWarning, reconst_flow.run, data_path,
tmp_bval_path, tmp_bvec_path, mask_path,
out_dir=out_dir, extract_pam_values=True)

if flow.get_short_name() == 'csd':

reconst_flow = flow()
reconst_flow._force_overwrite = True
reconst_flow.run(data_path, bval_path, bvec_path, mask_path,
out_dir=out_dir, frf=[15, 5, 5])
reconst_flow = flow()
reconst_flow._force_overwrite = True
reconst_flow.run(data_path, bval_path, bvec_path, mask_path,
out_dir=out_dir, frf='15, 5, 5')
reconst_flow = flow()
reconst_flow._force_overwrite = True
reconst_flow.run(data_path, bval_path, bvec_path, mask_path,
out_dir=out_dir, frf=None)
reconst_flow2 = flow()
reconst_flow2._force_overwrite = True
reconst_flow2.run(data_path, bval_path, bvec_path, mask_path,
out_dir=out_dir, frf=None,
roi_center=[10, 10, 10])


if __name__ == '__main__':
Expand Down

0 comments on commit 72f9516

Please sign in to comment.