Skip to content

Commit

Permalink
Allow one() to throw user-specified exceptions.
Browse files Browse the repository at this point in the history
The use-case for me is that I often want to present a different error
message to the user depending on whether there were too many or too few
items in the iterable.
  • Loading branch information
kalekundert committed Dec 1, 2017
1 parent d3a4674 commit 97df8d3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
29 changes: 22 additions & 7 deletions more_itertools/more.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,10 +420,10 @@ def with_iter(context_manager):
yield item


def one(iterable):
def one(iterable, too_short=None, too_long=None):
"""Return the only element from the iterable.
Raise ValueError if the iterable is empty or longer than 1 element. For
Raise an exception if the iterable is empty or longer than 1 element. For
example, assert that a DB query returns a single, unique result.
>>> one(['val'])
Expand All @@ -439,15 +439,33 @@ def one(iterable):
...
ValueError: not enough values to unpack (expected 1, got 0)
By default, ``one()`` will raise a ValueError if the iterable has the wrong
number of elements. However, you can also provide custom exceptions via
the ``too_short`` and ``too_long`` arguments to raise if the iterable is
either too short (i.e. empty) or too long (i.e. more than one element).
``one()`` attempts to advance the iterable twice in order to ensure there
aren't further items. Because this discards any second item, ``one()`` is
not suitable in situations where you want to catch its exception and then
try an alternative treatment of the iterable. It should be used only when a
iterable longer than 1 item is, in fact, an error.
"""
element, = iterable
return element
it = iter(iterable)

try:
value = next(it)
except StopIteration:
raise too_short or ValueError("not enough values to unpack (expected 1, got 0)")

try:
next(it)
except StopIteration:
pass
else:
raise too_long or ValueError("too many values to unpack (expected 1)")

return value


def distinct_permutations(iterable):
Expand Down Expand Up @@ -1156,7 +1174,6 @@ def always_iterable(obj, base_type=(text_type, binary_type)):
>>> list(always_iterable(obj))
[1]
If *obj* is ``None``, return an empty iterable:
>>> obj = None
Expand Down Expand Up @@ -1184,8 +1201,6 @@ def always_iterable(obj, base_type=(text_type, binary_type)):
>>> obj = 'foo'
>>> list(always_iterable(obj, base_type=None))
['f', 'o', 'o']
"""
if obj is None:
return iter(())
Expand Down
6 changes: 6 additions & 0 deletions more_itertools/tests/test_more.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,12 @@ def test_one(self):
self.assertRaises(ValueError, lambda: mi.one(numbers)) # burn 0 and 1
self.assertEqual(next(numbers), 2)

# Custom exceptions
self.assertRaises(ZeroDivisionError,
lambda: mi.one([], ZeroDivisionError, OverflowError))
self.assertRaises(OverflowError,
lambda: mi.one([1,2], ZeroDivisionError, OverflowError))


class IntersperseTest(TestCase):
""" Tests for intersperse() """
Expand Down

0 comments on commit 97df8d3

Please sign in to comment.