Skip to content

Commit

Permalink
BUG: core: handle sub-arrays in dtype comparisons
Browse files Browse the repository at this point in the history
  • Loading branch information
mwiebe authored and pv committed Oct 31, 2010
1 parent 33b3e60 commit 5012504
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
37 changes: 36 additions & 1 deletion numpy/core/src/multiarray/multiarraymodule.c
Expand Up @@ -1361,6 +1361,37 @@ _equivalent_units(PyObject *meta1, PyObject *meta2)
&& (data1->events == data2->events));
}

/*
* Compare the subarray data for two types.
* Return 1 if they are the same, 0 if not.
*/
static int
_equivalent_subarrays(PyArray_ArrayDescr *sub1, PyArray_ArrayDescr *sub2)
{
int val;

if (sub1 == sub2) {
return 1;

}
if (sub1 == NULL || sub2 == NULL) {
return 0;
}

#if defined(NPY_PY3K)
val = PyObject_RichCompareBool(sub1->shape, sub2->shape, Py_EQ);
if (val != 1 || PyErr_Occurred()) {
#else
val = PyObject_Compare(sub1->shape, sub2->shape);
if (val != 0 || PyErr_Occurred()) {
#endif
PyErr_Clear();
return 0;
}

return PyArray_EquivTypes(sub1->base, sub2->base);
}


/*NUMPY_API
*
Expand All @@ -1381,6 +1412,10 @@ PyArray_EquivTypes(PyArray_Descr *typ1, PyArray_Descr *typ2)
if (PyArray_ISNBO(typ1->byteorder) != PyArray_ISNBO(typ2->byteorder)) {
return FALSE;
}
if (typ1->subarray || typ2->subarray) {
return ((typenum1 == typenum2)
&& _equivalent_subarrays(typ1->subarray, typ2->subarray));
}
if (typenum1 == PyArray_VOID
|| typenum2 == PyArray_VOID) {
return ((typenum1 == typenum2)
Expand Down Expand Up @@ -1874,7 +1909,7 @@ array_arange(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kws) {
}
range = PyArray_ArangeObj(o_start, o_stop, o_step, typecode);
Py_XDECREF(typecode);
return range;
return range;
}

/*NUMPY_API
Expand Down
27 changes: 27 additions & 0 deletions numpy/core/tests/test_dtype.py
Expand Up @@ -55,6 +55,33 @@ def test_not_lists(self):
self.assertRaises(TypeError, np.dtype,
dict(names=['A', 'B'], formats=set(['f8', 'i4'])))

class TestShape(TestCase):
def test_equal(self):
"""Test some data types that are equal"""
self.assertEqual(np.dtype('f8'), np.dtype(('f8',tuple())))
self.assertEqual(np.dtype('f8'), np.dtype(('f8',1)))
self.assertEqual(np.dtype((np.int,2)), np.dtype((np.int,(2,))))
self.assertEqual(np.dtype(('<f4',(3,2))), np.dtype(('<f4',(3,2))))
d = ([('a','f4',(1,2)),('b','f8',(3,1))],(3,2))
self.assertEqual(np.dtype(d), np.dtype(d))

def test_simple(self):
"""Test some simple cases that shouldn't be equal"""
self.assertNotEqual(np.dtype('f8'), np.dtype(('f8',(1,))))
self.assertNotEqual(np.dtype(('f8',(1,))), np.dtype(('f8',(1,1))))
self.assertNotEqual(np.dtype(('f4',(3,2))), np.dtype(('f4',(2,3))))

def test_monster(self):
"""Test some more complicated cases that shouldn't be equal"""
self.assertNotEqual(np.dtype(([('a','f4',(2,1)), ('b','f8',(1,3))],(2,2))),
np.dtype(([('a','f4',(1,2)), ('b','f8',(1,3))],(2,2))))
self.assertNotEqual(np.dtype(([('a','f4',(2,1)), ('b','f8',(1,3))],(2,2))),
np.dtype(([('a','f4',(2,1)), ('b','i8',(1,3))],(2,2))))
self.assertNotEqual(np.dtype(([('a','f4',(2,1)), ('b','f8',(1,3))],(2,2))),
np.dtype(([('e','f8',(1,3)), ('d','f4',(2,1))],(2,2))))
self.assertNotEqual(np.dtype(([('a',[('a','i4',6)],(2,1)), ('b','f8',(1,3))],(2,2))),
np.dtype(([('a',[('a','u4',6)],(2,1)), ('b','f8',(1,3))],(2,2))))

class TestSubarray(TestCase):
def test_single_subarray(self):
a = np.dtype((np.int, (2)))
Expand Down

0 comments on commit 5012504

Please sign in to comment.