Skip to content

Commit

Permalink
Merge pull request #138 from evo-company/add-extensions-factory
Browse files Browse the repository at this point in the history
add ExtensionFactory, forbid to pass Extension instances to Endpoint …
  • Loading branch information
kindermax committed Oct 11, 2023
2 parents 466c75f + 4f89aef commit 685d98d
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 30 deletions.
2 changes: 2 additions & 0 deletions hiku/extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .base_extension import Extension
from .base_extension import ExtensionFactory

__all__ = [
"Extension",
"ExtensionFactory",
]
29 changes: 26 additions & 3 deletions hiku/extensions/base_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,24 @@ def on_execute(self) -> AsyncIteratorOrIterator[None]:
yield None


class ExtensionFactory:
"""Lazy extension factory.
Remembers arguments and keyword arguments and creates an extension
instance when ExtensionsManager is created.
"""

ext_class: Type[Extension]

def __init__(self, *args: Any, **kwargs: Any):
self._args = args
self._kwargs = kwargs

def create(self, execution_context: ExecutionContext) -> Extension:
extension = self.ext_class(*self._args, **self._kwargs)
extension.execution_context = execution_context
return extension


class ExtensionsManager:
def __init__(
self,
Expand All @@ -120,9 +138,14 @@ def __init__(
init_extensions: List[Extension] = []

for extension in extensions:
if isinstance(extension, Extension):
extension.execution_context = execution_context
init_extensions.append(extension)
if isinstance(extension, ExtensionFactory):
init_extensions.append(extension.create(execution_context))
elif isinstance(extension, Extension):
raise ValueError(
f"Extension {extension} must be a class, "
"not an instance. Use ExtensionFactory if your extension "
"has custom arguments."
)
else:
init_extensions.append(
extension(execution_context=execution_context)
Expand Down
11 changes: 9 additions & 2 deletions hiku/extensions/context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Callable, Dict, Iterator

from hiku.context import ExecutionContext
from hiku.extensions.base_extension import Extension
from hiku.extensions.base_extension import Extension, ExtensionFactory


class CustomContext(Extension):
class _CustomContextImpl(Extension):
def __init__(
self,
get_context: Callable[[ExecutionContext], Dict],
Expand All @@ -16,3 +16,10 @@ def on_execute(self) -> Iterator[None]:
self.execution_context
)
yield


class CustomContext(ExtensionFactory):
ext_class = _CustomContextImpl

def __init__(self, get_context: Callable[[ExecutionContext], Dict]):
super().__init__(get_context)
14 changes: 11 additions & 3 deletions hiku/extensions/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
GraphMetrics,
GraphMetricsBase,
)
from hiku.extensions.base_extension import Extension
from hiku.extensions.base_extension import Extension, ExtensionFactory


class PrometheusMetrics(Extension):
class _PrometheusMetricsImpl(Extension):
def __init__(
self,
name: str,
Expand Down Expand Up @@ -42,7 +42,11 @@ def on_execute(self) -> Iterator[None]:
self._ctx_var.reset(token)


class PrometheusMetricsAsync(PrometheusMetrics):
class PrometheusMetrics(ExtensionFactory):
ext_class = _PrometheusMetricsImpl


class _PrometheusMetricsAsyncImpl(_PrometheusMetricsImpl):
def __init__(
self,
name: str,
Expand All @@ -56,3 +60,7 @@ def __init__(
ctx_var=ctx_var,
transformer_cls=AsyncGraphMetrics,
)


class PrometheusMetricsAsync(ExtensionFactory):
ext_class = _PrometheusMetricsAsyncImpl
16 changes: 12 additions & 4 deletions hiku/extensions/query_depth_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from hiku.graph import Graph

from hiku.extensions.base_extension import Extension
from hiku.extensions.base_extension import Extension, ExtensionFactory
from hiku.extensions.base_validator import QueryValidator
from hiku.query import Field, Link, Node

Expand Down Expand Up @@ -52,7 +52,7 @@ def visit_node(self, obj: Node) -> None:
self._current_depth -= 1


class QueryDepthValidator(Extension):
class _QueryDepthValidatorImpl(Extension):
"""Use this extension to limit the maximum allowed query depth.
Example:
Expand All @@ -62,12 +62,20 @@ class QueryDepthValidator(Extension):
"""

def __init__(self, max_depth: int):
self._validator = _QueryDepthValidator(max_depth)
def __init__(self, validator: _QueryDepthValidator):
self._validator = validator

def on_dispatch(self) -> Iterator[None]:
self.execution_context.validators = (
self.execution_context.validators + tuple([self._validator])
)

yield


class QueryDepthValidator(ExtensionFactory):
ext_class = _QueryDepthValidatorImpl

def __init__(self, max_depth: int):
self._validator = _QueryDepthValidator(max_depth)
super().__init__(self._validator)
20 changes: 14 additions & 6 deletions hiku/extensions/query_parse_cache.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from functools import lru_cache
from typing import Iterator, Optional
from typing import Callable, Iterator, Optional

from prometheus_client import Gauge

from hiku.extensions.base_extension import Extension
from hiku.extensions.base_extension import Extension, ExtensionFactory
from hiku.readers.graphql import parse_query


QUERY_CACHE_HITS = Gauge("hiku_query_cache_hits", "Query cache hits")
QUERY_CACHE_MISSES = Gauge("hiku_query_cache_misses", "Query cache misses")


class QueryParserCache(Extension):
class _QueryParserCacheImpl(Extension):
"""Sets up lru cache for the ast parsing.
Exposes two metrics:
Expand All @@ -21,8 +21,8 @@ class QueryParserCache(Extension):
:param int maxsize: Maximum size of the cache
"""

def __init__(self, maxsize: Optional[int] = None):
self.cached_parser = lru_cache(maxsize=maxsize)(parse_query)
def __init__(self, cached_parser: Callable):
self.cached_parser = cached_parser

def on_parse(self) -> Iterator[None]:
execution_context = self.execution_context
Expand All @@ -31,7 +31,15 @@ def on_parse(self) -> Iterator[None]:
execution_context.query_src,
)

info = self.cached_parser.cache_info()
info = self.cached_parser.cache_info() # type: ignore[attr-defined]
QUERY_CACHE_HITS.set(info.hits)
QUERY_CACHE_MISSES.set(info.misses)
yield


class QueryParserCache(ExtensionFactory):
ext_class = _QueryParserCacheImpl

def __init__(self, maxsize: Optional[int] = None):
self.cached_parser = lru_cache(maxsize=maxsize)(parse_query)
super().__init__(self.cached_parser)
22 changes: 15 additions & 7 deletions hiku/extensions/query_transform_cache.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
from functools import lru_cache
from typing import Iterator, Optional
from typing import Callable, Iterator, Optional

from hiku.utils import to_immutable_dict

from hiku.extensions.base_extension import Extension
from hiku.extensions.base_extension import Extension, ExtensionFactory
from hiku.readers.graphql import read_operation


class QueryTransformCache(Extension):
class _QueryTransformCacheImpl(Extension):
"""Sets up lru cache for the ast to Node transformation.
:param int maxsize: Maximum size of the cache
"""

def __init__(self, maxsize: Optional[int] = None):
self.cached_operation_reader = lru_cache(maxsize=maxsize)(
read_operation
)
def __init__(self, cached_operation_reader: Callable):
self.cached_operation_reader = cached_operation_reader

def on_operation(self) -> Iterator[None]:
execution_context = self.execution_context
Expand All @@ -30,3 +28,13 @@ def on_operation(self) -> Iterator[None]:
)

yield


class QueryTransformCache(ExtensionFactory):
ext_class = _QueryTransformCacheImpl

def __init__(self, maxsize: Optional[int] = None):
self.cached_operation_reader = lru_cache(maxsize=maxsize)(
read_operation
)
super().__init__(self.cached_operation_reader)
18 changes: 13 additions & 5 deletions hiku/extensions/query_validation_cache.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from functools import lru_cache
from typing import Iterator, Optional
from typing import Callable, Iterator, Optional

from hiku.endpoint.graphql import _run_validation
from hiku.extensions.base_extension import Extension
from hiku.extensions.base_extension import Extension, ExtensionFactory


class QueryValidationCache(Extension):
class _QueryValidationCacheImpl(Extension):
"""Sets up lru cache for the query Node validation.
:param int maxsize: Maximum size of the cache
"""

def __init__(self, maxsize: Optional[int] = None):
self.cached_validator = lru_cache(maxsize=maxsize)(_run_validation)
def __init__(self, cached_validator: Callable):
self.cached_validator = cached_validator

def on_validate(self) -> Iterator[None]:
execution_context = self.execution_context
Expand All @@ -24,3 +24,11 @@ def on_validate(self) -> Iterator[None]:
)

yield


class QueryValidationCache(ExtensionFactory):
ext_class = _QueryValidationCacheImpl

def __init__(self, maxsize: Optional[int] = None):
self.cached_validator = lru_cache(maxsize=maxsize)(_run_validation)
super().__init__(self.cached_validator)

0 comments on commit 685d98d

Please sign in to comment.