Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Weakness in auth chain indexing allows DoS from remote room members
through disk fill and high CPU usage.

A remote Matrix user with malicious intent, sharing a room with Synapse
instances before 1.104.1, can dispatch specially crafted events to
exploit a weakness in how the auth chain cover index is calculated. This
can induce high CPU consumption and accumulate excessive data in the
database of such instances, resulting in a denial of service.

Servers in private federations, or those that do not federate, are not
affected.
  • Loading branch information
erikjohnston committed Apr 23, 2024
1 parent fbb2573 commit 55b0aa8
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 104 deletions.
1 change: 1 addition & 0 deletions changelog.d/17044.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor auth chain fetching to reduce duplication.
108 changes: 37 additions & 71 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
import collections
import itertools
import logging
from collections import OrderedDict
Expand Down Expand Up @@ -53,6 +54,7 @@
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.event_federation import EventFederationStore
from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.engines import PostgresEngine
Expand Down Expand Up @@ -768,40 +770,26 @@ def _add_chain_cover_index(
# that have the same chain ID as the event.
# 2. For each retained auth event we:
# a. Add a link from the event's to the auth event's chain
# ID/sequence number; and
# b. Add a link from the event to every chain reachable by the
# auth event.
# ID/sequence number

# Step 1, fetch all existing links from all the chains we've seen
# referenced.
chain_links = _LinkMap()
auth_chain_rows = cast(
List[Tuple[int, int, int, int]],
db_pool.simple_select_many_txn(
txn,
table="event_auth_chain_links",
column="origin_chain_id",
iterable={chain_id for chain_id, _ in chain_map.values()},
keyvalues={},
retcols=(
"origin_chain_id",
"origin_sequence_number",
"target_chain_id",
"target_sequence_number",
),
),
)
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in auth_chain_rows:
chain_links.add_link(
(origin_chain_id, origin_sequence_number),
(target_chain_id, target_sequence_number),
new=False,
)

for links in EventFederationStore._get_chain_links(
txn, {chain_id for chain_id, _ in chain_map.values()}
):
for origin_chain_id, inner_links in links.items():
for (
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in inner_links:
chain_links.add_link(
(origin_chain_id, origin_sequence_number),
(target_chain_id, target_sequence_number),
new=False,
)

# We do this in toplogical order to avoid adding redundant links.
for event_id in sorted_topologically(
Expand Down Expand Up @@ -836,18 +824,6 @@ def _add_chain_cover_index(
(chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
)

# Step 2b, add a link to chains reachable from the auth
# event.
for target_id, target_seq in chain_links.get_links_from(
(auth_chain_id, auth_sequence_number)
):
if target_id == chain_id:
continue

chain_links.add_link(
(chain_id, sequence_number), (target_id, target_seq)
)

db_pool.simple_insert_many_txn(
txn,
table="event_auth_chain_links",
Expand Down Expand Up @@ -2451,31 +2427,6 @@ def add_link(
current_links[src_seq] = target_seq
return True

def get_links_from(
self, src_tuple: Tuple[int, int]
) -> Generator[Tuple[int, int], None, None]:
"""Gets the chains reachable from the given chain/sequence number.
Yields:
The chain ID and sequence number the link points to.
"""
src_chain, src_seq = src_tuple
for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
for link_src_seq, target_seq in sequence_numbers.items():
if link_src_seq <= src_seq:
yield target_id, target_seq

def get_links_between(
self, source_chain: int, target_chain: int
) -> Generator[Tuple[int, int], None, None]:
"""Gets the links between two chains.
Yields:
The source and target sequence numbers.
"""

yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()

def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
"""Gets any newly added links.
Expand All @@ -2502,9 +2453,24 @@ def exists_path_from(
if src_chain == target_chain:
return target_seq <= src_seq

links = self.get_links_between(src_chain, target_chain)
for link_start_seq, link_end_seq in links:
if link_start_seq <= src_seq and target_seq <= link_end_seq:
return True
# We have to graph traverse the links to check for indirect paths.
visited_chains = collections.Counter()
search = [(src_chain, src_seq)]
while search:
chain, seq = search.pop()
visited_chains[chain] = max(seq, visited_chains[chain])
for tc, links in self.maps.get(chain, {}).items():
for ss, ts in links.items():
# Don't revisit chains we've already seen, unless the target
# sequence number is higher than last time.
if ts <= visited_chains.get(tc, 0):
continue

if ss <= seq:
if tc == target_chain:
if target_seq <= ts:
return True
else:
search.append((tc, ts))

return False
8 changes: 6 additions & 2 deletions synapse/storage/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,16 @@
Changes in SCHEMA_VERSION = 83
- The event_txn_id is no longer used.
Changes in SCHEMA_VERSION = 84
- No longer assumes that `event_auth_chain_links` holds transitive links, and
so read operations must do graph traversal.
"""


SCHEMA_COMPAT_VERSION = (
# The event_txn_id table and tables from MSC2716 no longer exist.
83
# Transitive links are no longer written to `event_auth_chain_links`
84
)
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
Expand Down
104 changes: 73 additions & 31 deletions tests/storage/test_event_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from typing import Dict, List, Set, Tuple, cast

from parameterized import parameterized

from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest

Expand All @@ -45,14 +47,16 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self._next_stream_ordering = 1

def test_simple(self) -> None:
@parameterized.expand([(False,), (True,)])
def test_simple(self, batched: bool) -> None:
"""Test that the example in `docs/auth_chain_difference_algorithm.md`
works.
"""

event_factory = self.hs.get_event_builder_factory()
bob = "@creator:test"
alice = "@alice:test"
charlie = "@charlie:test"
room_id = "!room:test"

# Ensure that we have a rooms entry so that we generate the chain index.
Expand Down Expand Up @@ -191,6 +195,26 @@ def test_simple(self) -> None:
)
)

charlie_invite = self.get_success(
event_factory.for_room_version(
RoomVersions.V6,
{
"type": EventTypes.Member,
"state_key": charlie,
"sender": alice,
"room_id": room_id,
"content": {"tag": "charlie_invite"},
},
).build(
prev_event_ids=[],
auth_event_ids=[
create.event_id,
alice_join2.event_id,
power_2.event_id,
],
)
)

events = [
create,
bob_join,
Expand All @@ -200,33 +224,41 @@ def test_simple(self) -> None:
bob_join_2,
power_2,
alice_join2,
charlie_invite,
]

expected_links = [
(bob_join, create),
(power, create),
(power, bob_join),
(alice_invite, create),
(alice_invite, power),
(alice_invite, bob_join),
(bob_join_2, power),
(alice_join2, power_2),
(charlie_invite, alice_join2),
]

self.persist(events)
# We either persist as a batch or one-by-one depending on test
# parameter.
if batched:
self.persist(events)
else:
for event in events:
self.persist([event])

chain_map, link_map = self.fetch_chains(events)

# Check that the expected links and only the expected links have been
# added.
self.assertEqual(len(expected_links), len(list(link_map.get_additions())))

for start, end in expected_links:
start_id, start_seq = chain_map[start.event_id]
end_id, end_seq = chain_map[end.event_id]
event_map = {e.event_id: e for e in events}
reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()}

self.assertIn(
(start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
)
self.maxDiff = None
self.assertCountEqual(
expected_links,
[
(reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)])
for s1, s2, t1, t2 in link_map.get_additions()
],
)

# Test that everything can reach the create event, but the create event
# can't reach anything.
Expand Down Expand Up @@ -368,24 +400,23 @@ def test_out_of_order_events(self) -> None:

expected_links = [
(bob_join, create),
(power, create),
(power, bob_join),
(alice_invite, create),
(alice_invite, power),
(alice_invite, bob_join),
]

# Check that the expected links and only the expected links have been
# added.
self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
event_map = {e.event_id: e for e in events}
reverse_chain_map = {v: event_map[k] for k, v in chain_map.items()}

for start, end in expected_links:
start_id, start_seq = chain_map[start.event_id]
end_id, end_seq = chain_map[end.event_id]

self.assertIn(
(start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
)
self.maxDiff = None
self.assertCountEqual(
expected_links,
[
(reverse_chain_map[(s1, s2)], reverse_chain_map[(t1, t2)])
for s1, s2, t1, t2 in link_map.get_additions()
],
)

def persist(
self,
Expand Down Expand Up @@ -489,8 +520,6 @@ def test_simple(self) -> None:
link_map = _LinkMap()

link_map.add_link((1, 1), (2, 1), new=False)
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
self.assertCountEqual(link_map.get_additions(), [])
self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
Expand All @@ -499,18 +528,31 @@ def test_simple(self) -> None:

# Attempting to add a redundant link is ignored.
self.assertFalse(link_map.add_link((1, 4), (2, 1)))
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
self.assertCountEqual(link_map.get_additions(), [])

# Adding new non-redundant links works
self.assertTrue(link_map.add_link((1, 3), (2, 3)))
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3)])

self.assertTrue(link_map.add_link((2, 5), (1, 3)))
self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])

self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])

def test_exists_path_from(self) -> None:
"Check that `exists_path_from` can handle non-direct links"
link_map = _LinkMap()

link_map.add_link((1, 1), (2, 1), new=False)
link_map.add_link((2, 1), (3, 1), new=False)

self.assertTrue(link_map.exists_path_from((1, 4), (3, 1)))
self.assertFalse(link_map.exists_path_from((1, 4), (3, 2)))

link_map.add_link((1, 5), (2, 3), new=False)
link_map.add_link((2, 2), (3, 3), new=False)

self.assertTrue(link_map.exists_path_from((1, 6), (3, 2)))
self.assertFalse(link_map.exists_path_from((1, 4), (3, 2)))


class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
servlets = [
Expand Down

0 comments on commit 55b0aa8

Please sign in to comment.