Skip to content

Commit

Permalink
Merge pull request #84 from evo-company/support-custom-cache-key-func
Browse files Browse the repository at this point in the history
support custom cache key func
  • Loading branch information
kindermax committed Oct 4, 2022
2 parents f4a08a5 + 1226bd2 commit e11d008
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 28 deletions.
66 changes: 59 additions & 7 deletions hiku/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
deque,
)
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Union,
Deque,
Iterator,
Optional,
Callable,
)

from hiku.compat import Protocol
from hiku.result import Index
from hiku.graph import (
Many,
Expand All @@ -28,10 +31,20 @@
Link as QueryLink,
)

if TYPE_CHECKING:
from hiku.engine import Context

CACHE_VERSION = '1'


class Hasher(Protocol):
def update(self, data: bytes) -> None:
...


CacheKeyFn = Callable[['Context', Hasher], None]


class BaseCache(abc.ABC):
@abc.abstractmethod
def get_many(self, keys: List[str]) -> Dict[str, Any]:
Expand All @@ -43,6 +56,40 @@ def set_many(self, items: Dict[str, Any], ttl: int) -> None:
raise NotImplementedError()


class CacheSettings:
def __init__(
self,
cache: BaseCache,
cache_key: Optional[CacheKeyFn] = None
):
self.cache = cache
self.cache_key = cache_key


class CacheInfo:
def __init__(
self,
cache_settings: CacheSettings
):
self.cache = cache_settings.cache
self.cache_key = cache_settings.cache_key

def query_hash(
self, ctx: 'Context', query_link: QueryLink, req: Any
) -> str:
hasher = hashlib.sha1()
get_query_hash(hasher, query_link, req)
if self.cache_key:
self.cache_key(ctx, hasher)
return hasher.hexdigest()

def get_many(self, keys: List[str]) -> Dict[str, Any]:
return self.cache.get_many(keys)

def set_many(self, items: Dict[str, Any], ttl: int) -> None:
self.cache.set_many(items, ttl)


class HashVisitor(QueryVisitor):
def __init__(self, hasher) -> None: # type: ignore
self._hasher = hasher
Expand All @@ -59,7 +106,10 @@ class CacheVisitor(QueryVisitor):
"""Visit cached query link to extract all data from index
that needs to be cached
"""
def __init__(self, index: Index, graph: Graph, node: Node) -> None:
def __init__(
self, cache: CacheInfo, index: Index, graph: Graph, node: Node
) -> None:
self._cache = cache
self._index = index
self._graph = graph
self._node = deque([node])
Expand Down Expand Up @@ -110,7 +160,9 @@ def _visit_ctx(req: Any) -> Iterator:

self._node.pop()

def process(self, link: QueryLink, ids: List, reqs: List) -> Dict:
def process(
self, link: QueryLink, ids: List, reqs: List, ctx: 'Context'
) -> Dict:
to_cache = {}
for i, req in zip(ids, reqs):
node = self._node[-1]
Expand All @@ -121,17 +173,18 @@ def process(self, link: QueryLink, ids: List, reqs: List) -> Dict:
self.visit(link)

self._to_cache[-1][node.name] = self._data.pop()
to_cache[get_query_hash(link, req)] = dict(self._to_cache.pop())
key = self._cache.query_hash(ctx, link, req)
to_cache[key] = dict(self._to_cache.pop())
self._node_idx.pop()

return to_cache


def get_query_hash(
hasher: Hasher,
query_link: Union[QueryLink, QueryField],
req: Any
) -> str:
hasher = hashlib.sha1()
) -> None:
hash_visitor = HashVisitor(hasher)
hash_visitor.visit(query_link)

