Skip to content

Commit

Permalink
Merge pull request #167 from erikrose/bucket-validator
Browse files Browse the repository at this point in the history
Add validator parameter to bucket()
  • Loading branch information
bbayles committed Nov 24, 2017
2 parents e08c285 + 5cdc4a8 commit f7a1817
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
29 changes: 24 additions & 5 deletions more_itertools/more.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ class bucket(object):
child iterables based on a *key* function.
>>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
>>> s = bucket(iterable, key=lambda s: s[0])
>>> s = bucket(iterable, key=lambda x: x[0])
>>> a_iterable = s['a']
>>> next(a_iterable)
'a1'
Expand All @@ -620,16 +620,32 @@ class bucket(object):
The original iterable will be advanced and its items will be cached until
they are used by the child iterables. This may require significant storage.
Be aware that attempting to select a bucket that no items correspond to
will exhaust the iterable and cache all values.
By default, attempting to select a bucket to which no items belong will
exhaust the iterable and cache all values.
If you specify a *validator* function, selected buckets will instead be
checked against it.
>>> from itertools import count
>>> it = count(1, 2) # Infinite sequence of odd numbers
>>> key = lambda x: x % 10 # Bucket by last digit
>>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
>>> s = bucket(it, key=key, validator=validator)
>>> 2 in s
False
>>> list(s[2])
[]
"""
def __init__(self, iterable, key):
def __init__(self, iterable, key, validator=None):
self._it = iter(iterable)
self._key = key
self._cache = defaultdict(deque)
self._validator = validator or (lambda x: True)

def __contains__(self, value):
if not self._validator(value):
return False

try:
item = next(self[value])
except StopIteration:
Expand Down Expand Up @@ -659,10 +675,13 @@ def _get_values(self, value):
if item_value == value:
yield item
break
else:
elif self._validator(item_value):
self._cache[item_value].append(item)

def __getitem__(self, value):
if not self._validator(value):
return iter(())

return self._get_values(value)


Expand Down
10 changes: 10 additions & 0 deletions more_itertools/tests/test_more.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,16 @@ def test_in(self):
# Checking in-ness shouldn't advance the iterator
self.assertEqual(next(D[10]), 10)

def test_validator(self):
iterable = count(0)
key = lambda x: int(str(x)[0]) # First digit of each number
validator = lambda x: 0 < x < 10 # No leading zeros
D = mi.bucket(iterable, key, validator=validator)
self.assertEqual(mi.take(3, D[1]), [1, 10, 11])
self.assertNotIn(0, D) # Non-valid entries don't return True
self.assertNotIn(0, D._cache) # Don't store non-valid entries
self.assertEqual(list(D[0]), [])


class SpyTests(TestCase):
"""Tests for ``spy()``"""
Expand Down

0 comments on commit f7a1817

Please sign in to comment.