diff --git a/dipy/align/vector_fields.pyx b/dipy/align/vector_fields.pyx index 346ec197fa..974310b5d2 100644 --- a/dipy/align/vector_fields.pyx +++ b/dipy/align/vector_fields.pyx @@ -431,7 +431,7 @@ cdef inline int _interpolate_scalar_nn_3d(number[:, :, :] volume, double dkk, return 1 -def interpolate_scalar_3d(floating[:, :, :] image, double[:, :] locations): +def interpolate_scalar_3d(floating[:, :, :] image, locations): r"""Trilinear interpolation of a 3D scalar image Interpolates the 3D image at the given locations. This function is @@ -461,10 +461,11 @@ def interpolate_scalar_3d(floating[:, :, :] image, double[:, :] locations): cnp.npy_intp n = locations.shape[0] floating[:] out = np.zeros(shape=(n,), dtype=ftype) int[:] inside = np.empty(shape=(n,), dtype=np.int32) + double[:,:] _locations = np.array(locations, dtype=np.float64) with nogil: for i in range(n): inside[i] = _interpolate_scalar_3d[floating](image, - locations[i, 0], locations[i, 1], locations[i, 2], &out[i]) + _locations[i, 0], _locations[i, 1], _locations[i, 2], &out[i]) return np.asarray(out), np.asarray(inside) diff --git a/dipy/tracking/streamline.py b/dipy/tracking/streamline.py index b13510ef74..50b728527c 100644 --- a/dipy/tracking/streamline.py +++ b/dipy/tracking/streamline.py @@ -1,15 +1,19 @@ +from copy import deepcopy from warnings import warn +import types +from scipy.spatial.distance import cdist import numpy as np from nibabel.affines import apply_affine + from dipy.tracking.streamlinespeed import set_number_of_points from dipy.tracking.streamlinespeed import length from dipy.tracking.streamlinespeed import compress_streamlines import dipy.tracking.utils as ut from dipy.tracking.utils import streamline_near_roi from dipy.core.geometry import dist_to_corner -from scipy.spatial.distance import cdist -from copy import deepcopy +import dipy.align.vector_fields as vfu + def unlist_streamlines(streamlines): """ Return the streamlines not as a list but as an array and an offset @@ -323,3 +327,138 @@ def orient_by_rois(streamlines, roi1, roi2, affine=None, copy=True): new_sl[idx] = sl[::-1] return new_sl + + +def _extract_vals(data, streamlines, affine=None, threedvec=False): + """ + Helper function for use with `values_from_volume`. + + Parameters + ---------- + data : 3D or 4D array + Scalar (for 3D) and vector (for 4D) values to be extracted. For 4D + data, interpolation will be done on the 3 spatial dimensions in each + volume. + + streamlines : ndarray or list + If array, of shape (n_streamlines, n_nodes, 3) + If list, len(n_streamlines) with (n_nodes, 3) array in + each element of the list. + + affine : ndarray, shape (4, 4) + Affine transformation from voxels (image coordinates) to streamlines. + Default: identity. + + threedvec : bool + Whether the last dimension has length 3. This is a special case in + which we can use :func:`vfu.interpolate_vector_3d` for the + interploation of 4D volumes without looping over the elements of the + last dimension. + + Return + ------ + array or list (depending on the input) : values interpolate to each + coordinate along the length of each streamline + """ + data = data.astype(np.float) + if (isinstance(streamlines, list) or + isinstance(streamlines, types.GeneratorType)): + if affine is not None: + streamlines = ut.move_streamlines(streamlines, + np.linalg.inv(affine)) + + vals = [] + for sl in streamlines: + if threedvec: + vals.append(list(vfu.interpolate_vector_3d(data, + sl.astype(np.float))[0])) + else: + vals.append(list(vfu.interpolate_scalar_3d(data, + sl.astype(np.float))[0])) + + elif isinstance(streamlines, np.ndarray): + sl_shape = streamlines.shape + sl_cat = streamlines.reshape(sl_shape[0] * + sl_shape[1], 3).astype(np.float) + + if affine is not None: + inv_affine = np.linalg.inv(affine) + sl_cat = (np.dot(sl_cat, inv_affine[:3, :3]) + + inv_affine[:3, 3]) + + # So that we can index in one operation: + if threedvec: + vals = np.array(vfu.interpolate_vector_3d(data, sl_cat)[0]) + else: + vals = np.array(vfu.interpolate_scalar_3d(data, sl_cat)[0]) + vals = np.reshape(vals, (sl_shape[0], sl_shape[1], -1)).squeeze() + + else: + raise RuntimeError("Extracting values from a volume ", + "requires streamlines input as an array, ", + "a list of arrays, or a streamline generator.") + + return vals + + +def values_from_volume(data, streamlines, affine=None): + """Extract values of a scalar/vector along each streamline from a volume. + + Parameters + ---------- + data : 3D or 4D array + Scalar (for 3D) and vector (for 4D) values to be extracted. For 4D + data, interpolation will be done on the 3 spatial dimensions in each + volume. + + streamlines : ndarray or list + If array, of shape (n_streamlines, n_nodes, 3) + If list, len(n_streamlines) with (n_nodes, 3) array in + each element of the list. + + affine : ndarray, shape (4, 4) + Affine transformation from voxels (image coordinates) to streamlines. + Default: identity. For example, if no affine is provided and the first + coordinate of the first streamline is ``[1, 0, 0]``, data[1, 0, 0] + would be returned as the value for that streamline coordinate + + Return + ------ + array or list (depending on the input) : values interpolate to each + coordinate along the length of each streamline. + + Notes + ----- + Values are extracted from the image based on the 3D coordinates of the + nodes that comprise the points in the streamline, without any interpolation + into segments between the nodes. Using this function with streamlines that + have been resampled into a very small number of nodes will result in very + few values. + """ + data = np.asarray(data) + if len(data.shape) == 4: + if data.shape[-1] == 3: + return _extract_vals(data, streamlines, affine=affine, + threedvec=True) + if isinstance(streamlines, types.GeneratorType): + streamlines = list(streamlines) + vals = [] + for ii in range(data.shape[-1]): + vals.append(_extract_vals(data[..., ii], streamlines, + affine=affine)) + + if isinstance(vals[-1], np.ndarray): + return np.swapaxes(np.array(vals), 2, 1).T + else: + new_vals = [] + for sl_idx in range(len(streamlines)): + sl_vals = [] + for ii in range(data.shape[-1]): + sl_vals.append(vals[ii][sl_idx]) + new_vals.append(np.array(sl_vals).T) + return new_vals + + elif len(data.shape) == 3: + return _extract_vals(data, streamlines, affine=affine) + else: + raise ValueError("Data needs to have 3 or 4 dimensions") diff --git a/dipy/tracking/tests/test_streamline.py b/dipy/tracking/tests/test_streamline.py index 1a76767165..fd7c407e88 100644 --- a/dipy/tracking/tests/test_streamline.py +++ b/dipy/tracking/tests/test_streamline.py @@ -9,6 +9,7 @@ from numpy.testing import (assert_array_equal, assert_array_almost_equal, assert_raises, run_module_suite) +import dipy.tracking.utils as ut from dipy.tracking.streamline import (set_number_of_points, length as ds_length, relist_streamlines, @@ -18,7 +19,8 @@ select_random_set_of_streamlines, compress_streamlines, select_by_rois, - orient_by_rois) + orient_by_rois, + values_from_volume) streamline = np.array([[82.20181274, 91.36505890, 43.15737152], @@ -800,5 +802,83 @@ def test_orient_by_rois(): npt.assert_equal(new_streamlines, flipped_sl) +def test_values_from_volume(): + decimal = 4 + data3d = np.arange(2000).reshape(20, 10, 10) + # Test two cases of 4D data (handled differently) + # One where the last dimension is length 3: + data4d_3vec = np.arange(6000).reshape(20, 10, 10, 3) + # The other where the last dimension is not 3: + data4d_2vec = np.arange(4000).reshape(20, 10, 10, 2) + for dt in [np.float32, np.float64]: + for data in [data3d, data4d_3vec, data4d_2vec]: + sl1 = [np.array([[1, 0, 0], + [1.5, 0, 0], + [2, 0, 0], + [2.5, 0, 0]]).astype(dt), + np.array([[2, 0, 0], + [3.1, 0, 0], + [3.9, 0, 0], + [4.1, 0, 0]]).astype(dt)] + + ans1 = [[data[1, 0, 0], + data[1, 0, 0] + (data[2, 0, 0] - data[1, 0, 0]) / 2, + data[2, 0, 0], + data[2, 0, 0] + (data[3, 0, 0] - data[2, 0, 0]) / 2], + [data[2, 0, 0], + data[3, 0, 0] + (data[4, 0, 0] - data[3, 0, 0]) * 0.1, + data[3, 0, 0] + (data[4, 0, 0] - data[3, 0, 0]) * 0.9, + data[4, 0, 0] + (data[5, 0, 0] - data[4, 0, 0]) * 0.1]] + + vv = values_from_volume(data, sl1) + npt.assert_almost_equal(vv, ans1, decimal=decimal) + + vv = values_from_volume(data, np.array(sl1)) + npt.assert_almost_equal(vv, ans1, decimal=decimal) + + affine = np.eye(4) + affine[:, 3] = [-100, 10, 1, 1] + x_sl1 = ut.move_streamlines(sl1, affine) + x_sl2 = ut.move_streamlines(sl1, affine) + + vv = values_from_volume(data, x_sl1, affine=affine) + npt.assert_almost_equal(vv, ans1, decimal=decimal) + + # The generator has already been consumed so needs to be + # regenerated: + x_sl1 = list(ut.move_streamlines(sl1, affine)) + vv = values_from_volume(data, x_sl1, affine=affine) + npt.assert_almost_equal(vv, ans1, decimal=decimal) + + # Test that the streamlines haven't mutated: + l_sl2 = list(x_sl2) + npt.assert_equal(x_sl1, l_sl2) + + vv = values_from_volume(data, np.array(x_sl1), affine=affine) + npt.assert_almost_equal(vv, ans1, decimal=decimal) + npt.assert_equal(np.array(x_sl1), np.array(l_sl2)) + + + # Test for lists of streamlines with different numbers of nodes: + sl2 = [sl1[0][:-1], sl1[1]] + ans2 = [ans1[0][:-1], ans1[1]] + vv = values_from_volume(data, sl2) + for ii, v in enumerate(vv): + npt.assert_almost_equal(v, ans2[ii], decimal=decimal) + + # We raise an error if the streamlines fed don't make sense. In this + # case, a tuple instead of a list, generator or array + nonsense_sl = (np.array([[1, 0, 0], + [1.5, 0, 0], + [2, 0, 0], + [2.5, 0, 0]]), + np.array([[2, 0, 0], + [3.1, 0, 0], + [3.9, 0, 0], + [4.1, 0, 0]])) + + npt.assert_raises(RuntimeError, values_from_volume, data, nonsense_sl) + + if __name__ == '__main__': run_module_suite() diff --git a/dipy/tracking/tests/test_utils.py b/dipy/tracking/tests/test_utils.py index f23634e7da..504eff622a 100644 --- a/dipy/tracking/tests/test_utils.py +++ b/dipy/tracking/tests/test_utils.py @@ -4,6 +4,8 @@ import numpy as np import nose +import nibabel as nib + from dipy.io.bvectxt import orientation_from_string from dipy.tracking.utils import (affine_for_trackvis, connectivity_matrix, density_map, length, move_streamlines, diff --git a/dipy/tracking/utils.py b/dipy/tracking/utils.py index 8bdbf5f08d..e2bc1bae8b 100644 --- a/dipy/tracking/utils.py +++ b/dipy/tracking/utils.py @@ -62,7 +62,7 @@ import numpy as np from numpy import (asarray, ceil, dot, empty, eye, sqrt) from dipy.io.bvectxt import ornt_mapping -from . import metrics +from dipy.tracking import metrics # Import helper functions shared with vox2track from ._utils import (_mapping_to_voxel, _to_voxel_coordinates) diff --git a/doc/examples/probabilistic_fiber_tracking.py b/doc/examples/probabilistic_fiber_tracking.py index f4a096394c..2032f1da0d 100644 --- a/doc/examples/probabilistic_fiber_tracking.py +++ b/doc/examples/probabilistic_fiber_tracking.py @@ -106,4 +106,3 @@ streamlines = LocalTracking(prob_dg, classifier, seeds, affine, step_size=.5) save_trk("probabilistic_peaks_from_model.trk", streamlines, affine, labels.shape) -