Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Finish up type hints for federation client code #15465

Merged
merged 9 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 0 additions & 6 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@ exclude = (?x)
|synapse/storage/schema/
)$

[mypy-synapse.federation.transport.client]
disallow_untyped_defs = False

[mypy-synapse.http.matrixfederationclient]
disallow_untyped_defs = False

[mypy-synapse.metrics._reactor_metrics]
disallow_untyped_defs = False
# This module imports select.epoll. That exists on Linux, but doesn't on macOS.
Expand Down
8 changes: 2 additions & 6 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,11 @@ async def backfill(
logger.debug("backfill transaction_data=%r", transaction_data)

if not isinstance(transaction_data, dict):
# TODO we probably want an exception type specific to federation
# client validation.
raise TypeError("Backfill transaction_data is not a dict.")
raise InvalidResponseError("Backfill transaction_data is not a dict.")
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

transaction_data_pdus = transaction_data.get("pdus")
if not isinstance(transaction_data_pdus, list):
# TODO we probably want an exception type specific to federation
# client validation.
raise TypeError("transaction_data.pdus is not a list.")
raise InvalidResponseError("transaction_data.pdus is not a list.")

room_version = await self.store.get_room_version(room_id)

Expand Down
63 changes: 51 additions & 12 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
import urllib
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Expand All @@ -42,18 +43,25 @@
)
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.units import Transaction
from synapse.http.matrixfederationclient import ByteParser
from synapse.http.matrixfederationclient import (
ByteParser,
JsonDictParser,
LegacyJsonDictParser,
)
from synapse.http.types import QueryParams
from synapse.types import JsonDict
from synapse.util import ExceptionBundle

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)


class TransportLayerClient:
"""Sends federation HTTP requests to other servers"""

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
self.server_name = hs.hostname
self.client = hs.get_federation_http_client()
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
Expand All @@ -80,6 +88,7 @@ async def get_room_state_ids(
path=path,
args={"event_id": event_id},
try_trailing_slash_on_400=True,
parser=JsonDictParser(),
)

