Skip to content

Commit

Permalink
Define bucket.__iter__ (#371)
Browse files Browse the repository at this point in the history
  • Loading branch information
bbayles committed Jan 11, 2020
1 parent 39306e8 commit 3150ad2
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
12 changes: 11 additions & 1 deletion more_itertools/more.py
Expand Up @@ -772,7 +772,9 @@ class bucket:
child iterables based on a *key* function.
>>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
>>> s = bucket(iterable, key=lambda x: x[0])
>>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
>>> sorted(list(s)) # Get the keys
['a', 'b', 'c']
>>> a_iterable = s['a']
>>> next(a_iterable)
'a1'
Expand Down Expand Up @@ -846,6 +848,14 @@ def _get_values(self, value):
elif self._validator(item_value):
self._cache[item_value].append(item)

def __iter__(self):
for item in self._it:
item_value = self._key(item)
if self._validator(item_value):
self._cache[item_value].append(item)

yield from self._cache.keys()

def __getitem__(self, value):
if not self._validator(value):
return iter(())
Expand Down
20 changes: 18 additions & 2 deletions tests/test_more.py
Expand Up @@ -822,8 +822,6 @@ def test_reverse(self):


class BucketTests(TestCase):
"""Tests for ``bucket()``"""

def test_basic(self):
iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33]
D = mi.bucket(iterable, key=lambda x: 10 * (x // 10))
Expand Down Expand Up @@ -859,6 +857,24 @@ def test_validator(self):
self.assertNotIn(0, D._cache) # Don't store non-valid entries
self.assertEqual(list(D[0]), [])

def test_list(self):
iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33]
D = mi.bucket(iterable, key=lambda x: 10 * (x // 10))
self.assertEqual(list(D[10]), [10, 11, 12])
self.assertEqual(list(D[20]), [20, 21, 22, 23])
self.assertEqual(list(D[30]), [30, 31, 33])
self.assertEqual(set(D), {10, 20, 30})

def test_list_validator(self):
iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33]
key = lambda x: 10 * (x // 10)
validator = lambda x: x != 20
D = mi.bucket(iterable, key, validator=validator)
self.assertEqual(set(D), {10, 30})
self.assertEqual(list(D[10]), [10, 11, 12])
self.assertEqual(list(D[20]), [])
self.assertEqual(list(D[30]), [30, 31, 33])


class SpyTests(TestCase):
"""Tests for ``spy()``"""
Expand Down

0 comments on commit 3150ad2

Please sign in to comment.