From eb24cdb5cf0fcb6dc439bd3319a984f300548a69 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 25 Jan 2023 00:23:48 +0000 Subject: [PATCH 01/36] WIP mypy plugin to check `@cached` return types --- scripts-dev/mypy_synapse_plugin.py | 139 ++++++++++++++++++++++++++++- 1 file changed, 136 insertions(+), 3 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 2c377533c0fd..a46a973548a6 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -16,12 +16,22 @@ can crop up, e.g the cache descriptors. """ -from typing import Callable, Optional, Type +from typing import Callable, Optional, Tuple, Type +import mypy.types +from mypy.errorcodes import ErrorCode from mypy.nodes import ARG_NAMED_OPT from mypy.plugin import MethodSigContext, Plugin from mypy.typeops import bind_self -from mypy.types import CallableType, NoneType, UnionType +from mypy.types import ( + AnyType, + CallableType, + Instance, + NoneType, + TupleType, + TypeAliasType, + UnionType, +) class SynapsePlugin(Plugin): @@ -48,7 +58,7 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: """ # First we mark this as a bound function signature. - signature = bind_self(ctx.default_signature) + signature: CallableType = bind_self(ctx.default_signature) # Secondly, we remove any "cache_context" args. # @@ -98,9 +108,132 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: arg_kinds=arg_kinds, ) + # 4. Complain loudly if we are returning something mutable + check_is_cacheable(signature, ctx) + return signature +def clean_message(value: str) -> str: + return value.replace("builtins.", "").replace("typing.", "") + + +def unwrap_awaitable_types(t: mypy.types.Type) -> mypy.types.Type: + if isinstance(t, Instance): + if t.type.fullname == "typing.Coroutine": + # We're assuming this is Awaitable[R]m aka Coroutine[None, None, R]. + # TODO: assert yield type and send type are None + # Extract the `R` type from `Coroutine[Y, S, R]`, and reinspect + t = t.args[2] + elif t.type.fullname == "twisted.internet.defer.Deferred": + t = t.args[0] + + return mypy.types.get_proper_type(t) + + +def check_is_cacheable(signature: CallableType, ctx: MethodSigContext) -> None: + return_type = unwrap_awaitable_types(signature.ret_type) + verbose = ctx.api.options.verbosity >= 1 + ok, note = is_cacheable(return_type, signature, verbose) + + if ok: + message = f"function {signature.name} is @cached, returning {return_type}" + else: + message = f"function {signature.name} is @cached, but has mutable return value {return_type}" + + if note: + message += f" ({note})" + message = clean_message(message) + + if ok and note: + ctx.api.note(message, ctx.context) # type: ignore[attr-defined] + elif not ok: + ctx.api.fail(message, ctx.context, code=AT_CACHED_MUTABLE_RETURN) + + +IMMUTABLE_BUILTIN_VALUE_TYPES = { + "builtins.bool", + "builtins.int", + "builtins.float", + "builtins.str", + "builtins.bytes", +} + +IMMUTABLE_BUILTIN_CONTAINER_TYPES = { + "builtins.frozenset", + "typing.AbstractSet", +} + +MUTABLE_BUILTIN_CONTAINER_TYPES = { + "builtins.set", + "builtins.list", + "builtins.dict", + "typing.Sequence", +} + +AT_CACHED_MUTABLE_RETURN = ErrorCode( + "synapse-@cached-mutable", + "@cached() should have an immutable return type", + "General", +) + + +def is_cacheable( + rt: mypy.types.Type, signature: CallableType, verbose: bool +) -> Tuple[bool, Optional[str]]: + """ + Returns: a 2-tuple (cacheable, message). + - cachable: False means the type is definitely not cacheable; + true means anything else. + - Optional message. + """ + + # This should probably be done via a TypeVisitor. Apologies to the reader! + if isinstance(rt, AnyType): + return True, ("may be mutable" if verbose else None) + + elif isinstance(rt, Instance): + if rt.type.fullname in IMMUTABLE_BUILTIN_VALUE_TYPES: + return True, None + + elif rt.type.fullname == "typing.Mapping": + return is_cacheable(rt.args[1], signature, verbose) + + elif rt.type.fullname in IMMUTABLE_BUILTIN_CONTAINER_TYPES: + return is_cacheable(rt.args[0], signature, verbose) + + elif rt.type.fullname in MUTABLE_BUILTIN_CONTAINER_TYPES: + return False, None + + elif "attrs" in rt.type.metadata: + frozen = rt.type.metadata["attrs"].get("frozen", False) + if frozen: + # TODO: should really check that all of the fields are also cacheable + return True, None + else: + return False, "non-frozen attrs class" + + else: + return False, f"Don't know how to handle {rt.type.fullname}" + + elif isinstance(rt, NoneType): + return True, None + + elif isinstance(rt, (TupleType, UnionType)): + for item in rt.items: + ok, note = is_cacheable(item, signature, verbose) + if not ok: + return False, note + # This discards notes but that's probably fine + return True, None + + elif isinstance(rt, TypeAliasType): + return is_cacheable(mypy.types.get_proper_type(rt), signature, verbose) + + else: + return False, f"Don't know how to handle {type(rt).__qualname__} return type" + + def plugin(version: str) -> Type[SynapsePlugin]: # This is the entry point of the plugin, and lets us deal with the fact # that the mypy plugin interface is *not* stable by looking at the version From d2cbac3b2e77974383cf9499d48a24ae74a595c6 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 31 Jan 2023 17:39:22 +0000 Subject: [PATCH 02/36] Whoops, Sequence is immutable --- scripts-dev/mypy_synapse_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index a46a973548a6..17c801fec793 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -162,13 +162,13 @@ def check_is_cacheable(signature: CallableType, ctx: MethodSigContext) -> None: IMMUTABLE_BUILTIN_CONTAINER_TYPES = { "builtins.frozenset", "typing.AbstractSet", + "typing.Sequence", } MUTABLE_BUILTIN_CONTAINER_TYPES = { "builtins.set", "builtins.list", "builtins.dict", - "typing.Sequence", } AT_CACHED_MUTABLE_RETURN = ErrorCode( From 90631ac60b3f98f79511ebc44fdb64b4603286b1 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 31 Jan 2023 18:13:29 +0000 Subject: [PATCH 03/36] WIP --- scripts-dev/mypy_synapse_plugin.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 17c801fec793..6116e3ab6ccb 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -151,7 +151,7 @@ def check_is_cacheable(signature: CallableType, ctx: MethodSigContext) -> None: ctx.api.fail(message, ctx.context, code=AT_CACHED_MUTABLE_RETURN) -IMMUTABLE_BUILTIN_VALUE_TYPES = { +IMMUTABLE_VALUE_TYPES = { "builtins.bool", "builtins.int", "builtins.float", @@ -159,13 +159,16 @@ def check_is_cacheable(signature: CallableType, ctx: MethodSigContext) -> None: "builtins.bytes", } -IMMUTABLE_BUILTIN_CONTAINER_TYPES = { +IMMUTABLE_CONTAINER_TYPES_REQUIRING_HASHABLE_ELEMENTS = { "builtins.frozenset", "typing.AbstractSet", - "typing.Sequence", } -MUTABLE_BUILTIN_CONTAINER_TYPES = { +IMMUTABLE_CONTAINER_TYPES_ALLOWING_MUTABLE_ELEMENTS = { + "typing.Sequence" +} + +MUTABLE_CONTAINER_TYPES = { "builtins.set", "builtins.list", "builtins.dict", @@ -193,16 +196,20 @@ def is_cacheable( return True, ("may be mutable" if verbose else None) elif isinstance(rt, Instance): - if rt.type.fullname in IMMUTABLE_BUILTIN_VALUE_TYPES: + if ( + rt.type.fullname in IMMUTABLE_VALUE_TYPES + or rt.type.fullname in IMMUTABLE_CONTAINER_TYPES_REQUIRING_HASHABLE_ELEMENTS + ): return True, None elif rt.type.fullname == "typing.Mapping": return is_cacheable(rt.args[1], signature, verbose) - elif rt.type.fullname in IMMUTABLE_BUILTIN_CONTAINER_TYPES: + elif rt.type.fullname in IMMUTABLE_CONTAINER_TYPES_ALLOWING_MUTABLE_ELEMENTS: + # E.g. Collection[T] is cachable iff T is cachable. return is_cacheable(rt.args[0], signature, verbose) - elif rt.type.fullname in MUTABLE_BUILTIN_CONTAINER_TYPES: + elif rt.type.fullname in MUTABLE_CONTAINER_TYPES: return False, None elif "attrs" in rt.type.metadata: From ad1c28a56012f5a5dbb898bdcc1ca68bb1270eff Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 12 Sep 2023 15:24:19 -0400 Subject: [PATCH 04/36] Update comments. --- scripts-dev/mypy_synapse_plugin.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index e2fa03802264..ce4c64e2d091 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -59,10 +59,10 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: 3. an optional keyword argument `on_invalidated` should be added. """ - # First we mark this as a bound function signature. + # 1. Mark this as a bound function signature. signature: CallableType = bind_self(ctx.default_signature) - # Secondly, we remove any "cache_context" args. + # 2. Remove any "cache_context" args. # # Note: We should be only doing this if `cache_context=True` is set, but if # it isn't then the code will raise an exception when its called anyway, so @@ -82,7 +82,7 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: arg_names.pop(context_arg_index) arg_kinds.pop(context_arg_index) - # Third, we add an optional "on_invalidate" argument. + # 3. Add an optional "on_invalidate" argument. # # This is a either # - a callable which accepts no input and returns nothing, or @@ -104,7 +104,7 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: arg_names.append("on_invalidate") arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg. - # Finally we ensure the return type is a Deferred. + # 4. Ensure the return type is a Deferred. if ( isinstance(signature.ret_type, Instance) and signature.ret_type.type.fullname == "twisted.internet.defer.Deferred" @@ -141,7 +141,7 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: ret_type=ret_type, ) - # 4. Complain loudly if we are returning something mutable + # 5. Complain loudly if we are returning something mutable check_is_cacheable(signature, ctx) return signature From 88323dd3f1774808a1f5232d60a8cf01da326d39 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 12 Sep 2023 15:38:28 -0400 Subject: [PATCH 05/36] Simplify code due to other knowledge. --- scripts-dev/mypy_synapse_plugin.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index ce4c64e2d091..106b26a65b95 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -142,30 +142,20 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: ) # 5. Complain loudly if we are returning something mutable - check_is_cacheable(signature, ctx) + check_is_cacheable(signature, ctx, ret_type) return signature -def clean_message(value: str) -> str: - return value.replace("builtins.", "").replace("typing.", "") +def check_is_cacheable( + signature: CallableType, + ctx: MethodSigContext, + deferred_return_type: Instance, +) -> None: + # The previous code wraps the return type into a Deferred. + assert deferred_return_type.type.fullname == "twisted.internet.defer.Deferred" + return_type = deferred_return_type.args[0] - -def unwrap_awaitable_types(t: mypy.types.Type) -> mypy.types.Type: - if isinstance(t, Instance): - if t.type.fullname == "typing.Coroutine": - # We're assuming this is Awaitable[R]m aka Coroutine[None, None, R]. - # TODO: assert yield type and send type are None - # Extract the `R` type from `Coroutine[Y, S, R]`, and reinspect - t = t.args[2] - elif t.type.fullname == "twisted.internet.defer.Deferred": - t = t.args[0] - - return mypy.types.get_proper_type(t) - - -def check_is_cacheable(signature: CallableType, ctx: MethodSigContext) -> None: - return_type = unwrap_awaitable_types(signature.ret_type) verbose = ctx.api.options.verbosity >= 1 ok, note = is_cacheable(return_type, signature, verbose) @@ -176,7 +166,7 @@ def check_is_cacheable(signature: CallableType, ctx: MethodSigContext) -> None: if note: message += f" ({note})" - message = clean_message(message) + message = message.replace("builtins.", "").replace("typing.", "") if ok and note: ctx.api.note(message, ctx.context) # type: ignore[attr-defined] From 676c858c96d04d59ed04d30b2be78584ed2efb6e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 12 Sep 2023 15:52:44 -0400 Subject: [PATCH 06/36] Treat containers more similarly. --- scripts-dev/mypy_synapse_plugin.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 106b26a65b95..5f7a4385a883 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -174,6 +174,7 @@ def check_is_cacheable( ctx.api.fail(message, ctx.context, code=AT_CACHED_MUTABLE_RETURN) +# Immutable simple values. IMMUTABLE_VALUE_TYPES = { "builtins.bool", "builtins.int", @@ -182,13 +183,13 @@ def check_is_cacheable( "builtins.bytes", } -IMMUTABLE_CONTAINER_TYPES_REQUIRING_HASHABLE_ELEMENTS = { +# Immutable containers only if the values are also immutable. +IMMUTABLE_CONTAINER_TYPES_REQUIRING_IMMUTABLE_ELEMENTS = { "builtins.frozenset", "typing.AbstractSet", + "typing.Sequence", } -IMMUTABLE_CONTAINER_TYPES_ALLOWING_MUTABLE_ELEMENTS = {"typing.Sequence"} - MUTABLE_CONTAINER_TYPES = { "builtins.set", "builtins.list", @@ -217,16 +218,13 @@ def is_cacheable( return True, ("may be mutable" if verbose else None) elif isinstance(rt, Instance): - if ( - rt.type.fullname in IMMUTABLE_VALUE_TYPES - or rt.type.fullname in IMMUTABLE_CONTAINER_TYPES_REQUIRING_HASHABLE_ELEMENTS - ): + if rt.type.fullname in IMMUTABLE_VALUE_TYPES: return True, None elif rt.type.fullname == "typing.Mapping": return is_cacheable(rt.args[1], signature, verbose) - elif rt.type.fullname in IMMUTABLE_CONTAINER_TYPES_ALLOWING_MUTABLE_ELEMENTS: + elif rt.type.fullname in IMMUTABLE_CONTAINER_TYPES_REQUIRING_IMMUTABLE_ELEMENTS: # E.g. Collection[T] is cachable iff T is cachable. return is_cacheable(rt.args[0], signature, verbose) From 008ef3f1652434e16502a1c672d0c43c0de38b9a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 12 Sep 2023 15:57:42 -0400 Subject: [PATCH 07/36] cachedList wraps Mapping --- scripts-dev/mypy_synapse_plugin.py | 2 ++ synapse/storage/databases/main/devices.py | 2 +- synapse/storage/databases/main/end_to_end_keys.py | 2 +- synapse/storage/databases/main/events_worker.py | 5 +++-- synapse/storage/databases/main/keys.py | 6 +++--- synapse/storage/databases/main/presence.py | 14 ++++++++++++-- synapse/storage/databases/main/push_rule.py | 2 +- synapse/storage/databases/main/receipts.py | 2 +- synapse/storage/databases/main/relations.py | 6 +++--- synapse/storage/databases/main/roommember.py | 8 ++++---- synapse/storage/databases/main/state.py | 14 ++++++++++++-- synapse/storage/databases/main/transactions.py | 4 ++-- .../storage/databases/main/user_erasure_store.py | 4 ++-- 13 files changed, 47 insertions(+), 24 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 5f7a4385a883..1be17102f19e 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -157,6 +157,8 @@ def check_is_cacheable( return_type = deferred_return_type.args[0] verbose = ctx.api.options.verbosity >= 1 + # TODO Technically a cachedList only needs immutable values, but forcing them + # to return Mapping instead of Dict is fine. ok, note = is_cacheable(return_type, signature, verbose) if ok: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 70faf4b1ecca..d3d8b866665e 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1042,7 +1042,7 @@ async def get_device_list_last_stream_id_for_remote( ) async def get_device_list_last_stream_id_for_remotes( self, user_ids: Iterable[str] - ) -> Dict[str, Optional[str]]: + ) -> Mapping[str, Optional[str]]: rows = await self.db_pool.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index b49dea577cba..835cb37b338c 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -846,7 +846,7 @@ def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDic ) async def _get_bare_e2e_cross_signing_keys_bulk( self, user_ids: Iterable[str] - ) -> Dict[str, Optional[Mapping[str, JsonDict]]]: + ) -> Mapping[str, Optional[Mapping[str, JsonDict]]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 1eb313040ed9..b788d70fc500 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -24,6 +24,7 @@ Dict, Iterable, List, + Mapping, MutableMapping, Optional, Set, @@ -1633,7 +1634,7 @@ async def _have_seen_events_dict( self, room_id: str, event_ids: Collection[str], - ) -> Dict[str, bool]: + ) -> Mapping[str, bool]: """Helper for have_seen_events Returns: @@ -2325,7 +2326,7 @@ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: @cachedList(cached_method_name="is_partial_state_event", list_name="event_ids") async def get_partial_state_events( self, event_ids: Collection[str] - ) -> Dict[str, bool]: + ) -> Mapping[str, bool]: """Checks which of the given events have partial state Args: diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 41563371dcd2..889c578b9c97 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -16,7 +16,7 @@ import itertools import json import logging -from typing import Dict, Iterable, Optional, Tuple +from typing import Dict, Iterable, Mapping, Optional, Tuple from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes @@ -130,7 +130,7 @@ def _get_server_keys_json( ) async def get_server_keys_json( self, server_name_and_key_ids: Iterable[Tuple[str, str]] - ) -> Dict[Tuple[str, str], FetchKeyResult]: + ) -> Mapping[Tuple[str, str], FetchKeyResult]: """ Args: server_name_and_key_ids: @@ -200,7 +200,7 @@ def get_server_key_json_for_remote( ) async def get_server_keys_json_for_remote( self, server_name: str, key_ids: Iterable[str] - ) -> Dict[str, Optional[FetchKeyResultForRemote]]: + ) -> Mapping[str, Optional[FetchKeyResultForRemote]]: """Fetch the cached keys for the given server/key IDs. If we have multiple entries for a given key ID, returns the most recent. diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index b51d20ac266c..194b4e031f73 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -11,7 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + cast, +) from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream @@ -249,7 +259,7 @@ def _get_presence_for_user(self, user_id: str) -> None: ) async def get_presence_for_users( self, user_ids: Iterable[str] - ) -> Dict[str, UserPresenceState]: + ) -> Mapping[str, UserPresenceState]: rows = await self.db_pool.simple_select_many_batch( table="presence_stream", column="user_id", diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index bec0dc2afeeb..af69944008e1 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -216,7 +216,7 @@ def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool: @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids") async def bulk_get_push_rules( self, user_ids: Collection[str] - ) -> Dict[str, FilteredPushRules]: + ) -> Mapping[str, FilteredPushRules]: if not user_ids: return {} diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index e4d10ff250d1..d3fb0ef6a02f 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -353,7 +353,7 @@ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: ) async def _get_linearized_receipts_for_rooms( self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None - ) -> Dict[str, Sequence[JsonDict]]: + ) -> Mapping[str, Sequence[JsonDict]]: if not room_ids: return {} diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 96908f14ba35..6ba9c9651f97 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -519,7 +519,7 @@ def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") async def get_applicable_edits( self, event_ids: Collection[str] - ) -> Dict[str, Optional[EventBase]]: + ) -> Mapping[str, Optional[EventBase]]: """Get the most recent edit (if any) that has happened for the given events. @@ -605,7 +605,7 @@ def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]: @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") async def get_thread_summaries( self, event_ids: Collection[str] - ) -> Dict[str, Optional[Tuple[int, EventBase]]]: + ) -> Mapping[str, Optional[Tuple[int, EventBase]]]: """Get the number of threaded replies and the latest reply (if any) for the given events. Args: @@ -779,7 +779,7 @@ def get_thread_participated(self, event_id: str, user_id: str) -> bool: @cachedList(cached_method_name="get_thread_participated", list_name="event_ids") async def get_threads_participated( self, event_ids: Collection[str], user_id: str - ) -> Dict[str, bool]: + ) -> Mapping[str, bool]: """Get whether the requesting user participated in the given threads. This is separate from get_thread_summaries since that can be cached across diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index fff259f74c95..7b503dd697e9 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -191,7 +191,7 @@ def get_user_in_room_with_profile(self, room_id: str, user_id: str) -> ProfileIn ) async def get_subset_users_in_room_with_profiles( self, room_id: str, user_ids: Collection[str] - ) -> Dict[str, ProfileInfo]: + ) -> Mapping[str, ProfileInfo]: """Get a mapping from user ID to profile information for a list of users in a given room. @@ -676,7 +676,7 @@ async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]: ) async def _get_rooms_for_users( self, user_ids: Collection[str] - ) -> Dict[str, FrozenSet[str]]: + ) -> Mapping[str, FrozenSet[str]]: """A batched version of `get_rooms_for_user`. Returns: @@ -881,7 +881,7 @@ def _get_user_id_from_membership_event_id( ) async def _get_user_ids_from_membership_event_ids( self, event_ids: Iterable[str] - ) -> Dict[str, Optional[str]]: + ) -> Mapping[str, Optional[str]]: """For given set of member event_ids check if they point to a join event. @@ -1191,7 +1191,7 @@ async def _get_membership_from_event_id( ) async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] - ) -> Dict[str, Optional[EventIdMembership]]: + ) -> Mapping[str, Optional[EventIdMembership]]: """Get user_id and membership of a set of event IDs. Returns: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index ebb2ae964f5a..5eaaff5b6864 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -14,7 +14,17 @@ # limitations under the License. import collections.abc import logging -from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + Mapping, + Optional, + Set, + Tuple, +) import attr @@ -372,7 +382,7 @@ async def _get_state_group_for_event(self, event_id: str) -> Optional[int]: ) async def _get_state_group_for_events( self, event_ids: Collection[str] - ) -> Dict[str, int]: + ) -> Mapping[str, int]: """Returns mapping event_id -> state_group. Raises: diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index efd21b5bfceb..8f70eff80916 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -14,7 +14,7 @@ import logging from enum import Enum -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Tuple, cast import attr from canonicaljson import encode_canonical_json @@ -210,7 +210,7 @@ def _get_destination_retry_timings( ) async def get_destination_retry_timings_batch( self, destinations: StrCollection - ) -> Dict[str, Optional[DestinationRetryTimings]]: + ) -> Mapping[str, Optional[DestinationRetryTimings]]: rows = await self.db_pool.simple_select_many_batch( table="destinations", iterable=destinations, diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index f79006533f3c..06fcbe5e54fd 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Iterable +from typing import Iterable, Mapping from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore @@ -40,7 +40,7 @@ async def is_user_erased(self, user_id: str) -> bool: return bool(result) @cachedList(cached_method_name="is_user_erased", list_name="user_ids") - async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]: + async def are_users_erased(self, user_ids: Iterable[str]) -> Mapping[str, bool]: """ Checks which users in a list have requested erasure From 8aa4e874f3aee5954f1cf1adbf591ebd04d133aa Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 Sep 2023 10:52:07 -0400 Subject: [PATCH 08/36] Fix-up errors in tests. --- scripts-dev/mypy_synapse_plugin.py | 7 ++++++ tests/util/caches/test_descriptors.py | 35 +++++++++++++++------------ 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 1be17102f19e..76e930abe008 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -31,6 +31,7 @@ NoneType, TupleType, TypeAliasType, + UninhabitedType, UnionType, ) @@ -188,6 +189,7 @@ def check_is_cacheable( # Immutable containers only if the values are also immutable. IMMUTABLE_CONTAINER_TYPES_REQUIRING_IMMUTABLE_ELEMENTS = { "builtins.frozenset", + "builtins.tuple", "typing.AbstractSet", "typing.Sequence", } @@ -258,6 +260,11 @@ def is_cacheable( elif isinstance(rt, TypeAliasType): return is_cacheable(mypy.types.get_proper_type(rt), signature, verbose) + # The tests check what happens if you raise an Exception, so they don't return. + elif isinstance(rt, UninhabitedType) and rt.is_noreturn: + # There's no return value, just consider it cachable. + return True, None + else: return False, f"Don't know how to handle {type(rt).__qualname__} return type" diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 168419f440fb..7e8725e610c7 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -15,10 +15,10 @@ import logging from typing import ( Any, - Dict, Generator, Iterable, List, + Mapping, NoReturn, Optional, Set, @@ -96,7 +96,7 @@ def __init__(self) -> None: self.mock = mock.Mock() @descriptors.cached(num_args=1) - def fn(self, arg1: int, arg2: int) -> mock.Mock: + def fn(self, arg1: int, arg2: int) -> str: return self.mock(arg1, arg2) obj = Cls() @@ -228,8 +228,9 @@ class Cls: call_count = 0 @cached() - def fn(self, arg1: int) -> Optional[Deferred]: + def fn(self, arg1: int) -> Deferred: self.call_count += 1 + assert self.result is not None return self.result obj = Cls() @@ -401,21 +402,21 @@ def __init__(self) -> None: self.mock = mock.Mock() @descriptors.cached(iterable=True) - def fn(self, arg1: int, arg2: int) -> List[str]: + def fn(self, arg1: int, arg2: int) -> Tuple[str, ...]: return self.mock(arg1, arg2) obj = Cls() - obj.mock.return_value = ["spam", "eggs"] + obj.mock.return_value = ("spam", "eggs") r = obj.fn(1, 2) - self.assertEqual(r.result, ["spam", "eggs"]) + self.assertEqual(r.result, ("spam", "eggs")) obj.mock.assert_called_once_with(1, 2) obj.mock.reset_mock() # a call with different params should call the mock again - obj.mock.return_value = ["chips"] + obj.mock.return_value = ("chips",) r = obj.fn(1, 3) - self.assertEqual(r.result, ["chips"]) + self.assertEqual(r.result, ("chips",)) obj.mock.assert_called_once_with(1, 3) obj.mock.reset_mock() @@ -423,9 +424,9 @@ def fn(self, arg1: int, arg2: int) -> List[str]: self.assertEqual(len(obj.fn.cache.cache), 3) r = obj.fn(1, 2) - self.assertEqual(r.result, ["spam", "eggs"]) + self.assertEqual(r.result, ("spam", "eggs")) r = obj.fn(1, 3) - self.assertEqual(r.result, ["chips"]) + self.assertEqual(r.result, ("chips",)) obj.mock.assert_not_called() def test_cache_iterable_with_sync_exception(self) -> None: @@ -784,7 +785,9 @@ def fn(self, arg1: int, arg2: int) -> None: pass @descriptors.cachedList(cached_method_name="fn", list_name="args1") - async def list_fn(self, args1: Iterable[int], arg2: int) -> Dict[int, str]: + async def list_fn( + self, args1: Iterable[int], arg2: int + ) -> Mapping[int, str]: context = current_context() assert isinstance(context, LoggingContext) assert context.name == "c1" @@ -847,11 +850,11 @@ def fn(self, arg1: int) -> None: pass @descriptors.cachedList(cached_method_name="fn", list_name="args1") - def list_fn(self, args1: List[int]) -> "Deferred[dict]": + def list_fn(self, args1: List[int]) -> "Deferred[Mapping[int, str]]": return self.mock(args1) obj = Cls() - deferred_result: "Deferred[dict]" = Deferred() + deferred_result: "Deferred[Mapping[int, str]]" = Deferred() obj.mock.return_value = deferred_result # start off several concurrent lookups of the same key @@ -890,7 +893,7 @@ def fn(self, arg1: int, arg2: int) -> None: pass @descriptors.cachedList(cached_method_name="fn", list_name="args1") - async def list_fn(self, args1: List[int], arg2: int) -> Dict[int, str]: + async def list_fn(self, args1: List[int], arg2: int) -> Mapping[int, str]: # we want this to behave like an asynchronous function await run_on_reactor() return self.mock(args1, arg2) @@ -929,7 +932,7 @@ def fn(self, arg1: int) -> None: pass @cachedList(cached_method_name="fn", list_name="args") - async def list_fn(self, args: List[int]) -> Dict[int, str]: + async def list_fn(self, args: List[int]) -> Mapping[int, str]: await complete_lookup return {arg: str(arg) for arg in args} @@ -964,7 +967,7 @@ def fn(self, arg1: int) -> None: pass @cachedList(cached_method_name="fn", list_name="args") - async def list_fn(self, args: List[int]) -> Dict[int, str]: + async def list_fn(self, args: List[int]) -> Mapping[int, str]: await make_deferred_yieldable(complete_lookup) self.inner_context_was_finished = current_context().finished return {arg: str(arg) for arg in args} From 0f3c03695faf4c7079a8429f66dcdbdbc3ebd0d9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 Sep 2023 10:57:41 -0400 Subject: [PATCH 09/36] Ignore a few calls which purposefully (?) return mutable objects. --- scripts-dev/mypy_synapse_plugin.py | 1 + synapse/handlers/relations.py | 4 ++-- synapse/storage/controllers/state.py | 2 +- synapse/storage/databases/main/relations.py | 2 ++ synapse/storage/databases/main/roommember.py | 1 + 5 files changed, 7 insertions(+), 3 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 76e930abe008..9370ec549fba 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -171,6 +171,7 @@ def check_is_cacheable( message += f" ({note})" message = message.replace("builtins.", "").replace("typing.", "") + # TODO The context is the context of the caller, not the method itself. if ok and note: ctx.api.note(message, ctx.context) # type: ignore[attr-defined] elif not ok: diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index db97f7aedee6..74442b1eb917 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -310,7 +310,7 @@ async def _get_threads_for_events( event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id] # Fetch thread summaries. - summaries = await self._main_store.get_thread_summaries(event_ids) + summaries = await self._main_store.get_thread_summaries(event_ids) # type: ignore[synapse-@cached-mutable] # Limit fetching whether the requester has participated in a thread to # events which are thread roots. @@ -504,7 +504,7 @@ async def _fetch_edits() -> None: Note that there is no use in limiting edits by ignored users since the parent event should be ignored in the first place if the user is ignored. """ - edits = await self._main_store.get_applicable_edits( + edits = await self._main_store.get_applicable_edits( # type: ignore[synapse-@cached-mutable] [ event_id for event_id, event in events_by_id.items() diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 278c7832ba01..33f04e7572c9 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -682,7 +682,7 @@ async def _get_joined_hosts( # `get_joined_hosts` is called with the "current" state group for the # room, and so consecutive calls will be for consecutive state groups # which point to the previous state group. - cache = await self.stores.main._get_joined_hosts_cache(room_id) + cache = await self.stores.main._get_joined_hosts_cache(room_id) # type: ignore[synapse-@cached-mutable] # If the state group in the cache matches, we already have the data we need. if state_entry.state_group == cache.state_group: diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 6ba9c9651f97..ef08192fdb81 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -516,6 +516,7 @@ def _get_references_for_events_txn( def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: raise NotImplementedError() + # TODO: This returns a mutable object, which is generally bad. @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") async def get_applicable_edits( self, event_ids: Collection[str] @@ -602,6 +603,7 @@ def _get_applicable_edits_txn(txn: LoggingTransaction) -> Dict[str, str]: def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]: raise NotImplementedError() + # TODO: This returns a mutable object, which is generally bad. @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") async def get_thread_summaries( self, event_ids: Collection[str] diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 7b503dd697e9..87932b9827c0 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -1069,6 +1069,7 @@ async def _get_approximate_current_memberships_in_room( ) return {row["event_id"]: row["membership"] for row in rows} + # TODO This returns a mutable object, which is generally confusing when using a cache. @cached(max_entries=10000) def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache": return _JoinedHostsCache() From 9c945745e6825f81cc48ceb91760ac96a3efbb51 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 Sep 2023 11:10:43 -0400 Subject: [PATCH 10/36] Data exfilitration is read-only and update admin APIs. --- synapse/app/admin_cmd.py | 14 +++++++------- synapse/handlers/admin.py | 18 +++++++++--------- synapse/rest/admin/users.py | 8 ++++---- .../databases/main/experimental_features.py | 7 +++---- 4 files changed, 23 insertions(+), 24 deletions(-) diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index f9aada269a0a..aa24f7da6cae 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -17,7 +17,7 @@ import os import sys import tempfile -from typing import List, Mapping, Optional +from typing import List, Mapping, Optional, Sequence from twisted.internet import defer, task @@ -57,7 +57,7 @@ from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore -from synapse.types import JsonDict, StateMap +from synapse.types import JsonMapping, StateMap from synapse.util import SYNAPSE_VERSION from synapse.util.logcontext import LoggingContext @@ -198,7 +198,7 @@ def write_knock( for event in state.values(): json.dump(event, fp=f) - def write_profile(self, profile: JsonDict) -> None: + def write_profile(self, profile: JsonMapping) -> None: user_directory = os.path.join(self.base_directory, "user_data") os.makedirs(user_directory, exist_ok=True) profile_file = os.path.join(user_directory, "profile") @@ -206,7 +206,7 @@ def write_profile(self, profile: JsonDict) -> None: with open(profile_file, "a") as f: json.dump(profile, fp=f) - def write_devices(self, devices: List[JsonDict]) -> None: + def write_devices(self, devices: Sequence[JsonMapping]) -> None: user_directory = os.path.join(self.base_directory, "user_data") os.makedirs(user_directory, exist_ok=True) device_file = os.path.join(user_directory, "devices") @@ -215,7 +215,7 @@ def write_devices(self, devices: List[JsonDict]) -> None: with open(device_file, "a") as f: json.dump(device, fp=f) - def write_connections(self, connections: List[JsonDict]) -> None: + def write_connections(self, connections: Sequence[JsonMapping]) -> None: user_directory = os.path.join(self.base_directory, "user_data") os.makedirs(user_directory, exist_ok=True) connection_file = os.path.join(user_directory, "connections") @@ -225,7 +225,7 @@ def write_connections(self, connections: List[JsonDict]) -> None: json.dump(connection, fp=f) def write_account_data( - self, file_name: str, account_data: Mapping[str, JsonDict] + self, file_name: str, account_data: Mapping[str, JsonMapping] ) -> None: account_data_directory = os.path.join( self.base_directory, "user_data", "account_data" @@ -237,7 +237,7 @@ def write_account_data( with open(account_data_file, "a") as f: json.dump(account_data, fp=f) - def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None: + def write_media_id(self, media_id: str, media_metadata: JsonMapping) -> None: file_directory = os.path.join(self.base_directory, "media_ids") os.makedirs(file_directory, exist_ok=True) media_id_file = os.path.join(file_directory, media_id) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 2f0e5f3b0a9e..987045019380 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -14,11 +14,11 @@ import abc import logging -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set from synapse.api.constants import Direction, Membership from synapse.events import EventBase -from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID +from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -35,7 +35,7 @@ def __init__(self, hs: "HomeServer"): self._state_storage_controller = self._storage_controllers.state self._msc3866_enabled = hs.config.experimental.msc3866.enabled - async def get_whois(self, user: UserID) -> JsonDict: + async def get_whois(self, user: UserID) -> JsonMapping: connections = [] sessions = await self._store.get_user_ip_and_agents(user) @@ -55,7 +55,7 @@ async def get_whois(self, user: UserID) -> JsonDict: return ret - async def get_user(self, user: UserID) -> Optional[JsonDict]: + async def get_user(self, user: UserID) -> Optional[JsonMapping]: """Function to get user details""" user_info_dict = await self._store.get_user_by_id(user.to_string()) if user_info_dict is None: @@ -349,7 +349,7 @@ def write_knock( raise NotImplementedError() @abc.abstractmethod - def write_profile(self, profile: JsonDict) -> None: + def write_profile(self, profile: JsonMapping) -> None: """Write the profile of a user. Args: @@ -358,7 +358,7 @@ def write_profile(self, profile: JsonDict) -> None: raise NotImplementedError() @abc.abstractmethod - def write_devices(self, devices: List[JsonDict]) -> None: + def write_devices(self, devices: Sequence[JsonMapping]) -> None: """Write the devices of a user. Args: @@ -367,7 +367,7 @@ def write_devices(self, devices: List[JsonDict]) -> None: raise NotImplementedError() @abc.abstractmethod - def write_connections(self, connections: List[JsonDict]) -> None: + def write_connections(self, connections: Sequence[JsonMapping]) -> None: """Write the connections of a user. Args: @@ -377,7 +377,7 @@ def write_connections(self, connections: List[JsonDict]) -> None: @abc.abstractmethod def write_account_data( - self, file_name: str, account_data: Mapping[str, JsonDict] + self, file_name: str, account_data: Mapping[str, JsonMapping] ) -> None: """Write the account data of a user. @@ -388,7 +388,7 @@ def write_account_data( raise NotImplementedError() @abc.abstractmethod - def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None: + def write_media_id(self, media_id: str, media_metadata: JsonMapping) -> None: """Write the media's metadata of a user. Exports only the metadata, as this can be fetched from the database via read only. In order to access the files, a connection to the correct diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 91898a5c135c..9aaa88e22987 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -39,7 +39,7 @@ from synapse.rest.client._base import client_patterns from synapse.storage.databases.main.registration import ExternalIDReuseException from synapse.storage.databases.main.stats import UserSortOrder -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, JsonMapping, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -211,7 +211,7 @@ def __init__(self, hs: "HomeServer"): async def on_GET( self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, JsonMapping]: await assert_requester_is_admin(self.auth, request) target_user = UserID.from_string(user_id) @@ -226,7 +226,7 @@ async def on_GET( async def on_PUT( self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, JsonMapping]: requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester) @@ -658,7 +658,7 @@ def __init__(self, hs: "HomeServer"): async def on_GET( self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, JsonMapping]: target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) diff --git a/synapse/storage/databases/main/experimental_features.py b/synapse/storage/databases/main/experimental_features.py index cf3226ae5a70..654f924019a3 100644 --- a/synapse/storage/databases/main/experimental_features.py +++ b/synapse/storage/databases/main/experimental_features.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, FrozenSet from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main import CacheInvalidationWorkerStore -from synapse.types import StrCollection from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -34,7 +33,7 @@ def __init__( super().__init__(database, db_conn, hs) @cached() - async def list_enabled_features(self, user_id: str) -> StrCollection: + async def list_enabled_features(self, user_id: str) -> FrozenSet[str]: """ Checks to see what features are enabled for a given user Args: @@ -49,7 +48,7 @@ async def list_enabled_features(self, user_id: str) -> StrCollection: ["feature"], ) - return [feature["feature"] for feature in enabled] + return frozenset(feature["feature"] for feature in enabled) async def set_features_for_user( self, From f5fec7f8e739ae482e92a721212a8587439d47ed Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 Sep 2023 11:11:34 -0400 Subject: [PATCH 11/36] Update account_data & tags methods to be immutable. --- synapse/handlers/sync.py | 29 +++++++++++-------- synapse/rest/client/account_data.py | 10 +++---- .../storage/databases/main/account_data.py | 12 ++++---- synapse/storage/databases/main/tags.py | 6 ++-- 4 files changed, 31 insertions(+), 26 deletions(-) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0ccd7d250c4b..f6f001dbea77 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -57,6 +57,7 @@ from synapse.types import ( DeviceListUpdates, JsonDict, + JsonMapping, MutableStateMap, Requester, RoomStreamToken, @@ -1767,7 +1768,7 @@ async def _generate_sync_entry_for_account_data( since_token = sync_result_builder.since_token if since_token and not sync_result_builder.full_state: - global_account_data = ( + updated_global_account_data = ( await self.store.get_updated_global_account_data_for_user( user_id, since_token.account_data_key ) @@ -1778,19 +1779,23 @@ async def _generate_sync_entry_for_account_data( ) if push_rules_changed: - global_account_data = dict(global_account_data) - global_account_data[ - AccountDataTypes.PUSH_RULES - ] = await self._push_rules_handler.push_rules_for_user(sync_config.user) + global_account_data: JsonMapping = { + AccountDataTypes.PUSH_RULES: await self._push_rules_handler.push_rules_for_user( + sync_config.user + ), + **updated_global_account_data, + } else: all_global_account_data = await self.store.get_global_account_data_for_user( user_id ) - global_account_data = dict(all_global_account_data) - global_account_data[ - AccountDataTypes.PUSH_RULES - ] = await self._push_rules_handler.push_rules_for_user(sync_config.user) + global_account_data = { + AccountDataTypes.PUSH_RULES: await self._push_rules_handler.push_rules_for_user( + sync_config.user + ), + **all_global_account_data, + } account_data_for_user = ( await sync_config.filter_collection.filter_global_account_data( @@ -1894,7 +1899,7 @@ async def _generate_sync_entry_for_rooms( blocks_all_rooms or sync_result_builder.sync_config.filter_collection.blocks_all_room_account_data() ): - account_data_by_room: Mapping[str, Mapping[str, JsonDict]] = {} + account_data_by_room: Mapping[str, Mapping[str, JsonMapping]] = {} elif since_token and not sync_result_builder.full_state: account_data_by_room = ( await self.store.get_updated_room_account_data_for_user( @@ -2334,8 +2339,8 @@ async def _generate_room_entry( sync_result_builder: "SyncResultBuilder", room_builder: "RoomSyncResultBuilder", ephemeral: List[JsonDict], - tags: Optional[Mapping[str, Mapping[str, Any]]], - account_data: Mapping[str, JsonDict], + tags: Optional[Mapping[str, JsonMapping]], + account_data: Mapping[str, JsonMapping], always_include: bool = False, ) -> None: """Populates the `joined` and `archived` section of `sync_result_builder` diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index b1f9e9dc9ba5..ce0c4e774202 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -20,7 +20,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.types import JsonDict, RoomID +from synapse.types import JsonDict, JsonMapping, RoomID from ._base import client_patterns @@ -95,7 +95,7 @@ async def on_PUT( async def on_GET( self, request: SynapseRequest, user_id: str, account_data_type: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, JsonMapping]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") @@ -106,7 +106,7 @@ async def on_GET( and account_data_type == AccountDataTypes.PUSH_RULES ): account_data: Optional[ - JsonDict + JsonMapping ] = await self._push_rules_handler.push_rules_for_user(requester.user) else: account_data = await self.store.get_global_account_data_by_type_for_user( @@ -236,7 +236,7 @@ async def on_GET( user_id: str, room_id: str, account_data_type: str, - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, JsonMapping]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") @@ -253,7 +253,7 @@ async def on_GET( self._hs.config.experimental.msc4010_push_rules_account_data and account_data_type == AccountDataTypes.PUSH_RULES ): - account_data: Optional[JsonDict] = {} + account_data: Optional[JsonMapping] = {} else: account_data = await self.store.get_account_data_for_room_and_type( user_id, room_id, account_data_type diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 8f7bdbc61a7b..a29c57521d2e 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -43,7 +43,7 @@ MultiWriterIdGenerator, StreamIdGenerator, ) -from synapse.types import JsonDict +from synapse.types import JsonDict, JsonMapping from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -119,7 +119,7 @@ def get_max_account_data_stream_id(self) -> int: @cached() async def get_global_account_data_for_user( self, user_id: str - ) -> Mapping[str, JsonDict]: + ) -> Mapping[str, JsonMapping]: """ Get all the global client account_data for a user. @@ -164,7 +164,7 @@ def get_global_account_data_for_user( @cached() async def get_room_account_data_for_user( self, user_id: str - ) -> Mapping[str, Mapping[str, JsonDict]]: + ) -> Mapping[str, Mapping[str, JsonMapping]]: """ Get all of the per-room client account_data for a user. @@ -213,7 +213,7 @@ def get_room_account_data_for_user_txn( @cached(num_args=2, max_entries=5000, tree=True) async def get_global_account_data_by_type_for_user( self, user_id: str, data_type: str - ) -> Optional[JsonDict]: + ) -> Optional[JsonMapping]: """ Returns: The account data. @@ -265,7 +265,7 @@ def get_latest_stream_id_for_global_account_data_by_type_for_user_txn( @cached(num_args=2, tree=True) async def get_account_data_for_room( self, user_id: str, room_id: str - ) -> Mapping[str, JsonDict]: + ) -> Mapping[str, JsonMapping]: """Get all the client account_data for a user for a room. Args: @@ -296,7 +296,7 @@ def get_account_data_for_room_txn( @cached(num_args=3, max_entries=5000, tree=True) async def get_account_data_for_room_and_type( self, user_id: str, room_id: str, account_data_type: str - ) -> Optional[JsonDict]: + ) -> Optional[JsonMapping]: """Get the client account_data of given type for a user for a room. Args: diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index c149a9eacba7..61403a98cf95 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -23,7 +23,7 @@ from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.util.id_generators import AbstractStreamIdGenerator -from synapse.types import JsonDict +from synapse.types import JsonDict, JsonMapping from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -34,7 +34,7 @@ class TagsWorkerStore(AccountDataWorkerStore): @cached() async def get_tags_for_user( self, user_id: str - ) -> Mapping[str, Mapping[str, JsonDict]]: + ) -> Mapping[str, Mapping[str, JsonMapping]]: """Get all the tags for a user. @@ -109,7 +109,7 @@ def get_all_updated_tags_txn( async def get_updated_tags( self, user_id: str, stream_id: int - ) -> Mapping[str, Mapping[str, JsonDict]]: + ) -> Mapping[str, Mapping[str, JsonMapping]]: """Get all the tags for the rooms where the tags have changed since the given version From 9a62053cbaa8fdb75b2043496e9440496e5c1986 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 Sep 2023 11:19:12 -0400 Subject: [PATCH 12/36] FIx-up push related caching. --- scripts-dev/mypy_synapse_plugin.py | 10 +++++++++- synapse/push/bulk_push_rule_evaluator.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 9370ec549fba..48f683ec6c10 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -187,6 +187,11 @@ def check_is_cacheable( "builtins.bytes", } +# Types defined in Synapse which are known to be immutable. +IMMUTABLE_CUSTOM_TYPES = { + "synapse.synapse_rust.push.FilteredPushRules", +} + # Immutable containers only if the values are also immutable. IMMUTABLE_CONTAINER_TYPES_REQUIRING_IMMUTABLE_ELEMENTS = { "builtins.frozenset", @@ -223,7 +228,10 @@ def is_cacheable( return True, ("may be mutable" if verbose else None) elif isinstance(rt, Instance): - if rt.type.fullname in IMMUTABLE_VALUE_TYPES: + if ( + rt.type.fullname in IMMUTABLE_VALUE_TYPES + or rt.type.fullname in IMMUTABLE_CUSTOM_TYPES + ): return True, None elif rt.type.fullname == "typing.Mapping": diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 554634579ed0..14784312dcb7 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -131,7 +131,7 @@ def __init__(self, hs: "HomeServer"): async def _get_rules_for_event( self, event: EventBase, - ) -> Dict[str, FilteredPushRules]: + ) -> Mapping[str, FilteredPushRules]: """Get the push rules for all users who may need to be notified about the event. From cc61862aa4799eb0494b5c50011fc89188573830 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 Sep 2023 11:25:56 -0400 Subject: [PATCH 13/36] Update filtering to return immutable objects. --- synapse/api/filtering.py | 8 ++++---- synapse/rest/client/filter.py | 4 ++-- synapse/storage/databases/main/filtering.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 0995ecbe832a..74ee8e9f3f96 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -37,7 +37,7 @@ from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState from synapse.events import EventBase, relation_from_event -from synapse.types import JsonDict, RoomID, UserID +from synapse.types import JsonDict, JsonMapping, RoomID, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -191,7 +191,7 @@ def check_valid_filter(self, user_filter_json: JsonDict) -> None: class FilterCollection: - def __init__(self, hs: "HomeServer", filter_json: JsonDict): + def __init__(self, hs: "HomeServer", filter_json: JsonMapping): self._filter_json = filter_json room_filter_json = self._filter_json.get("room", {}) @@ -219,7 +219,7 @@ def __init__(self, hs: "HomeServer", filter_json: JsonDict): def __repr__(self) -> str: return "" % (json.dumps(self._filter_json),) - def get_filter_json(self) -> JsonDict: + def get_filter_json(self) -> JsonMapping: return self._filter_json def timeline_limit(self) -> int: @@ -313,7 +313,7 @@ def blocks_all_room_timeline(self) -> bool: class Filter: - def __init__(self, hs: "HomeServer", filter_json: JsonDict): + def __init__(self, hs: "HomeServer", filter_json: JsonMapping): self._hs = hs self._store = hs.get_datastores().main self.filter_json = filter_json diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py index 5da1e511a281..b5879496dbd5 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py @@ -19,7 +19,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, JsonMapping, UserID from ._base import client_patterns, set_timeline_upper_limit @@ -41,7 +41,7 @@ def __init__(self, hs: "HomeServer"): async def on_GET( self, request: SynapseRequest, user_id: str, filter_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, JsonMapping]: target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index 047de6283acc..7d94685caf91 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -25,7 +25,7 @@ LoggingTransaction, ) from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, JsonMapping, UserID from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -145,7 +145,7 @@ def _final_batch(txn: LoggingTransaction, lower_bound_id: str) -> None: @cached(num_args=2) async def get_user_filter( self, user_id: UserID, filter_id: Union[int, str] - ) -> JsonDict: + ) -> JsonMapping: # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail # with a coherent error message rather than 500 M_UNKNOWN. try: From a27a67f5bf68358a19b897c533f9b4890546a39c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 Sep 2023 11:27:20 -0400 Subject: [PATCH 14/36] Update relations with immutable. --- synapse/handlers/relations.py | 14 ++++++++++++-- synapse/storage/databases/main/relations.py | 4 ++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 74442b1eb917..f950e33dfaac 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -13,7 +13,17 @@ # limitations under the License. import enum import logging -from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + FrozenSet, + Iterable, + List, + Mapping, + Optional, + Sequence, +) import attr @@ -245,7 +255,7 @@ async def redact_events_related_to( async def get_references_for_events( self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() - ) -> Dict[str, List[_RelatedEvent]]: + ) -> Mapping[str, Sequence[_RelatedEvent]]: """Get a list of references to the given events. Args: diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index ef08192fdb81..b2e3b4dc1ad8 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -465,7 +465,7 @@ async def get_references_for_event(self, event_id: str) -> List[JsonDict]: @cachedList(cached_method_name="get_references_for_event", list_name="event_ids") async def get_references_for_events( self, event_ids: Collection[str] - ) -> Mapping[str, Optional[List[_RelatedEvent]]]: + ) -> Mapping[str, Optional[Sequence[_RelatedEvent]]]: """Get a list of references to the given events. Args: @@ -933,7 +933,7 @@ async def get_threads( room_id: str, limit: int = 5, from_token: Optional[ThreadsNextBatch] = None, - ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + ) -> Tuple[Sequence[str], Optional[ThreadsNextBatch]]: """Get a list of thread IDs, ordered by topological ordering of their latest reply. From 8f7f4d7d0f8fcd681f618f01cf6e86b46f529af2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 Sep 2023 11:30:46 -0400 Subject: [PATCH 15/36] Update receipts code. --- synapse/handlers/initial_sync.py | 3 ++- synapse/handlers/receipts.py | 13 +++++++------ synapse/storage/databases/main/receipts.py | 14 +++++++------- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 5dc76ef588f7..5737f8014dd3 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -32,6 +32,7 @@ from synapse.streams.config import PaginationConfig from synapse.types import ( JsonDict, + JsonMapping, Requester, RoomStreamToken, StreamKeyType, @@ -454,7 +455,7 @@ async def get_presence() -> List[JsonDict]: for s in states ] - async def get_receipts() -> List[JsonDict]: + async def get_receipts() -> List[JsonMapping]: receipts = await self.store.get_linearized_receipts_for_room( room_id, to_key=now_token.receipt_key ) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 2bacdebfb5f9..152285ad3e17 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -19,6 +19,7 @@ from synapse.streams import EventSource from synapse.types import ( JsonDict, + JsonMapping, ReadReceipt, StreamKeyType, UserID, @@ -182,15 +183,15 @@ async def received_client_receipt( await self.federation_sender.send_read_receipt(receipt) -class ReceiptEventSource(EventSource[int, JsonDict]): +class ReceiptEventSource(EventSource[int, JsonMapping]): def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.config = hs.config @staticmethod def filter_out_private_receipts( - rooms: Sequence[JsonDict], user_id: str - ) -> List[JsonDict]: + rooms: Sequence[JsonMapping], user_id: str + ) -> List[JsonMapping]: """ Filters a list of serialized receipts (as returned by /sync and /initialSync) and removes private read receipts of other users. @@ -207,7 +208,7 @@ def filter_out_private_receipts( The same as rooms, but filtered. """ - result = [] + result: List[JsonMapping] = [] # Iterate through each room's receipt content. for room in rooms: @@ -260,7 +261,7 @@ async def get_new_events( room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[JsonMapping], int]: from_key = int(from_key) to_key = self.get_current_key() @@ -279,7 +280,7 @@ async def get_new_events( async def get_new_events_as( self, from_key: int, to_key: int, service: ApplicationService - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[JsonMapping], int]: """Returns a set of new read receipt events that an appservice may be interested in. diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index d3fb0ef6a02f..c19fdd5190b7 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -43,7 +43,7 @@ MultiWriterIdGenerator, StreamIdGenerator, ) -from synapse.types import JsonDict +from synapse.types import JsonDict, JsonMapping from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -218,7 +218,7 @@ async def get_receipts_for_user_with_orderings( @cached() async def _get_receipts_for_user_with_orderings( self, user_id: str, receipt_type: str - ) -> JsonDict: + ) -> JsonMapping: """ Fetch receipts for all rooms that the given user is joined to. @@ -258,7 +258,7 @@ def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]: async def get_linearized_receipts_for_rooms( self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None - ) -> List[dict]: + ) -> List[JsonMapping]: """Get receipts for multiple rooms for sending to clients. Args: @@ -287,7 +287,7 @@ async def get_linearized_receipts_for_rooms( async def get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> Sequence[JsonDict]: + ) -> Sequence[JsonMapping]: """Get receipts for a single room for sending to clients. Args: @@ -310,7 +310,7 @@ async def get_linearized_receipts_for_room( @cached(tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> Sequence[JsonDict]: + ) -> Sequence[JsonMapping]: """See get_linearized_receipts_for_room""" def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: @@ -353,7 +353,7 @@ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: ) async def _get_linearized_receipts_for_rooms( self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None - ) -> Mapping[str, Sequence[JsonDict]]: + ) -> Mapping[str, Sequence[JsonMapping]]: if not room_ids: return {} @@ -415,7 +415,7 @@ def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: ) async def get_linearized_receipts_for_all_rooms( self, to_key: int, from_key: Optional[int] = None - ) -> Mapping[str, JsonDict]: + ) -> Mapping[str, JsonMapping]: """Get receipts for all rooms between two stream_ids, up to a limit of the latest 100 read receipts. From 9fdc5a1dbfc43e49dcbed724feb62a67c162a92a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 Sep 2023 11:34:00 -0400 Subject: [PATCH 16/36] Update e2e keys & devices. --- synapse/handlers/e2e_keys.py | 24 +++++++------------ synapse/handlers/sync.py | 4 ++-- synapse/storage/databases/main/devices.py | 21 +++++++++++----- .../storage/databases/main/end_to_end_keys.py | 22 +++++++++-------- 4 files changed, 38 insertions(+), 33 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index ad075497c8b8..8c6432035d1c 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -31,6 +31,7 @@ from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace from synapse.types import ( JsonDict, + JsonMapping, UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -272,11 +273,7 @@ async def _query( delay_cancellation=True, ) - ret = {"device_keys": results, "failures": failures} - - ret.update(cross_signing_keys) - - return ret + return {"device_keys": results, "failures": failures, **cross_signing_keys} @trace async def _query_devices_for_destination( @@ -408,7 +405,7 @@ async def _query_devices_for_destination( @cancellable async def get_cross_signing_keys_from_cache( self, query: Iterable[str], from_user_id: Optional[str] - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Dict[str, JsonMapping]]: """Get cross-signing keys for users from the database Args: @@ -551,16 +548,13 @@ async def on_federation_query_client_keys( self.config.federation.allow_device_name_lookup_over_federation ), ) - ret = {"device_keys": res} # add in the cross-signing keys cross_signing_keys = await self.get_cross_signing_keys_from_cache( device_keys_query, None ) - ret.update(cross_signing_keys) - - return ret + return {"device_keys": res, **cross_signing_keys} async def claim_local_one_time_keys( self, @@ -1127,7 +1121,7 @@ def _check_master_key_signature( user_id: str, master_key_id: str, signed_master_key: JsonDict, - stored_master_key: JsonDict, + stored_master_key: JsonMapping, devices: Dict[str, Dict[str, JsonDict]], ) -> List["SignatureListItem"]: """Check signatures of a user's master key made by their devices. @@ -1278,7 +1272,7 @@ async def _process_other_signatures( async def _get_e2e_cross_signing_verify_key( self, user_id: str, key_type: str, from_user_id: Optional[str] = None - ) -> Tuple[JsonDict, str, VerifyKey]: + ) -> Tuple[JsonMapping, str, VerifyKey]: """Fetch locally or remotely query for a cross-signing public key. First, attempt to fetch the cross-signing public key from storage. @@ -1333,7 +1327,7 @@ async def _retrieve_cross_signing_keys_for_remote_user( self, user: UserID, desired_key_type: str, - ) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]: + ) -> Optional[Tuple[JsonMapping, str, VerifyKey]]: """Queries cross-signing keys for a remote user and saves them to the database Only the key specified by `key_type` will be returned, while all retrieved keys @@ -1474,7 +1468,7 @@ def _check_device_signature( user_id: str, verify_key: VerifyKey, signed_device: JsonDict, - stored_device: JsonDict, + stored_device: JsonMapping, ) -> None: """Check that a signature on a device or cross-signing key is correct and matches the copy of the device/key that we have stored. Throws an diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f6f001dbea77..b56b1d15ab5e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -235,7 +235,7 @@ class SyncResult: archived: List[ArchivedSyncResult] to_device: List[JsonDict] device_lists: DeviceListUpdates - device_one_time_keys_count: JsonDict + device_one_time_keys_count: JsonMapping device_unused_fallback_key_types: List[str] def __bool__(self) -> bool: @@ -1543,7 +1543,7 @@ async def generate_sync_result( logger.debug("Fetching OTK data") device_id = sync_config.device_id - one_time_keys_count: JsonDict = {} + one_time_keys_count: JsonMapping = {} unused_fallback_key_types: List[str] = [] if device_id: # TODO: We should have a way to let clients differentiate between the states of: diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index d3d8b866665e..df596f35f9b3 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -55,7 +55,12 @@ AbstractStreamIdGenerator, StreamIdGenerator, ) -from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key +from synapse.types import ( + JsonDict, + JsonMapping, + StrCollection, + get_verify_key_from_cross_signing_key, +) from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache @@ -746,7 +751,7 @@ def _add_user_signature_change_txn( @cancellable async def get_user_devices_from_cache( self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]] - ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]: + ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonMapping]]]: """Get the devices (and keys if any) for remote users from the cache. Args: @@ -766,13 +771,13 @@ async def get_user_devices_from_cache( user_ids_not_in_cache = unique_user_ids - user_ids_in_cache # First fetch all the users which all devices are to be returned. - results: Dict[str, Mapping[str, JsonDict]] = {} + results: Dict[str, Mapping[str, JsonMapping]] = {} for user_id in user_ids: if user_id in user_ids_in_cache: results[user_id] = await self.get_cached_devices_for_user(user_id) # Then fetch all device-specific requests, but skip users we've already # fetched all devices for. - device_specific_results: Dict[str, Dict[str, JsonDict]] = {} + device_specific_results: Dict[str, Dict[str, JsonMapping]] = {} for user_id, device_id in user_and_device_ids: if user_id in user_ids_in_cache and user_id not in user_ids: device = await self._get_cached_user_device(user_id, device_id) @@ -801,7 +806,9 @@ async def get_users_whose_devices_are_cached( return user_ids_in_cache @cached(num_args=2, tree=True) - async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict: + async def _get_cached_user_device( + self, user_id: str, device_id: str + ) -> JsonMapping: content = await self.db_pool.simple_select_one_onecol( table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -811,7 +818,9 @@ async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDic return db_to_json(content) @cached() - async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]: + async def get_cached_devices_for_user( + self, user_id: str + ) -> Mapping[str, JsonMapping]: devices = await self.db_pool.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 835cb37b338c..3d00afd89773 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -52,7 +52,7 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.types import JsonDict +from synapse.types import JsonDict, JsonMapping from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.cancellation import cancellable @@ -125,7 +125,7 @@ def process_replication_rows( async def get_e2e_device_keys_for_federation_query( self, user_id: str - ) -> Tuple[int, List[JsonDict]]: + ) -> Tuple[int, Sequence[JsonMapping]]: """Get all devices (with any device keys) for a user Returns: @@ -174,7 +174,7 @@ async def get_e2e_device_keys_for_federation_query( @cached(iterable=True) async def _get_e2e_device_keys_for_federation_query_inner( self, user_id: str - ) -> List[JsonDict]: + ) -> Sequence[JsonMapping]: """Get all devices (with any device keys) for a user""" devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)]) @@ -578,7 +578,7 @@ def _add_e2e_one_time_keys_txn( @cached(max_entries=10000) async def count_e2e_one_time_keys( self, user_id: str, device_id: str - ) -> Dict[str, int]: + ) -> Mapping[str, int]: """Count the number of one time keys the server has for a device Returns: A mapping from algorithm to number of keys for that algorithm. @@ -812,7 +812,7 @@ async def get_e2e_unused_fallback_key_types( async def get_e2e_cross_signing_key( self, user_id: str, key_type: str, from_user_id: Optional[str] = None - ) -> Optional[JsonDict]: + ) -> Optional[JsonMapping]: """Returns a user's cross-signing key. Args: @@ -833,7 +833,9 @@ async def get_e2e_cross_signing_key( return user_keys.get(key_type) @cached(num_args=1) - def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]: + def _get_bare_e2e_cross_signing_keys( + self, user_id: str + ) -> Mapping[str, JsonMapping]: """Dummy function. Only used to make a cache for _get_bare_e2e_cross_signing_keys_bulk. """ @@ -846,7 +848,7 @@ def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDic ) async def _get_bare_e2e_cross_signing_keys_bulk( self, user_ids: Iterable[str] - ) -> Mapping[str, Optional[Mapping[str, JsonDict]]]: + ) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. @@ -867,7 +869,7 @@ async def _get_bare_e2e_cross_signing_keys_bulk( ) # The `Optional` comes from the `@cachedList` decorator. - return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result) + return cast(Dict[str, Optional[Mapping[str, JsonMapping]]], result) def _get_bare_e2e_cross_signing_keys_bulk_txn( self, @@ -1026,7 +1028,7 @@ def _get_e2e_cross_signing_signatures_txn( @cancellable async def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None - ) -> Dict[str, Optional[Mapping[str, JsonDict]]]: + ) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]: """Returns the cross-signing keys for a set of users. Args: @@ -1043,7 +1045,7 @@ async def get_e2e_cross_signing_keys_bulk( if from_user_id: result = cast( - Dict[str, Optional[Mapping[str, JsonDict]]], + Dict[str, Optional[Mapping[str, JsonMapping]]], await self.db_pool.runInteraction( "get_e2e_cross_signing_signatures", self._get_e2e_cross_signing_signatures_txn, From d8cce5bf1717c5d7b0cf986d73761c2748b028b4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 Sep 2023 11:35:02 -0400 Subject: [PATCH 17/36] Update appservice stuff. --- synapse/appservice/__init__.py | 6 +++--- synapse/appservice/api.py | 6 +++--- synapse/appservice/scheduler.py | 18 +++++++++--------- synapse/handlers/appservice.py | 9 +++++---- synapse/handlers/typing.py | 17 ++++++++++++----- synapse/storage/databases/main/appservice.py | 6 +++--- 6 files changed, 35 insertions(+), 27 deletions(-) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 2260a8f589b3..6f4aa53c93bd 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import DeviceListUpdates, JsonDict, UserID +from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, UserID from synapse.util.caches.descriptors import _CacheContext, cached if TYPE_CHECKING: @@ -379,8 +379,8 @@ def __init__( service: ApplicationService, id: int, events: Sequence[EventBase], - ephemeral: List[JsonDict], - to_device_messages: List[JsonDict], + ephemeral: List[JsonMapping], + to_device_messages: List[JsonMapping], one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index b1523be208e9..c42e1f11aa91 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -41,7 +41,7 @@ from synapse.events.utils import SerializeEventConfig, serialize_event from synapse.http.client import SimpleHttpClient, is_unknown_endpoint from synapse.logging import opentracing -from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID +from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -306,8 +306,8 @@ async def push_bulk( self, service: "ApplicationService", events: Sequence[EventBase], - ephemeral: List[JsonDict], - to_device_messages: List[JsonDict], + ephemeral: List[JsonMapping], + to_device_messages: List[JsonMapping], one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 3a319b0d42d9..c08eeb71a7cb 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -73,7 +73,7 @@ from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main import DataStore -from synapse.types import DeviceListUpdates, JsonDict +from synapse.types import DeviceListUpdates, JsonMapping from synapse.util import Clock if TYPE_CHECKING: @@ -121,8 +121,8 @@ def enqueue_for_appservice( self, appservice: ApplicationService, events: Optional[Collection[EventBase]] = None, - ephemeral: Optional[Collection[JsonDict]] = None, - to_device_messages: Optional[Collection[JsonDict]] = None, + ephemeral: Optional[Collection[JsonMapping]] = None, + to_device_messages: Optional[Collection[JsonMapping]] = None, device_list_summary: Optional[DeviceListUpdates] = None, ) -> None: """ @@ -180,9 +180,9 @@ def __init__( # dict of {service_id: [events]} self.queued_events: Dict[str, List[EventBase]] = {} # dict of {service_id: [events]} - self.queued_ephemeral: Dict[str, List[JsonDict]] = {} + self.queued_ephemeral: Dict[str, List[JsonMapping]] = {} # dict of {service_id: [to_device_message_json]} - self.queued_to_device_messages: Dict[str, List[JsonDict]] = {} + self.queued_to_device_messages: Dict[str, List[JsonMapping]] = {} # dict of {service_id: [device_list_summary]} self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {} @@ -295,8 +295,8 @@ async def _compute_msc3202_otk_counts_and_fallback_keys( self, service: ApplicationService, events: Iterable[EventBase], - ephemerals: Iterable[JsonDict], - to_device_messages: Iterable[JsonDict], + ephemerals: Iterable[JsonMapping], + to_device_messages: Iterable[JsonMapping], ) -> Tuple[TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys]: """ Given a list of the events, ephemeral messages and to-device messages, @@ -366,8 +366,8 @@ async def send( self, service: ApplicationService, events: Sequence[EventBase], - ephemeral: Optional[List[JsonDict]] = None, - to_device_messages: Optional[List[JsonDict]] = None, + ephemeral: Optional[List[JsonMapping]] = None, + to_device_messages: Optional[List[JsonMapping]] = None, one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None, unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, device_list_summary: Optional[DeviceListUpdates] = None, diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 6429545c98d5..7de7bd3289c8 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -46,6 +46,7 @@ from synapse.types import ( DeviceListUpdates, JsonDict, + JsonMapping, RoomAlias, RoomStreamToken, StreamKeyType, @@ -397,7 +398,7 @@ async def _notify_interested_services_ephemeral( async def _handle_typing( self, service: ApplicationService, new_token: int - ) -> List[JsonDict]: + ) -> List[JsonMapping]: """ Return the typing events since the given stream token that the given application service should receive. @@ -432,7 +433,7 @@ async def _handle_typing( async def _handle_receipts( self, service: ApplicationService, new_token: int - ) -> List[JsonDict]: + ) -> List[JsonMapping]: """ Return the latest read receipts that the given application service should receive. @@ -471,7 +472,7 @@ async def _handle_presence( service: ApplicationService, users: Collection[Union[str, UserID]], new_token: Optional[int], - ) -> List[JsonDict]: + ) -> List[JsonMapping]: """ Return the latest presence updates that the given application service should receive. @@ -491,7 +492,7 @@ async def _handle_presence( A list of json dictionaries containing data derived from the presence events that should be sent to the given application service. """ - events: List[JsonDict] = [] + events: List[JsonMapping] = [] presence_source = self.event_sources.sources.presence from_key = await self.store.get_type_stream_id_for_appservice( service, "presence" diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 4b4227003d0a..bdefa7f26f2e 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -26,7 +26,14 @@ ) from synapse.replication.tcp.streams import TypingStream from synapse.streams import EventSource -from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID +from synapse.types import ( + JsonDict, + JsonMapping, + Requester, + StrCollection, + StreamKeyType, + UserID, +) from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.retryutils import filter_destinations_by_retry_limiter @@ -487,7 +494,7 @@ def process_replication_rows( raise Exception("Typing writer instance got typing info over replication") -class TypingNotificationEventSource(EventSource[int, JsonDict]): +class TypingNotificationEventSource(EventSource[int, JsonMapping]): def __init__(self, hs: "HomeServer"): self._main_store = hs.get_datastores().main self.clock = hs.get_clock() @@ -497,7 +504,7 @@ def __init__(self, hs: "HomeServer"): # self.get_typing_handler = hs.get_typing_handler - def _make_event_for(self, room_id: str) -> JsonDict: + def _make_event_for(self, room_id: str) -> JsonMapping: typing = self.get_typing_handler()._room_typing[room_id] return { "type": EduTypes.TYPING, @@ -507,7 +514,7 @@ def _make_event_for(self, room_id: str) -> JsonDict: async def get_new_events_as( self, from_key: int, service: ApplicationService - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[JsonMapping], int]: """Returns a set of new typing events that an appservice may be interested in. @@ -551,7 +558,7 @@ async def get_new_events( room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[JsonMapping], int]: with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) handler = self.get_typing_handler() diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 484db175d090..0553a0621ace 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -45,7 +45,7 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import DeviceListUpdates, JsonDict +from synapse.types import DeviceListUpdates, JsonMapping from synapse.util import json_encoder from synapse.util.caches.descriptors import _CacheContext, cached @@ -268,8 +268,8 @@ async def create_appservice_txn( self, service: ApplicationService, events: Sequence[EventBase], - ephemeral: List[JsonDict], - to_device_messages: List[JsonDict], + ephemeral: List[JsonMapping], + to_device_messages: List[JsonMapping], one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, From ef61f3d2432319400708c257f52e06f553483237 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 Sep 2023 11:46:13 -0400 Subject: [PATCH 18/36] Ensure current hosts is immutable. --- synapse/federation/federation_client.py | 4 ++-- synapse/handlers/federation_event.py | 2 +- synapse/storage/controllers/state.py | 2 +- synapse/storage/databases/main/roommember.py | 10 ++++++---- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 607013f121bf..c8bc46415d9d 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -64,7 +64,7 @@ from synapse.http.client import is_unknown_endpoint from synapse.http.types import QueryParams from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace -from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.types import JsonDict, StrCollection, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -1704,7 +1704,7 @@ async def send_request( async def timestamp_to_event( self, *, - destinations: List[str], + destinations: StrCollection, room_id: str, timestamp: int, direction: Direction, diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index d32d224d5640..7691b31f5672 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1539,7 +1539,7 @@ async def _resync_device(self, sender: str) -> None: logger.exception("Failed to resync device for %s", sender) async def backfill_event_id( - self, destinations: List[str], room_id: str, event_id: str + self, destinations: StrCollection, room_id: str, event_id: str ) -> PulledPduInfo: """Backfill a single event and persist it as a non-outlier which means we also pull in all of the state and auth events necessary for it. diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 33f04e7572c9..1854e2c447cc 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -582,7 +582,7 @@ async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]: @trace @tag_args - async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]: + async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]: """Get current hosts in room based on current state. Blocks until we have full state for the given room. This only happens for rooms diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 87932b9827c0..6ec6c726f0a0 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -984,7 +984,7 @@ def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]: ) @cached(iterable=True, max_entries=10000) - async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]: + async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]: """ Get current hosts in room based on current state. @@ -1013,12 +1013,14 @@ async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]: # `get_users_in_room` rather than funky SQL. domains = await self.get_current_hosts_in_room(room_id) - return list(domains) + return tuple(domains) # For PostgreSQL we can use a regex to pull out the domains from the # joined users in `current_state_events` via regex. - def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]: + def get_current_hosts_in_room_ordered_txn( + txn: LoggingTransaction, + ) -> Tuple[str, ...]: # Returns a list of servers currently joined in the room sorted by # longest in the room first (aka. with the lowest depth). The # heuristic of sorting by servers who have been in the room the @@ -1043,7 +1045,7 @@ def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]: """ txn.execute(sql, (room_id,)) # `server_domain` will be `NULL` for malformed MXIDs with no colons. - return [d for d, in txn if d is not None] + return tuple(d for d, in txn if d is not None) return await self.db_pool.runInteraction( "get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn From 96faa341e251dc333514edc58f526ab697018d24 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 Sep 2023 12:39:22 -0400 Subject: [PATCH 19/36] Properly check attrs for frozen-ness. --- scripts-dev/mypy_synapse_plugin.py | 17 ++++++++++++++--- synapse/federation/federation_server.py | 2 +- synapse/handlers/sync.py | 4 ++-- .../databases/main/event_push_actions.py | 5 +++-- synapse/storage/roommember.py | 1 + tests/rest/admin/test_server_notice.py | 2 +- tests/rest/client/test_shadow_banned.py | 2 +- 7 files changed, 23 insertions(+), 10 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 48f683ec6c10..8daf303ee30a 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -21,7 +21,7 @@ import mypy.types from mypy.erasetype import remove_instance_last_known_values from mypy.errorcodes import ErrorCode -from mypy.nodes import ARG_NAMED_OPT +from mypy.nodes import ARG_NAMED_OPT, Var from mypy.plugin import MethodSigContext, Plugin from mypy.typeops import bind_self from mypy.types import ( @@ -190,6 +190,8 @@ def check_is_cacheable( # Types defined in Synapse which are known to be immutable. IMMUTABLE_CUSTOM_TYPES = { "synapse.synapse_rust.push.FilteredPushRules", + # This is technically not immutable, but close enough. + "signedjson.types.VerifyKey", } # Immutable containers only if the values are also immutable. @@ -198,6 +200,7 @@ def check_is_cacheable( "builtins.tuple", "typing.AbstractSet", "typing.Sequence", + "immutabledict.immutabledict", } MUTABLE_CONTAINER_TYPES = { @@ -245,9 +248,17 @@ def is_cacheable( return False, None elif "attrs" in rt.type.metadata: - frozen = rt.type.metadata["attrs"].get("frozen", False) + frozen = rt.type.metadata["attrs"]["frozen"] if frozen: - # TODO: should really check that all of the fields are also cacheable + for attribute in rt.type.metadata["attrs"]["attributes"]: + attribute_name = attribute["name"] + symbol_node = rt.type.names[attribute_name].node + assert isinstance(symbol_node, Var) + assert symbol_node.type is not None + ok, note = is_cacheable(symbol_node.type, signature, verbose) + if not ok: + return False, f"non-frozen attrs property: {attribute_name}" + # All attributes were frozen. return True, None else: return False, "non-frozen attrs class" diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index f9915e5a3f05..18dc0dab4739 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -710,7 +710,7 @@ async def on_send_join_request( state_event_ids: Collection[str] servers_in_room: Optional[Collection[str]] if caller_supports_partial_state: - summary = await self.store.get_room_summary(room_id) + summary = await self.store.get_room_summary(room_id) # type: ignore[synapse-@cached-mutable] state_event_ids = _get_event_ids_for_partial_state_join( event, prev_state_ids, summary ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b56b1d15ab5e..a63b1ce08ff4 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -796,7 +796,7 @@ async def compute_summary( ) # this is heavily cached, thus: fast. - details = await self.store.get_room_summary(room_id) + details = await self.store.get_room_summary(room_id) # type: ignore[synapse-@cached-mutable] name_id = state_ids.get((EventTypes.Name, "")) canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, "")) @@ -1321,7 +1321,7 @@ async def unread_notifs_for_room_id( return RoomNotifCounts.empty() with Measure(self.clock, "unread_notifs_for_room_id"): - return await self.store.get_unread_event_push_actions_by_room_for_user( + return await self.store.get_unread_event_push_actions_by_room_for_user( # type: ignore[synapse-@cached-mutable] room_id, sync_config.user.to_string(), ) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index b958a39aebb1..b7dd5787254c 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -182,6 +182,7 @@ class UserPushAction(EmailPushAction): profile_tag: str +# TODO This is used as a cached value and is mutable. @attr.s(slots=True, auto_attribs=True) class NotifCounts: """ @@ -193,7 +194,7 @@ class NotifCounts: highlight_count: int = 0 -@attr.s(slots=True, auto_attribs=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class RoomNotifCounts: """ The per-user, per-room count of notifications. Used by sync and push. @@ -201,7 +202,7 @@ class RoomNotifCounts: main_timeline: NotifCounts # Map of thread ID to the notification counts. - threads: Dict[str, NotifCounts] + threads: Mapping[str, NotifCounts] @staticmethod def empty() -> "RoomNotifCounts": diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 2500381b7b85..cbfb32014c79 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -45,6 +45,7 @@ class ProfileInfo: display_name: Optional[str] +# TODO This is used as a cached value and is mutable. @attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True) class MemberSummary: # A truncated list of (user_id, event_id) tuples for users of a given diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index 28b999573e75..b41bbcd46e1e 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -424,7 +424,7 @@ def test_send_server_notice_delete_room(self) -> None: # It doesn't really matter what API we use here, we just want to assert # that the room doesn't exist. - summary = self.get_success(self.store.get_room_summary(first_room_id)) + summary = self.get_success(self.store.get_room_summary(first_room_id)) # type: ignore[synapse-@cached-mutable] # The summary should be empty since the room doesn't exist. self.assertEqual(summary, {}) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 9aecf88e4160..bd0ba3d057f5 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -182,7 +182,7 @@ def test_upgrade(self) -> None: # It doesn't really matter what API we use here, we just want to assert # that the room doesn't exist. - summary = self.get_success(self.store.get_room_summary(new_room_id)) + summary = self.get_success(self.store.get_room_summary(new_room_id)) # type: ignore[synapse-@cached-mutable] # The summary should be empty since the room doesn't exist. self.assertEqual(summary, {}) From 2f759294f9091164c1b942d375035da929ab9565 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 14 Sep 2023 10:20:20 -0400 Subject: [PATCH 20/36] Kick CI --- .github/workflows/tests.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fb117380d028..1b7948758f87 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -109,13 +109,13 @@ jobs: # Cribbed from # https://github.com/AustinScola/mypy-cache-github-action/blob/85ea4f2972abed39b33bd02c36e341b28ca59213/src/restore.ts#L10-L17 - - name: Restore/persist mypy's cache - uses: actions/cache@v3 - with: - path: | - .mypy_cache - key: mypy-cache-${{ github.context.sha }} - restore-keys: mypy-cache- +# - name: Restore/persist mypy's cache +# uses: actions/cache@v3 +# with: +# path: | +# .mypy_cache +# key: mypy-cache-${{ github.context.sha }} +# restore-keys: mypy-cache- - name: Run mypy run: poetry run mypy From 8b1b15b5ffcdf4499a55dcfbf9aa4cb82633fe38 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 14 Sep 2023 10:58:14 -0400 Subject: [PATCH 21/36] Fix-up return value of get_latest_event_ids_in_room. --- synapse/handlers/admin.py | 8 +----- synapse/handlers/federation_event.py | 8 +++--- synapse/storage/controllers/persist_events.py | 9 +++---- .../databases/main/event_federation.py | 8 +++--- synapse/storage/databases/main/events.py | 2 +- tests/replication/tcp/streams/test_events.py | 10 +++---- tests/storage/test_cleanup_extrems.py | 14 +++++----- tests/test_federation.py | 26 ++++++++++++------- 8 files changed, 43 insertions(+), 42 deletions(-) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index f06ad81ab670..ba9704a065c5 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -18,13 +18,7 @@ from synapse.api.constants import Direction, Membership from synapse.events import EventBase -from synapse.types import ( - JsonMapping, - RoomStreamToken, - StateMap, - UserID, - UserInfo, -) +from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID, UserInfo from synapse.visibility import filter_events_for_client if TYPE_CHECKING: diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 7691b31f5672..7c62cdfaef5f 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -723,12 +723,11 @@ async def _get_missing_events_for_pdu( if not prevs - seen: return - latest_list = await self._store.get_latest_event_ids_in_room(room_id) + latest_frozen = await self._store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us - latest = set(latest_list) - latest |= seen + latest = seen | latest_frozen logger.info( "Requesting missing events between %s and %s", @@ -1976,8 +1975,7 @@ async def _check_for_soft_fail( # partial and full state and may not be accurate. return - extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id) - extrem_ids = set(extrem_ids_list) + extrem_ids = await self._store.get_latest_event_ids_in_room(event.room_id) prev_event_ids = set(event.prev_event_ids()) if extrem_ids == prev_event_ids: diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 6864f9309020..f39ae2d63536 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -19,6 +19,7 @@ from collections import deque from typing import ( TYPE_CHECKING, + AbstractSet, Any, Awaitable, Callable, @@ -618,7 +619,7 @@ async def _persist_event_batch( ) for room_id, ev_ctx_rm in events_by_room.items(): - latest_event_ids = set( + latest_event_ids = ( await self.main_store.get_latest_event_ids_in_room(room_id) ) new_latest_event_ids = await self._calculate_new_extremities( @@ -740,7 +741,7 @@ async def _calculate_new_extremities( self, room_id: str, event_contexts: List[Tuple[EventBase, EventContext]], - latest_event_ids: Collection[str], + latest_event_ids: AbstractSet[str], ) -> Set[str]: """Calculates the new forward extremities for a room given events to persist. @@ -758,8 +759,6 @@ async def _calculate_new_extremities( and not event.internal_metadata.is_soft_failed() ] - latest_event_ids = set(latest_event_ids) - # start with the existing forward extremities result = set(latest_event_ids) @@ -798,7 +797,7 @@ async def _get_new_state_after_events( self, room_id: str, events_context: List[Tuple[EventBase, EventContext]], - old_latest_event_ids: Set[str], + old_latest_event_ids: AbstractSet[str], new_latest_event_ids: Set[str], ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]: """Calculate the current state dict after adding some new events to diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 09de8f55e277..afffa549853d 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -19,6 +19,7 @@ TYPE_CHECKING, Collection, Dict, + FrozenSet, Iterable, List, Optional, @@ -47,7 +48,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.engines import PostgresEngine, Sqlite3Engine -from synapse.types import JsonDict, StrCollection, StrSequence +from synapse.types import JsonDict, StrCollection from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache @@ -1179,13 +1180,14 @@ def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]: ) @cached(max_entries=5000, iterable=True) - async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence: - return await self.db_pool.simple_select_onecol( + async def get_latest_event_ids_in_room(self, room_id: str) -> FrozenSet[str]: + event_ids = await self.db_pool.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", desc="get_latest_event_ids_in_room", ) + return frozenset(event_ids) async def get_min_depth(self, room_id: str) -> Optional[int]: """For the given room, get the minimum depth we have seen for it.""" diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 0c1ed752406f..bc8474a5897e 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -222,7 +222,7 @@ async def _persist_events_and_state_updates( for room_id, latest_event_ids in new_forward_extremities.items(): self.store.get_latest_event_ids_in_room.prefill( - (room_id,), list(latest_event_ids) + (room_id,), frozenset(latest_event_ids) ) async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 65ef4bb16055..128fc3e0460c 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence +from typing import Any, List, Optional from twisted.test.proto_helpers import MemoryReactor @@ -139,7 +139,7 @@ def test_update_function_huge_state_change(self) -> None: ) # this is the point in the DAG where we make a fork - fork_point: Sequence[str] = self.get_success( + fork_point = self.get_success( self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) @@ -294,7 +294,7 @@ def test_update_function_state_row_limit(self) -> None: ) # this is the point in the DAG where we make a fork - fork_point: Sequence[str] = self.get_success( + fork_point = self.get_success( self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) @@ -316,14 +316,14 @@ def test_update_function_state_row_limit(self) -> None: self.test_handler.received_rdata_rows.clear() # now roll back all that state by de-modding the users - prev_events = fork_point + prev_events = list(fork_point) pl_events = [] for u in user_ids: pls["users"][u] = 0 e = self.get_success( inject_event( self.hs, - prev_event_ids=list(prev_events), + prev_event_ids=prev_events, type=EventTypes.PowerLevels, state_key="", sender=self.user_id, diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 7de109966d61..ceb9597dd312 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -120,7 +120,7 @@ def test_soft_failed_extremities_handled_correctly(self) -> None: self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(latest_event_ids, [event_id_4]) + self.assertEqual(latest_event_ids, {event_id_4}) def test_basic_cleanup(self) -> None: """Test that extremities are correctly calculated in the presence of @@ -147,7 +147,7 @@ def test_basic_cleanup(self) -> None: latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) + self.assertEqual(latest_event_ids, {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -155,7 +155,7 @@ def test_basic_cleanup(self) -> None: latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(latest_event_ids, [event_id_b]) + self.assertEqual(latest_event_ids, {event_id_b}) def test_chain_of_fail_cleanup(self) -> None: """Test that extremities are correctly calculated in the presence of @@ -185,7 +185,7 @@ def test_chain_of_fail_cleanup(self) -> None: latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) + self.assertEqual(latest_event_ids, {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -193,7 +193,7 @@ def test_chain_of_fail_cleanup(self) -> None: latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(latest_event_ids, [event_id_b]) + self.assertEqual(latest_event_ids, {event_id_b}) def test_forked_graph_cleanup(self) -> None: r"""Test that extremities are correctly calculated in the presence of @@ -240,7 +240,7 @@ def test_forked_graph_cleanup(self) -> None: latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c}) + self.assertEqual(latest_event_ids, {event_id_a, event_id_b, event_id_c}) # Run the background update and check it did the right thing self.run_background_update() @@ -248,7 +248,7 @@ def test_forked_graph_cleanup(self) -> None: latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c}) + self.assertEqual(latest_event_ids, {event_id_b, event_id_c}) class CleanupExtremDummyEventsTestCase(HomeserverTestCase): diff --git a/tests/test_federation.py b/tests/test_federation.py index f8ade6da3852..1b0504709edc 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -51,9 +51,15 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = self.hs.get_datastores().main # Figure out what the most recent event is - most_recent = self.get_success( - self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) - )[0] + most_recent = next( + iter( + self.get_success( + self.hs.get_datastores().main.get_latest_event_ids_in_room( + self.room_id + ) + ) + ) + ) join_event = make_event_from_dict( { @@ -100,8 +106,8 @@ async def _check_sigs_and_hash_for_pulled_events_and_fetch( # Make sure we actually joined the room self.assertEqual( - self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0], - "$join:test.serv", + self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)), + {"$join:test.serv"}, ) def test_cant_hide_direct_ancestors(self) -> None: @@ -127,9 +133,11 @@ async def post_json( self.http_client.post_json = post_json # Figure out what the most recent event is - most_recent = self.get_success( - self.store.get_latest_event_ids_in_room(self.room_id) - )[0] + most_recent = next( + iter( + self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) + ) + ) # Now lie about an event lying_event = make_event_from_dict( @@ -165,7 +173,7 @@ async def post_json( # Make sure the invalid event isn't there extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) - self.assertEqual(extrem[0], "$join:test.serv") + self.assertEqual(extrem, {"$join:test.serv"}) def test_retry_device_list_resync(self) -> None: """Tests that device lists are marked as stale if they couldn't be synced, and From 451c9b169f54240e5b73cfe46432a51f672b80f9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 14 Sep 2023 10:58:21 -0400 Subject: [PATCH 22/36] Revert "Kick CI" This reverts commit 2f759294f9091164c1b942d375035da929ab9565. --- .github/workflows/tests.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1b7948758f87..fb117380d028 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -109,13 +109,13 @@ jobs: # Cribbed from # https://github.com/AustinScola/mypy-cache-github-action/blob/85ea4f2972abed39b33bd02c36e341b28ca59213/src/restore.ts#L10-L17 -# - name: Restore/persist mypy's cache -# uses: actions/cache@v3 -# with: -# path: | -# .mypy_cache -# key: mypy-cache-${{ github.context.sha }} -# restore-keys: mypy-cache- + - name: Restore/persist mypy's cache + uses: actions/cache@v3 + with: + path: | + .mypy_cache + key: mypy-cache-${{ github.context.sha }} + restore-keys: mypy-cache- - name: Run mypy run: poetry run mypy From d52e30ccd3bb04312c8e934421adbde97b9bc217 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 14 Sep 2023 11:25:45 -0400 Subject: [PATCH 23/36] Newsfragment --- changelog.d/14911.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/14911.misc diff --git a/changelog.d/14911.misc b/changelog.d/14911.misc new file mode 100644 index 000000000000..93ceaeafc9b9 --- /dev/null +++ b/changelog.d/14911.misc @@ -0,0 +1 @@ +Improve type hints. From 47b7ba730d907f91a266951fd9f52ff7fc402297 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 14 Sep 2023 11:56:32 -0400 Subject: [PATCH 24/36] FIx-up sync changes. --- synapse/handlers/sync.py | 6 +++--- synapse/storage/databases/main/account_data.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 63bebbc4f024..dc5cb08de849 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1783,7 +1783,7 @@ async def _generate_sync_entry_for_account_data( since_token = sync_result_builder.since_token if since_token and not sync_result_builder.full_state: - updated_global_account_data = ( + global_account_data = ( await self.store.get_updated_global_account_data_for_user( user_id, since_token.account_data_key ) @@ -1794,11 +1794,11 @@ async def _generate_sync_entry_for_account_data( ) if push_rules_changed: - global_account_data: JsonMapping = { + global_account_data = { AccountDataTypes.PUSH_RULES: await self._push_rules_handler.push_rules_for_user( sync_config.user ), - **updated_global_account_data, + **global_account_data, } else: all_global_account_data = await self.store.get_global_account_data_for_user( diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index a29c57521d2e..80f146dd530a 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -394,7 +394,7 @@ def get_updated_room_account_data_txn( async def get_updated_global_account_data_for_user( self, user_id: str, stream_id: int - ) -> Dict[str, JsonDict]: + ) -> Mapping[str, JsonMapping]: """Get all the global account_data that's changed for a user. Args: From 745ad617188a69fd19336a93053ac8332d9b5581 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 29 Sep 2023 16:10:14 +0100 Subject: [PATCH 25/36] Correct context --- scripts-dev/mypy_synapse_plugin.py | 60 +++++++++++++++++++++++----- synapse/util/caches/descriptors.py | 64 +++++++++++++++++++++++++----- 2 files changed, 104 insertions(+), 20 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 8daf303ee30a..6fbb88fef2ef 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -16,13 +16,13 @@ can crop up, e.g the cache descriptors. """ -from typing import Callable, Optional, Tuple, Type +from typing import Callable, Optional, Tuple, Type, Union import mypy.types from mypy.erasetype import remove_instance_last_known_values from mypy.errorcodes import ErrorCode -from mypy.nodes import ARG_NAMED_OPT, Var -from mypy.plugin import MethodSigContext, Plugin +from mypy.nodes import ARG_NAMED_OPT, TempNode, Var +from mypy.plugin import FunctionSigContext, MethodSigContext, Plugin from mypy.typeops import bind_self from mypy.types import ( AnyType, @@ -47,6 +47,12 @@ def get_method_signature_hook( ) ): return cached_function_method_signature + + if fullname in ( + "synapse.util.caches.descriptors._CachedFunctionDescriptor.__call__", + ): + return check_is_cacheable_wrapper + return None @@ -142,21 +148,53 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: ret_type=ret_type, ) - # 5. Complain loudly if we are returning something mutable - check_is_cacheable(signature, ctx, ret_type) + return signature + + +def check_is_cacheable_wrapper(ctx: MethodSigContext) -> CallableType: + signature: CallableType = ctx.default_signature + + if not isinstance(ctx.args[0][0], TempNode): + ctx.api.note("Cached function is not a TempNode?!", ctx.context) # type: ignore[attr-defined] + return signature + + orig_sig = ctx.args[0][0].type + if not isinstance(orig_sig, CallableType): + ctx.api.fail("Cached 'function' is not a callable", ctx.context) + return signature + + ret_arg = None + if isinstance(orig_sig.ret_type, Instance): + # If a coroutine, wrap the coroutine's return type in a Deferred. + if orig_sig.ret_type.type.fullname == "typing.Coroutine": + ret_arg = orig_sig.ret_type.args[2] + + # If an awaitable, wrap the awaitable's final value in a Deferred. + elif orig_sig.ret_type.type.fullname == "typing.Awaitable": + ret_arg = orig_sig.ret_type.args[0] + + elif orig_sig.ret_type.type.fullname == "twisted.internet.defer.Deferred": + ret_arg = orig_sig.ret_type.args[0] + + if ret_arg is None: + ret_arg = orig_sig.ret_type + + assert ret_arg is not None + + check_is_cacheable( + orig_sig, + ctx, + ret_arg, + ) return signature def check_is_cacheable( signature: CallableType, - ctx: MethodSigContext, - deferred_return_type: Instance, + ctx: Union[MethodSigContext, FunctionSigContext], + return_type: mypy.types.Type, ) -> None: - # The previous code wraps the return type into a Deferred. - assert deferred_return_type.type.fullname == "twisted.internet.defer.Deferred" - return_type = deferred_return_type.args[0] - verbose = ctx.api.options.verbosity >= 1 # TODO Technically a cachedList only needs immutable values, but forcing them # to return Mapping instead of Dict is fine. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 8514a75a1c2f..e397c60b0a1b 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -36,6 +36,8 @@ ) from weakref import WeakValueDictionary +import attr + from twisted.internet import defer from twisted.python.failure import Failure @@ -466,6 +468,35 @@ def get_instance( ) +@attr.s(auto_attribs=True, slots=True, frozen=True) +class _CachedFunctionDescriptor(Generic[F]): + """Helper for `@cached`, we name it so that we can hook into it with mypy + plugin.""" + + max_entries: int + num_args: Optional[int] + uncached_args: Optional[Collection[str]] + tree: bool + cache_context: bool + iterable: bool + prune_unread_entries: bool + name: Optional[str] + + def __call__(self, orig: F) -> CachedFunction[F]: + d = DeferredCacheDescriptor( + orig, + max_entries=self.max_entries, + num_args=self.num_args, + uncached_args=self.uncached_args, + tree=self.tree, + cache_context=self.cache_context, + iterable=self.iterable, + prune_unread_entries=self.prune_unread_entries, + name=self.name, + ) + return cast(CachedFunction[F], d) + + def cached( *, max_entries: int = 1000, @@ -476,9 +507,8 @@ def cached( iterable: bool = False, prune_unread_entries: bool = True, name: Optional[str] = None, -) -> Callable[[F], CachedFunction[F]]: - func = lambda orig: DeferredCacheDescriptor( - orig, +) -> _CachedFunctionDescriptor: + return _CachedFunctionDescriptor( max_entries=max_entries, num_args=num_args, uncached_args=uncached_args, @@ -489,7 +519,26 @@ def cached( name=name, ) - return cast(Callable[[F], CachedFunction[F]], func) + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class _CachedListFunctionDescriptor(Generic[F]): + """Helper for `@cachedList`, we name it so that we can hook into it with mypy + plugin.""" + + cached_method_name: str + list_name: str + num_args: Optional[int] = None + name: Optional[str] = None + + def __call__(self, orig: F) -> CachedFunction[F]: + d = DeferredCacheListDescriptor( + orig, + cached_method_name=self.cached_method_name, + list_name=self.list_name, + num_args=self.num_args, + name=self.name, + ) + return cast(CachedFunction[F], d) def cachedList( @@ -498,7 +547,7 @@ def cachedList( list_name: str, num_args: Optional[int] = None, name: Optional[str] = None, -) -> Callable[[F], CachedFunction[F]]: +) -> _CachedListFunctionDescriptor: """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. Used to do batch lookups for an already created cache. One of the arguments @@ -527,16 +576,13 @@ def do_something(self, first_arg, second_arg): def batch_do_something(self, first_arg, second_args): ... """ - func = lambda orig: DeferredCacheListDescriptor( - orig, + return _CachedListFunctionDescriptor( cached_method_name=cached_method_name, list_name=list_name, num_args=num_args, name=name, ) - return cast(Callable[[F], CachedFunction[F]], func) - def _get_cache_key_builder( param_names: Sequence[str], From 03b0e407a914f9c27ecac6d9f7f033e7ca298686 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 29 Sep 2023 11:32:35 -0400 Subject: [PATCH 26/36] Remove ignores for call-sites. --- synapse/federation/federation_server.py | 2 +- synapse/handlers/relations.py | 4 ++-- synapse/handlers/sync.py | 4 ++-- synapse/storage/controllers/state.py | 2 +- tests/rest/admin/test_server_notice.py | 2 +- tests/rest/client/test_shadow_banned.py | 2 +- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 18dc0dab4739..f9915e5a3f05 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -710,7 +710,7 @@ async def on_send_join_request( state_event_ids: Collection[str] servers_in_room: Optional[Collection[str]] if caller_supports_partial_state: - summary = await self.store.get_room_summary(room_id) # type: ignore[synapse-@cached-mutable] + summary = await self.store.get_room_summary(room_id) state_event_ids = _get_event_ids_for_partial_state_join( event, prev_state_ids, summary ) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index f950e33dfaac..9b13448cdd7a 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -320,7 +320,7 @@ async def _get_threads_for_events( event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id] # Fetch thread summaries. - summaries = await self._main_store.get_thread_summaries(event_ids) # type: ignore[synapse-@cached-mutable] + summaries = await self._main_store.get_thread_summaries(event_ids) # Limit fetching whether the requester has participated in a thread to # events which are thread roots. @@ -514,7 +514,7 @@ async def _fetch_edits() -> None: Note that there is no use in limiting edits by ignored users since the parent event should be ignored in the first place if the user is ignored. """ - edits = await self._main_store.get_applicable_edits( # type: ignore[synapse-@cached-mutable] + edits = await self._main_store.get_applicable_edits( [ event_id for event_id, event in events_by_id.items() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index dc5cb08de849..7bd42f635fd0 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -811,7 +811,7 @@ async def compute_summary( ) # this is heavily cached, thus: fast. - details = await self.store.get_room_summary(room_id) # type: ignore[synapse-@cached-mutable] + details = await self.store.get_room_summary(room_id) name_id = state_ids.get((EventTypes.Name, "")) canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, "")) @@ -1336,7 +1336,7 @@ async def unread_notifs_for_room_id( return RoomNotifCounts.empty() with Measure(self.clock, "unread_notifs_for_room_id"): - return await self.store.get_unread_event_push_actions_by_room_for_user( # type: ignore[synapse-@cached-mutable] + return await self.store.get_unread_event_push_actions_by_room_for_user( room_id, sync_config.user.to_string(), ) diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 1854e2c447cc..10d219c0452e 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -682,7 +682,7 @@ async def _get_joined_hosts( # `get_joined_hosts` is called with the "current" state group for the # room, and so consecutive calls will be for consecutive state groups # which point to the previous state group. - cache = await self.stores.main._get_joined_hosts_cache(room_id) # type: ignore[synapse-@cached-mutable] + cache = await self.stores.main._get_joined_hosts_cache(room_id) # If the state group in the cache matches, we already have the data we need. if state_entry.state_group == cache.state_group: diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index 647f754c6daf..dfd14f5751bf 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -438,7 +438,7 @@ def test_send_server_notice_delete_room(self) -> None: # It doesn't really matter what API we use here, we just want to assert # that the room doesn't exist. - summary = self.get_success(self.store.get_room_summary(first_room_id)) # type: ignore[synapse-@cached-mutable] + summary = self.get_success(self.store.get_room_summary(first_room_id)) # The summary should be empty since the room doesn't exist. self.assertEqual(summary, {}) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index bd0ba3d057f5..9aecf88e4160 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -182,7 +182,7 @@ def test_upgrade(self) -> None: # It doesn't really matter what API we use here, we just want to assert # that the room doesn't exist. - summary = self.get_success(self.store.get_room_summary(new_room_id)) # type: ignore[synapse-@cached-mutable] + summary = self.get_success(self.store.get_room_summary(new_room_id)) # The summary should be empty since the room doesn't exist. self.assertEqual(summary, {}) From 460ed3cc666ee6111f2530ac85f7e1308f16a7f2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 29 Sep 2023 11:43:23 -0400 Subject: [PATCH 27/36] Add ignores at definition sites. --- synapse/handlers/room_list.py | 4 ++-- synapse/storage/controllers/state.py | 2 +- synapse/storage/databases/main/event_push_actions.py | 2 +- synapse/storage/databases/main/relations.py | 6 +++--- synapse/storage/databases/main/roommember.py | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index bb0bdb8e6f39..156442e18cbf 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -33,7 +33,7 @@ RequestSendFailed, SynapseError, ) -from synapse.types import JsonDict, ThirdPartyInstanceID +from synapse.types import JsonDict, ThirdPartyInstanceID, JsonMapping from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.response_cache import ResponseCache @@ -256,7 +256,7 @@ async def generate_room_entry( cache_context: _CacheContext, with_alias: bool = True, allow_private: bool = False, - ) -> Optional[JsonDict]: + ) -> Optional[JsonMapping]: """Returns the entry for a room Args: diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 46957723a14c..bb68399cadae 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -502,7 +502,7 @@ async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: return event.content.get("alias") - @cached() + @cached() # type: ignore[synapse-@cached-mutable] async def get_server_acl_for_room( self, room_id: str ) -> Optional[ServerAclEvaluator]: diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 2b95daced2eb..39556481ffc9 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -484,7 +484,7 @@ def _get_unread_counts_by_room_for_user_txn( return room_to_count - @cached(tree=True, max_entries=5000, iterable=True) + @cached(tree=True, max_entries=5000, iterable=True) # type: ignore[synapse-@cached-mutable] async def get_unread_event_push_actions_by_room_for_user( self, room_id: str, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index b2e3b4dc1ad8..0136625e6ebb 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -458,7 +458,7 @@ async def event_is_target_of_relation(self, parent_id: str) -> bool: ) return result is not None - @cached() + @cached() # type: ignore[synapse-@cached-mutable] async def get_references_for_event(self, event_id: str) -> List[JsonDict]: raise NotImplementedError() @@ -512,7 +512,7 @@ def _get_references_for_events_txn( "_get_references_for_events_txn", _get_references_for_events_txn ) - @cached() + @cached() # type: ignore[synapse-@cached-mutable] def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: raise NotImplementedError() @@ -599,7 +599,7 @@ def _get_applicable_edits_txn(txn: LoggingTransaction) -> Dict[str, str]: for original_event_id in event_ids } - @cached() + @cached() # type: ignore[synapse-@cached-mutable] def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]: raise NotImplementedError() diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 6ec6c726f0a0..e93573f315d2 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -275,7 +275,7 @@ def _get_users_in_room_with_profiles( _get_users_in_room_with_profiles, ) - @cached(max_entries=100000) + @cached(max_entries=100000) # type: ignore[synapse-@cached-mutable] async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]: """Get the details of a room roughly suitable for use by the room summary extension to /sync. Useful when lazy loading room members. @@ -1072,7 +1072,7 @@ async def _get_approximate_current_memberships_in_room( return {row["event_id"]: row["membership"] for row in rows} # TODO This returns a mutable object, which is generally confusing when using a cache. - @cached(max_entries=10000) + @cached(max_entries=10000) # type: ignore[synapse-@cached-mutable] def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache": return _JoinedHostsCache() From fbecb56790feea8a10bacf9780c5598e161b8afc Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 29 Sep 2023 11:45:36 -0400 Subject: [PATCH 28/36] Actually check cachedList. --- scripts-dev/mypy_synapse_plugin.py | 1 + synapse/storage/databases/main/relations.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 6fbb88fef2ef..5095e7c9755b 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -50,6 +50,7 @@ def get_method_signature_hook( if fullname in ( "synapse.util.caches.descriptors._CachedFunctionDescriptor.__call__", + "synapse.util.caches.descriptors._CachedListFunctionDescriptor.__call__", ): return check_is_cacheable_wrapper diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 0136625e6ebb..9246b418f501 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -517,7 +517,7 @@ def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: raise NotImplementedError() # TODO: This returns a mutable object, which is generally bad. - @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") + @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") # type: ignore[synapse-@cached-mutable] async def get_applicable_edits( self, event_ids: Collection[str] ) -> Mapping[str, Optional[EventBase]]: @@ -604,7 +604,7 @@ def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]: raise NotImplementedError() # TODO: This returns a mutable object, which is generally bad. - @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") + @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") # type: ignore[synapse-@cached-mutable] async def get_thread_summaries( self, event_ids: Collection[str] ) -> Mapping[str, Optional[Tuple[int, EventBase]]]: From dba8e72ff5c31bd8fcdcc27c616d778b54fda2e0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 29 Sep 2023 12:05:55 -0400 Subject: [PATCH 29/36] Fix incorrect generic. --- synapse/util/caches/descriptors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index e397c60b0a1b..ce736fdf75af 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -469,7 +469,7 @@ def get_instance( @attr.s(auto_attribs=True, slots=True, frozen=True) -class _CachedFunctionDescriptor(Generic[F]): +class _CachedFunctionDescriptor: """Helper for `@cached`, we name it so that we can hook into it with mypy plugin.""" @@ -521,7 +521,7 @@ def cached( @attr.s(auto_attribs=True, slots=True, frozen=True) -class _CachedListFunctionDescriptor(Generic[F]): +class _CachedListFunctionDescriptor: """Helper for `@cachedList`, we name it so that we can hook into it with mypy plugin.""" From 4f06d85d7d7db9ee54110857943223c310d72493 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 29 Sep 2023 12:14:16 -0400 Subject: [PATCH 30/36] Lint --- synapse/handlers/room_list.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 156442e18cbf..36e2db897540 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -33,7 +33,7 @@ RequestSendFailed, SynapseError, ) -from synapse.types import JsonDict, ThirdPartyInstanceID, JsonMapping +from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.response_cache import ResponseCache From 3875662080be9fc53da26e4fc70ff42113ee24f1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 29 Sep 2023 13:19:37 -0400 Subject: [PATCH 31/36] Abstract shared code. --- scripts-dev/mypy_synapse_plugin.py | 82 +++++++++++------------------- 1 file changed, 31 insertions(+), 51 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 5095e7c9755b..7df30e2311b5 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -57,6 +57,25 @@ def get_method_signature_hook( return None +def _get_true_return_type(signature: CallableType) -> mypy.types.Type: + """Get the "final" return type of a maybe async/Deferred.""" + if isinstance(signature.ret_type, Instance): + # If a coroutine, unwrap the coroutine's return type. + if signature.ret_type.type.fullname == "typing.Coroutine": + return signature.ret_type.args[2] + + # If an awaitable, unwrap the awaitable's final value. + elif signature.ret_type.type.fullname == "typing.Awaitable": + return signature.ret_type.args[0] + + # If a Deferred, unwrap the Deferred's final value. + elif signature.ret_type.type.fullname == "twisted.internet.defer.Deferred": + return signature.ret_type.args[0] + + # Otherwise, return the raw value of the function. + return signature.ret_type + + def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: """Fixes the `CachedFunction.__call__` signature to be correct. @@ -113,34 +132,15 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg. # 4. Ensure the return type is a Deferred. - if ( - isinstance(signature.ret_type, Instance) - and signature.ret_type.type.fullname == "twisted.internet.defer.Deferred" - ): - # If it is already a Deferred, nothing to do. - ret_type = signature.ret_type - else: - ret_arg = None - if isinstance(signature.ret_type, Instance): - # If a coroutine, wrap the coroutine's return type in a Deferred. - if signature.ret_type.type.fullname == "typing.Coroutine": - ret_arg = signature.ret_type.args[2] - - # If an awaitable, wrap the awaitable's final value in a Deferred. - elif signature.ret_type.type.fullname == "typing.Awaitable": - ret_arg = signature.ret_type.args[0] - - # Otherwise, wrap the return value in a Deferred. - if ret_arg is None: - ret_arg = signature.ret_type - - # This should be able to use ctx.api.named_generic_type, but that doesn't seem - # to find the correct symbol for anything more than 1 module deep. - # - # modules is not part of CheckerPluginInterface. The following is a combination - # of TypeChecker.named_generic_type and TypeChecker.lookup_typeinfo. - sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred") # type: ignore[attr-defined] - ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)]) + ret_arg = _get_true_return_type(signature) + + # This should be able to use ctx.api.named_generic_type, but that doesn't seem + # to find the correct symbol for anything more than 1 module deep. + # + # modules is not part of CheckerPluginInterface. The following is a combination + # of TypeChecker.named_generic_type and TypeChecker.lookup_typeinfo. + sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred") # type: ignore[attr-defined] + ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)]) signature = signature.copy_modified( arg_types=arg_types, @@ -164,29 +164,9 @@ def check_is_cacheable_wrapper(ctx: MethodSigContext) -> CallableType: ctx.api.fail("Cached 'function' is not a callable", ctx.context) return signature - ret_arg = None - if isinstance(orig_sig.ret_type, Instance): - # If a coroutine, wrap the coroutine's return type in a Deferred. - if orig_sig.ret_type.type.fullname == "typing.Coroutine": - ret_arg = orig_sig.ret_type.args[2] - - # If an awaitable, wrap the awaitable's final value in a Deferred. - elif orig_sig.ret_type.type.fullname == "typing.Awaitable": - ret_arg = orig_sig.ret_type.args[0] - - elif orig_sig.ret_type.type.fullname == "twisted.internet.defer.Deferred": - ret_arg = orig_sig.ret_type.args[0] - - if ret_arg is None: - ret_arg = orig_sig.ret_type - - assert ret_arg is not None - - check_is_cacheable( - orig_sig, - ctx, - ret_arg, - ) + # Unwrap the true return type from the cached function. + ret_arg = _get_true_return_type(orig_sig) + check_is_cacheable(orig_sig, ctx, ret_arg) return signature From fb4ff5dfdadbb30adec86d459b3458d843eb45a5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 29 Sep 2023 13:25:18 -0400 Subject: [PATCH 32/36] ServerAclEvaluator is immutable. --- scripts-dev/mypy_synapse_plugin.py | 1 + synapse/storage/controllers/state.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 7df30e2311b5..e5f9d7f140d2 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -208,6 +208,7 @@ def check_is_cacheable( # Types defined in Synapse which are known to be immutable. IMMUTABLE_CUSTOM_TYPES = { + "synapse.synapse_rust.acl.ServerAclEvaluator", "synapse.synapse_rust.push.FilteredPushRules", # This is technically not immutable, but close enough. "signedjson.types.VerifyKey", diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index bb68399cadae..46957723a14c 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -502,7 +502,7 @@ async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: return event.content.get("alias") - @cached() # type: ignore[synapse-@cached-mutable] + @cached() async def get_server_acl_for_room( self, room_id: str ) -> Optional[ServerAclEvaluator]: From 06ddf659ab513306d04851c0e14daef3e40404ed Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 29 Sep 2023 13:33:50 -0400 Subject: [PATCH 33/36] Update comments. --- scripts-dev/mypy_synapse_plugin.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index e5f9d7f140d2..035a6d5ecd12 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -255,19 +255,28 @@ def is_cacheable( rt.type.fullname in IMMUTABLE_VALUE_TYPES or rt.type.fullname in IMMUTABLE_CUSTOM_TYPES ): + # "Simple" types are generally immutable. return True, None elif rt.type.fullname == "typing.Mapping": - return is_cacheable(rt.args[1], signature, verbose) + # Generally mapping keys are immutable, but they only *have* to be + # hashable, which doesn't imply immutability. E.g. Mapping[K, V] + # is cachable iff K and V are cachable. + return is_cacheable(rt.args[0], signature, verbose) and is_cacheable( + rt.args[1], signature, verbose + ) elif rt.type.fullname in IMMUTABLE_CONTAINER_TYPES_REQUIRING_IMMUTABLE_ELEMENTS: # E.g. Collection[T] is cachable iff T is cachable. return is_cacheable(rt.args[0], signature, verbose) elif rt.type.fullname in MUTABLE_CONTAINER_TYPES: + # Mutable containers are mutable regardless of their underlying type. return False, None elif "attrs" in rt.type.metadata: + # attrs classes are only cachable iff it is frozen (immutable itself) + # and all attributes are cachable. frozen = rt.type.metadata["attrs"]["frozen"] if frozen: for attribute in rt.type.metadata["attrs"]["attributes"]: @@ -284,12 +293,16 @@ def is_cacheable( return False, "non-frozen attrs class" else: - return False, f"Don't know how to handle {rt.type.fullname}" + # Ensure we fail for unknown types, these generally means that the + # above code is not complete. + return False, f"Don't know how to handle {rt.type.fullname} return type instance" elif isinstance(rt, NoneType): + # None is cachable. return True, None elif isinstance(rt, (TupleType, UnionType)): + # Tuples and unions are cachable iff all their items are cachable. for item in rt.items: ok, note = is_cacheable(item, signature, verbose) if not ok: @@ -298,14 +311,17 @@ def is_cacheable( return True, None elif isinstance(rt, TypeAliasType): + # For a type alias, check if the underlying real type is cachable. return is_cacheable(mypy.types.get_proper_type(rt), signature, verbose) - # The tests check what happens if you raise an Exception, so they don't return. elif isinstance(rt, UninhabitedType) and rt.is_noreturn: - # There's no return value, just consider it cachable. + # There is no return value, just consider it cachable. This is only used + # in tests. return True, None else: + # Ensure we fail for unknown types, these generally means that the + # above code is not complete. return False, f"Don't know how to handle {type(rt).__qualname__} return type" From 9b7ee0386da33ede55059ef0b0750bc9720b5b97 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 29 Sep 2023 15:31:47 -0400 Subject: [PATCH 34/36] Lint --- scripts-dev/mypy_synapse_plugin.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 035a6d5ecd12..a814a09e65b7 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -295,7 +295,10 @@ def is_cacheable( else: # Ensure we fail for unknown types, these generally means that the # above code is not complete. - return False, f"Don't know how to handle {rt.type.fullname} return type instance" + return ( + False, + f"Don't know how to handle {rt.type.fullname} return type instance", + ) elif isinstance(rt, NoneType): # None is cachable. From e2f599d98bddff0b10181a7193f235301ce9c57f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 2 Oct 2023 08:14:37 -0400 Subject: [PATCH 35/36] Update comments and remove unnecessary argument. --- scripts-dev/mypy_synapse_plugin.py | 33 +++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index a814a09e65b7..ef79dea799e7 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -58,7 +58,9 @@ def get_method_signature_hook( def _get_true_return_type(signature: CallableType) -> mypy.types.Type: - """Get the "final" return type of a maybe async/Deferred.""" + """ + Get the "final" return type of a callable which might return returned an Awaitable/Deferred. + """ if isinstance(signature.ret_type, Instance): # If a coroutine, unwrap the coroutine's return type. if signature.ret_type.type.fullname == "typing.Coroutine": @@ -84,6 +86,7 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: 1. the `self` argument needs to be marked as "bound"; 2. any `cache_context` argument should be removed; 3. an optional keyword argument `on_invalidated` should be added. + 4. Wrap the return type to always be a Deferred. """ # 1. Mark this as a bound function signature. @@ -93,7 +96,7 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: # # Note: We should be only doing this if `cache_context=True` is set, but if # it isn't then the code will raise an exception when its called anyway, so - # its not the end of the world. + # it's not the end of the world. context_arg_index = None for idx, name in enumerate(signature.arg_names): if name == "cache_context": @@ -153,6 +156,11 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: def check_is_cacheable_wrapper(ctx: MethodSigContext) -> CallableType: + """Asserts that the signature of a method returns a value which can be cached. + + Makes no changes to the provided method signature. + """ + # The true signature, this isn't being modified so this is what will be returned. signature: CallableType = ctx.default_signature if not isinstance(ctx.args[0][0], TempNode): @@ -164,9 +172,7 @@ def check_is_cacheable_wrapper(ctx: MethodSigContext) -> CallableType: ctx.api.fail("Cached 'function' is not a callable", ctx.context) return signature - # Unwrap the true return type from the cached function. - ret_arg = _get_true_return_type(orig_sig) - check_is_cacheable(orig_sig, ctx, ret_arg) + check_is_cacheable(orig_sig, ctx) return signature @@ -174,8 +180,17 @@ def check_is_cacheable_wrapper(ctx: MethodSigContext) -> CallableType: def check_is_cacheable( signature: CallableType, ctx: Union[MethodSigContext, FunctionSigContext], - return_type: mypy.types.Type, ) -> None: + """ + Check if a callable returns a type which can be cached. + + Args: + signature: The callable to check. + ctx: The signature context, used for error reporting. + """ + # Unwrap the true return type from the cached function. + return_type = _get_true_return_type(signature) + verbose = ctx.api.options.verbosity >= 1 # TODO Technically a cachedList only needs immutable values, but forcing them # to return Mapping instead of Dict is fine. @@ -190,7 +205,6 @@ def check_is_cacheable( message += f" ({note})" message = message.replace("builtins.", "").replace("typing.", "") - # TODO The context is the context of the caller, not the method itself. if ok and note: ctx.api.note(message, ctx.context) # type: ignore[attr-defined] elif not ok: @@ -240,6 +254,11 @@ def is_cacheable( rt: mypy.types.Type, signature: CallableType, verbose: bool ) -> Tuple[bool, Optional[str]]: """ + Check if a particular type is cachable. + + A type is cachable if it is immutable; for complex types this recurses to + check each type parameter. + Returns: a 2-tuple (cacheable, message). - cachable: False means the type is definitely not cacheable; true means anything else. From a2956a6e6e4154a315849ac0bd012e1ad4a61609 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 2 Oct 2023 08:26:00 -0400 Subject: [PATCH 36/36] Fix duplicate word. Co-authored-by: David Robertson --- scripts-dev/mypy_synapse_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index ef79dea799e7..6592a4a6b7ea 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -59,7 +59,7 @@ def get_method_signature_hook( def _get_true_return_type(signature: CallableType) -> mypy.types.Type: """ - Get the "final" return type of a callable which might return returned an Awaitable/Deferred. + Get the "final" return type of a callable which might return an Awaitable/Deferred. """ if isinstance(signature.ret_type, Instance): # If a coroutine, unwrap the coroutine's return type.