Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes issue #813 by not checking data type explicitly. #829

Merged
merged 4 commits into from Jan 12, 2016
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion dipy/segment/cythonutils.pxd
Expand Up @@ -30,7 +30,7 @@ cdef struct Shape:
cdef Shape shape_from_memview(Data data) nogil


cdef Shape tuple2shape(dims)
cdef Shape tuple2shape(dims) except *


cdef shape2tuple(Shape shape)
Expand Down
2 changes: 1 addition & 1 deletion dipy/segment/cythonutils.pyx
Expand Up @@ -30,7 +30,7 @@ cdef Shape shape_from_memview(Data data) nogil:
return shape


cdef Shape tuple2shape(dims):
cdef Shape tuple2shape(dims) except *:
""" Converts a Python's tuple into a `Shape` Cython's struct.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion dipy/segment/featurespeed.pyx
Expand Up @@ -41,7 +41,7 @@ cdef class Feature(object):
""" Cython version of `Feature.infer_shape`. """
with gil:
shape = self.infer_shape(np.asarray(datum))
if type(shape) is int:
if np.asarray(shape).ndim == 0:
return tuple2shape((1, shape))
elif len(shape) == 1:
return tuple2shape((1,) + shape)
Expand Down
39 changes: 29 additions & 10 deletions dipy/segment/metricspeed.pyx
Expand Up @@ -134,12 +134,12 @@ cdef class CythonMetric(Metric):
-----
This method calls its Cython version `self.c_are_compatible` accordingly.
"""
if type(shape1) is int:
if np.asarray(shape1).ndim == 0:
shape1 = (1, shape1)
elif len(shape1) == 1:
shape1 = (1,) + shape1

if type(shape2) is int:
if np.asarray(shape2).ndim == 0:
shape2 = (1, shape2)
elif len(shape2) == 1:
shape2 = (1,) + shape2
Expand All @@ -165,6 +165,29 @@ cdef class CythonMetric(Metric):
-----
This method calls its Cython version `self.c_dist` accordingly.
"""
# If needed, we convert features to 2D arrays.
features1 = np.asarray(features1)
if features1.ndim == 0:
features1 = features1[np.newaxis, np.newaxis]
elif features1.ndim == 1:
features1 = features1[np.newaxis]
elif features1.ndim == 2:
pass
else:
raise TypeError("Only scalar, 1D or 2D array features are"
" supported for parameter 'features1'!")

features2 = np.asarray(features2)
if features2.ndim == 0:
features2 = features2[np.newaxis, np.newaxis]
elif features2.ndim == 1:
features2 = features2[np.newaxis]
elif features2.ndim == 2:
pass
else:
raise TypeError("Only scalar, 1D or 2D array features are"
" supported for parameter 'features2'!")

if not self.are_compatible(features1.shape, features2.shape):
raise ValueError("Features are not compatible according to this metric!")

Expand Down Expand Up @@ -406,18 +429,14 @@ cpdef double dist(Metric metric, datum1, datum2) except -1:
double
Distance between two data points.
"""
shape1 = metric.feature.infer_shape(datum1)
shape2 = metric.feature.infer_shape(datum2)

if not metric.are_compatible(shape1, shape2):
raise ValueError("Data features' shapes must be compatible!")

datum1 = datum1 if datum1.flags.writeable and datum1.dtype is np.float32 else datum1.astype(np.float32)
datum2 = datum2 if datum2.flags.writeable and datum2.dtype is np.float32 else datum2.astype(np.float32)

cdef:
Data2D features1 = np.empty(shape1, np.float32)
Data2D features2 = np.empty(shape2, np.float32)
Shape shape1 = metric.feature.c_infer_shape(datum1)
Shape shape2 = metric.feature.c_infer_shape(datum2)
Data2D features1 = np.empty(shape2tuple(shape1), np.float32)
Data2D features2 = np.empty(shape2tuple(shape2), np.float32)

metric.feature.c_extract(datum1, features1)
metric.feature.c_extract(datum2, features2)
Expand Down
21 changes: 21 additions & 0 deletions dipy/segment/tests/test_feature.py
@@ -1,3 +1,4 @@
import sys
import numpy as np
import dipy.segment.metric as dipymetric
from dipy.segment.featurespeed import extract
Expand Down Expand Up @@ -284,6 +285,26 @@ def extract(self, streamline):
d2 = metric.dist(features1, features2)
assert_equal(d1, d2)

class ArcLengthFeature(dipymetric.Feature):
def infer_shape(self, streamline):
if sys.version_info > (3,):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this very ugly. Is there any way of avoiding it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this do it? MarcCote@1368756

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Less ugly but still hard to understand. Could y'all remind me why this is necessary? Isn't the return type dealt with by https://github.com/nipy/dipy/pull/829/files#diff-6102332b6a9436b5c5b18ab9ee552a9eR169 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh - sorry - I see you are trying to test the Windows behavior on non-Windows machines - maybe just a comment to that effect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @matthew-brett said I'm trying to test the Windows 64 bits behavior, i.e. where it uses long instead of int for constant integers. I added a comment explaining why we do such test.

return 1

return long(1)

def extract(self, streamline):
return np.sum(np.sqrt(np.sum((streamline[1:] - streamline[:-1]) ** 2)))

# Test using Python Feature with Cython Metric
feature = ArcLengthFeature()
metric = dipymetric.EuclideanMetric(feature)
d1 = dipymetric.dist(metric, s1, s2)

features1 = metric.feature.extract(s1)
features2 = metric.feature.extract(s2)
d2 = metric.dist(features1, features2)
assert_equal(d1, d2)


if __name__ == '__main__':
run_module_suite()