diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 0ebd2cc7d..ce166cdfc 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -148,6 +148,11 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: if HAS_NUMPY and isinstance(orig, (np.integer, np.bool_, np.byte)): return orig == new + if HAS_NUMPY and isinstance(orig, np.void): + if orig.dtype != new.dtype: + return False + return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields) + if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix): if orig.dtype != new.dtype: return False diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 4a4d9f2b1..b8d789401 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -298,6 +298,20 @@ def test_numpy(): assert comparator(ak, al) assert not comparator(ai, ak) + dt = np.dtype([('name', 'S10'), ('age', np.int32)]) + a_struct = np.array([('Alice', 25)], dtype=dt) + b_struct = np.array([('Alice', 25)], dtype=dt) + c_struct = np.array([('Bob', 30)], dtype=dt) + + a_void = a_struct[0] + b_void = b_struct[0] + c_void = c_struct[0] + + assert isinstance(a_void, np.void) + assert comparator(a_void, b_void) + assert not comparator(a_void, c_void) + + def test_scipy(): try: