Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Handle replication commands synchronously where possible (#7876)
Browse files Browse the repository at this point in the history
Most of the stuff we do for replication commands can be done synchronously. There's no point spinning up background processes if we're not going to need them.
  • Loading branch information
richvdh committed Jul 27, 2020
1 parent 7c2e2c2 commit f57b99a
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 86 deletions.
1 change: 1 addition & 0 deletions changelog.d/7876.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix an `AssertionError` exception introduced in v1.18.0rc1.
1 change: 1 addition & 0 deletions changelog.d/7876.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Further optimise queueing of inbound replication commands.
115 changes: 66 additions & 49 deletions synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
from typing import (
Any,
Awaitable,
Dict,
Iterable,
Iterator,
Expand All @@ -33,6 +34,7 @@
from twisted.internet.protocol import ReconnectingClientFactory

from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Expand Down Expand Up @@ -152,7 +154,7 @@ def __init__(self, hs):
# When POSITION or RDATA commands arrive, we stick them in a queue and process
# them in order in a separate background process.

# the streams which are currently being processed by _unsafe_process_stream
# the streams which are currently being processed by _unsafe_process_queue
self._processing_streams = set() # type: Set[str]

# for each stream, a queue of commands that are awaiting processing, and the
Expand Down Expand Up @@ -185,7 +187,7 @@ def __init__(self, hs):
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()

async def _add_command_to_stream_queue(
def _add_command_to_stream_queue(
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
"""Queue the given received command for processing
Expand All @@ -199,33 +201,34 @@ async def _add_command_to_stream_queue(
logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
return

# if we're already processing this stream, stick the new command in the
# queue, and we're done.
queue.append((cmd, conn))

# if we're already processing this stream, there's nothing more to do:
# the new entry on the queue will get picked up in due course
if stream_name in self._processing_streams:
queue.append((cmd, conn))
return

# otherwise, process the new command.
# fire off a background process to start processing the queue.
run_as_background_process(
"process-replication-data", self._unsafe_process_queue, stream_name
)

# arguably we should start off a new background process here, but nothing
# will be too upset if we don't return for ages, so let's save the overhead
# and use the existing logcontext.
async def _unsafe_process_queue(self, stream_name: str):
"""Processes the command queue for the given stream, until it is empty
Does not check if there is already a thread processing the queue, hence "unsafe"
"""
assert stream_name not in self._processing_streams

self._processing_streams.add(stream_name)
try:
# might as well skip the queue for this one, since it must be empty
assert not queue
await self._process_command(cmd, conn, stream_name)

# now process any other commands that have built up while we were
# dealing with that one.
queue = self._command_queues_by_stream.get(stream_name)
while queue:
cmd, conn = queue.popleft()
try:
await self._process_command(cmd, conn, stream_name)
except Exception:
logger.exception("Failed to handle command %s", cmd)

finally:
self._processing_streams.discard(stream_name)

Expand Down Expand Up @@ -299,7 +302,7 @@ def get_streams_to_replicate(self) -> List[Stream]:
"""
return self._streams_to_replicate

async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn)

def send_positions_to_connection(self, conn: AbstractConnection):
Expand All @@ -318,57 +321,73 @@ def send_positions_to_connection(self, conn: AbstractConnection):
)
)

async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
def on_USER_SYNC(
self, conn: AbstractConnection, cmd: UserSyncCommand
) -> Optional[Awaitable[None]]:
user_sync_counter.inc()

if self._is_master:
await self._presence_handler.update_external_syncs_row(
return self._presence_handler.update_external_syncs_row(
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
else:
return None

async def on_CLEAR_USER_SYNC(
def on_CLEAR_USER_SYNC(
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
):
) -> Optional[Awaitable[None]]:
if self._is_master:
await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
else:
return None

async def on_FEDERATION_ACK(
self, conn: AbstractConnection, cmd: FederationAckCommand
):
def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
federation_ack_counter.inc()

if self._federation_sender:
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)

async def on_REMOVE_PUSHER(
def on_REMOVE_PUSHER(
self, conn: AbstractConnection, cmd: RemovePusherCommand
):
) -> Optional[Awaitable[None]]:
remove_pusher_counter.inc()

if self._is_master:
await self._store.delete_pusher_by_app_id_pushkey_user_id(
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
)
return self._handle_remove_pusher(cmd)
else:
return None

async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
await self._store.delete_pusher_by_app_id_pushkey_user_id(
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
)

self._notifier.on_new_replication_data()
self._notifier.on_new_replication_data()

async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
def on_USER_IP(
self, conn: AbstractConnection, cmd: UserIpCommand
) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc()

if self._is_master:
await self._store.insert_client_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)
return self._handle_user_ip(cmd)
else:
return None

async def _handle_user_ip(self, cmd: UserIpCommand):
await self._store.insert_client_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)

