From 36b6aa859a17a3182769a08e5c7838f747f86859 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Fri, 24 Nov 2023 20:59:02 -0600 Subject: [PATCH] Call `__eq__` from `__ne__` for user defined `__eq__` If a user manually defines an `__eq__` for a `Struct` class, the default `__ne__` implementation will now call the user-defined `__eq__` and invert the result, rather than applying the standard `__ne__` logic. This makes it easier for users to manually override `__eq__`, and matches the behavior of standard python classes. --- msgspec/_core.c | 22 ++++++++++++++++++++++ tests/test_struct.py | 15 +++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/msgspec/_core.c b/msgspec/_core.c index 507eda8a..8ca24066 100644 --- a/msgspec/_core.c +++ b/msgspec/_core.c @@ -7173,6 +7173,28 @@ Struct_richcompare(PyObject *self, PyObject *other, int op) { Py_RETURN_NOTIMPLEMENTED; } + if ( + MS_UNLIKELY(op == Py_NE && (Py_TYPE(self)->tp_richcompare != Struct_richcompare)) + ) { + /* This case is hit when a subclass has manually defined `__eq__` but + * not `__ne__`. In this case we want to dispatch to `__eq__` and invert + * the result, rather than relying on the default `__ne__` implementation. + */ + PyObject *out = Py_TYPE(self)->tp_richcompare(self, other, Py_EQ); + if (out != NULL && out != Py_NotImplemented) { + int is_true = PyObject_IsTrue(out); + Py_DECREF(out); + if (is_true < 0) { + out = NULL; + } + else { + out = is_true ? Py_False : Py_True; + Py_INCREF(out); + } + } + return out; + } + int equal = 1; PyObject *left = NULL, *right = NULL; diff --git a/tests/test_struct.py b/tests/test_struct.py index 80255f52..d00b7e4a 100644 --- a/tests/test_struct.py +++ b/tests/test_struct.py @@ -1562,6 +1562,21 @@ class Test2(Test): self.assert_neq(x, Test(2, 2)) self.assert_neq(x, Test2(1, 2)) + def test_struct_override_eq(self): + class Ex(Struct): + a: int + b: int + + def __eq__(self, other): + return self.a == other.a + + x = Ex(1, 2) + y = Ex(1, 3) + z = Ex(2, 3) + + self.assert_eq(x, y) + self.assert_neq(x, z) + def test_struct_eq_identity_fastpath(self): class Bad: def __eq__(self, other):