New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable support for count() method for unicode string #4204
Conversation
Thanks for the PR. Was wondering why there are two merges included in this patchset? I think |
Thank you @stuartarchibald for the comments. As far as I understand the two merge requests were to sync my forked repo with original numba repo and then sync this branch my forked repo. But you're right. My forked repo wasn't in sync with original numba repo and that is why the old count method PR had issues in it. I made it right by updating my forked repo master branch. But if you feel like there are still some issues with this PR then let me know. |
@stuartarchibald This PR is ready for review in my opinion. |
Thanks @stuartarchibald for your quick response. I will clear the git history but meanwhile the patch can be reviewed on its own. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this is a good start, but the code will need a number of stylistic and algorithmic improvements before it is ready to merge.
numba/tests/test_unicode.py
Outdated
cfunc(sub, s), | ||
"'%s' in '%s'?" % (sub, s)) | ||
|
||
def test_count_with_startend(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for readability, please rename this to test_count_with_start_end
.
numba/unicode.py
Outdated
else: | ||
begin = start | ||
new_end = e | ||
if (begin < 0 and new_end < 0 and begin < new_end): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The next section deals with normalizing and clipping the start
and end
arguments. For simplicity you should probably do that first for each argument separately and you can define a helper function to do so. So for example, if the start argument is negative, take src_length
and add the negative argument to it, if it then is still negative , clip it to zero. If it is positive on the other hand and beyond/larger than src_len
, clip it to src_len
. A similar approach can be taken for the end
argument, and using the helper function will help to reduce overall code length. After this, you will have a valid range/slice in which you will search for the substring. At this stage you can and should also check for early exit conditions, such as the range/slice being zero, the length of the substring being longer than the range/slice and so on (there are a few more). Then, if everything is fine, commence with the substring search proper. This approach should also help reduce the number of cascaded if
statements and make the code more readable overall.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand your suggestion and I have tried to make some changes to consider that. But the issue is that the algorithm is complex and less readable because of the slicing, bound checking conditions and python characteristic to give an answer without error even if the bounds(start,end args) are not correct or feasible for a range. Another issue with creating a helper function is that if you look at the algorithm carefully you see there are certain conditions that need to be checked in the main function for slicing or for returning count value and this complicates separating out some code into a helper function. So I feel that instead of changing the algorithm I can add more comments to make the code readable. Let me know what you think about that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the implementation could work as follows:
- Normalise the slice starts and ends to concrete integer values
- Check if the slice start is in bounds else return 0
- Slice the input string based on the normalised slice
- Deal with empty input string and empty substring cases
- Walk the sliced string using a sliding window of the length of the substring (use a slice?!) to check for matches and either bump the window start position by 1 (no match) or window length (match), accumulate the match count and return this.
Hope this helps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stuartarchibald I have tried to simplify the code to make it more readable in my new changes. Can you review that and see if that looks good? If it doesn't then I'll try simplifying using the algorithm you suggested.
@seibert we are getting very close to the next release candidate ( |
Thank you @esc for the review comments. All the comments are helpful and make sense. So I will make suggested changes and submit it again soon. |
@esc I have made changes as per suggestions. It is ready for review. |
To fix the history, we will soon force push this. |
@stuartarchibald I have used force push to fix the git history. So this PR can be reviewed further now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had another look at this today and here is a simpler implementation that I managed to come up with:
@register_jitable
def normalize_slice_arg(arg, slice_length, default):
if arg is None:
return default
elif arg < 0:
return max(slice_length + arg, default)
else:
return min(slice_length, arg)
@overload_method(types.UnicodeType, 'count')
def unicode_count(src, sub, start=None, end=None):
if not (start is None or isinstance(start, (types.Omitted,
types.Integer,
types.Optional))):
raise TypingError("Start arg must be of type Integer or None")
if isinstance(sub, types.UnicodeType):
def count_impl(src, sub, start=start, end=end):
count, src_len, sub_len = 0, len(src), len(sub)
start = normalize_slice_arg(start, src_len, 0)
end = normalize_slice_arg(end, src_len, src_len)
# Early exit if sub length is zero
if sub_len == 0:
return end - start + 1
i = start
while(i + sub_len <= end):
if src[i:i + sub_len] == sub:
count += 1
i += sub_len
else:
i += 1
return count
return count_impl
As far as I can tell, this passes all of the unit tests, let me know what you think and if there are any additional bugs that I haven't spotted yet.
Also, I think it would be good to have an 'overlapping' test. So for example given the string 'aaaa' and the substring 'aa' we need to find exactly two matches, whereas a naive implementation might find three.
numba/tests/test_unicode.py
Outdated
for s in UNICODE_EXAMPLES: | ||
extras = ['', ' ', 'xx', s[::-1], s[:-2], s[3:], s, s + s] | ||
for sub in [x for x in extras]: | ||
self.assertEqual(pyfunc(sub, s), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe sub
and s
are in the wrong order here. This also applies to the other tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@esc Thanks for the code. I have tried this code and it passes all the tests except one where sub_string is empty. That part needs some change so I'll do that and resubmit it again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, that is good, which test was failing? I didn't notice any failing tests. Did this happen after changing the order or sub
and s
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it happened after changing the order for sub and s. It was failing for a case when the slice_range is bigger than src_len. But apart from that the algorithm is really useful because the same will be used in rfind method as well.
|
||
if not (start is None or isinstance(start, (types.Omitted, types.Integer, types.Optional))): | ||
raise TypingError("Start arg must be of type Integer or None") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you will need to type check the end
arg here too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rdesai16 Thanks for the updates. Given this now contains some of @esc's code I'll continue the review. General feedback is that:
- This is much improved, good job with the fixes!
- The type checking needs some work, I've made some suggestions and written a code sample for your perusal.
- The tests could do with some further work to make them more suitable for testing the function proposed, suggestions are made inline.
Thanks again for your work on this.
numba/tests/test_unicode.py
Outdated
return x.count(y) | ||
|
||
|
||
def count_with_start_end_usecase(x,y,start,end): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
code style: spaces after commas
numba/tests/test_unicode.py
Outdated
return x.count(y, start, end) | ||
|
||
|
||
def count_with_start_only_usecase(x,y,start): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
code style: spaces after commas
numba/unicode.py
Outdated
end = normalize_slice_arg(end, src_len, src_len) | ||
|
||
if (sub_len == 0): #special case when substring is empty | ||
return _count_special_case(start, end, src_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this function just something like: len(src[slice(start, end)]) + 1
?
Also, does this need to be a function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stuartarchibald Are you asking about _count_special_case? If yes then it needs to be a function because it has multiple conditions so including that in a function makes the code more readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is similar to len(src[slice(start, end)]) + 1 but not exactly like that so I think it will be better to put it like this in a separate function. Python has unique behavior when it comes to using such methods with empty strings and more than one condition is needed to cover all the corner cases related to this special case. But if you have other suggestions regarding this I am open for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that this is because the src
string isn't pre-sliced so the computation has to happen. Could the function body just be inlined here, it's small and only called from this location?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay will do that.
numba/unicode.py
Outdated
@overload_method(types.UnicodeType, 'count') | ||
def unicode_count(src, sub, start=None, end=None): | ||
|
||
if not (start is None or isinstance(start, (types.Omitted, types.Integer, types.Optional))): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure that this type check is quite right. Perhaps something like this would help:
def check_kwargs(ty, name):
thety = ty
if isinstance(ty, (types.Omitted)): # if the type is omitted, the concrete type is the value
thety = ty.value
elif isinstance(ty, (types.Optional)): # if the type is optional, the concrete type is the captured type
thety = ty.type
accepted = (types.Integer, types.NoneType, int) # classes that are accepted
if not (thety is None or isinstance(thety, accepted)): # check if the resolved type is None or an accepted type
msg = "Argument '{}' must be of type Integer or None"
raise TypingError(msg.format(name))
check_kwargs(start, "start") # check start arg
check_kwargs(end, "end") # check end arg
this is still not ideal because the Python implementation does not have start
and end
as kwargs, seems like they are positional only args, but there's nothing that can be done about this here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unit tests will need writing to ensure that the type rejection logic is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for my confusion but I am bit confused here. So are you suggesting that we should have a separate function for typechecking and then separate unit test for testing that function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stuartarchibald we have several other functions in unicode.py that do typechecking in similar ways. So in my opinion if we create a separate function then it should be flexible enough so that it can be used everywhere or we can keep it this way to maintain code format. what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that if it is going to be used repeatedly, pulling the function out to module level makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for my confusion but I am bit confused here. So are you suggesting that we should have a separate function for typechecking and then separate unit test for testing that function?
I do not mind how the type checking is implemented, but it seems like a separate function would be useful as exactly the same checks are needed for start
and end
. There should be separate unit tests to check that compilation fails gracefully during type inference for unsupported/invalid types, an example is here:
numba/numba/tests/test_unicode.py
Lines 465 to 470 in 37ab142
def test_repeat_exception_float(self): | |
self.disable_leak_check() | |
cfunc = njit(repeat_usecase) | |
with self.assertRaises(TypingError) as raises: | |
cfunc('hi', 2.5) | |
self.assertIn('Invalid use of Function(<built-in function mul>)', str(raises.exception)) |
this checks the multiplication of a string by a floating point value fails during typing.
Hope this helps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh okay! Got it. I'll make it in a separate function then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stuartarchibald I have a doubt about this test showed as an example. What does the following line do?
self.assertIn('Invalid use of Function()', str(raises.exception))
Can you explain briefly? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.assertIn('Invalid use of Function()', str(raises.exception))
this line is calling the assertIn
function which asserts something contains something else, str(raises.exception)
is converting the exception captured in the raises
object yielded from the with self.assertRaises
into a string so that it can be compared to an expected string. This is a convenient way of checking that the right class of exception was raised, along with the contents of the exception message being correct.
numba/tests/test_unicode.py
Outdated
for s in UNICODE_EXAMPLES: | ||
extras = ['', ' ', 'xx', s[::-1], s[:-2], s[3:], s, s + s] | ||
for sub in [x for x in extras]: | ||
for i , j in zip(range(-2,4), (0,6)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will yield the values:
-2 0
-1 6
was that intended?!
I expect that itertools.product
would be of use here:
for i , j in itertools.product(range(-20, 20), range(-20, 20)):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that is intended because I wanted to cover all kinds of cases in the test. I had thought about using itertools but I can't quite remember now why I ended up using zip. Does itertool have significant benefit over zip?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the requirement is to just testing two values then perhaps just spell those out instead of using range? On the basis that it seems like testing more values would be useful, consider the following:
In [51]: for i, j in zip(range(-2,4), (0,6)):
...: print(i, j)
...:
-2 0
-1 6
In [52]: for i, j in ((-2, 0), (-1, 6)):
...: print(i, j)
...:
-2 0
-1 6
In [53]: for i, j in itertools.product(range(-3, 3), range(-2, 2)):
...: print(i, j)
...:
-3 -2
-3 -1
-3 0
-3 1
-2 -2
-2 -1
-2 0
-2 1
-1 -2
-1 -1
-1 0
-1 1
0 -2
0 -1
0 0
0 1
1 -2
1 -1
1 0
1 1
2 -2
2 -1
2 0
2 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ooh I realize my mistake now! Sorry I had some misunderstanding. I hadn't realized until now that I should have used "range" in front of (0,6). I get what you're saying now. I'll fix it. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that if you use two zipped ranges then the result will be monotonic.
In [56]: for i , j in zip(range(-2,4), range(0,6)):
...: print(i, j)
...:
...:
-2 0
-1 1
0 2
1 3
2 4
3 5
whereas itertools.product
will cover the space start > end, start == end, start < end
with a variety of negative and positive values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you're right! Using itertools.product will cover good range. I will change that.
numba/tests/test_unicode.py
Outdated
for s in UNICODE_EXAMPLES: | ||
extras = ['', ' ', 'xx', s[::-1], s[:-2], s[3:], s, s + s] | ||
for sub in [x for x in extras]: | ||
for i in range(-2, 3): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The magnitude of the range here should be larger to encompass access before and after the valid range, this to accommodate checking with the longer unicode strings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True. I'll change that.
numba/unicode.py
Outdated
if arg is None: | ||
return default | ||
elif arg < 0: | ||
return max( arg + slice_len, default) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return max( arg + slice_len, default) | |
return max(arg + slice_len, default) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh thanks for bringing that to my notice! Will fix it
numba/unicode.py
Outdated
@@ -381,6 +381,18 @@ def _find(substr, s): | |||
return i | |||
return -1 | |||
|
|||
# The following funciton handles special case for count method where substring is empty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# The following funciton handles special case for count method where substring is empty | |
# The following function handles special case for count method where substring is empty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again! Will fix it
numba/unicode.py
Outdated
@@ -381,6 +381,18 @@ def _find(substr, s): | |||
return i | |||
return -1 | |||
|
|||
# The following funciton handles special case for count method where substring is empty | |||
@njit(_nrt=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be @register_jitable
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it should be. Will change that.
pyfunc = count_usecase | ||
cfunc = njit(pyfunc) | ||
|
||
for s in UNICODE_EXAMPLES + ['aaaa']: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure that UNICODE_EXAMPLES
contains suitable material for this test? There's no zero length string, no repeats, no nested repeats. Perhaps take a look at the cpython tests for some inspiration around what would be good to check: https://github.com/python/cpython/blob/5623ac87bbe5de481957eca5eeae06347612fbeb/Lib/test/test_unicode.py#L177-L201
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point! I will add some more tests form cpython add improve it.
Thanks @stuartarchibald for your review. All the comments are very helpful. I have some doubts which I have mentioned in my comments. apart from that I will make necessary changes and submit it. |
@stuartarchibald @esc Ignoring the Travis CI and install errors I have made all the changes as per the comments. You can review it again now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fixes, this is close to working now. However, adding more tests has exposed more issues that need addressing, please see the comments below. Most serious is that there seems to be an algorithmic bug.
Finally, the documentation will need updating to note that this feature has been implemented. The appropriate location is here: https://github.com/numba/numba/blob/81b91f93c60cd4bfe9d6c1cfce1191e64bdf2a5c/docs/source/reference/pysupported.rst#str
Thanks again.
numba/unicode.py
Outdated
@@ -817,6 +872,13 @@ def unicode_strip_types_check(chars): | |||
raise TypingError('The arg must be a UnicodeType or None') | |||
|
|||
|
|||
def args_types_check(arg): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function should probably be called _count_arg_type_check
.
numba/tests/test_unicode.py
Outdated
extras = ['', ' ', 'xx', s[::-1], s[:-2], s[3:], s, s + s, 'a', 'a_', | ||
'\u0102', 'a\u0102', 'a\U00100304', '\u0102_', '\u0102\U00100304'] | ||
for sub in [x for x in extras]: | ||
for i , j in product(range(-10,10), (0,20)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for i , j in product(range(-10,10), (0,20)): | |
for i , j in product(range(-10, 10), (0, 20)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, might it make sense for these ranges extend symmetrically about zero and to the length of the longest test string+1 ? Same comment applies below to the start only
case.
For example, right now, this fails and was not caught by this test:
from numba import njit
@njit
def foo(x, y):
return x.count(y, -40, -40)
a = "ascii"
b = ""
print(foo.py_func(a, b))
print(foo(a, b))
the implementation in this PR erroneously returns 6
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stuartarchibald I am not clear about the range suggestion. So are you suggesting to have the range like (-20,20) or (0,20)?
@@ -869,6 +872,13 @@ def unicode_strip_types_check(chars): | |||
raise TypingError('The arg must be a UnicodeType or None') | |||
|
|||
|
|||
def args_types_check(arg): | |||
if not (arg is None or isinstance(arg, (types.Omitted, | |||
types.Optional, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type of the Optional
should be checked to ensure it also is an integer type, for example:
from numba import njit
@njit
def foo(x, y):
if len(x) > 14:
maybe_none = 'oops'
else:
maybe_none = None
return x.count(y, maybe_none, 1)
a = "ascii"
b = "a"
print(foo.py_func(a, b))
print(foo(a, b))
fails strangely. One way to do this is something like if isinstance(arg, types.Optional) and not isinstance(arg.type, types.Integer: raise ...
. A test should be added to ensure this is correctly implemented.
numba/unicode.py
Outdated
def args_types_check(arg): | ||
if not (arg is None or isinstance(arg, (types.Omitted, | ||
types.Optional, | ||
types.Integer))): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also accept types.NoneType
for cases like:
from numba import njit
@njit
def foo(x, y, none_ty=None):
return x.count(y, none_ty, 1)
a = "ascii"
b = "a"
print(foo.py_func(a, b))
print(foo(a, b))
where start is a NoneType
opposed to the literal None
. A test should be added to ensure this is correctly implemented.
Thanks @stuartarchibald for the review. I'll make the required changes and also add the documentation for it. |
Hi, |
@Vyacheslav-Smirnov sounds good! |
Thanks for writing this patch for Numba. This PR is now unfortunately stale with respect to the current master branch. If you wish to resubmit for review please do so by reopening this PR and rebasing the patch(es) against current master. Thanks for your help in improving Numba! |
enables support for count() method for unicode string