Skip to content

Commit

Permalink
drop on_context hook, use on_execute to implement CustomContext, add …
Browse files Browse the repository at this point in the history
…docs to hooks
  • Loading branch information
m.kindritskiy committed Sep 5, 2023
1 parent cb139c6 commit 0447d4a
Show file tree
Hide file tree
Showing 14 changed files with 165 additions and 68 deletions.
1 change: 0 additions & 1 deletion docs/extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ Here are all the methods that can be implemented:

- :meth:`~hiku.extensions.Extension.on_graph` - when endpoint is created and transformations applied to graph
- :meth:`~hiku.extensions.Extension.on_dispatch` - when query is dispatched to the endpoint
- :meth:`~hiku.extensions.Extension.on_context` - after query is dispatched to the endpoint but before query is executed
- :meth:`~hiku.extensions.Extension.on_parse` - when query string is parsed into ast
- :meth:`~hiku.extensions.Extension.on_operation` - when query ast parsed into query Node
- :meth:`~hiku.extensions.Extension.on_validate` - when query is validated
Expand Down
3 changes: 3 additions & 0 deletions hiku/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ def __init__(
self._node_idx: Deque[Dict] = deque()

def visit_field(self, field: QueryField) -> None:
if field.name == "__typename":
return

self._data[-1][field.index_key] = self._node_idx[-1][field.index_key]
super().visit_field(field)

Expand Down
5 changes: 2 additions & 3 deletions hiku/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ def graph(self) -> Graph:
):
return self.mutation_graph

raise ValueError(
"Unsupported operation type: {!r}".format(self.operation.type)
)
assert self.query_graph is not None
return self.query_graph


@dataclass
Expand Down
31 changes: 16 additions & 15 deletions hiku/endpoint/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from ..extensions.base_extension import Extension, ExtensionsManager
from ..extensions.base_validator import QueryValidator
from ..graph import apply, Graph
from ..operation import OperationType
from ..query import Node
from ..validate.query import validate
from ..readers.graphql import (
parse_query,
read_operation,
Operation,
)
from ..denormalize.graphql import DenormalizeGraphQL
from ..introspection.graphql import AsyncGraphQLIntrospection
Expand Down Expand Up @@ -136,7 +136,7 @@ def _init_execution_context(
self,
execution_context: ExecutionContext,
extensions_manager: ExtensionsManager,
) -> t.Tuple[Graph, Operation]:
) -> None:
with extensions_manager.parsing():
if execution_context.graphql_document is None:
execution_context.graphql_document = parse_query(
Expand All @@ -158,14 +158,13 @@ def _init_execution_context(
]
)

execution_context.query = execution_context.operation.query
execution_context.query_graph = self.query_graph
execution_context.mutation_graph = self.mutation_graph
execution_context.query = execution_context.operation.query

try:
return execution_context.graph, execution_context.operation
except ValueError as err:
raise GraphQLError(errors=[str(err)])
op = execution_context.operation
if op.type not in (OperationType.QUERY, OperationType.MUTATION):
raise GraphQLError(
errors=["Unsupported operation type: {!r}".format(op.type)]
)


class BaseSyncGraphQLEndpoint(BaseGraphQLEndpoint):
Expand Down Expand Up @@ -193,6 +192,8 @@ def dispatch(
variables=data.get("variables"),
operation_name=data.get("operationName"),
context=context,
query_graph=self.query_graph,
mutation_graph=self.mutation_graph,
)

