Skip to content

Commit

Permalink
Merge 281f997 into 61e5bd8
Browse files Browse the repository at this point in the history
  • Loading branch information
dwf committed Mar 16, 2015
2 parents 61e5bd8 + 281f997 commit 7483e93
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ before_install:
- conda update -q --yes conda
install:
# Install all Python dependencies
- conda install -q --yes python=$TRAVIS_PYTHON_VERSION pip coverage six toolz nose
- conda install -q --yes python=$TRAVIS_PYTHON_VERSION pip coverage six toolz nose numpy
- pip install -q nose2[coverage-plugin] coveralls
script:
- rm -f .coverage
Expand Down
10 changes: 10 additions & 0 deletions picklable_itertools/iter_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import io

import six
try:
import numpy
NUMPY_AVAILABLE = True
except ImportError:
numpy = None
NUMPY_AVAILABLE = False


from .base import BaseItertool


Expand All @@ -26,6 +34,8 @@ def iter_(obj):
return ordered_sequence_iterator(obj)
if isinstance(obj, xrange): # noqa
return range_iterator(obj)
if NUMPY_AVAILABLE and isinstance(obj, numpy.ndarray):
return ordered_sequence_iterator(obj)
if six.PY3 and isinstance(obj, dict_view):
return ordered_sequence_iterator(list(obj))
return iter(obj)
Expand Down
40 changes: 37 additions & 3 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
groupby, permutations, combinations, combinations_with_replacement,
xrange as _xrange
)
from picklable_itertools.iter_dispatch import numpy, NUMPY_AVAILABLE

_map = map if six.PY3 else itertools.imap
_zip = zip if six.PY3 else itertools.izip
_zip_longest = itertools.zip_longest if six.PY3 else itertools.izip_longest
Expand All @@ -29,7 +31,15 @@ def _identity(x):
return x


def verify_same(picklable_version, reference_version, n, *args, **kwargs):
def safe_assert_equal(expected_val, actual_val):
if NUMPY_AVAILABLE and (isinstance(expected_val, numpy.ndarray) or
isinstance(actual_val, numpy.ndarray)):
assert (expected_val == actual_val).all()
else:
assert expected_val == actual_val


def verify_same(picklable_version, reference_version, n, *args, **kwargs):
"""Take a reference version from itertools, verify the same operation
in our version.
"""
Expand All @@ -52,7 +62,7 @@ def verify_same(picklable_version, reference_version, n, *args, **kwargs):
except StopIteration:
assert False, "prematurely exhausted; expected {}".format(
str(expected_val))
assert expected_val == actual_val
safe_assert_equal(expected_val, actual_val)
done += 1


Expand All @@ -66,12 +76,19 @@ def verify_pickle(picklable_version, reference_version, n, m, *args, **kwargs):
while done != n:
expected_val = next(expected)
actual_val = next(actual)
assert expected_val == actual_val
safe_assert_equal(expected_val, actual_val)
if done == m:
actual = cPickle.loads(cPickle.dumps(actual))
done += 1


def conditional_run(condition, f, *args, **kwargs):
if condition:
f(*args, **kwargs)
else:
raise SkipTest


def check_stops(it):
"""Verify that an exhausted iterator yields StopIteration."""
try:
Expand All @@ -88,6 +105,23 @@ def test_ordered_sequence_iterator():
yield verify_same, ordered_sequence_iterator, iter, None, ("D", "X", "J")
yield verify_pickle, ordered_sequence_iterator, iter, 4, 3, [2, 9, 3, 4]
yield verify_pickle, ordered_sequence_iterator, iter, 3, 2, ['a', 'c', 'b']
array = numpy.array if NUMPY_AVAILABLE else list
numpy_pickle_test = partial(conditional_run, NUMPY_AVAILABLE,
verify_pickle)
numpy_same_test = partial(conditional_run, NUMPY_AVAILABLE, verify_same)
yield (numpy_same_test, ordered_sequence_iterator, iter, None,
array([4, 3, 9]))
yield (numpy_same_test, ordered_sequence_iterator, iter, None,
array([[4, 3, 9], [2, 9, 6]]))
yield (numpy_pickle_test, ordered_sequence_iterator, iter, 4, 3,
array([2, 9, 3, 4]))
yield (numpy_pickle_test, ordered_sequence_iterator, iter, 3, 2,
array([[2, 1], [2, 9], [9, 4], [3, 9]]))
# Make sure the range iterator is actually getting dispatched by iter_.
yield (numpy_pickle_test, iter_, iter, 4, 3,
array([2, 9, 3, 4]))
yield (numpy_pickle_test, iter_, iter, 3, 2,
array([[2, 1], [2, 9], [9, 4], [3, 9]]))


def test_dict_iterator():
Expand Down

0 comments on commit 7483e93

Please sign in to comment.