Skip to content

Commit

Permalink
Merge pull request #25525 from mhvk/string-replace-small-optimization
Browse files Browse the repository at this point in the history
MAINT: optimization and broadcasting for .replace() method for strings.
  • Loading branch information
ngoldbaum committed Jan 8, 2024
2 parents ea56f1c + 01e9f72 commit 59f44a1
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 10 deletions.
26 changes: 16 additions & 10 deletions numpy/_core/defchararray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,16 +1357,16 @@ def replace(a, old, new, count=None):
----------
a : array-like of str or unicode
old, new : str or unicode
old, new : scalar or array-like str or unicode
count : int, optional
count : scalar or array-like int
If the optional argument `count` is given, only the first
`count` occurrences are replaced.
`count` occurrences are replaced. If negative, replace all.
Returns
-------
out : ndarray
Output array of str or unicode, depending on input type
Output array of str or unicode, depending on input type.
See Also
--------
Expand All @@ -1383,19 +1383,25 @@ def replace(a, old, new, count=None):
array(['The dwash was fresh', 'Thwas was it'], dtype='<U19')
"""
a_arr = numpy.asarray(a)
a_arr = numpy.asanyarray(a)
old = numpy.asanyarray(old)
new = numpy.asanyarray(new)
max_int64 = numpy.iinfo(numpy.int64).max
count = count if count is not None else max_int64

counts = numpy._core.umath.count(a_arr, old, 0, max_int64)
if count is not None:
count = numpy.asanyarray(count)
counts = numpy.where(count < 0, counts,
numpy.minimum(counts, count))

buffersizes = (
numpy._core.umath.str_len(a_arr)
+ counts * (numpy._core.umath.str_len(new) -
numpy._core.umath.str_len(old))
)
max_buffersize = numpy.max(buffersizes)
out = numpy.empty(a_arr.shape, dtype=f"{a_arr.dtype.char}{max_buffersize}")
numpy._core.umath._replace(a_arr, old, new, count, out=out)
# buffersizes is properly broadcast along all inputs.
out = numpy.empty_like(a_arr, shape=buffersizes.shape,
dtype=f"{a_arr.dtype.char}{buffersizes.max()}")
numpy._core.umath._replace(a_arr, old, new, counts, out=out)
return out


Expand Down
34 changes: 34 additions & 0 deletions numpy/_core/tests/test_defchararray.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,40 @@ def test_replace(self):
S4 = self.A.replace(b'3', b'', count=0)
assert_array_equal(S4, self.A)

def test_replace_count_and_size(self):
a = np.array(['0123456789' * i for i in range(4)]
).view(np.char.chararray)
r1 = a.replace('5', 'ABCDE')
assert r1.dtype.itemsize == (3*10 + 3*4) * 4
assert_array_equal(r1, np.array(['01234ABCDE6789' * i
for i in range(4)]))
r2 = a.replace('5', 'ABCDE', count=1)
assert r2.dtype.itemsize == (3*10 + 4) * 4
r3 = a.replace('5', 'ABCDE', count=0)
assert r3.dtype.itemsize == a.dtype.itemsize
assert_array_equal(r3, a)
# Negative values mean to replace all.
r4 = a.replace('5', 'ABCDE', count=-1)
assert r4.dtype.itemsize == (3*10 + 3*4) * 4
assert_array_equal(r4, r1)
# We can do count on an element-by-element basis.
r5 = a.replace('5', 'ABCDE', count=[-1, -1, -1, 1])
assert r5.dtype.itemsize == (3*10 + 4) * 4
assert_array_equal(r5, np.array(
['01234ABCDE6789' * i for i in range(3)]
+ ['01234ABCDE6789' + '0123456789' * 2]))

def test_replace_broadcasting(self):
a = np.array('0,0,0').view(np.char.chararray)
r1 = a.replace('0', '1', count=np.arange(3))
assert r1.dtype == a.dtype
assert_array_equal(r1, np.array(['0,0,0', '1,0,0', '1,1,0']))
r2 = a.replace('0', [['1'], ['2']], count=np.arange(1, 4))
assert_array_equal(r2, np.array([['1,0,0', '1,1,0', '1,1,1'],
['2,0,0', '2,2,0', '2,2,2']]))
r3 = a.replace(['0', '0,0', '0,0,0'], 'X')
assert_array_equal(r3, np.array(['X,X,X', 'X,0', 'X']))

def test_rjust(self):
assert_(issubclass(self.A.rjust(10).dtype.type, np.bytes_))

Expand Down

0 comments on commit 59f44a1

Please sign in to comment.