Skip to content

Commit

Permalink
Merge pull request #81 from evo-company/add-result-caching
Browse files Browse the repository at this point in the history
add result caching
  • Loading branch information
kindermax committed Sep 29, 2022
2 parents 15c86f5 + 2b9193f commit 4fd6453
Show file tree
Hide file tree
Showing 16 changed files with 1,240 additions and 50 deletions.
167 changes: 167 additions & 0 deletions hiku/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import abc
import contextlib
import hashlib

from collections import (
defaultdict,
deque,
)
from typing import (
Any,
Dict,
List,
Union,
Tuple,
Deque,
Iterator,
)

from hiku.result import Index
from hiku.graph import (
Many,
Graph,
Node,
Field,
)
from hiku.query import (
QueryVisitor,
Field as QueryField,
Link as QueryLink,
)


CACHE_VERSION = '1'


class BaseCache(abc.ABC):
@abc.abstractmethod
def get_many(self, keys: List[str]) -> Dict[str, Any]:
"""Result must contain only keys which were cached"""
raise NotImplementedError()

@abc.abstractmethod
def set_many(self, items: Dict[str, Any], ttl: int) -> None:
raise NotImplementedError()


class HashVisitor(QueryVisitor):
def __init__(self, hasher) -> None: # type: ignore
self._hasher = hasher

def visit_field(self, obj: QueryField) -> None:
self._hasher.update(obj.index_key.encode())

def visit_link(self, obj: QueryLink) -> None:
self._hasher.update(obj.index_key.encode())
super().visit_link(obj)


class CacheVisitor(QueryVisitor):
"""Visit cached query link to extract all data from index
that needs to be cached
"""
def __init__(self, index: Index, graph: Graph, node: Node) -> None:
self._index = index
self._graph = graph
self._node = deque([node])
self._req: Deque[Any] = deque()
self._data: Deque[Dict] = deque()
self._to_cache: Deque[Dict] = deque()
self._node_idx: Deque[Dict] = deque()

def visit_field(self, field: QueryField) -> None:
self._data[-1][field.index_key] = self._node_idx[-1][field.index_key]
super().visit_field(field)

def visit_link(self, link: QueryLink) -> None:
refs = self._node_idx[-1][link.index_key]

self._data[-1][link.index_key] = refs

graph_obj = self._node[-1].fields_map[link.name]
if isinstance(graph_obj, Field):
# Link as complex field
return

node = self._graph.nodes_map[graph_obj.node]
self._node.append(node)

@contextlib.contextmanager
def _visit_ctx(req: Any) -> Iterator:
self._node_idx.append(self._index[node.name][req])
self._data.append({})

yield

data = self._data.pop()
self._to_cache[-1][node.name][req] = data
self._node_idx.pop()

if graph_obj.type_enum is Many:
for r in refs:
with _visit_ctx(r.ident):
super().visit_link(link)
else:
if refs is None:
self._node.pop()
return

with _visit_ctx(refs.ident):
super().visit_link(link)

self._node.pop()

def process(self, link: QueryLink, ids: List, reqs: List) -> Dict:
to_cache = {}
for i, req in zip(ids, reqs):
node = self._node[-1]
self._node_idx.append(self._index[node.name][i])
self._data.append({})
self._to_cache.append(defaultdict(dict))

self.visit(link)

self._to_cache[-1][node.name] = self._data.pop()
to_cache[get_query_hash(link, req)] = dict(self._to_cache.pop())
self._node_idx.pop()

return to_cache


def get_query_hash(
query_link: Union[QueryLink, QueryField],
req: Any
) -> str:
hasher = hashlib.sha1()
hash_visitor = HashVisitor(hasher)
hash_visitor.visit(query_link)

if isinstance(req, list):
for r in req:
hasher.update(str(hash(r)).encode('utf-8'))
else:
hasher.update(str(hash(req)).encode('utf-8'))
hasher.update(CACHE_VERSION.encode('utf-8'))
return hasher.hexdigest()


def get_cached_data(
cache: BaseCache,
query_link: QueryLink,
ids: List,
reqs: List,
) -> Tuple[List, List]:
req_key = []
for i, req in zip(ids, reqs):
req_key.append((get_query_hash(query_link, req), i, req))

keys = set(info[0] for info in req_key)
cached_data_raw = cache.get_many(list(keys))
cached_data = []
cached_ids = []
for key, i, req in req_key:
if key in cached_data_raw:
cached_ids.append(i)
cached_data.append(cached_data_raw[key])

return cached_ids, cached_data
11 changes: 11 additions & 0 deletions hiku/directives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,14 @@ def get_deprecated(field: Union['Field', 'Link']) -> Optional[Deprecated]:
(d for d in field.directives if isinstance(d, Deprecated)),
None
)


class QueryDirective:
def __init__(self, name: str) -> None:
self.name = name


class Cached(QueryDirective):
def __init__(self, ttl: int):
super().__init__('cached')
self.ttl = ttl

0 comments on commit 4fd6453

Please sign in to comment.