Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equality check between numpy string array and string element fails #6312

Open
njriasan opened this issue Oct 7, 2020 · 2 comments
Open

Equality check between numpy string array and string element fails #6312

njriasan opened this issue Oct 7, 2020 · 2 comments
Labels
bug - incorrect behavior Bugs: incorrect behavior

Comments

@njriasan
Copy link
Contributor

njriasan commented Oct 7, 2020

Hi. I encountered a situation where comparing a Numpy array of strings with a single string doesn't perform element wise comparison and doesn't match the Numpy results. Here is a reproducer:

# Create the array in python 
a = np.array(["hello", "world"])

# Run in normal Numpy
print(a == "world")

# Run in numba nit
print(numba.njit(lambda a: a == "world")(a))

Running this example produces the output:

[False  True]
False

It seems like numba is only performing the comparison between the entire array and the single element and is not doing an element wise comparison.

@esc esc added the bug - incorrect behavior Bugs: incorrect behavior label Oct 7, 2020
@stuartarchibald
Copy link
Contributor

Thanks for the report. This is a bug, the code here:

@overload(operator.eq)
def unicode_eq(a, b):
if not (a.is_internal and b.is_internal):
return
accept = (types.UnicodeType, types.StringLiteral, types.UnicodeCharSeq)
a_unicode = isinstance(a, accept)
b_unicode = isinstance(b, accept)
if a_unicode and b_unicode:
def eq_impl(a, b):
# the str() is for UnicodeCharSeq, it's a nop else
a = str(a)
b = str(b)
if len(a) != len(b):
return False
return _cmp_region(a, 0, b, 0, len(a)) == 0
return eq_impl
elif a_unicode ^ b_unicode:
# one of the things is unicode, everything compares False
def eq_impl(a, b):
return False
return eq_impl

needs to handle arrays of unicode char seq.

@elfjes
Copy link

elfjes commented Sep 3, 2021

I've started working on an implementation for comparing unicode arrays. Currently numba supports comparing unicode arrays when the UnicodeCharSeq elements have the same length, but not with different size (eg (array([unichr x 2], 1d, C), array([unichr x 4], 1d, C))), so that could also be handled nicely in the unicode_eq overload. So far I've come up with this:

@register_jitable
def _cmp_unicode(a, b):
   # the str() is for UnicodeCharSeq, it's a nop else
   a = str(a)
   b = str(b)
   if len(a) != len(b):
       return False
   return _cmp_region(a, 0, b, 0, len(a)) == 0


@register_jitable
def _cmp_unicode_with_array(arr, val):
   rv = np.zeros_like(arr, dtype=np.bool_)
   for i in range(arr.size):
       rv.flat[i] = _cmp_unicode(arr.flat[i], val)
   return rv


@overload(operator.eq)
def unicode_eq(a, b):
   if not (a.is_internal and b.is_internal):
       return
   accept = (types.UnicodeType, types.StringLiteral, types.UnicodeCharSeq)
   a_unicode = isinstance(a, accept)
   b_unicode = isinstance(b, accept)
   a_unicode_array = isinstance(a, types.Array) and isinstance(a.dtype, accept)
   b_unicode_array = isinstance(b, types.Array) and isinstance(b.dtype, accept)

   if a_unicode and b_unicode:
       return _cmp_unicode

   elif a_unicode_array and b_unicode_array:
       def eq_impl(a, b):
           if a.size != b.size:
               raise ValueError("Cannot compare arrays of different size")
           rv = np.zeros_like(a, dtype=np.bool_)
           for i in range(a.size):
               rv.flat[i] = _cmp_unicode(a.flat[i], b.flat[i])
           return rv
       return eq_impl

   elif a_unicode_array and b_unicode:
       def eq_impl(a, b):
           return _cmp_unicode_with_array(a, b)
       return eq_impl

   elif a_unicode and b_unicode_array:
       def eq_impl(a, b):
               return _cmp_unicode_with_array(b, a)
       return eq_impl

   elif a_unicode ^ b_unicode:
       # one of the things is unicode, everything compares False
       def eq_impl(a, b):
           return False
       return eq_impl

This works nicely for comparing two arrays (which may have a different unicode length):

print(numba.njit(lambda a, b: a == b)(np.array(['a','a'], dtype='<U4'), np.array(['a', 'b'])))
# [True, False]

There are however a number of issues that I do not understand:

  • It does not work for comparing an array with a str scalar. eg::
numba.njit(lambda a, b: a == b)(np.array(['a', 'b']), 'a')
# raises exception:
# numba.core.errors.LoweringError: Failed in nopython mode pipeline (step: nopython mode backend)
# unsupported type for input operand: unicode_type
  • Somehow broadcasting does works. I have no idea how that happened.
print(numba.njit(lambda a, b: a == b)(np.array(['a'], dtype='<U4'), np.array(['a', 'b'])))
# [True, False]
  • I would have expected the ValueError to be raised, which didn't happen (or actually doesn't happen at all)

Any ideas on how to improve on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug - incorrect behavior Bugs: incorrect behavior
Projects
None yet
Development

No branches or pull requests

4 participants