Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Avoid sending massive replication updates when purging a room. #16510

Merged
merged 5 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16510.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve replication performance when purging rooms.
45 changes: 44 additions & 1 deletion synapse/replication/tcp/streams/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import heapq
from collections import defaultdict
from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast

import attr
Expand Down Expand Up @@ -51,8 +52,19 @@
* The state_key of the state which has changed
* The event id of the new state

A "state-all" row is sent whenever the "current state" in a room changes, but there are
too many state updates for a particular room in the same update. This replaces any
"state" rows on a per-room basis. The fields in the data part are:

* The room id for the state changes

"""

# Any room with more than _MAX_STATE_UPDATES_PER_ROOM will send a EventsStreamAllStateRow
# instead of individual EventsStreamEventRow. This is predominantly useful when
# purging large rooms.
_MAX_STATE_UPDATES_PER_ROOM = 150


@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventsStreamRow:
Expand Down Expand Up @@ -111,9 +123,17 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
event_id: Optional[str]


@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventsStreamAllStateRow(BaseEventsStreamRow):
TypeId = "state-all"

room_id: str


_EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
EventsStreamEventRow,
EventsStreamCurrentStateRow,
EventsStreamAllStateRow,
)

TypeToRow = {Row.TypeId: Row for Row in _EventRows}
Expand Down Expand Up @@ -213,9 +233,28 @@ async def _update_function(
if stream_id <= upper_limit
)

# Separate out rooms that have many state updates, listeners should clear
# all state for those rooms.
state_updates_by_room = defaultdict(list)
for stream_id, room_id, _type, _state_key, _event_id in state_rows:
state_updates_by_room[room_id].append(stream_id)

state_all_rows = [
(stream_ids[-1], room_id)
for room_id, stream_ids in state_updates_by_room.items()
if len(stream_ids) >= _MAX_STATE_UPDATES_PER_ROOM
]
state_all_updates: Iterable[Tuple[int, Tuple]] = (
(max_stream_id, (EventsStreamAllStateRow.TypeId, (room_id,)))
for (max_stream_id, room_id) in state_all_rows
)

# Any remaining state updates are sent individually.
state_all_rooms = {room_id for _, room_id in state_all_rows}
state_updates: Iterable[Tuple[int, Tuple]] = (
(stream_id, (EventsStreamCurrentStateRow.TypeId, rest))
for (stream_id, *rest) in state_rows
if rest[0] not in state_all_rooms
)

ex_outliers_updates: Iterable[Tuple[int, Tuple]] = (
Expand All @@ -224,7 +263,11 @@ async def _update_function(
)

# we need to return a sorted list, so merge them together.
updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates))
updates = list(
heapq.merge(
event_updates, state_all_updates, state_updates, ex_outliers_updates
)
)
return updates, upper_limit, limited

@classmethod
Expand Down
8 changes: 8 additions & 0 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from synapse.replication.tcp.streams import BackfillStream, CachesStream
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamAllStateRow,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
Expand Down Expand Up @@ -264,6 +265,13 @@ def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
(data.state_key,)
)
self.get_rooms_for_user.invalidate((data.state_key,)) # type: ignore[attr-defined]
elif row.type == EventsStreamAllStateRow.TypeId:
assert isinstance(data, EventsStreamAllStateRow)
# Similar to the above, but the entire caches are invalidated. This is
# unfortunate for the membership caches, but should recover quickly.
self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) # type: ignore[attr-defined]
self.get_rooms_for_user_with_stream_ordering.invalidate_all() # type: ignore[attr-defined]
self.get_rooms_for_user.invalidate_all() # type: ignore[attr-defined]
else:
raise Exception("Unknown events stream row type %s" % (row.type,))

Expand Down
91 changes: 62 additions & 29 deletions tests/replication/tcp/streams/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@

from typing import Any, List, Optional

from parameterized import parameterized

from twisted.test.proto_helpers import MemoryReactor

from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.replication.tcp.commands import RdataCommand
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams.events import (
_MAX_STATE_UPDATES_PER_ROOM,
EventsStreamAllStateRow,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
Expand Down Expand Up @@ -106,11 +110,21 @@ def test_update_function_event_row_limit(self) -> None:

self.assertEqual([], received_rows)

def test_update_function_huge_state_change(self) -> None:
@parameterized.expand(
[(_STREAM_UPDATE_TARGET_ROW_COUNT, False), (_MAX_STATE_UPDATES_PER_ROOM, True)]
)
def test_update_function_huge_state_change(
self, num_state_changes: int, collapse_state_changes: bool
) -> None:
"""Test replication with many state events

Ensures that all events are correctly replicated when there are lots of
state change rows to be replicated.

Args:
num_state_changes: The number of state changes to create.
collapse_state_changes: Whether the state changes are expected to be
collapsed or not.
"""

# we want to generate lots of state changes at a single stream ID.
Expand Down Expand Up @@ -145,7 +159,7 @@ def test_update_function_huge_state_change(self) -> None:

events = [
self._inject_state_event(sender=OTHER_USER)
for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT)
for _ in range(num_state_changes)
]

self.replicate()
Expand Down Expand Up @@ -202,8 +216,7 @@ def test_update_function_huge_state_change(self) -> None:
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]

# first check the first two rows, which should be state1

# first check the first two rows, which should be the state1 event.
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
Expand All @@ -217,7 +230,7 @@ def test_update_function_huge_state_change(self) -> None:
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state1.event_id)

# now the last two rows, which should be state2
# now the last two rows, which should be the state2 event.
stream_name, token, row = received_rows.pop(-2)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
Expand All @@ -231,34 +244,54 @@ def test_update_function_huge_state_change(self) -> None:
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state2.event_id)

# that should leave us with the rows for the PL event
self.assertEqual(len(received_rows), len(events) + 2)
# Based on the number of
if collapse_state_changes:
# that should leave us with the rows for the PL event, the state changes
# get collapsed into a single row.
self.assertEqual(len(received_rows), 2)

stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, pl_event.event_id)
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, pl_event.event_id)

# the state rows are unsorted
state_rows: List[EventsStreamCurrentStateRow] = []
for stream_name, _, row in received_rows:
stream_name, token, row = received_rows.pop(0)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state-all")
self.assertIsInstance(row.data, EventsStreamAllStateRow)
self.assertEqual(row.data.room_id, state2.room_id)

else:
# that should leave us with the rows for the PL event
self.assertEqual(len(received_rows), len(events) + 2)

stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
state_rows.append(row.data)

state_rows.sort(key=lambda r: r.state_key)

sr = state_rows.pop(0)
self.assertEqual(sr.type, EventTypes.PowerLevels)
self.assertEqual(sr.event_id, pl_event.event_id)
for sr in state_rows:
self.assertEqual(sr.type, "test_state_event")
# "None" indicates the state has been deleted
self.assertIsNone(sr.event_id)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, pl_event.event_id)

# the state rows are unsorted
state_rows: List[EventsStreamCurrentStateRow] = []
for stream_name, _, row in received_rows:
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
state_rows.append(row.data)

state_rows.sort(key=lambda r: r.state_key)

sr = state_rows.pop(0)
self.assertEqual(sr.type, EventTypes.PowerLevels)
self.assertEqual(sr.event_id, pl_event.event_id)
for sr in state_rows:
self.assertEqual(sr.type, "test_state_event")
# "None" indicates the state has been deleted
self.assertIsNone(sr.event_id)

def test_update_function_state_row_limit(self) -> None:
"""Test replication with many state events over several stream ids."""
Expand Down
Loading