async def get_room_state(
Expand Down Expand Up @@ -128,12 +137,16 @@ async def get_event(

path = _create_v1_path("/event/%s", event_id)
return await self.client.get_json(
destination, path=path, timeout=timeout, try_trailing_slash_on_400=True
destination,
path=path,
timeout=timeout,
try_trailing_slash_on_400=True,
parser=JsonDictParser(),
)

async def backfill(
self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
) -> Optional[JsonDict]:
) -> Optional[Union[JsonDict, list]]:
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
"""Requests `limit` previous PDUs in a given context before list of
PDUs.

Expand Down Expand Up @@ -248,6 +261,7 @@ async def send_transaction(
long_retries=True,
backoff_on_404=True, # If we get a 404 the other side has gone
try_trailing_slash_on_400=True,
parser=JsonDictParser(),
)

async def make_query(
Expand All @@ -268,6 +282,7 @@ async def make_query(
retry_on_dns_fail=retry_on_dns_fail,
timeout=10000,
ignore_backoff=ignore_backoff,
parser=JsonDictParser(),
)

async def make_membership_event(
Expand Down Expand Up @@ -329,6 +344,7 @@ async def make_membership_event(
retry_on_dns_fail=retry_on_dns_fail,
timeout=20000,
ignore_backoff=ignore_backoff,
parser=JsonDictParser(),
)

async def send_join_v1(
Expand Down Expand Up @@ -388,6 +404,7 @@ async def send_leave_v1(
# server was just having a momentary blip, the room will be out of
# sync.
ignore_backoff=True,
parser=LegacyJsonDictParser(),
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
)

async def send_leave_v2(
Expand All @@ -404,6 +421,7 @@ async def send_leave_v2(
# server was just having a momentary blip, the room will be out of
# sync.
ignore_backoff=True,
parser=JsonDictParser(),
)

async def send_knock_v1(
Expand Down Expand Up @@ -436,7 +454,10 @@ async def send_knock_v1(
path = _create_v1_path("/send_knock/%s/%s", room_id, event_id)

return await self.client.put_json(
destination=destination, path=path, data=content
destination=destination,
path=path,
data=content,
parser=JsonDictParser(),
)

async def send_invite_v1(
Expand All @@ -445,7 +466,11 @@ async def send_invite_v1(
path = _create_v1_path("/invite/%s/%s", room_id, event_id)

return await self.client.put_json(
destination=destination, path=path, data=content, ignore_backoff=True
destination=destination,
path=path,
data=content,
ignore_backoff=True,
parser=LegacyJsonDictParser(),
)

async def send_invite_v2(
Expand All @@ -454,7 +479,11 @@ async def send_invite_v2(
path = _create_v2_path("/invite/%s/%s", room_id, event_id)

return await self.client.put_json(
destination=destination, path=path, data=content, ignore_backoff=True
destination=destination,
path=path,
data=content,
ignore_backoff=True,
parser=JsonDictParser(),
)

async def get_public_rooms(
Expand Down Expand Up @@ -515,7 +544,11 @@ async def get_public_rooms(

try:
response = await self.client.get_json(
destination=remote_server, path=path, args=args, ignore_backoff=True
destination=remote_server,
path=path,
args=args,
ignore_backoff=True,
parser=JsonDictParser(),
)
except HttpResponseException as e:
if e.code == 403:
Expand All @@ -535,15 +568,17 @@ async def exchange_third_party_invite(
path = _create_v1_path("/exchange_third_party_invite/%s", room_id)

return await self.client.put_json(
destination=destination, path=path, data=event_dict
destination=destination, path=path, data=event_dict, parser=JsonDictParser()
)

async def get_event_auth(
self, destination: str, room_id: str, event_id: str
) -> JsonDict:
path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)

return await self.client.get_json(destination=destination, path=path)
return await self.client.get_json(
destination=destination, path=path, parser=JsonDictParser()
)

async def query_client_keys(
self, destination: str, query_content: JsonDict, timeout: int
Expand Down Expand Up @@ -622,7 +657,7 @@ async def query_user_devices(
path = _create_v1_path("/user/devices/%s", user_id)

return await self.client.get_json(
destination=destination, path=path, timeout=timeout
destination=destination, path=path, timeout=timeout, parser=JsonDictParser()
)

async def claim_client_keys(
Expand Down Expand Up @@ -695,7 +730,9 @@ async def get_room_complexity(self, destination: str, room_id: str) -> JsonDict:
"""
path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/rooms/%s/complexity", room_id)

return await self.client.get_json(destination=destination, path=path)
return await self.client.get_json(
destination=destination, path=path, parser=JsonDictParser()
)

async def get_room_hierarchy(
self, destination: str, room_id: str, suggested_only: bool
Expand All @@ -712,6 +749,7 @@ async def get_room_hierarchy(
destination=destination,
path=path,
args={"suggested_only": "true" if suggested_only else "false"},
parser=JsonDictParser(),
)

async def get_room_hierarchy_unstable(
Expand All @@ -731,6 +769,7 @@ async def get_room_hierarchy_unstable(
destination=destination,
path=path,
args={"suggested_only": "true" if suggested_only else "false"},
parser=JsonDictParser(),
)

async def get_account_status(
Expand Down
70 changes: 54 additions & 16 deletions synapse/http/matrixfederationclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import logging
import random
import sys
import typing
import urllib.parse
from http import HTTPStatus
from io import BytesIO, StringIO
Expand All @@ -30,9 +29,11 @@
Generic,
List,
Optional,
TextIO,
Tuple,
TypeVar,
Union,
cast,
overload,
)

Expand Down Expand Up @@ -183,20 +184,54 @@ def get_json(self) -> Optional[JsonDict]:
return self.json


class JsonParser(ByteParser[Union[JsonDict, list]]):
class _BaseJsonParser(ByteParser[T]):
"""A parser that buffers the response and tries to parse it as JSON."""

CONTENT_TYPE = "application/json"

def __init__(self) -> None:
def __init__(self, validator: Optional[Callable[[Any], bool]] = None) -> None:
clokep marked this conversation as resolved.
Show resolved Hide resolved
self._buffer = StringIO()
self._binary_wrapper = BinaryIOWrapper(self._buffer)
self._validator = validator

def write(self, data: bytes) -> int:
return self._binary_wrapper.write(data)

def finish(self) -> Union[JsonDict, list]:
return json_decoder.decode(self._buffer.getvalue())
def finish(self) -> T:
result = json_decoder.decode(self._buffer.getvalue())
if self._validator is not None and not self._validator(result):
raise ValueError("Unexpected JSON object")
clokep marked this conversation as resolved.
Show resolved Hide resolved
return result


class JsonParser(_BaseJsonParser[Union[JsonDict, list]]):
"""A parser that buffers the response and tries to parse it as JSON."""


class JsonDictParser(_BaseJsonParser[JsonDict]):
"""Ensure the response is a JSON object."""
def __init__(self) -> None:
super().__init__(self._validate)

@staticmethod
def _validate(v: Any) -> bool:
return isinstance(v, dict)


class LegacyJsonDictParser(_BaseJsonParser[Tuple[int, JsonDict]]):
"""Ensure the legacy responses of /send_join & /send_leave are correct."""
def __init__(self) -> None:
super().__init__(self._validate)

@staticmethod
def _validate(v: Any) -> bool:
# Match [integer, JSON dict]
return (
isinstance(v, list)
and len(v) == 2
and type(v[0]) == int
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
and isinstance(v[1], dict)
)


async def _handle_response(
Expand Down Expand Up @@ -313,9 +348,7 @@ async def _handle_response(
class BinaryIOWrapper:
"""A wrapper for a TextIO which converts from bytes on the fly."""

def __init__(
self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
):
def __init__(self, file: TextIO, encoding: str = "utf-8", errors: str = "strict"):
self.decoder = codecs.getincrementaldecoder(encoding)(errors)
self.file = file

Expand Down Expand Up @@ -825,8 +858,8 @@ async def put_json(
ignore_backoff: bool = False,
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser] = None,
):
parser: Optional[ByteParser[T]] = None,
) -> Union[JsonDict, list, T]:
"""Sends the specified json data using PUT

Args:
Expand Down Expand Up @@ -902,7 +935,7 @@ async def put_json(
_sec_timeout = self.default_timeout

if parser is None:
parser = JsonParser()
parser = cast(ByteParser[T], JsonParser())
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

body = await _handle_response(
self.reactor,
Expand All @@ -924,7 +957,7 @@ async def post_json(
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryParams] = None,
) -> Union[JsonDict, list]:
) -> JsonDict:
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
"""Sends the specified json data using POST

Args:
Expand Down Expand Up @@ -983,7 +1016,12 @@ async def post_json(
_sec_timeout = self.default_timeout

body = await _handle_response(
self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
self.reactor,
_sec_timeout,
request,
response,
start_ms,
parser=JsonDictParser(),
)
return body

Expand Down Expand Up @@ -1024,8 +1062,8 @@ async def get_json(
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser] = None,
):
parser: Optional[ByteParser[T]] = None,
) -> Union[JsonDict, list, T]:
"""GETs some json from the given host homeserver and path

Args:
Expand Down Expand Up @@ -1091,7 +1129,7 @@ async def get_json(
_sec_timeout = self.default_timeout

if parser is None:
parser = JsonParser()
parser = cast(ByteParser[T], JsonParser())

body = await _handle_response(
self.reactor,
Expand Down