Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Deal appropriately with user warnings #2478

Merged
merged 3 commits into from
Nov 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions dipy/core/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ def gradient_table_from_bvals_bvecs(bvals, bvecs, b0_threshold=50, atol=1e-2,

# checking for the correctness of bvals
if b0_threshold < bvals.min():
warn("b0_threshold (value: {0}) is too low, increase your \
b0_threshold. It should be higher than the lowest b0 value \
({1}).".format(b0_threshold, bvals.min()))
warn("b0_threshold (value: {0}) is too low, increase your "
"b0_threshold. It should be higher than the lowest b0 value "
"({1}).".format(b0_threshold, bvals.min()))

bvecs = np.where(np.isnan(bvecs), 0, bvecs)
bvecs_close_to_1 = abs(vector_norm(bvecs) - 1) <= atol
Expand Down
2 changes: 1 addition & 1 deletion dipy/reconst/shm.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def real_sh_descoteaux(sh_order, theta, phi,
return real_sh, m, n


@deprecate_with_version('dipy.reconst.shm.real_sym_sh_mrtix is deprecated, '
@deprecate_with_version('dipy.reconst.shm.real_sym_sh_mrtrix is deprecated, '
'Please use dipy.reconst.shm.real_sh_tournier instead',
since='1.3', until='2.0')
def real_sym_sh_mrtrix(sh_order, theta, phi):
Expand Down
3 changes: 2 additions & 1 deletion dipy/reconst/tests/test_ivim.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def setup_module():
120., 140., 160., 180., 200., 300., 400.,
500., 600., 700., 800., 900., 1000.])

gtab_no_b0 = gradient_table(bvals_no_b0, bvecs.T, b0_threshold=0)
with pytest.warns(UserWarning):
gtab_no_b0 = gradient_table(bvals_no_b0, bvecs.T, b0_threshold=0)

bvals_with_multiple_b0 = np.array([0., 0., 0., 0., 40., 60., 80., 100.,
120., 140., 160., 180., 200., 300.,
Expand Down
44 changes: 28 additions & 16 deletions dipy/reconst/tests/test_mcsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,20 @@ def test_multi_shell_fiber_response():
def test_mask_for_response_msmt():
gtab, data, masks_gt, _ = get_test_data()

wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab, data,
roi_center=None,
roi_radii=(1, 1, 0),
wm_fa_thr=0.7,
gm_fa_thr=0.3,
csf_fa_thr=0.15,
gm_md_thr=0.001,
csf_md_thr=0.0032)
with warnings.catch_warnings(record=True) as w:
wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab, data,
roi_center=None,
roi_radii=(1, 1, 0),
wm_fa_thr=0.7,
gm_fa_thr=0.3,
csf_fa_thr=0.15,
gm_md_thr=0.001,
csf_md_thr=0.0032)

npt.assert_equal(len(w), 1)
npt.assert_(issubclass(w[0].category, UserWarning))
npt.assert_("""Some b-values are higher than 1200.""" in
str(w[0].message))

# Verifies that masks are not empty:
masks_sum = int(np.sum(wm_mask) + np.sum(gm_mask) + np.sum(csf_mask))
Expand All @@ -214,14 +220,20 @@ def test_mask_for_response_msmt():
def test_mask_for_response_msmt_nvoxels():
gtab, data, _, _ = get_test_data()

wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab, data,
roi_center=None,
roi_radii=(1, 1, 0),
wm_fa_thr=0.7,
gm_fa_thr=0.3,
csf_fa_thr=0.15,
gm_md_thr=0.001,
csf_md_thr=0.0032)
with warnings.catch_warnings(record=True) as w:
wm_mask, gm_mask, csf_mask = mask_for_response_msmt(gtab, data,
roi_center=None,
roi_radii=(1, 1, 0),
wm_fa_thr=0.7,
gm_fa_thr=0.3,
csf_fa_thr=0.15,
gm_md_thr=0.001,
csf_md_thr=0.0032)

npt.assert_equal(len(w), 1)
npt.assert_(issubclass(w[0].category, UserWarning))
npt.assert_("""Some b-values are higher than 1200.""" in
str(w[0].message))

wm_nvoxels = np.sum(wm_mask)
gm_nvoxels = np.sum(gm_mask)
Expand Down
18 changes: 17 additions & 1 deletion dipy/reconst/tests/test_shm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
import numpy as np
import numpy.linalg as npl
import numpy.testing as npt

from dipy.testing import assert_true
from numpy.testing import (assert_array_equal, assert_array_almost_equal,
Expand Down Expand Up @@ -117,7 +118,22 @@ def test_gen_dirac():

def test_real_sym_sh_mrtrix():
coef, expected, sphere = mrtrix_spherical_functions()
basis, m, n = real_sym_sh_mrtrix(8, sphere.theta, sphere.phi)

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
basis, m, n = real_sym_sh_mrtrix(8, sphere.theta, sphere.phi)

npt.assert_equal(len(w), 2)
npt.assert_(issubclass(w[0].category, DeprecationWarning))
npt.assert_(
"dipy.reconst.shm.real_sym_sh_mrtrix is deprecated, Please use "
"dipy.reconst.shm.real_sh_tournier instead" in str(w[0].message))
npt.assert_(issubclass(w[1].category, PendingDeprecationWarning))
npt.assert_(
"The legacy tournier07 basis is outdated and will be deprecated "
"in a future release of DIPY. Consider using the new tournier07 "
"basis." in str(w[1].message))

func = np.dot(coef, basis.T)
assert_array_almost_equal(func, expected, 4)

Expand Down