Skip to content

Commit

Permalink
Merge pull request #119 from evo-company/refactor-fragments-parsing-a…
Browse files Browse the repository at this point in the history
…gain

move node fragments from fields into fragments attribute, rewrite fra…
  • Loading branch information
kindermax committed Aug 17, 2023
2 parents 61b15c2 + 108676c commit fdffa57
Show file tree
Hide file tree
Showing 24 changed files with 613 additions and 374 deletions.
17 changes: 16 additions & 1 deletion hiku/denormalize/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ def process(self, query: Node) -> t.Dict:
self.visit(query)
return self._res.pop()

def visit_node(self, obj: Node) -> t.Any:
for item in obj.fields:
self.visit(item)

type_name = None
if isinstance(self._data[-1], Proxy):
type_name = self._data[-1].__ref__.node

for fr in obj.fragments:
if type_name is not None and type_name != fr.type_name:
# do not visit fragment if type specified and not match
continue

self.visit(fr)

def visit_field(self, obj: Field) -> None:
if isinstance(self._data[-1], Proxy):
type_name = self._data[-1].__ref__.node
Expand All @@ -86,7 +101,7 @@ def visit_field(self, obj: Field) -> None:
self._res[-1][obj.result_key] = self._data[-1][obj.result_key]

def visit_link(self, obj: Link) -> None:
if isinstance(self._type[-1], Union):
if isinstance(self._type[-1], (Union, Interface)):
type_ = self._types[self._data[-1].__ref__.node].__field_types__[
obj.name
]
Expand Down
7 changes: 1 addition & 6 deletions hiku/denormalize/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,10 @@ def visit_field(self, obj: Field) -> None:
type_name = self._data[-1].__ref__.node
self._res[-1][obj.result_key] = type_name
else:
if isinstance(self._type[-1], (Union, Interface)):
type_name = self._data[-1].__ref__.node

if obj.name not in self._types[type_name].__field_types__:
return
super().visit_field(obj)

def visit_link(self, obj: Link) -> None:
if isinstance(self._type[-1], Union):
if isinstance(self._type[-1], (Union, Interface)):
type_ = self._types[self._data[-1].__ref__.node].__field_types__[
obj.name
]
Expand Down
13 changes: 10 additions & 3 deletions hiku/endpoint/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,22 @@ class _StripQuery(QueryTransformer):

def visit_node(self, obj: Node) -> Node:
return obj.copy(
fields=[self.visit(f) for f in obj.fields if f.name != "__typename"]
fields=[
self.visit(f) for f in obj.fields if f.name != "__typename"
],
fragments=[self.visit(f) for f in obj.fragments],
)

def visit_fragment(self, obj: Fragment) -> Fragment:
return obj.copy(node=self.visit(obj.node))


G = t.TypeVar("G", bound=Graph)


def _switch_graph(
data: t.Dict, query_graph: Graph, mutation_graph: t.Optional[Graph] = None
) -> t.Tuple[Graph, Operation]:
data: t.Dict, query_graph: G, mutation_graph: t.Optional[G] = None
) -> t.Tuple[G, Operation]:
try:
op = read_operation(
data["query"],
Expand All @@ -75,6 +81,7 @@ def _switch_graph(
"Unsupported operation type: {!r}".format(op.type),
]
)

return graph, op


