Skip to content

Commit

Permalink
Merge pull request #123 from evo-company/fix-resolvable-false-for-entity
Browse files Browse the repository at this point in the history
[fix] do not expose entity in _Entity union if @key is resolvable: false
  • Loading branch information
kindermax committed Aug 21, 2023
2 parents fdffa57 + ad12721 commit d523817
Show file tree
Hide file tree
Showing 12 changed files with 173 additions and 66 deletions.
3 changes: 2 additions & 1 deletion examples/graphql_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ def direct_link(ids):
app = Flask(__name__)

graphql_endpoint = FederatedGraphQLEndpoint(
Engine(SyncExecutor(), federation_version=1),
Engine(SyncExecutor()),
QUERY_GRAPH,
federation_version=1,
)


Expand Down
41 changes: 36 additions & 5 deletions hiku/federation/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
Type,
)

from hiku.federation.utils import get_entity_types
from hiku.federation.version import DEFAULT_FEDERATION_VERSION

from hiku.federation.sdl import print_sdl

from hiku.federation.graph import Graph
Expand All @@ -35,7 +38,7 @@
GraphQLError,
_StripQuery,
)
from hiku.graph import apply
from hiku.graph import GraphTransformer, apply, Union as GraphUnion
from hiku.query import Node
from hiku.result import Proxy
from hiku.readers.graphql import Operation
Expand Down Expand Up @@ -66,9 +69,31 @@ def denormalize_entities(
return data["_entities"]


class FederationV1EntityTransformer(GraphTransformer):
def visit_graph(self, obj: Graph) -> Graph: # type: ignore
entities = get_entity_types(obj.nodes, federation_version=1)
unions = []
for u in obj.unions:
if u.name == "_Entity" and entities:
unions.append(GraphUnion("_Entity", entities))
else:
unions.append(u)

return Graph(
[self.visit(node) for node in obj.items],
obj.data_types,
obj.directives,
unions,
obj.interfaces,
obj.enums,
obj.scalars,
)


class BaseFederatedGraphEndpoint(ABC):
query_graph: Graph
mutation_graph: Optional[Graph]
federation_version: int

@property
@abstractmethod
Expand All @@ -80,13 +105,19 @@ def __init__(
engine: Engine,
query_graph: Graph,
mutation_graph: Optional[Graph] = None,
federation_version: int = DEFAULT_FEDERATION_VERSION,
):
self.engine = engine
self.federation_version = federation_version

introspection = self.introspection_cls(query_graph, mutation_graph)
self.query_graph = apply(query_graph, [introspection])
transformers: List[GraphTransformer] = [introspection]
if federation_version == 1:
transformers.insert(0, FederationV1EntityTransformer())

self.query_graph = apply(query_graph, transformers)
if mutation_graph is not None:
self.mutation_graph = apply(mutation_graph, [introspection])
self.mutation_graph = apply(mutation_graph, transformers)
else:
self.mutation_graph = None

Expand Down Expand Up @@ -135,7 +166,7 @@ def execute(self, graph: Graph, op: Operation, ctx: Optional[Dict]) -> Dict:
ctx["__sdl__"] = print_sdl(
self.query_graph,
self.mutation_graph,
federation_version=self.engine.federation_version,
self.federation_version,
)

result = self.engine.execute(graph, stripped_query, ctx, op)
Expand Down Expand Up @@ -174,7 +205,7 @@ async def execute(
ctx["__sdl__"] = print_sdl(
self.query_graph,
self.mutation_graph,
federation_version=self.engine.federation_version,
self.federation_version,
)

coro = self.engine.execute(graph, stripped_query, ctx)
Expand Down
5 changes: 0 additions & 5 deletions hiku/federation/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Context,
)
from hiku.executors.queue import Queue
from hiku.federation.version import DEFAULT_FEDERATION_VERSION
from hiku.graph import (
Graph,
GraphTransformer,
Expand Down Expand Up @@ -61,12 +60,8 @@ def __init__(
self,
executor: SyncAsyncExecutor,
cache: Optional[CacheSettings] = None,
federation_version: int = DEFAULT_FEDERATION_VERSION,
) -> None:
super().__init__(executor, cache)
if federation_version not in (1, 2):
raise ValueError("federation_version must be 1 or 2")
self.federation_version = federation_version

def execute(
self,
Expand Down
58 changes: 45 additions & 13 deletions hiku/federation/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import typing as t
from inspect import isawaitable

from hiku.federation.version import DEFAULT_FEDERATION_VERSION

from hiku.engine import pass_context

from hiku.directives import SchemaDirective
Expand Down Expand Up @@ -62,13 +64,13 @@ def visit_root(self, root: Root) -> Root:
if not self.has_entity_types:
return root

return Root(
root.fields
+ [
self.entities_link(),
self.service_field(),
]
)
fields = root.fields[:]
if "_entities" not in root.fields_map:
fields.append(self.entities_link())
if "_service" not in root.fields_map:
fields.append(self.service_field())

return Root(fields)

def visit_node(self, obj: Node) -> Node:
if hasattr(obj, "resolve_reference") and obj.name is not None:
Expand Down Expand Up @@ -152,24 +154,54 @@ def __init__(
scalars: t.Optional[t.List[t.Type[Scalar]]] = None,
is_async: bool = False,
):
self.is_async = is_async

if unions is None:
unions = []

entity_types = get_entity_types(items)
if entity_types:
unions.append(Union("_Entity", entity_types))

if scalars is None:
scalars = []

scalars.extend([_Any, FieldSet, LinkImport])
unions_map = {union.name: union for union in unions}
scalars_map = {scalar.__type_name__: scalar for scalar in scalars}

entity_types = get_entity_types(items, DEFAULT_FEDERATION_VERSION)
if entity_types:
if "_Entity" not in unions_map:
unions.append(Union("_Entity", entity_types))

for scalar in [_Any, FieldSet, LinkImport]:
if scalar.__type_name__ not in scalars_map:
scalars.append(scalar)

if data_types is None:
data_types = {}

data_types["_Service"] = Record[{"sdl": String}]
if "_Service" not in data_types:
data_types["_Service"] = Record[{"sdl": String}]

items = GraphInit.init(items, is_async, bool(entity_types))

super().__init__(
items, data_types, directives, unions, interfaces, enums, scalars
)

@classmethod
def from_graph(
cls,
other: "Graph",
root: Root,
) -> "Graph":
"""Create graph from other graph, with new root node.
Useful for creating mutation graph from query graph.
"""
return cls(
other.nodes + [root],
data_types=other.data_types,
directives=other.directives,
unions=other.unions,
interfaces=other.interfaces,
enums=other.enums,
scalars=other.scalars,
is_async=other.is_async,
)
46 changes: 30 additions & 16 deletions hiku/federation/sdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@
from graphql.language import ast
from graphql.pyutils import inspect

from hiku.federation.utils import get_entity_types
from hiku.federation.version import DEFAULT_FEDERATION_VERSION

from hiku.directives import SchemaDirective
from hiku.federation.graph import Graph as FederationGraph
from hiku.federation.directive import (
ComposeDirective,
Extends,
FederationSchemaDirective,
Key,
Link as LinkDirective,
)
from hiku.federation.version import DEFAULT_FEDERATION_VERSION
from hiku.introspection.graphql import _BUILTIN_DIRECTIVES
from hiku.graph import (
Link,
Expand All @@ -29,9 +32,11 @@
Node,
Root,
GraphTransformer,
Graph,
Option,
Graph,
Union,
)
from hiku.scalar import ScalarMeta
from hiku.types import (
IDMeta,
IntegerMeta,
Expand Down Expand Up @@ -90,6 +95,8 @@ def _encode(
if input_type:
return f"IO{val.__type_name__}"
return val.__type_name__
elif isinstance(val, ScalarMeta):
return val.__type_name__
elif isinstance(val, IntegerMeta):
return "Int"
elif isinstance(val, StringMeta):
Expand Down Expand Up @@ -187,15 +194,6 @@ def __init__(
self.mutation_graph = mutation_graph
self.federation_version = federation_version

def get_entity_types(self) -> t.List[str]:
entity_nodes = set()
for node in self.graph.nodes:
for directive in node.directives:
if isinstance(directive, Key):
entity_nodes.add(node.name)

return list(sorted(entity_nodes))

def export_data_types(
self,
) -> t.Iterator[
Expand Down Expand Up @@ -466,11 +464,13 @@ def visit_schema_directive(


def get_ast(
graph: Graph, mutation_graph: Optional[Graph], federation_version: int
graph: t.Union[Graph, FederationGraph],
mutation_graph: Optional[t.Union[Graph, FederationGraph]],
federation_version: int,
) -> ast.DocumentNode:
graph = _StripGraph().visit(graph)
graph = _StripGraph(federation_version).visit(graph)
if mutation_graph is not None:
mutation_graph = _StripGraph().visit(mutation_graph)
mutation_graph = _StripGraph(federation_version).visit(mutation_graph)
return ast.DocumentNode(
definitions=Exporter(graph, mutation_graph, federation_version).visit(
graph
Expand All @@ -479,6 +479,9 @@ def get_ast(


class _StripGraph(GraphTransformer):
def __init__(self, federation_version: int):
self.federation_version = federation_version

def visit_root(self, obj: Root) -> Root:
def skip(field: t.Union[Field, Link]) -> bool:
return field.name in ["__typename", "_entities", "_service"]
Expand All @@ -498,11 +501,22 @@ def skip(node: Node) -> bool:
for name, type_ in obj.data_types.items()
if not name.startswith("_Service")
}

entities = get_entity_types(
obj.nodes, federation_version=self.federation_version
)
unions = []
for u in obj.unions:
if u.name == "_Entity" and entities:
unions.append(Union("_Entity", entities))
else:
unions.append(u)

return Graph(
[self.visit(node) for node in obj.items if not skip(node)],
data_types,
obj.directives,
obj.unions,
unions,
obj.interfaces,
obj.enums,
obj.scalars,
Expand All @@ -522,7 +536,7 @@ def skip(field: t.Union[Field, Link]) -> bool:


def print_sdl(
graph: Graph,
graph: FederationGraph,
mutation_graph: Optional[Graph] = None,
federation_version: int = DEFAULT_FEDERATION_VERSION,
) -> str:
Expand Down
4 changes: 3 additions & 1 deletion hiku/federation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ def get_keys(graph: Graph, typename: str) -> List[FieldSet]:
return [d.fields for d in node.directives if isinstance(d, Key)]


def get_entity_types(nodes: List[Node]) -> List[str]:
def get_entity_types(nodes: List[Node], federation_version: int) -> List[str]:
entity_nodes = set()
for node in nodes:
if node.name is not None:
for directive in node.directives:
if isinstance(directive, Key):
if not directive.resolvable and federation_version == 2:
continue
entity_nodes.add(node.name)

return list(sorted(entity_nodes))
23 changes: 20 additions & 3 deletions hiku/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,9 @@ class AbstractGraph(AbstractBase, ABC):
pass


G = t.TypeVar("G", bound="Graph")


class Graph(AbstractGraph):
"""Collection of nodes - definition of the graph
Expand Down Expand Up @@ -857,6 +860,23 @@ def scalars_map(self) -> "OrderedDict[str, t.Type[Scalar]]":
def accept(self, visitor: "AbstractGraphVisitor") -> t.Any:
return visitor.visit_graph(self)

@classmethod
def from_graph(cls: t.Type[G], other: G, root: Root) -> G:
"""Create graph from other graph, with new root node.
Useful for creating mutation graph from query graph.
Example:
MUTATION_GRAPH = Graph.from_graph(QUERY_GRAPH, Root([...]))
"""
return cls(
other.nodes + [root],
data_types=other.data_types,
directives=other.directives,
unions=other.unions,
interfaces=other.interfaces,
enums=other.enums,
scalars=other.scalars,
)


class AbstractGraphVisitor(ABC):
@abstractmethod
Expand Down Expand Up @@ -1018,9 +1038,6 @@ def visit_graph(self, obj: Graph) -> Graph:
)


G = t.TypeVar("G", bound=Graph)


def apply(graph: G, transformers: List[GraphTransformer]) -> G:
"""Helper function to apply graph transformations
Expand Down

0 comments on commit d523817

Please sign in to comment.