Skip to content

Commit

Permalink
Cast to float input, reduce test accuracy.
Browse files Browse the repository at this point in the history
  • Loading branch information
arokem committed Jan 20, 2016
1 parent f0049cc commit 8475809
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
10 changes: 7 additions & 3 deletions dipy/tracking/streamline.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,16 +369,20 @@ def _extract_vals(data, streamlines, affine=None, threedvec=False):
isinstance(streamlines, types.GeneratorType)):
vals = []
for sl in streamlines:

if threedvec:
vals.append(list(vfu.interpolate_vector_3d(data, sl)[0]))
vals.append(list(vfu.interpolate_vector_3d(data,
sl.astype(np.float))[0]))
else:
vals.append(list(vfu.interpolate_scalar_3d(data, sl)[0]))
vals.append(list(vfu.interpolate_scalar_3d(data,
sl.astype(np.float))[0]))

elif isinstance(streamlines, np.ndarray):
sl_shape = streamlines.shape
sl_cat = streamlines.reshape(sl_shape[0] * sl_shape[1], 3)
if affine is not None:
sl_cat = np.dot(sl_cat, affine[:3, :3]) + affine[:3, 3]
sl_cat = (np.dot(sl_cat, affine[:3, :3]) +
affine[:3, 3]).astype(np.float)

# So that we can index in one operation:
if threedvec:
Expand Down
13 changes: 7 additions & 6 deletions dipy/tracking/tests/test_streamline.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ def test_orient_by_rois():


def test_values_from_volume():
decimal = 4
data3d = np.arange(2000).reshape(20, 10, 10)
# Test two cases of 4D data (handled differently)
# One where the last dimension is length 3:
Expand Down Expand Up @@ -830,33 +831,33 @@ def test_values_from_volume():
data[4, 0, 0] + (data[5, 0, 0] - data[4, 0, 0]) * 0.1]]

vv = values_from_volume(data, sl1)
npt.assert_almost_equal(vv, ans1)
npt.assert_almost_equal(vv, ans1, decimal=decimal)

vv = values_from_volume(data, np.array(sl1))
npt.assert_almost_equal(vv, ans1)
npt.assert_almost_equal(vv, ans1, decimal=decimal)

affine = np.eye(4)
affine[:, 3] = [-100, 10, 1, 1]
x_sl1 = ut.move_streamlines(sl1, affine)

vv = values_from_volume(data, x_sl1, affine=affine)
npt.assert_almost_equal(vv, ans1)
npt.assert_almost_equal(vv, ans1, decimal=decimal)

# The generator has already been consumed so needs to be
# regenerated:
x_sl1 = list(ut.move_streamlines(sl1, affine))
vv = values_from_volume(data, x_sl1, affine=affine)
npt.assert_almost_equal(vv, ans1)
npt.assert_almost_equal(vv, ans1, decimal=decimal)

vv = values_from_volume(data, np.array(x_sl1), affine=affine)
npt.assert_almost_equal(vv, ans1)
npt.assert_almost_equal(vv, ans1, decimal=decimal)

# Test for lists of streamlines with different numbers of nodes:
sl2 = [sl1[0][:-1], sl1[1]]
ans2 = [ans1[0][:-1], ans1[1]]
vv = values_from_volume(data, sl2)
for ii, v in enumerate(vv):
npt.assert_almost_equal(v, ans2[ii])
npt.assert_almost_equal(v, ans2[ii], decimal=decimal)

# We raise an error if the streamlines fed don't make sense. In this
# case, a tuple instead of a list, generator or array
Expand Down

0 comments on commit 8475809

Please sign in to comment.