Skip to content

Commit

Permalink
Merge pull request #221 from erikrose/replace
Browse files Browse the repository at this point in the history
Ctrl+H
  • Loading branch information
bbayles committed Jul 15, 2018
2 parents 524e183 + 382fba7 commit 7041f0a
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 14 deletions.
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ Others

**New itertools**

.. autofunction:: replace
.. autofunction:: numeric_range(start, stop, step)
.. autofunction:: always_reversible
.. autofunction:: side_effect
Expand Down
111 changes: 97 additions & 14 deletions more_itertools/more.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
groupby,
islice,
repeat,
starmap,
takewhile,
tee
)
Expand Down Expand Up @@ -61,6 +62,7 @@
'one',
'padded',
'peekable',
'replace',
'rlocate',
'rstrip',
'run_length',
Expand Down Expand Up @@ -1463,7 +1465,7 @@ def count_cycle(iterable, n=None):
return ((i, item) for i in counter for item in iterable)


def locate(iterable, pred=bool):
def locate(iterable, pred=bool, window_size=None):
"""Yield the index of each item in *iterable* for which *pred* returns
``True``.
Expand All @@ -1473,18 +1475,17 @@ def locate(iterable, pred=bool):
[1, 2, 4]
Set *pred* to a custom function to, e.g., find the indexes for a particular
item:
item.
>>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b'))
[1, 3]
Use with :func:`windowed` to find the indexes of a sub-sequence:
If *window_size* is given, then the *pred* function will be called with
that many items. This enables searching for sub-sequences:
>>> from more_itertools import windowed
>>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
>>> sub = [1, 2, 3]
>>> pred = lambda w: w == tuple(sub) # windowed() returns tuples
>>> list(locate(windowed(iterable, len(sub)), pred=pred))
>>> pred = lambda *args: args == (1, 2, 3)
>>> list(locate(iterable, pred=pred, window_size=3))
[1, 5, 9]
Use with :func:`seekable` to find indexes and then retrieve the associated
Expand All @@ -1502,7 +1503,14 @@ def locate(iterable, pred=bool):
106
"""
return compress(count(), map(pred, iterable))
if window_size is None:
return compress(count(), map(pred, iterable))

if window_size < 1:
raise ValueError('window size must be at least 1')

it = windowed(iterable, window_size, fillvalue=_marker)
return compress(count(), starmap(pred, it))


def lstrip(iterable, pred):
Expand Down Expand Up @@ -2096,7 +2104,7 @@ def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None):
return ret


def rlocate(iterable, pred=bool):
def rlocate(iterable, pred=bool, window_size=None):
"""Yield the index of each item in *iterable* for which *pred* returns
``True``, starting from the right and moving left.
Expand All @@ -2113,6 +2121,14 @@ def rlocate(iterable, pred=bool):
>>> list(rlocate(iterable, pred))
[3, 1]
If *window_size* is given, then the *pred* function will be called with
that many items. This enables searching for sub-sequences:
>>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
>>> pred = lambda *args: args == (1, 2, 3)
>>> list(rlocate(iterable, pred=pred, window_size=3))
[9, 5, 1]
Beware, this function won't return anything for infinite iterables.
If *iterable* is reversible, ``rlocate`` will reverse it and search from
the right. Otherwise, it will search from the left and return the results
Expand All @@ -2121,8 +2137,75 @@ def rlocate(iterable, pred=bool):
See :func:`locate` to for other example applications.
"""
try:
len_iter = len(iterable)
return (len_iter - i - 1 for i in locate(reversed(iterable), pred))
except TypeError:
return reversed(list(locate(iterable, pred)))
if window_size is None:
try:
len_iter = len(iterable)
return (
len_iter - i - 1 for i in locate(reversed(iterable), pred)
)
except TypeError:
pass

return reversed(list(locate(iterable, pred, window_size)))


def replace(iterable, pred, substitutes, count=None, window_size=1):
"""Yield the items from *iterable*, replacing the items for which *pred*
returns ``True`` with the items from the iterable *substitutes*.
>>> iterable = [1, 1, 0, 1, 1, 0, 1, 1]
>>> pred = lambda x: x == 0
>>> substitutes = (2, 3)
>>> list(replace(iterable, pred, substitutes))
[1, 1, 2, 3, 1, 1, 2, 3, 1, 1]
If *count* is given, the number of replacements will be limited:
>>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0]
>>> pred = lambda x: x == 0
>>> substitutes = [None]
>>> list(replace(iterable, pred, substitutes, count=2))
[1, 1, None, 1, 1, None, 1, 1, 0]
Use *window_size* to control the number of items passed as arguments to
*pred*. This allows for locating and replacing subsequences.
>>> iterable = [0, 1, 2, 5, 0, 1, 2, 5]
>>> window_size = 3
>>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred
>>> substitutes = [3, 4] # Splice in these items
>>> list(replace(iterable, pred, substitutes, window_size=window_size))
[3, 4, 5, 3, 4, 5]
"""
if window_size < 1:
raise ValueError('window_size must be at least 1')

# Save the substitutes iterable, since it's used more than once
substitutes = tuple(substitutes)

