Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce device lists replication traffic. #17333

Merged
merged 11 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/17333.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Handle device lists notifications for large accounts more efficiently in worker mode.
19 changes: 12 additions & 7 deletions synapse/replication/tcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,19 @@ async def on_rdata(
"""
all_room_ids: Set[str] = set()
if stream_name == DeviceListsStream.NAME:
if any(row.entity.startswith("@") and not row.is_signature for row in rows):
if any(not row.is_signature and not row.hosts_calculated for row in rows):
prev_token = self.store.get_device_stream_token()
all_room_ids = await self.store.get_all_device_list_changes(
prev_token, token
)
self.store.device_lists_in_rooms_have_changed(all_room_ids, token)

# If we're sending federation we need to update the device lists
# outbound pokes stream change cache with updated hosts.
if self.send_handler and any(row.hosts_calculated for row in rows):
hosts = await self.store.get_destinations_for_device(token)
self.store.device_lists_outbound_pokes_have_changed(hosts, token)

self.store.process_replication_rows(stream_name, instance_name, token, rows)
# NOTE: this must be called after process_replication_rows to ensure any
# cache invalidations are first handled before any stream ID advances.
Expand Down Expand Up @@ -433,12 +439,11 @@ async def process_replication_rows(
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
hosts = {
row.entity
for row in rows
if not row.entity.startswith("@") and not row.is_signature
}
await self.federation_sender.send_device_messages(hosts, immediate=False)
if any(row.hosts_calculated for row in rows):
hosts = await self.store.get_destinations_for_device(token)
await self.federation_sender.send_device_messages(
hosts, immediate=False
)

elif stream_name == ToDeviceStream.NAME:
# The to_device stream includes stuff to be pushed to both local
Expand Down
12 changes: 8 additions & 4 deletions synapse/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,10 +549,14 @@ class DeviceListsStream(_StreamFromIdGen):

@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListsStreamRow:
entity: str
user_id: str
# Indicates that a user has signed their own device with their user-signing key
is_signature: bool

# Indicates if this is a notification that we've calculated the hosts we
# need to send the update to.
hosts_calculated: bool

NAME = "device_lists"
ROW_TYPE = DeviceListsStreamRow

Expand Down Expand Up @@ -594,13 +598,13 @@ async def _update_function(
upper_limit_token = min(upper_limit_token, signatures_to_token)

device_updates = [
(stream_id, (entity, False))
for stream_id, (entity,) in device_updates
(stream_id, (entity, False, hosts))
for stream_id, (entity, hosts) in device_updates
if stream_id <= upper_limit_token
]

signatures_updates = [
(stream_id, (entity, True))
(stream_id, (entity, True, False))
for stream_id, (entity,) in signatures_updates
if stream_id <= upper_limit_token
]
Expand Down
93 changes: 58 additions & 35 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,22 +164,24 @@ def __init__(
prefilled_cache=user_signature_stream_prefill,
)

(
device_list_federation_prefill,
device_list_federation_list_id,
) = self.db_pool.get_cache_dict(
db_conn,
"device_lists_outbound_pokes",
entity_column="destination",
stream_column="stream_id",
max_value=device_list_max,
limit=10000,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache",
device_list_federation_list_id,
prefilled_cache=device_list_federation_prefill,
)
self._device_list_federation_stream_cache = None
if hs.should_send_federation():
(
device_list_federation_prefill,
device_list_federation_list_id,
) = self.db_pool.get_cache_dict(
db_conn,
"device_lists_outbound_pokes",
entity_column="destination",
stream_column="stream_id",
max_value=device_list_max,
limit=10000,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache",
device_list_federation_list_id,
prefilled_cache=device_list_federation_prefill,
)

if hs.config.worker.run_background_tasks:
self._clock.looping_call(
Expand Down Expand Up @@ -207,23 +209,30 @@ def _invalidate_caches_for_devices(
) -> None:
for row in rows:
if row.is_signature:
self._user_signature_stream_cache.entity_has_changed(row.entity, token)
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
continue

# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
if row.entity.startswith("@"):
self._device_list_stream_cache.entity_has_changed(row.entity, token)
self.get_cached_devices_for_user.invalidate((row.entity,))
self._get_cached_user_device.invalidate((row.entity,))
self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))

else:
self._device_list_federation_stream_cache.entity_has_changed(
row.entity, token
if not row.hosts_calculated:
self._device_list_stream_cache.entity_has_changed(row.user_id, token)
self.get_cached_devices_for_user.invalidate((row.user_id,))
self._get_cached_user_device.invalidate((row.user_id,))
self.get_device_list_last_stream_id_for_remote.invalidate(
(row.user_id,)
)

def device_lists_outbound_pokes_have_changed(
self, destinations: StrCollection, token: int
) -> None:
assert self._device_list_federation_stream_cache is not None

for destination in destinations:
self._device_list_federation_stream_cache.entity_has_changed(
destination, token
)

def device_lists_in_rooms_have_changed(
self, room_ids: StrCollection, token: int
) -> None:
Expand Down Expand Up @@ -363,6 +372,11 @@ async def get_device_updates_by_remote(
EDU contents.
"""
now_stream_id = self.get_device_stream_token()
if from_stream_id == now_stream_id:
return now_stream_id, []

if self._device_list_federation_stream_cache is None:
raise Exception("Func can only be used on federation senders")

has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
Expand Down Expand Up @@ -1018,10 +1032,10 @@ def _get_all_device_list_changes_for_remotes(
# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
SELECT stream_id, entity FROM (
SELECT stream_id, user_id AS entity FROM device_lists_stream
SELECT stream_id, user_id, hosts FROM (
SELECT stream_id, user_id, false AS hosts FROM device_lists_stream
UNION ALL
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
SELECT DISTINCT stream_id, user_id, true AS hosts FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
Expand Down Expand Up @@ -1577,6 +1591,14 @@ def get_device_list_changes_in_room_txn(
get_device_list_changes_in_room_txn,
)

async def get_destinations_for_device(self, stream_id: int) -> StrCollection:
return await self.db_pool.simple_select_onecol(
table="device_lists_outbound_pokes",
keyvalues={"stream_id": stream_id},
retcol="destination",
desc="get_destinations_for_device",
)


class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(
Expand Down Expand Up @@ -2112,12 +2134,13 @@ def _add_device_outbound_poke_to_stream_txn(
stream_ids: List[int],
context: Optional[Dict[str, str]],
) -> None:
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host,
stream_ids[-1],
)
if self._device_list_federation_stream_cache:
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host,
stream_ids[-1],
)

now = self._clock.time_msec()
stream_id_iterator = iter(stream_ids)
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ def process_replication_rows(
if stream_name == DeviceListsStream.NAME:
for row in rows:
assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
if row.entity.startswith("@"):
if not row.hosts_calculated:
self._get_e2e_device_keys_for_federation_query_inner.invalidate(
(row.entity,)
(row.user_id,)
)

super().process_replication_rows(stream_name, instance_name, token, rows)
Expand Down
8 changes: 8 additions & 0 deletions tests/storage/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ class DeviceStoreTestCase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main

def default_config(self) -> JsonDict:
config = super().default_config()

# We 'enable' federation otherwise `get_device_updates_by_remote` will
# throw an exception.
config["federation_sender_instances"] = ["master"]
return config

def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
"""Add a device list change for the given device to
`device_lists_outbound_pokes` table.
Expand Down
Loading