Skip to content

Commit

Permalink
Call __eq__ from __ne__ for user defined __eq__
Browse files Browse the repository at this point in the history
If a user manually defines an `__eq__` for a `Struct` class, the default
`__ne__` implementation will now call the user-defined `__eq__` and
invert the result, rather than applying the standard `__ne__` logic.
This makes it easier for users to manually override `__eq__`, and
matches the behavior of standard python classes.
  • Loading branch information
jcrist committed Nov 25, 2023
1 parent ae9a77e commit 36b6aa8
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
22 changes: 22 additions & 0 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -7173,6 +7173,28 @@ Struct_richcompare(PyObject *self, PyObject *other, int op) {
Py_RETURN_NOTIMPLEMENTED;
}

if (
MS_UNLIKELY(op == Py_NE && (Py_TYPE(self)->tp_richcompare != Struct_richcompare))
) {
/* This case is hit when a subclass has manually defined `__eq__` but
* not `__ne__`. In this case we want to dispatch to `__eq__` and invert
* the result, rather than relying on the default `__ne__` implementation.
*/
PyObject *out = Py_TYPE(self)->tp_richcompare(self, other, Py_EQ);
if (out != NULL && out != Py_NotImplemented) {
int is_true = PyObject_IsTrue(out);
Py_DECREF(out);
if (is_true < 0) {
out = NULL;
}
else {
out = is_true ? Py_False : Py_True;
Py_INCREF(out);
}
}
return out;
}

int equal = 1;
PyObject *left = NULL, *right = NULL;

Expand Down
15 changes: 15 additions & 0 deletions tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,6 +1562,21 @@ class Test2(Test):
self.assert_neq(x, Test(2, 2))
self.assert_neq(x, Test2(1, 2))

def test_struct_override_eq(self):
class Ex(Struct):
a: int
b: int

def __eq__(self, other):
return self.a == other.a

x = Ex(1, 2)
y = Ex(1, 3)
z = Ex(2, 3)

self.assert_eq(x, y)
self.assert_neq(x, z)

def test_struct_eq_identity_fastpath(self):
class Bad:
def __eq__(self, other):
Expand Down

0 comments on commit 36b6aa8

Please sign in to comment.