Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Only store data in caches, not "smart" objects #9845

Merged
merged 9 commits into from
Apr 23, 2021
1 change: 1 addition & 0 deletions changelog.d/9845.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Only store the raw data in the in-memory caches, rather than objects that include references to e.g. the data stores.
161 changes: 92 additions & 69 deletions synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()

# Used by `RulesForRoom` to ensure only one thing mutates the cache at a
# time. Keyed off room_id.
self._rules_linearizer = Linearizer(name="rules_for_room")

erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
self.room_push_rule_cache_metrics = register_cache(
"cache",
"room_push_rule_cache",
Expand All @@ -123,7 +127,16 @@ async def _get_rules_for_event(
dict of user_id -> push_rules
"""
room_id = event.room_id
rules_for_room = self._get_rules_for_room(room_id)

rules_for_room_data = self._get_rules_for_room(room_id)
rules_for_room = RulesForRoom(
hs=self.hs,
room_id=room_id,
rules_for_room_cache=self._get_rules_for_room.cache,
room_push_rule_cache_metrics=self.room_push_rule_cache_metrics,
linearizer=self._rules_linearizer,
cached_data=rules_for_room_data,
)

rules_by_user = await rules_for_room.get_rules(event, context)

Expand All @@ -142,17 +155,12 @@ async def _get_rules_for_event(
return rules_by_user

@lru_cache()
def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
"""Get the current RulesForRoom object for the given room id"""
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache
def _get_rules_for_room(self, room_id: str) -> "RulesForRoomData":
"""Get the current RulesForRoomData object for the given room id"""
# It's important that the RulesForRoomData object gets added to self._get_rules_for_room.cache
# before any lookup methods get called on it as otherwise there may be
# a race if invalidate_all gets called (which assumes its in the cache)
return RulesForRoom(
self.hs,
room_id,
self._get_rules_for_room.cache,
self.room_push_rule_cache_metrics,
)
return RulesForRoomData()

async def _get_power_levels_and_sender_level(
self, event: EventBase, context: EventContext
Expand Down Expand Up @@ -282,11 +290,49 @@ def _condition_checker(
return True


@attr.s(slots=True)
class RulesForRoomData:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could probably do with a docstring describing what it contains.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can haz docstring?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/me mutters at github hiding things

"""The data stored in the cache by `RulesForRoom`.

We don't store `RulesForRoom` directly in the cache as we want our caches to
*only* include data, and not references to e.g. the data stores.
"""

# event_id -> (user_id, state)
member_map = attr.ib(type=Dict[str, Tuple[str, str]], factory=dict)
# user_id -> rules
rules_by_user = attr.ib(type=Dict[str, List[Dict[str, dict]]], factory=dict)

# The last state group we updated the caches for. If the state_group of
# a new event comes along, we know that we can just return the cached
# result.
# On invalidation of the rules themselves (if the user changes them),
# we invalidate everything and set state_group to `object()`
state_group = attr.ib(type=Union[object, int], factory=object)

# A sequence number to keep track of when we're allowed to update the
# cache. We bump the sequence number when we invalidate the cache. If
# the sequence number changes while we're calculating stuff we should
# not update the cache with it.
sequence = attr.ib(type=int, default=0)

# A cache of user_ids that we *know* aren't interesting, e.g. user_ids
# owned by AS's, or remote users, etc. (I.e. users we will never need to
# calculate push for)
# These never need to be invalidated as we will never set up push for
# them.
uninteresting_user_set = attr.ib(type=Set[str], factory=set)


class RulesForRoom:
"""Caches push rules for users in a room.

This efficiently handles users joining/leaving the room by not invalidating
the entire cache for the room.

A new instance is constructed for each call to
`BulkPushRuleEvaluator._get_rules_for_event`, with the cached data from
previous calls passed in.
"""

def __init__(
Expand All @@ -295,6 +341,8 @@ def __init__(
room_id: str,
rules_for_room_cache: LruCache,
room_push_rule_cache_metrics: CacheMetric,
linearizer: Linearizer,
cached_data: RulesForRoomData,
):
"""
Args:
Expand All @@ -303,38 +351,21 @@ def __init__(
rules_for_room_cache: The cache object that caches these
RoomsForUser objects.
room_push_rule_cache_metrics: The metrics object
linearizer: The linearizer used to ensure only one thing mutates
the cache at a time. Keyed off room_id
cached_data: Cached data from previous calls to `self.get_rules`,
can be mutated.
"""
self.room_id = room_id
self.is_mine_id = hs.is_mine_id
self.store = hs.get_datastore()
self.room_push_rule_cache_metrics = room_push_rule_cache_metrics

self.linearizer = Linearizer(name="rules_for_room")

# event_id -> (user_id, state)
self.member_map = {} # type: Dict[str, Tuple[str, str]]
# user_id -> rules
self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]]

# The last state group we updated the caches for. If the state_group of
# a new event comes along, we know that we can just return the cached
# result.
# On invalidation of the rules themselves (if the user changes them),
# we invalidate everything and set state_group to `object()`
self.state_group = object()

# A sequence number to keep track of when we're allowed to update the
# cache. We bump the sequence number when we invalidate the cache. If
# the sequence number changes while we're calculating stuff we should
# not update the cache with it.
self.sequence = 0

# A cache of user_ids that we *know* aren't interesting, e.g. user_ids
# owned by AS's, or remote users, etc. (I.e. users we will never need to
# calculate push for)
# These never need to be invalidated as we will never set up push for
# them.
self.uninteresting_user_set = set() # type: Set[str]
# Used to ensure only one thing mutates the cache at a time. Keyed off
# room_id.
self.linearizer = linearizer

self.data = cached_data

# We need to be clever on the invalidating caches callbacks, as
# otherwise the invalidation callback holds a reference to the object,
Expand All @@ -352,25 +383,25 @@ async def get_rules(
"""
state_group = context.state_group

if state_group and self.state_group == state_group:
if state_group and self.data.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits()
return self.rules_by_user
return self.data.rules_by_user

with (await self.linearizer.queue(())):
if state_group and self.state_group == state_group:
with (await self.linearizer.queue(self.room_id)):
if state_group and self.data.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits()
return self.rules_by_user
return self.data.rules_by_user

self.room_push_rule_cache_metrics.inc_misses()

ret_rules_by_user = {}
missing_member_event_ids = {}
if state_group and self.state_group == context.prev_group:
if state_group and self.data.state_group == context.prev_group:
# If we have a simple delta then we can reuse most of the previous
# results.
ret_rules_by_user = self.rules_by_user
ret_rules_by_user = self.data.rules_by_user
current_state_ids = context.delta_ids

push_rules_delta_state_cache_metric.inc_hits()
Expand All @@ -393,24 +424,24 @@ async def get_rules(
if typ != EventTypes.Member:
continue

if user_id in self.uninteresting_user_set:
if user_id in self.data.uninteresting_user_set:
continue

if not self.is_mine_id(user_id):
self.uninteresting_user_set.add(user_id)
self.data.uninteresting_user_set.add(user_id)
continue

if self.store.get_if_app_services_interested_in_user(user_id):
self.uninteresting_user_set.add(user_id)
self.data.uninteresting_user_set.add(user_id)
continue

event_id = current_state_ids[key]

res = self.member_map.get(event_id, None)
res = self.data.member_map.get(event_id, None)
if res:
user_id, state = res
if state == Membership.JOIN:
rules = self.rules_by_user.get(user_id, None)
rules = self.data.rules_by_user.get(user_id, None)
if rules:
ret_rules_by_user[user_id] = rules
continue
Expand All @@ -430,7 +461,7 @@ async def get_rules(
else:
# The push rules didn't change but lets update the cache anyway
self.update_cache(
self.sequence,
self.data.sequence,
members={}, # There were no membership changes
rules_by_user=ret_rules_by_user,
state_group=state_group,
Expand Down Expand Up @@ -461,7 +492,7 @@ async def _update_rules_with_member_event_ids(
for. Used when updating the cache.
event: The event we are currently computing push rules for.
"""
sequence = self.sequence
sequence = self.data.sequence

rows = await self.store.get_membership_from_event_ids(member_event_ids.values())

Expand Down Expand Up @@ -501,23 +532,11 @@ async def _update_rules_with_member_event_ids(

self.update_cache(sequence, members, ret_rules_by_user, state_group)

def invalidate_all(self) -> None:
# Note: Don't hand this function directly to an invalidation callback
# as it keeps a reference to self and will stop this instance from being
# GC'd if it gets dropped from the rules_to_user cache. Instead use
# `self.invalidate_all_cb`
logger.debug("Invalidating RulesForRoom for %r", self.room_id)
self.sequence += 1
self.state_group = object()
self.member_map = {}
self.rules_by_user = {}
push_rules_invalidation_counter.inc()

def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
if sequence == self.sequence:
self.member_map.update(members)
self.rules_by_user = rules_by_user
self.state_group = state_group
if sequence == self.data.sequence:
self.data.member_map.update(members)
self.data.rules_by_user = rules_by_user
self.data.state_group = state_group


@attr.attrs(slots=True, frozen=True)
Expand All @@ -535,6 +554,10 @@ class _Invalidation:
room_id = attr.ib(type=str)

def __call__(self) -> None:
rules = self.cache.get(self.room_id, None, update_metrics=False)
if rules:
rules.invalidate_all()
rules_data = self.cache.get(self.room_id, None, update_metrics=False)
if rules_data:
rules_data.sequence += 1
rules_data.state_group = object()
rules_data.member_map = {}
rules_data.rules_by_user = {}
push_rules_invalidation_counter.inc()
Loading