Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Factor out MultiWriter token from RoomStreamToken (#16427)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston committed Oct 5, 2023
1 parent ab9c1e8 commit 009b47b
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 61 deletions.
1 change: 1 addition & 0 deletions changelog.d/16427.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Factor out `MultiWriter` token from `RoomStreamToken`.
4 changes: 2 additions & 2 deletions synapse/handlers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") ->
else:
stream_ordering = room.stream_ordering

from_key = RoomStreamToken(0, 0)
to_key = RoomStreamToken(None, stream_ordering)
from_key = RoomStreamToken(topological=0, stream=0)
to_key = RoomStreamToken(stream=stream_ordering)

# Events that we've processed in this room
written_events: Set[str] = set()
Expand Down
3 changes: 1 addition & 2 deletions synapse/handlers/initial_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ async def handle_room(event: RoomsForUser) -> None:
)
elif event.membership == Membership.LEAVE:
room_end_token = RoomStreamToken(
None,
event.stream_ordering,
stream=event.stream_ordering,
)
deferred_room_state = run_in_background(
self._state_storage_controller.get_state_for_events,
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -1708,7 +1708,7 @@ async def get_new_events(

if from_key.topological:
logger.warning("Stream has topological part!!!! %r", from_key)
from_key = RoomStreamToken(None, from_key.stream)
from_key = RoomStreamToken(stream=from_key.stream)

app_service = self.store.get_app_service_by_user_id(user.to_string())
if app_service:
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -2333,7 +2333,7 @@ async def _get_room_changes_for_initial_sync(
continue

leave_token = now_token.copy_and_replace(
StreamKeyType.ROOM, RoomStreamToken(None, event.stream_ordering)
StreamKeyType.ROOM, RoomStreamToken(stream=event.stream_ordering)
)
room_entries.append(
RoomSyncResultBuilder(
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/admin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ async def on_POST(
# RoomStreamToken expects [int] not Optional[int]
assert event.internal_metadata.stream_ordering is not None
room_token = RoomStreamToken(
event.depth, event.internal_metadata.stream_ordering
topological=event.depth, stream=event.internal_metadata.stream_ordering
)
token = await room_token.to_string(self.store)

Expand Down
22 changes: 13 additions & 9 deletions synapse/storage/databases/main/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def generate_next_token(
# when we are going backwards so we subtract one from the
# stream part.
last_stream_ordering -= 1
return RoomStreamToken(last_topo_ordering, last_stream_ordering)
return RoomStreamToken(topological=last_topo_ordering, stream=last_stream_ordering)


def _make_generic_sql_bound(
Expand Down Expand Up @@ -558,7 +558,7 @@ def get_room_max_token(self) -> RoomStreamToken:
if p > min_pos
}

return RoomStreamToken(None, min_pos, immutabledict(positions))
return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))

async def get_room_events_stream_for_rooms(
self,
Expand Down Expand Up @@ -708,7 +708,7 @@ def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
ret.reverse()

if rows:
key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
key = RoomStreamToken(stream=min(r.stream_ordering for r in rows))
else:
# Assume we didn't get anything because there was nothing to
# get.
Expand Down Expand Up @@ -969,7 +969,7 @@ async def get_current_room_stream_token_for_room_id(
topo = await self.db_pool.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
return RoomStreamToken(topo, stream_ordering)
return RoomStreamToken(topological=topo, stream=stream_ordering)

@overload
def get_stream_id_for_event_txn(
Expand Down Expand Up @@ -1033,7 +1033,9 @@ async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToke
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
)
return RoomStreamToken(row["topological_ordering"], row["stream_ordering"])
return RoomStreamToken(
topological=row["topological_ordering"], stream=row["stream_ordering"]
)

async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
"""Gets the topological token in a room after or at the given stream
Expand Down Expand Up @@ -1114,8 +1116,8 @@ def _set_before_and_after(
else:
topo = None
internal = event.internal_metadata
internal.before = RoomStreamToken(topo, stream - 1)
internal.after = RoomStreamToken(topo, stream)
internal.before = RoomStreamToken(topological=topo, stream=stream - 1)
internal.after = RoomStreamToken(topological=topo, stream=stream)
internal.order = (int(topo) if topo else 0, int(stream))

async def get_events_around(
Expand Down Expand Up @@ -1191,11 +1193,13 @@ def _get_events_around_txn(
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
results["topological_ordering"] - 1, results["stream_ordering"]
topological=results["topological_ordering"] - 1,
stream=results["stream_ordering"],
)

after_token = RoomStreamToken(
results["topological_ordering"], results["stream_ordering"]
topological=results["topological_ordering"],
stream=results["stream_ordering"],
)

rows, start_token = self._paginate_room_events_txn(
Expand Down
132 changes: 91 additions & 41 deletions synapse/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
from synapse.util.stringutils import parse_and_validate_server_name

if TYPE_CHECKING:
from typing_extensions import Self

from synapse.appservice.api import ApplicationService
from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
Expand Down Expand Up @@ -437,7 +439,78 @@ def f2(m: Match[bytes]) -> bytes:


@attr.s(frozen=True, slots=True, order=False)
class RoomStreamToken:
class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
"""An abstract stream token class for streams that supports multiple
writers.
This works by keeping track of the stream position of each writer,
represented by a default `stream` attribute and a map of instance name to
stream position of any writers that are ahead of the default stream
position.
"""

stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)

instance_map: "immutabledict[str, int]" = attr.ib(
factory=immutabledict,
validator=attr.validators.deep_mapping(
key_validator=attr.validators.instance_of(str),
value_validator=attr.validators.instance_of(int),
mapping_validator=attr.validators.instance_of(immutabledict),
),
kw_only=True,
)

@classmethod
@abc.abstractmethod
async def parse(cls, store: "DataStore", string: str) -> "Self":
"""Parse the string representation of the token."""
...

@abc.abstractmethod
async def to_string(self, store: "DataStore") -> str:
"""Serialize the token into its string representation."""
...

def copy_and_advance(self, other: "Self") -> "Self":
"""Return a new token such that if an event is after both this token and
the other token, then its after the returned token too.
"""

max_stream = max(self.stream, other.stream)

instance_map = {
instance: max(
self.instance_map.get(instance, self.stream),
other.instance_map.get(instance, other.stream),
)
for instance in set(self.instance_map).union(other.instance_map)
}

return attr.evolve(
self, stream=max_stream, instance_map=immutabledict(instance_map)
)

def get_max_stream_pos(self) -> int:
"""Get the maximum stream position referenced in this token.
The corresponding "min" position is, by definition just `self.stream`.
This is used to handle tokens that have non-empty `instance_map`, and so
reference stream positions after the `self.stream` position.
"""
return max(self.instance_map.values(), default=self.stream)

def get_stream_pos_for_instance(self, instance_name: str) -> int:
"""Get the stream position that the given writer was at at this token."""

# If we don't have an entry for the instance we can assume that it was
# at `self.stream`.
return self.instance_map.get(instance_name, self.stream)


@attr.s(frozen=True, slots=True, order=False)
class RoomStreamToken(AbstractMultiWriterStreamToken):
"""Tokens are positions between events. The token "s1" comes after event 1.
s0 s1
Expand Down Expand Up @@ -514,16 +587,8 @@ class RoomStreamToken:

topological: Optional[int] = attr.ib(
validator=attr.validators.optional(attr.validators.instance_of(int)),
)
stream: int = attr.ib(validator=attr.validators.instance_of(int))

instance_map: "immutabledict[str, int]" = attr.ib(
factory=immutabledict,
validator=attr.validators.deep_mapping(
key_validator=attr.validators.instance_of(str),
value_validator=attr.validators.instance_of(int),
mapping_validator=attr.validators.instance_of(immutabledict),
),
kw_only=True,
default=None,
)

def __attrs_post_init__(self) -> None:
Expand Down Expand Up @@ -583,17 +648,7 @@ def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken":
if self.topological or other.topological:
raise Exception("Can't advance topological tokens")

max_stream = max(self.stream, other.stream)

instance_map = {
instance: max(
self.instance_map.get(instance, self.stream),
other.instance_map.get(instance, other.stream),
)
for instance in set(self.instance_map).union(other.instance_map)
}

return RoomStreamToken(None, max_stream, immutabledict(instance_map))
return super().copy_and_advance(other)

def as_historical_tuple(self) -> Tuple[int, int]:
"""Returns a tuple of `(topological, stream)` for historical tokens.
Expand All @@ -619,16 +674,6 @@ def get_stream_pos_for_instance(self, instance_name: str) -> int:
# at `self.stream`.
return self.instance_map.get(instance_name, self.stream)

def get_max_stream_pos(self) -> int:
"""Get the maximum stream position referenced in this token.
The corresponding "min" position is, by definition just `self.stream`.
This is used to handle tokens that have non-empty `instance_map`, and so
reference stream positions after the `self.stream` position.
"""
return max(self.instance_map.values(), default=self.stream)

async def to_string(self, store: "DataStore") -> str:
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
Expand Down Expand Up @@ -838,23 +883,28 @@ def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
return getattr(self, key.value)


StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0)
StreamToken.START = StreamToken(RoomStreamToken(stream=0), 0, 0, 0, 0, 0, 0, 0, 0, 0)


@attr.s(slots=True, frozen=True, auto_attribs=True)
class PersistedEventPosition:
"""Position of a newly persisted event with instance that persisted it.
This can be used to test whether the event is persisted before or after a
RoomStreamToken.
"""
class PersistedPosition:
"""Position of a newly persisted row with instance that persisted it."""

instance_name: str
stream: int

def persisted_after(self, token: RoomStreamToken) -> bool:
def persisted_after(self, token: AbstractMultiWriterStreamToken) -> bool:
return token.get_stream_pos_for_instance(self.instance_name) < self.stream


@attr.s(slots=True, frozen=True, auto_attribs=True)
class PersistedEventPosition(PersistedPosition):
"""Position of a newly persisted event with instance that persisted it.
This can be used to test whether the event is persisted before or after a
RoomStreamToken.
"""

def to_room_stream_token(self) -> RoomStreamToken:
"""Converts the position to a room stream token such that events
persisted in the same room after this position will be after the
Expand All @@ -865,7 +915,7 @@ def to_room_stream_token(self) -> RoomStreamToken:
"""
# Doing the naive thing satisfies the desired properties described in
# the docstring.
return RoomStreamToken(None, self.stream)
return RoomStreamToken(stream=self.stream)


@attr.s(slots=True, frozen=True, auto_attribs=True)
Expand Down
8 changes: 4 additions & 4 deletions tests/handlers/test_appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_notify_interested_services(self) -> None:
[event],
]
)
self.handler.notify_interested_services(RoomStreamToken(None, 1))
self.handler.notify_interested_services(RoomStreamToken(stream=1))

self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
interested_service, events=[event]
Expand All @@ -107,7 +107,7 @@ def test_query_user_exists_unknown_user(self) -> None:
]
)
self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]])
self.handler.notify_interested_services(RoomStreamToken(None, 0))
self.handler.notify_interested_services(RoomStreamToken(stream=0))

self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)

Expand All @@ -126,7 +126,7 @@ def test_query_user_exists_known_user(self) -> None:
]
)

self.handler.notify_interested_services(RoomStreamToken(None, 0))
self.handler.notify_interested_services(RoomStreamToken(stream=0))

self.assertFalse(
self.mock_as_api.query_user.called,
Expand Down Expand Up @@ -441,7 +441,7 @@ def _notify_interested_services(self) -> None:
self.get_success(
self.hs.get_application_service_handler()._notify_interested_services(
RoomStreamToken(
None, self.hs.get_application_service_handler().current_max
stream=self.hs.get_application_service_handler().current_max
)
)
)
Expand Down

0 comments on commit 009b47b

Please sign in to comment.