diff --git a/changelog.d/13863.bugfix b/changelog.d/13863.bugfix new file mode 100644 index 000000000000..74264a4fab79 --- /dev/null +++ b/changelog.d/13863.bugfix @@ -0,0 +1 @@ +Fix `have_seen_event` cache not being invalidated after we persist an event which causes inefficiency effects like extra `/state` federation calls. diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index 9e554a865ee5..5a53782da994 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -31,7 +31,7 @@ federation_ip_range_blacklist: [] # Disable server rate-limiting rc_federation: window_size: 1000 - sleep_limit: 10 + sleep_limit: 99999 sleep_delay: 500 reject_limit: 99999 concurrent: 3 diff --git a/synapse/events/builder.py b/synapse/events/builder.py index e2ee10dd3ddc..87a03318686e 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -132,6 +132,11 @@ async def build( auth_event_ids = self._event_auth_handler.compute_auth_events( self, state_ids ) + logger.info( + "builder.build state_ids=%s auth_event_ids=%s", + state_ids, + auth_event_ids, + ) format_version = self.room_version.event_format # The types of auth/prev events changes between event versions. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 73471fe04113..f429d0783237 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -38,7 +38,7 @@ from unpaddedbase64 import decode_base64 from synapse import event_auth -from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership from synapse.api.errors import ( AuthError, CodeMessageException, @@ -60,13 +60,7 @@ from synapse.federation.federation_client import InvalidResponseError from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import nested_logging_context -from synapse.logging.tracing import ( - SynapseTags, - set_attribute, - start_active_span, - tag_args, - trace, -) +from synapse.logging.tracing import SynapseTags, set_attribute, tag_args, trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import NOT_SPAM from synapse.replication.http.federation import ( @@ -125,6 +119,7 @@ class _BackfillPoint: event_id: str depth: int + stream_ordering: int type: _BackfillPointType @@ -231,16 +226,24 @@ async def _maybe_backfill_inner( processing. Only used for timing. """ backwards_extremities = [ - _BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY) - for event_id, depth in await self.store.get_backfill_points_in_room(room_id) + _BackfillPoint( + event_id, depth, stream_ordering, _BackfillPointType.BACKWARDS_EXTREMITY + ) + for event_id, depth, stream_ordering in await self.store.get_backfill_points_in_room( + room_id=room_id, + current_depth=current_depth, + ) ] insertion_events_to_be_backfilled: List[_BackfillPoint] = [] if self.hs.config.experimental.msc2716_enabled: insertion_events_to_be_backfilled = [ - _BackfillPoint(event_id, depth, _BackfillPointType.INSERTION_PONT) - for event_id, depth in await self.store.get_insertion_event_backward_extremities_in_room( - room_id + _BackfillPoint( + event_id, depth, stream_ordering, _BackfillPointType.INSERTION_PONT + ) + for event_id, depth, stream_ordering in await self.store.get_insertion_event_backward_extremities_in_room( + room_id=room_id, + current_depth=current_depth, ) ] logger.debug( @@ -249,10 +252,6 @@ async def _maybe_backfill_inner( insertion_events_to_be_backfilled, ) - if not backwards_extremities and not insertion_events_to_be_backfilled: - logger.debug("Not backfilling as no extremeties found.") - return False - # we now have a list of potential places to backpaginate from. We prefer to # start with the most recent (ie, max depth), so let's sort the list. sorted_backfill_points: List[_BackfillPoint] = sorted( @@ -260,7 +259,7 @@ async def _maybe_backfill_inner( backwards_extremities, insertion_events_to_be_backfilled, ), - key=lambda e: -int(e.depth), + key=lambda e: (-e.depth, -e.stream_ordering, e.event_id), ) logger.debug( @@ -273,6 +272,29 @@ async def _maybe_backfill_inner( sorted_backfill_points, ) + # If we have no backfill points lower than the `current_depth` then + # either we can a) bail or b) still attempt to backfill. We opt to try + # backfilling anyway just in case we do get relevant events. + if not sorted_backfill_points and current_depth != MAX_DEPTH: + logger.debug( + "_maybe_backfill_inner: all backfill points are *after* current depth. Backfilling anyway." + ) + return await self._maybe_backfill_inner( + room_id=room_id, + # We use `MAX_DEPTH` so that we find all backfill points next + # time (all events are below the `MAX_DEPTH`) + current_depth=MAX_DEPTH, + limit=limit, + processing_start_time=processing_start_time, + ) + elif not sorted_backfill_points and current_depth == MAX_DEPTH: + # Even after trying again with `MAX_DEPTH`, we didn't find any + # backward extremities to backfill from. + logger.debug( + "_maybe_backfill_inner: Not backfilling as no backward extremeties found." + ) + return False + # If we're approaching an extremity we trigger a backfill, otherwise we # no-op. # @@ -286,43 +308,16 @@ async def _maybe_backfill_inner( # XXX: shouldn't we do this *after* the filter by depth below? Again, we don't # care about events that have happened after our current position. # - max_depth = sorted_backfill_points[0].depth - if current_depth - 2 * limit > max_depth: + max_depth_of_backfill_points = sorted_backfill_points[0].depth + if current_depth - 2 * limit > max_depth_of_backfill_points: logger.debug( "Not backfilling as we don't need to. %d < %d - 2 * %d", - max_depth, + max_depth_of_backfill_points, current_depth, limit, ) return False - # We ignore extremities that have a greater depth than our current depth - # as: - # 1. we don't really care about getting events that have happened - # after our current position; and - # 2. we have likely previously tried and failed to backfill from that - # extremity, so to avoid getting "stuck" requesting the same - # backfill repeatedly we drop those extremities. - # - # However, we need to check that the filtered extremities are non-empty. - # If they are empty then either we can a) bail or b) still attempt to - # backfill. We opt to try backfilling anyway just in case we do get - # relevant events. - # - filtered_sorted_backfill_points = [ - t for t in sorted_backfill_points if t.depth <= current_depth - ] - if filtered_sorted_backfill_points: - logger.debug( - "_maybe_backfill_inner: backfill points before current depth: %s", - filtered_sorted_backfill_points, - ) - sorted_backfill_points = filtered_sorted_backfill_points - else: - logger.debug( - "_maybe_backfill_inner: all backfill points are *after* current depth. Backfilling anyway." - ) - # For performance's sake, we only want to paginate from a particular extremity # if we can actually see the events we'll get. Otherwise, we'd just spend a lot # of resources to get redacted events. We check each extremity in turn and @@ -344,14 +339,10 @@ async def _maybe_backfill_inner( # attempting to paginate before backfill reached the visible history. extremities_to_request: List[str] = [] - for i, bp in enumerate(sorted_backfill_points): + for bp in sorted_backfill_points: if len(extremities_to_request) >= 5: break - set_attribute( - SynapseTags.RESULT_PREFIX + "backfill_point" + str(i), str(bp) - ) - # For regular backwards extremities, we don't have the extremity events # themselves, so we need to actually check the events that reference them - # their "successor" events. @@ -396,7 +387,7 @@ async def _maybe_backfill_inner( ) return False - logger.debug( + logger.info( "_maybe_backfill_inner: extremities_to_request %s", extremities_to_request ) set_attribute( @@ -408,13 +399,12 @@ async def _maybe_backfill_inner( str(len(extremities_to_request)), ) - with start_active_span("getting likely_domains"): - # Now we need to decide which hosts to hit first. - # First we try hosts that are already in the room. - # TODO: HEURISTIC ALERT. - likely_domains = ( - await self._storage_controllers.state.get_current_hosts_in_room(room_id) - ) + # Now we need to decide which hosts to hit first. + # First we try hosts that are already in the room. + # TODO: HEURISTIC ALERT. + likely_domains = ( + await self._storage_controllers.state.get_current_hosts_in_room(room_id) + ) async def try_backfill(domains: Collection[str]) -> bool: # TODO: Should we try multiple of these at a time? diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 00a8860ff326..b82919b46a92 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -14,6 +14,7 @@ import collections import itertools +import json import logging from http import HTTPStatus from typing import ( @@ -137,6 +138,7 @@ class FederationEventHandler: """ def __init__(self, hs: "HomeServer"): + self.hs = hs self._store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state @@ -177,6 +179,7 @@ def __init__(self, hs: "HomeServer"): self._room_pdu_linearizer = Linearizer("fed_room_pdu") + @trace async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None: """Process a PDU received via a federation /send/ transaction @@ -644,9 +647,108 @@ async def backfill( f"room {ev.room_id}, when we were backfilling in {room_id}" ) + # foo + # + # We expect the events from the `/backfill`response to start from + # `?v` and include events that preceded it (so the list will be + # newest -> oldest, reverse chronological). This is at-most a + # convention between Synapse servers as the order is not specced. + # + # We want to calculate the `stream_ordering`` from newest -> oldest + # (so historical events sort in the correct order) and persist in + # oldest -> newest to get the least missing `prev_event` fetch + # thrashing. + reverse_chronological_events = events + # `[::-1]` is just syntax to reverse the list and give us a copy + chronological_events = reverse_chronological_events[::-1] + + logger.info( + "backfill assumed reverse_chronological_events=%s", + [ + "event_id=%s,depth=%d,body=%s(%s),prevs=%s\n" + % ( + event.event_id, + event.depth, + event.content.get("body", event.type), + getattr(event, "state_key", None), + event.prev_event_ids(), + ) + for event in reverse_chronological_events + ], + ) + + # logger.info( + # "backfill chronological_events=%s", + # [ + # "event_id=%s,depth=%d,body=%s(%s),prevs=%s\n" + # % ( + # event.event_id, + # event.depth, + # event.content.get("body", event.type), + # getattr(event, "state_key", None), + # event.prev_event_ids(), + # ) + # for event in chronological_events + # ], + # ) + + from synapse.storage.util.id_generators import AbstractStreamIdGenerator + + # This should only exist on instances that are configured to write + assert ( + self._instance_name in self.hs.config.worker.writers.events + ), "Can only instantiate xxxfoobarbaz on master" + + # Since we have been configured to write, we ought to have id generators, + # rather than id trackers. + assert isinstance(self._store._backfill_id_gen, AbstractStreamIdGenerator) + stream_ordering_manager = self._store._backfill_id_gen.get_next_mult( + len(reverse_chronological_events) + ) + async with stream_ordering_manager as stream_orderings: + for event, stream in zip( + reverse_chronological_events, stream_orderings + ): + event.internal_metadata.stream_ordering = stream + + logger.info( + "backfill chronological_events=%s", + [ + "event_id=%s,depth=%d,stream_ordering=%s,body=%s(%s),prevs=%s\n" + % ( + event.event_id, + event.depth, + event.internal_metadata.stream_ordering, + event.content.get("body", event.type), + getattr(event, "state_key", None), + event.prev_event_ids(), + ) + for event in chronological_events + ], + ) + logger.info( + "backfill reverse_chronological_events=%s", + [ + "event_id=%s,depth=%d,stream_ordering=%s,body=%s(%s),prevs=%s\n" + % ( + event.event_id, + event.depth, + event.internal_metadata.stream_ordering, + event.content.get("body", event.type), + getattr(event, "state_key", None), + event.prev_event_ids(), + ) + for event in reverse_chronological_events + ], + ) + await self._process_pulled_events( dest, - events, + # Expecting to persist in chronological order here (oldest -> + # newest) so that events are persisted before they're referenced + # as a `prev_event`. + chronological_events, + # reverse_chronological_events, backfilled=True, ) @@ -781,11 +883,12 @@ async def _process_pulled_events( "processing pulled backfilled=%s events=%s", backfilled, [ - "event_id=%s,depth=%d,body=%s,prevs=%s\n" + "event_id=%s,depth=%d,body=%s(%s),prevs=%s\n" % ( event.event_id, event.depth, event.content.get("body", event.type), + getattr(event, "state_key", None), event.prev_event_ids(), ) for event in events @@ -795,6 +898,26 @@ async def _process_pulled_events( # We want to sort these by depth so we process them and # tell clients about them in order. sorted_events = sorted(events, key=lambda x: x.depth) + + logger.info( + "_process_pulled_events backfill sorted_events=%s", + json.dumps( + [ + "event_id=%s,depth=%d,stream_ordering=%s,body=%s(%s),prevs=%s\n" + % ( + event.event_id, + event.depth, + event.internal_metadata.stream_ordering, + event.content.get("body", event.type), + getattr(event, "state_key", None), + event.prev_event_ids(), + ) + for event in sorted_events + ], + indent=4, + ), + ) + for ev in sorted_events: with nested_logging_context(ev.event_id): await self._process_pulled_event(origin, ev, backfilled=backfilled) @@ -1073,11 +1196,18 @@ async def _get_state_ids_after_missing_prev_event( destination, room_id, event_id=event_id ) - logger.debug( - "state_ids returned %i state events, %i auth events", + logger.info( + "_get_state_ids_after_missing_prev_event(event_id=%s): state_ids returned %i state events, %i auth events", + event_id, len(state_event_ids), len(auth_event_ids), ) + logger.info( + "_get_state_ids_after_missing_prev_event(event_id=%s): state_event_ids=%s auth_event_ids=%s", + event_id, + state_event_ids, + auth_event_ids, + ) # Start by checking events we already have in the DB desired_events = set(state_event_ids) @@ -1790,6 +1920,12 @@ async def _check_event_auth( # already have checked we have all the auth events, in # _load_or_fetch_auth_events_for_event above) if context.partial_state: + logger.info( + "_check_event_auth(event=%s) with partial_state - %s (%s)", + event.event_id, + event.content.get("body", event.type), + getattr(event, "state_key", None), + ) room_version = await self._store.get_room_version_id(event.room_id) local_state_id_map = await context.get_prev_state_ids() @@ -1807,15 +1943,40 @@ async def _check_event_auth( ) ) else: + logger.info( + "_check_event_auth(event=%s) with full state - %s (%s)", + event.event_id, + event.content.get("body", event.type), + getattr(event, "state_key", None), + ) event_types = event_auth.auth_types_for_event(event.room_version, event) state_for_auth_id_map = await context.get_prev_state_ids( StateFilter.from_types(event_types) ) + logger.info( + "_check_event_auth(event=%s) state_for_auth_id_map=%s - %s (%s)", + event.event_id, + state_for_auth_id_map, + event.content.get("body", event.type), + getattr(event, "state_key", None), + ) + calculated_auth_event_ids = self._event_auth_handler.compute_auth_events( event, state_for_auth_id_map, for_verification=True ) + logger.info( + "_check_event_auth(event=%s) match=%s claimed_auth_events=%s calculated_auth_event_ids=%s - %s (%s)", + event.event_id, + collections.Counter(event.auth_event_ids()) + == collections.Counter(calculated_auth_event_ids), + event.auth_event_ids(), + calculated_auth_event_ids, + event.content.get("body", event.type), + getattr(event, "state_key", None), + ) + # if those are the same, we're done here. if collections.Counter(event.auth_event_ids()) == collections.Counter( calculated_auth_event_ids diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 10b5dad03009..d53f8dc28d02 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1081,6 +1081,11 @@ async def create_new_client_event( # Do a quick sanity check here, rather than waiting until we've created the # event and then try to auth it (which fails with a somewhat confusing "No # create event in auth events") + logger.info( + "create_new_client_event allow_no_prev_events=%s auth_event_ids=%s", + allow_no_prev_events, + auth_event_ids, + ) if allow_no_prev_events: # We allow events with no `prev_events` but it better have some `auth_events` assert ( @@ -1101,6 +1106,13 @@ async def create_new_client_event( depth=depth, ) + logger.info( + "create_new_client_event(event=%s) state_event_ids=%s resultant auth_event_ids=%s", + event.event_id, + state_event_ids, + auth_event_ids, + ) + # Pass on the outlier property from the builder to the event # after it is created if builder.internal_metadata.outlier: diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 1414e575d6fc..fc303fa382fe 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -153,6 +153,7 @@ async def persist_state_events_at_start( self, state_events_at_start: List[JsonDict], room_id: str, + initial_prev_event_ids: List[str], initial_state_event_ids: List[str], app_service_requester: Requester, ) -> List[str]: @@ -178,10 +179,8 @@ async def persist_state_events_at_start( state_event_ids_at_start = [] state_event_ids = initial_state_event_ids.copy() - # Make the state events float off on their own by specifying no - # prev_events for the first one in the chain so we don't have a bunch of - # `@mxid joined the room` noise between each batch. - prev_event_ids_for_state_chain: List[str] = [] + # TODO: Here + prev_event_ids_for_state_chain: List[str] = initial_prev_event_ids for index, state_event in enumerate(state_events_at_start): assert_params_in_dict( @@ -269,6 +268,7 @@ async def persist_historical_events( events_to_create: List[JsonDict], room_id: str, inherited_depth: int, + state_chain_event_id_to_connect_to: str, initial_state_event_ids: List[str], app_service_requester: Requester, ) -> List[str]: @@ -301,10 +301,8 @@ async def persist_historical_events( # We expect the last event in a historical batch to be an batch event assert events_to_create[-1]["type"] == EventTypes.MSC2716_BATCH - # Make the historical event chain float off on its own by specifying no - # prev_events for the first event in the chain which causes the HS to - # ask for the state at the start of the batch later. - prev_event_ids: List[str] = [] + # TODO: Here + prev_event_ids: List[str] = [state_chain_event_id_to_connect_to] event_ids = [] events_to_persist = [] @@ -390,6 +388,7 @@ async def handle_batch_of_events( events_to_create: List[JsonDict], room_id: str, batch_id_to_connect_to: str, + state_chain_event_id_to_connect_to: str, inherited_depth: int, initial_state_event_ids: List[str], app_service_requester: Requester, @@ -458,6 +457,7 @@ async def handle_batch_of_events( events_to_create=events_to_create, room_id=room_id, inherited_depth=inherited_depth, + state_chain_event_id_to_connect_to=state_chain_event_id_to_connect_to, initial_state_event_ids=initial_state_event_ids, app_service_requester=app_service_requester, ) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 5f3c7ee46ca9..0b3998ba960c 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -634,6 +634,12 @@ def __init__(self, hs: "HomeServer"): async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: + logger.info( + "RoomMessageListRestServlet ========================================" + ) + logger.info( + "===================================================================" + ) processing_start_time = self.clock.time_msec() # Fire off and hope that we get a result by the end. # diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index dd91dabedd66..165efca87795 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -153,6 +153,7 @@ async def on_POST( await self.room_batch_handler.persist_state_events_at_start( state_events_at_start=body["state_events_at_start"], room_id=room_id, + initial_prev_event_ids=prev_event_ids_from_query, initial_state_event_ids=state_event_ids, app_service_requester=requester, ) @@ -222,6 +223,8 @@ async def on_POST( room_id=room_id, batch_id_to_connect_to=batch_id_to_connect_to, inherited_depth=inherited_depth, + # Connect the historical batch to the state chain + state_chain_event_id_to_connect_to=state_event_ids_at_start[-1], initial_state_event_ids=state_event_ids, app_service_requester=requester, ) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 3787d35b244f..bb4dbb1edf7d 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -36,11 +36,13 @@ from frozendict import frozendict from prometheus_client import Counter, Histogram +from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.context import ContextResourceUsage +from synapse.logging.tracing import SynapseTags, log_kv, set_attribute, tag_args, trace from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour @@ -250,6 +252,8 @@ async def get_hosts_in_room_at_events( state = await entry.get_state(self._state_storage_controller, StateFilter.all()) return await self.store.get_joined_hosts(room_id, state, entry) + @trace + @tag_args async def compute_event_context( self, event: EventBase, @@ -282,6 +286,14 @@ async def compute_event_context( RuntimeError if `state_ids_before_event` is not provided and one or more prev events are missing or outliers. """ + set_attribute( + SynapseTags.RESULT_PREFIX + "event_type_and_state", + f"{event.type} - {getattr(event, 'state_key', None)}", + ) + set_attribute( + SynapseTags.RESULT_PREFIX + "event_body", + event.content.get("body", None), + ) assert not event.internal_metadata.is_outlier() @@ -289,6 +301,15 @@ async def compute_event_context( # first of all, figure out the state before the event, unless we # already have it. # + logger.info( + "compute_event_context(event=%s, state_ids_before_event=%s) - %s (%s)", + event.event_id, + state_ids_before_event, + event.content.get("body", event.type), + getattr(event, "state_key", None), + # stack_info=True, + ) + if state_ids_before_event: # if we're given the state before the event, then we use that state_group_before_event_prev_group = None @@ -304,6 +325,12 @@ async def compute_event_context( current_state_ids=state_ids_before_event, ) ) + log_kv( + { + "message": "Using state before event because `state_ids_before_event` was given:", + "state_group_before_event": state_group_before_event, + } + ) # the partial_state flag must be provided assert partial_state is not None @@ -324,7 +351,7 @@ async def compute_event_context( ) partial_state = any(incomplete_prev_events.values()) if partial_state: - logger.debug( + logger.info( "New/incoming event %s refers to prev_events %s with partial state", event.event_id, [k for (k, v) in incomplete_prev_events.items() if v], @@ -343,6 +370,24 @@ async def compute_event_context( deltas_to_state_group_before_event = entry.delta_ids state_ids_before_event = None + logger.info( + "compute_event_context(event=%s) resolve_state_groups_for_events entry.state_group=%s state_group_before_event_prev_group=%s deltas_to_state_group_before_event=%s - %s (%s)", + event.event_id, + entry.state_group, + state_group_before_event_prev_group, + deltas_to_state_group_before_event, + event.content.get("body", event.type), + getattr(event, "state_key", None), + ) + log_kv( + { + "message": "resolve_state_groups_for_events", + "entry.state_group": entry.state_group, + "state_group_before_event_prev_group": state_group_before_event_prev_group, + "deltas_to_state_group_before_event": deltas_to_state_group_before_event, + } + ) + # We make sure that we have a state group assigned to the state. if entry.state_group is None: # store_state_group requires us to have either a previous state group @@ -352,6 +397,12 @@ async def compute_event_context( state_ids_before_event = await entry.get_state( self._state_storage_controller, StateFilter.all() ) + log_kv( + { + "message": "state_group_before_event_prev_group was None so get state_ids_before_event", + "state_ids_before_event": state_ids_before_event, + } + ) state_group_before_event = ( await self._state_storage_controller.store_state_group( @@ -363,15 +414,27 @@ async def compute_event_context( ) ) entry.set_state_group(state_group_before_event) + log_kv( + { + "message": "entry.set_state_group(state_group_before_event)", + "state_group_before_event": state_group_before_event, + } + ) else: state_group_before_event = entry.state_group + log_kv( + { + "message": "Entry already has a state_group", + "state_group_before_event": state_group_before_event, + } + ) # # now if it's not a state event, we're done # if not event.is_state(): - return EventContext.with_state( + event_context = EventContext.with_state( storage=self._storage_controllers, state_group_before_event=state_group_before_event, state_group=state_group_before_event, @@ -381,6 +444,29 @@ async def compute_event_context( partial_state=partial_state, ) + state_for_auth_id_map = await event_context.get_prev_state_ids( + StateFilter.from_types( + event_auth.auth_types_for_event(event.room_version, event) + ) + ) + log_kv( + { + "message": "Done creating context for non-state event", + "state_for_auth_id_map from event_context": str( + state_for_auth_id_map + ), + } + ) + logger.info( + "compute_event_context(event=%s) Done creating context=%s for non-state event - %s (%s)", + event.event_id, + event_context, + event.content.get("body", event.type), + getattr(event, "state_key", None), + ) + + return event_context + # # otherwise, we'll need to create a new state group for after the event # @@ -421,6 +507,7 @@ async def compute_event_context( ) @measure_func() + @trace async def resolve_state_groups_for_events( self, room_id: str, event_ids: Collection[str], await_full_state: bool = True ) -> _StateCacheEntry: @@ -448,6 +535,13 @@ async def resolve_state_groups_for_events( state_group_ids = state_groups.values() + logger.info( + "resolve_state_groups_for_events: state_group_ids=%s state_groups=%s", + state_group_ids, + state_groups, + ) + log_kv({"state_group_ids": state_group_ids, "state_groups": state_groups}) + # check if each event has same state group id, if so there's no state to resolve state_group_ids_set = set(state_group_ids) if len(state_group_ids_set) == 1: @@ -458,6 +552,18 @@ async def resolve_state_groups_for_events( ) = await self._state_storage_controller.get_state_group_delta( state_group_id ) + logger.info( + "resolve_state_groups_for_events: Returning state_group_id=%s prev_group=%s", + state_group_id, + prev_group, + ) + log_kv( + { + "message": "Returning state_group_id", + "state_group_id": state_group_id, + "prev_group": prev_group, + } + ) return _StateCacheEntry( state=None, state_group=state_group_id, @@ -465,6 +571,14 @@ async def resolve_state_groups_for_events( delta_ids=delta_ids, ) elif len(state_group_ids_set) == 0: + logger.info( + "resolve_state_groups_for_events: Returning empty state group since there are no state_group_ids" + ) + log_kv( + { + "message": "Returning empty state group since there are no state_group_ids", + } + ) return _StateCacheEntry(state={}, state_group=None) room_version = await self.store.get_room_version_id(room_id) @@ -480,6 +594,18 @@ async def resolve_state_groups_for_events( None, state_res_store=StateResolutionStore(self.store), ) + logger.info( + "resolve_state_groups_for_events: RResolving state groups and returning result state_to_resolve=%s result=%s", + state_to_resolve, + result, + ) + log_kv( + { + "message": "Resolving state groups and returning result", + "state_to_resolve": state_to_resolve, + "result": result, + } + ) return result async def update_current_state(self, room_id: str) -> None: diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index bde7a6648ae7..de4c8c15f95d 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -43,7 +43,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.events import EventBase +from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.logging.tracing import ( @@ -435,6 +435,22 @@ async def enqueue( else: events.append(event) + # We expect events to be persisted by this point and this makes + # mypy happy about `stream_ordering` not being optional below + assert event.internal_metadata.stream_ordering + # Invalidate related caches after we persist a new event + relation = relation_from_event(event) + self.main_store._invalidate_caches_for_event( + stream_ordering=event.internal_metadata.stream_ordering, + event_id=event.event_id, + room_id=event.room_id, + etype=event.type, + state_key=event.state_key if hasattr(event, "state_key") else None, + redacts=event.redacts, + relates_to=relation.parent_id if relation else None, + backfilled=backfilled, + ) + return ( events, self.main_store.get_room_max_token(), @@ -467,6 +483,22 @@ async def persist_event( replaced_event = replaced_events.get(event.event_id) if replaced_event: event = await self.main_store.get_event(replaced_event) + else: + # We expect events to be persisted by this point and this makes + # mypy happy about `stream_ordering` not being optional below + assert event.internal_metadata.stream_ordering + # Invalidate related caches after we persist a new event + relation = relation_from_event(event) + self.main_store._invalidate_caches_for_event( + stream_ordering=event.internal_metadata.stream_ordering, + event_id=event.event_id, + room_id=event.room_id, + etype=event.type, + state_key=event.state_key if hasattr(event, "state_key") else None, + redacts=event.redacts, + relates_to=relation.parent_id if relation else None, + backfilled=backfilled, + ) event_stream_id = event.internal_metadata.stream_ordering # stream ordering should have been assigned by now diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index db6ce83a2b32..b4c700025d6b 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -223,6 +223,7 @@ def _invalidate_caches_for_event( # process triggering the invalidation is responsible for clearing any external # cached objects. self._invalidate_local_get_event_cache(event_id) + self.have_seen_event.invalidate((room_id, event_id)) self._attempt_to_invalidate_cache("have_seen_event", (room_id, event_id)) self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,)) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 0261ff4ad614..b8c26adb7bea 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -726,22 +726,43 @@ def _get_auth_chain_difference_txn( async def get_backfill_points_in_room( self, room_id: str, - ) -> List[Tuple[str, int]]: + current_depth: int, + ) -> List[Tuple[str, int, int]]: """ Gets the oldest events(backwards extremities) in the room along with the - approximate depth. Sorted by depth, highest to lowest (descending). + approximate depth. Sorted by depth, highest to lowest (descending) so the closest + events to the `current_depth` are first in the list. + + We use this function so that we can compare and see if a client's + `current_depth` at their current scrollback is within pagination range + of the event extremities. If the `current_depth` is close to the depth + of given oldest event, we can trigger a backfill. + + We ignore extremities that have a greater depth than our `current_depth` + as: + 1. we don't really care about getting events that have happened + after our current position; and + 2. by the nature of paginating and scrolling back, we have likely + previously tried and failed to backfill from that extremity, so + to avoid getting "stuck" requesting the same backfill repeatedly + we drop those extremities. Args: room_id: Room where we want to find the oldest events + current_depth: The depth at the users current scrollback position + because we only care about finding events older than the given + `current_depth` when scrolling and paginating backwards. Returns: - List of (event_id, depth) tuples. Sorted by depth, highest to lowest - (descending) + List of (event_id, depth, stream_ordering) tuples. Sorted by depth, + highest to lowest (descending) so the closest events to the + `current_depth` are first in the list. Tie-broken with `stream_ordering`, + then `event_id` to get a stable sort. """ def get_backfill_points_in_room_txn( txn: LoggingTransaction, room_id: str - ) -> List[Tuple[str, int]]: + ) -> List[Tuple[str, int, int]]: # Assemble a tuple lookup of event_id -> depth for the oldest events # we know of in the room. Backwards extremeties are the oldest # events we know of in the room but we only know of them because @@ -750,7 +771,7 @@ def get_backfill_points_in_room_txn( # specifically). So we need to look for the approximate depth from # the events connected to the current backwards extremeties. sql = """ - SELECT backward_extrem.event_id, event.depth FROM events AS event + SELECT backward_extrem.event_id, event.depth, event.stream_ordering FROM events AS event /** * Get the edge connections from the event_edges table * so we can see whether this event's prev_events points @@ -784,6 +805,17 @@ def get_backfill_points_in_room_txn( * necessarily safe to assume that it will have been completed. */ AND edge.is_state is ? /* False */ + /** + * We only want backwards extremities that are older than or at + * the same position of the given `current_depth` (where older + * means less than the given depth) because we're looking backwards + * from the `current_depth` when backfilling. + * + * current_depth (ignore events that come after this, ignore 2-4) + * | + * [0]<--[1]▼<--[2]<--[3]<--[4] + */ + AND event.depth <= ? /* current_depth */ /** * Exponential back-off (up to the upper bound) so we don't retry the * same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc. @@ -802,7 +834,7 @@ def get_backfill_points_in_room_txn( * alphabetical order of the event_ids so we get a consistent * ordering which is nice when asserting things in tests. */ - ORDER BY event.depth DESC, backward_extrem.event_id DESC + ORDER BY event.depth DESC, event.stream_ordering DESC, backward_extrem.event_id DESC """ if isinstance(self.database_engine, PostgresEngine): @@ -817,13 +849,14 @@ def get_backfill_points_in_room_txn( ( room_id, False, + current_depth, self._clock.time_msec(), 1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS, 1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS, ), ) - return cast(List[Tuple[str, int]], txn.fetchall()) + return cast(List[Tuple[str, int, int]], txn.fetchall()) return await self.db_pool.runInteraction( "get_backfill_points_in_room", @@ -835,26 +868,47 @@ def get_backfill_points_in_room_txn( async def get_insertion_event_backward_extremities_in_room( self, room_id: str, - ) -> List[Tuple[str, int]]: + current_depth: int, + ) -> List[Tuple[str, int, int]]: """ Get the insertion events we know about that we haven't backfilled yet along with the approximate depth. Sorted by depth, highest to lowest - (descending). + (descending) so the closest events to the `current_depth` are first + in the list. + + We use this function so that we can compare and see if someones + `current_depth` at their current scrollback is within pagination range + of the insertion event. If the `current_depth` is close to the depth + of the given insertion event, we can trigger a backfill. + + We ignore insertion events that have a greater depth than our `current_depth` + as: + 1. we don't really care about getting events that have happened + after our current position; and + 2. by the nature of paginating and scrolling back, we have likely + previously tried and failed to backfill from that insertion event, so + to avoid getting "stuck" requesting the same backfill repeatedly + we drop those insertion event. Args: room_id: Room where we want to find the oldest events + current_depth: The depth at the users current scrollback position because + we only care about finding events older than the given + `current_depth` when scrolling and paginating backwards. Returns: - List of (event_id, depth) tuples. Sorted by depth, highest to lowest - (descending) + List of (event_id, depth, stream_ordering) tuples. Sorted by depth, + highest to lowest (descending) so the closest events to the + `current_depth` are first in the list. Tie-broken with `stream_ordering`, + then `event_id` to get a stable sort. """ def get_insertion_event_backward_extremities_in_room_txn( txn: LoggingTransaction, room_id: str - ) -> List[Tuple[str, int]]: + ) -> List[Tuple[str, int, int]]: sql = """ SELECT - insertion_event_extremity.event_id, event.depth + insertion_event_extremity.event_id, event.depth, event.stream_ordering /* We only want insertion events that are also marked as backwards extremities */ FROM insertion_event_extremities AS insertion_event_extremity /* Get the depth of the insertion event from the events table */ @@ -869,6 +923,17 @@ def get_insertion_event_backward_extremities_in_room_txn( AND failed_backfill_attempt_info.event_id = insertion_event_extremity.event_id WHERE insertion_event_extremity.room_id = ? + /** + * We only want extremities that are older than or at + * the same position of the given `current_depth` (where older + * means less than the given depth) because we're looking backwards + * from the `current_depth` when backfilling. + * + * current_depth (ignore events that come after this, ignore 2-4) + * | + * [0]<--[1]▼<--[2]<--[3]<--[4] + */ + AND event.depth <= ? /* current_depth */ /** * Exponential back-off (up to the upper bound) so we don't retry the * same backfill point over and over. ex. 2hr, 4hr, 8hr, 16hr, etc @@ -887,7 +952,7 @@ def get_insertion_event_backward_extremities_in_room_txn( * alphabetical order of the event_ids so we get a consistent * ordering which is nice when asserting things in tests. */ - ORDER BY event.depth DESC, insertion_event_extremity.event_id DESC + ORDER BY event.depth DESC, event.stream_ordering DESC, insertion_event_extremity.event_id DESC """ if isinstance(self.database_engine, PostgresEngine): @@ -901,12 +966,13 @@ def get_insertion_event_backward_extremities_in_room_txn( sql % (least_function,), ( room_id, + current_depth, self._clock.time_msec(), 1000 * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_SECONDS, 1000 * BACKFILL_EVENT_BACKOFF_UPPER_BOUND_SECONDS, ), ) - return cast(List[Tuple[str, int]], txn.fetchall()) + return cast(List[Tuple[str, int, int]], txn.fetchall()) return await self.db_pool.runInteraction( "get_insertion_event_backward_extremities_in_room", diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 0b86a53085d5..648afeca446b 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -209,7 +209,9 @@ async def _persist_events_and_state_updates( async with stream_ordering_manager as stream_orderings: for (event, _), stream in zip(events_and_contexts, stream_orderings): - event.internal_metadata.stream_ordering = stream + # foo + if event.internal_metadata.stream_ordering is None: + event.internal_metadata.stream_ordering = stream await self.db_pool.runInteraction( "persist_events", diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 9f6b1fcef1ce..74991a4992e0 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1474,32 +1474,38 @@ async def have_seen_events( # the batches as big as possible. results: Set[str] = set() - for chunk in batch_iter(event_ids, 500): - r = await self._have_seen_events_dict( - [(room_id, event_id) for event_id in chunk] + for event_ids_chunk in batch_iter(event_ids, 500): + events_seen_dict = await self._have_seen_events_dict( + room_id, event_ids_chunk + ) + results.update( + eid for (eid, have_event) in events_seen_dict.items() if have_event ) - results.update(eid for ((_rid, eid), have_event) in r.items() if have_event) return results - @cachedList(cached_method_name="have_seen_event", list_name="keys") + @cachedList(cached_method_name="have_seen_event", list_name="event_ids") async def _have_seen_events_dict( - self, keys: Collection[Tuple[str, str]] - ) -> Dict[Tuple[str, str], bool]: + self, + room_id: str, + event_ids: Collection[str], + ) -> Dict[str, bool]: """Helper for have_seen_events Returns: - a dict {(room_id, event_id)-> bool} + a dict {event_id -> bool} """ # if the event cache contains the event, obviously we've seen it. cache_results = { - (rid, eid) - for (rid, eid) in keys - if await self._get_event_cache.contains((eid,)) + event_id + for event_id in event_ids + if await self._get_event_cache.contains((event_id,)) } results = dict.fromkeys(cache_results, True) - remaining = [k for k in keys if k not in cache_results] + remaining = [ + event_id for event_id in event_ids if event_id not in cache_results + ] if not remaining: return results @@ -1511,23 +1517,21 @@ def have_seen_events_txn(txn: LoggingTransaction) -> None: sql = "SELECT event_id FROM events AS e WHERE " clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining] + txn.database_engine, "e.event_id", remaining ) txn.execute(sql + clause, args) found_events = {eid for eid, in txn} # ... and then we can update the results for each key - results.update( - {(rid, eid): (eid in found_events) for (rid, eid) in remaining} - ) + results.update({eid: (eid in found_events) for eid in remaining}) await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn) return results @cached(max_entries=100000, tree=True) async def have_seen_event(self, room_id: str, event_id: str) -> bool: - res = await self._have_seen_events_dict(((room_id, event_id),)) - return res[(room_id, event_id)] + res = await self._have_seen_events_dict(room_id, [event_id]) + return res[event_id] def _get_current_state_event_counts_txn( self, txn: LoggingTransaction, room_id: str diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 64cba763c45a..8e2153683630 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -30,7 +30,6 @@ from synapse.api.constants import RelationTypes from synapse.events import EventBase -from synapse.logging.tracing import trace from synapse.storage._base import SQLBaseStore from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause from synapse.storage.databases.main.stream import generate_pagination_where_clause diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 3909f1caea24..0391966462e7 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -431,6 +431,12 @@ def __get__( cache: DeferredCache[CacheKey, Any] = cached_method.cache num_args = cached_method.num_args + if num_args != self.num_args: + raise Exception( + "Number of args (%s) does not match underlying cache_method_name=%s (%s)." + % (self.num_args, self.cached_method_name, num_args) + ) + @functools.wraps(self.orig) def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]": # If we're passed a cache_context then we'll want to call its diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 918010cddbf9..eb9c125e8d79 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -11,16 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +import json +from typing import Dict, List, Optional, Tuple from unittest import mock +from unittest.mock import Mock, patch +from synapse.api.constants import EventContentFields, EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersion +from synapse.appservice import ApplicationService from synapse.event_auth import ( check_state_dependent_auth_rules, check_state_independent_auth_rules, ) -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.federation.transport.client import StateRequestResponse from synapse.logging.context import LoggingContext @@ -28,9 +32,15 @@ from synapse.rest.client import login, room from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort from synapse.types import JsonDict +from synapse.util.stringutils import random_string from tests import unittest from tests.test_utils import event_injection, make_awaitable +from tests.test_utils.event_injection import create_event, inject_event + +import logging + +logger = logging.getLogger(__name__) class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): @@ -45,10 +55,24 @@ def make_homeserver(self, reactor, clock): self.mock_federation_transport_client = mock.Mock( spec=["get_room_state_ids", "get_room_state", "get_event"] ) - return super().setup_test_homeserver( - federation_transport_client=self.mock_federation_transport_client + + self.appservice = ApplicationService( + token="i_am_an_app_service", + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + # Note: this user does not have to match the regex above + sender="@as_main:test", ) + mock_load_appservices = Mock(return_value=[self.appservice]) + with patch( + "synapse.storage.databases.main.appservice.load_appservices", + mock_load_appservices, + ): + return super().setup_test_homeserver( + federation_transport_client=self.mock_federation_transport_client + ) + def test_process_pulled_event_with_missing_state(self) -> None: """Ensure that we correctly handle pulled events with lots of missing state @@ -848,3 +872,454 @@ async def get_room_state( bert_member_event.event_id, "Rejected kick event unexpectedly became part of room state.", ) + + def test_process_pulled_events_asdf(self) -> None: + main_store = self.hs.get_datastores().main + state_storage_controller = self.hs.get_storage_controllers().state + + def _debug_event_string(event: EventBase) -> str: + debug_body = event.content.get("body", event.type) + maybe_state_key = getattr(event, "state_key", None) + return f"event_id={event.event_id},depth={event.depth},stream_ordering={event.internal_metadata.stream_ordering},body={debug_body}({maybe_state_key}),prevs={event.prev_event_ids()}" + + known_event_dict: Dict[str, Tuple[EventBase, List[EventBase]]] = {} + + def _add_to_known_event_list( + event: EventBase, state_events: Optional[List[EventBase]] = None + ) -> None: + if state_events is None: + state_map = self.get_success( + state_storage_controller.get_state_for_event(event.event_id) + ) + state_events = list(state_map.values()) + + known_event_dict[event.event_id] = (event, state_events) + + async def get_room_state_ids( + destination: str, room_id: str, event_id: str + ) -> JsonDict: + self.assertEqual(destination, self.OTHER_SERVER_NAME) + known_event_info = known_event_dict.get(event_id) + if known_event_info is None: + self.fail( + f"stubbed get_room_state_ids: Event ({event_id}) not part of our known events list" + ) + + known_event, known_event_state_list = known_event_info + logger.info( + "stubbed get_room_state_ids: destination=%s event_id=%s auth_event_ids=%s", + destination, + event_id, + known_event.auth_event_ids(), + ) + + # self.assertEqual(event_id, missing_event.event_id) + return { + "pdu_ids": [ + state_event.event_id for state_event in known_event_state_list + ], + "auth_chain_ids": known_event.auth_event_ids(), + } + + async def get_room_state( + room_version: RoomVersion, destination: str, room_id: str, event_id: str + ) -> StateRequestResponse: + self.assertEqual(destination, self.OTHER_SERVER_NAME) + known_event_info = known_event_dict.get(event_id) + if known_event_info is None: + self.fail( + f"stubbed get_room_state: Event ({event_id}) not part of our known events list" + ) + + known_event, known_event_state_list = known_event_info + logger.info( + "stubbed get_room_state: destination=%s event_id=%s auth_event_ids=%s", + destination, + event_id, + known_event.auth_event_ids(), + ) + + auth_event_ids = known_event.auth_event_ids() + auth_events = [] + for auth_event_id in auth_event_ids: + known_event_info = known_event_dict.get(event_id) + if known_event_info is None: + self.fail( + f"stubbed get_room_state: Auth event ({auth_event_id}) is not part of our known events list" + ) + known_auth_event, _ = known_event_info + auth_events.append(known_auth_event) + + return StateRequestResponse( + state=known_event_state_list, + auth_events=auth_events, + ) + + async def get_event(destination: str, event_id: str, timeout=None): + self.assertEqual(destination, self.OTHER_SERVER_NAME) + known_event_info = known_event_dict.get(event_id) + if known_event_info is None: + self.fail( + f"stubbed get_event: Event ({event_id}) not part of our known events list" + ) + + known_event, _ = known_event_info + return {"pdus": [known_event.get_pdu_json()]} + + self.mock_federation_transport_client.get_room_state_ids.side_effect = ( + get_room_state_ids + ) + self.mock_federation_transport_client.get_room_state.side_effect = ( + get_room_state + ) + + self.mock_federation_transport_client.get_event.side_effect = get_event + + # create the room + room_creator = self.appservice.sender + room_id = self.helper.create_room_as( + room_creator=self.appservice.sender, tok=self.appservice.token + ) + room_version = self.get_success(main_store.get_room_version(room_id)) + + event_before = self.get_success( + inject_event( + self.hs, + room_id=room_id, + sender=room_creator, + type=EventTypes.Message, + content={"body": "eventBefore0", "msgtype": "m.text"}, + ) + ) + _add_to_known_event_list(event_before) + + event_after = self.get_success( + inject_event( + self.hs, + room_id=room_id, + sender=room_creator, + type=EventTypes.Message, + content={"body": "eventAfter0", "msgtype": "m.text"}, + ) + ) + _add_to_known_event_list(event_after) + + state_map = self.get_success( + state_storage_controller.get_state_for_event(event_before.event_id) + ) + + room_create_event = state_map.get((EventTypes.Create, "")) + pl_event = state_map.get((EventTypes.PowerLevels, "")) + as_membership_event = state_map.get((EventTypes.Member, room_creator)) + assert room_create_event is not None + assert pl_event is not None + assert as_membership_event is not None + + for state_event in state_map.values(): + _add_to_known_event_list(state_event) + + # This should be the successor of the event we want to insert next to + # (the successor of event_before is event_after). + inherited_depth = event_after.depth + + historical_base_auth_event_ids = [ + room_create_event.event_id, + pl_event.event_id, + ] + historical_state_events = list(state_map.values()) + historical_state_event_ids = [ + state_event.event_id for state_event in historical_state_events + ] + + maria_mxid = "@maria:test" + maria_membership_event, _ = self.get_success( + create_event( + self.hs, + room_id=room_id, + sender=maria_mxid, + state_key=maria_mxid, + type=EventTypes.Member, + content={ + "membership": "join", + }, + # It all works when I add a prev_event for the floating + # insertion event but the event no longer floats. + # It's able to resolve state at the prev_events though. + prev_event_ids=[event_before.event_id], + # allow_no_prev_events=True, + # prev_event_ids=[], + # auth_event_ids=historical_base_auth_event_ids, + # + # Because we're creating all of these events without persisting them yet, + # we have to explicitly provide some auth_events. For member events, we do it this way. + state_event_ids=historical_state_event_ids, + depth=inherited_depth, + ) + ) + _add_to_known_event_list(maria_membership_event, historical_state_events) + logger.info("maria_membership_event=%s", maria_membership_event.event_id) + + historical_state_events.append(maria_membership_event) + historical_state_event_ids.append(maria_membership_event.event_id) + + batch_id = random_string(8) + next_batch_id = random_string(8) + insertion_event, _ = self.get_success( + create_event( + self.hs, + room_id=room_id, + sender=room_creator, + type=EventTypes.MSC2716_INSERTION, + content={ + EventContentFields.MSC2716_NEXT_BATCH_ID: next_batch_id, + EventContentFields.MSC2716_HISTORICAL: True, + }, + # The difference from the actual room /batch_send is that this is normally + # floating as well. But seems to work once we connect it to the + # floating historical state chain. + prev_event_ids=[maria_membership_event.event_id], + # allow_no_prev_events=True, + # prev_event_ids=[], + # Because we're creating all of these events without persisting them yet, + # we have to explicitly provide some auth_events + auth_event_ids=[ + *historical_base_auth_event_ids, + as_membership_event.event_id, + ], + # state_event_ids=historical_state_event_ids, + depth=inherited_depth, + ) + ) + _add_to_known_event_list(insertion_event, historical_state_events) + historical_message_event, _ = self.get_success( + create_event( + self.hs, + room_id=room_id, + sender=maria_mxid, + type=EventTypes.Message, + content={"body": "Historical message", "msgtype": "m.text"}, + prev_event_ids=[insertion_event.event_id], + # Because we're creating all of these events without persisting them yet, + # we have to explicitly provide some auth_events + auth_event_ids=[ + *historical_base_auth_event_ids, + maria_membership_event.event_id, + ], + depth=inherited_depth, + ) + ) + _add_to_known_event_list(historical_message_event, historical_state_events) + batch_event, _ = self.get_success( + create_event( + self.hs, + room_id=room_id, + sender=room_creator, + type=EventTypes.MSC2716_BATCH, + content={ + EventContentFields.MSC2716_BATCH_ID: batch_id, + EventContentFields.MSC2716_HISTORICAL: True, + }, + prev_event_ids=[historical_message_event.event_id], + # Because we're creating all of these events without persisting them yet, + # we have to explicitly provide some auth_events + auth_event_ids=[ + *historical_base_auth_event_ids, + as_membership_event.event_id, + ], + depth=inherited_depth, + ) + ) + _add_to_known_event_list(batch_event, historical_state_events) + base_insertion_event, base_insertion_event_context = self.get_success( + create_event( + self.hs, + room_id=room_id, + sender=room_creator, + type=EventTypes.MSC2716_INSERTION, + content={ + EventContentFields.MSC2716_NEXT_BATCH_ID: batch_id, + EventContentFields.MSC2716_HISTORICAL: True, + }, + prev_event_ids=[event_before.event_id], + # Because we're creating all of these events without persisting them yet, + # we have to explicitly provide some auth_events + auth_event_ids=[ + *historical_base_auth_event_ids, + as_membership_event.event_id, + ], + # state_event_ids=historical_state_event_ids, + depth=inherited_depth, + ) + ) + _add_to_known_event_list(base_insertion_event, historical_state_events) + + # Chronological + pulled_events: List[EventBase] = [ + # Beginning of room (oldest messages) + # *list(state_map.values()), + room_create_event, + pl_event, + as_membership_event, + state_map.get((EventTypes.JoinRules, "")), + state_map.get((EventTypes.RoomHistoryVisibility, "")), + event_before, + # HISTORICAL MESSAGE END + insertion_event, + historical_message_event, + batch_event, + base_insertion_event, + # HISTORICAL MESSAGE START + event_after, + # Latest in the room (newest messages) + ] + + # pulled_events: List[EventBase] = [ + # # Beginning of room (oldest messages) + # # *list(state_map.values()), + # room_create_event, + # pl_event, + # as_membership_event, + # state_map.get((EventTypes.JoinRules, "")), + # state_map.get((EventTypes.RoomHistoryVisibility, "")), + # event_before, + # # HISTORICAL MESSAGE END + # insertion_event, + # historical_message_event, + # batch_event, + # base_insertion_event, + # # HISTORICAL MESSAGE START + # event_after, + # # Latest in the room (newest messages) + # ] + + # The order that we get after passing reverse chronological events in + # that mostly passes. Only the insertion event is rejected but the + # historical messages appear /messages scrollback. + # pulled_events: List[EventBase] = [ + # # Beginning of room (oldest messages) + # # *list(state_map.values()), + # room_create_event, + # pl_event, + # as_membership_event, + # state_map.get((EventTypes.JoinRules, "")), + # state_map.get((EventTypes.RoomHistoryVisibility, "")), + # event_before, + # event_after, + # base_insertion_event, + # batch_event, + # historical_message_event, + # insertion_event, + # # Latest in the room (newest messages) + # ] + + logger.info( + "pulled_events=%s", + json.dumps( + [_debug_event_string(event) for event in pulled_events], + indent=4, + ), + ) + + for event, _ in known_event_dict.values(): + if event.internal_metadata.outlier: + self.fail("Our pristine events should not be marked as an outlier") + + # TODO: We currently don't set the `stream_ordering` on `pulled_events` here + # like we normally would via `backfill(..._` before passing it off to + # `_process_pulled_events(...)` + self.get_success( + self.hs.get_federation_event_handler()._process_pulled_events( + self.OTHER_SERVER_NAME, + [ + # Make copies of events since Synapse modifies the + # internal_metadata in place and we want to keep our + # pristine copies + make_event_from_dict(pulled_event.get_pdu_json(), room_version) + for pulled_event in pulled_events + ], + backfilled=True, + ) + ) + + from_token = self.get_success( + self.hs.get_event_sources().get_current_token_for_pagination(room_id) + ) + actual_events_in_room_reverse_chronological, _ = self.get_success( + main_store.paginate_room_events( + room_id, from_key=from_token.room_key, limit=100, direction="b" + ) + ) + + # We have to reverse the list to make it chronological. + actual_events_in_room_chronological = list( + reversed(actual_events_in_room_reverse_chronological) + ) + + expected_event_order = [ + # Beginning of room (oldest messages) + # *list(state_map.values()), + room_create_event, + as_membership_event, + pl_event, + state_map.get((EventTypes.JoinRules, "")), + state_map.get((EventTypes.RoomHistoryVisibility, "")), + event_before, + # HISTORICAL MESSAGE END + insertion_event, + historical_message_event, + batch_event, + base_insertion_event, + # HISTORICAL MESSAGE START + event_after, + # Latest in the room (newest messages) + ] + + event_id_diff = {event.event_id for event in expected_event_order} - { + event.event_id for event in actual_events_in_room_chronological + } + event_diff_ordered = [ + event for event in expected_event_order if event.event_id in event_id_diff + ] + event_id_extra = { + event.event_id for event in actual_events_in_room_chronological + } - {event.event_id for event in expected_event_order} + event_extra_ordered = [ + event + for event in actual_events_in_room_chronological + if event.event_id in event_id_extra + ] + assertion_message = ( + "Debug info:\nActual events missing from expected list: %s\nActual events contain %d additional events compared to expected: %s\nExpected event order: %s\nActual event order: %s" + % ( + json.dumps( + [_debug_event_string(event) for event in event_diff_ordered], + indent=4, + ), + len(event_extra_ordered), + json.dumps( + [_debug_event_string(event) for event in event_extra_ordered], + indent=4, + ), + json.dumps( + [_debug_event_string(event) for event in expected_event_order], + indent=4, + ), + json.dumps( + [ + _debug_event_string(event) + for event in actual_events_in_room_chronological + ], + indent=4, + ), + ) + ) + + # assert ( + # actual_events_in_room_chronological == expected_event_order + # ), assertion_message + + self.assertEqual( + [event.event_id for event in actual_events_in_room_chronological], + [event.event_id for event in expected_event_order], + assertion_message, + ) diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 67401272ac37..32a798d74bca 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -35,66 +35,45 @@ from synapse.util.async_helpers import yieldable_gather_results from tests import unittest +from tests.test_utils.event_injection import create_event, inject_event class HaveSeenEventsTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + def prepare(self, reactor, clock, hs): + self.hs = hs self.store: EventsWorkerStore = hs.get_datastores().main - # insert some test data - for rid in ("room1", "room2"): - self.get_success( - self.store.db_pool.simple_insert( - "rooms", - {"room_id": rid, "room_version": 4}, - ) - ) + self.user = self.register_user("user", "pass") + self.token = self.login(self.user, "pass") + self.room_id = self.helper.create_room_as(self.user, tok=self.token) self.event_ids: List[str] = [] - for idx, rid in enumerate( - ( - "room1", - "room1", - "room1", - "room2", - ) - ): - event_json = {"type": f"test {idx}", "room_id": rid} - event = make_event_from_dict(event_json, room_version=RoomVersions.V4) - event_id = event.event_id - - self.get_success( - self.store.db_pool.simple_insert( - "events", - { - "event_id": event_id, - "room_id": rid, - "topological_ordering": idx, - "stream_ordering": idx, - "type": event.type, - "processed": True, - "outlier": False, - }, + for i in range(3): + event = self.get_success( + inject_event( + hs, + room_version=RoomVersions.V7.identifier, + room_id=self.room_id, + sender=self.user, + type="test_event_type", + content={"body": f"foobarbaz{i}"}, ) ) - self.get_success( - self.store.db_pool.simple_insert( - "event_json", - { - "event_id": event_id, - "room_id": rid, - "json": json.dumps(event_json), - "internal_metadata": "{}", - "format_version": 3, - }, - ) - ) - self.event_ids.append(event_id) + + self.event_ids.append(event.event_id) def test_simple(self): with LoggingContext(name="test") as ctx: res = self.get_success( - self.store.have_seen_events("room1", [self.event_ids[0], "event19"]) + self.store.have_seen_events( + self.room_id, [self.event_ids[0], "eventdoesnotexist"] + ) ) self.assertEqual(res, {self.event_ids[0]}) @@ -104,7 +83,9 @@ def test_simple(self): # a second lookup of the same events should cause no queries with LoggingContext(name="test") as ctx: res = self.get_success( - self.store.have_seen_events("room1", [self.event_ids[0], "event19"]) + self.store.have_seen_events( + self.room_id, [self.event_ids[0], "eventdoesnotexist"] + ) ) self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) @@ -116,11 +97,86 @@ def test_query_via_event_cache(self): # looking it up should now cause no db hits with LoggingContext(name="test") as ctx: res = self.get_success( - self.store.have_seen_events("room1", [self.event_ids[0]]) + self.store.have_seen_events(self.room_id, [self.event_ids[0]]) ) self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) + def test_persisting_event_invalidates_cache(self): + """ + Test to make sure that the `have_seen_event` cache + is invalidated after we persist an event and returns + the updated value. + """ + event, event_context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + sender=self.user, + type="test_event_type", + content={"body": "garply"}, + ) + ) + + with LoggingContext(name="test") as ctx: + # First, check `have_seen_event` for an event we have not seen yet + # to prime the cache with a `false` value. + res = self.get_success( + self.store.have_seen_events(event.room_id, [event.event_id]) + ) + self.assertEqual(res, set()) + + # That should result in a single db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + + # Persist the event which should invalidate or prefill the + # `have_seen_event` cache so we don't return stale values. + persistence = self.hs.get_storage_controllers().persistence + self.get_success( + persistence.persist_event( + event, + event_context, + ) + ) + + with LoggingContext(name="test") as ctx: + # Check `have_seen_event` again and we should see the updated fact + # that we have now seen the event after persisting it. + res = self.get_success( + self.store.have_seen_events(event.room_id, [event.event_id]) + ) + self.assertEqual(res, {event.event_id}) + + # That should result in a single db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + + def test_invalidate_cache_by_room_id(self): + """ + Test to make sure that all events associated with the given `(room_id,)` + are invalidated in the `have_seen_event` cache. + """ + with LoggingContext(name="test") as ctx: + # Prime the cache with some values + res = self.get_success( + self.store.have_seen_events(self.room_id, self.event_ids) + ) + self.assertEqual(res, set(self.event_ids)) + + # That should result in a single db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + + # Clear the cache with any events associated with the `room_id` + self.store.have_seen_event.invalidate((self.room_id,)) + + with LoggingContext(name="test") as ctx: + res = self.get_success( + self.store.have_seen_events(self.room_id, self.event_ids) + ) + self.assertEqual(res, set(self.event_ids)) + + # Since we cleared the cache, it should result in another db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + class EventCacheTestCase(unittest.HomeserverTestCase): """Test that the various layers of event cache works.""" diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 8027c7a856e2..497ee188ca39 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -82,6 +82,11 @@ async def create_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[List[str]] = None, + *, + allow_no_prev_events: Optional[bool] = False, + auth_event_ids: Optional[List[str]] = None, + state_event_ids: Optional[List[str]] = None, + depth: Optional[int] = None, **kwargs, ) -> Tuple[EventBase, EventContext]: if room_version is None: @@ -89,11 +94,21 @@ async def create_event( kwargs["room_id"] ) + import logging + + logger = logging.getLogger(__name__) + logger.info("kwargs=%s", kwargs) builder = hs.get_event_builder_factory().for_room_version( KNOWN_ROOM_VERSIONS[room_version], kwargs ) event, context = await hs.get_event_creation_handler().create_new_client_event( - builder, prev_event_ids=prev_event_ids + builder, + # Why does this need another default to pass: `Argument "allow_no_prev_events" to "create_new_client_event" of "EventCreationHandler" has incompatible type "Optional[bool]"; expected "bool"` + allow_no_prev_events=allow_no_prev_events or False, + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, + state_event_ids=state_event_ids, + depth=depth, ) return event, context diff --git a/tests/unittest.py b/tests/unittest.py index 00cb023198b5..1431c8e9d7ec 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -743,7 +743,7 @@ def inject_room_member(self, room: str, user: str, membership: str) -> None: """ Inject a membership event into a room. - Deprecated: use event_injection.inject_room_member directly + Deprecated: use event_injection.inject_member_event directly Args: room: Room ID to inject the event into. diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 48e616ac7419..90861fe522c2 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Set +from typing import Iterable, Set, Tuple from unittest import mock from twisted.internet import defer, reactor @@ -1008,3 +1008,34 @@ async def do_lookup(): obj.inner_context_was_finished, "Tried to restart a finished logcontext" ) self.assertEqual(current_context(), SENTINEL_CONTEXT) + + def test_num_args_mismatch(self): + """ + Make sure someone does not accidentally use @cachedList on a method with + a mismatch in the number args to the underlying single cache method. + """ + + class Cls: + @descriptors.cached(tree=True) + def fn(self, room_id, event_id): + pass + + # This is wrong ❌. `@cachedList` expects to be given the same number + # of arguments as the underlying cached function, just with one of + # the arguments being an iterable + @descriptors.cachedList(cached_method_name="fn", list_name="keys") + def list_fn(self, keys: Iterable[Tuple[str, str]]): + pass + + # Corrected syntax ✅ + # + # @cachedList(cached_method_name="fn", list_name="event_ids") + # async def list_fn( + # self, room_id: str, event_ids: Collection[str], + # ) + + obj = Cls() + + # Make sure this raises an error about the arg mismatch + with self.assertRaises(Exception): + obj.list_fn([("foo", "bar")])