Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Extend ResponseCache to pass a context object to the callback
Browse files Browse the repository at this point in the history
... allowing the callback to specify whether or not the result should be
cached.
  • Loading branch information
richvdh committed Jun 9, 2021
1 parent f48542e commit 63e6c20
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 6 deletions.
52 changes: 47 additions & 5 deletions synapse/util/caches/response_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import logging
from typing import Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar

import attr

from twisted.internet import defer

from synapse.logging.context import make_deferred_yieldable, run_in_background
Expand All @@ -30,6 +32,28 @@
RV = TypeVar("RV")


@attr.s(auto_attribs=True)
class ResponseCacheContext(Generic[KV]):
"""Information about a missed ResponseCache hit
This object can be passed into the callback for additional feedback
"""

cache_key: KV
"""The cache key that caused the cache miss
This should be considered read-only.
TODO: in attrs 20.1, make it frozen with an on_setattr.
"""

should_cache: bool = True
"""Whether the result should be cached once the request completes.
This can be modified by the callback if it decides its result should not be cached.
"""


class ResponseCache(Generic[KV]):
"""
This caches a deferred response. Until the deferred completes it will be
Expand Down Expand Up @@ -79,7 +103,9 @@ def get(self, key: KV) -> Optional[defer.Deferred]:
self._metrics.inc_misses()
return None

def set(self, key: KV, deferred: defer.Deferred) -> defer.Deferred:
def _set(
self, context: ResponseCacheContext[KV], deferred: defer.Deferred
) -> defer.Deferred:
"""Set the entry for the given key to the given deferred.
*deferred* should run its callbacks in the sentinel logcontext (ie,
Expand All @@ -90,21 +116,26 @@ def set(self, key: KV, deferred: defer.Deferred) -> defer.Deferred:
You will probably want to make_deferred_yieldable the result.
Args:
key: key to get/set in the cache
context: Information about the cache miss
deferred: The deferred which resolves to the result.
Returns:
A new deferred which resolves to the actual result.
"""
result = ObservableDeferred(deferred, consumeErrors=True)
key = context.cache_key
self.pending_result_cache[key] = result

def on_complete(r):
if self.timeout_sec:
# if this cache has a non-zero timeout, and the callback has not cleared
# the should_cache bit, we leave it in the cache for now and schedule
# its removal later.
if self.timeout_sec and context.should_cache:
self.clock.call_later(
self.timeout_sec, self.pending_result_cache.pop, key, None
)
else:
# otherwise, remove the result immediately.
self.pending_result_cache.pop(key, None)
return r

Expand All @@ -115,7 +146,12 @@ def on_complete(r):
return result.observe()

async def wrap(
self, key: KV, callback: Callable[..., Awaitable[RV]], *args: Any, **kwargs: Any
self,
key: KV,
callback: Callable[..., Awaitable[RV]],
*args: Any,
cache_context: bool = False,
**kwargs: Any,
) -> RV:
"""Wrap together a *get* and *set* call, taking care of logcontexts
Expand Down Expand Up @@ -145,6 +181,9 @@ async def handle_request(request):
*args: positional parameters to pass to the callback, if it is used
cache_context: if set, the callback will be given a `cache_context` kw arg,
which will be a ResponseCacheContext object.
**kwargs: named parameters to pass to the callback, if it is used
Returns:
Expand All @@ -155,8 +194,11 @@ async def handle_request(request):
logger.debug(
"[%s]: no cached result for [%s], calculating new one", self._name, key
)
context = ResponseCacheContext(cache_key=key)
if cache_context:
kwargs["cache_context"] = context
d = run_in_background(callback, *args, **kwargs)
result = self.set(key, d)
result = self._set(context, d)
elif not isinstance(result, defer.Deferred) or result.called:
logger.info("[%s]: using completed cached result for [%s]", self._name, key)
else:
Expand Down
49 changes: 48 additions & 1 deletion tests/util/caches/test_response_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parameterized import parameterized

from twisted.internet import defer

from synapse.util.caches.response_cache import ResponseCache
from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext

from tests.server import get_clock
from tests.unittest import TestCase
Expand Down Expand Up @@ -142,3 +143,49 @@ def test_cache_wait_expire(self):
self.reactor.pump((2,))

self.assertIsNone(cache.get(0), "cache should not have the result now")

@parameterized.expand([(True,), (False,)])
def test_cache_context_nocache(self, should_cache: bool):
"""If the callback clears the should_cache bit, the result should not be cached"""
cache = self.with_cache("medium_cache", ms=3000)

expected_result = "howdy"

call_count = [0]

async def non_caching(o: str, cache_context: ResponseCacheContext[int]):
call_count[0] += 1
await self.clock.sleep(1)
cache_context.should_cache = should_cache
return o

wrap_d = defer.ensureDeferred(
cache.wrap(0, non_caching, expected_result, cache_context=True)
)
# there should be no result to start with
self.assertNoResult(wrap_d)

# a second call should also return a pending deferred
wrap2_d = defer.ensureDeferred(
cache.wrap(0, non_caching, expected_result, cache_context=True)
)
self.assertNoResult(wrap2_d)

# and there should have been exactly one call
self.assertEqual(call_count[0], 1)

# let the call complete
self.reactor.advance(1)

# both results should have completed
self.assertEqual(expected_result, self.successResultOf(wrap_d))
self.assertEqual(expected_result, self.successResultOf(wrap2_d))

if should_cache:
self.assertEqual(
expected_result,
self.successResultOf(cache.get(0)),
"cache should still have the result",
)
else:
self.assertIsNone(cache.get(0), "cache should not have the result")

0 comments on commit 63e6c20

Please sign in to comment.