Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert more cached return values to immutable types (#16356)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Sep 20, 2023
1 parent d7c89c5 commit 7ec0a14
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 36 deletions.
1 change: 1 addition & 0 deletions changelog.d/16356.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
8 changes: 4 additions & 4 deletions synapse/api/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase, relation_from_event
from synapse.types import JsonDict, RoomID, UserID
from synapse.types import JsonDict, JsonMapping, RoomID, UserID

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -191,7 +191,7 @@ def check_valid_filter(self, user_filter_json: JsonDict) -> None:


class FilterCollection:
def __init__(self, hs: "HomeServer", filter_json: JsonDict):
def __init__(self, hs: "HomeServer", filter_json: JsonMapping):
self._filter_json = filter_json

room_filter_json = self._filter_json.get("room", {})
Expand Down Expand Up @@ -219,7 +219,7 @@ def __init__(self, hs: "HomeServer", filter_json: JsonDict):
def __repr__(self) -> str:
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)

def get_filter_json(self) -> JsonDict:
def get_filter_json(self) -> JsonMapping:
return self._filter_json

def timeline_limit(self) -> int:
Expand Down Expand Up @@ -313,7 +313,7 @@ def blocks_all_room_timeline(self) -> bool:


class Filter:
def __init__(self, hs: "HomeServer", filter_json: JsonDict):
def __init__(self, hs: "HomeServer", filter_json: JsonMapping):
self._hs = hs
self._store = hs.get_datastores().main
self.filter_json = filter_json
Expand Down
4 changes: 2 additions & 2 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
from synapse.http.client import is_unknown_endpoint
from synapse.http.types import QueryParams
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.types import JsonDict, StrCollection, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
Expand Down Expand Up @@ -1704,7 +1704,7 @@ async def send_request(
async def timestamp_to_event(
self,
*,
destinations: List[str],
destinations: StrCollection,
room_id: str,
timestamp: int,
direction: Direction,
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,7 +1538,7 @@ async def _resync_device(self, sender: str) -> None:
logger.exception("Failed to resync device for %s", sender)

async def backfill_event_id(
self, destinations: List[str], room_id: str, event_id: str
self, destinations: StrCollection, room_id: str, event_id: str
) -> PulledPduInfo:
"""Backfill a single event and persist it as a non-outlier which means
we also pull in all of the state and auth events necessary for it.
Expand Down
14 changes: 12 additions & 2 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@
# limitations under the License.
import enum
import logging
from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional
from typing import (
TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Mapping,
Optional,
Sequence,
)

import attr

Expand Down Expand Up @@ -245,7 +255,7 @@ async def redact_events_related_to(

async def get_references_for_events(
self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
) -> Dict[str, List[_RelatedEvent]]:
) -> Mapping[str, Sequence[_RelatedEvent]]:
"""Get a list of references to the given events.
Args:
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID
from synapse.types import JsonDict, JsonMapping, UserID

from ._base import client_patterns, set_timeline_upper_limit

Expand All @@ -41,7 +41,7 @@ def __init__(self, hs: "HomeServer"):

async def on_GET(
self, request: SynapseRequest, user_id: str, filter_id: str
) -> Tuple[int, JsonDict]:
) -> Tuple[int, JsonMapping]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)

Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/controllers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:

@trace
@tag_args
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]:
"""Get current hosts in room based on current state.
Blocks until we have full state for the given room. This only happens for rooms
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, UserID
from synapse.types import JsonDict, JsonMapping, UserID
from synapse.util.caches.descriptors import cached

if TYPE_CHECKING:
Expand Down Expand Up @@ -145,7 +145,7 @@ def _final_batch(txn: LoggingTransaction, lower_bound_id: str) -> None:
@cached(num_args=2)
async def get_user_filter(
self, user_id: UserID, filter_id: Union[int, str]
) -> JsonDict:
) -> JsonMapping:
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN.
try:
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ async def get_references_for_event(self, event_id: str) -> List[JsonDict]:
@cachedList(cached_method_name="get_references_for_event", list_name="event_ids")
async def get_references_for_events(
self, event_ids: Collection[str]
) -> Mapping[str, Optional[List[_RelatedEvent]]]:
) -> Mapping[str, Optional[Sequence[_RelatedEvent]]]:
"""Get a list of references to the given events.
Args:
Expand Down Expand Up @@ -931,7 +931,7 @@ async def get_threads(
room_id: str,
limit: int = 5,
from_token: Optional[ThreadsNextBatch] = None,
) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
) -> Tuple[Sequence[str], Optional[ThreadsNextBatch]]:
"""Get a list of thread IDs, ordered by topological ordering of their
latest reply.
Expand Down
10 changes: 6 additions & 4 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,7 @@ def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
)

@cached(iterable=True, max_entries=10000)
async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]:
"""
Get current hosts in room based on current state.
Expand Down Expand Up @@ -1013,12 +1013,14 @@ async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
# `get_users_in_room` rather than funky SQL.

domains = await self.get_current_hosts_in_room(room_id)
return list(domains)
return tuple(domains)

# For PostgreSQL we can use a regex to pull out the domains from the
# joined users in `current_state_events` via regex.

