diff --git a/changelog.d/19211.misc b/changelog.d/19211.misc new file mode 100644 index 00000000000..d8a4a44662e --- /dev/null +++ b/changelog.d/19211.misc @@ -0,0 +1 @@ +Expire sliding sync connections that are too old or have too much pending data. diff --git a/synapse/handlers/sliding_sync/room_lists.py b/synapse/handlers/sliding_sync/room_lists.py index 3d119022367..fa4ff22b645 100644 --- a/synapse/handlers/sliding_sync/room_lists.py +++ b/synapse/handlers/sliding_sync/room_lists.py @@ -34,10 +34,12 @@ EventTypes, Membership, ) +from synapse.api.errors import SlidingSyncUnknownPosition from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import StrippedStateEvent from synapse.events.utils import parse_stripped_state_event from synapse.logging.opentracing import start_active_span, trace +from synapse.storage.databases.main.sliding_sync import UPDATE_INTERVAL_LAST_USED_TS_MS from synapse.storage.databases.main.state import ( ROOM_UNKNOWN_SENTINEL, Sentinel as StateSentinel, @@ -68,6 +70,7 @@ ) from synapse.types.state import StateFilter from synapse.util import MutableOverlayMapping +from synapse.util.constants import MILLISECONDS_PER_SECOND, ONE_HOUR_SECONDS from synapse.util.sentinel import Sentinel if TYPE_CHECKING: @@ -77,6 +80,27 @@ logger = logging.getLogger(__name__) +# Minimum time in milliseconds since the last sync before we consider expiring +# the connection due to too many rooms to send. This stops from getting into +# tight loops with clients that request lots of data at once. +# +# c.f. `NUM_ROOMS_THRESHOLD`. These values are somewhat arbitrary picked. +MINIMUM_NOT_USED_AGE_EXPIRY_MS = ONE_HOUR_SECONDS * MILLISECONDS_PER_SECOND + +# How many rooms with updates we allow before we consider the connection expired +# due to too many rooms to send. +# +# c.f. `MINIMUM_NOT_USED_AGE_EXPIRY_MS`. These values are somewhat arbitrary +# picked. +NUM_ROOMS_THRESHOLD = 100 + +# Sanity check that our minimum age is sensible compared to the update interval, +# i.e. if `MINIMUM_NOT_USED_AGE_EXPIRY_MS` is too small then we might expire the +# connection even if it is actively being used (and we're just not updating the +# DB frequently enough). We arbitrarily double the update interval to give some +# wiggle room. +assert 2 * UPDATE_INTERVAL_LAST_USED_TS_MS < MINIMUM_NOT_USED_AGE_EXPIRY_MS + # Helper definition for the types that we might return. We do this to avoid # copying data between types (which can be expensive for many rooms). RoomsForUserType = RoomsForUserStateReset | RoomsForUser | RoomsForUserSlidingSync @@ -176,6 +200,7 @@ def __init__(self, hs: "HomeServer"): self.storage_controllers = hs.get_storage_controllers() self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync self.is_mine_id = hs.is_mine_id + self._clock = hs.get_clock() async def compute_interested_rooms( self, @@ -857,11 +882,41 @@ async def _filter_relevant_rooms_to_send( # We only need to check for new events since any state changes # will also come down as new events. - rooms_that_have_updates = ( - self.store.get_rooms_that_might_have_updates( + + rooms_that_have_updates = await ( + self.store.get_rooms_that_have_updates_since_sliding_sync_table( relevant_room_map.keys(), from_token.room_key ) ) + + # Check if we have lots of updates to send, if so then its + # better for us to tell the client to do a full resync + # instead (to try and avoid long SSS response times when + # there is new data). + # + # Due to the construction of the SSS API, the client is in + # charge of setting the range of rooms to request updates + # for. Generally, it will start with a small range and then + # expand (and occasionally it may contract the range again + # if its been offline for a while). If we know there are a + # lot of updates, it's better to reset the connection and + # wait for the client to start again (with a much smaller + # range) than to try and send down a large number of updates + # (which can take a long time). + # + # We only do this if the last sync was over + # `MINIMUM_NOT_USED_AGE_EXPIRY_MS` to ensure we don't get + # into tight loops with clients that keep requesting large + # sliding sync windows. + if len(rooms_that_have_updates) > NUM_ROOMS_THRESHOLD: + last_sync_ts = previous_connection_state.last_used_ts + if ( + last_sync_ts is not None + and (self._clock.time_msec() - last_sync_ts) + > MINIMUM_NOT_USED_AGE_EXPIRY_MS + ): + raise SlidingSyncUnknownPosition() + rooms_should_send.update(rooms_that_have_updates) relevant_rooms_to_send_map = { room_id: room_sync_config diff --git a/synapse/handlers/sliding_sync/store.py b/synapse/handlers/sliding_sync/store.py index 7bcd5f27eae..d01fab271f6 100644 --- a/synapse/handlers/sliding_sync/store.py +++ b/synapse/handlers/sliding_sync/store.py @@ -75,7 +75,7 @@ async def get_and_clear_connection_positions( """ # If this is our first request, there is no previous connection state to fetch out of the database if from_token is None or from_token.connection_position == 0: - return PerConnectionState() + return PerConnectionState(last_used_ts=None) conn_id = sync_config.conn_id or "" diff --git a/synapse/storage/databases/main/sliding_sync.py b/synapse/storage/databases/main/sliding_sync.py index 2b67e75ac47..8cd3de8f40c 100644 --- a/synapse/storage/databases/main/sliding_sync.py +++ b/synapse/storage/databases/main/sliding_sync.py @@ -20,6 +20,7 @@ from synapse.api.errors import SlidingSyncUnknownPosition from synapse.logging.opentracing import log_kv +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( DatabasePool, @@ -36,6 +37,12 @@ RoomSyncConfig, ) from synapse.util.caches.descriptors import cached +from synapse.util.constants import ( + MILLISECONDS_PER_SECOND, + ONE_DAY_SECONDS, + ONE_HOUR_SECONDS, + ONE_MINUTE_SECONDS, +) from synapse.util.json import json_encoder if TYPE_CHECKING: @@ -45,6 +52,21 @@ logger = logging.getLogger(__name__) +# How often to update the `last_used_ts` column on +# `sliding_sync_connection_positions` when the client uses a connection +# position. We don't want to update it on every use to avoid excessive +# writes, but we want it to be reasonably up-to-date to help with +# cleaning up old connection positions. +UPDATE_INTERVAL_LAST_USED_TS_MS = 5 * ONE_MINUTE_SECONDS * MILLISECONDS_PER_SECOND + +# Time in milliseconds the connection hasn't been used before we consider it +# expired and delete it. +CONNECTION_EXPIRY_MS = 7 * ONE_DAY_SECONDS * MILLISECONDS_PER_SECOND + +# How often we run the background process to delete old sliding sync connections. +CONNECTION_EXPIRY_FREQUENCY_MS = ONE_HOUR_SECONDS * MILLISECONDS_PER_SECOND + + class SlidingSyncStore(SQLBaseStore): def __init__( self, @@ -76,6 +98,12 @@ def __init__( replaces_index="sliding_sync_membership_snapshots_user_id", ) + if self.hs.config.worker.run_background_tasks: + self.clock.looping_call( + self.delete_old_sliding_sync_connections, + CONNECTION_EXPIRY_FREQUENCY_MS, + ) + async def get_latest_bump_stamp_for_room( self, room_id: str, @@ -202,6 +230,7 @@ def persist_per_connection_state_txn( "effective_device_id": device_id, "conn_id": conn_id, "created_ts": self.clock.time_msec(), + "last_used_ts": self.clock.time_msec(), }, returning=("connection_key",), ) @@ -384,7 +413,7 @@ def _get_and_clear_connection_positions_txn( # The `previous_connection_position` is a user-supplied value, so we # need to make sure that the one they supplied is actually theirs. sql = """ - SELECT connection_key + SELECT connection_key, last_used_ts FROM sliding_sync_connection_positions INNER JOIN sliding_sync_connections USING (connection_key) WHERE @@ -396,7 +425,20 @@ def _get_and_clear_connection_positions_txn( if row is None: raise SlidingSyncUnknownPosition() - (connection_key,) = row + (connection_key, last_used_ts) = row + + # Update the `last_used_ts` if it's due to be updated. We don't update + # every time to avoid excessive writes. + now = self.clock.time_msec() + if last_used_ts is None or now - last_used_ts > UPDATE_INTERVAL_LAST_USED_TS_MS: + self.db_pool.simple_update_txn( + txn, + table="sliding_sync_connections", + keyvalues={ + "connection_key": connection_key, + }, + updatevalues={"last_used_ts": now}, + ) # Now that we have seen the client has received and used the connection # position, we can delete all the other connection positions. @@ -480,12 +522,30 @@ def _get_and_clear_connection_positions_txn( logger.warning("Unrecognized sliding sync stream in DB %r", stream) return PerConnectionStateDB( + last_used_ts=last_used_ts, rooms=RoomStatusMap(rooms), receipts=RoomStatusMap(receipts), account_data=RoomStatusMap(account_data), room_configs=room_configs, ) + @wrap_as_background_process("delete_old_sliding_sync_connections") + async def delete_old_sliding_sync_connections(self) -> None: + """Delete sliding sync connections that have not been used for a long time.""" + cutoff_ts = self.clock.time_msec() - CONNECTION_EXPIRY_MS + + def delete_old_sliding_sync_connections_txn(txn: LoggingTransaction) -> None: + sql = """ + DELETE FROM sliding_sync_connections + WHERE last_used_ts IS NOT NULL AND last_used_ts < ? + """ + txn.execute(sql, (cutoff_ts,)) + + await self.db_pool.runInteraction( + "delete_old_sliding_sync_connections", + delete_old_sliding_sync_connections_txn, + ) + @attr.s(auto_attribs=True, frozen=True) class PerConnectionStateDB: @@ -498,6 +558,8 @@ class PerConnectionStateDB: When persisting this *only* contains updates to the state. """ + last_used_ts: int | None + rooms: "RoomStatusMap[str]" receipts: "RoomStatusMap[str]" account_data: "RoomStatusMap[str]" @@ -553,6 +615,7 @@ async def from_state( ) return PerConnectionStateDB( + last_used_ts=per_connection_state.last_used_ts, rooms=RoomStatusMap(rooms), receipts=RoomStatusMap(receipts), account_data=RoomStatusMap(account_data), @@ -596,6 +659,7 @@ async def to_state(self, store: "DataStore") -> "PerConnectionState": } return PerConnectionState( + last_used_ts=self.last_used_ts, rooms=RoomStatusMap(rooms), receipts=RoomStatusMap(receipts), account_data=RoomStatusMap(account_data), diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 8644ff412ec..8fa1e2e5a97 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -740,7 +740,14 @@ async def get_rooms_that_have_updates_since_sliding_sync_table( from_key: RoomStreamToken, ) -> StrCollection: """Return the rooms that probably have had updates since the given - token (changes that are > `from_key`).""" + token (changes that are > `from_key`). + + May return false positives, but must not return false negatives. + + If `have_finished_sliding_sync_background_jobs` is False, then we return + all the room IDs, as we can't be sure that the sliding sync table is + fully populated. + """ # If the stream change cache is valid for the stream token, we can just # use the result of that. if from_key.stream >= self._events_stream_cache.get_earliest_known_position(): @@ -748,6 +755,11 @@ async def get_rooms_that_have_updates_since_sliding_sync_table( room_ids, from_key.stream ) + if not self.have_finished_sliding_sync_background_jobs(): + # If the table hasn't been populated yet, we have to assume all rooms + # have updates. + return room_ids + def get_rooms_that_have_updates_since_sliding_sync_table_txn( txn: LoggingTransaction, ) -> StrCollection: diff --git a/synapse/storage/schema/main/delta/93/03_sss_pos_last_used.sql b/synapse/storage/schema/main/delta/93/03_sss_pos_last_used.sql new file mode 100644 index 00000000000..747ba7a144b --- /dev/null +++ b/synapse/storage/schema/main/delta/93/03_sss_pos_last_used.sql @@ -0,0 +1,27 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 Element Creations, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +-- Add a timestamp for when the sliding sync connection position was last used, +-- only updated with a small granularity. +-- +-- This should be NOT NULL, but we need to consider existing rows. In future we +-- may want to either backfill this or delete all rows with a NULL value (and +-- then make it NOT NULL). +ALTER TABLE sliding_sync_connections ADD COLUMN last_used_ts BIGINT; + +-- Note: We don't add an index on this column to allow HOT updates on PostgreSQL +-- to reduce the cost of the updates to the column. c.f. +-- https://www.postgresql.org/docs/current/storage-hot.html +-- +-- We do query this column directly to find expired connections, but we expect +-- that to be an infrequent operation and a sequential scan should be fine. diff --git a/synapse/types/handlers/sliding_sync.py b/synapse/types/handlers/sliding_sync.py index 494e3570d05..03b3bcb3caf 100644 --- a/synapse/types/handlers/sliding_sync.py +++ b/synapse/types/handlers/sliding_sync.py @@ -850,12 +850,16 @@ class PerConnectionState: since the last time you made a sync request. Attributes: + last_used_ts: The time this connection was last used, in milliseconds. + This is only accurate to `UPDATE_CONNECTION_STATE_EVERY_MS`. rooms: The status of each room for the events stream. receipts: The status of each room for the receipts stream. room_configs: Map from room_id to the `RoomSyncConfig` of all rooms that we have previously sent down. """ + last_used_ts: int | None = None + rooms: RoomStatusMap[RoomStreamToken] = attr.Factory(RoomStatusMap) receipts: RoomStatusMap[MultiWriterStreamToken] = attr.Factory(RoomStatusMap) account_data: RoomStatusMap[int] = attr.Factory(RoomStatusMap) @@ -867,6 +871,7 @@ def get_mutable(self) -> "MutablePerConnectionState": room_configs = cast(MutableMapping[str, RoomSyncConfig], self.room_configs) return MutablePerConnectionState( + last_used_ts=self.last_used_ts, rooms=self.rooms.get_mutable(), receipts=self.receipts.get_mutable(), account_data=self.account_data.get_mutable(), @@ -875,6 +880,7 @@ def get_mutable(self) -> "MutablePerConnectionState": def copy(self) -> "PerConnectionState": return PerConnectionState( + last_used_ts=self.last_used_ts, rooms=self.rooms.copy(), receipts=self.receipts.copy(), account_data=self.account_data.copy(), @@ -889,6 +895,8 @@ def __len__(self) -> int: class MutablePerConnectionState(PerConnectionState): """A mutable version of `PerConnectionState`""" + last_used_ts: int | None + rooms: MutableRoomStatusMap[RoomStreamToken] receipts: MutableRoomStatusMap[MultiWriterStreamToken] account_data: MutableRoomStatusMap[int] diff --git a/synapse/util/constants.py b/synapse/util/constants.py index 7a3d073df55..f4491b58856 100644 --- a/synapse/util/constants.py +++ b/synapse/util/constants.py @@ -18,5 +18,6 @@ # readability and catching bugs. ONE_MINUTE_SECONDS = 60 ONE_HOUR_SECONDS = 60 * ONE_MINUTE_SECONDS +ONE_DAY_SECONDS = 24 * ONE_HOUR_SECONDS MILLISECONDS_PER_SECOND = 1000 diff --git a/tests/rest/client/sliding_sync/test_connection_tracking.py b/tests/rest/client/sliding_sync/test_connection_tracking.py index 16d13fcc860..cdf63317e39 100644 --- a/tests/rest/client/sliding_sync/test_connection_tracking.py +++ b/tests/rest/client/sliding_sync/test_connection_tracking.py @@ -12,6 +12,7 @@ # . # import logging +from unittest.mock import patch from parameterized import parameterized, parameterized_class @@ -19,8 +20,11 @@ import synapse.rest.admin from synapse.api.constants import EventTypes +from synapse.api.errors import Codes +from synapse.handlers.sliding_sync import room_lists from synapse.rest.client import login, room, sync from synapse.server import HomeServer +from synapse.storage.databases.main.sliding_sync import CONNECTION_EXPIRY_MS from synapse.util.clock import Clock from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase @@ -395,3 +399,107 @@ def test_rooms_timeline_incremental_sync_NEVER(self) -> None: ) self.assertEqual(response_body["rooms"][room_id1]["limited"], True) self.assertEqual(response_body["rooms"][room_id1]["initial"], True) + + @patch("synapse.handlers.sliding_sync.room_lists.NUM_ROOMS_THRESHOLD", new=5) + def test_sliding_sync_connection_expires_with_too_much_data(self) -> None: + """ + Test that if we have too much data to send down for incremental sync, + we expire the connection and ask the client to do a full resync. + + Connections are only expired if they have not been used for a minimum + amount of time (MINIMUM_NOT_USED_AGE_EXPIRY_MS) to avoid expiring + connections that are actively being used. + """ + + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # Create enough rooms that we can later trigger the too much data case + room_ids = [] + for _ in range(room_lists.NUM_ROOMS_THRESHOLD + 2): + room_id = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room_id, user1_id, tok=user1_tok) + room_ids.append(room_id) + + # Make sure we don't hit ratelimits + self.reactor.advance(60 * 1000) + + # Make the Sliding Sync request + sync_body = { + "lists": { + "foo-list": { + "ranges": [[0, 1000]], + "required_state": [], + "timeline_limit": 1, + } + } + } + + response_body, from_token = self.do_sync(sync_body, tok=user1_tok) + + # Check we got all the rooms down + for room_id in room_ids: + self.assertIn(room_id, response_body["rooms"]) + + # Send a lot of events to cause the connection to expire + for room_id in room_ids: + self.helper.send(room_id, "msg", tok=user2_tok) + + # If we don't advance the clock then we won't expire the connection. + response_body, from_token = self.do_sync(sync_body, tok=user1_tok) + + # Send some more events. + for room_id in room_ids: + self.helper.send(room_id, "msg", tok=user2_tok) + + # Advance the clock to ensure that the last_used_ts is old enough + self.reactor.advance(2 * room_lists.MINIMUM_NOT_USED_AGE_EXPIRY_MS / 1000) + + # This sync should now raise SlidingSyncUnknownPosition + channel = self.make_sync_request(sync_body, since=from_token, tok=user1_tok) + self.assertEqual(channel.code, 400) + self.assertEqual(channel.json_body["errcode"], Codes.UNKNOWN_POS) + + def test_sliding_sync_connection_expires_after_time(self) -> None: + """ + Test that if we don't use a sliding sync connection for a long time, + we expire the connection and ask the client to do a full resync. + """ + + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + room_id = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room_id, user1_id, tok=user1_tok) + + # Make the Sliding Sync request + sync_body = { + "lists": { + "foo-list": { + "ranges": [[0, 1000]], + "required_state": [], + "timeline_limit": 1, + } + } + } + + _, from_token = self.do_sync(sync_body, tok=user1_tok) + + # We can keep syncing so long as the interval between requests is less + # than CONNECTION_EXPIRY_MS + for _ in range(5): + self.reactor.advance(0.5 * CONNECTION_EXPIRY_MS / 1000) + + _, from_token = self.do_sync(sync_body, tok=user1_tok) + + # ... but if we wait too long, the connection expires + self.reactor.advance(1 + CONNECTION_EXPIRY_MS / 1000) + + # This sync should now raise SlidingSyncUnknownPosition + channel = self.make_sync_request(sync_body, since=from_token, tok=user1_tok) + self.assertEqual(channel.code, 400) + self.assertEqual(channel.json_body["errcode"], Codes.UNKNOWN_POS) diff --git a/tests/rest/client/sliding_sync/test_sliding_sync.py b/tests/rest/client/sliding_sync/test_sliding_sync.py index c27a712088d..bcd22d15ca6 100644 --- a/tests/rest/client/sliding_sync/test_sliding_sync.py +++ b/tests/rest/client/sliding_sync/test_sliding_sync.py @@ -46,7 +46,7 @@ from synapse.util.stringutils import random_string from tests import unittest -from tests.server import TimedOutException +from tests.server import FakeChannel, TimedOutException from tests.test_utils.event_injection import create_event logger = logging.getLogger(__name__) @@ -80,12 +80,10 @@ def default_config(self) -> JsonDict: config["experimental_features"] = {"msc3575_enabled": True} return config - def do_sync( + def make_sync_request( self, sync_body: JsonDict, *, since: str | None = None, tok: str - ) -> tuple[JsonDict, str]: - """Do a sliding sync request with given body. - - Asserts the request was successful. + ) -> FakeChannel: + """Make a sliding sync request with given body. Attributes: sync_body: The full request body to use @@ -106,6 +104,24 @@ def do_sync( content=sync_body, access_token=tok, ) + return channel + + def do_sync( + self, sync_body: JsonDict, *, since: str | None = None, tok: str + ) -> tuple[JsonDict, str]: + """Do a sliding sync request with given body. + + Asserts the request was successful. + + Attributes: + sync_body: The full request body to use + since: Optional since token + tok: Access token to use + + Returns: + A tuple of the response body and the `pos` field. + """ + channel = self.make_sync_request(sync_body, since=since, tok=tok) self.assertEqual(channel.code, 200, channel.json_body) return channel.json_body, channel.json_body["pos"]