Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use an enum for direction. #14927

Merged
merged 2 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14927.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints.
7 changes: 7 additions & 0 deletions synapse/api/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

"""Contains constants from the specification."""

import enum

from typing_extensions import Final

# the max size of a (canonical-json-encoded) event
Expand Down Expand Up @@ -290,3 +292,8 @@ class ApprovalNoticeMedium:

NONE = "org.matrix.msc3866.none"
EMAIL = "org.matrix.msc3866.email"


class Direction(enum.Enum):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this file is the best place for it, but they are constants from the spec?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a reasonable place to put them if they come from the spec!

BACKWARDS = "b"
FORWARDS = "f"
4 changes: 2 additions & 2 deletions synapse/handlers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set

from synapse.api.constants import Membership
from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.visibility import filter_events_for_client
Expand Down Expand Up @@ -197,7 +197,7 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") ->
# efficient method perhaps but it does guarantee we get everything.
while True:
events, _ = await self.store.paginate_room_events(
room_id, from_key, to_key, limit=100, direction="f"
room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS
)
if not events:
break
Expand Down
16 changes: 14 additions & 2 deletions synapse/handlers/initial_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, cast

from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership
from synapse.api.constants import (
AccountDataTypes,
Direction,
EduTypes,
EventTypes,
Membership,
)
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig
Expand Down Expand Up @@ -57,7 +63,13 @@ def __init__(self, hs: "HomeServer"):
self.validator = EventValidator()
self.snapshot_cache: ResponseCache[
Tuple[
str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool
str,
Optional[StreamToken],
Optional[StreamToken],
Direction,
int,
bool,
bool,
]
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from twisted.python.failure import Failure

from synapse.api.constants import EventTypes, Membership
from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import SerializeEventConfig
Expand Down Expand Up @@ -448,7 +448,7 @@ async def get_messages(

if pagin_config.from_token:
from_token = pagin_config.from_token
elif pagin_config.direction == "f":
elif pagin_config.direction == Direction.FORWARDS:
from_token = (
await self.hs.get_event_sources().get_start_token_for_pagination(
room_id
Expand Down Expand Up @@ -476,7 +476,7 @@ async def get_messages(
room_id, requester, allow_departed_users=True
)

if pagin_config.direction == "b":
if pagin_config.direction == Direction.BACKWARDS:
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:
Expand Down
8 changes: 6 additions & 2 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import attr

from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
from synapse.logging.context import make_deferred_yieldable, run_in_background
Expand Down Expand Up @@ -413,7 +413,11 @@ async def _get_threads_for_events(

# Attempt to find another event to use as the latest event.
potential_events, _ = await self._main_store.get_relations_for_event(
event_id, event, room_id, RelationTypes.THREAD, direction="f"
event_id,
event,
room_id,
RelationTypes.THREAD,
direction=Direction.FORWARDS,
)

# Filter out ignored users.
Expand Down
8 changes: 4 additions & 4 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import attr

from synapse.api.constants import MAIN_TIMELINE, RelationTypes
from synapse.api.constants import MAIN_TIMELINE, Direction, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
Expand Down Expand Up @@ -168,7 +168,7 @@ async def get_relations_for_event(
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
limit: int = 5,
direction: str = "b",
direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
Expand All @@ -181,8 +181,8 @@ async def get_relations_for_event(
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`).
direction: Whether to fetch the most recent first (backwards) or the
oldest first (forwards).
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.

Expand Down
59 changes: 31 additions & 28 deletions synapse/storage/databases/main/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

from twisted.internet import defer

from synapse.api.constants import Direction
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
Expand Down Expand Up @@ -86,7 +87,6 @@
_STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"


# Used as return values for pagination APIs
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventDictReturn:
Expand All @@ -104,7 +104,7 @@ class _EventsAround:


def generate_pagination_where_clause(
direction: str,
direction: Direction,
column_names: Tuple[str, str],
from_token: Optional[Tuple[Optional[int], int]],
to_token: Optional[Tuple[Optional[int], int]],
Expand All @@ -130,27 +130,26 @@ def generate_pagination_where_clause(
token, but include those that match the to token.

Args:
direction: Whether we're paginating backwards("b") or forwards ("f").
direction: Whether we're paginating backwards or forwards.
column_names: The column names to bound. Must *not* be user defined as
these get inserted directly into the SQL statement without escapes.
from_token: The start point for the pagination. This is an exclusive
minimum bound if direction is "f", and an inclusive maximum bound if
direction is "b".
minimum bound if direction is forwards, and an inclusive maximum bound if
direction is backwards.
to_token: The endpoint point for the pagination. This is an inclusive
maximum bound if direction is "f", and an exclusive minimum bound if
direction is "b".
maximum bound if direction is forwards, and an exclusive minimum bound if
direction is backwards.
engine: The database engine to generate the clauses for

Returns:
The sql expression
"""
assert direction in ("b", "f")

where_clause = []
if from_token:
where_clause.append(
_make_generic_sql_bound(
bound=">=" if direction == "b" else "<",
bound=">=" if direction == Direction.BACKWARDS else "<",
column_names=column_names,
values=from_token,
engine=engine,
Expand All @@ -160,7 +159,7 @@ def generate_pagination_where_clause(
if to_token:
where_clause.append(
_make_generic_sql_bound(
bound="<" if direction == "b" else ">=",
bound="<" if direction == Direction.BACKWARDS else ">=",
column_names=column_names,
values=to_token,
engine=engine,
Expand All @@ -171,7 +170,7 @@ def generate_pagination_where_clause(


def generate_pagination_bounds(
direction: str,
direction: Direction,
from_token: Optional[RoomStreamToken],
to_token: Optional[RoomStreamToken],
) -> Tuple[
Expand All @@ -181,7 +180,7 @@ def generate_pagination_bounds(
Generate a start and end point for this page of events.

Args:
direction: Whether pagination is going forwards or backwards. One of "f" or "b".
direction: Whether pagination is going forwards or backwards.
from_token: The token to start pagination at, or None to start at the first value.
to_token: The token to end pagination at, or None to not limit the end point.

Expand All @@ -201,7 +200,7 @@ def generate_pagination_bounds(
# Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
if direction == "b":
if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
Expand All @@ -215,7 +214,7 @@ def generate_pagination_bounds(
if from_token:
if from_token.topological is not None:
from_bound = from_token.as_historical_tuple()
elif direction == "b":
elif direction == Direction.BACKWARDS:
from_bound = (
None,
from_token.get_max_stream_pos(),
Expand All @@ -230,7 +229,7 @@ def generate_pagination_bounds(
if to_token:
if to_token.topological is not None:
to_bound = to_token.as_historical_tuple()
elif direction == "b":
elif direction == Direction.BACKWARDS:
to_bound = (
None,
to_token.stream,
Expand All @@ -245,20 +244,20 @@ def generate_pagination_bounds(


def generate_next_token(
direction: str, last_topo_ordering: int, last_stream_ordering: int
direction: Direction, last_topo_ordering: int, last_stream_ordering: int
) -> RoomStreamToken:
"""
Generate the next room stream token based on the currently returned data.

Args:
direction: Whether pagination is going forwards or backwards. One of "f" or "b".
direction: Whether pagination is going forwards or backwards.
last_topo_ordering: The last topological ordering being returned.
last_stream_ordering: The last stream ordering being returned.

Returns:
A new RoomStreamToken to return to the client.
"""
if direction == "b":
if direction == Direction.BACKWARDS:
# Tokens are positions between events.
# This token points *after* the last event in the chunk.
# We need it to point to the event before it in the chunk
Expand Down Expand Up @@ -1201,7 +1200,7 @@ def _get_events_around_txn(
txn,
room_id,
before_token,
direction="b",
direction=Direction.BACKWARDS,
limit=before_limit,
event_filter=event_filter,
)
Expand All @@ -1211,7 +1210,7 @@ def _get_events_around_txn(
txn,
room_id,
after_token,
direction="f",
direction=Direction.FORWARDS,
limit=after_limit,
event_filter=event_filter,
)
Expand Down Expand Up @@ -1374,7 +1373,7 @@ def _paginate_room_events_txn(
room_id: str,
from_token: RoomStreamToken,
to_token: Optional[RoomStreamToken] = None,
direction: str = "b",
direction: Direction = Direction.BACKWARDS,
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
Expand All @@ -1385,8 +1384,8 @@ def _paginate_room_events_txn(
room_id
from_token: The token used to stream from
to_token: A token which if given limits the results to only those before
direction: Either 'b' or 'f' to indicate whether we are paginating
forwards or backwards from `from_key`.
direction: Indicates whether we are paginating forwards or backwards
from `from_key`.
limit: The maximum number of events to return.
event_filter: If provided filters the events to
those that match the filter.
Expand Down Expand Up @@ -1489,8 +1488,12 @@ def _paginate_room_events_txn(
_EventDictReturn(event_id, topological_ordering, stream_ordering)
for event_id, instance_name, topological_ordering, stream_ordering in txn
if _filter_results(
lower_token=to_token if direction == "b" else from_token,
upper_token=from_token if direction == "b" else to_token,
lower_token=to_token
if direction == Direction.BACKWARDS
else from_token,
upper_token=from_token
if direction == Direction.BACKWARDS
else to_token,
instance_name=instance_name,
topological_ordering=topological_ordering,
stream_ordering=stream_ordering,
Expand All @@ -1514,7 +1517,7 @@ async def paginate_room_events(
room_id: str,
from_key: RoomStreamToken,
to_key: Optional[RoomStreamToken] = None,
direction: str = "b",
direction: Direction = Direction.BACKWARDS,
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[EventBase], RoomStreamToken]:
Expand All @@ -1524,8 +1527,8 @@ async def paginate_room_events(
room_id
from_key: The token used to stream from
to_key: A token which if given limits the results to only those before
direction: Either 'b' or 'f' to indicate whether we are paginating
forwards or backwards from `from_key`.
direction: Indicates whether we are paginating forwards or backwards
from `from_key`.
limit: The maximum number of events to return.
event_filter: If provided filters the events to those that match the filter.

Expand Down
11 changes: 8 additions & 3 deletions synapse/streams/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import attr

from synapse.api.constants import Direction
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest
Expand All @@ -34,7 +35,7 @@ class PaginationConfig:

from_token: Optional[StreamToken]
to_token: Optional[StreamToken]
direction: str
direction: Direction
limit: int

@classmethod
Expand All @@ -45,9 +46,13 @@ async def from_request(
default_limit: int,
default_dir: str = "f",
) -> "PaginationConfig":
direction = parse_string(
request, "dir", default=default_dir, allowed_values=["f", "b"]
direction_str = parse_string(
request,
"dir",
default=default_dir,
allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value],
)
direction = Direction(direction_str)
Comment on lines +49 to +55
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do this often we might want to make a parse_string_to_enum or something, but for this single instance it was easy enough...especially because this only has 2 values.


from_tok_str = parse_string(request, "from")
to_tok_str = parse_string(request, "to")
Expand Down