From 1b3085eafeb8582c8f7a11cabce7d2462ece4044 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 7 Jan 2016 12:14:53 -0500 Subject: [PATCH 1/4] Fixes issue #813 by not checking data type explicitly. --- dipy/segment/cythonutils.pxd | 2 +- dipy/segment/cythonutils.pyx | 2 +- dipy/segment/featurespeed.pyx | 2 +- dipy/segment/metricspeed.pyx | 39 ++++++++++++++++++++++-------- dipy/segment/tests/test_feature.py | 17 +++++++++++++ 5 files changed, 49 insertions(+), 13 deletions(-) diff --git a/dipy/segment/cythonutils.pxd b/dipy/segment/cythonutils.pxd index 86d79216b4..7a98864b50 100644 --- a/dipy/segment/cythonutils.pxd +++ b/dipy/segment/cythonutils.pxd @@ -30,7 +30,7 @@ cdef struct Shape: cdef Shape shape_from_memview(Data data) nogil -cdef Shape tuple2shape(dims) +cdef Shape tuple2shape(dims) except * cdef shape2tuple(Shape shape) diff --git a/dipy/segment/cythonutils.pyx b/dipy/segment/cythonutils.pyx index 6489002125..92994900f6 100644 --- a/dipy/segment/cythonutils.pyx +++ b/dipy/segment/cythonutils.pyx @@ -30,7 +30,7 @@ cdef Shape shape_from_memview(Data data) nogil: return shape -cdef Shape tuple2shape(dims): +cdef Shape tuple2shape(dims) except *: """ Converts a Python's tuple into a `Shape` Cython's struct. Parameters diff --git a/dipy/segment/featurespeed.pyx b/dipy/segment/featurespeed.pyx index 78c91591b4..a0915195d3 100644 --- a/dipy/segment/featurespeed.pyx +++ b/dipy/segment/featurespeed.pyx @@ -41,7 +41,7 @@ cdef class Feature(object): """ Cython version of `Feature.infer_shape`. """ with gil: shape = self.infer_shape(np.asarray(datum)) - if type(shape) is int: + if np.asarray(shape).ndim == 0: return tuple2shape((1, shape)) elif len(shape) == 1: return tuple2shape((1,) + shape) diff --git a/dipy/segment/metricspeed.pyx b/dipy/segment/metricspeed.pyx index 468e393e08..3a4aa2266b 100644 --- a/dipy/segment/metricspeed.pyx +++ b/dipy/segment/metricspeed.pyx @@ -134,12 +134,12 @@ cdef class CythonMetric(Metric): ----- This method calls its Cython version `self.c_are_compatible` accordingly. """ - if type(shape1) is int: + if np.asarray(shape1).ndim == 0: shape1 = (1, shape1) elif len(shape1) == 1: shape1 = (1,) + shape1 - if type(shape2) is int: + if np.asarray(shape2).ndim == 0: shape2 = (1, shape2) elif len(shape2) == 1: shape2 = (1,) + shape2 @@ -165,6 +165,29 @@ cdef class CythonMetric(Metric): ----- This method calls its Cython version `self.c_dist` accordingly. """ + # If needed, we convert features to 2D arrays. + features1 = np.asarray(features1) + if features1.ndim == 0: + features1 = features1[np.newaxis, np.newaxis] + elif features1.ndim == 1: + features1 = features1[np.newaxis] + elif features1.ndim == 2: + pass + else: + raise TypeError("Only scalar, 1D or 2D array features are" + " supported for parameter 'features1'!") + + features2 = np.asarray(features2) + if features2.ndim == 0: + features2 = features2[np.newaxis, np.newaxis] + elif features2.ndim == 1: + features2 = features2[np.newaxis] + elif features2.ndim == 2: + pass + else: + raise TypeError("Only scalar, 1D or 2D array features are" + " supported for parameter 'features2'!") + if not self.are_compatible(features1.shape, features2.shape): raise ValueError("Features are not compatible according to this metric!") @@ -406,18 +429,14 @@ cpdef double dist(Metric metric, datum1, datum2) except -1: double Distance between two data points. """ - shape1 = metric.feature.infer_shape(datum1) - shape2 = metric.feature.infer_shape(datum2) - - if not metric.are_compatible(shape1, shape2): - raise ValueError("Data features' shapes must be compatible!") - datum1 = datum1 if datum1.flags.writeable and datum1.dtype is np.float32 else datum1.astype(np.float32) datum2 = datum2 if datum2.flags.writeable and datum2.dtype is np.float32 else datum2.astype(np.float32) cdef: - Data2D features1 = np.empty(shape1, np.float32) - Data2D features2 = np.empty(shape2, np.float32) + Shape shape1 = metric.feature.c_infer_shape(datum1) + Shape shape2 = metric.feature.c_infer_shape(datum2) + Data2D features1 = np.empty(shape2tuple(shape1), np.float32) + Data2D features2 = np.empty(shape2tuple(shape2), np.float32) metric.feature.c_extract(datum1, features1) metric.feature.c_extract(datum2, features2) diff --git a/dipy/segment/tests/test_feature.py b/dipy/segment/tests/test_feature.py index b8aece1c8e..8093ed1bad 100644 --- a/dipy/segment/tests/test_feature.py +++ b/dipy/segment/tests/test_feature.py @@ -284,6 +284,23 @@ def extract(self, streamline): d2 = metric.dist(features1, features2) assert_equal(d1, d2) + class ArcLengthFeature(dipymetric.Feature): + def infer_shape(self, streamline): + return long(1) + + def extract(self, streamline): + return np.sum(np.sqrt(np.sum((streamline[1:] - streamline[:-1]) ** 2))) + + # Test using Python Feature with Cython Metric + feature = ArcLengthFeature() + metric = dipymetric.EuclideanMetric(feature) + d1 = dipymetric.dist(metric, s1, s2) + + features1 = metric.feature.extract(s1) + features2 = metric.feature.extract(s2) + d2 = metric.dist(features1, features2) + assert_equal(d1, d2) + if __name__ == '__main__': run_module_suite() From 13687569f666d499f84735c97910cd4312ba41b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sat, 9 Jan 2016 10:31:06 -0500 Subject: [PATCH 2/4] Python3 fix. Use np.int64 instead of long. --- dipy/segment/tests/test_feature.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dipy/segment/tests/test_feature.py b/dipy/segment/tests/test_feature.py index 8093ed1bad..34142d5816 100644 --- a/dipy/segment/tests/test_feature.py +++ b/dipy/segment/tests/test_feature.py @@ -286,7 +286,7 @@ def extract(self, streamline): class ArcLengthFeature(dipymetric.Feature): def infer_shape(self, streamline): - return long(1) + return np.int64(1) def extract(self, streamline): return np.sum(np.sqrt(np.sum((streamline[1:] - streamline[:-1]) ** 2))) From 1dadcd89c111184cfd39399fe52090da38f3b7f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sat, 9 Jan 2016 10:43:42 -0500 Subject: [PATCH 3/4] Python3 fix. Return 1 or long(1) depending of sys.version. --- dipy/segment/tests/test_feature.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dipy/segment/tests/test_feature.py b/dipy/segment/tests/test_feature.py index 34142d5816..392157550e 100644 --- a/dipy/segment/tests/test_feature.py +++ b/dipy/segment/tests/test_feature.py @@ -1,3 +1,4 @@ +import sys import numpy as np import dipy.segment.metric as dipymetric from dipy.segment.featurespeed import extract @@ -286,7 +287,10 @@ def extract(self, streamline): class ArcLengthFeature(dipymetric.Feature): def infer_shape(self, streamline): - return np.int64(1) + if sys.version_info > (3,): + return 1 + + return long(1) def extract(self, streamline): return np.sum(np.sqrt(np.sum((streamline[1:] - streamline[:-1]) ** 2))) From 65ffa1eaa8ab2ac1e9f141eb639b8d963e0370c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 11 Jan 2016 09:43:30 -0500 Subject: [PATCH 4/4] Added a comment explaining the relevance of the new unit test. --- dipy/segment/tests/test_feature.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dipy/segment/tests/test_feature.py b/dipy/segment/tests/test_feature.py index 392157550e..dacc73ac09 100644 --- a/dipy/segment/tests/test_feature.py +++ b/dipy/segment/tests/test_feature.py @@ -285,10 +285,13 @@ def extract(self, streamline): d2 = metric.dist(features1, features2) assert_equal(d1, d2) + # Python 2.7 on Windows 64 bits uses long type instead of int for + # constants integer. We make sure the code is robust to such behaviour + # by explicitly testing it. class ArcLengthFeature(dipymetric.Feature): def infer_shape(self, streamline): if sys.version_info > (3,): - return 1 + return 1 # In Python 3, constant integer are of type long. return long(1)