Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve get auth chain difference algorithm. #7095

Merged
merged 9 commits into from Mar 18, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion synapse/state/__init__.py
Expand Up @@ -671,7 +671,7 @@ def get_auth_chain_difference(self, state_sets: List[Set[str]]):
chain.

Returns:
Deferred[Set[str]]
Deferred[Set[str]]: Set of event IDs.
"""

return self.store.get_auth_chain_difference(state_sets)
21 changes: 3 additions & 18 deletions synapse/state/v2.py
Expand Up @@ -227,25 +227,10 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
Returns:
Deferred[set[str]]: Set of event IDs
"""
auth_sets = []
for state_set in state_sets:
auth_sets.append(
{
eid
for key, eid in iteritems(state_set)
if (
key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite)
or key
in (
(EventTypes.PowerLevels, ""),
(EventTypes.Create, ""),
(EventTypes.JoinRules, ""),
)
)
}
)

difference = yield state_res_store.get_auth_chain_difference(auth_sets)
difference = yield state_res_store.get_auth_chain_difference(
[set(state_set.values()) for state_set in state_sets]
)

return difference

Expand Down
43 changes: 24 additions & 19 deletions synapse/storage/data_stores/main/event_federation.py
Expand Up @@ -174,23 +174,24 @@ def _get_auth_chain_difference_txn(
}

# We need to get the depth of the initial events for sorting purposes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we trust this value of depth?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope! Though hopefully as explained above we're using it as a hint, and so worst case if its wrong we basically pull the full auth chains out of the DB as we're currently doing

rows = self.db.simple_select_many_txn(
txn,
table="events",
column="event_id",
iterable=initial_events,
keyvalues={},
retcols=("event_id", "depth"),
sql = """
SELECT depth, event_id FROM events
WHERE %s
ORDER BY depth ASC
"""
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", initial_events
)
txn.execute(sql % (clause,), args)

# The sorted list of events we should walk the auth chain off.
search = sorted((row["depth"], row["event_id"]) for row in rows)
# The sorted list of events whose auth chains we should walk.
search = txn.fetchall()
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved

# Map from event to its auth events
event_to_auth_events = {} # type: Dict[str, Set[str]]

base_sql = """
SELECT depth, a.event_id, auth_id
SELECT a.event_id, auth_id, depth
FROM event_auth AS a
INNER JOIN events AS e ON (e.event_id = a.auth_id)
WHERE
Expand All @@ -210,18 +211,25 @@ def _get_auth_chain_difference_txn(
)
txn.execute(base_sql + clause, args)

for depth, event_id, auth_id in txn:
event_to_auth_events.setdefault(event_id, set()).add(auth_id)
for event_id, auth_event_id, auth_event_depth in txn:
event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)

if auth_id not in event_to_missing_sets:
sets = event_to_missing_sets.get(auth_event_id)
if sets is None:
# First time we're seeing this event, so we add it to the
# queue of things to fetch.
search.append((depth, auth_id))
search.append((auth_event_depth, auth_event_id))

# Assume that this event is unreachable from any of the
# state sets until proven otherwise
sets = event_to_missing_sets.setdefault(
auth_event_id, set(range(len(state_sets)))
)
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
else:
# We've previously seen this event, so look up its auth
# events and recursively mark all ancestors as reachable
# by the current events state set.
a_ids = event_to_auth_events.get(auth_id)
# by the current event's state set.
a_ids = event_to_auth_events.get(auth_event_id)
while a_ids:
new_aids = set()
for a_id in a_ids:
Expand All @@ -236,9 +244,6 @@ def _get_auth_chain_difference_txn(
a_ids = new_aids

# Mark that the auth event is reachable by the approriate sets.
sets = event_to_missing_sets.setdefault(
auth_id, set(range(len(state_sets)))
)
sets.intersection_update(event_to_missing_sets[event_id])

search.sort()
Expand Down