Skip to content
Permalink
master
Go to file
 
 
Cannot retrieve contributors at this time
2476 lines (1987 sloc) 82.3 KB
import sys
import operator
import numpy as np
from llvmlite.ir import IntType, Constant
from numba.core.extending import (
models,
register_model,
make_attribute_wrapper,
unbox,
box,
NativeValue,
overload,
overload_method,
intrinsic,
register_jitable,
)
from numba.core.imputils import (lower_constant, lower_cast, lower_builtin,
iternext_impl, impl_ret_new_ref, RefType)
from numba.core.datamodel import register_default, StructModel
from numba.core import utils, types, cgutils
from numba.core.pythonapi import (
PY_UNICODE_1BYTE_KIND,
PY_UNICODE_2BYTE_KIND,
PY_UNICODE_4BYTE_KIND,
PY_UNICODE_WCHAR_KIND,
)
from numba._helperlib import c_helpers
from numba.cpython.hashing import _Py_hash_t
from numba.core.unsafe.bytes import memcpy_region
from numba.core.errors import TypingError
from numba.cpython.unicode_support import (_Py_TOUPPER, _Py_TOLOWER, _Py_UCS4,
_Py_ISALNUM,
_PyUnicode_ToUpperFull,
_PyUnicode_ToLowerFull,
_PyUnicode_ToFoldedFull,
_PyUnicode_ToTitleFull,
_PyUnicode_IsPrintable,
_PyUnicode_IsSpace,
_Py_ISSPACE,
_PyUnicode_IsXidStart,
_PyUnicode_IsXidContinue,
_PyUnicode_IsCased,
_PyUnicode_IsCaseIgnorable,
_PyUnicode_IsUppercase,
_PyUnicode_IsLowercase,
_PyUnicode_IsLineBreak,
_Py_ISLINEBREAK,
_Py_ISLINEFEED,
_Py_ISCARRIAGERETURN,
_PyUnicode_IsTitlecase,
_Py_ISLOWER,
_Py_ISUPPER,
_Py_TAB,
_Py_LINEFEED,
_Py_CARRIAGE_RETURN,
_Py_SPACE,
_PyUnicode_IsAlpha,
_PyUnicode_IsNumeric,
_Py_ISALPHA,
_PyUnicode_IsDigit,
_PyUnicode_IsDecimalDigit)
from numba.cpython import slicing
_py38_or_later = utils.PYVERSION >= (3, 8)
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L84-L85 # noqa: E501
_MAX_UNICODE = 0x10ffff
# DATA MODEL
@register_model(types.UnicodeType)
class UnicodeModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
('data', types.voidptr),
('length', types.intp),
('kind', types.int32),
('is_ascii', types.uint32),
('hash', _Py_hash_t),
('meminfo', types.MemInfoPointer(types.voidptr)),
# A pointer to the owner python str/unicode object
('parent', types.pyobject),
]
models.StructModel.__init__(self, dmm, fe_type, members)
make_attribute_wrapper(types.UnicodeType, 'data', '_data')
make_attribute_wrapper(types.UnicodeType, 'length', '_length')
make_attribute_wrapper(types.UnicodeType, 'kind', '_kind')
make_attribute_wrapper(types.UnicodeType, 'is_ascii', '_is_ascii')
make_attribute_wrapper(types.UnicodeType, 'hash', '_hash')
@register_default(types.UnicodeIteratorType)
class UnicodeIteratorModel(StructModel):
def __init__(self, dmm, fe_type):
members = [('index', types.EphemeralPointer(types.uintp)),
('data', fe_type.data)]
super(UnicodeIteratorModel, self).__init__(dmm, fe_type, members)
# CAST
def compile_time_get_string_data(obj):
"""Get string data from a python string for use at compile-time to embed
the string data into the LLVM module.
"""
from ctypes import (
CFUNCTYPE, c_void_p, c_int, c_uint, c_ssize_t, c_ubyte, py_object,
POINTER, byref,
)
extract_unicode_fn = c_helpers['extract_unicode']
proto = CFUNCTYPE(c_void_p, py_object, POINTER(c_ssize_t), POINTER(c_int),
POINTER(c_uint), POINTER(c_ssize_t))
fn = proto(extract_unicode_fn)
length = c_ssize_t()
kind = c_int()
is_ascii = c_uint()
hashv = c_ssize_t()
data = fn(obj, byref(length), byref(kind), byref(is_ascii), byref(hashv))
if data is None:
raise ValueError("cannot extract unicode data from the given string")
length = length.value
kind = kind.value
is_ascii = is_ascii.value
nbytes = (length + 1) * _kind_to_byte_width(kind)
out = (c_ubyte * nbytes).from_address(data)
return bytes(out), length, kind, is_ascii, hashv.value
def make_string_from_constant(context, builder, typ, literal_string):
"""
Get string data by `compile_time_get_string_data()` and return a
unicode_type LLVM value
"""
databytes, length, kind, is_ascii, hashv = \
compile_time_get_string_data(literal_string)
mod = builder.module
gv = context.insert_const_bytes(mod, databytes)
uni_str = cgutils.create_struct_proxy(typ)(context, builder)
uni_str.data = gv
uni_str.length = uni_str.length.type(length)
uni_str.kind = uni_str.kind.type(kind)
uni_str.is_ascii = uni_str.is_ascii.type(is_ascii)
# Set hash to -1 to indicate that it should be computed.
# We cannot bake in the hash value because of hashseed randomization.
uni_str.hash = uni_str.hash.type(-1)
return uni_str._getvalue()
@lower_cast(types.StringLiteral, types.unicode_type)
def cast_from_literal(context, builder, fromty, toty, val):
return make_string_from_constant(
context, builder, toty, fromty.literal_value,
)
# CONSTANT
@lower_constant(types.unicode_type)
def constant_unicode(context, builder, typ, pyval):
return make_string_from_constant(context, builder, typ, pyval)
# BOXING
@unbox(types.UnicodeType)
def unbox_unicode_str(typ, obj, c):
"""
Convert a unicode str object to a native unicode structure.
"""
ok, data, length, kind, is_ascii, hashv = \
c.pyapi.string_as_string_size_and_kind(obj)
uni_str = cgutils.create_struct_proxy(typ)(c.context, c.builder)
uni_str.data = data
uni_str.length = length
uni_str.kind = kind
uni_str.is_ascii = is_ascii
uni_str.hash = hashv
uni_str.meminfo = c.pyapi.nrt_meminfo_new_from_pyobject(
data, # the borrowed data pointer
obj, # the owner pyobject; the call will incref it.
)
uni_str.parent = obj
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
return NativeValue(uni_str._getvalue(), is_error=is_error)
@box(types.UnicodeType)
def box_unicode_str(typ, val, c):
"""
Convert a native unicode structure to a unicode string
"""
uni_str = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
res = c.pyapi.string_from_kind_and_data(
uni_str.kind, uni_str.data, uni_str.length)
# hash isn't needed now, just compute it so it ends up in the unicodeobject
# hash cache, cpython doesn't always do this, depends how a string was
# created it's safe, just burns the cycles required to hash on @box
c.pyapi.object_hash(res)
c.context.nrt.decref(c.builder, typ, val)
return res
# HELPER FUNCTIONS
def make_deref_codegen(bitsize):
def codegen(context, builder, signature, args):
data, idx = args
ptr = builder.bitcast(data, IntType(bitsize).as_pointer())
ch = builder.load(builder.gep(ptr, [idx]))
return builder.zext(ch, IntType(32))
return codegen
@intrinsic
def deref_uint8(typingctx, data, offset):
sig = types.uint32(types.voidptr, types.intp)
return sig, make_deref_codegen(8)
@intrinsic
def deref_uint16(typingctx, data, offset):
sig = types.uint32(types.voidptr, types.intp)
return sig, make_deref_codegen(16)
@intrinsic
def deref_uint32(typingctx, data, offset):
sig = types.uint32(types.voidptr, types.intp)
return sig, make_deref_codegen(32)
@intrinsic
def _malloc_string(typingctx, kind, char_bytes, length, is_ascii):
"""make empty string with data buffer of size alloc_bytes.
Must set length and kind values for string after it is returned
"""
def details(context, builder, signature, args):
[kind_val, char_bytes_val, length_val, is_ascii_val] = args
# fill the struct
uni_str_ctor = cgutils.create_struct_proxy(types.unicode_type)
uni_str = uni_str_ctor(context, builder)
# add null padding character
nbytes_val = builder.mul(char_bytes_val,
builder.add(length_val,
Constant(length_val.type, 1)))
uni_str.meminfo = context.nrt.meminfo_alloc(builder, nbytes_val)
uni_str.kind = kind_val
uni_str.is_ascii = is_ascii_val
uni_str.length = length_val
# empty string has hash value -1 to indicate "need to compute hash"
uni_str.hash = context.get_constant(_Py_hash_t, -1)
uni_str.data = context.nrt.meminfo_data(builder, uni_str.meminfo)
# Set parent to NULL
uni_str.parent = cgutils.get_null_value(uni_str.parent.type)
return uni_str._getvalue()
sig = types.unicode_type(types.int32, types.intp, types.intp, types.uint32)
return sig, details
@register_jitable
def _empty_string(kind, length, is_ascii=0):
char_width = _kind_to_byte_width(kind)
s = _malloc_string(kind, char_width, length, is_ascii)
_set_code_point(s, length, np.uint32(0)) # Write NULL character
return s
# Disable RefCt for performance.
@register_jitable(_nrt=False)
def _get_code_point(a, i):
if a._kind == PY_UNICODE_1BYTE_KIND:
return deref_uint8(a._data, i)
elif a._kind == PY_UNICODE_2BYTE_KIND:
return deref_uint16(a._data, i)
elif a._kind == PY_UNICODE_4BYTE_KIND:
return deref_uint32(a._data, i)
else:
# there's also a wchar kind, but that's one of the above,
# so skipping for this example
return 0
####
def make_set_codegen(bitsize):
def codegen(context, builder, signature, args):
data, idx, ch = args
if bitsize < 32:
ch = builder.trunc(ch, IntType(bitsize))
ptr = builder.bitcast(data, IntType(bitsize).as_pointer())
builder.store(ch, builder.gep(ptr, [idx]))
return context.get_dummy_value()
return codegen
@intrinsic
def set_uint8(typingctx, data, idx, ch):
sig = types.void(types.voidptr, types.int64, types.uint32)
return sig, make_set_codegen(8)
@intrinsic
def set_uint16(typingctx, data, idx, ch):
sig = types.void(types.voidptr, types.int64, types.uint32)
return sig, make_set_codegen(16)
@intrinsic
def set_uint32(typingctx, data, idx, ch):
sig = types.void(types.voidptr, types.int64, types.uint32)
return sig, make_set_codegen(32)
@register_jitable(_nrt=False)
def _set_code_point(a, i, ch):
# WARNING: This method is very dangerous:
# * Assumes that data contents can be changed (only allowed for new
# strings)
# * Assumes that the kind of unicode string is sufficiently wide to
# accept ch. Will truncate ch to make it fit.
# * Assumes that i is within the valid boundaries of the function
if a._kind == PY_UNICODE_1BYTE_KIND:
set_uint8(a._data, i, ch)
elif a._kind == PY_UNICODE_2BYTE_KIND:
set_uint16(a._data, i, ch)
elif a._kind == PY_UNICODE_4BYTE_KIND:
set_uint32(a._data, i, ch)
else:
raise AssertionError(
"Unexpected unicode representation in _set_code_point")
@register_jitable
def _pick_kind(kind1, kind2):
if kind1 == PY_UNICODE_WCHAR_KIND or kind2 == PY_UNICODE_WCHAR_KIND:
raise AssertionError("PY_UNICODE_WCHAR_KIND unsupported")
if kind1 == PY_UNICODE_1BYTE_KIND:
return kind2
elif kind1 == PY_UNICODE_2BYTE_KIND:
if kind2 == PY_UNICODE_4BYTE_KIND:
return kind2
else:
return kind1
elif kind1 == PY_UNICODE_4BYTE_KIND:
return kind1
else:
raise AssertionError("Unexpected unicode representation in _pick_kind")
@register_jitable
def _pick_ascii(is_ascii1, is_ascii2):
if is_ascii1 == 1 and is_ascii2 == 1:
return types.uint32(1)
return types.uint32(0)
@register_jitable
def _kind_to_byte_width(kind):
if kind == PY_UNICODE_1BYTE_KIND:
return 1
elif kind == PY_UNICODE_2BYTE_KIND:
return 2
elif kind == PY_UNICODE_4BYTE_KIND:
return 4
elif kind == PY_UNICODE_WCHAR_KIND:
raise AssertionError("PY_UNICODE_WCHAR_KIND unsupported")
else:
raise AssertionError("Unexpected unicode encoding encountered")
@register_jitable(_nrt=False)
def _cmp_region(a, a_offset, b, b_offset, n):
if n == 0:
return 0
elif a_offset + n > a._length:
return -1
elif b_offset + n > b._length:
return 1
for i in range(n):
a_chr = _get_code_point(a, a_offset + i)
b_chr = _get_code_point(b, b_offset + i)
if a_chr < b_chr:
return -1
elif a_chr > b_chr:
return 1
return 0
@register_jitable
def _codepoint_to_kind(cp):
"""
Compute the minimum unicode kind needed to hold a given codepoint
"""
if cp < 256:
return PY_UNICODE_1BYTE_KIND
elif cp < 65536:
return PY_UNICODE_2BYTE_KIND
else:
# Maximum code point of Unicode 6.0: 0x10ffff (1,114,111)
MAX_UNICODE = 0x10ffff
if cp > MAX_UNICODE:
msg = "Invalid codepoint. Found value greater than Unicode maximum"
raise ValueError(msg)
return PY_UNICODE_4BYTE_KIND
@register_jitable
def _codepoint_is_ascii(ch):
"""
Returns true if a codepoint is in the ASCII range
"""
return ch < 128
# PUBLIC API
@overload(str)
def unicode_str(s):
if isinstance(s, types.UnicodeType):
return lambda s: s
@overload(len)
def unicode_len(s):
if isinstance(s, types.UnicodeType):
def len_impl(s):
return s._length
return len_impl
@overload(operator.eq)
def unicode_eq(a, b):
if not (a.is_internal and b.is_internal):
return
accept = (types.UnicodeType, types.StringLiteral, types.UnicodeCharSeq)
a_unicode = isinstance(a, accept)
b_unicode = isinstance(b, accept)
if a_unicode and b_unicode:
def eq_impl(a, b):
# the str() is for UnicodeCharSeq, it's a nop else
a = str(a)
b = str(b)
if len(a) != len(b):
return False
return _cmp_region(a, 0, b, 0, len(a)) == 0
return eq_impl
elif a_unicode ^ b_unicode:
# one of the things is unicode, everything compares False
def eq_impl(a, b):
return False
return eq_impl
@overload(operator.ne)
def unicode_ne(a, b):
if not (a.is_internal and b.is_internal):
return
accept = (types.UnicodeType, types.StringLiteral, types.UnicodeCharSeq)
a_unicode = isinstance(a, accept)
b_unicode = isinstance(b, accept)
if a_unicode and b_unicode:
def ne_impl(a, b):
return not (a == b)
return ne_impl
elif a_unicode ^ b_unicode:
# one of the things is unicode, everything compares True
def eq_impl(a, b):
return True
return eq_impl
@overload(operator.lt)
def unicode_lt(a, b):
a_unicode = isinstance(a, (types.UnicodeType, types.StringLiteral))
b_unicode = isinstance(b, (types.UnicodeType, types.StringLiteral))
if a_unicode and b_unicode:
def lt_impl(a, b):
minlen = min(len(a), len(b))
eqcode = _cmp_region(a, 0, b, 0, minlen)
if eqcode == -1:
return True
elif eqcode == 0:
return len(a) < len(b)
return False
return lt_impl
@overload(operator.gt)
def unicode_gt(a, b):
a_unicode = isinstance(a, (types.UnicodeType, types.StringLiteral))
b_unicode = isinstance(b, (types.UnicodeType, types.StringLiteral))
if a_unicode and b_unicode:
def gt_impl(a, b):
minlen = min(len(a), len(b))
eqcode = _cmp_region(a, 0, b, 0, minlen)
if eqcode == 1:
return True
elif eqcode == 0:
return len(a) > len(b)
return False
return gt_impl
@overload(operator.le)
def unicode_le(a, b):
a_unicode = isinstance(a, (types.UnicodeType, types.StringLiteral))
b_unicode = isinstance(b, (types.UnicodeType, types.StringLiteral))
if a_unicode and b_unicode:
def le_impl(a, b):
return not (a > b)
return le_impl
@overload(operator.ge)
def unicode_ge(a, b):
a_unicode = isinstance(a, (types.UnicodeType, types.StringLiteral))
b_unicode = isinstance(b, (types.UnicodeType, types.StringLiteral))
if a_unicode and b_unicode:
def ge_impl(a, b):
return not (a < b)
return ge_impl
@overload(operator.contains)
def unicode_contains(a, b):
if isinstance(a, types.UnicodeType) and isinstance(b, types.UnicodeType):
def contains_impl(a, b):
# note parameter swap: contains(a, b) == b in a
return _find(a, b) > -1
return contains_impl
def unicode_idx_check_type(ty, name):
"""Check object belongs to one of specific types
ty: type
Type of the object
name: str
Name of the object
"""
thety = ty
# if the type is omitted, the concrete type is the value
if isinstance(ty, types.Omitted):
thety = ty.value
# if the type is optional, the concrete type is the captured type
elif isinstance(ty, types.Optional):
thety = ty.type
accepted = (types.Integer, types.NoneType)
if thety is not None and not isinstance(thety, accepted):
raise TypingError('"{}" must be {}, not {}'.format(name, accepted, ty))
def unicode_sub_check_type(ty, name):
"""Check object belongs to unicode type"""
if not isinstance(ty, types.UnicodeType):
msg = '"{}" must be {}, not {}'.format(name, types.UnicodeType, ty)
raise TypingError(msg)
def generate_finder(find_func):
"""Generate finder either left or right."""
def impl(data, substr, start=None, end=None):
length = len(data)
sub_length = len(substr)
if start is None:
start = 0
if end is None:
end = length
start, end = _adjust_indices(length, start, end)
if end - start < sub_length:
return -1
return find_func(data, substr, start, end)
return impl
@register_jitable
def _finder(data, substr, start, end):
"""Left finder."""
if len(substr) == 0:
return start
for i in range(start, min(len(data), end) - len(substr) + 1):
if _cmp_region(data, i, substr, 0, len(substr)) == 0:
return i
return -1
@register_jitable
def _rfinder(data, substr, start, end):
"""Right finder."""
if len(substr) == 0:
return end
for i in range(min(len(data), end) - len(substr), start - 1, -1):
if _cmp_region(data, i, substr, 0, len(substr)) == 0:
return i
return -1
_find = register_jitable(generate_finder(_finder))
_rfind = register_jitable(generate_finder(_rfinder))
@overload_method(types.UnicodeType, 'find')
def unicode_find(data, substr, start=None, end=None):
"""Implements str.find()"""
if isinstance(substr, types.UnicodeCharSeq):
def find_impl(data, substr, start=None, end=None):
return data.find(str(substr))
return find_impl
unicode_idx_check_type(start, 'start')
unicode_idx_check_type(end, 'end')
unicode_sub_check_type(substr, 'substr')
return _find
@overload_method(types.UnicodeType, 'rfind')
def unicode_rfind(data, substr, start=None, end=None):
"""Implements str.rfind()"""
if isinstance(substr, types.UnicodeCharSeq):
def rfind_impl(data, substr, start=None, end=None):
return data.rfind(str(substr))
return rfind_impl
unicode_idx_check_type(start, 'start')
unicode_idx_check_type(end, 'end')
unicode_sub_check_type(substr, 'substr')
return _rfind
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L12831-L12857 # noqa: E501
@overload_method(types.UnicodeType, 'rindex')
def unicode_rindex(s, sub, start=None, end=None):
"""Implements str.rindex()"""
unicode_idx_check_type(start, 'start')
unicode_idx_check_type(end, 'end')
unicode_sub_check_type(sub, 'sub')
def rindex_impl(s, sub, start=None, end=None):
result = s.rfind(sub, start, end)
if result < 0:
raise ValueError('substring not found')
return result
return rindex_impl
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L11692-L11718 # noqa: E501
@overload_method(types.UnicodeType, 'index')
def unicode_index(s, sub, start=None, end=None):
"""Implements str.index()"""
unicode_idx_check_type(start, 'start')
unicode_idx_check_type(end, 'end')
unicode_sub_check_type(sub, 'sub')
def index_impl(s, sub, start=None, end=None):
result = s.find(sub, start, end)
if result < 0:
raise ValueError('substring not found')
return result
return index_impl
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L12922-L12976 # noqa: E501
@overload_method(types.UnicodeType, 'partition')
def unicode_partition(data, sep):
"""Implements str.partition()"""
thety = sep
# if the type is omitted, the concrete type is the value
if isinstance(sep, types.Omitted):
thety = sep.value
# if the type is optional, the concrete type is the captured type
elif isinstance(sep, types.Optional):
thety = sep.type
accepted = (types.UnicodeType, types.UnicodeCharSeq)
if thety is not None and not isinstance(thety, accepted):
msg = '"{}" must be {}, not {}'.format('sep', accepted, sep)
raise TypingError(msg)
def impl(data, sep):
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/stringlib/partition.h#L7-L60 # noqa: E501
sep = str(sep)
empty_str = _empty_string(data._kind, 0, data._is_ascii)
sep_length = len(sep)
if data._kind < sep._kind or len(data) < sep_length:
return data, empty_str, empty_str
if sep_length == 0:
raise ValueError('empty separator')
pos = data.find(sep)
if pos < 0:
return data, empty_str, empty_str
return data[0:pos], sep, data[pos + sep_length:len(data)]
return impl
@overload_method(types.UnicodeType, 'count')
def unicode_count(src, sub, start=None, end=None):
_count_args_types_check(start)
_count_args_types_check(end)
if isinstance(sub, types.UnicodeType):
def count_impl(src, sub, start=None, end=None):
count = 0
src_len = len(src)
sub_len = len(sub)
start = _normalize_slice_idx_count(start, src_len, 0)
end = _normalize_slice_idx_count(end, src_len, src_len)
if end - start < 0 or start > src_len:
return 0
src = src[start : end]
src_len = len(src)
start, end = 0, src_len
if sub_len == 0:
return src_len + 1
while(start + sub_len <= src_len):
if src[start : start + sub_len] == sub:
count += 1
start += sub_len
else:
start += 1
return count
return count_impl
error_msg = "The substring must be a UnicodeType, not {}"
raise TypingError(error_msg.format(type(sub)))
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L12979-L13033 # noqa: E501
@overload_method(types.UnicodeType, 'rpartition')
def unicode_rpartition(data, sep):
"""Implements str.rpartition()"""
thety = sep
# if the type is omitted, the concrete type is the value
if isinstance(sep, types.Omitted):
thety = sep.value
# if the type is optional, the concrete type is the captured type
elif isinstance(sep, types.Optional):
thety = sep.type
accepted = (types.UnicodeType, types.UnicodeCharSeq)
if thety is not None and not isinstance(thety, accepted):
msg = '"{}" must be {}, not {}'.format('sep', accepted, sep)
raise TypingError(msg)
def impl(data, sep):
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/stringlib/partition.h#L62-L115 # noqa: E501
sep = str(sep)
empty_str = _empty_string(data._kind, 0, data._is_ascii)
sep_length = len(sep)
if data._kind < sep._kind or len(data) < sep_length:
return empty_str, empty_str, data
if sep_length == 0:
raise ValueError('empty separator')
pos = data.rfind(sep)
if pos < 0:
return empty_str, empty_str, data
return data[0:pos], sep, data[pos + sep_length:len(data)]
return impl
@overload_method(types.UnicodeType, 'startswith')
def unicode_startswith(a, b):
if isinstance(b, types.UnicodeType):
def startswith_impl(a, b):
return _cmp_region(a, 0, b, 0, len(b)) == 0
return startswith_impl
if isinstance(b, types.UnicodeCharSeq):
def startswith_impl(a, b):
return a.startswith(str(b))
return startswith_impl
# https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodeobject.c#L9342-L9354 # noqa: E501
@register_jitable
def _adjust_indices(length, start, end):
if end > length:
end = length
if end < 0:
end += length
if end < 0:
end = 0
if start < 0:
start += length
if start < 0:
start = 0
return start, end
@overload_method(types.UnicodeType, 'endswith')
def unicode_endswith(s, substr, start=None, end=None):
if not (start is None or isinstance(start, (types.Omitted,
types.Integer,
types.NoneType))):
raise TypingError('The arg must be a Integer or None')
if not (end is None or isinstance(end, (types.Omitted,
types.Integer,
types.NoneType))):
raise TypingError('The arg must be a Integer or None')
if isinstance(substr, (types.Tuple, types.UniTuple)):
def endswith_impl(s, substr, start=None, end=None):
for item in substr:
if s.endswith(item, start, end) is True:
return True
return False
return endswith_impl
if isinstance(substr, types.UnicodeType):
def endswith_impl(s, substr, start=None, end=None):
length = len(s)
sub_length = len(substr)
if start is None:
start = 0
if end is None:
end = length
start, end = _adjust_indices(length, start, end)
if end - start < sub_length:
return False
if sub_length == 0:
return True
s = s[start:end]
offset = len(s) - sub_length
return _cmp_region(s, offset, substr, 0, sub_length) == 0
return endswith_impl
if isinstance(substr, types.UnicodeCharSeq):
def endswith_impl(s, substr, start=None, end=None):
return s.endswith(str(substr), start, end)
return endswith_impl
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L11519-L11595 # noqa: E501
@overload_method(types.UnicodeType, 'expandtabs')
def unicode_expandtabs(data, tabsize=8):
"""Implements str.expandtabs()"""
thety = tabsize
# if the type is omitted, the concrete type is the value
if isinstance(tabsize, types.Omitted):
thety = tabsize.value
# if the type is optional, the concrete type is the captured type
elif isinstance(tabsize, types.Optional):
thety = tabsize.type
accepted = (types.Integer, int)
if thety is not None and not isinstance(thety, accepted):
raise TypingError(
'"tabsize" must be {}, not {}'.format(accepted, tabsize))
def expandtabs_impl(data, tabsize=8):
length = len(data)
j = line_pos = 0
found = False
for i in range(length):
code_point = _get_code_point(data, i)
if code_point == _Py_TAB:
found = True
if tabsize > 0:
# cannot overflow
incr = tabsize - (line_pos % tabsize)
if j > sys.maxsize - incr:
raise OverflowError('new string is too long')
line_pos += incr
j += incr
else:
if j > sys.maxsize - 1:
raise OverflowError('new string is too long')
line_pos += 1
j += 1
if code_point in (_Py_LINEFEED, _Py_CARRIAGE_RETURN):
line_pos = 0
if not found:
return data
res = _empty_string(data._kind, j, data._is_ascii)
j = line_pos = 0
for i in range(length):
code_point = _get_code_point(data, i)
if code_point == _Py_TAB:
if tabsize > 0:
incr = tabsize - (line_pos % tabsize)
line_pos += incr
for idx in range(j, j + incr):
_set_code_point(res, idx, _Py_SPACE)
j += incr
else:
line_pos += 1
_set_code_point(res, j, code_point)
j += 1
if code_point in (_Py_LINEFEED, _Py_CARRIAGE_RETURN):
line_pos = 0
return res
return expandtabs_impl
@overload_method(types.UnicodeType, 'split')
def unicode_split(a, sep=None, maxsplit=-1):
if not (maxsplit == -1 or
isinstance(maxsplit, (types.Omitted, types.Integer,
types.IntegerLiteral))):
return None # fail typing if maxsplit is not an integer
if isinstance(sep, types.UnicodeCharSeq):
def split_impl(a, sep=None, maxsplit=-1):
return a.split(str(sep), maxsplit=maxsplit)
return split_impl
if isinstance(sep, types.UnicodeType):
def split_impl(a, sep=None, maxsplit=-1):
a_len = len(a)
sep_len = len(sep)
if sep_len == 0:
raise ValueError('empty separator')
parts = []
last = 0
idx = 0
if sep_len == 1 and maxsplit == -1:
sep_code_point = _get_code_point(sep, 0)
for idx in range(a_len):
if _get_code_point(a, idx) == sep_code_point:
parts.append(a[last:idx])
last = idx + 1
else:
split_count = 0
while idx < a_len and (maxsplit == -1 or
split_count < maxsplit):
if _cmp_region(a, idx, sep, 0, sep_len) == 0:
parts.append(a[last:idx])
idx += sep_len
last = idx
split_count += 1
else:
idx += 1
if last <= a_len:
parts.append(a[last:])
return parts
return split_impl
elif sep is None or isinstance(sep, types.NoneType) or \
getattr(sep, 'value', False) is None:
def split_whitespace_impl(a, sep=None, maxsplit=-1):
a_len = len(a)
parts = []
last = 0
idx = 0
split_count = 0
in_whitespace_block = True
for idx in range(a_len):
code_point = _get_code_point(a, idx)
is_whitespace = _PyUnicode_IsSpace(code_point)
if in_whitespace_block:
if is_whitespace:
pass # keep consuming space
else:
last = idx # this is the start of the next string
in_whitespace_block = False
else:
if not is_whitespace:
pass # keep searching for whitespace transition
else:
parts.append(a[last:idx])
in_whitespace_block = True
split_count += 1
if maxsplit != -1 and split_count == maxsplit:
break
if last <= a_len and not in_whitespace_block:
parts.append(a[last:])
return parts
return split_whitespace_impl
def generate_rsplit_whitespace_impl(isspace_func):
"""Generate whitespace rsplit func based on either ascii or unicode"""
def rsplit_whitespace_impl(data, sep=None, maxsplit=-1):
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/stringlib/split.h#L192-L240 # noqa: E501
if maxsplit < 0:
maxsplit = sys.maxsize
result = []
i = len(data) - 1
while maxsplit > 0:
while i >= 0:
code_point = _get_code_point(data, i)
if not isspace_func(code_point):
break
i -= 1
if i < 0:
break
j = i
i -= 1
while i >= 0:
code_point = _get_code_point(data, i)
if isspace_func(code_point):
break
i -= 1
result.append(data[i + 1:j + 1])
maxsplit -= 1
if i >= 0:
# Only occurs when maxsplit was reached
# Skip any remaining whitespace and copy to beginning of string
while i >= 0:
code_point = _get_code_point(data, i)
if not isspace_func(code_point):
break
i -= 1
if i >= 0:
result.append(data[0:i + 1])
return result[::-1]
return rsplit_whitespace_impl
unicode_rsplit_whitespace_impl = register_jitable(
generate_rsplit_whitespace_impl(_PyUnicode_IsSpace))
ascii_rsplit_whitespace_impl = register_jitable(
generate_rsplit_whitespace_impl(_Py_ISSPACE))
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L13095-L13108 # noqa: E501
@overload_method(types.UnicodeType, 'rsplit')
def unicode_rsplit(data, sep=None, maxsplit=-1):
"""Implements str.unicode_rsplit()"""
def _unicode_rsplit_check_type(ty, name, accepted):
"""Check object belongs to one of specified types"""
thety = ty
# if the type is omitted, the concrete type is the value
if isinstance(ty, types.Omitted):
thety = ty.value
# if the type is optional, the concrete type is the captured type
elif isinstance(ty, types.Optional):
thety = ty.type
if thety is not None and not isinstance(thety, accepted):
raise TypingError(
'"{}" must be {}, not {}'.format(name, accepted, ty))
_unicode_rsplit_check_type(sep, 'sep', (types.UnicodeType,
types.UnicodeCharSeq,
types.NoneType))
_unicode_rsplit_check_type(maxsplit, 'maxsplit', (types.Integer, int))
if sep is None or isinstance(sep, (types.NoneType, types.Omitted)):
def rsplit_whitespace_impl(data, sep=None, maxsplit=-1):
if data._is_ascii:
return ascii_rsplit_whitespace_impl(data, sep, maxsplit)
return unicode_rsplit_whitespace_impl(data, sep, maxsplit)
return rsplit_whitespace_impl
def rsplit_impl(data, sep=None, maxsplit=-1):
sep = str(sep)
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/stringlib/split.h#L286-L333 # noqa: E501
if data._kind < sep._kind or len(data) < len(sep):
return [data]
def _rsplit_char(data, ch, maxsplit):
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/stringlib/split.h#L242-L284 # noqa: E501
result = []
ch_code_point = _get_code_point(ch, 0)
i = j = len(data) - 1
while i >= 0 and maxsplit > 0:
data_code_point = _get_code_point(data, i)
if data_code_point == ch_code_point:
result.append(data[i + 1 : j + 1])
j = i = i - 1
maxsplit -= 1
i -= 1
if j >= -1:
result.append(data[0 : j + 1])
return result[::-1]
if maxsplit < 0:
maxsplit = sys.maxsize
sep_length = len(sep)
if sep_length == 0:
raise ValueError('empty separator')
if sep_length == 1:
return _rsplit_char(data, sep, maxsplit)
result = []
j = len(data)
while maxsplit > 0:
pos = data.rfind(sep, start=0, end=j)
if pos < 0:
break
result.append(data[pos + sep_length:j])
j = pos
maxsplit -= 1
result.append(data[0:j])
return result[::-1]
return rsplit_impl
@overload_method(types.UnicodeType, 'center')
def unicode_center(string, width, fillchar=' '):
if not isinstance(width, types.Integer):
raise TypingError('The width must be an Integer')
if isinstance(fillchar, types.UnicodeCharSeq):
def center_impl(string, width, fillchar=' '):
return string.center(width, str(fillchar))
return center_impl
if not (fillchar == ' ' or
isinstance(fillchar, (types.Omitted, types.UnicodeType))):
raise TypingError('The fillchar must be a UnicodeType')
def center_impl(string, width, fillchar=' '):
str_len = len(string)
fillchar_len = len(fillchar)
if fillchar_len != 1:
raise ValueError('The fill character must be exactly one '
'character long')
if width <= str_len:
return string
allmargin = width - str_len
lmargin = (allmargin // 2) + (allmargin & width & 1)
rmargin = allmargin - lmargin
l_string = fillchar * lmargin
if lmargin == rmargin:
return l_string + string + l_string
else:
return l_string + string + (fillchar * rmargin)
return center_impl
def gen_unicode_Xjust(STRING_FIRST):
def unicode_Xjust(string, width, fillchar=' '):
if not isinstance(width, types.Integer):
raise TypingError('The width must be an Integer')
if isinstance(fillchar, types.UnicodeCharSeq):
if STRING_FIRST:
def ljust_impl(string, width, fillchar=' '):
return string.ljust(width, str(fillchar))
return ljust_impl
else:
def rjust_impl(string, width, fillchar=' '):
return string.rjust(width, str(fillchar))
return rjust_impl
if not (fillchar == ' ' or
isinstance(fillchar, (types.Omitted, types.UnicodeType))):
raise TypingError('The fillchar must be a UnicodeType')
def impl(string, width, fillchar=' '):
str_len = len(string)
fillchar_len = len(fillchar)
if fillchar_len != 1:
raise ValueError('The fill character must be exactly one '
'character long')
if width <= str_len:
return string
newstr = (fillchar * (width - str_len))
if STRING_FIRST:
return string + newstr
else:
return newstr + string
return impl
return unicode_Xjust
overload_method(types.UnicodeType, 'rjust')(gen_unicode_Xjust(False))
overload_method(types.UnicodeType, 'ljust')(gen_unicode_Xjust(True))
def generate_splitlines_func(is_line_break_func):
"""Generate splitlines performer based on ascii or unicode line breaks."""
def impl(data, keepends):
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/stringlib/split.h#L335-L389 # noqa: E501
length = len(data)
result = []
i = j = 0
while i < length:
# find a line and append it
while i < length:
code_point = _get_code_point(data, i)
if is_line_break_func(code_point):
break
i += 1
# skip the line break reading CRLF as one line break
eol = i
if i < length:
if i + 1 < length:
cur_cp = _get_code_point(data, i)
next_cp = _get_code_point(data, i + 1)
if _Py_ISCARRIAGERETURN(cur_cp) and _Py_ISLINEFEED(next_cp):
i += 1
i += 1
if keepends:
eol = i
result.append(data[j:eol])
j = i
return result
return impl
_ascii_splitlines = register_jitable(generate_splitlines_func(_Py_ISLINEBREAK))
_unicode_splitlines = register_jitable(generate_splitlines_func(
_PyUnicode_IsLineBreak))
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L10196-L10229 # noqa: E501
@overload_method(types.UnicodeType, 'splitlines')
def unicode_splitlines(data, keepends=False):
"""Implements str.splitlines()"""
thety = keepends
# if the type is omitted, the concrete type is the value
if isinstance(keepends, types.Omitted):
thety = keepends.value
# if the type is optional, the concrete type is the captured type
elif isinstance(keepends, types.Optional):
thety = keepends.type
accepted = (types.Integer, int, types.Boolean, bool)
if thety is not None and not isinstance(thety, accepted):
raise TypingError(
'"{}" must be {}, not {}'.format('keepends', accepted, keepends))
def splitlines_impl(data, keepends=False):
if data._is_ascii:
return _ascii_splitlines(data, keepends)
return _unicode_splitlines(data, keepends)
return splitlines_impl
@register_jitable
def join_list(sep, parts):
parts_len = len(parts)
if parts_len == 0:
return ''
# Precompute size and char_width of result
sep_len = len(sep)
length = (parts_len - 1) * sep_len
kind = sep._kind
is_ascii = sep._is_ascii
for p in parts:
length += len(p)
kind = _pick_kind(kind, p._kind)
is_ascii = _pick_ascii(is_ascii, p._is_ascii)
result = _empty_string(kind, length, is_ascii)
# populate string
part = parts[0]
_strncpy(result, 0, part, 0, len(part))
dst_offset = len(part)
for idx in range(1, parts_len):
_strncpy(result, dst_offset, sep, 0, sep_len)
dst_offset += sep_len
part = parts[idx]
_strncpy(result, dst_offset, part, 0, len(part))
dst_offset += len(part)
return result
@overload_method(types.UnicodeType, 'join')
def unicode_join(sep, parts):
if isinstance(parts, types.List):
if isinstance(parts.dtype, types.UnicodeType):
def join_list_impl(sep, parts):
return join_list(sep, parts)
return join_list_impl
elif isinstance(parts.dtype, types.UnicodeCharSeq):
def join_list_impl(sep, parts):
_parts = [str(p) for p in parts]
return join_list(sep, _parts)
return join_list_impl
else:
pass # lists of any other type not supported
elif isinstance(parts, types.IterableType):
def join_iter_impl(sep, parts):
parts_list = [p for p in parts]
return join_list(sep, parts_list)
return join_iter_impl
elif isinstance(parts, types.UnicodeType):
# Temporary workaround until UnicodeType is iterable
def join_str_impl(sep, parts):
parts_list = [parts[i] for i in range(len(parts))]
return join_list(sep, parts_list)
return join_str_impl
@overload_method(types.UnicodeType, 'zfill')
def unicode_zfill(string, width):
if not isinstance(width, types.Integer):
raise TypingError("<width> must be an Integer")
def zfill_impl(string, width):
str_len = len(string)
if width <= str_len:
return string
first_char = string[0] if str_len else ''
padding = '0' * (width - str_len)
if first_char in ['+', '-']:
newstr = first_char + padding + string[1:]
else:
newstr = padding + string
return newstr
return zfill_impl
# ------------------------------------------------------------------------------
# Strip functions
# ------------------------------------------------------------------------------
@register_jitable
def unicode_strip_left_bound(string, chars):
str_len = len(string)
i = 0
if chars is not None:
for i in range(str_len):
if string[i] not in chars:
return i
else:
for i in range(str_len):
if not _PyUnicode_IsSpace(string[i]):
return i
return str_len
@register_jitable
def unicode_strip_right_bound(string, chars):
str_len = len(string)
i = 0
if chars is not None:
for i in range(str_len - 1, -1, -1):
if string[i] not in chars:
i += 1
break
else:
for i in range(str_len - 1, -1, -1):
if not _PyUnicode_IsSpace(string[i]):
i += 1
break
return i
def unicode_strip_types_check(chars):
if isinstance(chars, types.Optional):
chars = chars.type # catch optional type with invalid non-None type
if not (chars is None or isinstance(chars, (types.Omitted,
types.UnicodeType,
types.NoneType))):
raise TypingError('The arg must be a UnicodeType or None')
def _count_args_types_check(arg):
if isinstance(arg, types.Optional):
arg = arg.type
if not (arg is None or isinstance(arg, (types.Omitted,
types.Integer,
types.NoneType))):
raise TypingError("The slice indices must be an Integer or None")
@overload_method(types.UnicodeType, 'lstrip')
def unicode_lstrip(string, chars=None):
if isinstance(chars, types.UnicodeCharSeq):
def lstrip_impl(string, chars=None):
return string.lstrip(str(chars))
return lstrip_impl
unicode_strip_types_check(chars)
def lstrip_impl(string, chars=None):
return string[unicode_strip_left_bound(string, chars):]
return lstrip_impl
@overload_method(types.UnicodeType, 'rstrip')
def unicode_rstrip(string, chars=None):
if isinstance(chars, types.UnicodeCharSeq):
def rstrip_impl(string, chars=None):
return string.rstrip(str(chars))
return rstrip_impl
unicode_strip_types_check(chars)
def rstrip_impl(string, chars=None):
return string[:unicode_strip_right_bound(string, chars)]
return rstrip_impl
@overload_method(types.UnicodeType, 'strip')
def unicode_strip(string, chars=None):
if isinstance(chars, types.UnicodeCharSeq):
def strip_impl(string, chars=None):
return string.strip(str(chars))
return strip_impl
unicode_strip_types_check(chars)
def strip_impl(string, chars=None):
lb = unicode_strip_left_bound(string, chars)
rb = unicode_strip_right_bound(string, chars)
return string[lb:rb]
return strip_impl
# ------------------------------------------------------------------------------
# Slice functions
# ------------------------------------------------------------------------------
@register_jitable
def normalize_str_idx(idx, length, is_start=True):
"""
Parameters
----------
idx : int or None
the index
length : int
the string length
is_start : bool; optional with defaults to True
Is it the *start* or the *stop* of the slice?
Returns
-------
norm_idx : int
normalized index
"""
if idx is None:
if is_start:
return 0
else:
return length
elif idx < 0:
idx += length
if idx < 0 or idx >= length:
raise IndexError("string index out of range")
return idx
@register_jitable
def _normalize_slice_idx_count(arg, slice_len, default):
"""
Used for unicode_count
If arg < -slice_len, returns 0 (prevents circle)
If arg is within slice, e.g -slice_len <= arg < slice_len
returns its real index via arg % slice_len
If arg > slice_len, returns arg (in this case count must
return 0 if it is start index)
"""
if arg is None:
return default
if -slice_len <= arg < slice_len:
return arg % slice_len
return 0 if arg < 0 else arg
@intrinsic
def _normalize_slice(typingctx, sliceobj, length):
"""Fix slice object.
"""
sig = sliceobj(sliceobj, length)
def codegen(context, builder, sig, args):
[slicetype, lengthtype] = sig.args
[sliceobj, length] = args
slice = context.make_helper(builder, slicetype, sliceobj)
slicing.guard_invalid_slice(context, builder, slicetype, slice)
slicing.fix_slice(builder, slice, length)
return slice._getvalue()
return sig, codegen
@intrinsic
def _slice_span(typingctx, sliceobj):
"""Compute the span from the given slice object.
"""
sig = types.intp(sliceobj)
def codegen(context, builder, sig, args):
[slicetype] = sig.args
[sliceobj] = args
slice = context.make_helper(builder, slicetype, sliceobj)
result_size = slicing.get_slice_length(builder, slice)
return result_size
return sig, codegen
@register_jitable(_nrt=False)
def _strncpy(dst, dst_offset, src, src_offset, n):
if src._kind == dst._kind:
byte_width = _kind_to_byte_width(src._kind)
src_byte_offset = byte_width * src_offset
dst_byte_offset = byte_width * dst_offset
nbytes = n * byte_width
memcpy_region(dst._data, dst_byte_offset, src._data,
src_byte_offset, nbytes, align=1)
else:
for i in range(n):
_set_code_point(dst, dst_offset + i,
_get_code_point(src, src_offset + i))
@intrinsic
def _get_str_slice_view(typingctx, src_t, start_t, length_t):
"""Create a slice of a unicode string using a view of its data to avoid
extra allocation.
"""
assert src_t == types.unicode_type
def codegen(context, builder, sig, args):
src, start, length = args
in_str = cgutils.create_struct_proxy(
types.unicode_type)(context, builder, value=src)
view_str = cgutils.create_struct_proxy(
types.unicode_type)(context, builder)
view_str.meminfo = in_str.meminfo
view_str.kind = in_str.kind
view_str.is_ascii = in_str.is_ascii
view_str.length = length
# hash value -1 to indicate "need to compute hash"
view_str.hash = context.get_constant(_Py_hash_t, -1)
# get a pointer to start of slice data
bw_typ = context.typing_context.resolve_value_type(_kind_to_byte_width)
bw_sig = bw_typ.get_call_type(
context.typing_context, (types.int32,), {})
bw_impl = context.get_function(bw_typ, bw_sig)
byte_width = bw_impl(builder, (in_str.kind,))
offset = builder.mul(start, byte_width)
view_str.data = builder.gep(in_str.data, [offset])
# Set parent pyobject to NULL
view_str.parent = cgutils.get_null_value(view_str.parent.type)
# incref original string
if context.enable_nrt:
context.nrt.incref(builder, sig.args[0], src)
return view_str._getvalue()
sig = types.unicode_type(types.unicode_type, types.intp, types.intp)
return sig, codegen
@overload(operator.getitem)
def unicode_getitem(s, idx):
if isinstance(s, types.UnicodeType):
if isinstance(idx, types.Integer):
def getitem_char(s, idx):
idx = normalize_str_idx(idx, len(s))
cp = _get_code_point(s, idx)
kind = _codepoint_to_kind(cp)
is_ascii = _codepoint_is_ascii(cp)
ret = _empty_string(kind, 1, is_ascii)
_set_code_point(ret, 0, cp)
return ret
return getitem_char
elif isinstance(idx, types.SliceType):
def getitem_slice(s, idx):
slice_idx = _normalize_slice(idx, len(s))
span = _slice_span(slice_idx)
cp = _get_code_point(s, slice_idx.start)
kind = _codepoint_to_kind(cp)
is_ascii = _codepoint_is_ascii(cp)
# Check slice to see if it's homogeneous in kind
for i in range(slice_idx.start + slice_idx.step,
slice_idx.stop, slice_idx.step):
cp = _get_code_point(s, i)
is_ascii |= _codepoint_is_ascii(cp)
new_kind = _codepoint_to_kind(cp)
if kind != new_kind:
kind = _pick_kind(kind, new_kind)
# TODO: it might be possible to break here if the kind
# is PY_UNICODE_4BYTE_KIND but there are potentially
# strings coming from other internal functions that are
# this wide and also actually ASCII (i.e. kind is larger
# than actually required for storing the code point), so
# it's necessary to continue.
if slice_idx.step == 1 and kind == s._kind:
# Can return a view, the slice has the same kind as the
# string itself and it's a stride slice 1.
return _get_str_slice_view(s, slice_idx.start, span)
else:
# It's heterogeneous in kind OR stride != 1
ret = _empty_string(kind, span, is_ascii)
cur = slice_idx.start
for i in range(span):
_set_code_point(ret, i, _get_code_point(s, cur))
cur += slice_idx.step
return ret
return getitem_slice
# ------------------------------------------------------------------------------
# String operations
# ------------------------------------------------------------------------------
@overload(operator.add)
@overload(operator.iadd)
def unicode_concat(a, b):
if isinstance(a, types.UnicodeType) and isinstance(b, types.UnicodeType):
def concat_impl(a, b):
new_length = a._length + b._length
new_kind = _pick_kind(a._kind, b._kind)
new_ascii = _pick_ascii(a._is_ascii, b._is_ascii)
result = _empty_string(new_kind, new_length, new_ascii)
for i in range(len(a)):
_set_code_point(result, i, _get_code_point(a, i))
for j in range(len(b)):
_set_code_point(result, len(a) + j, _get_code_point(b, j))
return result
return concat_impl
if isinstance(a, types.UnicodeType) and isinstance(b, types.UnicodeCharSeq):
def concat_impl(a, b):
return a + str(b)
return concat_impl
@register_jitable
def _repeat_impl(str_arg, mult_arg):
if str_arg == '' or mult_arg < 1:
return ''
elif mult_arg == 1:
return str_arg
else:
new_length = str_arg._length * mult_arg
new_kind = str_arg._kind
result = _empty_string(new_kind, new_length, str_arg._is_ascii)
# make initial copy into result
len_a = len(str_arg)
_strncpy(result, 0, str_arg, 0, len_a)
# loop through powers of 2 for efficient copying
copy_size = len_a
while 2 * copy_size <= new_length:
_strncpy(result, copy_size, result, 0, copy_size)
copy_size *= 2
if not 2 * copy_size == new_length:
# if copy_size not an exact multiple it then needs
# to complete the rest of the copies
rest = new_length - copy_size
_strncpy(result, copy_size, result, copy_size - rest, rest)
return result
@overload(operator.mul)
def unicode_repeat(a, b):
if isinstance(a, types.UnicodeType) and isinstance(b, types.Integer):
def wrap(a, b):
return _repeat_impl(a, b)
return wrap
elif isinstance(a, types.Integer) and isinstance(b, types.UnicodeType):
def wrap(a, b):
return _repeat_impl(b, a)
return wrap
@overload(operator.not_)
def unicode_not(a):
if isinstance(a, types.UnicodeType):
def impl(a):
return len(a) == 0
return impl
@overload_method(types.UnicodeType, 'replace')
def unicode_replace(s, old_str, new_str, count=-1):
thety = count
if isinstance(count, types.Omitted):
thety = count.value
elif isinstance(count, types.Optional):
thety = count.type
if not isinstance(thety, (int, types.Integer)):
raise TypingError('Unsupported parameters. The parametrs '
'must be Integer. Given count: {}'.format(count))
if not isinstance(old_str, (types.UnicodeType, types.NoneType)):
raise TypingError('The object must be a UnicodeType.'
' Given: {}'.format(old_str))
if not isinstance(new_str, types.UnicodeType):
raise TypingError('The object must be a UnicodeType.'
' Given: {}'.format(new_str))
def impl(s, old_str, new_str, count=-1):
if count == 0:
return s
if old_str == '':
schars = list(s)
if count == -1:
return new_str + new_str.join(schars) + new_str
split_result = [new_str]
min_count = min(len(schars), count)
for i in range(min_count):
split_result.append(schars[i])
if i + 1 != min_count:
split_result.append(new_str)
else:
split_result.append(''.join(schars[(i + 1):]))
if count > len(schars):
split_result.append(new_str)
return ''.join(split_result)
schars = s.split(old_str, count)
result = new_str.join(schars)
return result
return impl
# ------------------------------------------------------------------------------
# String `is*()` methods
# ------------------------------------------------------------------------------
# generates isalpha/isalnum
def gen_isAlX(ascii_func, unicode_func):
def unicode_isAlX(data):
def impl(data):
length = len(data)
if length == 0:
return False
if length == 1:
code_point = _get_code_point(data, 0)
if data._is_ascii:
return ascii_func(code_point)
else:
return unicode_func(code_point)
if data._is_ascii:
for i in range(length):
code_point = _get_code_point(data, i)
if not ascii_func(code_point):
return False
for i in range(length):
code_point = _get_code_point(data, i)
if not unicode_func(code_point):
return False
return True
return impl
return unicode_isAlX
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L11928-L11964 # noqa: E501
overload_method(types.UnicodeType, 'isalpha')(gen_isAlX(_Py_ISALPHA,
_PyUnicode_IsAlpha))
_unicode_is_alnum = register_jitable(lambda x:
(_PyUnicode_IsNumeric(x) or
_PyUnicode_IsAlpha(x)))
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L11975-L12006 # noqa: E501
overload_method(types.UnicodeType, 'isalnum')(gen_isAlX(_Py_ISALNUM,
_unicode_is_alnum))
def _is_upper(is_lower, is_upper, is_title):
# impl is an approximate translation of:
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L11794-L11827 # noqa: E501
# mixed with:
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/bytes_methods.c#L218-L242 # noqa: E501
def impl(a):
l = len(a)
if l == 1:
return is_upper(_get_code_point(a, 0))
if l == 0:
return False
cased = False
for idx in range(l):
code_point = _get_code_point(a, idx)
if is_lower(code_point) or is_title(code_point):
return False
elif(not cased and is_upper(code_point)):
cased = True
return cased
return impl
_always_false = register_jitable(lambda x: False)
_ascii_is_upper = register_jitable(_is_upper(_Py_ISLOWER, _Py_ISUPPER,
_always_false))
_unicode_is_upper = register_jitable(_is_upper(_PyUnicode_IsLowercase,
_PyUnicode_IsUppercase,
_PyUnicode_IsTitlecase))
@overload_method(types.UnicodeType, 'isupper')
def unicode_isupper(a):
"""
Implements .isupper()
"""
def impl(a):
if a._is_ascii:
return _ascii_is_upper(a)
else:
return _unicode_is_upper(a)
return impl
if utils.PYVERSION >= (3, 7):
@overload_method(types.UnicodeType, 'isascii')
def unicode_isascii(data):
"""Implements UnicodeType.isascii()"""
def impl(data):
return data._is_ascii
return impl
@overload_method(types.UnicodeType, 'istitle')
def unicode_istitle(data):
"""
Implements UnicodeType.istitle()
The algorithm is an approximate translation from CPython:
https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L11829-L11885 # noqa: E501
"""
def impl(data):
length = len(data)
if length == 1:
char = _get_code_point(data, 0)
return _PyUnicode_IsUppercase(char) or _PyUnicode_IsTitlecase(char)
if length == 0:
return False
cased = False
previous_is_cased = False
for idx in range(length):
char = _get_code_point(data, idx)
if _PyUnicode_IsUppercase(char) or _PyUnicode_IsTitlecase(char):
if previous_is_cased:
return False
previous_is_cased = True
cased = True
elif _PyUnicode_IsLowercase(char):
if not previous_is_cased:
return False
previous_is_cased = True
cased = True
else:
previous_is_cased = False
return cased
return impl
@overload_method(types.UnicodeType, 'islower')
def unicode_islower(data):
"""
impl is an approximate translation of:
https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodeobject.c#L11900-L11933 # noqa: E501
mixed with:
https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/bytes_methods.c#L131-L156 # noqa: E501
"""
def impl(data):
length = len(data)
if length == 1:
return _PyUnicode_IsLowercase(_get_code_point(data, 0))
if length == 0:
return False
cased = False
for idx in range(length):
cp = _get_code_point(data, idx)
if _PyUnicode_IsUppercase(cp) or _PyUnicode_IsTitlecase(cp):
return False
elif not cased and _PyUnicode_IsLowercase(cp):
cased = True
return cased
return impl
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L12126-L12161 # noqa: E501
@overload_method(types.UnicodeType, 'isidentifier')
def unicode_isidentifier(data):
"""Implements UnicodeType.isidentifier()"""
def impl(data):
length = len(data)
if length == 0:
return False
first_cp = _get_code_point(data, 0)
if not _PyUnicode_IsXidStart(first_cp) and first_cp != 0x5F:
return False
for i in range(1, length):
code_point = _get_code_point(data, i)
if not _PyUnicode_IsXidContinue(code_point):
return False
return True
return impl
# generator for simple unicode "isX" methods
def gen_isX(_PyUnicode_IS_func, empty_is_false=True):
def unicode_isX(data):
def impl(data):
length = len(data)
if length == 1:
return _PyUnicode_IS_func(_get_code_point(data, 0))
if empty_is_false and length == 0:
return False
for i in range(length):
code_point = _get_code_point(data, i)
if not _PyUnicode_IS_func(code_point):
return False
return True
return impl
return unicode_isX
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L11896-L11925 # noqa: E501
overload_method(types.UnicodeType, 'isspace')(gen_isX(_PyUnicode_IsSpace))
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L12096-L12124 # noqa: E501
overload_method(types.UnicodeType, 'isnumeric')(gen_isX(_PyUnicode_IsNumeric))
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L12056-L12085 # noqa: E501
overload_method(types.UnicodeType, 'isdigit')(gen_isX(_PyUnicode_IsDigit))
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L12017-L12045 # noqa: E501
overload_method(types.UnicodeType, 'isdecimal')(
gen_isX(_PyUnicode_IsDecimalDigit))
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L12188-L12213 # noqa: E501
overload_method(types.UnicodeType, 'isprintable')(
gen_isX(_PyUnicode_IsPrintable, False))
# ------------------------------------------------------------------------------
# String methods that apply a transformation to the characters themselves
# ------------------------------------------------------------------------------
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L9863-L9908 # noqa: E501
def case_operation(ascii_func, unicode_func):
"""Generate common case operation performer."""
def impl(data):
length = len(data)
if length == 0:
return _empty_string(data._kind, length, data._is_ascii)
if data._is_ascii:
res = _empty_string(data._kind, length, 1)
ascii_func(data, res)
return res
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L9863-L9908 # noqa: E501
tmp = _empty_string(PY_UNICODE_4BYTE_KIND, 3 * length, data._is_ascii)
# maxchar should be inside of a list to be pass as argument by reference
maxchars = [0]
newlength = unicode_func(data, length, tmp, maxchars)
maxchar = maxchars[0]
newkind = _codepoint_to_kind(maxchar)
res = _empty_string(newkind, newlength, _codepoint_is_ascii(maxchar))
for i in range(newlength):
_set_code_point(res, i, _get_code_point(tmp, i))
return res
return impl
# https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodeobject.c#L9856-L9883 # noqa: E501
@register_jitable
def _handle_capital_sigma(data, length, idx):
"""This is a translation of the function that handles the capital sigma."""
c = 0
j = idx - 1
while j >= 0:
c = _get_code_point(data, j)
if not _PyUnicode_IsCaseIgnorable(c):
break
j -= 1
final_sigma = (j >= 0 and _PyUnicode_IsCased(c))
if final_sigma:
j = idx + 1
while j < length:
c = _get_code_point(data, j)
if not _PyUnicode_IsCaseIgnorable(c):
break
j += 1
final_sigma = (j == length or (not _PyUnicode_IsCased(c)))
return 0x3c2 if final_sigma else 0x3c3
# https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodeobject.c#L9885-L9895 # noqa: E501
@register_jitable
def _lower_ucs4(code_point, data, length, idx, mapped):
"""This is a translation of the function that lowers a character."""
if code_point == 0x3A3:
mapped[0] = _handle_capital_sigma(data, length, idx)
return 1
return _PyUnicode_ToLowerFull(code_point, mapped)
# https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodeobject.c#L9946-L9965 # noqa: E501
def _gen_unicode_upper_or_lower(lower):
def _do_upper_or_lower(data, length, res, maxchars):
k = 0
for idx in range(length):
mapped = np.zeros(3, dtype=_Py_UCS4)
code_point = _get_code_point(data, idx)
if lower:
n_res = _lower_ucs4(code_point, data, length, idx, mapped)
else:
# might be needed if call _do_upper_or_lower in unicode_upper
n_res = _PyUnicode_ToUpperFull(code_point, mapped)
for m in mapped[:n_res]:
maxchars[0] = max(maxchars[0], m)
_set_code_point(res, k, m)
k += 1
return k
return _do_upper_or_lower
_unicode_upper = register_jitable(_gen_unicode_upper_or_lower(False))
_unicode_lower = register_jitable(_gen_unicode_upper_or_lower(True))
def _gen_ascii_upper_or_lower(func):
def _ascii_upper_or_lower(data, res):
for idx in range(len(data)):
code_point = _get_code_point(data, idx)
_set_code_point(res, idx, func(code_point))
return _ascii_upper_or_lower
_ascii_upper = register_jitable(_gen_ascii_upper_or_lower(_Py_TOUPPER))
_ascii_lower = register_jitable(_gen_ascii_upper_or_lower(_Py_TOLOWER))
@overload_method(types.UnicodeType, 'lower')
def unicode_lower(data):
"""Implements .lower()"""
return case_operation(_ascii_lower, _unicode_lower)
@overload_method(types.UnicodeType, 'upper')
def unicode_upper(data):
"""Implements .upper()"""
return case_operation(_ascii_upper, _unicode_upper)
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L9819-L9834 # noqa: E501
@register_jitable
def _unicode_casefold(data, length, res, maxchars):
k = 0
mapped = np.zeros(3, dtype=_Py_UCS4)
for idx in range(length):
mapped.fill(0)
code_point = _get_code_point(data, idx)
n_res = _PyUnicode_ToFoldedFull(code_point, mapped)
for m in mapped[:n_res]:
maxchar = maxchars[0]
maxchars[0] = max(maxchar, m)
_set_code_point(res, k, m)
k += 1
return k
@register_jitable
def _ascii_casefold(data, res):
for idx in range(len(data)):
code_point = _get_code_point(data, idx)
_set_code_point(res, idx, _Py_TOLOWER(code_point))
@overload_method(types.UnicodeType, 'casefold')
def unicode_casefold(data):
"""Implements str.casefold()"""
return case_operation(_ascii_casefold, _unicode_casefold)
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L9737-L9759 # noqa: E501
@register_jitable
def _unicode_capitalize(data, length, res, maxchars):
k = 0
maxchar = 0
mapped = np.zeros(3, dtype=_Py_UCS4)
code_point = _get_code_point(data, 0)
# https://github.com/python/cpython/commit/b015fc86f7b1f35283804bfee788cce0a5495df7/Objects/unicodeobject.c#diff-220e5da0d1c8abf508b25c02da6ca16c # noqa: E501
if _py38_or_later:
n_res = _PyUnicode_ToTitleFull(code_point, mapped)
else:
n_res = _PyUnicode_ToUpperFull(code_point, mapped)
for m in mapped[:n_res]:
maxchar = max(maxchar, m)
_set_code_point(res, k, m)
k += 1
for idx in range(1, length):
mapped.fill(0)
code_point = _get_code_point(data, idx)
n_res = _lower_ucs4(code_point, data, length, idx, mapped)
for m in mapped[:n_res]:
maxchar = max(maxchar, m)
_set_code_point(res, k, m)
k += 1
maxchars[0] = maxchar
return k
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/bytes_methods.c#L361-L382 # noqa: E501
@register_jitable
def _ascii_capitalize(data, res):
code_point = _get_code_point(data, 0)
_set_code_point(res, 0, _Py_TOUPPER(code_point))
for idx in range(1, len(data)):
code_point = _get_code_point(data, idx)
_set_code_point(res, idx, _Py_TOLOWER(code_point))
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L10765-L10774 # noqa: E501
@overload_method(types.UnicodeType, 'capitalize')
def unicode_capitalize(data):
return case_operation(_ascii_capitalize, _unicode_capitalize)
# https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodeobject.c#L9996-L10021 # noqa: E501
@register_jitable
def _unicode_title(data, length, res, maxchars):
"""This is a translation of the function that titles a unicode string."""
k = 0
previous_cased = False
mapped = np.empty(3, dtype=_Py_UCS4)
for idx in range(length):
mapped.fill(0)
code_point = _get_code_point(data, idx)
if previous_cased:
n_res = _lower_ucs4(code_point, data, length, idx, mapped)
else:
n_res = _PyUnicode_ToTitleFull(_Py_UCS4(code_point), mapped)
for m in mapped[:n_res]:
maxchar, = maxchars
maxchars[0] = max(maxchar, m)
_set_code_point(res, k, m)
k += 1
previous_cased = _PyUnicode_IsCased(_Py_UCS4(code_point))
return k
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/bytes_methods.c#L332-L352 # noqa: E501
@register_jitable
def _ascii_title(data, res):
""" Does .title() on an ASCII string """
previous_is_cased = False
for idx in range(len(data)):
code_point = _get_code_point(data, idx)
if _Py_ISLOWER(code_point):
if not previous_is_cased:
code_point = _Py_TOUPPER(code_point)
previous_is_cased = True
elif _Py_ISUPPER(code_point):
if previous_is_cased:
code_point = _Py_TOLOWER(code_point)
previous_is_cased = True
else:
previous_is_cased = False
_set_code_point(res, idx, code_point)
# https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodeobject.c#L10023-L10069 # noqa: E501
@overload_method(types.UnicodeType, 'title')
def unicode_title(data):
"""Implements str.title()"""
# https://docs.python.org/3/library/stdtypes.html#str.title
return case_operation(_ascii_title, _unicode_title)
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/bytes_methods.c#L391-L408 # noqa: E501
@register_jitable
def _ascii_swapcase(data, res):
for idx in range(len(data)):
code_point = _get_code_point(data, idx)
if _Py_ISUPPER(code_point):
code_point = _Py_TOLOWER(code_point)
elif _Py_ISLOWER(code_point):
code_point = _Py_TOUPPER(code_point)
_set_code_point(res, idx, code_point)
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L9761-L9784 # noqa: E501
@register_jitable
def _unicode_swapcase(data, length, res, maxchars):
k = 0
maxchar = 0
mapped = np.empty(3, dtype=_Py_UCS4)
for idx in range(length):
mapped.fill(0)
code_point = _get_code_point(data, idx)
if _PyUnicode_IsUppercase(code_point):
n_res = _lower_ucs4(code_point, data, length, idx, mapped)
elif _PyUnicode_IsLowercase(code_point):
n_res = _PyUnicode_ToUpperFull(code_point, mapped)
else:
n_res = 1
mapped[0] = code_point
for m in mapped[:n_res]:
maxchar = max(maxchar, m)
_set_code_point(res, k, m)
k += 1
maxchars[0] = maxchar
return k
@overload_method(types.UnicodeType, 'swapcase')
def unicode_swapcase(data):
return case_operation(_ascii_swapcase, _unicode_swapcase)
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Python/bltinmodule.c#L1781-L1824 # noqa: E501
@overload(ord)
def ol_ord(c):
if isinstance(c, types.UnicodeType):
def impl(c):
lc = len(c)
if lc != 1:
# CPython does TypeError
raise TypeError("ord() expected a character")
return _get_code_point(c, 0)
return impl
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L2005-L2028 # noqa: E501
# This looks a bit different to the cpython implementation but, with the
# exception of a latin1 fast path is logically the same. It finds the "kind" of
# the codepoint `ch`, creates a length 1 string of that kind and then injects
# the code point into the zero position of that string. Cpython does similar but
# branches for each kind (this is encapsulated in Numba's _set_code_point).
@register_jitable
def _unicode_char(ch):
assert ch <= _MAX_UNICODE
kind = _codepoint_to_kind(ch)
ret = _empty_string(kind, 1, kind == PY_UNICODE_1BYTE_KIND)
_set_code_point(ret, 0, ch)
return ret
_out_of_range_msg = "chr() arg not in range(0x%hx)" % _MAX_UNICODE
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L3045-L3055 # noqa: E501
@register_jitable
def _PyUnicode_FromOrdinal(ordinal):
if (ordinal < 0 or ordinal > _MAX_UNICODE):
raise ValueError(_out_of_range_msg)
return _unicode_char(_Py_UCS4(ordinal))
# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Python/bltinmodule.c#L715-L720 # noqa: E501
@overload(chr)
def ol_chr(i):
if isinstance(i, types.Integer):
def impl(i):
return _PyUnicode_FromOrdinal(i)
return impl
@overload(str)
def integer_str(n):
if isinstance(n, types.Integer):
ten = n(10)
def impl(n):
flag = False
if n < 0:
n = -n
flag = True
if n == 0:
return '0'
length = flag + 1 + int(np.floor(np.log10(n)))
kind = PY_UNICODE_1BYTE_KIND
char_width = _kind_to_byte_width(kind)
s = _malloc_string(kind, char_width, length, True)
if flag:
_set_code_point(s, 0, ord('-'))
idx = length - 1
while n > 0:
n, digit = divmod(n, ten)
c = ord('0') + digit
_set_code_point(s, idx, c)
idx -= 1
return s
return impl
# ------------------------------------------------------------------------------
# iteration
# ------------------------------------------------------------------------------
@lower_builtin('getiter', types.UnicodeType)
def getiter_unicode(context, builder, sig, args):
[ty] = sig.args
[data] = args
iterobj = context.make_helper(builder, sig.return_type)
# set the index to zero
zero = context.get_constant(types.uintp, 0)
indexptr = cgutils.alloca_once_value(builder, zero)
iterobj.index = indexptr
# wire in the unicode type data
iterobj.data = data
# incref as needed
if context.enable_nrt:
context.nrt.incref(builder, ty, data)
res = iterobj._getvalue()
return impl_ret_new_ref(context, builder, sig.return_type, res)
@lower_builtin('iternext', types.UnicodeIteratorType)
# a new ref counted object is put into result._yield so set the new_ref to True!
@iternext_impl(RefType.NEW)
def iternext_unicode(context, builder, sig, args, result):
[iterty] = sig.args
[iter] = args
tyctx = context.typing_context
# get ref to unicode.__getitem__
fnty = tyctx.resolve_value_type(operator.getitem)
getitem_sig = fnty.get_call_type(tyctx, (types.unicode_type, types.uintp),
{})
getitem_impl = context.get_function(fnty, getitem_sig)
# get ref to unicode.__len__
fnty = tyctx.resolve_value_type(len)
len_sig = fnty.get_call_type(tyctx, (types.unicode_type,), {})
len_impl = context.get_function(fnty, len_sig)
# grab unicode iterator struct
iterobj = context.make_helper(builder, iterty, value=iter)
# find the length of the string
strlen = len_impl(builder, (iterobj.data,))
# find the current index
index = builder.load(iterobj.index)
# see if the index is in range
is_valid = builder.icmp_unsigned('<', index, strlen)
result.set_valid(is_valid)
with builder.if_then(is_valid):
# return value at index
gotitem = getitem_impl(builder, (iterobj.data, index,))
result.yield_(gotitem)
# bump index for next cycle
nindex = cgutils.increment_index(builder, index)
builder.store(nindex, iterobj.index)