Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Separate creating an event context and persisting the event in the fed handler #9800

Merged
merged 13 commits into from
Apr 14, 2021
1 change: 1 addition & 0 deletions changelog.d/9800.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership.
175 changes: 131 additions & 44 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@

@attr.s(slots=True)
class _NewEventInfo:
"""Holds information about a received event, ready for passing to _handle_new_events
"""Holds information about a received event, ready for passing to _auth_and_persist_events

Attributes:
event: the received event
Expand Down Expand Up @@ -808,7 +808,12 @@ async def _process_received_pdu(
logger.debug("Processing event: %s", event)

try:
await self._handle_new_event(origin, event, state=state)
context, auth_events = await self._calculate_event_context(
event, state=state
)
await self._auth_and_persist_event(
origin, event, context, auth_events=auth_events, state=state
)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)

Expand Down Expand Up @@ -1011,7 +1016,9 @@ async def backfill(
)

if ev_infos:
await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
await self._auth_and_persist_events(
dest, room_id, ev_infos, backfilled=True
)

# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
Expand All @@ -1024,10 +1031,14 @@ async def backfill(
# non-outliers
assert not event.internal_metadata.is_outlier()

context, context_auth_events = await self._calculate_event_context(event)

# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
await self._handle_new_event(dest, event, backfilled=True)
await self._auth_and_persist_event(
dest, event, context, auth_events=context_auth_events, backfilled=True
)

return events

Expand Down Expand Up @@ -1361,7 +1372,7 @@ async def get_event(event_id: str):

event_infos.append(_NewEventInfo(event, None, auth))

await self._handle_new_events(
await self._auth_and_persist_events(
destination,
room_id,
event_infos,
Expand Down Expand Up @@ -1667,10 +1678,16 @@ async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
# would introduce the danger of backwards-compatibility problems.
event.internal_metadata.send_on_behalf_of = origin

context = await self._handle_new_event(origin, event)
# Calculate the event context and persist the event.
context, auth_events = await self._calculate_event_context(
event, state=None, auth_events=None
)
context = await self._auth_and_persist_event(
origin, event, context, auth_events=auth_events, backfilled=False
)

logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
"on_send_join_request: After _auth_and_persist_event: %s, sigs: %s",
event.event_id,
event.signatures,
)
Expand Down Expand Up @@ -1879,10 +1896,13 @@ async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:

event.internal_metadata.outlier = False

await self._handle_new_event(origin, event)
context, auth_events = await self._calculate_event_context(event)
await self._auth_and_persist_event(
origin, event, context, auth_events=auth_events
)

logger.debug(
"on_send_leave_request: After _handle_new_event: %s, sigs: %s",
"on_send_leave_request: After _auth_and_persist_event: %s, sigs: %s",
event.event_id,
event.signatures,
)
Expand Down Expand Up @@ -1990,16 +2010,45 @@ async def get_persisted_pdu(
async def get_min_depth_for_context(self, context: str) -> int:
return await self.store.get_min_depth(context)

async def _handle_new_event(
async def _auth_and_persist_event(
self,
origin: str,
event: EventBase,
context: EventContext,
auth_events: MutableStateMap[EventBase],
state: Optional[Iterable[EventBase]] = None,
auth_events: Optional[MutableStateMap[EventBase]] = None,
backfilled: bool = False,
) -> EventContext:
context = await self._prep_event(
origin, event, state=state, auth_events=auth_events, backfilled=backfilled
"""
Process an event by performing auth checks and then persisting to the database.

Args:
origin: The host the event originates from.
event: The event itself.
context:
The event context.

NB that this function potentially modifies it.
state: The state events used to auth the event.
auth_events:
Map from (event_type, state_key) to event

Normally, our calculated auth_events based on the state of the room
at the event's position in the DAG, though occasionally (eg if the
event is an outlier), may be the auth events claimed by the remote
server.
clokep marked this conversation as resolved.
Show resolved Hide resolved
backfilled: True if the event was backfilled.

Returns:
The event context.
"""
context = await self._check_event_auth(
origin,
event,
context,
state=state,
auth_events=auth_events,
backfilled=backfilled,
)

try:
Expand All @@ -2023,7 +2072,7 @@ async def _handle_new_event(

return context

async def _handle_new_events(
async def _auth_and_persist_events(
self,
origin: str,
room_id: str,
Expand All @@ -2041,11 +2090,17 @@ async def _handle_new_events(
async def prep(ev_info: _NewEventInfo):
event = ev_info.event
with nested_logging_context(suffix=event.event_id):
res = await self._prep_event(
origin,
res, auth_events = await self._calculate_event_context(
event,
state=ev_info.state,
auth_events=ev_info.auth_events,
)
res = await self._check_event_auth(
origin,
event,
res,
state=ev_info.state,
auth_events=auth_events,
backfilled=backfilled,
)
return res
Expand Down Expand Up @@ -2178,14 +2233,31 @@ async def _persist_auth_tree(
room_id, [(event, new_event_context)]
)

async def _prep_event(
async def _calculate_event_context(
self,
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
auth_events: Optional[MutableStateMap[EventBase]],
backfilled: bool,
) -> EventContext:
state: Optional[Iterable[EventBase]] = None,
auth_events: Optional[MutableStateMap[EventBase]] = None,
) -> Tuple[EventContext, MutableStateMap[EventBase]]:
"""
Calculate the context and auth events for a given event.

Args:
event: The event itself.
state: The state events to calculate the event context from.
clokep marked this conversation as resolved.
Show resolved Hide resolved
auth_events:
Map from (event_type, state_key) to event

Normally, our calculated auth_events based on the state of the room
at the event's position in the DAG, though occasionally (eg if the
event is an outlier), may be the auth events claimed by the remote
server.

Also NB that this function adds entries to it.

Returns:
The event context and auth events.
"""
context = await self.state_handler.compute_event_context(event, old_state=state)

if not auth_events:
clokep marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -2206,20 +2278,7 @@ async def _prep_event(
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c

context = await self.do_auth(origin, event, context, auth_events=auth_events)

if not context.rejected:
await self._check_for_soft_fail(event, state, backfilled)

if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)

# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
await self.event_creation_handler.cache_joined_hosts_for_event(event)

return context
return context, auth_events

async def _check_for_soft_fail(
self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
Expand Down Expand Up @@ -2331,19 +2390,27 @@ async def on_get_missing_events(

return missing_events

async def do_auth(
async def _check_event_auth(
self,
origin: str,
event: EventBase,
context: EventContext,
state: Optional[Iterable[EventBase]],
auth_events: MutableStateMap[EventBase],
backfilled: bool,
) -> EventContext:
"""
Checks whether an event should be rejected (for failing auth checks).

Args:
origin:
event:
origin: The host the event originates from.
event: The event itself.
context:
The event context.

NB that this function potentially modifies it.
state: The state events to calculate the event context from. This is
ignored if context is provided.
auth_events:
Map from (event_type, state_key) to event

Expand All @@ -2353,8 +2420,10 @@ async def do_auth(
server.

Also NB that this function adds entries to it.
backfilled: True if the event was backfilled.

Returns:
updated context object
The updated context object.
"""
room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
Expand All @@ -2380,6 +2449,17 @@ async def do_auth(
logger.warning("Failed auth resolution for %r because %s", event, e)
context.rejected = RejectedReason.AUTH_ERROR

if not context.rejected:
await self._check_for_soft_fail(event, state, backfilled)

if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)

# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
await self.event_creation_handler.cache_joined_hosts_for_event(event)

richvdh marked this conversation as resolved.
Show resolved Hide resolved
return context

async def _update_auth_events_and_context_for_auth(
Expand All @@ -2389,7 +2469,7 @@ async def _update_auth_events_and_context_for_auth(
context: EventContext,
auth_events: MutableStateMap[EventBase],
) -> EventContext:
"""Helper for do_auth. See there for docs.
"""Helper for _check_event_auth. See there for docs.

Checks whether a given event has the expected auth events. If it
doesn't then we talk to the remote server to compare state to see if
Expand Down Expand Up @@ -2465,13 +2545,20 @@ async def _update_auth_events_and_context_for_auth(
(e.type, e.state_key): e
for e in remote_auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
} # type: MutableStateMap[EventBase]
e.internal_metadata.outlier = True

logger.debug(
"do_auth %s missing_auth: %s", event.event_id, e.event_id
"_check_event_auth %s missing_auth: %s",
event.event_id,
e.event_id,
)
context, auth = await self._calculate_event_context(
e, auth_events=auth
)
await self._auth_and_persist_event(
origin, e, context, auth_events=auth
)
await self._handle_new_event(origin, e, auth_events=auth)

if e.event_id in event_auth_events:
auth_events[(e.type, e.state_key)] = e
Expand Down
6 changes: 4 additions & 2 deletions tests/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ def setUp(self):
)

self.handler = self.homeserver.get_federation_handler()
self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
context
self.handler._check_event_auth = (
lambda origin, event, context, state, auth_events, backfilled: succeed(
context
)
)
self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
Expand Down