diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index f820544d89ca..34c662c4dbd7 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -14,6 +14,8 @@ import logging from typing import Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar +import attr + from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -30,6 +32,28 @@ RV = TypeVar("RV") +@attr.s(auto_attribs=True) +class ResponseCacheContext(Generic[KV]): + """Information about a missed ResponseCache hit + + This object can be passed into the callback for additional feedback + """ + + cache_key: KV + """The cache key that caused the cache miss + + This should be considered read-only. + + TODO: in attrs 20.1, make it frozen with an on_setattr. + """ + + should_cache: bool = True + """Whether the result should be cached once the request completes. + + This can be modified by the callback if it decides its result should not be cached. + """ + + class ResponseCache(Generic[KV]): """ This caches a deferred response. Until the deferred completes it will be @@ -79,7 +103,9 @@ def get(self, key: KV) -> Optional[defer.Deferred]: self._metrics.inc_misses() return None - def set(self, key: KV, deferred: defer.Deferred) -> defer.Deferred: + def _set( + self, context: ResponseCacheContext[KV], deferred: defer.Deferred + ) -> defer.Deferred: """Set the entry for the given key to the given deferred. *deferred* should run its callbacks in the sentinel logcontext (ie, @@ -90,21 +116,26 @@ def set(self, key: KV, deferred: defer.Deferred) -> defer.Deferred: You will probably want to make_deferred_yieldable the result. Args: - key: key to get/set in the cache + context: Information about the cache miss deferred: The deferred which resolves to the result. Returns: A new deferred which resolves to the actual result. """ result = ObservableDeferred(deferred, consumeErrors=True) + key = context.cache_key self.pending_result_cache[key] = result def on_complete(r): - if self.timeout_sec: + # if this cache has a non-zero timeout, and the callback has not cleared + # the should_cache bit, we leave it in the cache for now and schedule + # its removal later. + if self.timeout_sec and context.should_cache: self.clock.call_later( self.timeout_sec, self.pending_result_cache.pop, key, None ) else: + # otherwise, remove the result immediately. self.pending_result_cache.pop(key, None) return r @@ -115,7 +146,12 @@ def on_complete(r): return result.observe() async def wrap( - self, key: KV, callback: Callable[..., Awaitable[RV]], *args: Any, **kwargs: Any + self, + key: KV, + callback: Callable[..., Awaitable[RV]], + *args: Any, + cache_context: bool = False, + **kwargs: Any, ) -> RV: """Wrap together a *get* and *set* call, taking care of logcontexts @@ -145,6 +181,9 @@ async def handle_request(request): *args: positional parameters to pass to the callback, if it is used + cache_context: if set, the callback will be given a `cache_context` kw arg, + which will be a ResponseCacheContext object. + **kwargs: named parameters to pass to the callback, if it is used Returns: @@ -155,8 +194,11 @@ async def handle_request(request): logger.debug( "[%s]: no cached result for [%s], calculating new one", self._name, key ) + context = ResponseCacheContext(cache_key=key) + if cache_context: + kwargs["cache_context"] = context d = run_in_background(callback, *args, **kwargs) - result = self.set(key, d) + result = self._set(context, d) elif not isinstance(result, defer.Deferred) or result.called: logger.info("[%s]: using completed cached result for [%s]", self._name, key) else: diff --git a/tests/util/caches/test_response_cache.py b/tests/util/caches/test_response_cache.py index d2f3c2c7fa82..f69419766fac 100644 --- a/tests/util/caches/test_response_cache.py +++ b/tests/util/caches/test_response_cache.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from parameterized import parameterized from twisted.internet import defer -from synapse.util.caches.response_cache import ResponseCache +from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext from tests.server import get_clock from tests.unittest import TestCase @@ -142,3 +143,49 @@ def test_cache_wait_expire(self): self.reactor.pump((2,)) self.assertIsNone(cache.get(0), "cache should not have the result now") + + @parameterized.expand([(True,), (False,)]) + def test_cache_context_nocache(self, should_cache: bool): + """If the callback clears the should_cache bit, the result should not be cached""" + cache = self.with_cache("medium_cache", ms=3000) + + expected_result = "howdy" + + call_count = [0] + + async def non_caching(o: str, cache_context: ResponseCacheContext[int]): + call_count[0] += 1 + await self.clock.sleep(1) + cache_context.should_cache = should_cache + return o + + wrap_d = defer.ensureDeferred( + cache.wrap(0, non_caching, expected_result, cache_context=True) + ) + # there should be no result to start with + self.assertNoResult(wrap_d) + + # a second call should also return a pending deferred + wrap2_d = defer.ensureDeferred( + cache.wrap(0, non_caching, expected_result, cache_context=True) + ) + self.assertNoResult(wrap2_d) + + # and there should have been exactly one call + self.assertEqual(call_count[0], 1) + + # let the call complete + self.reactor.advance(1) + + # both results should have completed + self.assertEqual(expected_result, self.successResultOf(wrap_d)) + self.assertEqual(expected_result, self.successResultOf(wrap2_d)) + + if should_cache: + self.assertEqual( + expected_result, + self.successResultOf(cache.get(0)), + "cache should still have the result", + ) + else: + self.assertIsNone(cache.get(0), "cache should not have the result")