Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add helper to parse an enum from query args & use it. #14956

Merged
merged 9 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14956.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints.
15 changes: 10 additions & 5 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import attr
from prometheus_client import Counter

from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
from synapse.api.errors import (
CodeMessageException,
Codes,
Expand Down Expand Up @@ -1680,7 +1680,12 @@ async def send_request(
return result

async def timestamp_to_event(
self, *, destinations: List[str], room_id: str, timestamp: int, direction: str
self,
*,
destinations: List[str],
room_id: str,
timestamp: int,
direction: Direction,
) -> Optional["TimestampToEventResponse"]:
"""
Calls each remote federating server from `destinations` asking for their closest
Expand All @@ -1693,7 +1698,7 @@ async def timestamp_to_event(
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.

Returns:
Expand Down Expand Up @@ -1738,7 +1743,7 @@ async def _timestamp_to_event_from_destination(
return None

async def _timestamp_to_event_from_destination(
self, destination: str, room_id: str, timestamp: int, direction: str
self, destination: str, room_id: str, timestamp: int, direction: Direction
) -> "TimestampToEventResponse":
"""
Calls a remote federating server at `destination` asking for their
Expand All @@ -1751,7 +1756,7 @@ async def _timestamp_to_event_from_destination(
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.

Returns:
Expand Down
12 changes: 9 additions & 3 deletions synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@
from twisted.internet.abstract import isIPAddress
from twisted.python import failure

from synapse.api.constants import EduTypes, EventContentFields, EventTypes, Membership
from synapse.api.constants import (
Direction,
EduTypes,
EventContentFields,
EventTypes,
Membership,
)
from synapse.api.errors import (
AuthError,
Codes,
Expand Down Expand Up @@ -218,7 +224,7 @@ async def on_backfill_request(
return 200, res

async def on_timestamp_to_event_request(
self, origin: str, room_id: str, timestamp: int, direction: str
self, origin: str, room_id: str, timestamp: int, direction: Direction
) -> Tuple[int, Dict[str, Any]]:
"""When we receive a federated `/timestamp_to_event` request,
handle all of the logic for validating and fetching the event.
Expand All @@ -228,7 +234,7 @@ async def on_timestamp_to_event_request(
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.

Returns:
Expand Down
8 changes: 4 additions & 4 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import attr
import ijson

from synapse.api.constants import Membership
from synapse.api.constants import Direction, Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.api.urls import (
Expand Down Expand Up @@ -169,7 +169,7 @@ async def backfill(
)

async def timestamp_to_event(
self, destination: str, room_id: str, timestamp: int, direction: str
self, destination: str, room_id: str, timestamp: int, direction: Direction
) -> Union[JsonDict, List]:
"""
Calls a remote federating server at `destination` asking for their
Expand All @@ -180,7 +180,7 @@ async def timestamp_to_event(
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.

Returns:
Expand All @@ -194,7 +194,7 @@ async def timestamp_to_event(
room_id,
)

args = {"ts": [str(timestamp)], "dir": [direction]}
args = {"ts": [str(timestamp)], "dir": [direction.value]}

remote_response = await self.client.get_json(
destination, path=path, args=args, try_trailing_slash_on_400=True
Expand Down
7 changes: 4 additions & 3 deletions synapse/federation/transport/server/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from typing_extensions import Literal

from synapse.api.constants import EduTypes
from synapse.api.constants import Direction, EduTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX, FEDERATION_V2_PREFIX
Expand Down Expand Up @@ -234,9 +234,10 @@ async def on_GET(
room_id: str,
) -> Tuple[int, JsonDict]:
timestamp = parse_integer_from_args(query, "ts", required=True)
direction = parse_string_from_args(
query, "dir", default="f", allowed_values=["f", "b"], required=True
direction_str = parse_string_from_args(
query, "dir", allowed_values=["f", "b"], required=True
)
direction = Direction(direction_str)

return await self.handler.on_timestamp_to_event_request(
origin, room_id, timestamp, direction
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main

def get_current_key(self, direction: str = "f") -> int:
def get_current_key(self) -> int:
return self.store.get_max_account_data_stream_id()

async def get_new_events(
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,5 +315,5 @@ async def get_new_events_as(

return events, to_key

def get_current_key(self, direction: str = "f") -> int:
def get_current_key(self) -> int:
return self.store.get_max_receipt_stream_id()
9 changes: 5 additions & 4 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import synapse.events.snapshot
from synapse.api.constants import (
Direction,
EventContentFields,
EventTypes,
GuestAccess,
Expand Down Expand Up @@ -1487,7 +1488,7 @@ async def get_event_for_timestamp(
requester: Requester,
room_id: str,
timestamp: int,
direction: str,
direction: Direction,
) -> Tuple[str, int]:
"""Find the closest event to the given timestamp in the given direction.
If we can't find an event locally or the event we have locally is next to a gap,
Expand All @@ -1498,7 +1499,7 @@ async def get_event_for_timestamp(
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.

Returns:
Expand Down Expand Up @@ -1533,13 +1534,13 @@ async def get_event_for_timestamp(
local_event_id, allow_none=False, allow_rejected=False
)

if direction == "f":
if direction == Direction.FORWARDS:
# We only need to check for a backward gap if we're looking forwards
# to ensure there is nothing in between.
is_event_next_to_backward_gap = (
await self.store.is_event_next_to_backward_gap(local_event)
)
elif direction == "b":
elif direction == Direction.BACKWARDS:
# We only need to check for a forward gap if we're looking backwards
# to ensure there is nothing in between
is_event_next_to_forward_gap = (
Expand Down
65 changes: 65 additions & 0 deletions synapse/http/servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

""" This module contains base REST classes for constructing REST servlets. """
import enum
import logging
from http import HTTPStatus
from typing import (
Expand Down Expand Up @@ -362,6 +363,7 @@ def parse_string(
request: Request,
name: str,
*,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
Expand Down Expand Up @@ -413,6 +415,69 @@ def parse_string(
)


EnumT = TypeVar("EnumT", bound=enum.Enum)


@overload
def parse_enum(
request: Request,
name: str,
E: Type[EnumT],
default: EnumT,
) -> EnumT:
...


@overload
def parse_enum(
request: Request,
name: str,
E: Type[EnumT],
*,
required: Literal[True],
) -> EnumT:
...


def parse_enum(
request: Request,
name: str,
E: Type[EnumT],
default: Optional[EnumT] = None,
required: bool = False,
) -> Optional[EnumT]:
"""
Parse an enum parameter from the request query string.

Args:
request: the twisted HTTP request.
name: the name of the query parameter.
E: the enum which represents valid values
default: enum value to use if the parameter is absent, defaults to None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.

Returns:
An enum value.

Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and
is not one of those allowed values.
"""
# TODO Assert the enum values are strings.
clokep marked this conversation as resolved.
Show resolved Hide resolved
str_value = parse_string(
request,
name,
default=default.value if default is not None else None,
required=required,
allowed_values=[e.value for e in E],
)
if str_value is None:
return None
return E(str_value)


def _parse_string_value(
value: bytes,
allowed_values: Optional[Iterable[str]],
Expand Down
12 changes: 3 additions & 9 deletions synapse/rest/admin/event_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple

from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict
Expand Down Expand Up @@ -60,7 +61,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:

start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
direction = parse_string(request, "dir", default="b")
direction = parse_enum(request, "dir", Direction, Direction.BACKWARDS)
user_id = parse_string(request, "user_id")
room_id = parse_string(request, "room_id")

Expand All @@ -78,13 +79,6 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
errcode=Codes.INVALID_PARAM,
)

if direction not in ("f", "b"):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Unknown direction: %s" % (direction,),
errcode=Codes.INVALID_PARAM,
)

event_reports, total = await self.store.get_event_reports_paginate(
start, limit, direction, user_id, room_id
)
Expand Down
7 changes: 4 additions & 3 deletions synapse/rest/admin/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple

from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.federation.transport.server import Authenticator
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.storage.databases.main.transactions import DestinationSortOrder
Expand Down Expand Up @@ -79,7 +80,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
allowed_values=[dest.value for dest in DestinationSortOrder],
)

direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)

destinations, total = await self._store.get_destinations_paginate(
start, limit, destination, order_by, direction
Expand Down Expand Up @@ -192,7 +193,7 @@ async def on_GET(
errcode=Codes.INVALID_PARAM,
)

direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)

rooms, total = await self._store.get_destination_rooms_paginate(
destination, start, limit, direction
Expand Down