Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix client reader sharding tests #7853

Merged
merged 5 commits into from
Jul 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7853.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for handling registration requests across multiple client reader workers.
24 changes: 23 additions & 1 deletion synapse/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
IReactorPluggableNameResolver,
IResolutionReceiver,
)
from twisted.internet.task import Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, HTTPConnectionPool, readBody
Expand Down Expand Up @@ -69,6 +70,21 @@ def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
return False


_EPSILON = 0.00000001


def _make_scheduler(reactor):
"""Makes a schedular suitable for a Cooperator using the given reactor.

(This is effectively just a copy from `twisted.internet.task`)
"""

def _scheduler(x):
return reactor.callLater(_EPSILON, x)

return _scheduler
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved


class IPBlacklistingResolver(object):
"""
A proxy for reactor.nameResolver which only produces non-blacklisted IP
Expand Down Expand Up @@ -212,6 +228,10 @@ def __init__(
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)

# We use this for our body producers to ensure that they use the correct
# reactor.
self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor()))

self.user_agent = self.user_agent.encode("ascii")

if self._ip_blacklist:
Expand Down Expand Up @@ -292,7 +312,9 @@ def request(self, method, uri, data=None, headers=None):
try:
body_producer = None
if data is not None:
body_producer = QuieterFileBodyProducer(BytesIO(data))
body_producer = QuieterFileBodyProducer(
BytesIO(data), cooperator=self._cooperator,
)

request_deferred = treq.request(
method,
Expand Down
5 changes: 5 additions & 0 deletions synapse/server.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import synapse.handlers.room
import synapse.handlers.room_member
import synapse.handlers.set_password
import synapse.http.client
import synapse.http.matrixfederationclient
import synapse.notifier
import synapse.push.pusherpool
import synapse.replication.tcp.client
Expand Down Expand Up @@ -141,3 +142,7 @@ class HomeServer(object):
pass
def get_replication_streams(self) -> Dict[str, Stream]:
pass
def get_http_client(
self,
) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient:
pass
168 changes: 165 additions & 3 deletions tests/replication/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import logging
from typing import Any, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple

import attr

Expand All @@ -26,16 +26,17 @@
GenericWorkerReplicationHandler,
GenericWorkerServer,
)
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.replication.http import streams
from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from synapse.util import Clock

from tests import unittest
from tests.server import FakeTransport
from tests.server import FakeTransport, render

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -180,6 +181,159 @@ def assert_request_is_get_repl_stream_updates(
self.assertEqual(request.method, b"GET")


class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests running multiple workers.

Automatically handle HTTP replication requests from workers to master,
unlike `BaseStreamTestCase`.
"""

servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]]

def setUp(self):
super().setUp()

# build a replication server
self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()

store = self.hs.get_datastore()
self.database = store.db

self.reactor.lookups["testserv"] = "1.2.3.4"

self._worker_hs_to_resource = {}

# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
# manually have to go and explicitly set it up each time (plus sometimes
# it is impossible to write the handling explicitly in the tests).
self.reactor.add_tcp_client_callback(
"1.2.3.4", 8765, self._handle_http_replication_attempt
)

def create_test_json_resource(self):
"""Overrides `HomeserverTestCase.create_test_json_resource`.
"""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
# subclassses.

resource = ReplicationRestResource(self.hs)
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved

for servlet in self.servlets:
servlet(self.hs, resource)

return resource

def make_worker_hs(
self, worker_app: str, extra_config: dict = {}, **kwargs
) -> HomeServer:
"""Make a new worker HS instance, correctly connecting replcation
stream to the master HS.

Args:
worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
extra_config: Any extra config to use for this instances.
**kwargs: Options that get passed to `self.setup_test_homeserver`,
useful to e.g. pass some mocks for things like `http_client`

Returns:
The new worker HomeServer instance.
"""

config = self._get_worker_hs_config()
config["worker_app"] = worker_app
config.update(extra_config)

worker_hs = self.setup_test_homeserver(
homeserverToUse=GenericWorkerServer,
config=config,
reactor=self.reactor,
**kwargs
)

store = worker_hs.get_datastore()
store.db._db_pool = self.database._db_pool

repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
worker_hs, "client", "test", self.clock, repl_handler,
)
server = self.server_factory.buildProtocol(None)

client_transport = FakeTransport(server, self.reactor)
client.makeConnection(client_transport)

server_transport = FakeTransport(client, self.reactor)
server.makeConnection(server_transport)

# Set up a resource for the worker
resource = ReplicationRestResource(self.hs)

for servlet in self.servlets:
servlet(worker_hs, resource)

