Skip to content

Commit

Permalink
Merge pull request #4858 from densmirn/feature/str_expandtabs
Browse files Browse the repository at this point in the history
Implement str.expandtabs() based on CPython
  • Loading branch information
seibert committed Dec 10, 2019
2 parents 8044d6e + e775317 commit 3ee0556
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/pysupported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ The following functions, attributes and methods are currently supported:
* ``==``, ``<``, ``<=``, ``>``, ``>=`` (comparison)
* ``.startswith()``
* ``.endswith()``
* ``.expandtabs()``
* ``.isspace()``
* ``.isidentifier()``
* ``.find()``
Expand Down
49 changes: 49 additions & 0 deletions numba/tests/test_unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,18 @@ def endswith_usecase(x, y):
return x.endswith(y)


def expandtabs_usecase(s):
return s.expandtabs()


def expandtabs_with_tabsize_usecase(s, tabsize):
return s.expandtabs(tabsize)


def expandtabs_with_tabsize_kwarg_usecase(s, tabsize):
return s.expandtabs(tabsize=tabsize)


def split_usecase(x, y):
return x.split(y)

Expand Down Expand Up @@ -431,6 +443,43 @@ def test_endswith(self, flags=no_pyobj_flags):
cfunc(a, b),
'%s, %s' % (a, b))

def test_expandtabs(self):
pyfunc = expandtabs_usecase
cfunc = njit(pyfunc)

cases = ['', '\t', 't\tt\t', 'a\t', '\t⚡', 'a\tbc\nab\tc',
'🐍\t⚡', '🐍⚡\n\t\t🐍\t', 'ab\rab\t\t\tab\r\n\ta']

msg = 'Results of "{}".expandtabs() must be equal'
for s in cases:
self.assertEqual(pyfunc(s), cfunc(s), msg=msg.format(s))

def test_expandtabs_with_tabsize(self):
pyfuncs = [expandtabs_with_tabsize_usecase,
expandtabs_with_tabsize_kwarg_usecase]
messages = ['Results of "{}".expandtabs({}) must be equal',
'Results of "{}".expandtabs(tabsize={}) must be equal']

cases = ['', '\t', 't\tt\t', 'a\t', '\t⚡', 'a\tbc\nab\tc',
'🐍\t⚡', '🐍⚡\n\t\t🐍\t', 'ab\rab\t\t\tab\r\n\ta']

for s in cases:
for tabsize in range(-1, 10):
for pyfunc, msg in zip(pyfuncs, messages):
cfunc = njit(pyfunc)
self.assertEqual(pyfunc(s, tabsize), cfunc(s, tabsize),
msg=msg.format(s, tabsize))

def test_expandtabs_exception_noninteger_tabsize(self):
pyfunc = expandtabs_with_tabsize_usecase
cfunc = njit(pyfunc)

accepted_types = (types.Integer, int)
with self.assertRaises(TypingError) as raises:
cfunc('\t', 2.4)
msg = '"tabsize" must be {}, not float'.format(accepted_types)
self.assertIn(msg, str(raises.exception))

def test_in(self, flags=no_pyobj_flags):
pyfunc = in_usecase
cfunc = njit(pyfunc)
Expand Down
68 changes: 68 additions & 0 deletions numba/unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
_PyUnicode_IsCased, _PyUnicode_IsCaseIgnorable,
_PyUnicode_IsUppercase, _PyUnicode_IsLowercase,
_PyUnicode_IsTitlecase, _Py_ISLOWER, _Py_ISUPPER,
_Py_TAB, _Py_LINEFEED,
_Py_CARRIAGE_RETURN, _Py_SPACE,
_PyUnicode_IsAlpha, _PyUnicode_IsNumeric,
_Py_ISALPHA,)

Expand Down Expand Up @@ -745,6 +747,72 @@ def endswith_impl(a, b):
return endswith_impl


# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L11519-L11595 # noqa: E501
@overload_method(types.UnicodeType, 'expandtabs')
def unicode_expandtabs(data, tabsize=8):
"""Implements str.expandtabs()"""
thety = tabsize
# if the type is omitted, the concrete type is the value
if isinstance(tabsize, types.Omitted):
thety = tabsize.value
# if the type is optional, the concrete type is the captured type
elif isinstance(tabsize, types.Optional):
thety = tabsize.type

accepted = (types.Integer, int)
if thety is not None and not isinstance(thety, accepted):
raise TypingError(
'"tabsize" must be {}, not {}'.format(accepted, tabsize))

def expandtabs_impl(data, tabsize=8):
length = len(data)
j = line_pos = 0
found = False
for i in range(length):
code_point = _get_code_point(data, i)
if code_point == _Py_TAB:
found = True
if tabsize > 0:
# cannot overflow
incr = tabsize - (line_pos % tabsize)
if j > sys.maxsize - incr:
raise OverflowError('new string is too long')
line_pos += incr
j += incr
else:
if j > sys.maxsize - 1:
raise OverflowError('new string is too long')
line_pos += 1
j += 1
if code_point in (_Py_LINEFEED, _Py_CARRIAGE_RETURN):
line_pos = 0

if not found:
return data

res = _empty_string(data._kind, j, data._is_ascii)
j = line_pos = 0
for i in range(length):
code_point = _get_code_point(data, i)
if code_point == _Py_TAB:
if tabsize > 0:
incr = tabsize - (line_pos % tabsize)
line_pos += incr
for idx in range(j, j + incr):
_set_code_point(res, idx, _Py_SPACE)
j += incr
else:
line_pos += 1
_set_code_point(res, j, code_point)
j += 1
if code_point in (_Py_LINEFEED, _Py_CARRIAGE_RETURN):
line_pos = 0

return res

return expandtabs_impl


@overload_method(types.UnicodeType, 'split')
def unicode_split(a, sep=None, maxsplit=-1):
if not (maxsplit == -1 or
Expand Down
6 changes: 6 additions & 0 deletions numba/unicode_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@
#


_Py_TAB = 0x9
_Py_LINEFEED = 0xa
_Py_CARRIAGE_RETURN = 0xd
_Py_SPACE = 0x20


class _PyUnicode_TyperecordMasks(IntEnum):
ALPHA_MASK = 0x01
DECIMAL_MASK = 0x02
Expand Down

0 comments on commit 3ee0556

Please sign in to comment.