Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Refactor EventContext #12689

Merged
merged 9 commits into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12689.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor `EventContext` class.
177 changes: 32 additions & 145 deletions synapse/events/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@
from frozendict import frozendict
from typing_extensions import Literal

from twisted.internet.defer import Deferred

from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import JsonDict, StateMap

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,6 +58,9 @@ class EventContext:
If ``state_group`` is None (ie, the event is an outlier),
``state_group_before_event`` will always also be ``None``.

state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None
then this is the delta of the state between the two groups.

prev_group: If it is known, ``state_group``'s prev_group. Note that this being
None does not necessarily mean that ``state_group`` does not have
a prev_group!
Expand All @@ -79,73 +79,47 @@ class EventContext:
app_service: If this event is being sent by a (local) application service, that
app service.

_current_state_ids: The room state map, including this event - ie, the state
in ``state_group``.

(type, state_key) -> event_id

For an outlier, this is {}

Note that this is a private attribute: it should be accessed via
``get_current_state_ids``. _AsyncEventContext impl calculates this
on-demand: it will be None until that happens.

_prev_state_ids: The room state map, excluding this event - ie, the state
in ``state_group_before_event``. For a non-state
event, this will be the same as _current_state_events.

Note that it is a completely different thing to prev_group!

(type, state_key) -> event_id

For an outlier, this is {}

As with _current_state_ids, this is a private attribute. It should be
accessed via get_prev_state_ids.

