Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Speed up state res in rare case we don't have all events #16116

Merged
merged 4 commits into from Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16116.bugfix
@@ -0,0 +1 @@
Fix performance of state resolutions for large, old rooms that did not have the full auth chain persisted.
184 changes: 161 additions & 23 deletions synapse/storage/databases/main/event_federation.py
Expand Up @@ -452,33 +452,56 @@ def _get_auth_chain_difference_using_cover_index_txn(
# sets.
seen_chains: Set[int] = set()

sql = """
SELECT event_id, chain_id, sequence_number
FROM event_auth_chains
WHERE %s
"""
for batch in batch_iter(initial_events, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", batch
)
txn.execute(sql % (clause,), args)
# Fetch the chain cover index for the initial set of events we're
# considering.
def fetch_chain_info(events_to_fetch: Collection[str]) -> None:
sql = """
SELECT event_id, chain_id, sequence_number
FROM event_auth_chains
WHERE %s
"""
for batch in batch_iter(events_to_fetch, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", batch
)
txn.execute(sql % (clause,), args)

for event_id, chain_id, sequence_number in txn:
chain_info[event_id] = (chain_id, sequence_number)
seen_chains.add(chain_id)
chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
for event_id, chain_id, sequence_number in txn:
chain_info[event_id] = (chain_id, sequence_number)
seen_chains.add(chain_id)
chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id

fetch_chain_info(initial_events)

# Check that we actually have a chain ID for all the events.
events_missing_chain_info = initial_events.difference(chain_info)

# The result set to return, i.e. the auth chain difference.
result: Set[str] = set()

if events_missing_chain_info:
# This can happen due to e.g. downgrade/upgrade of the server. We
# raise an exception and fall back to the previous algorithm.
logger.info(
"Unexpectedly found that events don't have chain IDs in room %s: %s",
# For some reason we have events we haven't calculated the chain
# index for, so we need to handle those separately. This should only
# happen for older rooms where the server doesn't have all the auth
# events.
result = self._fixup_auth_chain_difference_sets(
txn,
room_id,
events_missing_chain_info,
state_sets=state_sets,
events_missing_chain_info=events_missing_chain_info,
events_that_have_chain_index=chain_info,
)
raise _NoChainCoverIndex(room_id)

# We now need to refetch any events that we have added to the state
# sets.
new_events_to_fetch = {
event_id
for state_set in state_sets
for event_id in state_set
if event_id not in initial_events
}

fetch_chain_info(new_events_to_fetch)

# Corresponds to `state_sets`, except as a map from chain ID to max
# sequence number reachable from the state set.
Expand All @@ -487,8 +510,8 @@ def _get_auth_chain_difference_using_cover_index_txn(
chains: Dict[int, int] = {}
set_to_chain.append(chains)

for event_id in state_set:
chain_id, seq_no = chain_info[event_id]
for state_id in state_set:
chain_id, seq_no = chain_info[state_id]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just to make mypy happy, as the changes above caused the type of event_id to change :(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grr at how we don't get Rust's shadowing rules in Python


chains[chain_id] = max(seq_no, chains.get(chain_id, 0))

Expand Down Expand Up @@ -532,7 +555,6 @@ def _get_auth_chain_difference_using_cover_index_txn(
# from *any* state set and the minimum sequence number reachable from
# *all* state sets. Events in that range are in the auth chain
# difference.
result = set()

# Mapping from chain ID to the range of sequence numbers that should be
# pulled from the database.
Expand Down Expand Up @@ -588,6 +610,122 @@ def _get_auth_chain_difference_using_cover_index_txn(

return result

def _fixup_auth_chain_difference_sets(
self,
txn: LoggingTransaction,
room_id: str,
state_sets: List[Set[str]],
events_missing_chain_info: Set[str],
events_that_have_chain_index: Collection[str],
) -> Set[str]:
"""Helper for `_get_auth_chain_difference_using_cover_index_txn` to
handle the case where we haven't calculated the chain cover index for
all events.

This modifies `state_sets` so that they only include events that have a
chain cover index, and returns a set of event IDs that are part of the
auth difference.
Comment on lines +625 to +627
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, here's the hidden mut keyword :p

Took me a moment to notice that modifying state_sets is part of the result but not sure I have a better suggestion

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's icky, but here we are. I mostly wanted to avoid needlessly copying the state sets when its easier to just mutate tehm

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, reasonable. Half tempted to suggest a _mut suffix convention for this type of thing but probably gets overbearing soon enough. Meh, it will do.

"""

# This works similarly to the handling of unpersisted events in
# `synapse.state.v2_get_auth_chain_difference`. We uses the observation
# that if you can split the set of events into two classes X and Y,
# where no events in Y have events in X in their auth chain, then we can
# calculate the auth difference by considering X and Y separately.
#
# We do this in three steps:
# 1. Compute the set of events without chain cover index belonging to
# the auth difference.
# 2. Replacing the un-indexed events in the state_sets with their auth
# events, recursively, until the state_sets contain only indexed
# events. We can then calculate the auth difference of those state
# sets using the chain cover index.
# 3. Add the results of 1 and 2 together.

# By construction we know that all events that we haven't persisted the
# chain cover index for are contained in
# `event_auth_chain_to_calculate`, so we pull out the events from those
# rather than doing recursive queries to walk the auth chain.
#
# We pull out those events with their auth events, which gives us enough
# information to construct the auth chain of an event up to auth events
# that have the chain cover index.
sql = """
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I'm following along properly, this pulls out all (event_id, authing_event_id) pairs with event_id being in the set of chain-cover-uncalculated events for this room,

then annotates these (event_id, authing_event_id) pairs with whether the auth event has a chain cover...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then annotates these (event_id, authing_event_id) pairs with whether the auth event has a chain cover...

I'm not sure what you mean by "annotate" here? We look at all the events we've pulled out to see if they are indexed and mark them as such?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'annotate' as in labels that keypair with a bool value in a map ;-)

SELECT tc.event_id, ea.auth_id, eac.chain_id IS NOT NULL
FROM event_auth_chain_to_calculate AS tc
LEFT JOIN event_auth AS ea USING (event_id)
LEFT JOIN event_auth_chains AS eac ON (ea.auth_id = eac.event_id)
WHERE tc.room_id = ?
"""
txn.execute(sql, (room_id,))
event_to_auth_ids: Dict[str, Set[str]] = {}
events_that_have_chain_index = set(events_that_have_chain_index)
for event_id, auth_id, auth_id_has_chain in txn:
s = event_to_auth_ids.setdefault(event_id, set())
if auth_id is not None:
s.add(auth_id)
if auth_id_has_chain:
events_that_have_chain_index.add(auth_id)

if events_missing_chain_info - event_to_auth_ids.keys():
# Uh oh, we somehow haven't correctly done the chain cover index,
# bail and fall back to the old method.
logger.info(
"Unexpectedly found that events don't have chain IDs in room %s: %s",
room_id,
events_missing_chain_info - event_to_auth_ids.keys(),
)
raise _NoChainCoverIndex(room_id)

# Create a map from event IDs we care about to their partial auth chain.
event_id_to_partial_auth_chain: Dict[str, Set[str]] = {}
for event_id, auth_ids in event_to_auth_ids.items():
if not any(event_id in state_set for state_set in state_sets):
continue

processing = set(auth_ids)
to_add = set()
while processing:
auth_id = processing.pop()
to_add.add(auth_id)

sub_auth_ids = event_to_auth_ids.get(auth_id)
if sub_auth_ids is None:
continue

processing.update(sub_auth_ids - to_add)

event_id_to_partial_auth_chain[event_id] = to_add
Comment on lines +686 to +698
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so event_id_to_partial_auth_chain[event_id] is the transitive auth-event closure of event_id?

We say 'partial auth chain' here... what's the partial in respect to?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't pull out the full auth chain, we stop when the auth events are indexed.


# Now we do two things:
# 1. Update the state sets to only include indexed events; and
# 2. Create a new list containing the auth chains of the un-indexed
# events
unindexed_state_sets: List[Set[str]] = []
for state_set in state_sets:
unindexed_state_set = set()
for event_id, auth_chain in event_id_to_partial_auth_chain.items():
if event_id not in state_set:
continue

unindexed_state_set.add(event_id)

state_set.discard(event_id)
state_set.difference_update(auth_chain)
for auth_id in auth_chain:
if auth_id in events_that_have_chain_index:
state_set.add(auth_id)
else:
unindexed_state_set.add(auth_id)

unindexed_state_sets.append(unindexed_state_set)

# Calculate and return the auth difference of the un-indexed events.
union = unindexed_state_sets[0].union(*unindexed_state_sets[1:])
intersection = unindexed_state_sets[0].intersection(*unindexed_state_sets[1:])

return union - intersection

def _get_auth_chain_difference_txn(
self, txn: LoggingTransaction, state_sets: List[Set[str]]
) -> Set[str]:
Expand Down