Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport 3324 #3490

Merged
merged 12 commits into from Jul 2, 2013
81 changes: 41 additions & 40 deletions numpy/core/src/multiarray/arrayobject.c
Expand Up @@ -1264,7 +1264,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
{
PyArrayObject *array_other;
PyObject *result = NULL;
PyArray_Descr *dtype = NULL;

switch (cmp_op) {
case Py_LT:
Expand All @@ -1280,28 +1279,30 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_False);
return Py_False;
}
/* Make sure 'other' is an array */
if (PyArray_TYPE(self) == NPY_OBJECT) {
dtype = PyArray_DTYPE(self);
Py_INCREF(dtype);
}
array_other = (PyArrayObject *)PyArray_FromAny(other, dtype, 0, 0, 0,
NULL);
result = PyArray_GenericBinaryFunction(self,
(PyObject *)other,
n_ops.equal);
if (result && result != Py_NotImplemented)
break;

/*
* If not successful, indicate that the items cannot be compared
* this way.
* The ufunc does not support void/structured types, so these
* need to be handled specifically. Only a few cases are supported.
*/
if (array_other == NULL) {
PyErr_Clear();
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}

result = PyArray_GenericBinaryFunction(self,
(PyObject *)array_other,
n_ops.equal);
if ((result == Py_NotImplemented) &&
(PyArray_TYPE(self) == NPY_VOID)) {
if (PyArray_TYPE(self) == NPY_VOID) {
array_other = (PyArrayObject *)PyArray_FromAny(other, NULL, 0, 0, 0,
NULL);
/*
* If not successful, indicate that the items cannot be compared
* this way.
*/
if (array_other == NULL) {
PyErr_Clear();
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}

int _res;

_res = PyObject_RichCompareBool
Expand All @@ -1325,7 +1326,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
* two array objects can not be compared together;
* indicate that
*/
Py_DECREF(array_other);
if (result == NULL) {
PyErr_Clear();
Py_INCREF(Py_NotImplemented);
Expand All @@ -1337,27 +1337,29 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
Py_INCREF(Py_True);
return Py_True;
}
/* Make sure 'other' is an array */
if (PyArray_TYPE(self) == NPY_OBJECT) {
dtype = PyArray_DTYPE(self);
Py_INCREF(dtype);
}
array_other = (PyArrayObject *)PyArray_FromAny(other, dtype, 0, 0, 0,
NULL);
result = PyArray_GenericBinaryFunction(self, (PyObject *)other,
n_ops.not_equal);
if (result && result != Py_NotImplemented)
break;

/*
* If not successful, indicate that the items cannot be compared
* this way.
* The ufunc does not support void/structured types, so these
* need to be handled specifically. Only a few cases are supported.
*/
if (array_other == NULL) {
PyErr_Clear();
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}

result = PyArray_GenericBinaryFunction(self, (PyObject *)array_other,
n_ops.not_equal);
if ((result == Py_NotImplemented) &&
(PyArray_TYPE(self) == NPY_VOID)) {
if (PyArray_TYPE(self) == NPY_VOID) {
array_other = (PyArrayObject *)PyArray_FromAny(other, NULL, 0, 0, 0,
NULL);
/*
* If not successful, indicate that the items cannot be compared
* this way.
*/
if (array_other == NULL) {
PyErr_Clear();
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}

int _res;

_res = PyObject_RichCompareBool(
Expand All @@ -1377,7 +1379,6 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
return result;
}

Py_DECREF(array_other);
if (result == NULL) {
PyErr_Clear();
Py_INCREF(Py_NotImplemented);
Expand Down
7 changes: 7 additions & 0 deletions numpy/core/src/umath/ufunc_object.c
Expand Up @@ -486,6 +486,13 @@ _has_reflected_op(PyObject *op, char *name)
_GETATTR_(bitwise_and, rand);
_GETATTR_(bitwise_xor, rxor);
_GETATTR_(bitwise_or, ror);
/* Comparisons */
_GETATTR_(equal, eq);
_GETATTR_(not_equal, ne);
_GETATTR_(greater, lt);
_GETATTR_(less, gt);
_GETATTR_(greater_equal, le);
_GETATTR_(less_equal, ge);
return 0;
}

Expand Down
147 changes: 147 additions & 0 deletions numpy/core/tests/test_multiarray.py
Expand Up @@ -2870,5 +2870,152 @@ def test_mem_seteventhook(self):
test_pydatamem_seteventhook_end()


class PriorityNdarray():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should be PriorityNdarray(object): for Py2.4

__array_priority__ = 1000

def __init__(self, array):
self.array = array

def __lt__(self, array):
if isinstance(array, PriorityNdarray):
array = array.array
return PriorityNdarray(self.array < array)

def __gt__(self, array):
if isinstance(array, PriorityNdarray):
array = array.array
return PriorityNdarray(self.array > array)

def __le__(self, array):
if isinstance(array, PriorityNdarray):
array = array.array
return PriorityNdarray(self.array <= array)

def __ge__(self, array):
if isinstance(array, PriorityNdarray):
array = array.array
return PriorityNdarray(self.array >= array)

def __eq__(self, array):
if isinstance(array, PriorityNdarray):
array = array.array
return PriorityNdarray(self.array == array)

def __ne__(self, array):
if isinstance(array, PriorityNdarray):
array = array.array
return PriorityNdarray(self.array != array)


class TestArrayPriority(TestCase):
def test_lt(self):
l = np.asarray([0., -1., 1.], dtype=dtype)
r = np.asarray([0., 1., -1.], dtype=dtype)
lp = PriorityNdarray(l)
rp = PriorityNdarray(r)
res1 = l < r
res2 = l < rp
res3 = lp < r
res4 = lp < rp

assert_array_equal(res1, res2.array)
assert_array_equal(res1, res3.array)
assert_array_equal(res1, res4.array)
assert_(isinstance(res1, np.ndarray))
assert_(isinstance(res2, PriorityNdarray))
assert_(isinstance(res3, PriorityNdarray))
assert_(isinstance(res4, PriorityNdarray))

def test_gt(self):
l = np.asarray([0., -1., 1.], dtype=dtype)
r = np.asarray([0., 1., -1.], dtype=dtype)
lp = PriorityNdarray(l)
rp = PriorityNdarray(r)
res1 = l > r
res2 = l > rp
res3 = lp > r
res4 = lp > rp

assert_array_equal(res1, res2.array)
assert_array_equal(res1, res3.array)
assert_array_equal(res1, res4.array)
assert_(isinstance(res1, np.ndarray))
assert_(isinstance(res2, PriorityNdarray))
assert_(isinstance(res3, PriorityNdarray))
assert_(isinstance(res4, PriorityNdarray))

def test_le(self):
l = np.asarray([0., -1., 1.], dtype=dtype)
r = np.asarray([0., 1., -1.], dtype=dtype)
lp = PriorityNdarray(l)
rp = PriorityNdarray(r)
res1 = l <= r
res2 = l <= rp
res3 = lp <= r
res4 = lp <= rp

assert_array_equal(res1, res2.array)
assert_array_equal(res1, res3.array)
assert_array_equal(res1, res4.array)
assert_(isinstance(res1, np.ndarray))
assert_(isinstance(res2, PriorityNdarray))
assert_(isinstance(res3, PriorityNdarray))
assert_(isinstance(res4, PriorityNdarray))

def test_ge(self):
l = np.asarray([0., -1., 1.], dtype=dtype)
r = np.asarray([0., 1., -1.], dtype=dtype)
lp = PriorityNdarray(l)
rp = PriorityNdarray(r)
res1 = l >= r
res2 = l >= rp
res3 = lp >= r
res4 = lp >= rp

assert_array_equal(res1, res2.array)
assert_array_equal(res1, res3.array)
assert_array_equal(res1, res4.array)
assert_(isinstance(res1, np.ndarray))
assert_(isinstance(res2, PriorityNdarray))
assert_(isinstance(res3, PriorityNdarray))
assert_(isinstance(res4, PriorityNdarray))

def test_eq(self):
l = np.asarray([0., -1., 1.], dtype=dtype)
r = np.asarray([0., 1., -1.], dtype=dtype)
lp = PriorityNdarray(l)
rp = PriorityNdarray(r)
res1 = l == r
res2 = l == rp
res3 = lp == r
res4 = lp == rp

assert_array_equal(res1, res2.array)
assert_array_equal(res1, res3.array)
assert_array_equal(res1, res4.array)
assert_(isinstance(res1, np.ndarray))
assert_(isinstance(res2, PriorityNdarray))
assert_(isinstance(res3, PriorityNdarray))
assert_(isinstance(res4, PriorityNdarray))

def test_ne(self):
l = np.asarray([0., -1., 1.], dtype=dtype)
r = np.asarray([0., 1., -1.], dtype=dtype)
lp = PriorityNdarray(l)
rp = PriorityNdarray(r)
res1 = l != r
res2 = l != rp
res3 = lp != r
res4 = lp != rp

assert_array_equal(res1, res2.array)
assert_array_equal(res1, res3.array)
assert_array_equal(res1, res4.array)
assert_(isinstance(res1, np.ndarray))
assert_(isinstance(res2, PriorityNdarray))
assert_(isinstance(res3, PriorityNdarray))
assert_(isinstance(res4, PriorityNdarray))


if __name__ == "__main__":
run_module_suite()