Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rsplit #385

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
259 changes: 196 additions & 63 deletions more_itertools/more.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from collections import Counter, defaultdict, deque, abc
from collections.abc import Sequence
from functools import partial, wraps
from functools import partial, wraps, reduce
from heapq import merge, heapify, heapreplace, heappop
from itertools import (
chain,
Expand All @@ -19,7 +19,7 @@
)
from math import exp, floor, log
from random import random, randrange, uniform
from operator import itemgetter, sub, gt, lt
from operator import itemgetter, sub, gt, lt, add
from sys import maxsize
from time import monotonic

Expand Down Expand Up @@ -1064,14 +1064,18 @@ def sliced(seq, n):
return takewhile(len, (seq[i : i + n] for i in count(0, n)))


def split_at(iterable, pred, maxsplit=-1):
def split_at(iterable, pred, maxsplit=-1, rsplit=False, keep_separator=False):
"""Yield lists of items from *iterable*, where each list is delimited by
an item where callable *pred* returns ``True``.
The lists do not include the delimiting items:
The lists do not include the delimiting items, unless you set
*keep_separator* to ``True``:

>>> list(split_at('abcdcba', lambda x: x == 'b'))
[['a'], ['c', 'd', 'c'], ['a']]

>>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True))
[['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']]

>>> list(split_at(range(10), lambda n: n % 2 == 1))
[[0], [2], [4], [6], [8], []]

Expand All @@ -1080,27 +1084,117 @@ def split_at(iterable, pred, maxsplit=-1):

>>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2))
[[0], [2], [4, 5, 6, 7, 8, 9]]

If *rsplit* is ``True``, then the splits are made from right to left:

>>> list(split_at(range(10), lambda n: n % 2 == 1,
... maxsplit=2, rsplit=True))
[[0, 1, 2, 3, 4, 5, 6], [8], []]

>>> list(split_at(range(10), lambda n: n % 2 == 1,
... maxsplit=2, rsplit=True, keep_separator=True))
[[0, 1, 2, 3, 4, 5, 6], [7], [8], [9], []]
"""
def _split_at(iterable, pred, maxsplit, keep_separator):
buf = []
it = iter(iterable)
for item in it:
if pred(item):
yield buf
if keep_separator:
yield [item]
if maxsplit == 1:
yield list(it)
return
buf = []
maxsplit -= 1
else:
buf.append(item)
yield buf

if maxsplit == 0:
yield list(iterable)
return

buf = []
it = iter(iterable)
for item in it:
if pred(item):
yield buf
if maxsplit == 1:
yield list(it)
return
buf = []
maxsplit -= 1
if rsplit and maxsplit > 0:
it = _concatenate_slice(_split_at(iterable, pred, maxsplit=-1,
keep_separator=True), -2 * maxsplit)
if keep_separator:
yield from it
else:
buf.append(item)
yield buf
yield from islice(it, None, None, 2)
else:
yield from _split_at(iterable, pred, maxsplit, keep_separator)


def _concatenate_slice(iterable, n=None):
"""Concatenate the *n* first elements of *iterable* and then iterate as
usual:

>>> list(_concatenate_slice(([x] for x in range(10)), 3))
[[0, 1, 2], [3], [4], [5], [6], [7], [8], [9]]

If *n* is absent or `None`, then apply *func* to all elements:

>>> list(_concatenate_slice(([x] for x in range(10))))
[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]

If *n* is negative, then *-n* elements are omitted:

>>> list(_concatenate_slice(([x] for x in range(10)), -3))
[[0, 1, 2, 3, 4, 5, 6], [7], [8], [9]]

"""
it = _aggregate_slice(iterable, lambda it: reduce(add, it), n)
try:
yield next(it)
except TypeError: # empty iterator
pass
else:
yield from it


def _aggregate_slice(iterable, func, n=None):
"""Apply *func* to the *n* first elements of *iterable* and then iterate as
usual:

>>> list(_aggregate_slice(range(1, 10), sum, 3))
[6, 4, 5, 6, 7, 8, 9]

If *n* is absent or `None`, then apply *func* to all elements:

>>> list(_aggregate_slice(range(1, 10), sum))
[45]

If *n* is negative, then *-n* elements are omitted:

>>> list(_aggregate_slice(range(1, 10), sum, -3))
[21, 7, 8, 9]

"""
if n is None:
yield func(iterable)
return

if n < 0:
try:
total = len(iterable)
except TypeError:
iterable = tuple(iterable)
total = len(iterable)

if n < -total:
n = 0
else:
n = total + n

it = iter(iterable)
if n != 0:
yield func(islice(it, n))
yield from it


