Skip to content

Commit

Permalink
Merge pull request #2478 from jhlegarreta/DealAppropriatelyWithUserWa…
Browse files Browse the repository at this point in the history
…rnings

ENH: Deal appropriately with user warnings
  • Loading branch information
skoudoro committed Nov 6, 2021
2 parents c22e18d + db632c6 commit 7ca2d6c
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 22 deletions.
6 changes: 3 additions & 3 deletions dipy/core/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ def gradient_table_from_bvals_bvecs(bvals, bvecs, b0_threshold=50, atol=1e-2,

# checking for the correctness of bvals
if b0_threshold < bvals.min():
warn("b0_threshold (value: {0}) is too low, increase your \
b0_threshold. It should be higher than the lowest b0 value \
({1}).".format(b0_threshold, bvals.min()))
warn("b0_threshold (value: {0}) is too low, increase your "
"b0_threshold. It should be higher than the lowest b0 value "
"({1}).".format(b0_threshold, bvals.min()))

bvecs = np.where(np.isnan(bvecs), 0, bvecs)
bvecs_close_to_1 = abs(vector_norm(bvecs) - 1) <= atol
Expand Down
2 changes: 1 addition & 1 deletion dipy/reconst/shm.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def real_sh_descoteaux(sh_order, theta, phi,
return real_sh, m, n


@deprecate_with_version('dipy.reconst.shm.real_sym_sh_mrtix is deprecated, '
@deprecate_with_version('dipy.reconst.shm.real_sym_sh_mrtrix is deprecated, '
'Please use dipy.reconst.shm.real_sh_tournier instead',
since='1.3', until='2.0')
def real_sym_sh_mrtrix(sh_order, theta, phi):
Expand Down
3 changes: 2 additions & 1 deletion dipy/reconst/tests/test_ivim.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def setup_module():
120., 140., 160., 180., 200., 300., 400.,
500., 600., 700., 800., 900., 1000.])

gtab_no_b0 = gradient_table(bvals_no_b0, bvecs.T, b0_threshold=0)
with pytest.warns(UserWarning):
gtab_no_b0 = gradient_table(bvals_no_b0, bvecs.T, b0_threshold=0)

bvals_with_multiple_b0 = np.array([0., 0., 0., 0., 40., 60., 80., 100.,
120., 140., 160., 180., 200., 300.,
Expand Down
44 changes: 28 additions & 16 deletions dipy/reconst/tests/test_mcsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,20 @@ def test_multi_shell_fiber_response():
def test_mask_for_response_msmt():
gtab, data, masks_gt, _ = get_test_data()

wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab, data,
roi_center=None,
roi_radii=(1, 1, 0),
wm_fa_thr=0.7,
gm_fa_thr=0.3,
csf_fa_thr=0.15,
gm_md_thr=0.001,
csf_md_thr=0.0032)
with warnings.catch_warnings(record=True) as w:
wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab, data,
roi_center=None,
roi_radii=(1, 1, 0),
wm_fa_thr=0.7,
gm_fa_thr=0.3,
csf_fa_thr=0.15,
gm_md_thr=0.001,
csf_md_thr=0.0032)

npt.assert_equal(len(w), 1)
npt.assert_(issubclass(w[0].category, UserWarning))
npt.assert_("""Some b-values are higher than 1200.""" in
str(w[0].message))

# Verifies that masks are not empty:
masks_sum = int(np.sum(wm_mask) + np.sum(gm_mask) + np.sum(csf_mask))
Expand All @@ -214,14 +220,20 @@ def test_mask_for_response_msmt():
def test_mask_for_response_msmt_nvoxels():
gtab, data, _, _ = get_test_data()

wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab, data,
roi_center=None,
roi_radii=(1, 1, 0),
wm_fa_thr=0.7,
gm_fa_thr=0.3,
csf_fa_thr=0.15,
gm_md_thr=0.001,
csf_md_thr=0.0032)
with warnings.catch_warnings(record=True) as w:
wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab, data,
roi_center=None,
roi_radii=(1, 1, 0),
wm_fa_thr=0.7,
gm_fa_thr=0.3,
csf_fa_thr=0.15,
gm_md_thr=0.001,
csf_md_thr=0.0032)

npt.assert_equal(len(w), 1)
npt.assert_(issubclass(w[0].category, UserWarning))
npt.assert_("""Some b-values are higher than 1200.""" in
str(w[0].message))

wm_nvoxels = np.sum(wm_mask)
gm_nvoxels = np.sum(gm_mask)
Expand Down
18 changes: 17 additions & 1 deletion dipy/reconst/tests/test_shm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
import numpy as np
import numpy.linalg as npl
import numpy.testing as npt

from dipy.testing import assert_true
from numpy.testing import (assert_array_equal, assert_array_almost_equal,
Expand Down Expand Up @@ -117,7 +118,22 @@ def test_gen_dirac():

def test_real_sym_sh_mrtrix():
coef, expected, sphere = mrtrix_spherical_functions()
basis, m, n = real_sym_sh_mrtrix(8, sphere.theta, sphere.phi)

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
basis, m, n = real_sym_sh_mrtrix(8, sphere.theta, sphere.phi)

npt.assert_equal(len(w), 2)
npt.assert_(issubclass(w[0].category, DeprecationWarning))
npt.assert_(
"dipy.reconst.shm.real_sym_sh_mrtrix is deprecated, Please use "
"dipy.reconst.shm.real_sh_tournier instead" in str(w[0].message))
npt.assert_(issubclass(w[1].category, PendingDeprecationWarning))
npt.assert_(
"The legacy tournier07 basis is outdated and will be deprecated "
"in a future release of DIPY. Consider using the new tournier07 "
"basis." in str(w[1].message))

func = np.dot(coef, basis.T)
assert_array_almost_equal(func, expected, 4)

Expand Down

0 comments on commit 7ca2d6c

Please sign in to comment.