Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Clean up tests.test_visibility to remove legacy code. #11495

Merged
merged 10 commits into from
Dec 2, 2021
1 change: 1 addition & 0 deletions changelog.d/11495.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Clean up `tests.test_visibility` to remove legacy code.
1 change: 0 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ exclude = (?x)
|tests/test_server.py
|tests/test_state.py
|tests/test_terms_auth.py
|tests/test_visibility.py
|tests/unittest.py
|tests/util/caches/test_cached_call.py
|tests/util/caches/test_deferred_cache.py
Expand Down
142 changes: 72 additions & 70 deletions tests/test_visibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from typing import Dict, List, Optional, Tuple, cast
from unittest.mock import Mock

from twisted.internet import defer
from twisted.internet.defer import succeed

from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent
from synapse.events import EventBase, FrozenEvent
from synapse.storage import Storage
from synapse.types import JsonDict
from synapse.visibility import filter_events_for_server

import tests.unittest
from tests.utils import create_room, setup_test_homeserver
from tests import unittest
from tests.utils import create_room

logger = logging.getLogger(__name__)

TEST_ROOM_ID = "!TEST:ROOM"


class FilterEventsForServerTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(self.addCleanup)
class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None:
super(FilterEventsForServerTestCase, self).setUp()
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
self.storage = self.hs.get_storage()

yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))

@defer.inlineCallbacks
def test_filtering(self):
def test_filtering(self) -> None:
#
# The events to be filtered consist of 10 membership events (it doesn't
# really matter if they are joins or leaves, so let's make them joins).
Expand All @@ -51,18 +48,20 @@ def test_filtering(self):
#

# before we do that, we persist some other events to act as state.
yield self.inject_visibility("@admin:hs", "joined")
self.get_success(self._inject_visibility("@admin:hs", "joined"))
for i in range(0, 10):
yield self.inject_room_member("@resident%i:hs" % i)
self.get_success(self._inject_room_member("@resident%i:hs" % i))

events_to_filter = []

for i in range(0, 10):
user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
evt = yield self.inject_room_member(user, extra_content={"a": "b"})
evt = self.get_success(
self._inject_room_member(user, extra_content={"a": "b"})
)
events_to_filter.append(evt)

filtered = yield defer.ensureDeferred(
filtered = self.get_success(
filter_events_for_server(self.storage, "test_server", events_to_filter)
)

Expand All @@ -75,34 +74,31 @@ def test_filtering(self):
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
self.assertEqual(filtered[i].content["a"], "b")

@defer.inlineCallbacks
def test_erased_user(self):
def test_erased_user(self) -> None:
# 4 message events, from erased and unerased users, with a membership
# change in the middle of them.
events_to_filter = []

evt = yield self.inject_message("@unerased:local_hs")
evt = self.get_success(self._inject_message("@unerased:local_hs"))
events_to_filter.append(evt)

evt = yield self.inject_message("@erased:local_hs")
evt = self.get_success(self._inject_message("@erased:local_hs"))
events_to_filter.append(evt)

evt = yield self.inject_room_member("@joiner:remote_hs")
evt = self.get_success(self._inject_room_member("@joiner:remote_hs"))
events_to_filter.append(evt)

evt = yield self.inject_message("@unerased:local_hs")
evt = self.get_success(self._inject_message("@unerased:local_hs"))
events_to_filter.append(evt)

evt = yield self.inject_message("@erased:local_hs")
evt = self.get_success(self._inject_message("@erased:local_hs"))
events_to_filter.append(evt)

# the erasey user gets erased
yield defer.ensureDeferred(
self.hs.get_datastore().mark_user_erased("@erased:local_hs")
)
self.get_success(self.hs.get_datastore().mark_user_erased("@erased:local_hs"))

# ... and the filtering happens.
filtered = yield defer.ensureDeferred(
filtered = self.get_success(
filter_events_for_server(self.storage, "test_server", events_to_filter)
)

Expand All @@ -123,8 +119,7 @@ def test_erased_user(self):
for i in (1, 4):
self.assertNotIn("body", filtered[i].content)

@defer.inlineCallbacks
def inject_visibility(self, user_id, visibility):
def _inject_visibility(self, user_id: str, visibility: str) -> EventBase:
content = {"history_visibility": visibility}
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
Expand All @@ -137,18 +132,18 @@ def inject_visibility(self, user_id, visibility):
},
)

