Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Stop patching EventBase.__eq__ in tests. #16349

Merged
merged 2 commits into from Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16349.misc
@@ -0,0 +1 @@
Avoid patching code in tests.
17 changes: 12 additions & 5 deletions tests/replication/storage/_base.py
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Iterable, Optional
from typing import Any, Callable, Iterable, Optional
from unittest.mock import Mock

from twisted.test.proto_helpers import MemoryReactor
Expand Down Expand Up @@ -47,24 +47,31 @@ def replicate(self) -> None:
self.pump(0.1)

def check(
self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None
self,
method: str,
args: Iterable[Any],
expected_result: Optional[Any] = None,
asserter: Optional[Callable[[Any, Any, Optional[Any]], None]] = None,
) -> None:
if asserter is None:
asserter = self.assertEqual

master_result = self.get_success(getattr(self.master_store, method)(*args))
worker_result = self.get_success(getattr(self.worker_store, method)(*args))
if expected_result is not None:
self.assertEqual(
asserter(
master_result,
expected_result,
"Expected master result to be %r but was %r"
% (expected_result, master_result),
)
self.assertEqual(
asserter(
worker_result,
expected_result,
"Expected worker result to be %r but was %r"
% (expected_result, worker_result),
)
self.assertEqual(
asserter(
master_result,
worker_result,
"Worker result %r does not match master result %r"
Expand Down
49 changes: 18 additions & 31 deletions tests/replication/storage/test_events.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, Iterable, List, Optional, Tuple
from typing import Any, Iterable, List, Optional, Tuple

from canonicaljson import encode_canonical_json
from parameterized import parameterized
Expand All @@ -21,7 +21,7 @@

from synapse.api.constants import ReceiptTypes
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.handlers.room import RoomEventSource
from synapse.server import HomeServer
Expand All @@ -46,32 +46,9 @@
logger = logging.getLogger(__name__)


def dict_equals(self: EventBase, other: EventBase) -> bool:
me = encode_canonical_json(self.get_pdu_json())
them = encode_canonical_json(other.get_pdu_json())
return me == them


def patch__eq__(cls: object) -> Callable[[], None]:
eq = getattr(cls, "__eq__", None)
cls.__eq__ = dict_equals # type: ignore[assignment]

def unpatch() -> None:
if eq is not None:
cls.__eq__ = eq # type: ignore[method-assign]

return unpatch


class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
STORE_TYPE = EventsWorkerStore

def setUp(self) -> None:
# Patch up the equality operator for events so that we can check
# whether lists of events match using assertEqual
self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)]
super().setUp()

def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)

Expand All @@ -84,8 +61,14 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
)
)

def tearDown(self) -> None:
[unpatch() for unpatch in self.unpatches]
def assertEventsEqual(
self, first: EventBase, second: EventBase, msg: Optional[Any] = None
) -> None:
self.assertEqual(
encode_canonical_json(first.get_pdu_json()),
encode_canonical_json(second.get_pdu_json()),
msg,
)

def test_get_latest_event_ids_in_room(self) -> None:
create = self.persist(type="m.room.create", key="", creator=USER_ID)
Expand All @@ -107,7 +90,7 @@ def test_redactions(self) -> None:

msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
self.replicate()
self.check("get_event", [msg.event_id], msg)
self.check("get_event", [msg.event_id], msg, asserter=self.assertEventsEqual)

redaction = self.persist(type="m.room.redaction", redacts=msg.event_id)
self.replicate()
Expand All @@ -119,15 +102,17 @@ def test_redactions(self) -> None:
redacted = make_event_from_dict(
msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict()
)
self.check("get_event", [msg.event_id], redacted)
self.check(
"get_event", [msg.event_id], redacted, asserter=self.assertEventsEqual
)

def test_backfilled_redactions(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")

msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
self.replicate()
self.check("get_event", [msg.event_id], msg)
self.check("get_event", [msg.event_id], msg, asserter=self.assertEventsEqual)

redaction = self.persist(
type="m.room.redaction", redacts=msg.event_id, backfill=True
Expand All @@ -141,7 +126,9 @@ def test_backfilled_redactions(self) -> None:
redacted = make_event_from_dict(
msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict()
)
self.check("get_event", [msg.event_id], redacted)
self.check(
"get_event", [msg.event_id], redacted, asserter=self.assertEventsEqual
)

def test_invites(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
Expand Down