# Add padding such that the number of windows matches the length of the
# iterable
it = chain(iterable, [_marker] * (window_size - 1))
windows = windowed(it, window_size)

n = 0
for w in windows:
# If the current window matches our predicate (and we haven't hit
# our maximum number of replacements), splice in the substitutes
# and then consume the following windows that overlap with this one.
# For example, if the iterable is (0, 1, 2, 3, 4...)
# and the window size is 2, we have (0, 1), (1, 2), (2, 3)...
# If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2)
if pred(*w):
if (count is None) or (n < count):
n += 1
for s in substitutes:
yield s
consume(windows, window_size - 1)
continue

# If there was no match (or we've reached the replacement limit),
# yield the first item from the window.
if w and (w[0] is not _marker):
yield w[0]
110 changes: 110 additions & 0 deletions more_itertools/tests/test_more.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,6 +1547,26 @@ def test_custom_pred(self):
expected = [0, 3, 5, 6]
self.assertEqual(actual, expected)

def test_window_size(self):
iterable = ['0', 1, 1, '0', 1, '0', '0']
pred = lambda *args: args == ('0', 1)
actual = list(mi.locate(iterable, pred, window_size=2))
expected = [0, 3]
self.assertEqual(actual, expected)

def test_window_size_large(self):
iterable = [1, 2, 3, 4]
pred = lambda a, b, c, d, e: True
actual = list(mi.locate(iterable, pred, window_size=5))
expected = [0]
self.assertEqual(actual, expected)

def test_window_size_zero(self):
iterable = [1, 2, 3, 4]
pred = lambda: True
with self.assertRaises(ValueError):
list(mi.locate(iterable, pred, window_size=0))


class StripFunctionTests(TestCase):
def test_hashable(self):
Expand Down Expand Up @@ -1962,3 +1982,93 @@ def test_efficient_reversal(self):
pred = lambda x: x == target # Find-able from the right
actual = next(mi.rlocate(iterable, pred))
self.assertEqual(actual, target)

def test_window_size(self):
iterable = ['0', 1, 1, '0', 1, '0', '0']
pred = lambda *args: args == ('0', 1)
for it in (iterable, iter(iterable)):
actual = list(mi.rlocate(it, pred, window_size=2))
expected = [3, 0]
self.assertEqual(actual, expected)

def test_window_size_large(self):
iterable = [1, 2, 3, 4]
pred = lambda a, b, c, d, e: True
for it in (iterable, iter(iterable)):
actual = list(mi.rlocate(iterable, pred, window_size=5))
expected = [0]
self.assertEqual(actual, expected)

def test_window_size_zero(self):
iterable = [1, 2, 3, 4]
pred = lambda: True
for it in (iterable, iter(iterable)):
with self.assertRaises(ValueError):
list(mi.locate(iterable, pred, window_size=0))


class ReplaceTests(TestCase):
def test_basic(self):
iterable = range(10)
pred = lambda x: x % 2 == 0
substitutes = []
actual = list(mi.replace(iterable, pred, substitutes))
expected = [1, 3, 5, 7, 9]
self.assertEqual(actual, expected)

def test_count(self):
iterable = range(10)
pred = lambda x: x % 2 == 0
substitutes = []
actual = list(mi.replace(iterable, pred, substitutes, count=4))
expected = [1, 3, 5, 7, 8, 9]
self.assertEqual(actual, expected)

def test_window_size(self):
iterable = range(10)
pred = lambda *args: args == (0, 1, 2)
substitutes = []
actual = list(mi.replace(iterable, pred, substitutes, window_size=3))
expected = [3, 4, 5, 6, 7, 8, 9]
self.assertEqual(actual, expected)

def test_window_size_end(self):
iterable = range(10)
pred = lambda *args: args == (7, 8, 9)
substitutes = []
actual = list(mi.replace(iterable, pred, substitutes, window_size=3))
expected = [0, 1, 2, 3, 4, 5, 6]
self.assertEqual(actual, expected)

def test_window_size_count(self):
iterable = range(10)
pred = lambda *args: (args == (0, 1, 2)) or (args == (7, 8, 9))
substitutes = []
actual = list(
mi.replace(iterable, pred, substitutes, count=1, window_size=3)
)
expected = [3, 4, 5, 6, 7, 8, 9]
self.assertEqual(actual, expected)

def test_window_size_large(self):
iterable = range(4)
pred = lambda a, b, c, d, e: True
substitutes = [5, 6, 7]
actual = list(mi.replace(iterable, pred, substitutes, window_size=5))
expected = [5, 6, 7]
self.assertEqual(actual, expected)

def test_window_size_zero(self):
iterable = range(10)
pred = lambda *args: True
substitutes = []
with self.assertRaises(ValueError):
list(mi.replace(iterable, pred, substitutes, window_size=0))

def test_iterable_substitutes(self):
iterable = range(5)
pred = lambda x: x % 2 == 0
substitutes = iter('__')
actual = list(mi.replace(iterable, pred, substitutes))
expected = ['_', '_', 1, '_', '_', 3, '_', '_']
self.assertEqual(actual, expected)

0 comments on commit 7041f0a

Please sign in to comment.