Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Track and deduplicate in-flight requests to _get_state_for_groups. #10870

Merged
merged 36 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
363565e
Add a stub implementation of `StateFilter.approx_difference`
reivilibre Sep 21, 2021
1fceefd
Add method to request 1 state group, tracking the inflight request
reivilibre Sep 21, 2021
d998903
Add method to gather in-flight requests and calculate left-over filter
reivilibre Sep 21, 2021
247f558
Add method to get 1 state group, both using and updating in-flight cache
reivilibre Sep 21, 2021
7dcdab4
Use the in-flight caches for `_get_state_for_groups`
reivilibre Sep 21, 2021
831a7a4
Add some basic tests about requests getting deduplicated
reivilibre Sep 21, 2021
857b2d2
Newsfile
reivilibre Sep 21, 2021
a30042b
Convert to review comments
reivilibre Sep 21, 2021
cc76d9f
Use `yieldable_gather_results` helper because it's more elegant
reivilibre Sep 27, 2021
58ef32e
Revert "Add a stub implementation of `StateFilter.approx_difference`"
reivilibre Oct 12, 2021
af85ac4
Merge remote-tracking branch 'origin/develop' into rei/gsfg_1
reivilibre Oct 12, 2021
f78a082
Fix up log contexts in _get_state_for_group_fire_request
reivilibre Nov 4, 2021
8ea530a
Check != against StateFilter.none() for clarity
reivilibre Nov 4, 2021
5352f21
Directly await fetched state for simplicity
reivilibre Nov 4, 2021
db840d2
Simplify gatherResults and fix log contexts
reivilibre Nov 4, 2021
e38f795
Merge branch 'develop' into rei/gsfg_1
reivilibre Nov 4, 2021
8325ddd
Merge branch 'develop' into rei/gsfg_1
reivilibre Dec 20, 2021
232be1d
Fix up misunderstanding (fixes tests)
reivilibre Dec 20, 2021
a3ec20c
Add licence header
reivilibre Jan 17, 2022
af7b61d
Simplify `super` calls
reivilibre Jan 17, 2022
f5321a7
Use less keyboard-feline-sounding names
reivilibre Jan 17, 2022
ae49d99
Update synapse/storage/databases/state/store.py
reivilibre Jan 17, 2022
a4ececa
Simplify the fake get_state_for_groups
reivilibre Jan 17, 2022
e9cb9b0
Merge branch 'develop' into rei/gsfg_1
reivilibre Feb 10, 2022
e2d1ea3
Use an AbstractObservableDeferred
reivilibre Feb 10, 2022
61f85a0
Consume the errors from the ObservableDeferred
reivilibre Feb 10, 2022
6328c65
Update tests/storage/databases/test_state_store.py
reivilibre Feb 16, 2022
950a8ca
Update synapse/storage/databases/state/store.py
reivilibre Feb 16, 2022
f4f756a
Convert set to list before doubly iterating
reivilibre Feb 16, 2022
a57e7ce
Collapse for loop onto one line
reivilibre Feb 16, 2022
969d45e
Use patch.object in prepare() for adding a patch
reivilibre Feb 16, 2022
0d6bcef
Add docstring on test_duplicate_requests_deduplicated
reivilibre Feb 16, 2022
7c2949b
Restructure _get_state_for_group_gather_inflight_requests a bit
reivilibre Feb 16, 2022
fcc7786
Describe the flow in the docstring
reivilibre Feb 16, 2022
2c8b936
Restructure the way state is gathered to prevent building an intermed…
reivilibre Feb 16, 2022
1b5cabe
Merge branch 'develop' into rei/gsfg_1
reivilibre Feb 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10870.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Deduplicate in-flight requests in `_get_state_for_groups`.
201 changes: 177 additions & 24 deletions synapse/storage/databases/state/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,23 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
Optional,
Sequence,
Set,
Tuple,
)

import attr

from twisted.internet import defer

from synapse.api.constants import EventTypes
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
Expand All @@ -29,6 +41,12 @@
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import MutableStateMap, StateKey, StateMap
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import (
AbstractObservableDeferred,
ObservableDeferred,
yieldable_gather_results,
)
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache

Expand All @@ -37,7 +55,6 @@

logger = logging.getLogger(__name__)


MAX_STATE_DELTA_HOPS = 100


Expand Down Expand Up @@ -106,6 +123,12 @@ def __init__(
500000,
)

# Current ongoing get_state_for_groups in-flight requests
# {group ID -> {StateFilter -> ObservableDeferred}}
self._state_group_inflight_requests: Dict[
int, Dict[StateFilter, AbstractObservableDeferred[StateMap[str]]]
] = {}

def get_max_state_group_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
return txn.fetchone()[0] # type: ignore
Expand Down Expand Up @@ -157,7 +180,7 @@ def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta:
)

