Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add types to StreamToken and RoomStreamToken (#8279)
Browse files Browse the repository at this point in the history
The intention here is to change `StreamToken.room_key` to be a `RoomStreamToken` in a future PR, but that is a big enough change without this refactoring too.
  • Loading branch information
erikjohnston committed Sep 8, 2020
1 parent 094896a commit 63c0e9e
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 91 deletions.
1 change: 1 addition & 0 deletions changelog.d/8279.misc
@@ -0,0 +1 @@
Add type hints to `StreamToken` and `RoomStreamToken` classes.
5 changes: 2 additions & 3 deletions synapse/handlers/sync.py
Expand Up @@ -1310,12 +1310,11 @@ async def _generate_sync_entry_for_presence(
presence_source = self.event_sources.sources["presence"]

since_token = sync_result_builder.since_token
presence_key = None
include_offline = False
if since_token and not sync_result_builder.full_state:
presence_key = since_token.presence_key
include_offline = True
else:
presence_key = None
include_offline = False

presence, presence_key = await presence_source.get_new_events(
user=user,
Expand Down
7 changes: 3 additions & 4 deletions synapse/storage/databases/main/devices.py
Expand Up @@ -481,7 +481,7 @@ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]
}

async def get_users_whose_devices_changed(
self, from_key: str, user_ids: Iterable[str]
self, from_key: int, user_ids: Iterable[str]
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
Expand All @@ -493,7 +493,6 @@ async def get_users_whose_devices_changed(
Returns:
The set of user_ids whose devices have changed since `from_key`
"""
from_key = int(from_key)

# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
Expand Down Expand Up @@ -527,7 +526,7 @@ def _get_users_whose_devices_changed_txn(txn):
)

async def get_users_whose_signatures_changed(
self, user_id: str, from_key: str
self, user_id: str, from_key: int
) -> Set[str]:
"""Get the users who have new cross-signing signatures made by `user_id` since
`from_key`.
Expand All @@ -539,7 +538,7 @@ async def get_users_whose_signatures_changed(
Returns:
A set of user IDs with updated signatures.
"""
from_key = int(from_key)

if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
sql = """
SELECT DISTINCT user_ids FROM user_signature_stream
Expand Down
21 changes: 11 additions & 10 deletions synapse/storage/databases/main/stream.py
Expand Up @@ -79,8 +79,8 @@
def generate_pagination_where_clause(
direction: str,
column_names: Tuple[str, str],
from_token: Optional[Tuple[int, int]],
to_token: Optional[Tuple[int, int]],
from_token: Optional[Tuple[Optional[int], int]],
to_token: Optional[Tuple[Optional[int], int]],
engine: BaseDatabaseEngine,
) -> str:
"""Creates an SQL expression to bound the columns by the pagination
Expand Down Expand Up @@ -535,13 +535,13 @@ async def get_recent_event_ids_for_room(
if limit == 0:
return [], end_token

end_token = RoomStreamToken.parse(end_token)
parsed_end_token = RoomStreamToken.parse(end_token)

rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
from_token=end_token,
from_token=parsed_end_token,
limit=limit,
)

Expand Down Expand Up @@ -989,8 +989,8 @@ def _paginate_room_events_txn(
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
from_token=from_token,
to_token=to_token,
from_token=from_token.as_tuple(),
to_token=to_token.as_tuple() if to_token else None,
engine=self.database_engine,
)

Expand Down Expand Up @@ -1083,16 +1083,17 @@ async def paginate_room_events(
and `to_key`).
"""

from_key = RoomStreamToken.parse(from_key)
parsed_from_key = RoomStreamToken.parse(from_key)
parsed_to_key = None
if to_key:
to_key = RoomStreamToken.parse(to_key)
parsed_to_key = RoomStreamToken.parse(to_key)

rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
from_key,
to_key,
parsed_from_key,
parsed_to_key,
direction,
limit,
event_filter,
Expand Down
152 changes: 78 additions & 74 deletions synapse/types.py
Expand Up @@ -18,7 +18,7 @@
import string
import sys
from collections import namedtuple
from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar
from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar

import attr
from signedjson.key import decode_verify_key_bytes
Expand Down Expand Up @@ -362,38 +362,95 @@ def f2(m):
return username.decode("ascii")


class StreamToken(
namedtuple(
"Token",
(
"room_key",
"presence_key",
"typing_key",
"receipt_key",
"account_data_key",
"push_rules_key",
"to_device_key",
"device_list_key",
"groups_key",
),
@attr.s(frozen=True, slots=True)
class RoomStreamToken:
"""Tokens are positions between events. The token "s1" comes after event 1.
s0 s1
| |
[0] V [1] V [2]
Tokens can either be a point in the live event stream or a cursor going
through historic events.
When traversing the live event stream events are ordered by when they
arrived at the homeserver.
When traversing historic events the events are ordered by their depth in
the event graph "topological_ordering" and then by when they arrived at the
homeserver "stream_ordering".
Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after.
"""

topological = attr.ib(
type=Optional[int],
validator=attr.validators.optional(attr.validators.instance_of(int)),
)
):
stream = attr.ib(type=int, validator=attr.validators.instance_of(int))

@classmethod
def parse(cls, string: str) -> "RoomStreamToken":
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
if string[0] == "t":
parts = string[1:].split("-", 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))

@classmethod
def parse_stream_token(cls, string: str) -> "RoomStreamToken":
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))

def as_tuple(self) -> Tuple[Optional[int], int]:
return (self.topological, self.stream)

def __str__(self) -> str:
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)


@attr.s(slots=True, frozen=True)
class StreamToken:
room_key = attr.ib(type=str)
presence_key = attr.ib(type=int)
typing_key = attr.ib(type=int)
receipt_key = attr.ib(type=int)
account_data_key = attr.ib(type=int)
push_rules_key = attr.ib(type=int)
to_device_key = attr.ib(type=int)
device_list_key = attr.ib(type=int)
groups_key = attr.ib(type=int)

_SEPARATOR = "_"
START = None # type: StreamToken

@classmethod
def from_string(cls, string):
try:
keys = string.split(cls._SEPARATOR)
while len(keys) < len(cls._fields):
while len(keys) < len(attr.fields(cls)):
# i.e. old token from before receipt_key
keys.append("0")
return cls(*keys)
return cls(keys[0], *(int(k) for k in keys[1:]))
except Exception:
raise SynapseError(400, "Invalid Token")

def to_string(self):
return self._SEPARATOR.join([str(k) for k in self])
return self._SEPARATOR.join([str(k) for k in attr.astuple(self)])

@property
def room_stream_id(self):
Expand Down Expand Up @@ -435,63 +492,10 @@ def copy_and_advance(self, key, new_value):
return self

def copy_and_replace(self, key, new_value):
return self._replace(**{key: new_value})


StreamToken.START = StreamToken(*(["s0"] + ["0"] * (len(StreamToken._fields) - 1)))


class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
"""Tokens are positions between events. The token "s1" comes after event 1.
s0 s1
| |
[0] V [1] V [2]
Tokens can either be a point in the live event stream or a cursor going
through historic events.
When traversing the live event stream events are ordered by when they
arrived at the homeserver.
When traversing historic events the events are ordered by their depth in
the event graph "topological_ordering" and then by when they arrived at the
homeserver "stream_ordering".
Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after.
"""
return attr.evolve(self, **{key: new_value})

__slots__ = [] # type: list

@classmethod
def parse(cls, string):
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
if string[0] == "t":
parts = string[1:].split("-", 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))

@classmethod
def parse_stream_token(cls, string):
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))

def __str__(self):
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)
StreamToken.START = StreamToken.from_string("s0_0")


class ThirdPartyInstanceID(
Expand Down

0 comments on commit 63c0e9e

Please sign in to comment.