From a47ee14ed5712983b6eee059e79c8b75c70b629f Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Thu, 3 Nov 2022 12:23:50 -0700 Subject: [PATCH 01/14] add a function to store state groups for batched events/contexts --- synapse/storage/databases/state/store.py | 112 ++++++++++++++++++++++- 1 file changed, 107 insertions(+), 5 deletions(-) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 1a7232b27665..454f8f7f77af 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -18,6 +18,8 @@ import attr from synapse.api.constants import EventTypes +from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -404,6 +406,111 @@ def _insert_into_cache( fetched_keys=non_member_types, ) + async def store_state_deltas_for_batched( + self, + events_and_context: List[Tuple[EventBase, EventContext]], + room_id: str, + prev_group: int, + ) -> List[int]: + """Generate and store state deltas for a group of events and contexts created to be + batch persisted. + + Args: + events_and_context: the events to generate and store a state groups for + and their associated contexts + room_id: the id of the room the events were created for + prev_group: the state group of the last event persisted before the batched events + were created + """ + + def insert_deltas_group_txn( + txn: LoggingTransaction, + events_and_context: List[Tuple[EventBase, EventContext]], + prev_group: int, + ) -> List[int]: + """Generate and store state groups for the provided events and contexts. + + Requires that we have the state as a delta from the last persisted state group. + + Returns: + A list of state groups + """ + is_in_db = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + num_state_groups = len(events_and_context) + + state_groups = self._state_group_seq_gen.get_next_mult_txn( + txn, num_state_groups + ) + + index = 0 + for event, context in events_and_context: + context._state_group = state_groups[index] + # The first prev_group will be the last persisted state group, which is passed in + # else it will be the group most recently assigned + if index > 0: + context.prev_group = state_groups[index - 1] + context.state_group_before_event = state_groups[index - 1] + else: + context.prev_group = prev_group + context.state_group_before_event = prev_group + context.delta_ids = {(event.type, event.state_key): event.event_id} + context._state_delta_due_to_event = { + (event.type, event.state_key): event.event_id + } + index += 1 + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups", + keys=("id", "room_id", "event_id"), + values=[ + (context._state_group, room_id, event.event_id) + for event, context in events_and_context + ], + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_group_edges", + keys=("state_group", "prev_state_group"), + values=[ + (context._state_group, context.prev_group) + for _, context in events_and_context + ], + ) + + for _, context in events_and_context: + assert context.delta_ids is not None + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + keys=("state_group", "room_id", "type", "state_key", "event_id"), + values=[ + (context._state_group, room_id, key[0], key[1], state_id) + for key, state_id in context.delta_ids.items() + ], + ) + return state_groups + + return await self.db_pool.runInteraction( + "store_state_deltas_for_batched.insert_deltas_group", + insert_deltas_group_txn, + events_and_context, + prev_group, + ) + async def store_state_group( self, event_id: str, @@ -413,10 +520,8 @@ async def store_state_group( current_state_ids: Optional[StateMap[str]], ) -> int: """Store a new set of state, returning a newly assigned state group. - At least one of `current_state_ids` and `prev_group` must be provided. Whenever `prev_group` is not None, `delta_ids` must also not be None. - Args: event_id: The event ID for which the state was calculated room_id @@ -426,7 +531,6 @@ async def store_state_group( `current_state_ids`. current_state_ids: The state to store. Map of (type, state_key) to event_id. - Returns: The state group ID """ @@ -441,9 +545,7 @@ def insert_delta_group_txn( txn: LoggingTransaction, prev_group: int, delta_ids: StateMap[str] ) -> Optional[int]: """Try and persist the new group as a delta. - Requires that we have the state as a delta from a previous state group. - Returns: The state group if successfully created, or None if the state needs to be persisted as a full state. From 8f4282cdc37f4a72b296fec442929f6ec310eafd Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 25 Jan 2023 19:04:06 -0800 Subject: [PATCH 02/14] add a function to batch persist unpersisted event contexts --- synapse/events/snapshot.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index e0d82ad81cf9..68e68b98262e 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from synapse.storage.controllers import StorageControllers + from synapse.storage.databases import StateGroupDataStore from synapse.storage.databases.main import DataStore from synapse.types.state import StateFilter @@ -348,6 +349,43 @@ class UnpersistedEventContext(UnpersistedEventContextBase): partial_state: bool state_map_before_event: Optional[StateMap[str]] = None + @classmethod + async def batch_persist_unpersisted_contexts( + cls, + events_and_context: List[Tuple[EventBase, "UnpersistedEventContextBase"]], + room_id: str, + last_known_state_group: int, + datastore: "StateGroupDataStore", + ) -> List[Tuple[EventBase, EventContext]]: + """ + Takes a list of events and their associated unpersisted contexts and persists + the unpersisted contexts, returning a list of events and persisted contexts. + + Args: + events_and_context: A list of events and their unpersisted contexts + room_id: the room_id for the events + last_known_state_group: the last persisted state group + datastore: a state datastore + """ + amended_events_and_context = await datastore.store_state_deltas_for_batched( + events_and_context, room_id, last_known_state_group + ) + + events_and_persisted_context = [] + for event, unpersisted_context in amended_events_and_context: + assert unpersisted_context.partial_state is not None + context = EventContext( + storage=unpersisted_context._storage, + state_group=unpersisted_context.state_group_after_event, + state_group_before_event=unpersisted_context.state_group_before_event, + state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, + partial_state=unpersisted_context.partial_state, + prev_group=unpersisted_context.prev_group_for_state_group_after_event, + delta_ids=unpersisted_context.delta_ids_to_state_group_after_event, + ) + events_and_persisted_context.append((event, context)) + return events_and_persisted_context + async def get_prev_state_ids( self, state_filter: Optional["StateFilter"] = None ) -> StateMap[str]: From b9bef4223a76e99a7aa4cecca235863d4110059b Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 25 Jan 2023 19:13:45 -0800 Subject: [PATCH 03/14] refactor function that batch persists state groups to use unpersisted event contexts --- synapse/storage/databases/state/store.py | 32 ++++++++++++------------ 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 454f8f7f77af..baa79da036f0 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -19,7 +19,7 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -408,10 +408,10 @@ def _insert_into_cache( async def store_state_deltas_for_batched( self, - events_and_context: List[Tuple[EventBase, EventContext]], + events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]], room_id: str, prev_group: int, - ) -> List[int]: + ) -> List[Tuple[EventBase, UnpersistedEventContext]]: """Generate and store state deltas for a group of events and contexts created to be batch persisted. @@ -425,9 +425,9 @@ async def store_state_deltas_for_batched( def insert_deltas_group_txn( txn: LoggingTransaction, - events_and_context: List[Tuple[EventBase, EventContext]], + events_and_context: List[Tuple[EventBase, UnpersistedEventContext]], prev_group: int, - ) -> List[int]: + ) -> List[Tuple[EventBase, UnpersistedEventContext]]: """Generate and store state groups for the provided events and contexts. Requires that we have the state as a delta from the last persisted state group. @@ -456,17 +456,17 @@ def insert_deltas_group_txn( index = 0 for event, context in events_and_context: - context._state_group = state_groups[index] + context.state_group_after_event = state_groups[index] # The first prev_group will be the last persisted state group, which is passed in # else it will be the group most recently assigned if index > 0: - context.prev_group = state_groups[index - 1] + context.prev_group_for_state_group_after_event = state_groups[index - 1] context.state_group_before_event = state_groups[index - 1] else: - context.prev_group = prev_group + context.prev_group_for_state_group_after_event = prev_group context.state_group_before_event = prev_group - context.delta_ids = {(event.type, event.state_key): event.event_id} - context._state_delta_due_to_event = { + context.delta_ids_to_state_group_after_event = {(event.type, event.state_key): event.event_id} + context.state_delta_due_to_event = { (event.type, event.state_key): event.event_id } index += 1 @@ -476,7 +476,7 @@ def insert_deltas_group_txn( table="state_groups", keys=("id", "room_id", "event_id"), values=[ - (context._state_group, room_id, event.event_id) + (context.state_group_after_event, room_id, event.event_id) for event, context in events_and_context ], ) @@ -486,23 +486,23 @@ def insert_deltas_group_txn( table="state_group_edges", keys=("state_group", "prev_state_group"), values=[ - (context._state_group, context.prev_group) + (context.state_group_after_event, context.prev_group_for_state_group_after_event) for _, context in events_and_context ], ) for _, context in events_and_context: - assert context.delta_ids is not None + assert context.delta_ids_to_state_group_after_event is not None self.db_pool.simple_insert_many_txn( txn, table="state_groups_state", keys=("state_group", "room_id", "type", "state_key", "event_id"), values=[ - (context._state_group, room_id, key[0], key[1], state_id) - for key, state_id in context.delta_ids.items() + (context.state_group_after_event, room_id, key[0], key[1], state_id) + for key, state_id in context.delta_ids_to_state_group_after_event.items() ], ) - return state_groups + return events_and_context return await self.db_pool.runInteraction( "store_state_deltas_for_batched.insert_deltas_group", From 4f991a07205a9f4d50555f4cfddc750f5e89c6bf Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 25 Jan 2023 19:14:23 -0800 Subject: [PATCH 04/14] batch persist state groups when creating room --- synapse/handlers/room.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 0e759b8a5d7a..be0a9ba38454 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -50,6 +50,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase +from synapse.events.snapshot import UnpersistedEventContext from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM @@ -211,7 +212,7 @@ async def upgrade_room( # the required power level to send the tombstone event. ( tombstone_event, - tombstone_context, + tombstone_unpersisted_context, ) = await self.event_creation_handler.create_event( requester, { @@ -225,6 +226,9 @@ async def upgrade_room( }, }, ) + tombstone_context = await tombstone_unpersisted_context.persist( + tombstone_event + ) validate_event_for_room_version(tombstone_event) await self._event_auth_handler.check_auth_rules_from_context( tombstone_event @@ -1091,7 +1095,7 @@ async def create_event( content: JsonDict, for_batch: bool, **kwargs: Any, - ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]: + ) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]: """ Creates an event and associated event context. Args: @@ -1110,20 +1114,23 @@ async def create_event( event_dict = create_event_dict(etype, content, **kwargs) - new_event, new_context = await self.event_creation_handler.create_event( + ( + new_event, + new_unpersisted_context, + ) = await self.event_creation_handler.create_event( creator, event_dict, prev_event_ids=prev_event, depth=depth, state_map=state_map, for_batch=for_batch, - current_state_group=current_state_group, ) + depth += 1 prev_event = [new_event.event_id] state_map[(new_event.type, new_event.state_key)] = new_event.event_id - return new_event, new_context + return new_event, new_unpersisted_context try: config = self._presets_dict[preset_config] @@ -1133,10 +1140,10 @@ async def create_event( ) creation_content.update({"creator": creator_id}) - creation_event, creation_context = await create_event( + creation_event, unpersisted_creation_context = await create_event( EventTypes.Create, creation_content, False ) - + creation_context = await unpersisted_creation_context.persist(creation_event) logger.debug("Sending %s in new room", EventTypes.Member) ev = await self.event_creation_handler.handle_new_client_event( requester=creator, @@ -1180,7 +1187,6 @@ async def create_event( power_event, power_context = await create_event( EventTypes.PowerLevels, pl_content, True ) - current_state_group = power_context._state_group events_to_send.append((power_event, power_context)) else: power_level_content: JsonDict = { @@ -1229,14 +1235,12 @@ async def create_event( power_level_content, True, ) - current_state_group = pl_context._state_group events_to_send.append((pl_event, pl_context)) if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: room_alias_event, room_alias_context = await create_event( EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True ) - current_state_group = room_alias_context._state_group events_to_send.append((room_alias_event, room_alias_context)) if (EventTypes.JoinRules, "") not in initial_state: @@ -1245,7 +1249,6 @@ async def create_event( {"join_rule": config["join_rules"]}, True, ) - current_state_group = join_rules_context._state_group events_to_send.append((join_rules_event, join_rules_context)) if (EventTypes.RoomHistoryVisibility, "") not in initial_state: @@ -1254,7 +1257,6 @@ async def create_event( {"history_visibility": config["history_visibility"]}, True, ) - current_state_group = visibility_context._state_group events_to_send.append((visibility_event, visibility_context)) if config["guest_can_join"]: @@ -1264,14 +1266,12 @@ async def create_event( {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, True, ) - current_state_group = guest_access_context._state_group events_to_send.append((guest_access_event, guest_access_context)) for (etype, state_key), content in initial_state.items(): event, context = await create_event( etype, content, True, state_key=state_key ) - current_state_group = context._state_group events_to_send.append((event, context)) if config["encrypted"]: @@ -1283,9 +1283,16 @@ async def create_event( ) events_to_send.append((encryption_event, encryption_context)) + datastore = self.hs.get_datastores().state + events_and_context = ( + await UnpersistedEventContext.batch_persist_unpersisted_contexts( + events_to_send, room_id, current_state_group, datastore + ) + ) + last_event = await self.event_creation_handler.handle_new_client_event( creator, - events_to_send, + events_and_context, ignore_shadow_ban=True, ratelimit=False, ) From eba04f57fac1b9325e07ce0ce903ca03eb9566c8 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 25 Jan 2023 19:15:22 -0800 Subject: [PATCH 05/14] persist contexts after `creat_event` instead of during --- synapse/handlers/message.py | 15 +++++++-------- synapse/handlers/room_batch.py | 4 ++-- synapse/handlers/room_member.py | 13 ++++++++++--- tests/handlers/test_message.py | 19 +++++++++++++------ tests/handlers/test_register.py | 3 ++- tests/push/test_bulk_push_rule_evaluator.py | 6 ++++-- tests/rest/client/test_rooms.py | 4 ++-- tests/storage/test_event_chain.py | 6 ++++-- tests/unittest.py | 4 ++-- 9 files changed, 46 insertions(+), 28 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 3e30f52e4d9f..bab00a118122 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -574,7 +574,7 @@ async def create_event( state_map: Optional[StateMap[str]] = None, for_batch: bool = False, current_state_group: Optional[int] = None, - ) -> Tuple[EventBase, EventContext]: + ) -> Tuple[EventBase, UnpersistedEventContextBase]: """ Given a dict from a client, create a new event. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for @@ -721,7 +721,6 @@ async def create_event( current_state_group=current_state_group, ) - context = await unpersisted_context.persist(event) # In an ideal world we wouldn't need the second part of this condition. However, # this behaviour isn't spec'd yet, meaning we should be able to deactivate this @@ -739,7 +738,7 @@ async def create_event( assert state_map is not None prev_event_id = state_map.get((EventTypes.Member, event.sender)) else: - prev_state_ids = await context.get_prev_state_ids( + prev_state_ids = await unpersisted_context.get_prev_state_ids( StateFilter.from_types([(EventTypes.Member, None)]) ) prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) @@ -764,8 +763,7 @@ async def create_event( ) self.validator.validate_new(event, self.config) - - return event, context + return event, unpersisted_context async def _is_exempt_from_privacy_policy( self, builder: EventBuilder, requester: Requester @@ -1005,7 +1003,7 @@ async def create_and_send_nonmember_event( max_retries = 5 for i in range(max_retries): try: - event, context = await self.create_event( + event, unpersisted_context = await self.create_event( requester, event_dict, txn_id=txn_id, @@ -1016,6 +1014,7 @@ async def create_and_send_nonmember_event( historical=historical, depth=depth, ) + context = await unpersisted_context.persist(event) assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( event.sender, @@ -1190,7 +1189,6 @@ async def create_new_client_event( if for_batch: assert prev_event_ids is not None assert state_map is not None - assert current_state_group is not None auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map) event = await builder.build( prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth @@ -2042,7 +2040,7 @@ async def _send_dummy_event_for_room(self, room_id: str) -> bool: max_retries = 5 for i in range(max_retries): try: - event, context = await self.create_event( + event, unpersisted_context = await self.create_event( requester, { "type": EventTypes.Dummy, @@ -2051,6 +2049,7 @@ async def _send_dummy_event_for_room(self, room_id: str) -> bool: "sender": user_id, }, ) + context = await unpersisted_context.persist(event) event.internal_metadata.proactively_send = False diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index c73d2adaad47..7cfa2c28bc4f 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -327,7 +327,7 @@ async def persist_historical_events( # Mark all events as historical event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - event, context = await self.event_creation_handler.create_event( + event, unpersisted_context = await self.event_creation_handler.create_event( await self.create_requester_for_user_id_from_app_service( ev["sender"], app_service_requester.app_service ), @@ -345,7 +345,7 @@ async def persist_historical_events( historical=True, depth=inherited_depth, ) - + context = await unpersisted_context.persist(event) assert context._state_group # Normally this is done when persisting the event but we have to diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index d236cc09b526..9406fd3097a2 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -398,7 +398,10 @@ async def _local_membership_update( max_retries = 5 for i in range(max_retries): try: - event, context = await self.event_creation_handler.create_event( + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_event( requester, { "type": EventTypes.Member, @@ -419,7 +422,7 @@ async def _local_membership_update( outlier=outlier, historical=historical, ) - + context = await unpersisted_context.persist(event) prev_state_ids = await context.get_prev_state_ids( StateFilter.from_types([(EventTypes.Member, None)]) ) @@ -1892,7 +1895,10 @@ async def _generate_local_out_of_band_leave( max_retries = 5 for i in range(max_retries): try: - event, context = await self.event_creation_handler.create_event( + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_event( requester, event_dict, txn_id=txn_id, @@ -1900,6 +1906,7 @@ async def _generate_local_out_of_band_leave( auth_event_ids=auth_event_ids, outlier=True, ) + context = await unpersisted_context.persist(event) event.internal_metadata.out_of_band_membership = True result_event = ( diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index c4727ab917fd..04ab940cd53a 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -106,7 +106,8 @@ def test_duplicated_txn_id(self) -> None: txn_id = "something_suitably_random" - event1, context = self._create_duplicate_event(txn_id) + event1, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event1)) ret_event1 = self.get_success( self.handler.handle_new_client_event( @@ -118,7 +119,8 @@ def test_duplicated_txn_id(self) -> None: self.assertEqual(event1.event_id, ret_event1.event_id) - event2, context = self._create_duplicate_event(txn_id) + event2, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event2)) # We want to test that the deduplication at the persit event end works, # so we want to make sure we test with different events. @@ -139,7 +141,9 @@ def test_duplicated_txn_id(self) -> None: # Let's test that calling `persist_event` directly also does the right # thing. - event3, context = self._create_duplicate_event(txn_id) + event3, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event3)) + self.assertNotEqual(event1.event_id, event3.event_id) ret_event3, event_pos3, _ = self.get_success( @@ -153,7 +157,8 @@ def test_duplicated_txn_id(self) -> None: # Let's test that calling `persist_events` directly also does the right # thing. - event4, context = self._create_duplicate_event(txn_id) + event4, unpersisted_context = self._create_duplicate_event(txn_id) + context = self.get_success(unpersisted_context.persist(event4)) self.assertNotEqual(event1.event_id, event3.event_id) events, _ = self.get_success( @@ -173,8 +178,10 @@ def test_duplicated_txn_id_one_call(self) -> None: txn_id = "something_else_suitably_random" # Create two duplicate events to persist at the same time - event1, context1 = self._create_duplicate_event(txn_id) - event2, context2 = self._create_duplicate_event(txn_id) + event1, unpersisted_context1 = self._create_duplicate_event(txn_id) + context1 = self.get_success(unpersisted_context1.persist(event1)) + event2, unpersisted_context2 = self._create_duplicate_event(txn_id) + context2 = self.get_success(unpersisted_context2.persist(event2)) # Ensure their event IDs are different to start with self.assertNotEqual(event1.event_id, event2.event_id) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index b9332d97dcdc..a4ba89afd7f5 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -503,7 +503,7 @@ def test_auto_create_auto_join_room_preset_invalid_permissions(self) -> None: # Lower the permissions of the inviter. event_creation_handler = self.hs.get_event_creation_handler() requester = create_requester(inviter) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_creation_handler.create_event( requester, { @@ -515,6 +515,7 @@ def test_auto_create_auto_join_room_preset_invalid_permissions(self) -> None: }, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_creation_handler.handle_new_client_event( requester, events_and_context=[(event, context)] diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 7567756135b7..46df79f730a5 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -131,7 +131,7 @@ def test_action_for_event_by_user_handles_noninteger_room_power_levels( # Create a new message event, and try to evaluate it under the dodgy # power level event. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -146,6 +146,7 @@ def test_action_for_event_by_user_handles_noninteger_room_power_levels( prev_event_ids=[pl_event_id], ) ) + context = self.get_success(unpersisted_context.persist(event)) bulk_evaluator = BulkPushRuleEvaluator(self.hs) # should not raise @@ -171,7 +172,7 @@ def test_action_for_event_by_user_disabled_by_config(self) -> None: """Ensure that push rules are not calculated when disabled in the config""" # Create a new message event which should cause a notification. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -185,6 +186,7 @@ def test_action_for_event_by_user_disabled_by_config(self) -> None: }, ) ) + context = self.get_success(unpersisted_context.persist(event)) bulk_evaluator = BulkPushRuleEvaluator(self.hs) # Mock the method which calculates push rules -- we do this instead of diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 9222cab19801..ede630a4f981 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -715,7 +715,7 @@ def test_post_room_no_keys(self) -> None: self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(33, channel.resource_usage.db_txn_count) + self.assertEqual(30, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -728,7 +728,7 @@ def test_post_room_initial_state(self) -> None: self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(36, channel.resource_usage.db_txn_count) + self.assertEqual(32, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index c070278db80f..4d70d24b8d37 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -522,7 +522,7 @@ def _generate_room(self) -> Tuple[str, List[Set[str]]]: latest_event_ids = self.get_success( self.store.get_prev_events_for_room(room_id) ) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_handler.create_event( self.requester, { @@ -535,6 +535,7 @@ def _generate_room(self) -> Tuple[str, List[Set[str]]]: prev_event_ids=latest_event_ids, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] @@ -542,7 +543,7 @@ def _generate_room(self) -> Tuple[str, List[Set[str]]]: ) state1 = set(self.get_success(context.get_current_state_ids()).values()) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_handler.create_event( self.requester, { @@ -555,6 +556,7 @@ def _generate_room(self) -> Tuple[str, List[Set[str]]]: prev_event_ids=latest_event_ids, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( event_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] diff --git a/tests/unittest.py b/tests/unittest.py index 68e59a88dc0f..1ba4f14c2b53 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -716,7 +716,7 @@ def create_and_send_event( event_creator = self.hs.get_event_creation_handler() requester = create_requester(user) - event, context = self.get_success( + event, unpersisted_context = self.get_success( event_creator.create_event( requester, { @@ -728,7 +728,7 @@ def create_and_send_event( prev_event_ids=prev_event_ids, ) ) - + context = self.get_success(unpersisted_context.persist(event)) if soft_failed: event.internal_metadata.soft_failed = True From bc0ccd1902eec6a9eaaeaedda67ecce87e954bc9 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 25 Jan 2023 19:18:05 -0800 Subject: [PATCH 06/14] lints --- synapse/handlers/message.py | 1 - synapse/storage/databases/state/store.py | 21 +++++++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index bab00a118122..05144981008b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -721,7 +721,6 @@ async def create_event( current_state_group=current_state_group, ) - # In an ideal world we wouldn't need the second part of this condition. However, # this behaviour isn't spec'd yet, meaning we should be able to deactivate this # behaviour. Another reason is that this code is also evaluated each time a new diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index baa79da036f0..b0268c2850e4 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -460,12 +460,16 @@ def insert_deltas_group_txn( # The first prev_group will be the last persisted state group, which is passed in # else it will be the group most recently assigned if index > 0: - context.prev_group_for_state_group_after_event = state_groups[index - 1] + context.prev_group_for_state_group_after_event = state_groups[ + index - 1 + ] context.state_group_before_event = state_groups[index - 1] else: context.prev_group_for_state_group_after_event = prev_group context.state_group_before_event = prev_group - context.delta_ids_to_state_group_after_event = {(event.type, event.state_key): event.event_id} + context.delta_ids_to_state_group_after_event = { + (event.type, event.state_key): event.event_id + } context.state_delta_due_to_event = { (event.type, event.state_key): event.event_id } @@ -486,7 +490,10 @@ def insert_deltas_group_txn( table="state_group_edges", keys=("state_group", "prev_state_group"), values=[ - (context.state_group_after_event, context.prev_group_for_state_group_after_event) + ( + context.state_group_after_event, + context.prev_group_for_state_group_after_event, + ) for _, context in events_and_context ], ) @@ -498,7 +505,13 @@ def insert_deltas_group_txn( table="state_groups_state", keys=("state_group", "room_id", "type", "state_key", "event_id"), values=[ - (context.state_group_after_event, room_id, key[0], key[1], state_id) + ( + context.state_group_after_event, + room_id, + key[0], + key[1], + state_id, + ) for key, state_id in context.delta_ids_to_state_group_after_event.items() ], ) From 7ec7b9a7dd51470a7f356aceaf6c8aa08a7c31ca Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 25 Jan 2023 19:18:10 -0800 Subject: [PATCH 07/14] newsfragment --- changelog.d/14918.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/14918.misc diff --git a/changelog.d/14918.misc b/changelog.d/14918.misc new file mode 100644 index 000000000000..828794354acd --- /dev/null +++ b/changelog.d/14918.misc @@ -0,0 +1 @@ +Batch up storing state groups when creating a new room. \ No newline at end of file From c4646cbc7add683d02aa0b554fee4c3d9fa30872 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 8 Feb 2023 11:17:57 -0800 Subject: [PATCH 08/14] requested changes --- synapse/events/snapshot.py | 2 ++ synapse/storage/databases/state/store.py | 9 ++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 68e68b98262e..b02fca12a68c 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -360,6 +360,8 @@ async def batch_persist_unpersisted_contexts( """ Takes a list of events and their associated unpersisted contexts and persists the unpersisted contexts, returning a list of events and persisted contexts. + Note that all the events must be in a linear chain (ie a <- b <- c) + and must be state events. Args: events_and_context: A list of events and their unpersisted contexts diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index b0268c2850e4..176c8afb97a7 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -413,7 +413,8 @@ async def store_state_deltas_for_batched( prev_group: int, ) -> List[Tuple[EventBase, UnpersistedEventContext]]: """Generate and store state deltas for a group of events and contexts created to be - batch persisted. + batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c) + and must be state events. Args: events_and_context: the events to generate and store a state groups for @@ -454,8 +455,7 @@ def insert_deltas_group_txn( txn, num_state_groups ) - index = 0 - for event, context in events_and_context: + for index, (event, context) in enumerate(events_and_context): context.state_group_after_event = state_groups[index] # The first prev_group will be the last persisted state group, which is passed in # else it will be the group most recently assigned @@ -533,8 +533,10 @@ async def store_state_group( current_state_ids: Optional[StateMap[str]], ) -> int: """Store a new set of state, returning a newly assigned state group. + At least one of `current_state_ids` and `prev_group` must be provided. Whenever `prev_group` is not None, `delta_ids` must also not be None. + Args: event_id: The event ID for which the state was calculated room_id @@ -544,6 +546,7 @@ async def store_state_group( `current_state_ids`. current_state_ids: The state to store. Map of (type, state_key) to event_id. + Returns: The state group ID """ From 4f6f50fadeebe39786d3e0b9a232a0c331d16a9b Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Mon, 13 Feb 2023 17:59:12 -0800 Subject: [PATCH 09/14] requested changes --- synapse/events/snapshot.py | 8 ++- synapse/storage/databases/state/store.py | 58 ++++++++++++--------- tests/push/test_bulk_push_rule_evaluator.py | 3 +- 3 files changed, 37 insertions(+), 32 deletions(-) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index b02fca12a68c..96fae2a46689 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -360,8 +360,7 @@ async def batch_persist_unpersisted_contexts( """ Takes a list of events and their associated unpersisted contexts and persists the unpersisted contexts, returning a list of events and persisted contexts. - Note that all the events must be in a linear chain (ie a <- b <- c) - and must be state events. + Note that all the events must be in a linear chain (ie a <- b <- c). Args: events_and_context: A list of events and their unpersisted contexts @@ -375,15 +374,14 @@ async def batch_persist_unpersisted_contexts( events_and_persisted_context = [] for event, unpersisted_context in amended_events_and_context: - assert unpersisted_context.partial_state is not None context = EventContext( storage=unpersisted_context._storage, state_group=unpersisted_context.state_group_after_event, state_group_before_event=unpersisted_context.state_group_before_event, state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, partial_state=unpersisted_context.partial_state, - prev_group=unpersisted_context.prev_group_for_state_group_after_event, - delta_ids=unpersisted_context.delta_ids_to_state_group_after_event, + prev_group=unpersisted_context.state_group_before_event, + delta_ids=unpersisted_context.state_delta_due_to_event, ) events_and_persisted_context.append((event, context)) return events_and_persisted_context diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 176c8afb97a7..7873064b5e05 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -413,8 +413,7 @@ async def store_state_deltas_for_batched( prev_group: int, ) -> List[Tuple[EventBase, UnpersistedEventContext]]: """Generate and store state deltas for a group of events and contexts created to be - batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c) - and must be state events. + batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c). Args: events_and_context: the events to generate and store a state groups for @@ -449,31 +448,32 @@ def insert_deltas_group_txn( % (prev_group,) ) - num_state_groups = len(events_and_context) + num_state_groups = 0 + for event, _ in events_and_context: + if event.is_state(): + num_state_groups += 1 state_groups = self._state_group_seq_gen.get_next_mult_txn( txn, num_state_groups ) + sg_before = prev_group for index, (event, context) in enumerate(events_and_context): - context.state_group_after_event = state_groups[index] - # The first prev_group will be the last persisted state group, which is passed in - # else it will be the group most recently assigned - if index > 0: - context.prev_group_for_state_group_after_event = state_groups[ - index - 1 - ] - context.state_group_before_event = state_groups[index - 1] - else: - context.prev_group_for_state_group_after_event = prev_group - context.state_group_before_event = prev_group - context.delta_ids_to_state_group_after_event = { + if not event.is_state(): + context.state_group_after_event = sg_before + context.state_group_before_event = sg_before + pass + + sg_after = state_groups[index] + context.state_group_after_event = sg_after + context.state_group_before_event = sg_before + context.delta_ids_to_state_group_before_event = { (event.type, event.state_key): event.event_id } context.state_delta_due_to_event = { (event.type, event.state_key): event.event_id } - index += 1 + sg_before = sg_after self.db_pool.simple_insert_many_txn( txn, @@ -492,19 +492,20 @@ def insert_deltas_group_txn( values=[ ( context.state_group_after_event, - context.prev_group_for_state_group_after_event, + context.state_group_before_event, ) for _, context in events_and_context ], ) + values = [] for _, context in events_and_context: - assert context.delta_ids_to_state_group_after_event is not None - self.db_pool.simple_insert_many_txn( - txn, - table="state_groups_state", - keys=("state_group", "room_id", "type", "state_key", "event_id"), - values=[ + assert context.delta_ids_to_state_group_before_event is not None + for ( + key, + state_id, + ) in context.delta_ids_to_state_group_before_event.items(): + values.append( ( context.state_group_after_event, room_id, @@ -512,9 +513,14 @@ def insert_deltas_group_txn( key[1], state_id, ) - for key, state_id in context.delta_ids_to_state_group_after_event.items() - ], - ) + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_state", + keys=("state_group", "room_id", "type", "state_key", "event_id"), + values=values, + ) return events_and_context return await self.db_pool.runInteraction( diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 46df79f730a5..ea61dd00ae09 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -379,7 +379,7 @@ def test_suppress_edits(self) -> None: bulk_evaluator = BulkPushRuleEvaluator(self.hs) # Create & persist an event to use as the parent of the relation. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -393,6 +393,7 @@ def test_suppress_edits(self) -> None: }, ) ) + context = self.get_success(unpersisted_context.persist(event)) self.get_success( self.event_creation_handler.handle_new_client_event( self.requester, events_and_context=[(event, context)] From 1f752ff52e6c602dd265350c971e001391624501 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Mon, 13 Feb 2023 18:10:59 -0800 Subject: [PATCH 10/14] re-add mistakenly removed spaces --- synapse/storage/databases/state/store.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 7873064b5e05..bf2213e207ec 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -567,7 +567,9 @@ def insert_delta_group_txn( txn: LoggingTransaction, prev_group: int, delta_ids: StateMap[str] ) -> Optional[int]: """Try and persist the new group as a delta. + Requires that we have the state as a delta from a previous state group. + Returns: The state group if successfully created, or None if the state needs to be persisted as a full state. From 8634ba668aac94bbbe37249b0e7729b003c90aa3 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Tue, 14 Feb 2023 17:28:44 -0800 Subject: [PATCH 11/14] deal more effectively with state/non-state --- synapse/events/snapshot.py | 29 ++++++++++++++++-------- synapse/storage/databases/state/store.py | 14 ++++-------- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 96fae2a46689..a91a5d1e3cbe 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -374,15 +374,26 @@ async def batch_persist_unpersisted_contexts( events_and_persisted_context = [] for event, unpersisted_context in amended_events_and_context: - context = EventContext( - storage=unpersisted_context._storage, - state_group=unpersisted_context.state_group_after_event, - state_group_before_event=unpersisted_context.state_group_before_event, - state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, - partial_state=unpersisted_context.partial_state, - prev_group=unpersisted_context.state_group_before_event, - delta_ids=unpersisted_context.state_delta_due_to_event, - ) + if event.is_state(): + context = EventContext( + storage=unpersisted_context._storage, + state_group=unpersisted_context.state_group_after_event, + state_group_before_event=unpersisted_context.state_group_before_event, + state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, + partial_state=unpersisted_context.partial_state, + prev_group=unpersisted_context.state_group_before_event, + delta_ids=unpersisted_context.state_delta_due_to_event, + ) + else: + context = EventContext( + storage=unpersisted_context._storage, + state_group=unpersisted_context.state_group_after_event, + state_group_before_event=unpersisted_context.state_group_before_event, + state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, + partial_state=unpersisted_context.partial_state, + prev_group=unpersisted_context.prev_group_for_state_group_before_event, + delta_ids=unpersisted_context.delta_ids_to_state_group_before_event, + ) events_and_persisted_context.append((event, context)) return events_and_persisted_context diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index bf2213e207ec..c4c2b271fcae 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -448,10 +448,9 @@ def insert_deltas_group_txn( % (prev_group,) ) - num_state_groups = 0 - for event, _ in events_and_context: - if event.is_state(): - num_state_groups += 1 + num_state_groups = sum( + 1 for event, _ in events_and_context if event.is_state() + ) state_groups = self._state_group_seq_gen.get_next_mult_txn( txn, num_state_groups @@ -467,9 +466,6 @@ def insert_deltas_group_txn( sg_after = state_groups[index] context.state_group_after_event = sg_after context.state_group_before_event = sg_before - context.delta_ids_to_state_group_before_event = { - (event.type, event.state_key): event.event_id - } context.state_delta_due_to_event = { (event.type, event.state_key): event.event_id } @@ -500,11 +496,11 @@ def insert_deltas_group_txn( values = [] for _, context in events_and_context: - assert context.delta_ids_to_state_group_before_event is not None + assert context.state_delta_due_to_event is not None for ( key, state_id, - ) in context.delta_ids_to_state_group_before_event.items(): + ) in context.state_delta_due_to_event.items(): values.append( ( context.state_group_after_event, From 3d564efb7ca381f0ec85c6e0a761ff95306be342 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Tue, 14 Feb 2023 20:33:56 -0800 Subject: [PATCH 12/14] fix tests typing --- tests/handlers/test_message.py | 6 ++++-- tests/push/test_bulk_push_rule_evaluator.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 9540b13b8514..9691d66b48a0 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -18,7 +18,7 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer @@ -79,7 +79,9 @@ def _create_and_persist_member_event(self) -> Tuple[EventBase, EventContext]: return memberEvent, memberEventContext - def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]: + def _create_duplicate_event( + self, txn_id: str + ) -> Tuple[EventBase, UnpersistedEventContextBase]: """Create a new event with the given transaction ID. All events produced by this method will be considered duplicates. """ diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index ea61dd00ae09..ad2ddfd88ad1 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -203,7 +203,7 @@ def _create_and_process( ) -> bool: """Returns true iff the `mentions` trigger an event push action.""" # Create a new message event which should cause a notification. - event, context = self.get_success( + event, unpersisted_context = self.get_success( self.event_creation_handler.create_event( self.requester, { @@ -214,7 +214,7 @@ def _create_and_process( }, ) ) - + context = self.get_success(unpersisted_context.persist(event)) # Execute the push rule machinery. self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) From 5811fd97f22d323cd4425ee27cf1f047b51bac22 Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 22 Feb 2023 21:01:18 -0800 Subject: [PATCH 13/14] requested changes --- synapse/storage/databases/state/store.py | 41 +++++++++++------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index c4c2b271fcae..daa96716b509 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -457,13 +457,14 @@ def insert_deltas_group_txn( ) sg_before = prev_group - for index, (event, context) in enumerate(events_and_context): + state_group_iter = iter(state_groups) + for event, context in events_and_context: if not event.is_state(): context.state_group_after_event = sg_before context.state_group_before_event = sg_before - pass + continue - sg_after = state_groups[index] + sg_after = next(state_group_iter) context.state_group_after_event = sg_after context.state_group_before_event = sg_before context.state_delta_due_to_event = { @@ -478,6 +479,7 @@ def insert_deltas_group_txn( values=[ (context.state_group_after_event, room_id, event.event_id) for event, context in events_and_context + if event.is_state() ], ) @@ -490,32 +492,27 @@ def insert_deltas_group_txn( context.state_group_after_event, context.state_group_before_event, ) - for _, context in events_and_context + for event, context in events_and_context + if event.is_state() ], ) - values = [] - for _, context in events_and_context: - assert context.state_delta_due_to_event is not None - for ( - key, - state_id, - ) in context.state_delta_due_to_event.items(): - values.append( - ( - context.state_group_after_event, - room_id, - key[0], - key[1], - state_id, - ) - ) - self.db_pool.simple_insert_many_txn( txn, table="state_groups_state", keys=("state_group", "room_id", "type", "state_key", "event_id"), - values=values, + values=[ + ( + context.state_group_after_event, + room_id, + key[0], + key[1], + state_id, + ) + for event, context in events_and_context + if context.state_delta_due_to_event is not None + for key, state_id in context.state_delta_due_to_event.items() + ], ) return events_and_context From d92b3a3253b8dda2cb48845b61c36684d48ca6df Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Wed, 22 Feb 2023 21:02:15 -0800 Subject: [PATCH 14/14] add a test for store_state_deltas_for_batched --- tests/storage/test_state.py | 126 ++++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index f730b888f7d2..ddac728dbd0c 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -496,3 +496,129 @@ def test_get_state_for_event(self) -> None: self.assertEqual(is_all, True) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) + + def test_batched_state_group_storing(self) -> None: + creation_event = self.inject_state_event( + self.room, self.u_alice, EventTypes.Create, "", {} + ) + state_to_event = self.get_success( + self.storage.state.get_state_groups( + self.room.to_string(), [creation_event.event_id] + ) + ) + current_state_group = list(state_to_event.keys())[0] + + # create some unpersisted events and event contexts to store against room + events_and_context = [] + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Name, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"name": "first rename of room"}, + }, + ) + + event1, unpersisted_context1 = self.get_success( + self.event_creation_handler.create_new_client_event(builder) + ) + events_and_context.append((event1, unpersisted_context1)) + + builder2 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.JoinRules, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"join_rule": "private"}, + }, + ) + + event2, unpersisted_context2 = self.get_success( + self.event_creation_handler.create_new_client_event(builder2) + ) + events_and_context.append((event2, unpersisted_context2)) + + builder3 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.Message, + "sender": self.u_alice.to_string(), + "room_id": self.room.to_string(), + "content": {"body": "hello from event 3", "msgtype": "m.text"}, + }, + ) + + event3, unpersisted_context3 = self.get_success( + self.event_creation_handler.create_new_client_event(builder3) + ) + events_and_context.append((event3, unpersisted_context3)) + + builder4 = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": EventTypes.JoinRules, + "sender": self.u_alice.to_string(), + "state_key": "", + "room_id": self.room.to_string(), + "content": {"join_rule": "public"}, + }, + ) + + event4, unpersisted_context4 = self.get_success( + self.event_creation_handler.create_new_client_event(builder4) + ) + events_and_context.append((event4, unpersisted_context4)) + + processed_events_and_context = self.get_success( + self.hs.get_datastores().state.store_state_deltas_for_batched( + events_and_context, self.room.to_string(), current_state_group + ) + ) + + # check that only state events are in state_groups, and all state events are in state_groups + res = self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups", + keyvalues=None, + retcols=("event_id",), + ) + ) + + events = [] + for result in res: + self.assertNotIn(event3.event_id, result) + events.append(result.get("event_id")) + + for event, _ in processed_events_and_context: + if event.is_state(): + self.assertIn(event.event_id, events) + + # check that each unique state has state group in state_groups_state and that the + # type/state key is correct, and check that each state event's state group + # has an entry and prev event in state_group_edges + for event, context in processed_events_and_context: + if event.is_state(): + state = self.get_success( + self.store.db_pool.simple_select_list( + table="state_groups_state", + keyvalues={"state_group": context.state_group_after_event}, + retcols=("type", "state_key"), + ) + ) + self.assertEqual(event.type, state[0].get("type")) + self.assertEqual(event.state_key, state[0].get("state_key")) + + groups = self.get_success( + self.store.db_pool.simple_select_list( + table="state_group_edges", + keyvalues={"state_group": str(context.state_group_after_event)}, + retcols=("*",), + ) + ) + self.assertEqual( + context.state_group_before_event, groups[0].get("prev_state_group") + )