Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Speed up updating state in large rooms #15971

Merged
merged 3 commits into from Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15971.misc
@@ -0,0 +1 @@
Speed up updating state in large rooms.
9 changes: 4 additions & 5 deletions synapse/handlers/message.py
Expand Up @@ -1565,12 +1565,11 @@ async def cache_joined_hosts_for_events(
if state_entry.state_group in self._external_cache_joined_hosts_updates:
return

state = await state_entry.get_state(
self._storage_controllers.state, StateFilter.all()
)
with opentracing.start_active_span("get_joined_hosts"):
joined_hosts = await self.store.get_joined_hosts(
event.room_id, state, state_entry
joined_hosts = (
await self._storage_controllers.state.get_joined_hosts(
event.room_id, state_entry
)
)

# Note that the expiry times must be larger than the expiry time in
Expand Down
3 changes: 1 addition & 2 deletions synapse/state/__init__.py
Expand Up @@ -268,8 +268,7 @@ async def get_hosts_in_room_at_events(
The hosts in the room at the given events
"""
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
return await self.store.get_joined_hosts(room_id, state, entry)
return await self._state_storage_controller.get_joined_hosts(room_id, entry)

@trace
@tag_args
Expand Down
137 changes: 135 additions & 2 deletions synapse/storage/controllers/state.py
Expand Up @@ -12,36 +12,45 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from itertools import chain
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
)

from synapse.api.constants import EventTypes
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.logging.opentracing import tag_args, trace
from synapse.storage.roommember import ProfileInfo
from synapse.storage.util.partial_state_events_tracker import (
PartialCurrentStateTracker,
PartialStateEventsTracker,
)
from synapse.types import MutableStateMap, StateMap
from synapse.types import MutableStateMap, StateMap, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached
from synapse.util.cancellation import cancellable
from synapse.util.metrics import Measure

if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.state import _StateCacheEntry
from synapse.storage.databases import Databases


logger = logging.getLogger(__name__)


Expand All @@ -52,10 +61,15 @@ class StateStorageController:

def __init__(self, hs: "HomeServer", stores: "Databases"):
self._is_mine_id = hs.is_mine_id
self._clock = hs.get_clock()
self.stores = stores
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main)

# Used by `_get_joined_hosts` to ensure only one thing mutates the cache
# at a time. Keyed by room_id.
self._joined_host_linearizer = Linearizer("_JoinedHostsCache")

def notify_event_un_partial_stated(self, event_id: str) -> None:
self._partial_state_events_tracker.notify_un_partial_stated(event_id)

Expand Down Expand Up @@ -627,3 +641,122 @@ async def get_users_in_room_with_profiles(
await self._partial_state_room_tracker.await_full_state(room_id)

return await self.stores.main.get_users_in_room_with_profiles(room_id)

async def get_joined_hosts(
self, room_id: str, state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
state_group: Union[object, int] = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()

assert state_group is not None
with Measure(self._clock, "get_joined_hosts"):
return await self._get_joined_hosts(
room_id, state_group, state_entry=state_entry
)

@cached(num_args=2, max_entries=10000, iterable=True)
async def _get_joined_hosts(
self,
room_id: str,
state_group: Union[object, int],
state_entry: "_StateCacheEntry",
) -> FrozenSet[str]:
# We don't use `state_group`, it's there so that we can cache based on
# it. However, its important that its never None, since two
# current_state's with a state_group of None are likely to be different.
#
# The `state_group` must match the `state_entry.state_group` (if not None).
assert state_group is not None
assert state_entry.state_group is None or state_entry.state_group == state_group

# We use a secondary cache of previous work to allow us to build up the
# joined hosts for the given state group based on previous state groups.
#
# We cache one object per room containing the results of the last state
# group we got joined hosts for. The idea is that generally
# `get_joined_hosts` is called with the "current" state group for the
# room, and so consecutive calls will be for consecutive state groups
# which point to the previous state group.
cache = await self.stores.main._get_joined_hosts_cache(room_id)

# If the state group in the cache matches, we already have the data we need.
if state_entry.state_group == cache.state_group:
return frozenset(cache.hosts_to_joined_users)

# Since we'll mutate the cache we need to lock.
async with self._joined_host_linearizer.queue(room_id):
if state_entry.state_group == cache.state_group:
# Same state group, so nothing to do. We've already checked for
# this above, but the cache may have changed while waiting on
# the lock.
pass
elif state_entry.prev_group == cache.state_group:
# The cached work is for the previous state group, so we work out
# the delta.
assert state_entry.delta_ids is not None
for (typ, state_key), event_id in state_entry.delta_ids.items():
if typ != EventTypes.Member:
continue

host = intern_string(get_domain_from_id(state_key))
user_id = state_key
known_joins = cache.hosts_to_joined_users.setdefault(host, set())

event = await self.stores.main.get_event(event_id)
if event.membership == Membership.JOIN:
known_joins.add(user_id)
else:
known_joins.discard(user_id)

if not known_joins:
cache.hosts_to_joined_users.pop(host, None)
else:
# The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles.
#
# We need to fetch all hosts joined to the room according to `state` by
# inspecting all join memberships in `state`. However, if the `state` is
# relatively recent then many of its events are likely to be held in
# the current state of the room, which is easily available and likely
# cached.
#
# We therefore compute the set of `state` events not in the
# current state and only fetch those.
current_memberships = (
await self.stores.main._get_approximate_current_memberships_in_room(
room_id
)
)
unknown_state_events = {}
joined_users_in_current_state = []

state = await state_entry.get_state(
self, StateFilter.from_types([(EventTypes.Member, None)])
)

for (type, state_key), event_id in state.items():
if event_id not in current_memberships:
unknown_state_events[type, state_key] = event_id
elif current_memberships[event_id] == Membership.JOIN:
joined_users_in_current_state.append(state_key)

joined_user_ids = await self.stores.main.get_joined_user_ids_from_state(
room_id, unknown_state_events
)

cache.hosts_to_joined_users = {}
for user_id in chain(joined_user_ids, joined_users_in_current_state):
host = intern_string(get_domain_from_id(user_id))
cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)

if state_entry.state_group:
cache.state_group = state_entry.state_group
else:
cache.state_group = object()

return frozenset(cache.hosts_to_joined_users)
122 changes: 0 additions & 122 deletions synapse/storage/databases/main/roommember.py
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from itertools import chain
from typing import (
TYPE_CHECKING,
AbstractSet,
Expand Down Expand Up @@ -57,15 +56,12 @@
StrCollection,
get_domain_from_id,
)
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure

if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.state import _StateCacheEntry

logger = logging.getLogger(__name__)

Expand All @@ -91,10 +87,6 @@ def __init__(
):
super().__init__(database, db_conn, hs)

# Used by `_get_joined_hosts` to ensure only one thing mutates the cache
# at a time. Keyed by room_id.
self._joined_host_linearizer = Linearizer("_JoinedHostsCache")

self._server_notices_mxid = hs.config.servernotices.server_notices_mxid

if (
Expand Down Expand Up @@ -1057,120 +1049,6 @@ def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]:
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
)

async def get_joined_hosts(
self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
state_group: Union[object, int] = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()

assert state_group is not None
with Measure(self._clock, "get_joined_hosts"):
return await self._get_joined_hosts(
room_id, state_group, state, state_entry=state_entry
)

@cached(num_args=2, max_entries=10000, iterable=True)
async def _get_joined_hosts(
self,
room_id: str,
state_group: Union[object, int],
state: StateMap[str],
state_entry: "_StateCacheEntry",
) -> FrozenSet[str]:
# We don't use `state_group`, it's there so that we can cache based on
# it. However, its important that its never None, since two
# current_state's with a state_group of None are likely to be different.
#
# The `state_group` must match the `state_entry.state_group` (if not None).
assert state_group is not None
assert state_entry.state_group is None or state_entry.state_group == state_group

# We use a secondary cache of previous work to allow us to build up the
# joined hosts for the given state group based on previous state groups.
#
# We cache one object per room containing the results of the last state
# group we got joined hosts for. The idea is that generally
# `get_joined_hosts` is called with the "current" state group for the
# room, and so consecutive calls will be for consecutive state groups
# which point to the previous state group.
cache = await self._get_joined_hosts_cache(room_id)

# If the state group in the cache matches, we already have the data we need.
if state_entry.state_group == cache.state_group:
return frozenset(cache.hosts_to_joined_users)

# Since we'll mutate the cache we need to lock.
async with self._joined_host_linearizer.queue(room_id):
if state_entry.state_group == cache.state_group:
# Same state group, so nothing to do. We've already checked for
# this above, but the cache may have changed while waiting on
# the lock.
pass
elif state_entry.prev_group == cache.state_group:
# The cached work is for the previous state group, so we work out
# the delta.
assert state_entry.delta_ids is not None
for (typ, state_key), event_id in state_entry.delta_ids.items():
if typ != EventTypes.Member:
continue

host = intern_string(get_domain_from_id(state_key))
user_id = state_key
known_joins = cache.hosts_to_joined_users.setdefault(host, set())

event = await self.get_event(event_id)
if event.membership == Membership.JOIN:
known_joins.add(user_id)
else:
known_joins.discard(user_id)

if not known_joins:
cache.hosts_to_joined_users.pop(host, None)
else:
# The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles.
#
# We need to fetch all hosts joined to the room according to `state` by
# inspecting all join memberships in `state`. However, if the `state` is
# relatively recent then many of its events are likely to be held in
# the current state of the room, which is easily available and likely
# cached.
#
# We therefore compute the set of `state` events not in the
# current state and only fetch those.
current_memberships = (
await self._get_approximate_current_memberships_in_room(room_id)
)
unknown_state_events = {}
joined_users_in_current_state = []

for (type, state_key), event_id in state.items():
if event_id not in current_memberships:
unknown_state_events[type, state_key] = event_id
elif current_memberships[event_id] == Membership.JOIN:
joined_users_in_current_state.append(state_key)

joined_user_ids = await self.get_joined_user_ids_from_state(
room_id, unknown_state_events
)

cache.hosts_to_joined_users = {}
for user_id in chain(joined_user_ids, joined_users_in_current_state):
host = intern_string(get_domain_from_id(user_id))
cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)

if state_entry.state_group:
cache.state_group = state_entry.state_group
else:
cache.state_group = object()

return frozenset(cache.hosts_to_joined_users)

async def _get_approximate_current_memberships_in_room(
self, room_id: str
) -> Mapping[str, Optional[str]]:
Expand Down