From 5c9e39e6192e952ba8a5bb8e5485bc6067f91699 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 4 Apr 2022 15:25:20 +0100 Subject: [PATCH] Track device list updates per room. (#12321) This is a first step in dealing with #7721. The idea is basically that rather than calculating the full set of users a device list update needs to be sent to up front, we instead simply record the rooms the user was in at the time of the change. This will allow a few things: 1. we can defer calculating the set of remote servers that need to be poked about the change; and 2. during `/sync` and `/keys/changes` we can avoid also avoid calculating users who share rooms with other users, and instead just look at the rooms that have changed. However, care needs to be taken to correctly handle server downgrades. As such this PR writes to both `device_lists_changes_in_room` and the `device_lists_outbound_pokes` table synchronously. In a future release we can then bump the database schema compat version to `69` and then we can assume that the new `device_lists_changes_in_room` exists and is handled. There is a temporary option to disable writing to `device_lists_outbound_pokes` synchronously, allowing us to test the new code path does work (and by implication upgrading to a future release and downgrading to this one will work correctly). Note: Ideally we'd do the calculation of room to servers on a worker (e.g. the background worker), but currently only master can write to the `device_list_outbound_pokes` table. --- changelog.d/12321.misc | 1 + synapse/_scripts/synapse_port_db.py | 1 + synapse/config/server.py | 8 + synapse/handlers/device.py | 132 ++++++++++- synapse/replication/slave/storage/devices.py | 1 + synapse/storage/databases/main/__init__.py | 1 + synapse/storage/databases/main/devices.py | 217 ++++++++++++++++-- synapse/storage/schema/__init__.py | 1 + .../69/01device_list_oubound_by_room.sql | 38 +++ tests/federation/test_federation_sender.py | 23 +- tests/storage/test_devices.py | 14 +- 11 files changed, 390 insertions(+), 47 deletions(-) create mode 100644 changelog.d/12321.misc create mode 100644 synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql diff --git a/changelog.d/12321.misc b/changelog.d/12321.misc new file mode 100644 index 000000000000..200e7c44fe5d --- /dev/null +++ b/changelog.d/12321.misc @@ -0,0 +1 @@ +Add ground work for speeding up device list updates for users in large numbers of rooms. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index c38666da18e6..6324df883bc7 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -97,6 +97,7 @@ "users": ["shadow_banned"], "e2e_fallback_keys_json": ["used"], "access_tokens": ["used"], + "device_lists_changes_in_room": ["converted_to_destinations"], } diff --git a/synapse/config/server.py b/synapse/config/server.py index 0f90302c9566..b3a9e5075269 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -680,6 +680,14 @@ def read_config(self, config, **kwargs): config.get("use_account_validity_in_account_status") or False ) + # This is a temporary option that enables fully using the new + # `device_lists_changes_in_room` without the backwards compat code. This + # is primarily for testing. If enabled the server should *not* be + # downgraded, as it may lead to missing device list updates. + self.use_new_device_lists_changes_in_room = ( + config.get("use_new_device_lists_changes_in_room") or False + ) + self.rooms_to_exclude_from_sync: List[str] = ( config.get("exclude_rooms_from_sync") or [] ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d5ccaa0c37cc..c710c02cf97e 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -37,7 +37,10 @@ SynapseError, ) from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.types import ( JsonDict, StreamToken, @@ -278,6 +281,22 @@ def __init__(self, hs: "HomeServer"): hs.get_distributor().observe("user_left_room", self.user_left_room) + # Whether `_handle_new_device_update_async` is currently processing. + self._handle_new_device_update_is_processing = False + + # If a new device update may have happened while the loop was + # processing. + self._handle_new_device_update_new_data = False + + # On start up check if there are any updates pending. + hs.get_reactor().callWhenRunning(self._handle_new_device_update_async) + + # Used to decide if we calculate outbound pokes up front or not. By + # default we do to allow safely downgrading Synapse. + self.use_new_device_lists_changes_in_room = ( + hs.config.server.use_new_device_lists_changes_in_room + ) + def _check_device_name_length(self, name: Optional[str]) -> None: """ Checks whether a device name is longer than the maximum allowed length. @@ -469,19 +488,26 @@ async def notify_device_update( # No changes to notify about, so this is a no-op. return - users_who_share_room = await self.store.get_users_who_share_room_with_user( - user_id - ) + room_ids = await self.store.get_rooms_for_user(user_id) + + hosts: Optional[Set[str]] = None + if not self.use_new_device_lists_changes_in_room: + hosts = set() - hosts: Set[str] = set() - if self.hs.is_mine_id(user_id): - hosts.update(get_domain_from_id(u) for u in users_who_share_room) - hosts.discard(self.server_name) + if self.hs.is_mine_id(user_id): + for room_id in room_ids: + joined_users = await self.store.get_users_in_room(room_id) + hosts.update(get_domain_from_id(u) for u in joined_users) - set_tag("target_hosts", hosts) + set_tag("target_hosts", hosts) + + hosts.discard(self.server_name) position = await self.store.add_device_change_to_streams( - user_id, device_ids, list(hosts) + user_id, + device_ids, + hosts=hosts, + room_ids=room_ids, ) if not position: @@ -495,9 +521,12 @@ async def notify_device_update( # specify the user ID too since the user should always get their own device list # updates, even if they aren't in any rooms. - users_to_notify = users_who_share_room.union({user_id}) + self.notifier.on_new_event( + "device_list_key", position, users={user_id}, rooms=room_ids + ) - self.notifier.on_new_event("device_list_key", position, users=users_to_notify) + # We may need to do some processing asynchronously. + self._handle_new_device_update_async() if hosts: logger.info( @@ -614,6 +643,85 @@ async def rehydrate_device( return {"success": True} + @wrap_as_background_process("_handle_new_device_update_async") + async def _handle_new_device_update_async(self) -> None: + """Called when we have a new local device list update that we need to + send out over federation. + + This happens in the background so as not to block the original request + that generated the device update. + """ + if self._handle_new_device_update_is_processing: + self._handle_new_device_update_new_data = True + return + + self._handle_new_device_update_is_processing = True + + # The stream ID we processed previous iteration (if any), and the set of + # hosts we've already poked about for this update. This is so that we + # don't poke the same remote server about the same update repeatedly. + current_stream_id = None + hosts_already_sent_to: Set[str] = set() + + try: + while True: + self._handle_new_device_update_new_data = False + rows = await self.store.get_uncoverted_outbound_room_pokes() + if not rows: + # If the DB returned nothing then there is nothing left to + # do, *unless* a new device list update happened during the + # DB query. + if self._handle_new_device_update_new_data: + continue + else: + return + + for user_id, device_id, room_id, stream_id, opentracing_context in rows: + joined_user_ids = await self.store.get_users_in_room(room_id) + hosts = {get_domain_from_id(u) for u in joined_user_ids} + hosts.discard(self.server_name) + + # Check if we've already sent this update to some hosts + if current_stream_id == stream_id: + hosts -= hosts_already_sent_to + + await self.store.add_device_list_outbound_pokes( + user_id=user_id, + device_id=device_id, + room_id=room_id, + stream_id=stream_id, + hosts=hosts, + context=opentracing_context, + ) + + # Notify replication that we've updated the device list stream. + self.notifier.notify_replication() + + if hosts: + logger.info( + "Sending device list update notif for %r to: %r", + user_id, + hosts, + ) + for host in hosts: + self.federation_sender.send_device_messages( + host, immediate=False + ) + log_kv( + {"message": "sent device update to host", "host": host} + ) + + if current_stream_id != stream_id: + # Clear the set of hosts we've already sent to as we're + # processing a new update. + hosts_already_sent_to.clear() + + hosts_already_sent_to.update(hosts) + current_stream_id = stream_id + + finally: + self._handle_new_device_update_is_processing = False + def _update_device_from_client_ips( device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 0ffd34f1dad0..f040e33bfb41 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -44,6 +44,7 @@ def __init__( extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), ], ) device_list_max = self._device_list_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 1ea0b2aa6f0a..cdbe3872faa0 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -146,6 +146,7 @@ def __init__( extra_tables=[ ("user_signature_stream", "stream_id"), ("device_lists_outbound_pokes", "stream_id"), + ("device_lists_changes_in_room", "stream_id"), ], ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index f08f7834d39e..07eea4b3d217 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -810,6 +810,7 @@ def _get_all_device_list_changes_for_remotes(txn): SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes ) AS e WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? """ @@ -1528,7 +1529,11 @@ def _update_remote_device_list_cache_txn( ) async def add_device_change_to_streams( - self, user_id: str, device_ids: Collection[str], hosts: Collection[str] + self, + user_id: str, + device_ids: Collection[str], + hosts: Optional[Collection[str]], + room_ids: Collection[str], ) -> Optional[int]: """Persist that a user's devices have been updated, and which hosts (if any) should be poked. @@ -1537,7 +1542,10 @@ async def add_device_change_to_streams( user_id: The ID of the user whose device changed. device_ids: The IDs of any changed devices. If empty, this function will return None. - hosts: The remote destinations that should be notified of the change. + hosts: The remote destinations that should be notified of the change. If + None then the set of hosts have *not* been calculated, and will be + calculated later by a background task. + room_ids: The rooms that the user is in Returns: The maximum stream ID of device list updates that were added to the database, or @@ -1546,34 +1554,62 @@ async def add_device_change_to_streams( if not device_ids: return None - async with self._device_list_id_gen.get_next_mult( - len(device_ids) - ) as stream_ids: - await self.db_pool.runInteraction( - "add_device_change_to_stream", - self._add_device_change_to_stream_txn, + context = get_active_span_text_map() + + def add_device_changes_txn( + txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes + ): + self._add_device_change_to_stream_txn( + txn, user_id, device_ids, - stream_ids, + stream_ids_for_device_change, ) - if not hosts: - return stream_ids[-1] + self._add_device_outbound_room_poke_txn( + txn, + user_id, + device_ids, + room_ids, + stream_ids_for_device_change, + context, + hosts_have_been_calculated=hosts is not None, + ) - context = get_active_span_text_map() - async with self._device_list_id_gen.get_next_mult( - len(hosts) * len(device_ids) - ) as stream_ids: - await self.db_pool.runInteraction( - "add_device_outbound_poke_to_stream", - self._add_device_outbound_poke_to_stream_txn, + # If the set of hosts to send to has not been calculated yet (and so + # `hosts` is None) or there are no `hosts` to send to, then skip + # trying to persist them to the DB. + if not hosts: + return + + self._add_device_outbound_poke_to_stream_txn( + txn, user_id, device_ids, hosts, - stream_ids, + stream_ids_for_outbound_pokes, context, ) + # `device_lists_stream` wants a stream ID per device update. + num_stream_ids = len(device_ids) + + if hosts: + # `device_lists_outbound_pokes` wants a different stream ID for + # each row, which is a row per host per device update. + num_stream_ids += len(hosts) * len(device_ids) + + async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids: + stream_ids_for_device_change = stream_ids[: len(device_ids)] + stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :] + + await self.db_pool.runInteraction( + "add_device_change_to_stream", + add_device_changes_txn, + stream_ids_for_device_change, + stream_ids_for_outbound_pokes, + ) + return stream_ids[-1] def _add_device_change_to_stream_txn( @@ -1617,7 +1653,7 @@ def _add_device_outbound_poke_to_stream_txn( user_id: str, device_ids: Iterable[str], hosts: Collection[str], - stream_ids: List[str], + stream_ids: List[int], context: Dict[str, str], ) -> None: for host in hosts: @@ -1628,8 +1664,9 @@ def _add_device_outbound_poke_to_stream_txn( ) now = self._clock.time_msec() - next_stream_id = iter(stream_ids) + stream_id_iterator = iter(stream_ids) + encoded_context = json_encoder.encode(context) self.db_pool.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", @@ -1645,16 +1682,146 @@ def _add_device_outbound_poke_to_stream_txn( values=[ ( destination, - next(next_stream_id), + next(stream_id_iterator), user_id, device_id, False, now, - json_encoder.encode(context) - if whitelisted_homeserver(destination) - else "{}", + encoded_context if whitelisted_homeserver(destination) else "{}", ) for destination in hosts for device_id in device_ids ], ) + + def _add_device_outbound_room_poke_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_ids: Iterable[str], + room_ids: Collection[str], + stream_ids: List[str], + context: Dict[str, str], + hosts_have_been_calculated: bool, + ) -> None: + """Record the user in the room has updated their device. + + Args: + hosts_have_been_calculated: True if `device_lists_outbound_pokes` + has been updated already with the updates. + """ + + # We only need to convert to outbound pokes if they are our user. + converted_to_destinations = ( + hosts_have_been_calculated or not self.hs.is_mine_id(user_id) + ) + + encoded_context = json_encoder.encode(context) + + # The `device_lists_changes_in_room.stream_id` column matches the + # corresponding `stream_id` of the update in the `device_lists_stream` + # table, i.e. all rows persisted for the same device update will have + # the same `stream_id` (but different room IDs). + self.db_pool.simple_insert_many_txn( + txn, + table="device_lists_changes_in_room", + keys=( + "user_id", + "device_id", + "room_id", + "stream_id", + "converted_to_destinations", + "opentracing_context", + ), + values=[ + ( + user_id, + device_id, + room_id, + stream_id, + converted_to_destinations, + encoded_context, + ) + for room_id in room_ids + for device_id, stream_id in zip(device_ids, stream_ids) + ], + ) + + async def get_uncoverted_outbound_room_pokes( + self, limit: int = 10 + ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: + """Get device list changes by room that have not yet been handled and + written to `device_lists_outbound_pokes`. + + Returns: + A list of user ID, device ID, room ID, stream ID and optional opentracing context. + """ + + sql = """ + SELECT user_id, device_id, room_id, stream_id, opentracing_context + FROM device_lists_changes_in_room + WHERE NOT converted_to_destinations + ORDER BY stream_id + LIMIT ? + """ + + def get_uncoverted_outbound_room_pokes_txn(txn): + txn.execute(sql, (limit,)) + return txn.fetchall() + + return await self.db_pool.runInteraction( + "get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn + ) + + async def add_device_list_outbound_pokes( + self, + user_id: str, + device_id: str, + room_id: str, + stream_id: int, + hosts: Collection[str], + context: Optional[Dict[str, str]], + ) -> None: + """Queue the device update to be sent to the given set of hosts, + calculated from the room ID. + + Marks the associated row in `device_lists_changes_in_room` as handled. + """ + + def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]): + if hosts: + self._add_device_outbound_poke_to_stream_txn( + txn, + user_id=user_id, + device_ids=[device_id], + hosts=hosts, + stream_ids=stream_ids, + context=context, + ) + + self.db_pool.simple_update_txn( + txn, + table="device_lists_changes_in_room", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "stream_id": stream_id, + "room_id": room_id, + }, + updatevalues={"converted_to_destinations": True}, + ) + + if not hosts: + # If there are no hosts then we don't try and generate stream IDs. + return await self.db_pool.runInteraction( + "add_device_list_outbound_pokes", + add_device_list_outbound_pokes_txn, + [], + ) + + async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: + return await self.db_pool.runInteraction( + "add_device_list_outbound_pokes", + add_device_list_outbound_pokes_txn, + stream_ids, + ) diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index ea900e0f3d35..151f2aa9bbfe 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -60,6 +60,7 @@ new events. Changes in SCHEMA_VERSION = 69: + - We now write to `device_lists_changes_in_room` table. - Use sequence to generate future `application_services_txns.txn_id`s """ diff --git a/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql new file mode 100644 index 000000000000..b5b1782b2aae --- /dev/null +++ b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql @@ -0,0 +1,38 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE device_lists_changes_in_room ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + room_id TEXT NOT NULL, + + -- This initially matches `device_lists_stream.stream_id`. Note that we + -- delete older values from `device_lists_stream`, so we can't use a foreign + -- constraint here. + -- + -- The table will contain rows with the same `stream_id` but different + -- `room_id`, as for each device update we store a row per room the user is + -- joined to. Therefore `(stream_id, room_id)` gives a unique index. + stream_id BIGINT NOT NULL, + + -- We have a background process which goes through this table and converts + -- entries into rows in `device_lists_outbound_pokes`. Once we have processed + -- a row, we mark it as such by setting `converted_to_destinations=TRUE`. + converted_to_destinations BOOLEAN NOT NULL, + opentracing_context TEXT +); + +CREATE UNIQUE INDEX device_lists_changes_in_stream_id ON device_lists_changes_in_room(stream_id, room_id); +CREATE INDEX device_lists_changes_in_stream_id_unconverted ON device_lists_changes_in_room(stream_id) WHERE NOT converted_to_destinations; diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index e90592855ad9..a6e91956af2f 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -14,6 +14,7 @@ from typing import Optional from unittest.mock import Mock +from parameterized import parameterized_class from signedjson import key, sign from signedjson.types import BaseKey, SigningKey @@ -154,6 +155,12 @@ def test_send_receipts_with_backoff(self): ) +@parameterized_class( + [ + {"enable_room_poke_code_path": False}, + {"enable_room_poke_code_path": True}, + ] +) class FederationSenderDevicesTestCases(HomeserverTestCase): servlets = [ admin.register_servlets, @@ -168,17 +175,21 @@ def make_homeserver(self, reactor, clock): def default_config(self): c = super().default_config() c["send_federation"] = True + c["use_new_device_lists_changes_in_room"] = self.enable_room_poke_code_path return c def prepare(self, reactor, clock, hs): - # stub out get_users_who_share_room_with_user so that it claims that - # `@user2:host2` is in the room - def get_users_who_share_room_with_user(user_id): + # stub out `get_rooms_for_user` and `get_users_in_room` so that the + # server thinks the user shares a room with `@user2:host2` + def get_rooms_for_user(user_id): + return defer.succeed({"!room:host1"}) + + hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user + + def get_users_in_room(room_id): return defer.succeed({"@user2:host2"}) - hs.get_datastores().main.get_users_who_share_room_with_user = ( - get_users_who_share_room_with_user - ) + hs.get_datastores().main.get_users_in_room = get_users_in_room # whenever send_transaction is called, record the edu data self.edus = [] diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 21ffc5a9095b..d1227dd4ac02 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -96,7 +96,9 @@ def test_get_device_updates_by_remote(self): # Add two device updates with sequential `stream_id`s self.get_success( - self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + self.store.add_device_change_to_streams( + "user_id", device_ids, ["somehost"], ["!some:room"] + ) ) # Get all device updates ever meant for this remote @@ -122,7 +124,9 @@ def test_get_device_updates_by_remote_can_limit_properly(self): "device_id5", ] self.get_success( - self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + self.store.add_device_change_to_streams( + "user_id", device_ids, ["somehost"], ["!some:room"] + ) ) # Get device updates meant for this remote @@ -144,7 +148,9 @@ def test_get_device_updates_by_remote_can_limit_properly(self): # Add some more device updates to ensure it still resumes properly device_ids = ["device_id6", "device_id7"] self.get_success( - self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) + self.store.add_device_change_to_streams( + "user_id", device_ids, ["somehost"], ["!some:room"] + ) ) # Get the next batch of device updates @@ -220,7 +226,7 @@ def test_get_device_updates_by_remote_cross_signing_key_updates( self.get_success( self.store.add_device_change_to_streams( - "@user_id:test", device_ids, ["somehost"] + "@user_id:test", device_ids, ["somehost"], ["!some:room"] ) )