diff --git a/more_itertools/more.py b/more_itertools/more.py index d9d57a40..e062a844 100644 --- a/more_itertools/more.py +++ b/more_itertools/more.py @@ -420,34 +420,65 @@ def with_iter(context_manager): yield item -def one(iterable): - """Return the only element from the iterable. +def one(iterable, too_short=None, too_long=None): + """Return the first item from *iterable*, which is expected to contain only + that item. Raise an exception if *iterable* is empty or has more than one + item. - Raise ValueError if the iterable is empty or longer than 1 element. For - example, assert that a DB query returns a single, unique result. + :func:`one` is useful for ensuring that an iterable contains only one item. + For example, it can be used to retrieve the result of a database query + that is expected to return a single row. - >>> one(['val']) - 'val' + If *iterable* is empty, ``ValueError`` will be raised. You may specify a + different exception with the *too_short* keyword: - >>> one(['val', 'other']) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> it = [] + >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError: too many values to unpack (expected 1) + ValueError: too many items in iterable (expected 1)' + >>> too_short = IndexError('too few items') + >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + IndexError: too few items + + Similarly, if *iterable* contains more than one item, ``ValueError`` will + be raised. You may specify a different exception with the *too_long* + keyword: - >>> one([]) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> it = ['too', 'many'] + >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too many items in iterable (expected 1)' + >>> too_long = RuntimeError + >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError: not enough values to unpack (expected 1, got 0) + RuntimeError - ``one()`` attempts to advance the iterable twice in order to ensure there - aren't further items. Because this discards any second item, ``one()`` is - not suitable in situations where you want to catch its exception and then - try an alternative treatment of the iterable. It should be used only when a - iterable longer than 1 item is, in fact, an error. + Note that :func:`one` attempts to advance *iterable* twice to ensure there + is only one item. If there is more than one, both items will be discarded. + See :func:`spy` or :func:`peekable` to check iterable contents less + destructively. """ - element, = iterable - return element + it = iter(iterable) + + try: + value = next(it) + except StopIteration: + raise too_short or ValueError('too few items in iterable (expected 1)') + + try: + next(it) + except StopIteration: + pass + else: + raise too_long or ValueError('too many items in iterable (expected 1)') + + return value def distinct_permutations(iterable): @@ -1156,7 +1187,6 @@ def always_iterable(obj, base_type=(text_type, binary_type)): >>> list(always_iterable(obj)) [1] - If *obj* is ``None``, return an empty iterable: >>> obj = None @@ -1184,8 +1214,6 @@ def always_iterable(obj, base_type=(text_type, binary_type)): >>> obj = 'foo' >>> list(always_iterable(obj, base_type=None)) ['f', 'o', 'o'] - - """ if obj is None: return iter(()) diff --git a/more_itertools/tests/test_more.py b/more_itertools/tests/test_more.py index ffda1379..d8b3b5ab 100644 --- a/more_itertools/tests/test_more.py +++ b/more_itertools/tests/test_more.py @@ -415,12 +415,22 @@ def test_with_iter(self): class OneTests(TestCase): - def test_one(self): - """Test the ``one()`` cases that aren't covered by its doctests.""" - # Infinite iterables - numbers = count() - self.assertRaises(ValueError, lambda: mi.one(numbers)) # burn 0 and 1 - self.assertEqual(next(numbers), 2) + def test_basic(self): + it = iter(['item']) + self.assertEqual(mi.one(it), 'item') + + def test_too_short(self): + it = iter([]) + self.assertRaises(ValueError, lambda: mi.one(it)) + self.assertRaises(IndexError, lambda: mi.one(it, too_short=IndexError)) + + def test_too_long(self): + it = count() + self.assertRaises(ValueError, lambda: mi.one(it)) # burn 0 and 1 + self.assertEqual(next(it), 2) + self.assertRaises( + OverflowError, lambda: mi.one(it, too_long=OverflowError) + ) class IntersperseTest(TestCase):