Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to tests.replication #14987

Merged
merged 6 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions tests/replication/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from typing import Any, Dict, List, Optional, Set, Tuple

from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol
from twisted.internet.protocol import Protocol, connectionDone
from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource

Expand Down Expand Up @@ -109,7 +110,7 @@ def _get_worker_hs_config(self) -> dict:
config["worker_replication_http_port"] = "8765"
return config

def _build_replication_data_handler(self):
def _build_replication_data_handler(self) -> "TestReplicationDataHandler":
return TestReplicationDataHandler(self.worker_hs)

def reconnect(self) -> None:
Expand Down Expand Up @@ -170,7 +171,7 @@ def handle_http_replication_attempt(self) -> SynapseRequest:
requests: List[SynapseRequest] = []
real_request_factory = channel.requestFactory

def request_factory(*args, **kwargs):
def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest:
request = real_request_factory(*args, **kwargs)
requests.append(request)
return request
Expand Down Expand Up @@ -204,7 +205,7 @@ def request_factory(*args, **kwargs):

def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
):
) -> None:
"""Asserts that the given request is a HTTP replication request for
fetching updates for given stream.
"""
Expand Down Expand Up @@ -246,7 +247,7 @@ def default_config(self) -> Dict[str, Any]:
base["redis"] = {"enabled": True}
return base

def setUp(self):
def setUp(self) -> None:
super().setUp()

# build a replication server
Expand Down Expand Up @@ -289,7 +290,7 @@ def setUp(self):
lambda: self._handle_http_replication_attempt(self.hs, 8765),
)

def create_test_resource(self):
def create_test_resource(self) -> ReplicationRestResource:
"""Overrides `HomeserverTestCase.create_test_resource`."""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
Expand All @@ -303,7 +304,7 @@ def create_test_resource(self):
return resource