def split_before(iterable, pred, maxsplit=-1):
def split_before(iterable, pred, maxsplit=-1, rsplit=False):
"""Yield lists of items from *iterable*, where each list ends just before
an item for which callable *pred* returns ``True``:

Expand All @@ -1115,26 +1209,39 @@ def split_before(iterable, pred, maxsplit=-1):

>>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2))
[[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]

If *rsplit* is ``True``, then the splits are made from right to left:

>>> list(split_before(range(10), lambda n: n % 3 == 0,
... maxsplit=2, rsplit=True))
[[0, 1, 2, 3, 4, 5], [6, 7, 8], [9]]
"""
def _split_before(iterable, pred, maxsplit):
buf = []
it = iter(iterable)
for item in it:
if pred(item) and buf:
yield buf
if maxsplit == 1:
yield [item] + list(it)
return
buf = []
maxsplit -= 1
buf.append(item)
yield buf

if maxsplit == 0:
yield list(iterable)
return

buf = []
it = iter(iterable)
for item in it:
if pred(item) and buf:
yield buf
if maxsplit == 1:
yield [item] + list(it)
return
buf = []
maxsplit -= 1
buf.append(item)
yield buf
if rsplit and maxsplit > 0:
yield from _concatenate_slice(_split_before(iterable, pred,
maxsplit=-1), -maxsplit)
else:
yield from _split_before(iterable, pred, maxsplit)


def split_after(iterable, pred, maxsplit=-1):
def split_after(iterable, pred, maxsplit=-1, rsplit=False):
"""Yield lists of items from *iterable*, where each list ends with an
item where callable *pred* returns ``True``:

Expand All @@ -1150,27 +1257,40 @@ def split_after(iterable, pred, maxsplit=-1):
>>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2))
[[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]]

If *rsplit* is ``True``, then the splits are made from right to left:

>>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2,
... rsplit=True))
[[0, 1, 2, 3], [4, 5, 6], [7, 8, 9]]

"""
if maxsplit == 0:
yield list(iterable)
return

buf = []
it = iter(iterable)
for item in it:
buf.append(item)
if pred(item) and buf:
def _split_after(iterable, pred, maxsplit):
buf = []
it = iter(iterable)
for item in it:
buf.append(item)
if pred(item) and buf:
yield buf
if maxsplit == 1:
yield list(it)
return
buf = []
maxsplit -= 1
if buf:
yield buf
if maxsplit == 1:
yield list(it)
return
buf = []
maxsplit -= 1
if buf:
yield buf

if rsplit and maxsplit > 0:
yield from _concatenate_slice(_split_after(iterable, pred,
maxsplit=-1), -maxsplit)
else:
yield from _split_after(iterable, pred, maxsplit)


def split_when(iterable, pred, maxsplit=-1):
def split_when(iterable, pred, maxsplit=-1, rsplit=False):
"""Split *iterable* into pieces based on the output of *pred*.
*pred* should be a function that takes successive pairs of items and
returns ``True`` if the iterable should be split in between them.
Expand All @@ -1188,31 +1308,44 @@ def split_when(iterable, pred, maxsplit=-1):
... lambda x, y: x > y, maxsplit=2))
[[1, 2, 3, 3], [2, 5], [2, 4, 2]]

If *rsplit* is ``True``, then the splits are made from right to left:

>>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2],
... lambda x, y: x > y, maxsplit=2, rsplit=True))
[[1, 2, 3, 3, 2, 5], [2, 4], [2]]

"""
if maxsplit == 0:
yield list(iterable)
return
def _split_when(iterable, pred, maxsplit):
it = iter(iterable)
try:
cur_item = next(it)
except StopIteration:
return

it = iter(iterable)
try:
cur_item = next(it)
except StopIteration:
return
buf = [cur_item]
for next_item in it:
if pred(cur_item, next_item):
yield buf
if maxsplit == 1:
yield [next_item] + list(it)
return
buf = []
maxsplit -= 1

buf = [cur_item]
for next_item in it:
if pred(cur_item, next_item):
yield buf
if maxsplit == 1:
yield [next_item] + list(it)
return
buf = []
maxsplit -= 1
buf.append(next_item)
cur_item = next_item

buf.append(next_item)
cur_item = next_item
yield buf

yield buf
if maxsplit == 0:
yield list(iterable)
return

if rsplit and maxsplit > 0:
yield from _concatenate_slice(_split_when(iterable, pred,
maxsplit=-1), -maxsplit)
else:
yield from _split_when(iterable, pred, maxsplit)


def split_into(iterable, sizes):
Expand Down