async def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
self, groups: Sequence[int], state_filter: StateFilter
) -> Dict[int, StateMap[str]]:
"""Returns the state groups for a given set of groups from the
database, filtering on types of state events.
Expand Down Expand Up @@ -228,6 +251,150 @@ def _get_state_for_group_using_cache(

return state_filter.filter_state(state_dict_ids), not missing_types

def _get_state_for_group_gather_inflight_requests(
self, group: int, state_filter_left_over: StateFilter
) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]:
"""
Attempts to gather in-flight requests and re-use them to retrieve state
for the given state group, filtered with the given state filter.

Used as part of _get_state_for_group_using_inflight_cache.

Returns:
Tuple of two values:
A sequence of ObservableDeferreds to observe
A StateFilter representing what else needs to be requested to fulfill the request
"""

inflight_requests = self._state_group_inflight_requests.get(group)
if inflight_requests is None:
# no requests for this group, need to retrieve it all ourselves
return (), state_filter_left_over

# The list of ongoing requests which will help narrow the current request.
reusable_requests = []
reivilibre marked this conversation as resolved.
Show resolved Hide resolved
for (request_state_filter, request_deferred) in inflight_requests.items():
new_state_filter_left_over = state_filter_left_over.approx_difference(
request_state_filter
)
if new_state_filter_left_over == state_filter_left_over:
# Reusing this request would not gain us anything, so don't bother.
continue

reusable_requests.append(request_deferred)
state_filter_left_over = new_state_filter_left_over
clokep marked this conversation as resolved.
Show resolved Hide resolved
if state_filter_left_over == StateFilter.none():
# we have managed to collect enough of the in-flight requests
# to cover our StateFilter and give us the state we need.
break

return reusable_requests, state_filter_left_over

async def _get_state_for_group_fire_request(
self, group: int, state_filter: StateFilter
) -> StateMap[str]:
"""
Fires off a request to get the state at a state group,
potentially filtering by type and/or state key.

This request will be tracked in the in-flight request cache and automatically
removed when it is finished.

Used as part of _get_state_for_group_using_inflight_cache.

Args:
group: ID of the state group for which we want to get state
state_filter: the state filter used to fetch state from the database
"""
cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence

# Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded()

async def _the_request() -> StateMap[str]:
group_to_state_dict = await self._get_state_groups_from_groups(
(group,), state_filter=db_state_filter
)

# Now let's update the caches
self._insert_into_cache(
group_to_state_dict,
db_state_filter,
cache_seq_num_members=cache_sequence_m,
cache_seq_num_non_members=cache_sequence_nm,
)

# Remove ourselves from the in-flight cache
group_request_dict = self._state_group_inflight_requests[group]
del group_request_dict[db_state_filter]
if not group_request_dict:
# If there are no more requests in-flight for this group,
# clean up the cache by removing the empty dictionary
del self._state_group_inflight_requests[group]

return group_to_state_dict[group]

# We don't immediately await the result, so must use run_in_background
# But we DO await the result before the current log context (request)
# finishes, so don't need to run it as a background process.
request_deferred = run_in_background(_the_request)
observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True)

# Insert the ObservableDeferred into the cache
group_request_dict = self._state_group_inflight_requests.setdefault(group, {})
group_request_dict[db_state_filter] = observable_deferred

return await make_deferred_yieldable(observable_deferred.observe())

async def _get_state_for_group_using_inflight_cache(
reivilibre marked this conversation as resolved.
Show resolved Hide resolved
self, group: int, state_filter: StateFilter
) -> MutableStateMap[str]:
"""
Gets the state at a state group, potentially filtering by type and/or
state key.

1. Calls _get_state_for_group_gather_inflight_requests to gather any
ongoing requests which might overlap with the current request.
2. Fires a new request, using _get_state_for_group_fire_request,
for any state which cannot be gathered from ongoing requests.