extensions_manager = ExtensionsManager(
Expand All @@ -206,8 +207,7 @@ def dispatch(
execution_context, extensions_manager
)

with extensions_manager.context():
result = self.execute(execution_context, extensions_manager)
result = self.execute(execution_context, extensions_manager)
return {"data": result}
except GraphQLError as e:
return {"errors": [{"message": e} for e in e.errors]}
Expand All @@ -230,6 +230,8 @@ async def dispatch(
variables=data.get("variables"),
operation_name=data.get("operationName"),
context=context,
query_graph=self.query_graph,
mutation_graph=self.mutation_graph,
)

extensions_manager = ExtensionsManager(
Expand All @@ -241,10 +243,9 @@ async def dispatch(
self._init_execution_context(
execution_context, extensions_manager
)
with extensions_manager.context():
result = await self.execute(
execution_context, extensions_manager
)
result = await self.execute(
execution_context, extensions_manager
)
return {"data": result}
except GraphQLError as e:
return {"errors": [{"message": e} for e in e.errors]}
Expand Down
63 changes: 47 additions & 16 deletions hiku/extensions/base_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,31 +41,68 @@ def __init__(self, *, execution_context: Optional[ExecutionContext] = None):
self.execution_context = execution_context # type: ignore[assignment]

def on_graph(self) -> AsyncIteratorOrIterator[None]:
"""Called before and after the graph (transformation) step"""
yield None
"""Called before and after the graph (transformation) step.
def on_dispatch(self) -> AsyncIteratorOrIterator[None]:
"""Called before and after the dispatch step"""
Graph transformation step is a step where we applying transformations
to the graph such as introspection, etc.
Note: unlike other hooks, this hook is called only once during endpoint
creation.
"""
yield None

def on_context(self) -> AsyncIteratorOrIterator[None]:
"""Called before and after the context step"""
def on_dispatch(self) -> AsyncIteratorOrIterator[None]:
"""Called before and after the dispatch step.
Dispatch step is a step where the query is dispatched by to the endpoint
to be parsed, validated and executed.
At this step the:
- execution_context.query_src is set to the query string
from request
- execution_context.variables is set to the query variables
from request
- execution_context.operation_name is set to the query operation name
from request
- execution_context.query_graph is set to the query graph
- execution_context.mutation_graph is set to the mutation graph
- execution_context.context is set to the context from dispatch argument
"""
yield None

def on_operation(self) -> AsyncIteratorOrIterator[None]:
"""Called before and after the operation step"""
"""Called before and after the operation step.
Operation step is a step where the graphql ast is transformed into
hiku's query ast and Operation type is created and assigned to the
execution_context.operation.
"""
yield None

def on_parse(self) -> AsyncIteratorOrIterator[None]:
"""Called before and after the parsing step"""
"""Called before and after the parsing step.
Parse step is when query string is parsed into graphql ast
and will be assigned to the execution_context.graphql_document.
"""
yield None

def on_validate(self) -> AsyncIteratorOrIterator[None]:
"""Called before and after the validation step"""
"""Called before and after the validation step.
Validation step is when hiku query is validated.
After validation is done, if there are errors, they will be assigned
to the execution_context.errors.
"""
yield None

def on_execute(self) -> AsyncIteratorOrIterator[None]:
"""Called before and after the execution step"""
"""Called before and after the execution step.
Execution step is when hiku query is executed by hiku engine.
After execution, normalized result(Proxy) will be assigned
to execution_context.result."""
yield None


Expand Down Expand Up @@ -105,12 +142,6 @@ def dispatch(self) -> "ExtensionContextManagerBase":
self.extensions,
)

def context(self) -> "ExtensionContextManagerBase":
return ExtensionContextManagerBase(
Extension.on_context.__name__,
self.extensions,
)

def parsing(self) -> "ExtensionContextManagerBase":
return ExtensionContextManagerBase(
Extension.on_parse.__name__,
Expand Down
2 changes: 1 addition & 1 deletion hiku/extensions/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(
):
self.get_context = get_context

def on_context(self) -> Iterator[None]:
def on_execute(self) -> Iterator[None]:
self.execution_context.context = self.get_context(
self.execution_context
)
Expand Down
27 changes: 21 additions & 6 deletions hiku/extensions/prometheus.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextvars import ContextVar
from typing import Iterator, Optional, Type

from prometheus_client.metrics import MetricWrapperBase
Expand All @@ -6,7 +7,6 @@
AsyncGraphMetrics,
GraphMetrics,
GraphMetricsBase,
metrics_ctx,
)
from hiku.extensions.base_extension import Extension

Expand All @@ -15,12 +15,17 @@ class PrometheusMetrics(Extension):
def __init__(
self,
name: str,
*,
metric: Optional[MetricWrapperBase] = None,
ctx_var: Optional[ContextVar] = None,
transformer_cls: Type[GraphMetricsBase] = GraphMetrics,
):
self._name = name
self._metric = metric
self._transformer = transformer_cls(self._name, metric=self._metric)
self._ctx_var = ctx_var
self._transformer = transformer_cls(
self._name, metric=self._metric, ctx_var=ctx_var
)

def on_graph(self) -> Iterator[None]:
self.execution_context.transformers = (
Expand All @@ -29,15 +34,25 @@ def on_graph(self) -> Iterator[None]:
yield

def on_execute(self) -> Iterator[None]:
token = metrics_ctx.set(self.execution_context.context)
yield
metrics_ctx.reset(token)
if self._ctx_var is None:
yield
else:
token = self._ctx_var.set(self.execution_context.context)
yield
self._ctx_var.reset(token)


class PrometheusMetricsAsync(PrometheusMetrics):
def __init__(
self,
name: str,
*,
metric: Optional[MetricWrapperBase] = None,
ctx_var: Optional[ContextVar] = None,
):
super().__init__(name, metric=metric, transformer_cls=AsyncGraphMetrics)
super().__init__(
name,
metric=metric,
ctx_var=ctx_var,
transformer_cls=AsyncGraphMetrics,
)
19 changes: 11 additions & 8 deletions hiku/federation/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def dispatch(self, data: Dict, context: Optional[Dict] = None) -> Dict:
variables=data.get("variables"),
operation_name=data.get("operationName"),
context=context,
query_graph=self.query_graph,
mutation_graph=self.mutation_graph,
)

extensions_manager = ExtensionsManager(
Expand All @@ -194,8 +196,7 @@ def dispatch(self, data: Dict, context: Optional[Dict] = None) -> Dict:
self._init_execution_context(
execution_context, extensions_manager
)
with extensions_manager.context():
result = self.execute(execution_context, extensions_manager)
result = self.execute(execution_context, extensions_manager)
return {"data": result}
except GraphQLError as e:
return {"errors": [{"message": e} for e in e.errors]}
Expand All @@ -218,6 +219,8 @@ async def dispatch(
variables=data.get("variables"),
operation_name=data.get("operationName"),
context=context,
query_graph=self.query_graph,
mutation_graph=self.mutation_graph,
)

extensions_manager = ExtensionsManager(
Expand All @@ -230,10 +233,9 @@ async def dispatch(
self._init_execution_context(
execution_context, extensions_manager
)
with extensions_manager.context():
result = await self.execute(
execution_context, extensions_manager
)
result = await self.execute(
execution_context, extensions_manager
)
return {"data": result}
except GraphQLError as e:
return {"errors": [{"message": e} for e in e.errors]}
Expand Down Expand Up @@ -292,6 +294,8 @@ def dispatch(self, data: Dict, context: Optional[Dict] = None) -> Dict:
variables=data.get("variables"),
operation_name=data.get("operationName"),
context=context,
query_graph=self.query_graph,
mutation_graph=self.mutation_graph,
)

extensions_manager = ExtensionsManager(
Expand All @@ -304,8 +308,7 @@ def dispatch(self, data: Dict, context: Optional[Dict] = None) -> Dict:
self._init_execution_context(
execution_context, extensions_manager
)
with extensions_manager.context():
result = self.execute(execution_context, extensions_manager)
result = self.execute(execution_context, extensions_manager)
return {"data": result}
except GraphQLError as e:
return {"errors": [{"message": e} for e in e.errors]}
Expand Down
Loading

0 comments on commit 0447d4a

Please sign in to comment.