diff --git a/docs/api.rst b/docs/api.rst index 3a2c3a08..c0afbe32 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -129,7 +129,8 @@ These tools return summarized or aggregated data from an iterable. .. autofunction:: ilen .. autofunction:: first(iterable[, default]) .. autofunction:: last(iterable[, default]) -.. autofunction:: one +.. autofunction:: one(iterable, too_short=ValueError, too_long=ValueError) +.. autofunction:: only(iterable, default=None, too_long=ValueError) .. autofunction:: unique_to_each .. autofunction:: locate(iterable, pred=bool, window_size=None) .. autofunction:: rlocate(iterable, pred=bool, window_size=None) diff --git a/more_itertools/more.py b/more_itertools/more.py index 45a2fd8b..63796e2f 100644 --- a/more_itertools/more.py +++ b/more_itertools/more.py @@ -54,6 +54,7 @@ 'map_reduce', 'numeric_range', 'one', + 'only', 'padded', 'partitions', 'peekable', @@ -503,9 +504,8 @@ def one(iterable, too_short=None, too_long=None): RuntimeError Note that :func:`one` attempts to advance *iterable* twice to ensure there - is only one item. If there is more than one, both items will be discarded. - See :func:`spy` or :func:`peekable` to check iterable contents less - destructively. + is only one item. See :func:`spy` or :func:`peekable` to check iterable + contents less destructively. """ it = iter(iterable) @@ -2401,3 +2401,35 @@ def time_limited(limit_seconds, iterable): if monotonic() - start_time > limit_seconds: break yield item + + +def only(iterable, default=None, too_long=None): + """If *iterable* has only one item, return it. + If it has zero items, return *default*. + If it has more than item, raise the exception given by *too_short*, + which is ``ValueError`` by default. + + >>> only([], default='missing') + 'missing' + >>> only([1]) + 1 + >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too many items in iterable (expected 1)' + + Note that :func:`only` attempts to advance *iterable* twice to ensure there + is only one item. See :func:`spy` or :func:`peekable` to check + iterable contents less destructively. + """ + it = iter(iterable) + value = next(it, default) + + try: + next(it) + except StopIteration: + pass + else: + raise too_long or ValueError('too many items in iterable (expected 1)') + + return value diff --git a/more_itertools/tests/test_more.py b/more_itertools/tests/test_more.py index 9f3df462..b0d97a86 100644 --- a/more_itertools/tests/test_more.py +++ b/more_itertools/tests/test_more.py @@ -2450,3 +2450,22 @@ def test_zero_limit(self): def test_invalid_limit(self): with self.assertRaises(ValueError): list(mi.time_limited(-0.1, count())) + + +class OnlyTests(TestCase): + def test_defaults(self): + self.assertEqual(mi.only([]), None) + self.assertEqual(mi.only([1]), 1) + self.assertRaises(ValueError, lambda: mi.only([1, 2])) + + def test_custom_value(self): + self.assertEqual(mi.only([], default='!'), '!') + self.assertEqual(mi.only([1], default='!'), 1) + self.assertRaises(ValueError, lambda: mi.only([1, 2], default='!')) + + def test_custom_exception(self): + self.assertEqual(mi.only([], too_long=RuntimeError), None) + self.assertEqual(mi.only([1], too_long=RuntimeError), 1) + self.assertRaises( + RuntimeError, lambda: mi.only([1, 2], too_long=RuntimeError) + )