if self._server_notices_sender:
await self._server_notices_sender.on_user_ip(cmd.user_id)
assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id)

async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
return
Expand All @@ -382,7 +401,7 @@ async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
# 2. so we don't race with getting a POSITION command and fetching
# missing RDATA.

await self._add_command_to_stream_queue(conn, cmd)
self._add_command_to_stream_queue(conn, cmd)

async def _process_rdata(
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
Expand Down Expand Up @@ -459,14 +478,14 @@ async def on_rdata(
stream_name, instance_name, token, rows
)

async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes
return

logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())

await self._add_command_to_stream_queue(conn, cmd)
self._add_command_to_stream_queue(conn, cmd)

async def _process_position(
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
Expand Down Expand Up @@ -526,9 +545,7 @@ async def _process_position(

self._streams_by_connection.setdefault(conn, set()).add(stream_name)

async def on_REMOTE_SERVER_UP(
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
):
def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)

Expand Down
45 changes: 28 additions & 17 deletions synapse/replication/tcp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import fcntl
import logging
import struct
from inspect import isawaitable
from typing import TYPE_CHECKING, List

from prometheus_client import Counter
Expand Down Expand Up @@ -128,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
`ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine;
if so, that will get run as a background process.
It also sends `PING` periodically, and correctly times out remote connections
(if they send a `PING` command)
Expand Down Expand Up @@ -166,9 +169,9 @@ def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"):

# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
self._logging_context = BackgroundProcessLoggingContext(
"replication_command_handler-%s" % self.conn_id
)
ctx_name = "replication-conn-%s" % self.conn_id
self._logging_context = BackgroundProcessLoggingContext(ctx_name)
self._logging_context.request = ctx_name

def connectionMade(self):
logger.info("[%s] Connection established", self.id())
Expand Down Expand Up @@ -246,18 +249,17 @@ def _parse_and_dispatch_line(self, line: bytes):

tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()

# Now lets try and call on_<CMD_NAME> function
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)
self.handle_command(cmd)

async def handle_command(self, cmd: Command):
def handle_command(self, cmd: Command) -> None:
"""Handle a command we have received over the replication stream.
First calls `self.on_<COMMAND>` if it exists, then calls
`self.command_handler.on_<COMMAND>` if it exists. This allows for
protocol level handling of commands (e.g. PINGs), before delegating to
the handler.
`self.command_handler.on_<COMMAND>` if it exists (which can optionally
return an Awaitable).
This allows for protocol level handling of commands (e.g. PINGs), before
delegating to the handler.
Args:
cmd: received command
Expand All @@ -268,13 +270,22 @@ async def handle_command(self, cmd: Command):
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
cmd_func(cmd)
handled = True

# Then call out to the handler.
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(self, cmd)
res = cmd_func(self, cmd)

# the handler might be a coroutine: fire it off as a background process
# if so.

if isawaitable(res):
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), lambda: res
)

handled = True

if not handled:
Expand Down Expand Up @@ -350,10 +361,10 @@ def _send_pending_commands(self):
for cmd in pending:
self.send_command(cmd)

async def on_PING(self, line):
def on_PING(self, line):
self.received_ping = True

async def on_ERROR(self, cmd):
def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)

def pauseProducing(self):
Expand Down Expand Up @@ -448,7 +459,7 @@ def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
super().connectionMade()

async def on_NAME(self, cmd):
def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data

Expand Down Expand Up @@ -477,7 +488,7 @@ def connectionMade(self):
# Once we've connected subscribe to the necessary streams
self.replicate()

async def on_SERVER(self, cmd):
def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
Expand Down
Loading

0 comments on commit f57b99a

Please sign in to comment.