def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]:
def get_current_hosts_in_room_ordered_txn(
txn: LoggingTransaction,
) -> Tuple[str, ...]:
# Returns a list of servers currently joined in the room sorted by
# longest in the room first (aka. with the lowest depth). The
# heuristic of sorting by servers who have been in the room the
Expand All @@ -1043,7 +1045,7 @@ def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]:
"""
txn.execute(sql, (room_id,))
# `server_domain` will be `NULL` for malformed MXIDs with no colons.
return [d for d, in txn if d is not None]
return tuple(d for d, in txn if d is not None)

return await self.db_pool.runInteraction(
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
Expand Down
35 changes: 19 additions & 16 deletions tests/util/caches/test_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import logging
from typing import (
Any,
Dict,
Generator,
Iterable,
List,
Mapping,
NoReturn,
Optional,
Set,
Expand Down Expand Up @@ -96,7 +96,7 @@ def __init__(self) -> None:
self.mock = mock.Mock()

@descriptors.cached(num_args=1)
def fn(self, arg1: int, arg2: int) -> mock.Mock:
def fn(self, arg1: int, arg2: int) -> str:
return self.mock(arg1, arg2)

obj = Cls()
Expand Down Expand Up @@ -228,8 +228,9 @@ class Cls:
call_count = 0

@cached()
def fn(self, arg1: int) -> Optional[Deferred]:
def fn(self, arg1: int) -> Deferred:
self.call_count += 1
assert self.result is not None
return self.result

obj = Cls()
Expand Down Expand Up @@ -401,31 +402,31 @@ def __init__(self) -> None:
self.mock = mock.Mock()

@descriptors.cached(iterable=True)
def fn(self, arg1: int, arg2: int) -> List[str]:
def fn(self, arg1: int, arg2: int) -> Tuple[str, ...]:
return self.mock(arg1, arg2)

obj = Cls()

obj.mock.return_value = ["spam", "eggs"]
obj.mock.return_value = ("spam", "eggs")
r = obj.fn(1, 2)
self.assertEqual(r.result, ["spam", "eggs"])
self.assertEqual(r.result, ("spam", "eggs"))
obj.mock.assert_called_once_with(1, 2)
obj.mock.reset_mock()

# a call with different params should call the mock again
obj.mock.return_value = ["chips"]
obj.mock.return_value = ("chips",)
r = obj.fn(1, 3)
self.assertEqual(r.result, ["chips"])
self.assertEqual(r.result, ("chips",))
obj.mock.assert_called_once_with(1, 3)
obj.mock.reset_mock()

# the two values should now be cached
self.assertEqual(len(obj.fn.cache.cache), 3)

r = obj.fn(1, 2)
self.assertEqual(r.result, ["spam", "eggs"])
self.assertEqual(r.result, ("spam", "eggs"))
r = obj.fn(1, 3)
self.assertEqual(r.result, ["chips"])
self.assertEqual(r.result, ("chips",))
obj.mock.assert_not_called()

def test_cache_iterable_with_sync_exception(self) -> None:
Expand Down Expand Up @@ -784,7 +785,9 @@ def fn(self, arg1: int, arg2: int) -> None:
pass

@descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1: Iterable[int], arg2: int) -> Dict[int, str]:
async def list_fn(
self, args1: Iterable[int], arg2: int
) -> Mapping[int, str]:
context = current_context()
assert isinstance(context, LoggingContext)
assert context.name == "c1"
Expand Down Expand Up @@ -847,11 +850,11 @@ def fn(self, arg1: int) -> None:
pass

@descriptors.cachedList(cached_method_name="fn", list_name="args1")
def list_fn(self, args1: List[int]) -> "Deferred[dict]":
def list_fn(self, args1: List[int]) -> "Deferred[Mapping[int, str]]":
return self.mock(args1)

obj = Cls()
deferred_result: "Deferred[dict]" = Deferred()
deferred_result: "Deferred[Mapping[int, str]]" = Deferred()
obj.mock.return_value = deferred_result

# start off several concurrent lookups of the same key
Expand Down Expand Up @@ -890,7 +893,7 @@ def fn(self, arg1: int, arg2: int) -> None:
pass

@descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1: List[int], arg2: int) -> Dict[int, str]:
async def list_fn(self, args1: List[int], arg2: int) -> Mapping[int, str]:
# we want this to behave like an asynchronous function
await run_on_reactor()
return self.mock(args1, arg2)
Expand Down Expand Up @@ -929,7 +932,7 @@ def fn(self, arg1: int) -> None:
pass

@cachedList(cached_method_name="fn", list_name="args")
async def list_fn(self, args: List[int]) -> Dict[int, str]:
async def list_fn(self, args: List[int]) -> Mapping[int, str]:
await complete_lookup
return {arg: str(arg) for arg in args}

Expand Down Expand Up @@ -964,7 +967,7 @@ def fn(self, arg1: int) -> None:
pass

@cachedList(cached_method_name="fn", list_name="args")
async def list_fn(self, args: List[int]) -> Dict[int, str]:
async def list_fn(self, args: List[int]) -> Mapping[int, str]:
await make_deferred_yieldable(complete_lookup)
self.inner_context_was_finished = current_context().finished
return {arg: str(arg) for arg in args}
Expand Down

0 comments on commit 7ec0a14

Please sign in to comment.