diff --git a/changelog.d/14723.bugfix b/changelog.d/14723.bugfix new file mode 100644 index 000000000000..e1f89cee35c8 --- /dev/null +++ b/changelog.d/14723.bugfix @@ -0,0 +1 @@ +Ensure stream IDs are always updated after caches get invalidated with workers. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index b4dad47b45ad..d1c45e647805 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -148,6 +148,9 @@ async def on_rdata( rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ self.store.process_replication_rows(stream_name, instance_name, token, rows) + # NOTE: this must be called after process_replication_rows to ensure any + # cache invalidations are first handled before any stream ID advances. + self.store.process_replication_position(stream_name, instance_name, token) if self.send_handler: await self.send_handler.process_replication_rows(stream_name, token, rows) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 7f690129c940..e8092b85f137 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -59,7 +59,22 @@ def process_replication_rows( # noqa: B027 (no-op by design) token: int, rows: Iterable[Any], ) -> None: - pass + """ + Used by storage classes to invalidate caches based on incoming replication data. These + must not update any ID generators, use `process_replication_position`. + """ + + def process_replication_position( # noqa: B027 (no-op by design) + self, + stream_name: str, + instance_name: str, + token: int, + ) -> None: + """ + Used by storage classes to advance ID generators based on incoming replication data. This + is called after process_replication_rows such that caches are invalidated before any token + positions advance. + """ def _invalidate_state_caches( self, room_id: str, members_changed: Collection[str] diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 913efddcdb4a..f7874b7be650 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -415,10 +415,7 @@ def process_replication_rows( token: int, rows: Iterable[Any], ) -> None: - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - elif stream_name == AccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) + if stream_name == AccountDataStream.NAME: for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( @@ -433,6 +430,15 @@ def process_replication_rows( super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + elif stream_name == AccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict ) -> int: diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index bfa9561adcbd..858575dfddbd 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -164,9 +164,6 @@ def process_replication_rows( backfilled=True, ) elif stream_name == CachesStream.NAME: - if self._cache_id_gen: - self._cache_id_gen.advance(instance_name, token) - for row in rows: if row.cache_func == CURRENT_STATE_CACHE_NAME: if row.keys is None: @@ -182,6 +179,14 @@ def process_replication_rows( super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == CachesStream.NAME: + if self._cache_id_gen: + self._cache_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: data = row.data @@ -198,8 +203,14 @@ def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: - # TODO: Nothing to do here, handled in events_worker, cleanup? - pass + assert isinstance(data, EventsStreamCurrentStateRow) + self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) + + if data.type == EventTypes.Member: + self.get_rooms_for_user_with_stream_ordering.invalidate( + (data.state_key,) + ) + self.get_rooms_for_user.invalidate((data.state_key,)) else: raise Exception("Unknown events stream row type %s" % (row.type,)) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 87a0798e8f55..89b85ec9caa7 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -160,10 +160,15 @@ def process_replication_rows( self._device_federation_outbox_stream_cache.entity_has_changed( row.entity, token ) - # Important that the ID gen advances after stream change caches - self._device_inbox_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ToDeviceStream.NAME: + self._device_inbox_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def get_to_device_stream_token(self) -> int: return self._device_inbox_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 16d39d7bfd20..db1cbfb506ef 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -163,15 +163,20 @@ def process_replication_rows( ) -> None: if stream_name == DeviceListsStream.NAME: self._invalidate_caches_for_devices(token, rows) - # Important that the ID gen advances after stream change caches - self._device_list_id_gen.advance(instance_name, token) elif stream_name == UserSignatureStream.NAME: for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) - # Important that the ID gen advances after stream change caches - self._device_list_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == DeviceListsStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + elif stream_name == UserSignatureStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _invalidate_caches_for_devices( self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] ) -> None: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index db90c8769ea1..7f74fddc9d5b 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -249,22 +249,6 @@ def __init__( prefilled_cache=curr_state_delta_prefill, ) - event_cache_prefill, min_event_val = self.db_pool.get_cache_dict( - db_conn, - "events", - entity_column="room_id", - stream_column="stream_ordering", - max_value=events_max, - ) - self._events_stream_cache = StreamChangeCache( - "EventsRoomStreamChangeCache", - min_event_val, - prefilled_cache=event_cache_prefill, - ) - self._membership_stream_cache = StreamChangeCache( - "MembershipStreamChangeCache", events_max - ) - if hs.config.worker.run_background_tasks: # We periodically clean out old transaction ID mappings self._clock.looping_call( @@ -332,28 +316,27 @@ def process_replication_rows( token: int, rows: Iterable[Any], ) -> None: - # Process event stream replication rows, handling both the ID generators from the events - # worker store and the stream change caches in this store as the two are interlinked. - if stream_name == EventsStream.NAME: + if stream_name == UnPartialStatedEventStream.NAME: for row in rows: - if row.type == EventsStreamEventRow.TypeId: - self._events_stream_cache.entity_has_changed( - row.data.room_id, token - ) - if row.data.type == EventTypes.Member: - self._membership_stream_cache.entity_has_changed( - row.data.state_key, token - ) - if row.type == EventsStreamCurrentStateRow.TypeId: - self._curr_state_delta_stream_cache.entity_has_changed( - row.data.room_id, token - ) - # Important that the ID gen advances after stream change caches + assert isinstance(row, UnPartialStatedEventStreamRow) + + self.is_partial_state_event.invalidate((row.event_id,)) + + if row.rejection_status_changed: + # If the partial-stated event became rejected or unrejected + # when it wasn't before, we need to invalidate this cache. + self._invalidate_local_get_event_cache(row.event_id) + + super().process_replication_rows(stream_name, instance_name, token, rows) + + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == EventsStream.NAME: self._stream_id_gen.advance(instance_name, token) elif stream_name == BackfillStream.NAME: self._backfill_id_gen.advance(instance_name, -token) - - super().process_replication_rows(stream_name, instance_name, token, rows) + super().process_replication_position(stream_name, instance_name, token) async def have_censored_event(self, event_id: str) -> bool: """Check if an event has been censored, i.e. if the content of the event has been erased diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 9769a18a9d0c..7b60815043a6 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -439,8 +439,14 @@ def process_replication_rows( rows: Iterable[Any], ) -> None: if stream_name == PresenceStream.NAME: - self._presence_id_gen.advance(instance_name, token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) return super().process_replication_rows(stream_name, instance_name, token, rows) + + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == PresenceStream.NAME: + self._presence_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index d4c64c46ad44..d4e4b777da95 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -154,6 +154,13 @@ def process_replication_rows( self.push_rules_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == PushRulesStream.NAME: + self._push_rules_stream_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + @cached(max_entries=5000) async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: rows = await self.db_pool.simple_select_list( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 40fd781a6ab1..7f24a3b6ec5e 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -111,12 +111,12 @@ def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]: def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() - def process_replication_rows( - self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + def process_replication_position( + self, stream_name: str, instance_name: str, token: int ) -> None: if stream_name == PushersStream.NAME: self._pushers_id_gen.advance(instance_name, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + super().process_replication_position(stream_name, instance_name, token) async def get_pushers_by_app_id_and_pushkey( self, app_id: str, pushkey: str diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 89414097720f..b5cf2d8c3511 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -600,11 +600,15 @@ def process_replication_rows( row.room_id, row.receipt_type, row.user_id ) self._receipts_stream_cache.entity_has_changed(row.room_id, token) - # Important that the ID gen advances after stream change caches - self._receipts_id_gen.advance(instance_name, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ReceiptsStream.NAME: + self._receipts_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _insert_linearized_receipt_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 62f65f4cab54..6924add6c96e 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -71,6 +71,7 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import PersistedEventPosition, RoomStreamToken from synapse.util.caches.descriptors import cached +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.cancellation import cancellable if TYPE_CHECKING: @@ -396,6 +397,23 @@ def __init__( # during startup which would cause one to die. self._need_to_reset_federation_stream_positions = self._send_federation + events_max = self.get_room_max_stream_ordering() + event_cache_prefill, min_event_val = self.db_pool.get_cache_dict( + db_conn, + "events", + entity_column="room_id", + stream_column="stream_ordering", + max_value=events_max, + ) + self._events_stream_cache = StreamChangeCache( + "EventsRoomStreamChangeCache", + min_event_val, + prefilled_cache=event_cache_prefill, + ) + self._membership_stream_cache = StreamChangeCache( + "MembershipStreamChangeCache", events_max + ) + self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index b0f5de67a30d..e23c927e02f2 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -300,13 +300,19 @@ def process_replication_rows( rows: Iterable[Any], ) -> None: if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed(row.user_id, token) super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + class TagsStore(TagsWorkerStore): pass