diff --git a/numba/pythonapi.py b/numba/pythonapi.py index 40d3ed9a38a..d53fa2be0e3 100644 --- a/numba/pythonapi.py +++ b/numba/pythonapi.py @@ -562,12 +562,21 @@ def object_richcompare(self, lhs, rhs, opstr): of the opid. """ ops = ['<', '<=', '==', '!=', '>', '>='] - opid = ops.index(opstr) - assert 0 <= opid < len(ops) - fnty = Type.function(self.pyobj, [self.pyobj, self.pyobj, Type.int()]) - fn = self._get_function(fnty, name="PyObject_RichCompare") - lopid = self.context.get_constant(types.int32, opid) - return self.builder.call(fn, (lhs, rhs, lopid)) + if opstr in ops: + opid = ops.index(opstr) + fnty = Type.function(self.pyobj, [self.pyobj, self.pyobj, Type.int()]) + fn = self._get_function(fnty, name="PyObject_RichCompare") + lopid = self.context.get_constant(types.int32, opid) + return self.builder.call(fn, (lhs, rhs, lopid)) + elif opstr == 'is': + bitflag = self.builder.icmp(lc.ICMP_EQ, lhs, rhs) + return self.from_native_value(bitflag, types.boolean) + elif opstr == 'is not': + bitflag = self.builder.icmp(lc.ICMP_NE, lhs, rhs) + return self.from_native_value(bitflag, types.boolean) + else: + raise NotImplementedError("Unknown operator {op!r}".format( + op=opstr)) def iter_next(self, iterobj): fnty = Type.function(self.pyobj, [self.pyobj]) diff --git a/numba/tests/test_optional.py b/numba/tests/test_optional.py index d09137a97c6..2644fa35ba0 100644 --- a/numba/tests/test_optional.py +++ b/numba/tests/test_optional.py @@ -1,7 +1,7 @@ from __future__ import print_function, absolute_import import numpy import numba.unittest_support as unittest -from numba.compiler import compile_isolated +from numba.compiler import compile_isolated, Flags from numba import types, typeof, njit from numba.pythonapi import NativeError from numba import lowering @@ -75,6 +75,16 @@ def test_is_this_a_none(self): for v in [-1, 0, 1, 2]: self.assertEqual(pyfunc(v), cfunc(v)) + def test_is_this_a_none_objmode(self): + pyfunc = is_this_a_none + flags = Flags() + flags.set('force_pyobject') + cres = compile_isolated(pyfunc, [types.intp], flags=flags) + cfunc = cres.entry_point + self.assertTrue(cres.objectmode) + for v in [-1, 0, 1, 2]: + self.assertEqual(pyfunc(v), cfunc(v)) + def test_a_is_b_intp(self): pyfunc = a_is_b with self.assertRaises(lowering.LoweringError):