Args:
group: ID of the state group for which we want to get state
state_filter: the state filter used to fetch state from the database
Returns:
state map
"""

# first, figure out whether we can re-use any in-flight requests
# (and if so, what would be left over)
(
reusable_requests,
state_filter_left_over,
) = self._get_state_for_group_gather_inflight_requests(group, state_filter)

if state_filter_left_over != StateFilter.none():
# Fetch remaining state
remaining = await self._get_state_for_group_fire_request(
group, state_filter_left_over
)
reivilibre marked this conversation as resolved.
Show resolved Hide resolved
assembled_state: MutableStateMap[str] = dict(remaining)
else:
assembled_state = {}

gathered = await make_deferred_yieldable(
defer.gatherResults(
(r.observe() for r in reusable_requests), consumeErrors=True
)
).addErrback(unwrapFirstError)
clokep marked this conversation as resolved.
Show resolved Hide resolved

# assemble our result.
for result_piece in gathered:
assembled_state.update(result_piece)

# Filter out any state that may be more than what we asked for.
return state_filter.filter_state(assembled_state)

async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Dict[int, MutableStateMap[str]]:
Expand Down Expand Up @@ -269,30 +436,16 @@ async def _get_state_for_groups(
if not incomplete_groups:
return state

cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence

# Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded()
incomplete_groups_list = list(incomplete_groups)

group_to_state_dict = await self._get_state_groups_from_groups(
list(incomplete_groups), state_filter=db_state_filter
results_from_requests = await yieldable_gather_results(
self._get_state_for_group_using_inflight_cache,
incomplete_groups_list,
state_filter,
)

# Now lets update the caches
self._insert_into_cache(
group_to_state_dict,
db_state_filter,
cache_seq_num_members=cache_sequence_m,
cache_seq_num_non_members=cache_sequence_nm,
)

# And finally update the result dict, by filtering out any extra
# stuff we pulled out of the database.
for group, group_state_dict in group_to_state_dict.items():
# We just replace any existing entries, as we will have loaded
# everything we need from the database anyway.
state[group] = state_filter.filter_state(group_state_dict)
for group, group_result in zip(incomplete_groups_list, results_from_requests):
state[group] = group_result

return state

Expand Down
133 changes: 133 additions & 0 deletions tests/storage/databases/test_state_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from typing import Dict, List, Sequence, Tuple
reivilibre marked this conversation as resolved.
Show resolved Hide resolved
from unittest.mock import patch

from twisted.internet.defer import Deferred, ensureDeferred
from twisted.test.proto_helpers import MemoryReactor

from synapse.storage.state import StateFilter
from synapse.types import MutableStateMap, StateMap
from synapse.util import Clock

from tests.unittest import HomeserverTestCase

if typing.TYPE_CHECKING:
from synapse.server import HomeServer


class StateGroupInflightCachingTestCase(HomeserverTestCase):
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: "HomeServer"
) -> None:
self.state_storage = homeserver.get_storage().state
self.state_datastore = homeserver.get_datastores().state
# Patch out the `_get_state_groups_from_groups`.
# This is useful because it lets us pretend we have a slow database.
reivilibre marked this conversation as resolved.
Show resolved Hide resolved
get_state_groups_patch = patch.object(
self.state_datastore,
"_get_state_groups_from_groups",
self._fake_get_state_groups_from_groups,
)
get_state_groups_patch.start()

self.addCleanup(get_state_groups_patch.stop)
self.get_state_group_calls: List[
Tuple[Tuple[int, ...], StateFilter, Deferred[Dict[int, StateMap[str]]]]
] = []

def _fake_get_state_groups_from_groups(
clokep marked this conversation as resolved.
Show resolved Hide resolved
self, groups: Sequence[int], state_filter: StateFilter
) -> "Deferred[Dict[int, StateMap[str]]]":
d: Deferred[Dict[int, StateMap[str]]] = Deferred()
self.get_state_group_calls.append((tuple(groups), state_filter, d))
return d

def _complete_request_fake(
self,
groups: Tuple[int, ...],
state_filter: StateFilter,
d: "Deferred[Dict[int, StateMap[str]]]",
) -> None:
"""
Assemble a fake database response and complete the database request.
"""

result: Dict[int, StateMap[str]] = {}

for group in groups:
group_result: MutableStateMap[str] = {}
result[group] = group_result

for state_type, state_keys in state_filter.types.items():
if state_keys is None:
group_result[(state_type, "a")] = "xyz"
group_result[(state_type, "b")] = "xyz"
else:
for state_key in state_keys:
group_result[(state_type, state_key)] = "abc"

if state_filter.include_others:
group_result[("other.event.type", "state.key")] = "123"

d.callback(result)

def test_duplicate_requests_deduplicated(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test seems nuanced enough that it could use a docstring saying what the overall steps are that are happening.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added one; any good?

"""
Tests that duplicate requests for state are deduplicated.

This test:
- requests some state (state group 42, 'all' state filter)
- requests it again, before the first request finishes
- checks to see that only one database query was made
- completes the database query
- checks that both requests see the same retrieved state
"""
req1 = ensureDeferred(
self.state_datastore._get_state_for_group_using_inflight_cache(
42, StateFilter.all()
)
)
self.pump(by=0.1)

# This should have gone to the database
self.assertEqual(len(self.get_state_group_calls), 1)
self.assertFalse(req1.called)

req2 = ensureDeferred(
self.state_datastore._get_state_for_group_using_inflight_cache(
42, StateFilter.all()
)
)
self.pump(by=0.1)

# No more calls should have gone to the database
self.assertEqual(len(self.get_state_group_calls), 1)
self.assertFalse(req1.called)
self.assertFalse(req2.called)

groups, sf, d = self.get_state_group_calls[0]
self.assertEqual(groups, (42,))
self.assertEqual(sf, StateFilter.all())

# Now we can complete the request
self._complete_request_fake(groups, sf, d)

self.assertEqual(
self.get_success(req1), {("other.event.type", "state.key"): "123"}
)
self.assertEqual(
self.get_success(req2), {("other.event.type", "state.key"): "123"}
)