partial_state: if True, we may be storing this event with a temporary,
incomplete state.
"""

_storage: "Storage"
rejected: Union[Literal[False], str] = False
_state_group: Optional[int] = None
state_group_before_event: Optional[int] = None
_state_delta_due_to_event: Optional[StateMap[str]] = None
prev_group: Optional[int] = None
delta_ids: Optional[StateMap[str]] = None
app_service: Optional[ApplicationService] = None

_current_state_ids: Optional[StateMap[str]] = None
_prev_state_ids: Optional[StateMap[str]] = None

partial_state: bool = False

@staticmethod
def with_state(
storage: "Storage",
state_group: Optional[int],
state_group_before_event: Optional[int],
current_state_ids: Optional[StateMap[str]],
prev_state_ids: Optional[StateMap[str]],
state_delta_due_to_event: Optional[StateMap[str]],
partial_state: bool,
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
) -> "EventContext":
return EventContext(
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
storage=storage,
state_group=state_group,
state_group_before_event=state_group_before_event,
state_delta_due_to_event=state_delta_due_to_event,
prev_group=prev_group,
delta_ids=delta_ids,
partial_state=partial_state,
)

@staticmethod
def for_outlier() -> "EventContext":
def for_outlier(
storage: "Storage",
) -> "EventContext":
"""Return an EventContext instance suitable for persisting an outlier event"""
return EventContext(
current_state_ids={},
prev_state_ids={},
)
return EventContext(storage=storage)

async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
"""Converts self to a type that can be serialized as JSON, and then
Expand All @@ -158,24 +132,14 @@ async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
The serialized event.
"""

# We don't serialize the full state dicts, instead they get pulled out
# of the DB on the other side. However, the other side can't figure out
# the prev_state_ids, so if we're a state event we include the event
# id that we replaced in the state.
if event.is_state():
prev_state_ids = await self.get_prev_state_ids()
prev_state_id = prev_state_ids.get((event.type, event.state_key))
else:
prev_state_id = None

return {
"prev_state_id": prev_state_id,
"event_type": event.type,
"event_state_key": event.get_state_key(),
"state_group": self._state_group,
"state_group_before_event": self.state_group_before_event,
"rejected": self.rejected,
"prev_group": self.prev_group,
"state_delta_due_to_event": _encode_state_dict(
self._state_delta_due_to_event
),
"delta_ids": _encode_state_dict(self.delta_ids),
"app_service_id": self.app_service.id if self.app_service else None,
"partial_state": self.partial_state,
Expand All @@ -193,16 +157,16 @@ def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
Returns:
The event context.
"""
context = _AsyncEventContextImpl(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So happy to see this logic disappear. 🔥

context = EventContext(
# We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids.
storage=storage,
prev_state_id=input["prev_state_id"],
event_type=input["event_type"],
event_state_key=input["event_state_key"],
state_group=input["state_group"],
state_group_before_event=input["state_group_before_event"],
prev_group=input["prev_group"],
state_delta_due_to_event=_decode_state_dict(
input["state_delta_due_to_event"]
),
delta_ids=_decode_state_dict(input["delta_ids"]),
rejected=input["rejected"],
partial_state=input.get("partial_state", False),
Expand Down Expand Up @@ -250,8 +214,15 @@ async def get_current_state_ids(self) -> Optional[StateMap[str]]:
if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event")

await self._ensure_fetched()
return self._current_state_ids
assert self._state_delta_due_to_event is not None

prev_state_ids = await self.get_prev_state_ids()

if self._state_delta_due_to_event:
prev_state_ids = dict(prev_state_ids)
prev_state_ids.update(self._state_delta_due_to_event)

return prev_state_ids

async def get_prev_state_ids(self) -> StateMap[str]:
"""
Expand All @@ -266,94 +237,10 @@ async def get_prev_state_ids(self) -> StateMap[str]:
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
await self._ensure_fetched()
# There *should* be previous state IDs now.
assert self._prev_state_ids is not None
return self._prev_state_ids

def get_cached_current_state_ids(self) -> Optional[StateMap[str]]:
"""Gets the current state IDs if we have them already cached.

It is an error to access this for a rejected event, since rejected state should
not make it into the room state. This method will raise an exception if
``rejected`` is set.

Returns:
Returns None if we haven't cached the state or if state_group is None
(which happens when the associated event is an outlier).

Otherwise, returns the the current state IDs.
"""
if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event")

return self._current_state_ids

async def _ensure_fetched(self) -> None:
return None


@attr.s(slots=True)
class _AsyncEventContextImpl(EventContext):
"""
An implementation of EventContext which fetches _current_state_ids and
_prev_state_ids from the database on demand.

Attributes:

_storage

_fetching_state_deferred: Resolves when *_state_ids have been calculated.
None if we haven't started calculating yet

_event_type: The type of the event the context is associated with.

_event_state_key: The state_key of the event the context is associated with.

_prev_state_id: If the event associated with the context is a state event,
then `_prev_state_id` is the event_id of the state that was replaced.
"""

# This needs to have a default as we're inheriting
_storage: "Storage" = attr.ib(default=None)
_prev_state_id: Optional[str] = attr.ib(default=None)
_event_type: str = attr.ib(default=None)
_event_state_key: Optional[str] = attr.ib(default=None)
_fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)

async def _ensure_fetched(self) -> None:
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(self._fill_out_state)

await make_deferred_yieldable(self._fetching_state_deferred)

async def _fill_out_state(self) -> None:
"""Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database.
"""
if self.state_group is None:
# No state group means the event is an outlier. Usually the state_ids dicts are also
# pre-set to empty dicts, but they get reset when the context is serialized, so set
# them to empty dicts again here.
self._current_state_ids = {}
self._prev_state_ids = {}
return

current_state_ids = await self._storage.state.get_state_ids_for_group(
self.state_group
assert self.state_group_before_event is not None
return await self._storage.state.get_state_ids_for_group(
self.state_group_before_event
)
# Set this separately so mypy knows current_state_ids is not None.
self._current_state_ids = current_state_ids
if self._event_state_key is not None:
self._prev_state_ids = dict(current_state_ids)

key = (self._event_type, self._event_state_key)
if self._prev_state_id:
self._prev_state_ids[key] = self._prev_state_id
else:
self._prev_state_ids.pop(key, None)
else:
self._prev_state_ids = current_state_ids


def _encode_state_dict(
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ async def do_knock(
# in the invitee's sync stream. It is stripped out for all other local users.
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]

context = EventContext.for_outlier()
context = EventContext.for_outlier(self.storage)
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
Expand Down Expand Up @@ -848,7 +848,7 @@ async def on_invite_request(
)
)

context = EventContext.for_outlier()
context = EventContext.for_outlier(self.storage)
await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
Expand Down Expand Up @@ -877,7 +877,7 @@ async def do_remotely_reject_invite(

await self.federation_client.send_leave(host_list, event)

context = EventContext.for_outlier()
context = EventContext.for_outlier(self.storage)
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,7 @@ def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
# we're not bothering about room state, so flag the event as an outlier.
event.internal_metadata.outlier = True

context = EventContext.for_outlier()
context = EventContext.for_outlier(self._storage)
try:
validate_event_for_room_version(room_version_obj, event)
check_auth_rules_for_event(room_version_obj, event, auth)
Expand Down Expand Up @@ -1874,10 +1874,10 @@ async def _update_context_for_auth_events(
)

return EventContext.with_state(
storage=self._storage,
state_group=state_group,
state_group_before_event=context.state_group_before_event,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
state_delta_due_to_event=state_updates,
prev_group=prev_group,
delta_ids=state_updates,
partial_state=context.partial_state,
Expand Down
6 changes: 5 additions & 1 deletion synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,10 @@ async def deduplicate_state_event(
The previous version of the event is returned, if it is found in the
event context. Otherwise, None is returned.
"""
if event.internal_metadata.is_outlier():
# This can happen due to out of band memberships
return None

prev_state_ids = await context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id:
Expand Down Expand Up @@ -1001,7 +1005,7 @@ async def create_new_client_event(
# after it is created
if builder.internal_metadata.outlier:
event.internal_metadata.outlier = True
context = EventContext.for_outlier()
context = EventContext.for_outlier(self.storage)
elif (
event.type == EventTypes.MSC2716_INSERTION
and state_event_ids
Expand Down
4 changes: 4 additions & 0 deletions synapse/push/action_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,9 @@ def __init__(self, hs: "HomeServer"):
async def handle_push_actions_for_event(
self, event: EventBase, context: EventContext
) -> None:
if event.internal_metadata.is_outlier():
# This can happen due to out of band memberships
return

with Measure(self.clock, "action_for_event_by_user"):
await self.bulk_evaluator.action_for_event_by_user(event, context)
9 changes: 5 additions & 4 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(self, hs: "HomeServer"):
self.state_store = hs.get_storage().state
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
self._storage = hs.get_storage()

@overload
async def get_current_state(
Expand Down Expand Up @@ -361,10 +362,10 @@ async def compute_event_context(

if not event.is_state():
return EventContext.with_state(
storage=self._storage,
state_group_before_event=state_group_before_event,
state_group=state_group_before_event,
current_state_ids=state_ids_before_event,
prev_state_ids=state_ids_before_event,
state_delta_due_to_event={},
prev_group=state_group_before_event_prev_group,
delta_ids=deltas_to_state_group_before_event,
partial_state=partial_state,
Expand Down Expand Up @@ -393,10 +394,10 @@ async def compute_event_context(
)

return EventContext.with_state(
storage=self._storage,
state_group=state_group_after_event,
state_group_before_event=state_group_before_event,
current_state_ids=state_ids_after_event,
prev_state_ids=state_ids_before_event,
state_delta_due_to_event=delta_ids,
prev_group=state_group_before_event,
delta_ids=delta_ids,
partial_state=partial_state,
Expand Down
Loading