Skip to content

Commit

Permalink
count() method for unicode string
Browse files Browse the repository at this point in the history
  • Loading branch information
rdesai16 committed Jun 19, 2019
1 parent f67296d commit 82b7e57
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 0 deletions.
31 changes: 31 additions & 0 deletions numba/tests/test_unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ def find_usecase(x, y):
return x.find(y)


def count_usecase(x, y):
return x.count(y)


def count_with_startend_usecase(x,y,start,end):
return x.count(y, start, end)


def startswith_usecase(x, y):
return x.startswith(y)

Expand Down Expand Up @@ -337,6 +345,29 @@ def test_find(self, flags=no_pyobj_flags):
cfunc(a, substr),
"'%s'.find('%s')?" % (a, substr))

def test_count(self):
pyfunc = count_usecase
cfunc = njit(pyfunc)

for s in UNICODE_EXAMPLES:
extras = [' ', 'xx', s[::-1], s[:-2], s[3:], s, s + s]
for sub in [x for x in extras]:
self.assertEqual(pyfunc(sub, s),
cfunc(sub, s),
"'%s' in '%s'?" % (sub, s))

def test_count_with_startend(self):
pyfunc = count_with_startend_usecase
cfunc = njit(pyfunc)

for s in UNICODE_EXAMPLES:
extras = [' ', 'xx', s[::-1], s[:-2], s[3:], s, s + s]
for sub in [x for x in extras]:
for i , j in zip(range(-2,4), (0,6)):
self.assertEqual(pyfunc(sub, s, 2, -1),
cfunc(sub, s, 2, -1),
"'%s' in '%s'?" % (sub, s))

def test_getitem(self):
pyfunc = getitem_usecase
cfunc = njit(pyfunc)
Expand Down
61 changes: 61 additions & 0 deletions numba/unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,67 @@ def find_impl(a, b):
return find_impl


@overload_method(types.UnicodeType, 'count')
def unicode_count(src, sub, start=None, end=None):

if not (start is None or isinstance(start, (types.Omitted, types.Integer, types.Optional))):
raise TypingError("Start arg must be of type Integer, Omitted or Optional")

if isinstance(sub, types.UnicodeType):
def count_impl(src, sub, start=start, e=end):
count = 0
src_len = len(src)
sub_len = len(sub)

if start is None: #This means no optional arguments are given
begin = 0
new_end = src_len
else:
begin = start
new_end = e
if (begin < 0 and new_end < 0 and begin < new_end):
cond_flag = 1
else:
cond_flag = 0
if (begin >= 0 or cond_flag == 1):
if(new_end < 0 and cond_flag == 0):
new_end = src_len + new_end # Positive end bound makes the search easier
if(new_end < begin):
return 0
if(cond_flag == 1): # If both bounds are negative then turn it into positive limits
begin = begin + src_len
new_end = new_end + src_len
i = begin
if( (new_end - begin) == sub_len):
temp_end = new_end
else:
temp_end = new_end - sub_len + 1
while (i < temp_end):
temp_count = 0
offset = 0
for j in range(sub_len):
src_char = _get_code_point(src, i + offset)
sub_char = _get_code_point(sub, j)
if src_char != sub_char:
break
else:
temp_count = temp_count + 1
offset = offset + 1
if (i + offset) > new_end:
break
if temp_count == sub_len:
count = count + 1
i = i + sub_len
else:
i = i + 1
if i == 0:
break
return count
else:
return 0
return count_impl


@overload_method(types.UnicodeType, 'startswith')
def unicode_startswith(a, b):
if isinstance(b, types.UnicodeType):
Expand Down

0 comments on commit 82b7e57

Please sign in to comment.