Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Always notify replication when a stream advances (#14877)
Browse files Browse the repository at this point in the history
This ensures that all other workers are told about stream updates in a timely manner, without having to remember to manually poke replication.
  • Loading branch information
erikjohnston committed Jan 20, 2023
1 parent cf18fea commit 65d0386
Show file tree
Hide file tree
Showing 19 changed files with 104 additions and 29 deletions.
1 change: 1 addition & 0 deletions changelog.d/14877.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Always notify replication when a stream advances automatically.
4 changes: 4 additions & 0 deletions synapse/_scripts/synapse_port_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
make_deferred_yieldable,
run_in_background,
)
from synapse.notifier import ReplicationNotifier
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.databases.main import PushRuleStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
Expand Down Expand Up @@ -260,6 +261,9 @@ def get_instance_name(self) -> str:
def should_send_federation(self) -> bool:
return False

def get_replication_notifier(self) -> ReplicationNotifier:
return ReplicationNotifier()


class Porter:
def __init__(
Expand Down
31 changes: 26 additions & 5 deletions synapse/notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.pending_new_room_events: List[_PendingRoomEventEntry] = []

# Called when there are new things to stream over replication
self.replication_callbacks: List[Callable[[], None]] = []
self._replication_notifier = hs.get_replication_notifier()
self._new_join_in_room_callbacks: List[Callable[[str, str], None]] = []

self._federation_client = hs.get_federation_http_client()
Expand Down Expand Up @@ -279,7 +278,7 @@ def add_replication_callback(self, cb: Callable[[], None]) -> None:
it needs to do any asynchronous work, a background thread should be started and
wrapped with run_as_background_process.
"""
self.replication_callbacks.append(cb)
self._replication_notifier.add_replication_callback(cb)

def add_new_join_in_room_callback(self, cb: Callable[[str, str], None]) -> None:
"""Add a callback that will be called when a user joins a room.
Expand Down Expand Up @@ -741,8 +740,7 @@ def _user_joined_room(self, user_id: str, room_id: str) -> None:

def notify_replication(self) -> None:
"""Notify the any replication listeners that there's a new event"""
for cb in self.replication_callbacks:
cb()
self._replication_notifier.notify_replication()

def notify_user_joined_room(self, event_id: str, room_id: str) -> None:
for cb in self._new_join_in_room_callbacks:
Expand All @@ -759,3 +757,26 @@ def notify_remote_server_up(self, server: str) -> None:
# Tell the federation client about the fact the server is back up, so
# that any in flight requests can be immediately retried.
self._federation_client.wake_destination(server)


@attr.s(auto_attribs=True)
class ReplicationNotifier:
"""Tracks callbacks for things that need to know about stream changes.
This is separate from the notifier to avoid circular dependencies.
"""

_replication_callbacks: List[Callable[[], None]] = attr.Factory(list)

def add_replication_callback(self, cb: Callable[[], None]) -> None:
"""Add a callback that will be called when some new data is available.
Callback is not given any arguments. It should *not* return a Deferred - if
it needs to do any asynchronous work, a background thread should be started and
wrapped with run_as_background_process.
"""
self._replication_callbacks.append(cb)

def notify_replication(self) -> None:
"""Notify the any replication listeners that there's a new event"""
for cb in self._replication_callbacks:
cb()
6 changes: 5 additions & 1 deletion synapse/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager
from synapse.module_api import ModuleApi
from synapse.notifier import Notifier
from synapse.notifier import Notifier, ReplicationNotifier
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
Expand Down Expand Up @@ -389,6 +389,10 @@ def get_federation_server(self) -> FederationServer:
def get_notifier(self) -> Notifier:
return Notifier(self)

@cache_in_self
def get_replication_notifier(self) -> ReplicationNotifier:
return ReplicationNotifier()

@cache_in_self
def get_auth(self) -> Auth:
return Auth(self)
Expand Down
2 changes: 2 additions & 0 deletions synapse/storage/databases/main/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="account_data",
instance_name=self._instance_name,
tables=[
Expand All @@ -95,6 +96,7 @@ def __init__(
# SQLite).
self._account_data_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
Expand Down
1 change: 1 addition & 0 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
notifier=hs.get_replication_notifier(),
stream_name="caches",
instance_name=hs.get_instance_name(),
tables=[
Expand Down
3 changes: 2 additions & 1 deletion synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="to_device",
instance_name=self._instance_name,
tables=[("device_inbox", "instance_name", "stream_id")],
Expand All @@ -101,7 +102,7 @@ def __init__(
else:
self._can_write_to_device = True
self._device_inbox_id_gen = StreamIdGenerator(
db_conn, "device_inbox", "stream_id"
db_conn, hs.get_replication_notifier(), "device_inbox", "stream_id"
)

max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
Expand Down
1 change: 1 addition & 0 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(
# class below that is used on the main process.
self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"device_lists_stream",
"stream_id",
extra_tables=[
Expand Down
5 changes: 4 additions & 1 deletion synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,10 @@ def __init__(
super().__init__(database, db_conn, hs)

self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"
db_conn,
hs.get_replication_notifier(),
"e2e_cross_signing_keys",
"stream_id",
)

async def set_e2e_device_keys(
Expand Down
10 changes: 9 additions & 1 deletion synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def __init__(
self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="events",
instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")],
Expand All @@ -200,6 +201,7 @@ def __init__(
self._backfill_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="backfill",
instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")],
Expand All @@ -217,12 +219,14 @@ def __init__(
# SQLite).
self._stream_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"events",
"stream_ordering",
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"events",
"stream_ordering",
step=-1,
Expand Down Expand Up @@ -300,6 +304,7 @@ def get_chain_id_txn(txn: Cursor) -> int:
self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="un_partial_stated_event_stream",
instance_name=hs.get_instance_name(),
tables=[
Expand All @@ -311,7 +316,10 @@ def get_chain_id_txn(txn: Cursor) -> int:
)
else:
self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
db_conn, "un_partial_stated_event_stream", "stream_id"
db_conn,
hs.get_replication_notifier(),
"un_partial_stated_event_stream",
"stream_id",
)

def get_un_partial_stated_events_token(self) -> int:
Expand Down
3 changes: 2 additions & 1 deletion synapse/storage/databases/main/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
self._presence_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="presence_stream",
instance_name=self._instance_name,
tables=[("presence_stream", "instance_name", "stream_id")],
Expand All @@ -85,7 +86,7 @@ def __init__(
)
else:
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id"
)

self.hs = hs
Expand Down
1 change: 1 addition & 0 deletions synapse/storage/databases/main/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
# class below that is used on the main process.
self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"push_rules_stream",
"stream_id",
is_writer=hs.config.worker.worker_app is None,
Expand Down
1 change: 1 addition & 0 deletions synapse/storage/databases/main/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
# class below that is used on the main process.
self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"pushers",
"id",
extra_tables=[("deleted_pushers", "stream_id")],
Expand Down
2 changes: 2 additions & 0 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
self._receipts_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="receipts",
instance_name=self._instance_name,
tables=[("receipts_linearized", "instance_name", "stream_id")],
Expand All @@ -91,6 +92,7 @@ def __init__(
# SQLite).
self._receipts_id_gen = StreamIdGenerator(
db_conn,
hs.get_replication_notifier(),
"receipts_linearized",
"stream_id",
is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
Expand Down
6 changes: 5 additions & 1 deletion synapse/storage/databases/main/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
stream_name="un_partial_stated_room_stream",
instance_name=self._instance_name,
tables=[
Expand All @@ -137,7 +138,10 @@ def __init__(
)
else:
self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
db_conn, "un_partial_stated_room_stream", "stream_id"
db_conn,
hs.get_replication_notifier(),
"un_partial_stated_room_stream",
"stream_id",
)

async def store_room(
Expand Down
26 changes: 24 additions & 2 deletions synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from contextlib import contextmanager
from types import TracebackType
from typing import (
TYPE_CHECKING,
AsyncContextManager,
ContextManager,
Dict,
Expand Down Expand Up @@ -49,6 +50,9 @@
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import PostgresSequenceGenerator

if TYPE_CHECKING:
from synapse.notifier import ReplicationNotifier

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -182,6 +186,7 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
def __init__(
self,
db_conn: LoggingDatabaseConnection,
notifier: "ReplicationNotifier",
table: str,
column: str,
extra_tables: Iterable[Tuple[str, str]] = (),
Expand All @@ -205,6 +210,8 @@ def __init__(
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()

self._notifier = notifier

def advance(self, instance_name: str, new_id: int) -> None:
# Advance should never be called on a writer instance, only over replication
if self._is_writer:
Expand All @@ -227,6 +234,8 @@ def manager() -> Generator[int, None, None]:
with self._lock:
self._unfinished_ids.pop(next_id)

self._notifier.notify_replication()

return _AsyncCtxManagerWrapper(manager())

def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
Expand All @@ -250,6 +259,8 @@ def manager() -> Generator[Sequence[int], None, None]:
for next_id in next_ids:
self._unfinished_ids.pop(next_id)

self._notifier.notify_replication()

return _AsyncCtxManagerWrapper(manager())

def get_current_token(self) -> int:
Expand Down Expand Up @@ -296,6 +307,7 @@ def __init__(
self,
db_conn: LoggingDatabaseConnection,
db: DatabasePool,
notifier: "ReplicationNotifier",
stream_name: str,
instance_name: str,
tables: List[Tuple[str, str, str]],
Expand All @@ -304,6 +316,7 @@ def __init__(
positive: bool = True,
) -> None:
self._db = db
self._notifier = notifier
self._stream_name = stream_name
self._instance_name = instance_name
self._positive = positive
Expand Down Expand Up @@ -535,7 +548,9 @@ def get_next(self) -> AsyncContextManager[int]:
# Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
# controls the return type. If `None` or omitted, the context manager yields
# a single integer stream_id; otherwise it yields a list of stream_ids.
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
return cast(
AsyncContextManager[int], _MultiWriterCtxManager(self, self._notifier)
)

def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
# If we have a list of instances that are allowed to write to this
Expand All @@ -544,7 +559,10 @@ def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
raise Exception("Tried to allocate stream ID on non-writer")

# Cast safety: see get_next.
return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n))
return cast(
AsyncContextManager[List[int]],
_MultiWriterCtxManager(self, self._notifier, n),
)

def get_next_txn(self, txn: LoggingTransaction) -> int:
"""
Expand All @@ -563,6 +581,7 @@ def get_next_txn(self, txn: LoggingTransaction) -> int:

txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
txn.call_after(self._notifier.notify_replication)

# Update the `stream_positions` table with newly updated stream
# ID (unless self._writers is not set in which case we don't
Expand Down Expand Up @@ -787,6 +806,7 @@ class _MultiWriterCtxManager:
"""Async context manager returned by MultiWriterIdGenerator"""

id_gen: MultiWriterIdGenerator
notifier: "ReplicationNotifier"
multiple_ids: Optional[int] = None
stream_ids: List[int] = attr.Factory(list)

Expand Down Expand Up @@ -814,6 +834,8 @@ async def __aexit__(
for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i)

self.notifier.notify_replication()

if exc_type is not None:
return False

Expand Down
Loading

0 comments on commit 65d0386

Please sign in to comment.