Expand Down
93 changes: 24 additions & 69 deletions hiku/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from .compat import Concatenate, ParamSpec
from .executors.base import SyncAsyncExecutor
from .query import (
Fragment,
Node as QueryNode,
Field as QueryField,
Link as QueryLink,
Expand Down Expand Up @@ -127,20 +126,19 @@ def enter_path(self, type_: Any) -> Iterator[None]:

def visit_node(self, obj: QueryNode) -> QueryNode:
fields = []
fragments = []

for f in obj.fields:
if f.name == "__typename":
fields.append(f)
continue

type_ = None
if isinstance(f, Fragment):
type_ = self._graph.nodes_map[f.name]

with self.enter_path(type_):
else:
fields.append(self.visit(f))

return obj.copy(fields=fields)
for fr in obj.fragments:
with self.enter_path(self._graph.nodes_map[fr.type_name]):
fragments.append(self.visit(fr))

return obj.copy(fields=fields, fragments=fragments)

def visit_field(self, obj: QueryField) -> QueryField:
graph_obj = self._path[-1].fields_map[obj.name]
Expand All @@ -154,15 +152,9 @@ def visit_link(self, obj: QueryLink) -> QueryLink:

if isinstance(graph_obj, Link):
if graph_obj.type_info.type_enum is LinkType.UNION:
if obj.fragment_type:
self._path.append(self._graph.nodes_map[obj.fragment_type])
self._path.append(self._graph.unions_map[graph_obj.node])
elif graph_obj.type_info.type_enum is LinkType.INTERFACE:
if obj.fragment_type:
self._path.append(self._graph.nodes_map[obj.fragment_type])
else:
self._path.append(
self._graph.interfaces_map[graph_obj.node]
)
self._path.append(self._graph.interfaces_map[graph_obj.node])
else:
self._path.append(self._graph.nodes_map[graph_obj.node])
try:
Expand Down Expand Up @@ -202,6 +194,12 @@ def split(
) -> Tuple[List[CallableFieldGroup], List[LinkGroup]]:
for item in query_node.fields:
self.visit(item)

for fr in query_node.fragments:
if fr.type_name != self._node.name:
continue
self.visit(fr)

return self._fields, self._links

def visit_node(self, obj: QueryNode) -> None:
Expand All @@ -222,7 +220,6 @@ def visit_link(self, obj: QueryLink) -> None:
self.visit(QueryField(r))
else:
self.visit(QueryField(graph_obj.requires))

self._links.append((graph_obj, obj))
else:
assert isinstance(graph_obj, Field), type(graph_obj)
Expand Down Expand Up @@ -748,72 +745,30 @@ def process_link(
from_list = ids is not None and graph_link.requires is not None
to_ids = link_result_to_ids(from_list, graph_link.type_enum, result)
if to_ids:
if graph_link.type_info.type_enum is LinkType.UNION and isinstance(
to_ids, list
):
if graph_link.type_info.type_enum in (
LinkType.UNION,
LinkType.INTERFACE,
) and isinstance(to_ids, list):
grouped_ids = defaultdict(list)
for id_, type_ref in to_ids:
grouped_ids[type_ref.__type_name__].append(id_)

# FIXME: call track len(ids) - 1 times because first track was
# already called by process_node for this link

fragments_map: Dict[str, Fragment] = {}
for f in query_link.node.fields:
if isinstance(f, Fragment):
fragments_map[f.name] = f

track_times = len(grouped_ids) - 1

for type_name, type_ids in grouped_ids.items():
query_node = query_link.node

if type_name in fragments_map:
query_node = fragments_map[type_name].node

self.process_node(
path,
self._graph.nodes_map[type_name],
query_node,
query_link.node,
list(type_ids),
)
for _ in range(track_times):
self._track(path)
elif (
graph_link.type_info.type_enum is LinkType.INTERFACE
and isinstance(to_ids, list)
):
grouped_ids = defaultdict(list)
for id_, type_ref in to_ids:
grouped_ids[type_ref.__type_name__].append(id_)

fragments_map = {}
for f in query_link.node.fields:
if isinstance(f, Fragment):
fragments_map[f.name] = f

track_times = len(grouped_ids) - 1

for type_name, type_ids in grouped_ids.items():
query_node = query_link.node

if type_name in fragments_map:
query_node = fragments_map[type_name].result_node(
query_link.node
)

self.process_node(
path,
self._graph.nodes_map[type_name],
query_node,
list(type_ids),
)

for _ in range(track_times):
# FIXME: call track len(ids) - 1 times because first track was
# already called by process_node for this link
for _ in range(len(grouped_ids) - 1):
self._track(path)
else:
if graph_link.type_enum is MaybeMany:
to_ids = [id_ for id_ in to_ids if id_ is not Nothing]

self.process_node(
path,
self._graph.nodes_map[graph_link.node],
Expand Down
51 changes: 28 additions & 23 deletions hiku/federation/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
Type,
)

from hiku.federation.sdl import print_sdl

from hiku.federation.graph import Graph
from hiku.federation.engine import Engine
from hiku.federation.introspection import (
BaseFederatedGraphQLIntrospection,
Expand All @@ -32,10 +35,7 @@
GraphQLError,
_StripQuery,
)
from hiku.graph import (
Graph,
apply,
)
from hiku.graph import apply
from hiku.query import Node
from hiku.result import Proxy
from hiku.readers.graphql import Operation
Expand Down Expand Up @@ -94,17 +94,6 @@ def __init__(
def context(self, op: Operation) -> Iterator[Dict]:
yield {}

def postprocess_result(
self, result: Proxy, graph: Graph, op: Operation
) -> Dict:
if "_service" in op.query.fields_map:
return {"_service": {"sdl": result["sdl"]}}

type_name = _type_names[op.type]

data = DenormalizeGraphQL(graph, result, type_name).process(op.query)
return data


class BaseSyncFederatedGraphQLEndpoint(BaseFederatedGraphEndpoint):
@abstractmethod
Expand Down Expand Up @@ -138,14 +127,21 @@ def introspection_cls(self) -> Type[BaseFederatedGraphQLIntrospection]:
return FederatedGraphQLIntrospection

def execute(self, graph: Graph, op: Operation, ctx: Optional[Dict]) -> Dict:
if ctx is None:
ctx = {}

stripped_query = _process_query(graph, op.query)
if "_service" in stripped_query.fields_map:
result = self.engine.execute_service(graph, self.mutation_graph)
else:
result = self.engine.execute(graph, stripped_query, ctx or {}, op)
ctx["__sdl__"] = print_sdl(
self.query_graph,
self.mutation_graph,
federation_version=self.engine.federation_version,
)

result = self.engine.execute(graph, stripped_query, ctx, op)
assert isinstance(result, Proxy)
return self.postprocess_result(result, graph, op)
type_name = _type_names[op.type]
return DenormalizeGraphQL(graph, result, type_name).process(op.query)

def dispatch(self, data: Dict) -> Dict:
try:
Expand All @@ -169,15 +165,24 @@ def introspection_cls(self) -> Type[BaseFederatedGraphQLIntrospection]:
async def execute(
self, graph: Graph, op: Operation, ctx: Optional[Dict]
) -> Dict:
if ctx is None:
ctx = {}

stripped_query = _process_query(graph, op.query)

if "_service" in stripped_query.fields_map:
coro = self.engine.execute_service(graph, self.mutation_graph)
else:
coro = self.engine.execute(graph, stripped_query, ctx)
ctx["__sdl__"] = print_sdl(
self.query_graph,
self.mutation_graph,
federation_version=self.engine.federation_version,
)

coro = self.engine.execute(graph, stripped_query, ctx)

assert isawaitable(coro)
result = await coro
return self.postprocess_result(result, graph, op)
type_name = _type_names[op.type]
return DenormalizeGraphQL(graph, result, type_name).process(op.query)

async def dispatch(self, data: Dict) -> Dict:
try:
Expand Down

0 comments on commit fdffa57

Please sign in to comment.