event, context = yield defer.ensureDeferred(
event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
self.get_success(self.storage.persistence.persist_event(event, context))
return event

@defer.inlineCallbacks
def inject_room_member(
self, user_id, membership="join", extra_content: Optional[dict] = None
):
def _inject_room_member(
self,
user_id: str,
membership: str = "join",
extra_content: Optional[JsonDict] = None,
) -> EventBase:
content = {"membership": membership}
content.update(extra_content or {})
builder = self.event_builder_factory.for_room_version(
Expand All @@ -162,17 +157,16 @@ def inject_room_member(
},
)

event, context = yield defer.ensureDeferred(
event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)

yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
self.get_success(self.storage.persistence.persist_event(event, context))
return event

@defer.inlineCallbacks
def inject_message(self, user_id, content=None):
def _inject_message(
self, user_id: str, content: Optional[JsonDict] = None
) -> EventBase:
if content is None:
content = {"body": "testytest", "msgtype": "m.text"}
builder = self.event_builder_factory.for_room_version(
Expand All @@ -185,17 +179,15 @@ def inject_message(self, user_id, content=None):
},
)

event, context = yield defer.ensureDeferred(
event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)

yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
self.get_success(self.storage.persistence.persist_event(event, context))
return event

@defer.inlineCallbacks
def test_large_room(self):
# FIXME This test is broken!
reivilibre marked this conversation as resolved.
Show resolved Hide resolved
def test_large_room(self) -> None:
# see what happens when we have a large room with hundreds of thousands
# of membership events

Expand All @@ -222,7 +214,8 @@ def test_large_room(self):
"state_key": "",
"room_id": TEST_ROOM_ID,
"content": {"history_visibility": "joined"},
}
},
room_version=RoomVersions.V1,
)
room_state.append(history_visibility_evt)
test_store.add_event(history_visibility_evt)
Expand All @@ -237,12 +230,13 @@ def test_large_room(self):
"sender": user,
"room_id": TEST_ROOM_ID,
"content": {"membership": "join", "extra": "zzz,"},
}
},
room_version=RoomVersions.V1,
)
room_state.append(evt)
test_store.add_event(evt)

events_to_filter = []
events_to_filter: List[EventBase] = []
for i in range(0, 10):
user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
evt = FrozenEvent(
Expand All @@ -253,7 +247,8 @@ def test_large_room(self):
"sender": user,
"room_id": TEST_ROOM_ID,
"content": {"membership": "join", "extra": "zzz"},
}
},
room_version=RoomVersions.V1,
)
events_to_filter.append(evt)
room_state.append(evt)
Expand All @@ -273,8 +268,10 @@ def test_large_room(self):
storage.main = test_store
storage.state = test_store

filtered = yield defer.ensureDeferred(
filter_events_for_server(test_store, "test_server", events_to_filter)
filtered = self.get_success(
filter_events_for_server(
cast(Storage, test_store), "test_server", events_to_filter
)
)
logger.info("Filtering took %f seconds", time.time() - start)

Expand All @@ -292,7 +289,8 @@ def test_large_room(self):
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
self.assertEqual(filtered[i].content["extra"], "zzz")

test_large_room.skip = "Disabled by default because it's slow"
# (Mypy is not happy about a function having a field, but we ignore it because it's fine here.)
test_large_room.skip = "Disabled by default because it's slow" # type: ignore[attr-defined]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we use the @unittest.skip decorator (in general)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know. Personally speaking I don't know how to skip tests, but also wouldn't write a test that I'd skip by default anyway :).



class _TestStore:
Expand All @@ -301,22 +299,26 @@ class _TestStore:

"""

def __init__(self):
def __init__(self) -> None:
# data for get_events: a map from event_id to event
self.events = {}
self.events: Dict[str, EventBase] = {}

# data for get_state_ids_for_events mock: a map from event_id to
# a map from (type_state_key) -> event_id for the state at that
# a map from (type, state_key) -> event_id for the state at that
# event
self.state_ids_for_events = {}
self.state_ids_for_events: Dict[str, Dict[Tuple[str, Optional[str]], str]] = {}
reivilibre marked this conversation as resolved.
Show resolved Hide resolved

def add_event(self, event):
def add_event(self, event) -> None:
self.events[event.event_id] = event

def set_state_ids_for_event(self, event, state):
def set_state_ids_for_event(
self, event: EventBase, state: Dict[Tuple[str, Optional[str]], str]
) -> None:
self.state_ids_for_events[event.event_id] = state

def get_state_ids_for_events(self, events, types):
def get_state_ids_for_events(
self, events: List[str], types: List[Tuple[str, Optional[str]]]
) -> Dict[str, Dict[Tuple[str, Optional[str]], str]]:
res = {}
include_memberships = False
for (type, state_key) in types:
Expand All @@ -339,10 +341,10 @@ def get_state_ids_for_events(self, events, types):
hve = self.state_ids_for_events[event_id][k]
res[event_id] = {k: hve}

return succeed(res)
return res

def get_events(self, events):
return succeed({event_id: self.events[event_id] for event_id in events})
def get_events(self, events: List[str]) -> Dict[str, EventBase]:
return {event_id: self.events[event_id] for event_id in events}

def are_users_erased(self, users):
return succeed({u: False for u in users})
def are_users_erased(self, users: List[str]) -> Dict[str, bool]:
return {u: False for u in users}