Skip to content

Commit

Permalink
GraphQL: Improve schema error handling & tidy (#6026)
Browse files Browse the repository at this point in the history
  • Loading branch information
MetRonnie committed May 21, 2024
1 parent 095fa7e commit 1a7f797
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 100 deletions.
50 changes: 23 additions & 27 deletions cylc/flow/network/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
"""

from functools import partial
from inspect import isclass, iscoroutinefunction
import logging
from typing import TYPE_CHECKING, Any, Tuple, Union

from inspect import isclass, iscoroutinefunction

from graphene.utils.str_converters import to_snake_case
from graphql.execution.utils import (
get_operation_root_type, get_field_def
Expand All @@ -35,16 +34,16 @@
from graphql.backend.base import GraphQLBackend, GraphQLDocument
from graphql.backend.core import execute_and_validate
from graphql.utils.base import type_from_ast
from graphql.type import get_named_type
from graphql.type.definition import get_named_type
from promise import Promise
from rx import Observable

from cylc.flow.network.schema import NODE_MAP, get_type_str
from cylc.flow.network.schema import NODE_MAP

if TYPE_CHECKING:
from graphql.execution import ExecutionResult
from graphql.language.ast import Document
from graphql.type import GraphQLSchema
from graphql.type.schema import GraphQLSchema


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -376,18 +375,18 @@ def resolve(self, next_, root, info, **args):

# Avoid using the protobuf default if field isn't set.
if (
hasattr(root, 'ListFields')
and hasattr(root, field_name)
and get_type_str(info.return_type) not in NODE_MAP
hasattr(root, 'ListFields')
and hasattr(root, field_name)
and get_named_type(info.return_type).name not in NODE_MAP
):

# Gather fields set in root
parent_path_string = f'{info.path[:-1:]}'
stamp = getattr(root, 'stamp', '')
if (
parent_path_string not in self.field_sets
or self.field_sets[
parent_path_string]['stamp'] != stamp
parent_path_string not in self.field_sets
or self.field_sets[
parent_path_string]['stamp'] != stamp
):
self.field_sets[parent_path_string] = {
'stamp': stamp,
Expand All @@ -398,36 +397,33 @@ def resolve(self, next_, root, info, **args):
}

if (
parent_path_string in self.field_sets
and field_name not in self.field_sets[
parent_path_string]['fields']
parent_path_string in self.field_sets
and field_name not in self.field_sets[
parent_path_string]['fields']
):
return None
# Do not resolve subfields of an empty type
# by setting as null in parent/root.
elif (
isinstance(root, dict)
and field_name in root
):
elif isinstance(root, dict) and field_name in root:
field_value = root[field_name]
if (
field_value in EMPTY_VALUES
or (
hasattr(field_value, 'ListFields')
and not field_value.ListFields()
)
field_value in EMPTY_VALUES
or (
hasattr(field_value, 'ListFields')
and not field_value.ListFields()
)
):
return None
if (
info.operation.operation in self.ASYNC_OPS
or iscoroutinefunction(next_)
info.operation.operation in self.ASYNC_OPS
or iscoroutinefunction(next_)
):
return self.async_null_setter(next_, root, info, **args)
return null_setter(next_(root, info, **args))

if (
info.operation.operation in self.ASYNC_OPS
or iscoroutinefunction(next_)
info.operation.operation in self.ASYNC_OPS
or iscoroutinefunction(next_)
):
return self.async_resolve(next_, root, info, **args)
return next_(root, info, **args)
Expand Down
26 changes: 16 additions & 10 deletions cylc/flow/network/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
from time import time
from typing import (
Any,
AsyncGenerator,
Dict,
List,
NamedTuple,
Optional,
Tuple,
TYPE_CHECKING,
Union,
cast,
)
from uuid import uuid4

Expand All @@ -58,6 +60,8 @@
from cylc.flow.data_store_mgr import DataStoreMgr
from cylc.flow.scheduler import Scheduler

DeltaQueue = queue.Queue[Tuple[str, str, dict]]


class TaskMsg(NamedTuple):
"""Tuple for Scheduler.message_queue"""
Expand Down Expand Up @@ -395,7 +399,7 @@ async def get_nodes_all(self, node_type, args):
[
node
for flow in await self.get_workflows_data(args)
for node in flow.get(node_type).values()
for node in flow[node_type].values()
if node_filter(
node,
node_type,
Expand Down Expand Up @@ -538,7 +542,9 @@ async def get_nodes_edges(self, root_nodes, args):
nodes=sort_elements(nodes, args),
edges=sort_elements(edges, args))

async def subscribe_delta(self, root, info, args):
async def subscribe_delta(
self, root, info: 'ResolveInfo', args
) -> AsyncGenerator[Any, None]:
"""Delta subscription async generator.
Async generator mapping the incoming protobuf deltas to
Expand All @@ -553,19 +559,19 @@ async def subscribe_delta(self, root, info, args):
self.delta_store[sub_id] = {}

op_id = root
if 'ops_queue' not in info.context:
info.context['ops_queue'] = {}
info.context['ops_queue'][op_id] = queue.Queue()
op_queue = info.context['ops_queue'][op_id]
op_queue: queue.Queue[Tuple[UUID, str]] = queue.Queue()
cast('dict', info.context).setdefault(
'ops_queue', {}
)[op_id] = op_queue
self.delta_processing_flows[sub_id] = set()
delta_processing_flows = self.delta_processing_flows[sub_id]

delta_queues = self.data_store_mgr.delta_queues
deltas_queue = queue.Queue()
deltas_queue: DeltaQueue = queue.Queue()

counters = {}
delta_yield_queue = queue.Queue()
flow_delta_queues = {}
counters: Dict[str, int] = {}
delta_yield_queue: DeltaQueue = queue.Queue()
flow_delta_queues: Dict[str, queue.Queue[Tuple[str, dict]]] = {}
try:
# Iterate over the queue yielding deltas
w_ids = workflow_ids
Expand Down
Loading

0 comments on commit 1a7f797

Please sign in to comment.