Skip to content

Commit

Permalink
Merge pull request #116 from evo-company/refactor-fragments-parsing
Browse files Browse the repository at this point in the history
refactor fragments parsing, introduce Fragment type
  • Loading branch information
kindermax committed Aug 9, 2023
2 parents 0476de6 + 4277491 commit 0ccd628
Show file tree
Hide file tree
Showing 14 changed files with 591 additions and 331 deletions.
5 changes: 3 additions & 2 deletions hiku/denormalize/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ def visit_field(self, obj: Field) -> None:

def visit_link(self, obj: Link) -> None:
if isinstance(self._type[-1], Union):
assert obj.parent_type in self._type[-1].types
type_ = self._types[obj.parent_type].__field_types__[obj.name]
type_ = self._types[self._data[-1].__ref__.node].__field_types__[
obj.name
]
elif isinstance(self._type[-1], RecordMeta):
type_ = self._type[-1].__field_types__[obj.name]
else:
Expand Down
5 changes: 3 additions & 2 deletions hiku/denormalize/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ def visit_field(self, obj: Field) -> None:

def visit_link(self, obj: Link) -> None:
if isinstance(self._type[-1], Union):
assert obj.parent_type in self._type[-1].types
type_ = self._types[obj.parent_type].__field_types__[obj.name]
type_ = self._types[self._data[-1].__ref__.node].__field_types__[
obj.name
]
elif isinstance(self._type[-1], RecordMeta):
type_ = self._type[-1].__field_types__[obj.name]
else:
Expand Down
4 changes: 4 additions & 0 deletions hiku/endpoint/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Graph,
)
from ..query import (
Fragment,
QueryTransformer,
Node,
)
Expand Down Expand Up @@ -45,6 +46,9 @@ def visit_node(self, obj: Node) -> Node:
fields=[self.visit(f) for f in obj.fields if f.name != "__typename"]
)

def visit_fragment(self, obj: Fragment) -> Fragment:
return obj.copy(node=self.visit(obj.node))


def _switch_graph(
data: t.Dict, query_graph: Graph, mutation_graph: t.Optional[Graph] = None
Expand Down
70 changes: 46 additions & 24 deletions hiku/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
)
from .compat import Concatenate, ParamSpec
from .executors.base import SyncAsyncExecutor
from .interface import SplitInterfaceQueryByNodes
from .query import (
Fragment,
Node as QueryNode,
Field as QueryField,
Link as QueryLink,
Expand All @@ -42,7 +42,6 @@
)
from .graph import (
FieldType,
Interface,
Link,
LinkType,
Maybe,
Expand All @@ -53,7 +52,6 @@
Field,
Graph,
Node,
Union as GraphUnion,
)
from .result import (
Proxy,
Expand All @@ -67,7 +65,6 @@
TaskSet,
SubmitRes,
)
from .union import SplitUnionQueryByNodes
from .utils import ImmutableDict
from .utils.serialize import serialize

Expand Down Expand Up @@ -130,21 +127,15 @@ def enter_path(self, type_: Any) -> Iterator[None]:

def visit_node(self, obj: QueryNode) -> QueryNode:
fields = []
is_union = isinstance(self._path[-1], GraphUnion)
is_interface = isinstance(self._path[-1], Interface)

for f in obj.fields:
if f.name == "__typename":
fields.append(f)
continue

type_ = None

if is_union:
type_ = self._graph.nodes_map[f.parent_type]
elif is_interface:
if f.parent_type is not None:
type_ = self._graph.nodes_map[f.parent_type]
if isinstance(f, Fragment):
type_ = self._graph.nodes_map[f.name]

with self.enter_path(type_):
fields.append(self.visit(f))
Expand All @@ -163,9 +154,15 @@ def visit_link(self, obj: QueryLink) -> QueryLink:

if isinstance(graph_obj, Link):
if graph_obj.type_info.type_enum is LinkType.UNION:
self._path.append(self._graph.unions_map[graph_obj.node])
if obj.fragment_type:
self._path.append(self._graph.nodes_map[obj.fragment_type])
elif graph_obj.type_info.type_enum is LinkType.INTERFACE:
self._path.append(self._graph.interfaces_map[graph_obj.node])
if obj.fragment_type:
self._path.append(self._graph.nodes_map[obj.fragment_type])
else:
self._path.append(
self._graph.interfaces_map[graph_obj.node]
)
else:
self._path.append(self._graph.nodes_map[graph_obj.node])
try:
Expand All @@ -191,6 +188,10 @@ def visit_link(self, obj: QueryLink) -> QueryLink:


class SplitQuery(QueryVisitor):
"""Splits query into two groups: fields and links.
This is needed because we execute fields and links separately.
"""

def __init__(self, graph_node: Node) -> None:
self._node = graph_node
self._fields: List[CallableFieldGroup] = []
Expand All @@ -203,8 +204,9 @@ def split(
self.visit(item)
return self._fields, self._links

def visit_node(self, obj: QueryNode) -> NoReturn:
raise ValueError("Unexpected value: {!r}".format(obj))
def visit_node(self, obj: QueryNode) -> None:
for item in obj.fields:
self.visit(item)

def visit_field(self, obj: QueryField) -> None:
graph_obj = self._node.fields_map[obj.name]
Expand All @@ -220,6 +222,7 @@ def visit_link(self, obj: QueryLink) -> None:
self.visit(QueryField(r))
else:
self.visit(QueryField(graph_obj.requires))

self._links.append((graph_obj, obj))
else:
assert isinstance(graph_obj, Field), type(graph_obj)
Expand Down Expand Up @@ -754,15 +757,24 @@ def process_link(

# FIXME: call track len(ids) - 1 times because first track was
# already called by process_node for this link

fragments_map: Dict[str, Fragment] = {}
for f in query_link.node.fields:
if isinstance(f, Fragment):
fragments_map[f.name] = f

track_times = len(grouped_ids) - 1
union_nodes = SplitUnionQueryByNodes(
self._graph, self._graph.unions_map[graph_link.node]
).split(query_link.node)

for type_name, type_ids in grouped_ids.items():
query_node = query_link.node

if type_name in fragments_map:
query_node = fragments_map[type_name].node

self.process_node(
path,
self._graph.nodes_map[type_name],
union_nodes[type_name],
query_node,
list(type_ids),
)
for _ in range(track_times):
Expand All @@ -775,18 +787,28 @@ def process_link(
for id_, type_ref in to_ids:
grouped_ids[type_ref.__type_name__].append(id_)

fragments_map = {}
for f in query_link.node.fields:
if isinstance(f, Fragment):
fragments_map[f.name] = f

track_times = len(grouped_ids) - 1

interface_nodes = SplitInterfaceQueryByNodes(self._graph).split(
query_link.node
)
for type_name, type_ids in grouped_ids.items():
query_node = query_link.node

if type_name in fragments_map:
query_node = fragments_map[type_name].result_node(
query_link.node
)

self.process_node(
path,
self._graph.nodes_map[type_name],
interface_nodes[type_name],
query_node,
list(type_ids),
)

for _ in range(track_times):
self._track(path)
else:
Expand Down
40 changes: 0 additions & 40 deletions hiku/interface.py

This file was deleted.

0 comments on commit 0ccd628

Please sign in to comment.