From 825eb0f08bcfa549ce83abeb567e0862b23286e1 Mon Sep 17 00:00:00 2001 From: David Warde-Farley Date: Mon, 16 Mar 2015 18:26:36 -0400 Subject: [PATCH 1/2] iter_: support NumPy arrays --- picklable_itertools/iter_dispatch.py | 10 +++++++ tests/__init__.py | 40 +++++++++++++++++++++++++--- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/picklable_itertools/iter_dispatch.py b/picklable_itertools/iter_dispatch.py index 12a5b08..17de62d 100644 --- a/picklable_itertools/iter_dispatch.py +++ b/picklable_itertools/iter_dispatch.py @@ -1,6 +1,14 @@ import io import six +try: + import numpy + NUMPY_AVAILABLE = True +except ImportError: + numpy = None + NUMPY_AVAILABLE = False + + from .base import BaseItertool @@ -26,6 +34,8 @@ def iter_(obj): return ordered_sequence_iterator(obj) if isinstance(obj, xrange): # noqa return range_iterator(obj) + if NUMPY_AVAILABLE and isinstance(obj, numpy.ndarray): + return ordered_sequence_iterator(obj) if six.PY3 and isinstance(obj, dict_view): return ordered_sequence_iterator(list(obj)) return iter(obj) diff --git a/tests/__init__.py b/tests/__init__.py index 3d39df5..4a1c33e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -17,6 +17,8 @@ groupby, permutations, combinations, combinations_with_replacement, xrange as _xrange ) +from picklable_itertools.iter_dispatch import numpy, NUMPY_AVAILABLE + _map = map if six.PY3 else itertools.imap _zip = zip if six.PY3 else itertools.izip _zip_longest = itertools.zip_longest if six.PY3 else itertools.izip_longest @@ -29,7 +31,15 @@ def _identity(x): return x -def verify_same(picklable_version, reference_version, n, *args, **kwargs): +def safe_assert_equal(expected_val, actual_val): + if NUMPY_AVAILABLE and (isinstance(expected_val, numpy.ndarray) or + isinstance(actual_val, numpy.ndarray)): + assert (expected_val == actual_val).all() + else: + assert expected_val == actual_val + + +def verify_same(picklable_version, reference_version, n, *args, **kwargs): """Take a reference version from itertools, verify the same operation in our version. """ @@ -52,7 +62,7 @@ def verify_same(picklable_version, reference_version, n, *args, **kwargs): except StopIteration: assert False, "prematurely exhausted; expected {}".format( str(expected_val)) - assert expected_val == actual_val + safe_assert_equal(expected_val, actual_val) done += 1 @@ -66,12 +76,19 @@ def verify_pickle(picklable_version, reference_version, n, m, *args, **kwargs): while done != n: expected_val = next(expected) actual_val = next(actual) - assert expected_val == actual_val + safe_assert_equal(expected_val, actual_val) if done == m: actual = cPickle.loads(cPickle.dumps(actual)) done += 1 +def conditional_run(condition, f, *args, **kwargs): + if condition: + f(*args, **kwargs) + else: + raise SkipTest + + def check_stops(it): """Verify that an exhausted iterator yields StopIteration.""" try: @@ -88,6 +105,23 @@ def test_ordered_sequence_iterator(): yield verify_same, ordered_sequence_iterator, iter, None, ("D", "X", "J") yield verify_pickle, ordered_sequence_iterator, iter, 4, 3, [2, 9, 3, 4] yield verify_pickle, ordered_sequence_iterator, iter, 3, 2, ['a', 'c', 'b'] + array = numpy.array if NUMPY_AVAILABLE else list + numpy_pickle_test = partial(conditional_run, NUMPY_AVAILABLE, + verify_pickle) + numpy_same_test = partial(conditional_run, NUMPY_AVAILABLE, verify_same) + yield (numpy_same_test, ordered_sequence_iterator, iter, None, + array([4, 3, 9])) + yield (numpy_same_test, ordered_sequence_iterator, iter, None, + array([[4, 3, 9], [2, 9, 6]])) + yield (numpy_pickle_test, ordered_sequence_iterator, iter, 4, 3, + array([2, 9, 3, 4])) + yield (numpy_pickle_test, ordered_sequence_iterator, iter, 3, 2, + array([[2, 1], [2, 9], [9, 4], [3, 9]])) + # Make sure the range iterator is actually getting dispatched by iter_. + yield (numpy_pickle_test, iter_, iter, 4, 3, + array([2, 9, 3, 4])) + yield (numpy_pickle_test, iter_, iter, 3, 2, + array([[2, 1], [2, 9], [9, 4], [3, 9]])) def test_dict_iterator(): From 281f9972a7b6dddea400cc9fd2730426b47bc17c Mon Sep 17 00:00:00 2001 From: David Warde-Farley Date: Mon, 16 Mar 2015 18:32:32 -0400 Subject: [PATCH 2/2] Add NumPy install to Travis. --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index d4ae4e4..57a948c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,7 +19,7 @@ before_install: - conda update -q --yes conda install: # Install all Python dependencies - - conda install -q --yes python=$TRAVIS_PYTHON_VERSION pip coverage six toolz nose + - conda install -q --yes python=$TRAVIS_PYTHON_VERSION pip coverage six toolz nose numpy - pip install -q nose2[coverage-plugin] coveralls script: - rm -f .coverage