Expand All @@ -141,4 +194,3 @@ def get_query_hash(
else:
hasher.update(str(hash(req)).encode('utf-8'))
hasher.update(CACHE_VERSION.encode('utf-8'))
return hasher.hexdigest()
20 changes: 11 additions & 9 deletions hiku/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
from collections.abc import Sequence, Mapping, Hashable

from .cache import (
BaseCache,
CacheVisitor,
get_query_hash,
CacheInfo,
CacheSettings,
)
from .compat import Concatenate, ParamSpec
from .executors.base import SyncAsyncExecutor
Expand Down Expand Up @@ -456,7 +456,7 @@ def __init__(
graph: Graph,
query: QueryNode,
ctx: 'Context',
cache: BaseCache = None
cache: CacheInfo = None
) -> None:
self._queue = queue
self._task_set = task_set
Expand Down Expand Up @@ -643,7 +643,9 @@ def _update_index_from_cache(
assert self._cache is not None
key_info = []
for i, req in zip(ids, reqs):
key_info.append((get_query_hash(query_link, req), i, req))
key_info.append(
(self._cache.query_hash(self._ctx, query_link, req), i, req)
)

keys = set(info[0] for info in key_info)
dep = self._submit(self._cache.get_many, list(keys))
Expand Down Expand Up @@ -717,9 +719,9 @@ def store_link_cache() -> None:
assert self._cache is not None
cached = query_link.directives_map['cached']
reqs = link_reqs(self._index, node, graph_link, ids)
to_cache = CacheVisitor(self._index, self._graph, node).process(
query_link, ids, reqs
)
to_cache = CacheVisitor(
self._cache, self._index, self._graph, node
).process(query_link, ids, reqs, self._ctx)

self._submit(self._cache.set_many, to_cache, cached.ttl)

Expand Down Expand Up @@ -773,10 +775,10 @@ class Engine:
def __init__(
self,
executor: SyncAsyncExecutor,
cache: BaseCache = None,
cache: CacheSettings = None,
) -> None:
self.executor = executor
self.cache = cache
self.cache = CacheInfo(cache) if cache else None

def execute(
self,
Expand Down
41 changes: 29 additions & 12 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@
)
from hiku.engine import Engine
from hiku.readers.graphql import read
from hiku.cache import BaseCache, get_query_hash
from hiku.cache import (
BaseCache,
CacheSettings,
)


class InMemoryCache(BaseCache):
Expand Down Expand Up @@ -439,7 +442,7 @@ def get_product_query(product_id: int) -> str:
name
}
}
company @cached(ttl: 10) {
company @cached(ttl: 10) {
id
name
address { city }
Expand Down Expand Up @@ -496,10 +499,14 @@ def test_cached_link_one__sqlalchemy(sync_graph_sqlalchemy):

cache = InMemoryCache()
cache = Mock(wraps=cache)
engine = Engine(ThreadsExecutor(thread_pool), cache)
engine = Engine(ThreadsExecutor(thread_pool), CacheSettings(cache))
ctx = {
SA_ENGINE_KEY: sa_engine,
'locale': 'en'
}

def execute(q):
proxy = engine.execute(graph, q, {SA_ENGINE_KEY: sa_engine})
proxy = engine.execute(graph, q, ctx)
return DenormalizeGraphQL(graph, proxy, 'query').process(q)

query = read(get_product_query(1))
Expand All @@ -513,8 +520,8 @@ def execute(q):
.node.fields_map['photo']
)

company_key = get_query_hash(company_link, 10)
attributes_key = get_query_hash(attributes_link, [11, 12])
company_key = engine.cache.query_hash(ctx, company_link, 10)
attributes_key = engine.cache.query_hash(ctx, attributes_link, [11, 12])

company_cache = {
'User': {
Expand Down Expand Up @@ -615,10 +622,19 @@ def test_cached_link_many__sqlalchemy(sync_graph_sqlalchemy):

cache = InMemoryCache()
cache = Mock(wraps=cache)
engine = Engine(ThreadsExecutor(thread_pool), cache)

def cache_key(ctx, hasher):
hasher.update(ctx['locale'].encode('utf-8'))

cache_settings = CacheSettings(cache, cache_key)
engine = Engine(ThreadsExecutor(thread_pool), cache_settings)
ctx = {
SA_ENGINE_KEY: sa_engine,
'locale': 'en'
}

def execute(q):
proxy = engine.execute(graph, q, {SA_ENGINE_KEY: sa_engine})
proxy = engine.execute(graph, q, ctx)
return DenormalizeGraphQL(graph, proxy, 'query').process(q)

query = read(get_products_query())
Expand All @@ -633,10 +649,11 @@ def execute(q):
.node.fields_map['photo']
)

company10_key = get_query_hash(company_link, 10)
company20_key = get_query_hash(company_link, 20)
attributes11_12_key = get_query_hash(attributes_link, [11, 12])
attributes_none_key = get_query_hash(attributes_link, [])
company10_key = engine.cache.query_hash(ctx, company_link, 10)
company20_key = engine.cache.query_hash(ctx, company_link, 20)
attributes11_12_key = engine.cache.query_hash(
ctx, attributes_link, [11, 12])
attributes_none_key = engine.cache.query_hash(ctx, attributes_link, [])

company10_cache = {
'User': {
Expand Down

0 comments on commit e11d008

Please sign in to comment.