From d128efcaa24c1541d6fa4f9cf761b3c036868f5e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Mar 2020 13:08:07 +0000 Subject: [PATCH 1/9] Port event federation tests to new style --- tests/storage/test_event_federation.py | 36 ++++++++++++++------------ 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index a331517f4d53..abe2fd93a011 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -13,19 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - import tests.unittest import tests.utils -class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver(self.addCleanup) +class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks def test_get_prev_events_for_room(self): room_id = "@ROOM:local" @@ -61,15 +56,14 @@ def insert_event(txn, i): ) for i in range(0, 20): - yield self.store.db.runInteraction("insert", insert_event, i) + self.get_success(self.store.db.runInteraction("insert", insert_event, i)) # this should get the last ten - r = yield self.store.get_prev_events_for_room(room_id) + r = self.get_success(self.store.get_prev_events_for_room(room_id)) self.assertEqual(10, len(r)) for i in range(0, 10): self.assertEqual("$event_%i:local" % (19 - i), r[i]) - @defer.inlineCallbacks def test_get_rooms_with_many_extremities(self): room1 = "#room1" room2 = "#room2" @@ -86,25 +80,33 @@ def insert_event(txn, i, room_id): ) for i in range(0, 20): - yield self.store.db.runInteraction("insert", insert_event, i, room1) - yield self.store.db.runInteraction("insert", insert_event, i, room2) - yield self.store.db.runInteraction("insert", insert_event, i, room3) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room1) + ) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room2) + ) + self.get_success( + self.store.db.runInteraction("insert", insert_event, i, room3) + ) # Test simple case - r = yield self.store.get_rooms_with_many_extremities(5, 5, []) + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [])) self.assertEqual(len(r), 3) # Does filter work? - r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1]) + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [room1])) self.assertTrue(room2 in r) self.assertTrue(room3 in r) self.assertEqual(len(r), 2) - r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1, room2]) + r = self.get_success( + self.store.get_rooms_with_many_extremities(5, 5, [room1, room2]) + ) self.assertEqual(r, [room3]) # Does filter and limit work? - r = yield self.store.get_rooms_with_many_extremities(5, 1, [room1]) + r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) self.assertTrue(r == [room2] or r == [room3]) From 278686c1880782cdcf09cc18a94ecd030f4b8623 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Mar 2020 16:16:02 +0000 Subject: [PATCH 2/9] Improve get auth chain difference algorithm. It was originally implemented by pulling the full auth chain of all state sets out of the database and doing set comparison. However, that can take a lot work if the state and auth chains are large. Instead, lets try and fetch the auth chains at the same time and calcualte the difference on the fly, allowing us to bail early if all the auth chains converge. Assuming that the auth chains do converge more often than not, this should improve performance. Hopefully. --- synapse/state/__init__.py | 28 +--- synapse/state/v2.py | 41 ++--- .../data_stores/main/event_federation.py | 145 +++++++++++++++++- tests/state/test_v2.py | 13 +- tests/storage/test_event_federation.py | 116 ++++++++++++++ 5 files changed, 292 insertions(+), 51 deletions(-) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index df7a4f6a893a..f37c43a419bc 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -662,28 +662,16 @@ def get_events(self, event_ids, allow_rejected=False): allow_rejected=allow_rejected, ) - def get_auth_chain(self, event_ids: List[str], ignore_events: Set[str]): - """Gets the full auth chain for a set of events (including rejected - events). - - Includes the given event IDs in the result. - - Note that: - 1. All events must be state events. - 2. For v1 rooms this may not have the full auth chain in the - presence of rejected events - - Args: - event_ids: The event IDs of the events to fetch the auth chain for. - Must be state events. - ignore_events: Set of events to exclude from the returned auth - chain. + def get_auth_chain_difference(self, state_sets: List[Set[str]]): + """Given sets of state events figure out the auth chain difference (as + per state res v2 algorithm). + This equivalent to fetching the full auth chain for each set of state + and returning the events that don't appear in each and every auth + chain. Returns: - Deferred[list[str]]: List of event IDs of the auth chain. + Deferred[Set[str]] """ - return self.store.get_auth_chain_ids( - event_ids, include_given=True, ignore_events=ignore_events, - ) + return self.store.get_auth_chain_difference(state_sets) diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 0ffe6d8c1428..46e63ffde34d 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -227,36 +227,27 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): Returns: Deferred[set[str]]: Set of event IDs """ - common = set(itervalues(state_sets[0])).intersection( - *(itervalues(s) for s in state_sets[1:]) - ) - auth_sets = [] for state_set in state_sets: - auth_ids = { - eid - for key, eid in iteritems(state_set) - if ( - key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite) - or key - in ( - (EventTypes.PowerLevels, ""), - (EventTypes.Create, ""), - (EventTypes.JoinRules, ""), + auth_sets.append( + { + eid + for key, eid in iteritems(state_set) + if ( + key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite) + or key + in ( + (EventTypes.PowerLevels, ""), + (EventTypes.Create, ""), + (EventTypes.JoinRules, ""), + ) ) - ) - and eid not in common - } - - auth_chain = yield state_res_store.get_auth_chain(auth_ids, common) - auth_ids.update(auth_chain) - - auth_sets.append(auth_ids) + } + ) - intersection = set(auth_sets[0]).intersection(*auth_sets[1:]) - union = set().union(*auth_sets) + difference = yield state_res_store.get_auth_chain_difference(auth_sets) - return union - intersection + return difference def _seperate(state_sets): diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 49a7b8b4338d..68a95c3a4f74 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -14,7 +14,7 @@ # limitations under the License. import itertools import logging -from typing import List, Optional, Set +from typing import Dict, List, Optional, Set from six.moves.queue import Empty, PriorityQueue @@ -103,6 +103,149 @@ def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events): return list(results) + def get_auth_chain_difference(self, state_sets: List[Set[str]]): + """Given sets of state events figure out the auth chain difference (as + per state res v2 algorithm). + + This equivalent to fetching the full auth chain for each set of state + and returning the events that don't appear in each and every auth + chain. + + Returns: + Deferred[Set[str]] + """ + + return self.db.runInteraction( + "get_auth_chain_difference", + self._get_auth_chain_difference_txn, + state_sets, + ) + + def _get_auth_chain_difference_txn( + self, txn, state_sets: List[Set[str]] + ) -> Set[str]: + + # Algorithm Description + # ~~~~~~~~~~~~~~~~~~~~~ + # + # The idea here is to basically walk the auth graph of each state set in + # tandem, keeping track of which auth events are reachable by each state + # set. If we reach an auth event we've already visited (via a different + # state set) then we mark that auth event and all ancestors as reachable + # by the state set. This requires that we keep track of the auth chains + # in memory. + # + # Doing it in a such a way means that we can stop early if all auth + # events we're currently walking are reachable by all state sets. + # + # *Note*: We can't stop walking an event's auth chain if it is reachable + # by all state sets. This is because other auth chains we're walking + # might be reachable only via the original auth chain. For example, + # given the following auth chain: + # + # A -> C -> D -> E + # / / + # B -´---------´ + # + # and state sets {A} and {B} then walking the auth chains of A and B + # would immediatley show that C is reachable by both. However, if we + # stopped at C then we'd only reach E via the auth chain of B and so E + # would errornously get included in the returned difference. + # + # The other thing that we do is limit the number of auth chains we walk + # at once, due to practical limits (i.e. we can only query the database + # with a limited set of parameters). We pick the auth chains we walk + # each iteration based on their depth, in the hope that events with a + # lower depth are likely reachable by those with higher depths. + # + # We could use any ordering that we believe would give a rough + # toplogical ordfering, e.g. origin server timestamp. If the ordering + # chosen is not topological then the algorithm still produces the right + # result, but perhaps a bit more inefficiently. This is why it is safe + # to use "depth" here. + + initial_events = set(state_sets[0]).union(*state_sets[1:]) + + # Dict from events in auth chains to which sets *cannot* reach them. + # I.e. if the set is empty then all sets can reach the event. + event_to_missing_sets = { + event_id: {i for i, a in enumerate(state_sets) if event_id not in a} + for event_id in initial_events + } + + # We need to get the depth of the initial events for sorting purposes. + rows = self.db.simple_select_many_txn( + txn, + table="events", + column="event_id", + iterable=initial_events, + keyvalues={}, + retcols=("event_id", "depth"), + ) + + # The sorted list of events we should walk the auth chain off. + search = sorted((row["depth"], row["event_id"]) for row in rows) + + # Map from event to its auth events + event_to_auth_events = {} # type: Dict[str, Set[str]] + + base_sql = """ + SELECT depth, a.event_id, auth_id + FROM event_auth AS a + INNER JOIN events AS e ON (e.event_id = a.auth_id) + WHERE + """ + + while search: + # Check whether all our current walks are reachable by all state + # sets. If so we can bail. + if all(not event_to_missing_sets[eid] for _, eid in search): + break + + # Fetch the auth events and their depths of the N last events we're + # currently walking + search, chunk = search[:-100], search[-100:] + clause, args = make_in_list_sql_clause( + txn.database_engine, "a.event_id", [e_id for _, e_id in chunk] + ) + txn.execute(base_sql + clause, args) + + for depth, event_id, auth_id in txn: + event_to_auth_events.setdefault(event_id, set()).add(auth_id) + + if auth_id not in event_to_missing_sets: + # First time we're seeing this event, so we add it to the + # queue of things to fetch. + search.append((depth, auth_id)) + else: + # We've previously seen this event, so look up its auth + # events and recursively mark all ancestors as reachable + # by the current events state set. + a_ids = event_to_auth_events.get(auth_id) + while a_ids: + new_aids = set() + for a_id in a_ids: + event_to_missing_sets[a_id].intersection_update( + event_to_missing_sets[event_id] + ) + + b = event_to_auth_events.get(a_id) + if b: + new_aids.update(b) + + a_ids = new_aids + + # Mark that the auth event is reachable by the approriate sets. + sets = event_to_missing_sets.setdefault( + auth_id, set(range(len(state_sets))) + ) + sets.intersection_update(event_to_missing_sets[event_id]) + + search.sort() + + # Return all events where not all sets can reach them. + return {eid for eid, n in event_to_missing_sets.items() if n} + def get_oldest_events_in_room(self, room_id): return self.db.runInteraction( "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 5059ade8503f..a44960203e06 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -603,7 +603,7 @@ def get_events(self, event_ids, allow_rejected=False): return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} - def get_auth_chain(self, event_ids, ignore_events): + def _get_auth_chain(self, event_ids): """Gets the full auth chain for a set of events (including rejected events). @@ -617,9 +617,6 @@ def get_auth_chain(self, event_ids, ignore_events): Args: event_ids (list): The event IDs of the events to fetch the auth chain for. Must be state events. - ignore_events: Set of events to exclude from the returned auth - chain. - Returns: Deferred[list[str]]: List of event IDs of the auth chain. """ @@ -629,7 +626,7 @@ def get_auth_chain(self, event_ids, ignore_events): stack = list(event_ids) while stack: event_id = stack.pop() - if event_id in result or event_id in ignore_events: + if event_id in result: continue result.add(event_id) @@ -639,3 +636,9 @@ def get_auth_chain(self, event_ids, ignore_events): stack.append(aid) return list(result) + + def get_auth_chain_difference(self, auth_sets): + chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] + + common = set(chains[0]).intersection(*chains[1:]) + return set(chains[0]).union(*chains[1:]) - common diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index abe2fd93a011..7314f84fedf4 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -110,3 +110,119 @@ def insert_event(txn, i, room_id): r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) self.assertTrue(r == [room2] or r == [room3]) + + def test_auth_difference(self): + room_id = "@ROOM:local" + + # The silly auth graph we use to test th eauth difference algorithm, + # where the top are the most recent events. + # + # A B + # \ / + # D E + # \ | + # ` F C + # | /| + # G ´ | + # | \ | + # H I + # | | + # K J + + auth_graph = { + "a": ["e"], + "b": ["e"], + "c": ["g", "i"], + "d": ["f"], + "e": ["f"], + "f": ["g"], + "g": ["h", "i"], + "h": ["k"], + "i": ["j"], + "k": [], + "j": [], + } + + depth_map = { + "a": 7, + "b": 7, + "c": 4, + "d": 6, + "e": 6, + "f": 5, + "g": 3, + "h": 2, + "i": 2, + "k": 1, + "j": 1, + } + + # We rudely fiddle with the appropriate tables directly, as that's much + # easier than constructing events properly. + + def insert_event(txn, event_id): + + depth = depth_map[event_id] + + self.store.db.simple_insert_txn( + txn, + table="events", + values={ + "event_id": event_id, + "room_id": room_id, + "depth": depth, + "topological_ordering": depth, + "type": "m.test", + "processed": True, + "outlier": False, + }, + ) + + self.store.db.simple_insert_many_txn( + txn, + table="event_auth", + values=[ + {"event_id": event_id, "room_id": room_id, "auth_id": a} + for a in auth_graph[event_id] + ], + ) + + for event_id in auth_graph: + self.get_success( + self.store.db.runInteraction("insert", insert_event, event_id) + ) + + # Now actually test that various combinations give the right result: + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a", "c"}, {"b"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "d", "e"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}]) + ) + self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"}) + + difference = self.get_success( + self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}]) + ) + self.assertSetEqual(difference, {"a", "b"}) + + difference = self.get_success(self.store.get_auth_chain_difference([{"a"}])) + self.assertSetEqual(difference, set()) From 91a6f78be321738728e87dbf19cc40228bbcb58b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Mar 2020 16:21:18 +0000 Subject: [PATCH 3/9] Newsfile --- changelog.d/7095.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/7095.misc diff --git a/changelog.d/7095.misc b/changelog.d/7095.misc new file mode 100644 index 000000000000..44fc9f616f08 --- /dev/null +++ b/changelog.d/7095.misc @@ -0,0 +1 @@ +Attempt to improve performance of state res v2 algorithm. From 0ebf292c8d59b832da43b9a5bab25db185a9acbf Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Mar 2020 16:39:07 +0000 Subject: [PATCH 4/9] Fix unit tests --- tests/storage/test_event_federation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 7314f84fedf4..cc39e1b1a766 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -175,6 +175,7 @@ def insert_event(txn, event_id): "type": "m.test", "processed": True, "outlier": False, + "stream_ordering": depth, }, ) From bd2e18ad2519b99f1b9c0abb6a9634b53c10c86b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Mar 2020 16:53:36 +0000 Subject: [PATCH 5/9] Really fix unit tests --- tests/storage/test_event_federation.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index cc39e1b1a766..6d638572cbba 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -160,7 +160,7 @@ def test_auth_difference(self): # We rudely fiddle with the appropriate tables directly, as that's much # easier than constructing events properly. - def insert_event(txn, event_id): + def insert_event(txn, event_id, stream_ordering): depth = depth_map[event_id] @@ -175,7 +175,7 @@ def insert_event(txn, event_id): "type": "m.test", "processed": True, "outlier": False, - "stream_ordering": depth, + "stream_ordering": stream_ordering, }, ) @@ -188,9 +188,13 @@ def insert_event(txn, event_id): ], ) + next_stream_ordering = 0 for event_id in auth_graph: + next_stream_ordering += 1 self.get_success( - self.store.db.runInteraction("insert", insert_event, event_id) + self.store.db.runInteraction( + "insert", insert_event, event_id, next_stream_ordering + ) ) # Now actually test that various combinations give the right result: From d373279c965addb92ae52f235f3cedfb9e243c08 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Mar 2020 09:46:17 +0000 Subject: [PATCH 6/9] Fix typos Co-Authored-By: Matthew Hodgson --- synapse/storage/data_stores/main/event_federation.py | 4 ++-- tests/storage/test_event_federation.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 68a95c3a4f74..0813b619eb9c 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -148,7 +148,7 @@ def _get_auth_chain_difference_txn( # B -´---------´ # # and state sets {A} and {B} then walking the auth chains of A and B - # would immediatley show that C is reachable by both. However, if we + # would immediately show that C is reachable by both. However, if we # stopped at C then we'd only reach E via the auth chain of B and so E # would errornously get included in the returned difference. # @@ -159,7 +159,7 @@ def _get_auth_chain_difference_txn( # lower depth are likely reachable by those with higher depths. # # We could use any ordering that we believe would give a rough - # toplogical ordfering, e.g. origin server timestamp. If the ordering + # topological ordering, e.g. origin server timestamp. If the ordering # chosen is not topological then the algorithm still produces the right # result, but perhaps a bit more inefficiently. This is why it is safe # to use "depth" here. diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 6d638572cbba..3aeec0dc0f52 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -114,7 +114,7 @@ def insert_event(txn, i, room_id): def test_auth_difference(self): room_id = "@ROOM:local" - # The silly auth graph we use to test th eauth difference algorithm, + # The silly auth graph we use to test the auth difference algorithm, # where the top are the most recent events. # # A B From fa04467bbeba81616d4cb796252c89a7ef97f3d3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Mar 2020 13:46:03 +0000 Subject: [PATCH 7/9] Review comments --- synapse/state/__init__.py | 2 +- synapse/state/v2.py | 21 ++------- .../data_stores/main/event_federation.py | 43 +++++++++++-------- 3 files changed, 28 insertions(+), 38 deletions(-) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index f37c43a419bc..4afefc6b1d2e 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -671,7 +671,7 @@ def get_auth_chain_difference(self, state_sets: List[Set[str]]): chain. Returns: - Deferred[Set[str]] + Deferred[Set[str]]: Set of event IDs. """ return self.store.get_auth_chain_difference(state_sets) diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 46e63ffde34d..18484e2fa6f9 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -227,25 +227,10 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): Returns: Deferred[set[str]]: Set of event IDs """ - auth_sets = [] - for state_set in state_sets: - auth_sets.append( - { - eid - for key, eid in iteritems(state_set) - if ( - key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite) - or key - in ( - (EventTypes.PowerLevels, ""), - (EventTypes.Create, ""), - (EventTypes.JoinRules, ""), - ) - ) - } - ) - difference = yield state_res_store.get_auth_chain_difference(auth_sets) + difference = yield state_res_store.get_auth_chain_difference( + [set(state_set.values()) for state_set in state_sets] + ) return difference diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 0813b619eb9c..87ae5b22730e 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -174,23 +174,24 @@ def _get_auth_chain_difference_txn( } # We need to get the depth of the initial events for sorting purposes. - rows = self.db.simple_select_many_txn( - txn, - table="events", - column="event_id", - iterable=initial_events, - keyvalues={}, - retcols=("event_id", "depth"), + sql = """ + SELECT depth, event_id FROM events + WHERE %s + ORDER BY depth ASC + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", initial_events ) + txn.execute(sql % (clause,), args) - # The sorted list of events we should walk the auth chain off. - search = sorted((row["depth"], row["event_id"]) for row in rows) + # The sorted list of events whose auth chains we should walk. + search = txn.fetchall() # Map from event to its auth events event_to_auth_events = {} # type: Dict[str, Set[str]] base_sql = """ - SELECT depth, a.event_id, auth_id + SELECT a.event_id, auth_id, depth FROM event_auth AS a INNER JOIN events AS e ON (e.event_id = a.auth_id) WHERE @@ -210,18 +211,25 @@ def _get_auth_chain_difference_txn( ) txn.execute(base_sql + clause, args) - for depth, event_id, auth_id in txn: - event_to_auth_events.setdefault(event_id, set()).add(auth_id) + for event_id, auth_event_id, auth_event_depth in txn: + event_to_auth_events.setdefault(event_id, set()).add(auth_event_id) - if auth_id not in event_to_missing_sets: + sets = event_to_missing_sets.get(auth_event_id) + if sets is None: # First time we're seeing this event, so we add it to the # queue of things to fetch. - search.append((depth, auth_id)) + search.append((auth_event_depth, auth_event_id)) + + # Assume that this event is unreachable from any of the + # state sets until proven otherwise + sets = event_to_missing_sets.setdefault( + auth_event_id, set(range(len(state_sets))) + ) else: # We've previously seen this event, so look up its auth # events and recursively mark all ancestors as reachable - # by the current events state set. - a_ids = event_to_auth_events.get(auth_id) + # by the current event's state set. + a_ids = event_to_auth_events.get(auth_event_id) while a_ids: new_aids = set() for a_id in a_ids: @@ -236,9 +244,6 @@ def _get_auth_chain_difference_txn( a_ids = new_aids # Mark that the auth event is reachable by the approriate sets. - sets = event_to_missing_sets.setdefault( - auth_id, set(range(len(state_sets))) - ) sets.intersection_update(event_to_missing_sets[event_id]) search.sort() From f47a4177e023c1e3f3d11bf9afd31aea4866f035 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Mar 2020 15:14:01 +0000 Subject: [PATCH 8/9] Apply suggestions from code review Co-Authored-By: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> --- synapse/storage/data_stores/main/event_federation.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 87ae5b22730e..52f3d9da9a40 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -185,7 +185,7 @@ def _get_auth_chain_difference_txn( txn.execute(sql % (clause,), args) # The sorted list of events whose auth chains we should walk. - search = txn.fetchall() + search = txn.fetchall() # type: List[Tuple[int, str]] # Map from event to its auth events event_to_auth_events = {} # type: Dict[str, Set[str]] @@ -222,9 +222,7 @@ def _get_auth_chain_difference_txn( # Assume that this event is unreachable from any of the # state sets until proven otherwise - sets = event_to_missing_sets.setdefault( - auth_event_id, set(range(len(state_sets))) - ) + sets = event_to_missing_sets[auth_event_id] = set(range(len(state_sets))) else: # We've previously seen this event, so look up its auth # events and recursively mark all ancestors as reachable From b20946666bc0f85913550880441492b23c485001 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Mar 2020 15:18:17 +0000 Subject: [PATCH 9/9] pep8 --- synapse/storage/data_stores/main/event_federation.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 52f3d9da9a40..62d4e9f59977 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -14,7 +14,7 @@ # limitations under the License. import itertools import logging -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Tuple from six.moves.queue import Empty, PriorityQueue @@ -185,7 +185,7 @@ def _get_auth_chain_difference_txn( txn.execute(sql % (clause,), args) # The sorted list of events whose auth chains we should walk. - search = txn.fetchall() # type: List[Tuple[int, str]] + search = txn.fetchall() # type: List[Tuple[int, str]] # Map from event to its auth events event_to_auth_events = {} # type: Dict[str, Set[str]] @@ -222,7 +222,9 @@ def _get_auth_chain_difference_txn( # Assume that this event is unreachable from any of the # state sets until proven otherwise - sets = event_to_missing_sets[auth_event_id] = set(range(len(state_sets))) + sets = event_to_missing_sets[auth_event_id] = set( + range(len(state_sets)) + ) else: # We've previously seen this event, so look up its auth # events and recursively mark all ancestors as reachable