Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add some missing type hints to cache datastore. #12216

Merged
merged 2 commits into from Mar 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12216.misc
@@ -0,0 +1 @@
Add missing type hints for cache storage.
57 changes: 36 additions & 21 deletions synapse/storage/databases/main/cache.py
Expand Up @@ -23,6 +23,7 @@
EventsStream,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
Expand All @@ -31,6 +32,7 @@
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
Expand Down Expand Up @@ -82,7 +84,9 @@ async def get_all_updated_caches(
if last_id == current_id:
return [], current_id, False

def get_all_updated_caches_txn(txn):
def get_all_updated_caches_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
Expand All @@ -107,7 +111,9 @@ def get_all_updated_caches_txn(txn):
"get_all_updated_caches", get_all_updated_caches_txn
)

def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
Comment on lines +114 to +116
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the Iterable[Any] seems to match what we do for the other instances of process_replication_rows, it isn't ideal but I think we want them to all match?

if stream_name == EventsStream.NAME:
for row in rows:
self._process_event_stream_row(token, row)
Expand Down Expand Up @@ -142,10 +148,11 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):

super().process_replication_rows(stream_name, instance_name, token, rows)

def _process_event_stream_row(self, token, row):
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data

if row.type == EventsStreamEventRow.TypeId:
assert isinstance(data, EventsStreamEventRow)
self._invalidate_caches_for_event(
token,
data.event_id,
Expand All @@ -157,9 +164,8 @@ def _process_event_stream_row(self, token, row):
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
self._curr_state_delta_stream_cache.entity_has_changed(
row.data.room_id, token
)
assert isinstance(data, EventsStreamCurrentStateRow)
self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)

if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate(
Expand All @@ -170,15 +176,15 @@ def _process_event_stream_row(self, token, row):

def _invalidate_caches_for_event(
self,
stream_ordering,
event_id,
room_id,
etype,
state_key,
redacts,
relates_to,
backfilled,
):
stream_ordering: int,
event_id: str,
room_id: str,
etype: str,
state_key: Optional[str],
redacts: Optional[str],
relates_to: Optional[str],
backfilled: bool,
) -> None:
self._invalidate_get_event_cache(event_id)
self.have_seen_event.invalidate((room_id, event_id))

Expand Down Expand Up @@ -207,7 +213,9 @@ def _invalidate_caches_for_event(
self.get_thread_summary.invalidate((relates_to,))
self.get_thread_participated.invalidate((relates_to,))

async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...]
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.

Expand All @@ -227,7 +235,12 @@ async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ..
keys,
)

def _invalidate_cache_and_stream(self, txn, cache_func, keys):
def _invalidate_cache_and_stream(
self,
txn: LoggingTransaction,
cache_func: _CachedFunction,
keys: Tuple[Any, ...],
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.

Expand All @@ -238,7 +251,9 @@ def _invalidate_cache_and_stream(self, txn, cache_func, keys):
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)

def _invalidate_all_cache_and_stream(self, txn, cache_func):
def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: _CachedFunction
) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
"""
Expand Down Expand Up @@ -279,8 +294,8 @@ def _invalidate_state_caches_and_stream(
)

def _send_invalidation_to_replication(
self, txn, cache_name: str, keys: Optional[Iterable[Any]]
):
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
) -> None:
"""Notifies replication that given cache has been invalidated.

Note that this does *not* invalidate the cache locally.
Expand Down Expand Up @@ -315,7 +330,7 @@ def _send_invalidation_to_replication(
"instance_name": self._instance_name,
"cache_func": cache_name,
"keys": keys,
"invalidation_ts": self.clock.time_msec(),
"invalidation_ts": self._clock.time_msec(),
},
)

Expand Down