def make_worker_hs(
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
) -> HomeServer:
"""Make a new worker HS instance, correctly connecting replcation
stream to the master HS.
Expand Down Expand Up @@ -394,7 +395,7 @@ def replicate(self) -> None:
self.streamer.on_notifier_poke()
self.pump()

def _handle_http_replication_attempt(self, hs, repl_port):
def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> None:
"""Handles a connection attempt to the given HS replication HTTP
listener on the given port.
"""
Expand Down Expand Up @@ -442,8 +443,11 @@ def connect_any_redis_attempts(self) -> None:
self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)

client_protocol = client_factory.buildProtocol(None)
server_protocol = self._redis_server.buildProtocol(None)
client_address = IPv4Address("TCP", "127.0.0.1", 6379)
client_protocol = client_factory.buildProtocol(client_address)

server_address = IPv4Address("TCP", host, port)
server_protocol = self._redis_server.buildProtocol(server_address)
Comment on lines -443 to +450
Copy link
Contributor

@DMRobertson DMRobertson Feb 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went for a different approach in #14988 (comment) of passing in a dummy address. I should probably just go for something simpler like this.


client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
Expand All @@ -465,7 +469,9 @@ def __init__(self, hs: HomeServer):
# list of received (stream_name, token, row) tuples
self.received_rdata_rows: List[Tuple[str, int, Any]] = []

async def on_rdata(self, stream_name, instance_name, token, rows):
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
) -> None:
await super().on_rdata(stream_name, instance_name, token, rows)
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))
Expand All @@ -479,23 +485,25 @@ def __init__(self) -> None:
bytes, Set["FakeRedisPubSubProtocol"]
] = defaultdict(set)

def add_subscriber(self, conn, channel: bytes) -> None:
def add_subscriber(self, conn: "FakeRedisPubSubProtocol", channel: bytes) -> None:
"""A connection has called SUBSCRIBE"""
self._subscribers_by_channel[channel].add(conn)

def remove_subscriber(self, conn) -> None:
def remove_subscriber(self, conn: "FakeRedisPubSubProtocol") -> None:
"""A connection has lost connection"""
for subscribers in self._subscribers_by_channel.values():
subscribers.discard(conn)

def publish(self, conn, channel: bytes, msg) -> int:
def publish(
self, conn: "FakeRedisPubSubProtocol", channel: bytes, msg: object
) -> int:
"""A connection want to publish a message to subscribers."""
for sub in self._subscribers_by_channel[channel]:
sub.send(["message", channel, msg])

return len(self._subscribers_by_channel)

def buildProtocol(self, addr) -> "FakeRedisPubSubProtocol":
def buildProtocol(self, addr: IPv4Address) -> "FakeRedisPubSubProtocol":
return FakeRedisPubSubProtocol(self)


Expand All @@ -508,7 +516,7 @@ def __init__(self, server: FakeRedisPubSubServer):
self._server = server
self._reader = hiredis.Reader()

def dataReceived(self, data) -> None:
def dataReceived(self, data: bytes) -> None:
self._reader.feed(data)

# We might get multiple messages in one packet.
Expand All @@ -525,7 +533,7 @@ def dataReceived(self, data) -> None:

self.handle_command(msg[0], *msg[1:])

def handle_command(self, command, *args) -> None:
def handle_command(self, command: bytes, *args: bytes) -> None:
"""Received a Redis command from the client."""

# We currently only support pub/sub.
Expand All @@ -550,9 +558,9 @@ def handle_command(self, command, *args) -> None:
self.send("PONG")

else:
raise Exception(f"Unknown command: {command}")
raise Exception(f"Unknown command: {command!r}")

def send(self, msg) -> None:
def send(self, msg: object) -> None:
"""Send a message back to the client."""
assert self.transport is not None

Expand Down Expand Up @@ -583,5 +591,5 @@ def encode(self, obj: object) -> str:

raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)

def connectionLost(self, reason) -> None:
def connectionLost(self, reason: Failure = connectionDone) -> None:
self._server.remove_subscriber(self)
27 changes: 14 additions & 13 deletions tests/replication/tcp/streams/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

class TypingStreamTestCase(BaseStreamTestCase):
def _build_replication_data_handler(self) -> Mock:
return Mock(wraps=super()._build_replication_data_handler())
self.mock_handler = Mock(wraps=super()._build_replication_data_handler())
return self.mock_handler

def test_typing(self) -> None:
typing = self.hs.get_typing_handler()
Expand All @@ -43,8 +44,8 @@ def test_typing(self) -> None:
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")

self.test_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.mock_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row: TypingStream.TypingStreamRow = rdata_rows[0]
Expand All @@ -54,11 +55,11 @@ def test_typing(self) -> None:
# Now let's disconnect and insert some data.
self.disconnect()

self.test_handler.on_rdata.reset_mock()
self.mock_handler.on_rdata.reset_mock()

typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)

self.test_handler.on_rdata.assert_not_called()
self.mock_handler.on_rdata.assert_not_called()

self.reconnect()
self.pump(0.1)
Expand All @@ -71,8 +72,8 @@ def test_typing(self) -> None:
assert request.args is not None
self.assertEqual(int(request.args[b"from_token"][0]), token)

self.test_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.mock_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
Expand All @@ -98,8 +99,8 @@ def test_reset(self) -> None:
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")

self.test_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.mock_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row: TypingStream.TypingStreamRow = rdata_rows[0]
Expand Down Expand Up @@ -134,15 +135,15 @@ def test_reset(self) -> None:
self.assert_request_is_get_repl_stream_updates(request, "typing")

# Reset the test code.
self.test_handler.on_rdata.reset_mock()
self.test_handler.on_rdata.assert_not_called()
self.mock_handler.on_rdata.reset_mock()
self.mock_handler.on_rdata.assert_not_called()

# Push additional data.
typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
self.reactor.advance(0)

self.test_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.mock_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
Expand Down