self._worker_hs_to_resource[worker_hs] = resource

return worker_hs

def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
return config

def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
render(request, self._worker_hs_to_resource[worker_hs], self.reactor)

def replicate(self):
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump()
clokep marked this conversation as resolved.
Show resolved Hide resolved

def _handle_http_replication_attempt(self):
"""Handles a connection attempt to the master replication HTTP
listener.
"""

# We should have at least one outbound connection attempt, where the
# last is one to the HTTP repication IP/port.
clients = self.reactor.tcpClients
self.assertGreaterEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 8765)

# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)

request_factory = OneShotRequestFactory()

# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor)
channel.requestFactory = request_factory
channel.site = self.site

# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
channel, self.reactor, client_protocol
)
client_protocol.makeConnection(client_to_server_transport)

server_to_client_transport = FakeTransport(
client_protocol, self.reactor, channel
)
channel.makeConnection(server_to_client_transport)

# Note: at this point we've wired everything up, but we need to return
# before the data starts flowing over the connections as this is called
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.


class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""

Expand Down Expand Up @@ -241,6 +395,14 @@ def unregisterProducer(self):
# We need to manually stop the _PullToPushProducer.
self._pull_to_push_producer.stop()

def checkPersistence(self, request, version):
"""Check whether the connection can be re-used
"""
# We hijack this to always say no for ease of wiring stuff up in
# `handle_http_replication_attempt`.
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
return False


class _PullToPushProducer:
"""A push producer that wraps a pull producer.
Expand Down
59 changes: 11 additions & 48 deletions tests/replication/test_client_reader_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,63 +15,26 @@
import logging

from synapse.api.constants import LoginType
from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest.client.v2_alpha import register

from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
from tests.server import FakeChannel, render
from tests.server import FakeChannel

logger = logging.getLogger(__name__)


class ClientReaderTestCase(unittest.HomeserverTestCase):
class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
"""Base class for tests of the replication streams"""

servlets = [
register.register_servlets,
]
servlets = [register.register_servlets]

def prepare(self, reactor, clock, hs):
# build a replication server
self.server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()

store = hs.get_datastore()
self.database = store.db

self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler()
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker

self.reactor.lookups["testserv"] = "1.2.3.4"

def make_worker_hs(self, extra_config={}):
config = self._get_worker_hs_config()
config.update(extra_config)

worker_hs = self.setup_test_homeserver(
homeserverToUse=GenericWorkerServer, config=config, reactor=self.reactor,
)

store = worker_hs.get_datastore()
store.db._db_pool = self.database._db_pool

# Register the expected servlets, essentially this is HomeserverTestCase.create_test_json_resource.
resource = JsonResource(self.hs)

for servlet in self.servlets:
servlet(worker_hs, resource)

# Essentially HomeserverTestCase.render.
def _render(request):
render(request, self.resource, self.reactor)

return worker_hs, _render

def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.client_reader"
Expand All @@ -82,14 +45,14 @@ def _get_worker_hs_config(self) -> dict:
def test_register_single_worker(self):
"""Test that registration works when using a single client reader worker.
"""
_, worker_render = self.make_worker_hs()
worker_hs = self.make_worker_hs("synapse.app.client_reader")

request_1, channel_1 = self.make_request(
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
) # type: SynapseRequest, FakeChannel
worker_render(request_1)
self.render_on_worker(worker_hs, request_1)
self.assertEqual(request_1.code, 401)

# Grab the session
Expand All @@ -99,7 +62,7 @@ def test_register_single_worker(self):
request_2, channel_2 = self.make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
) # type: SynapseRequest, FakeChannel
worker_render(request_2)
self.render_on_worker(worker_hs, request_2)
self.assertEqual(request_2.code, 200)

# We're given a registered user.
Expand All @@ -108,15 +71,15 @@ def test_register_single_worker(self):
def test_register_multi_worker(self):
"""Test that registration works when using multiple client reader workers.
"""
_, worker_render_1 = self.make_worker_hs()
_, worker_render_2 = self.make_worker_hs()
worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")

request_1, channel_1 = self.make_request(
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
) # type: SynapseRequest, FakeChannel
worker_render_1(request_1)
self.render_on_worker(worker_hs_1, request_1)
self.assertEqual(request_1.code, 401)

# Grab the session
Expand All @@ -126,7 +89,7 @@ def test_register_multi_worker(self):
request_2, channel_2 = self.make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
) # type: SynapseRequest, FakeChannel
worker_render_2(request_2)
self.render_on_worker(worker_hs_2, request_2)
self.assertEqual(request_2.code, 200)

# We're given a registered user.
Expand Down
Loading