diff --git a/graphql/execution/executor.py b/graphql/execution/executor.py index f23f59c4..d4043612 100644 --- a/graphql/execution/executor.py +++ b/graphql/execution/executor.py @@ -47,7 +47,9 @@ def subscribe(*args, **kwargs): # type: (*Any, **Any) -> Union[ExecutionResult, Observable] allow_subscriptions = kwargs.pop("allow_subscriptions", True) - return execute(*args, allow_subscriptions=allow_subscriptions, **kwargs) + return execute( # type: ignore + *args, allow_subscriptions=allow_subscriptions, **kwargs + ) def execute( @@ -116,7 +118,7 @@ def execute( allow_subscriptions, ) - def executor(v): + def promise_executor(v): # type: (Optional[Any]) -> Union[OrderedDict, Promise, Observable] return execute_operation(exe_context, exe_context.operation, root) @@ -135,7 +137,9 @@ def on_resolve(data): return ExecutionResult(data=data, errors=exe_context.errors) - promise = Promise.resolve(None).then(executor).catch(on_rejected).then(on_resolve) + promise = ( + Promise.resolve(None).then(promise_executor).catch(on_rejected).then(on_resolve) + ) if not return_promise: exe_context.executor.wait_until_finished() @@ -151,7 +155,7 @@ def on_resolve(data): def execute_operation( exe_context, # type: ExecutionContext operation, # type: OperationDefinition - root_value, # type: Union[None, Data, type] + root_value, # type: Any ): # type: (...) -> Union[OrderedDict, Promise] type = get_operation_root_type(exe_context.schema, operation) @@ -224,7 +228,7 @@ def execute_fields( parent_type, # type: GraphQLObjectType source_value, # type: Any fields, # type: DefaultOrderedDict - path, # type: Union[List[Union[int, str]], List[str]] + path, # type: List[Union[int, str]] info, # type: Optional[ResolveInfo] ): # type: (...) -> Union[OrderedDict, Promise] @@ -257,39 +261,43 @@ def execute_fields( def subscribe_fields( exe_context, # type: ExecutionContext parent_type, # type: GraphQLObjectType - source_value, # type: Union[None, Data, type] + source_value, # type: Any fields, # type: DefaultOrderedDict ): # type: (...) -> Observable - exe_context = SubscriberExecutionContext(exe_context) + subscriber_exe_context = SubscriberExecutionContext(exe_context) def on_error(error): - exe_context.report_error(error) + subscriber_exe_context.report_error(error) def map_result( data # type: Union[Dict[str, None], Dict[str, OrderedDict], Dict[str, str]] ): # type: (...) -> ExecutionResult - if exe_context.errors: - result = ExecutionResult(data=data, errors=exe_context.errors) + if subscriber_exe_context.errors: + result = ExecutionResult(data=data, errors=subscriber_exe_context.errors) else: result = ExecutionResult(data=data) - exe_context.reset() + subscriber_exe_context.reset() return result - observables = [] + observables = [] # type: List[Observable] # assert len(fields) == 1, "Can only subscribe one element at a time." for response_name, field_asts in fields.items(): result = subscribe_field( - exe_context, parent_type, source_value, field_asts, [response_name] + subscriber_exe_context, + parent_type, + source_value, + field_asts, + [response_name], ) if result is Undefined: continue def catch_error(error): - exe_context.errors.append(error) + subscriber_exe_context.errors.append(error) return Observable.just(None) # Map observable results @@ -305,10 +313,10 @@ def catch_error(error): def resolve_field( exe_context, # type: ExecutionContext parent_type, # type: GraphQLObjectType - source, # type: Union[None, Cat, Dog] + source, # type: Any field_asts, # type: List[Field] parent_info, # type: Optional[ResolveInfo] - field_path, # type: Union[List[Union[int, str]], List[str]] + field_path, # type: List[Union[int, str]] ): # type: (...) -> Any field_ast = field_asts[0] @@ -360,7 +368,7 @@ def resolve_field( def subscribe_field( exe_context, # type: SubscriberExecutionContext parent_type, # type: GraphQLObjectType - source, # type: Union[None, Data, type] + source, # type: Any field_asts, # type: List[Field] path, # type: List[str] ): @@ -430,12 +438,12 @@ def subscribe_field( def resolve_or_error( resolve_fn, # type: Callable - source, # type: Union[None, Cat, Dog] + source, # type: Any info, # type: ResolveInfo args, # type: Dict - executor, # type: Union[BaseExecutor, SyncExecutor] + executor, # type: Any ): - # type: (...) -> Union[List[Union[Cat, Dog]], bool, str] + # type: (...) -> Any try: return executor.execute(resolve_fn, source, info, **args) except Exception as e: @@ -444,7 +452,7 @@ def resolve_or_error( info.parent_type.name, info.field_name ) ) - e.stack = sys.exc_info()[2] + e.stack = sys.exc_info()[2] # type: ignore return e @@ -453,10 +461,10 @@ def complete_value_catching_error( return_type, # type: Any field_asts, # type: List[Field] info, # type: ResolveInfo - path, # type: Union[List[Union[int, str]], List[str]] + path, # type: List[Union[int, str]] result, # type: Any ): - # type: (...) -> Union[bool, str] + # type: (...) -> Any # If the field type is non-nullable, then it is resolved without any # protection from errors. if isinstance(return_type, GraphQLNonNull): @@ -472,7 +480,7 @@ def complete_value_catching_error( def handle_error(error): # type: (Union[GraphQLError, GraphQLLocatedError]) -> Optional[Any] - traceback = completed._traceback + traceback = completed._traceback # type: ignore exe_context.report_error(error, traceback) return None @@ -490,10 +498,10 @@ def complete_value( return_type, # type: Any field_asts, # type: List[Field] info, # type: ResolveInfo - path, # type: Union[List[Union[int, str]], List[str]] + path, # type: List[Union[int, str]] result, # type: Any ): - # type: (...) -> Union[bool, str] + # type: (...) -> Any """ Implements the instructions for completeValue as defined in the "Field entries" section of the spec. @@ -566,10 +574,10 @@ def complete_list_value( return_type, # type: GraphQLList field_asts, # type: List[Field] info, # type: ResolveInfo - path, # type: List[str] + path, # type: List[Union[int, str]] result, # type: Any ): - # type: (...) -> Any + # type: (...) -> List[Any] """ Complete a list value by completing each item in the list with the inner type """ @@ -597,10 +605,10 @@ def complete_list_value( def complete_leaf_value( return_type, # type: Union[GraphQLEnumType, GraphQLScalarType] - path, # type: Union[List[Union[int, str]], List[str]] - result, # type: Union[int, str] + path, # type: List[Union[int, str]] + result, # type: Any ): - # type: (...) -> Union[int, str] + # type: (...) -> Union[int, str, float, bool] """ Complete a Scalar or Enum by serializing to a valid value, returning null if serialization is not possible. """ @@ -625,12 +633,12 @@ def complete_abstract_value( path, # type: List[Union[int, str]] result, # type: Any ): - # type: (...) -> OrderedDict + # type: (...) -> Dict[str, Any] """ Complete an value of an abstract type by determining the runtime type of that value, then completing based on that type. """ - runtime_type = None + runtime_type = None # type: Union[str, GraphQLObjectType, None] # Field type must be Object, Interface or Union and expect sub-selections. if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)): @@ -640,7 +648,7 @@ def complete_abstract_value( runtime_type = get_default_resolve_type_fn(result, info, return_type) if isinstance(runtime_type, string_types): - runtime_type = info.schema.get_type(runtime_type) + runtime_type = info.schema.get_type(runtime_type) # type: ignore if not isinstance(runtime_type, GraphQLObjectType): raise GraphQLError( @@ -671,11 +679,12 @@ def get_default_resolve_type_fn( info, # type: ResolveInfo abstract_type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] ): - # type: (...) -> GraphQLObjectType + # type: (...) -> Optional[GraphQLObjectType] possible_types = info.schema.get_possible_types(abstract_type) for type in possible_types: if callable(type.is_type_of) and type.is_type_of(value, info): return type + return None def complete_object_value( @@ -683,10 +692,10 @@ def complete_object_value( return_type, # type: GraphQLObjectType field_asts, # type: List[Field] info, # type: ResolveInfo - path, # type: Union[List[Union[int, str]], List[str]] + path, # type: List[Union[int, str]] result, # type: Any ): - # type: (...) -> Union[OrderedDict, Promise] + # type: (...) -> Dict[str, Any] """ Complete an Object value by evaluating all sub-selections. """ @@ -708,7 +717,7 @@ def complete_nonnull_value( return_type, # type: GraphQLNonNull field_asts, # type: List[Field] info, # type: ResolveInfo - path, # type: Union[List[Union[int, str]], List[str]] + path, # type: List[Union[int, str]] result, # type: Any ): # type: (...) -> Any diff --git a/graphql/execution/executors/asyncio.py b/graphql/execution/executors/asyncio.py index 27637c23..8dc183eb 100644 --- a/graphql/execution/executors/asyncio.py +++ b/graphql/execution/executors/asyncio.py @@ -6,13 +6,13 @@ if False: # flake8: noqa from asyncio.unix_events import _UnixSelectorEventLoop - from typing import Optional, Any, Callable + from typing import Optional, Any, Callable, List try: from asyncio import ensure_future except ImportError: # ensure_future is only implemented in Python 3.4.4+ - def ensure_future(coro_or_future, loop=None): + def ensure_future(coro_or_future, loop=None): # type: ignore """Wrap a coroutine or an awaitable in a future. If the argument is a Future, it is returned directly. @@ -39,7 +39,7 @@ def ensure_future(coro_or_future, loop=None): def isasyncgen(obj): False - def asyncgen_to_observable(asyncgen): + def asyncgen_to_observable(asyncgen, loop=None): pass @@ -49,7 +49,7 @@ def __init__(self, loop=None): if loop is None: loop = get_event_loop() self.loop = loop - self.futures = [] + self.futures = [] # type: List[Future] def wait_until_finished(self): # type: () -> None diff --git a/graphql/execution/executors/thread.py b/graphql/execution/executors/thread.py index 35210d99..1544ec98 100644 --- a/graphql/execution/executors/thread.py +++ b/graphql/execution/executors/thread.py @@ -5,7 +5,7 @@ from .utils import process if False: # flake8: noqa - from typing import Any, Callable + from typing import Any, Callable, List class ThreadExecutor(object): @@ -14,7 +14,7 @@ class ThreadExecutor(object): def __init__(self, pool=False): # type: (bool) -> None - self.threads = [] + self.threads = [] # type: List[Thread] if pool: self.execute = self.execute_in_pool self.pool = ThreadPool(processes=pool) @@ -26,7 +26,8 @@ def wait_until_finished(self): while self.threads: threads = self.threads self.threads = [] - [thread.join() for thread in threads] + for thread in threads: + thread.join() def clean(self): self.threads = [] diff --git a/graphql/execution/executors/utils.py b/graphql/execution/executors/utils.py index 41cd29d1..d710e4c7 100644 --- a/graphql/execution/executors/utils.py +++ b/graphql/execution/executors/utils.py @@ -3,7 +3,7 @@ if False: # flake8: noqa from ..base import ResolveInfo from promise import Promise - from typing import Callable, Dict, Tuple, Union + from typing import Callable, Dict, Tuple, Union, Any def process( @@ -18,5 +18,5 @@ def process( p.do_resolve(val) except Exception as e: traceback = exc_info()[2] - e.stack = traceback + e.stack = traceback # type: ignore p.do_reject(e, traceback=traceback) diff --git a/graphql/execution/middleware.py b/graphql/execution/middleware.py index 0880858b..97cbec11 100644 --- a/graphql/execution/middleware.py +++ b/graphql/execution/middleware.py @@ -2,11 +2,11 @@ from functools import partial from itertools import chain -from promise import Promise +from promise import Promise, promisify if False: # flake8: noqa from .base import ResolveInfo - from typing import Any, Callable, Iterator, Tuple, Union, List, Dict + from typing import Any, Callable, Iterator, Tuple, Union, List, Dict, Iterable MIDDLEWARE_RESOLVER_FUNCTION = "resolve" @@ -20,16 +20,16 @@ class MiddlewareManager(object): ) def __init__(self, *middlewares, **kwargs): - # type: (*Any, **Dict[str, bool]) -> None + # type: (*Callable, **bool) -> None self.middlewares = middlewares self.wrap_in_promise = kwargs.get("wrap_in_promise", True) self._middleware_resolvers = ( list(get_middleware_resolvers(middlewares)) if middlewares else [] ) - self._cached_resolvers = {} + self._cached_resolvers = {} # type: Dict[Callable, Callable] def get_field_resolver(self, field_resolver): - # type: (Callable[[Any, ResolveInfo, ...], Any]) -> Callable[[Any, ResolveInfo, ...], Any] + # type: (Callable) -> Callable if field_resolver not in self._cached_resolvers: self._cached_resolvers[field_resolver] = middleware_chain( field_resolver, @@ -44,7 +44,7 @@ def get_field_resolver(self, field_resolver): def get_middleware_resolvers(middlewares): - # type: (Tuple[Any]) -> Iterator[Callable] + # type: (Tuple[Any, ...]) -> Iterator[Callable] for middleware in middlewares: # If the middleware is a function instead of a class if inspect.isfunction(middleware): @@ -55,7 +55,7 @@ def get_middleware_resolvers(middlewares): def middleware_chain(func, middlewares, wrap_in_promise): - # type: (Callable, List[Callable], bool) -> Callable + # type: (Callable, Iterable[Callable], bool) -> Callable if not middlewares: return func if wrap_in_promise: @@ -66,9 +66,10 @@ def middleware_chain(func, middlewares, wrap_in_promise): for middleware in middlewares: last_func = partial(middleware, last_func) if last_func else middleware - return last_func + return last_func # type: ignore -def make_it_promise(next, *a, **b): - # type: (Callable, *Any, **Any) -> Promise - return Promise.resolve(next(*a, **b)) +@promisify +def make_it_promise(next, *args, **kwargs): + # type: (Callable, *Any, **Any) -> Any + return next(*args, **kwargs) diff --git a/graphql/execution/utils.py b/graphql/execution/utils.py index bbab86c3..a44cd503 100644 --- a/graphql/execution/utils.py +++ b/graphql/execution/utils.py @@ -29,7 +29,7 @@ ) from .base import ResolveInfo from types import TracebackType - from typing import Any, List, Dict, Optional, Union, Callable, Set + from typing import Any, List, Dict, Optional, Union, Callable, Set, Tuple logger = logging.getLogger(__name__) @@ -59,9 +59,9 @@ def __init__( self, schema, # type: GraphQLSchema document_ast, # type: Document - root_value, # type: Union[None, Data, type] - context_value, # type: Optional[Context] - variable_values, # type: Dict[str, int] + root_value, # type: Any + context_value, # type: Any + variable_values, # type: Optional[Dict[str, Any]] operation_name, # type: Optional[str] executor, # type: Any middleware, # type: Optional[Any] @@ -71,9 +71,9 @@ def __init__( """Constructs a ExecutionContext object from the arguments passed to execute, which we will pass throughout the other execution methods.""" - errors = [] + errors = [] # type: List[Exception] operation = None - fragments = {} + fragments = {} # type: Dict[str, FragmentDefinition] for definition in document_ast.definitions: if isinstance(definition, ast.OperationDefinition): @@ -120,11 +120,11 @@ def __init__( self.variable_values = variable_values self.errors = errors self.context_value = context_value - self.argument_values_cache = {} + self.argument_values_cache = {} # type: Dict[Tuple[GraphQLField, Field], Dict[str, Any]] self.executor = executor self.middleware = middleware self.allow_subscriptions = allow_subscriptions - self._subfields_cache = {} + self._subfields_cache = {} # type: Dict[Tuple[GraphQLObjectType, Tuple[Field, ...]], DefaultOrderedDict] def get_field_resolver(self, field_resolver): # type: (Callable) -> Callable @@ -144,7 +144,7 @@ def get_argument_values(self, field_def, field_ast): return result def report_error(self, error, traceback=None): - # type: (GraphQLError, Optional[TracebackType]) -> None + # type: (Exception, Optional[TracebackType]) -> None exception = format_exception( type(error), error, getattr(error, "stack", None) or traceback ) @@ -156,7 +156,7 @@ def get_sub_fields(self, return_type, field_asts): k = return_type, tuple(field_asts) if k not in self._subfields_cache: subfield_asts = DefaultOrderedDict(list) - visited_fragment_names = set() + visited_fragment_names = set() # type: Set[str] for field_ast in field_asts: selection_set = field_ast.selection_set if selection_set: @@ -177,7 +177,7 @@ class SubscriberExecutionContext(object): def __init__(self, exe_context): # type: (ExecutionContext) -> None self.exe_context = exe_context - self.errors = [] + self.errors = [] # type: List[Exception] def reset(self): # type: () -> None @@ -262,7 +262,7 @@ def collect_fields( continue prev_fragment_names.add(frag_name) - fragment = ctx.fragments.get(frag_name) + fragment = ctx.fragments[frag_name] frag_directives = fragment.directives if ( not fragment @@ -279,7 +279,7 @@ def collect_fields( def should_include_node(ctx, directives): - # type: (ExecutionContext, List[Directive]) -> bool + # type: (ExecutionContext, Optional[List[Directive]]) -> bool """Determines if a field should be included based on the @include and @skip directives, where @skip has higher precidence than @include.""" # TODO: Refactor based on latest code diff --git a/graphql/execution/values.py b/graphql/execution/values.py index a0568321..c1a20f22 100644 --- a/graphql/execution/values.py +++ b/graphql/execution/values.py @@ -22,7 +22,7 @@ from ..language.ast import VariableDefinition, Argument from ..type.schema import GraphQLSchema from ..type.definition import GraphQLArgument - from typing import Any, Dict, List, Union, Dict + from typing import Any, Dict, List, Union, Dict, Optional __all__ = ["get_variable_values", "get_argument_values"] @@ -32,7 +32,7 @@ def get_variable_values( definition_asts, # type: List[VariableDefinition] inputs, # type: Any ): - # type: (...) -> Union[Dict[str, Dict[str, Any]], Dict[str, Dict[str, str]], Dict[str, int]] + # type: (...) -> Dict[str, Any] """Prepares an object map of variables of the correct type based on the provided variable definitions and arbitrary input. If the input cannot be parsed to match the variable definitions, a GraphQLError will be thrown.""" if inputs is None: @@ -53,7 +53,9 @@ def get_variable_values( ) elif value is None: if def_ast.default_value is not None: - values[var_name] = value_from_ast(def_ast.default_value, var_type) + values[var_name] = value_from_ast( + def_ast.default_value, var_type + ) # type: ignore if isinstance(var_type, GraphQLNonNull): raise GraphQLError( 'Variable "${var_name}" of required type "{var_type}" was not provided.'.format( @@ -82,8 +84,8 @@ def get_variable_values( def get_argument_values( arg_defs, # type: Union[Dict[str, GraphQLArgument], Dict] - arg_asts, # type: List[Argument] - variables=None, # type: Dict[str, int] + arg_asts, # type: Optional[List[Argument]] + variables=None, # type: Optional[Dict[str, Union[List, Dict, int, float, bool, str, None]]] ): # type: (...) -> Dict[str, Any] """Prepares an object map of argument values given a list of argument @@ -92,14 +94,16 @@ def get_argument_values( return {} if arg_asts: - arg_ast_map = {arg.name.value: arg for arg in arg_asts} + arg_ast_map = { + arg.name.value: arg for arg in arg_asts + } # type: Dict[str, Argument] else: arg_ast_map = {} result = {} for name, arg_def in arg_defs.items(): arg_type = arg_def.type - value_ast = arg_ast_map.get(name) + arg_ast = arg_ast_map.get(name) if name not in arg_ast_map: if arg_def.default_value is not None: result[arg_def.out_name or name] = arg_def.default_value @@ -111,11 +115,10 @@ def get_argument_values( ), arg_asts, ) - elif isinstance(value_ast.value, ast.Variable): - variable_name = value_ast.value.name.value - variable_value = variables.get(variable_name) + elif isinstance(arg_ast.value, ast.Variable): # type: ignore + variable_name = arg_ast.value.name.value # type: ignore if variables and variable_name in variables: - result[arg_def.out_name or name] = variable_value + result[arg_def.out_name or name] = variables[variable_name] elif arg_def.default_value is not None: result[arg_def.out_name or name] = arg_def.default_value elif isinstance(arg_type, GraphQLNonNull): @@ -128,9 +131,7 @@ def get_argument_values( continue else: - value_ast = value_ast.value - - value = value_from_ast(value_ast, arg_type, variables) + value = value_from_ast(arg_ast.value, arg_type, variables) # type: ignore if value is None: if arg_def.default_value is not None: value = arg_def.default_value @@ -144,7 +145,7 @@ def get_argument_values( def coerce_value(type, value): - # type: (Any, Any) -> Union[int, str] + # type: (Any, Any) -> Union[List, Dict, int, float, bool, str, None] """Given a type and any value, return a runtime value coerced to match the type.""" if isinstance(type, GraphQLNonNull): # Note: we're not checking that the result of coerceValue is diff --git a/graphql/language/printer.py b/graphql/language/printer.py index cedd8d73..02e15ee9 100644 --- a/graphql/language/printer.py +++ b/graphql/language/printer.py @@ -55,19 +55,19 @@ class PrintingVisitor(Visitor): __slots__ = () def leave_Name(self, node, *args): - # type: (Name, *Any) -> str + # type: (Any, *Any) -> str return node.value # type: ignore def leave_Variable(self, node, *args): - # type: (Variable, *Any) -> str + # type: (Any, *Any) -> str return "$" + node.name # type: ignore def leave_Document(self, node, *args): - # type: (Document, *Any) -> str + # type: (Any, *Any) -> str return join(node.definitions, "\n\n") + "\n" # type: ignore def leave_OperationDefinition(self, node, *args): - # type: (OperationDefinition, *Any) -> str + # type: (Any, *Any) -> str name = node.name selection_set = node.selection_set op = node.operation @@ -80,15 +80,15 @@ def leave_OperationDefinition(self, node, *args): return join([op, join([name, var_defs]), directives, selection_set], " ") def leave_VariableDefinition(self, node, *args): - # type: (VariableDefinition, *Any) -> str + # type: (Any, *Any) -> str return node.variable + ": " + node.type + wrap(" = ", node.default_value) def leave_SelectionSet(self, node, *args): - # type: (SelectionSet, *Any) -> str + # type: (Any, *Any) -> str return block(node.selections) def leave_Field(self, node, *args): - # type: (Field, *Any) -> str + # type: (Any, *Any) -> str return join( [ wrap("", node.alias, ": ") @@ -101,17 +101,17 @@ def leave_Field(self, node, *args): ) def leave_Argument(self, node, *args): - # type: (Argument, *Any) -> str + # type: (Any, *Any) -> str return "{0.name}: {0.value}".format(node) # Fragments def leave_FragmentSpread(self, node, *args): - # type: (FragmentSpread, *Any) -> str + # type: (Any, *Any) -> str return "..." + node.name + wrap(" ", join(node.directives, " ")) def leave_InlineFragment(self, node, *args): - # type: (InlineFragment, *Any) -> str + # type: (Any, *Any) -> str return join( [ "...", @@ -123,7 +123,7 @@ def leave_InlineFragment(self, node, *args): ) def leave_FragmentDefinition(self, node, *args): - # type: (FragmentDefinition, *Any) -> str + # type: (Any, *Any) -> str return ( "fragment {} on {} ".format(node.name, node.type_condition) + wrap("", join(node.directives, " "), " ") @@ -133,74 +133,74 @@ def leave_FragmentDefinition(self, node, *args): # Value def leave_IntValue(self, node, *args): - # type: (IntValue, *Any) -> str + # type: (Any, *Any) -> str return node.value def leave_FloatValue(self, node, *args): return node.value def leave_StringValue(self, node, *args): - # type: (StringValue, *Any) -> str + # type: (Any, *Any) -> str return json.dumps(node.value) def leave_BooleanValue(self, node, *args): - # type: (BooleanValue, *Any) -> str + # type: (Any, *Any) -> str return json.dumps(node.value) def leave_EnumValue(self, node, *args): - # type: (EnumValue, *Any) -> str + # type: (Any, *Any) -> str return node.value def leave_ListValue(self, node, *args): - # type: (ListValue, *Any) -> str + # type: (Any, *Any) -> str return "[" + join(node.values, ", ") + "]" def leave_ObjectValue(self, node, *args): - # type: (ObjectValue, *Any) -> str + # type: (Any, *Any) -> str return "{" + join(node.fields, ", ") + "}" def leave_ObjectField(self, node, *args): - # type: (ObjectField, *Any) -> str + # type: (Any, *Any) -> str return node.name + ": " + node.value # Directive def leave_Directive(self, node, *args): - # type: (Directive, *Any) -> str + # type: (Any, *Any) -> str return "@" + node.name + wrap("(", join(node.arguments, ", "), ")") # Type def leave_NamedType(self, node, *args): - # type: (NamedType, *Any) -> str + # type: (Any, *Any) -> str return node.name def leave_ListType(self, node, *args): - # type: (ListType, *Any) -> str + # type: (Any, *Any) -> str return "[" + node.type + "]" def leave_NonNullType(self, node, *args): - # type: (NonNullType, *Any) -> str + # type: (Any, *Any) -> str return node.type + "!" # Type Definitions: def leave_SchemaDefinition(self, node, *args): - # type: (SchemaDefinition, *Any) -> str + # type: (Any, *Any) -> str return join( ["schema", join(node.directives, " "), block(node.operation_types)], " " ) def leave_OperationTypeDefinition(self, node, *args): - # type: (OperationTypeDefinition, *Any) -> str + # type: (Any, *Any) -> str return "{}: {}".format(node.operation, node.type) def leave_ScalarTypeDefinition(self, node, *args): - # type: (ScalarTypeDefinition, *Any) -> str + # type: (Any, *Any) -> str return "scalar " + node.name + wrap(" ", join(node.directives, " ")) def leave_ObjectTypeDefinition(self, node, *args): - # type: (ObjectTypeDefinition, *Any) -> str + # type: (Any, *Any) -> str return join( [ "type", @@ -213,7 +213,7 @@ def leave_ObjectTypeDefinition(self, node, *args): ) def leave_FieldDefinition(self, node, *args): - # type: (FieldDefinition, *Any) -> str + # type: (Any, *Any) -> str return ( node.name + wrap("(", join(node.arguments, ", "), ")") @@ -223,7 +223,7 @@ def leave_FieldDefinition(self, node, *args): ) def leave_InputValueDefinition(self, node, *args): - # type: (InputValueDefinition, *Any) -> str + # type: (Any, *Any) -> str return ( node.name + ": " @@ -233,7 +233,7 @@ def leave_InputValueDefinition(self, node, *args): ) def leave_InterfaceTypeDefinition(self, node, *args): - # type: (InterfaceTypeDefinition, *Any) -> str + # type: (Any, *Any) -> str return ( "interface " + node.name @@ -243,7 +243,7 @@ def leave_InterfaceTypeDefinition(self, node, *args): ) def leave_UnionTypeDefinition(self, node, *args): - # type: (UnionTypeDefinition, *Any) -> str + # type: (Any, *Any) -> str return ( "union " + node.name @@ -253,7 +253,7 @@ def leave_UnionTypeDefinition(self, node, *args): ) def leave_EnumTypeDefinition(self, node, *args): - # type: (EnumTypeDefinition, *Any) -> str + # type: (Any, *Any) -> str return ( "enum " + node.name @@ -263,11 +263,11 @@ def leave_EnumTypeDefinition(self, node, *args): ) def leave_EnumValueDefinition(self, node, *args): - # type: (EnumValueDefinition, *Any) -> str + # type: (Any, *Any) -> str return node.name + wrap(" ", join(node.directives, " ")) def leave_InputObjectTypeDefinition(self, node, *args): - # type: (InputObjectTypeDefinition, *Any) -> str + # type: (Any, *Any) -> str return ( "input " + node.name @@ -277,11 +277,11 @@ def leave_InputObjectTypeDefinition(self, node, *args): ) def leave_TypeExtensionDefinition(self, node, *args): - # type: (TypeExtensionDefinition, *Any) -> str + # type: (Any, *Any) -> str return "extend " + node.definition def leave_DirectiveDefinition(self, node, *args): - # type: (DirectiveDefinition, *Any) -> str + # type: (Any, *Any) -> str return "directive @{}{} on {}".format( node.name, wrap("(", join(node.arguments, ", "), ")"), @@ -290,7 +290,7 @@ def leave_DirectiveDefinition(self, node, *args): def join(maybe_list, separator=""): - # type: (Optional[List[Optional[str]]], str) -> str + # type: (Optional[List[str]], str) -> str if maybe_list: return separator.join(filter(None, maybe_list)) return "" diff --git a/graphql/language/visitor.py b/graphql/language/visitor.py index ba193046..c2a7bbac 100644 --- a/graphql/language/visitor.py +++ b/graphql/language/visitor.py @@ -150,11 +150,11 @@ def visit(root, visitor, key_map=None): if not is_leaving: stack = Stack(in_array, index, keys, edits, stack) in_array = isinstance(node, list) - keys = ( + keys = ( # type: ignore node if in_array else visitor_keys.get(type(node), None) or [] # type: ignore - ) # type: ignore + ) index = -1 edits = [] diff --git a/graphql/type/introspection.py b/graphql/type/introspection.py index 1dafd80d..1f5d1292 100644 --- a/graphql/type/introspection.py +++ b/graphql/type/introspection.py @@ -20,7 +20,8 @@ if False: # flake8: noqa from ..execution.base import ResolveInfo - from typing import Union, List, Optional, Any + from .definition import GraphQLInputObjectField + from typing import Union, List, Optional, Any, Dict InputField = namedtuple("InputField", ["name", "description", "type", "default_value"]) Field = namedtuple( @@ -29,6 +30,7 @@ def input_fields_to_list(input_fields): + # type: (Dict[str, GraphQLInputObjectField]) -> List[InputField] fields = [] for field_name, field in input_fields.items(): fields.append( @@ -53,7 +55,9 @@ def input_fields_to_list(input_fields): "types", GraphQLField( description="A list of all types supported by this server.", - type=GraphQLNonNull(GraphQLList(GraphQLNonNull(__Type))), + type=GraphQLNonNull( + GraphQLList(GraphQLNonNull(__Type)) # type: ignore + ), resolver=lambda schema, *_: schema.get_type_map().values(), ), ), @@ -61,7 +65,7 @@ def input_fields_to_list(input_fields): "queryType", GraphQLField( description="The type that query operations will be rooted at.", - type=GraphQLNonNull(__Type), + type=GraphQLNonNull(__Type), # type: ignore resolver=lambda schema, *_: schema.get_query_type(), ), ), @@ -70,7 +74,7 @@ def input_fields_to_list(input_fields): GraphQLField( description="If this server supports mutation, the type that " "mutation operations will be rooted at.", - type=__Type, + type=__Type, # type: ignore resolver=lambda schema, *_: schema.get_mutation_type(), ), ), @@ -79,7 +83,7 @@ def input_fields_to_list(input_fields): GraphQLField( description="If this server support subscription, the type " "that subscription operations will be rooted at.", - type=__Type, + type=__Type, # type: ignore resolver=lambda schema, *_: schema.get_subscription_type(), ), ), @@ -87,7 +91,9 @@ def input_fields_to_list(input_fields): "directives", GraphQLField( description="A list of all directives supported by this server.", - type=GraphQLNonNull(GraphQLList(GraphQLNonNull(__Directive))), + type=GraphQLNonNull( + GraphQLList(GraphQLNonNull(__Directive)) # type: ignore + ), resolver=lambda schema, *_: schema.get_directives(), ), ), @@ -115,14 +121,16 @@ def input_fields_to_list(input_fields): "locations", GraphQLField( type=GraphQLNonNull( - GraphQLList(GraphQLNonNull(__DirectiveLocation)) + GraphQLList(GraphQLNonNull(__DirectiveLocation)) # type: ignore ) ), ), ( "args", GraphQLField( - type=GraphQLNonNull(GraphQLList(GraphQLNonNull(__InputValue))), + type=GraphQLNonNull( + GraphQLList(GraphQLNonNull(__InputValue)) # type: ignore + ), resolver=lambda directive, *args: input_fields_to_list( directive.args ), @@ -359,13 +367,11 @@ def fields( return None @staticmethod - def interfaces( - type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] - info, # type: ResolveInfo - ): - # type: (...) -> Optional[Any] + def interfaces(type, info): + # type: (Optional[GraphQLObjectType], ResolveInfo) -> Optional[List[GraphQLInterfaceType]] if isinstance(type, GraphQLObjectType): return type.interfaces + return None @staticmethod def possible_types( @@ -379,24 +385,22 @@ def possible_types( @staticmethod def enum_values( - type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] + type, # type: GraphQLEnumType info, # type: ResolveInfo includeDeprecated=None, # type: bool ): - # type: (...) -> Optional[Any] + # type: (...) -> Optional[List[GraphQLEnumValue]] if isinstance(type, GraphQLEnumType): values = type.values if not includeDeprecated: values = [v for v in values if not v.deprecation_reason] return values + return None @staticmethod - def input_fields( - type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] - info, # type: ResolveInfo - ): - # type: (...) -> Optional[Any] + def input_fields(type, info): + # type: (GraphQLInputObjectType, ResolveInfo) -> List[InputField] if isinstance(type, GraphQLInputObjectType): return input_fields_to_list(type.fields) @@ -416,7 +420,8 @@ def input_fields( ( "kind", GraphQLField( - type=GraphQLNonNull(__TypeKind), resolver=TypeFieldResolvers.kind + type=GraphQLNonNull(__TypeKind), # type: ignore + resolver=TypeFieldResolvers.kind, ), ), ("name", GraphQLField(GraphQLString)), @@ -424,7 +429,7 @@ def input_fields( ( "fields", GraphQLField( - type=GraphQLList(GraphQLNonNull(__Field)), + type=GraphQLList(GraphQLNonNull(__Field)), # type: ignore args={ "includeDeprecated": GraphQLArgument( GraphQLBoolean, default_value=False @@ -436,21 +441,21 @@ def input_fields( ( "interfaces", GraphQLField( - type=GraphQLList(GraphQLNonNull(__Type)), + type=GraphQLList(GraphQLNonNull(__Type)), # type: ignore resolver=TypeFieldResolvers.interfaces, ), ), ( "possibleTypes", GraphQLField( - type=GraphQLList(GraphQLNonNull(__Type)), + type=GraphQLList(GraphQLNonNull(__Type)), # type: ignore resolver=TypeFieldResolvers.possible_types, ), ), ( "enumValues", GraphQLField( - type=GraphQLList(GraphQLNonNull(__EnumValue)), + type=GraphQLList(GraphQLNonNull(__EnumValue)), # type: ignore args={ "includeDeprecated": GraphQLArgument( GraphQLBoolean, default_value=False @@ -462,14 +467,14 @@ def input_fields( ( "inputFields", GraphQLField( - type=GraphQLList(GraphQLNonNull(__InputValue)), + type=GraphQLList(GraphQLNonNull(__InputValue)), # type: ignore resolver=TypeFieldResolvers.input_fields, ), ), ( "ofType", GraphQLField( - type=__Type, + type=__Type, # type: ignore resolver=lambda type, *_: getattr(type, "of_type", None), ), ), @@ -488,11 +493,13 @@ def input_fields( ( "args", GraphQLField( - type=GraphQLNonNull(GraphQLList(GraphQLNonNull(__InputValue))), + type=GraphQLNonNull( + GraphQLList(GraphQLNonNull(__InputValue)) # type: ignore + ), resolver=lambda field, *_: input_fields_to_list(field.args), ), ), - ("type", GraphQLField(GraphQLNonNull(__Type))), + ("type", GraphQLField(GraphQLNonNull(__Type))), # type: ignore ( "isDeprecated", GraphQLField( diff --git a/graphql/type/schema.py b/graphql/type/schema.py index 06cf5be6..54e0a09f 100644 --- a/graphql/type/schema.py +++ b/graphql/type/schema.py @@ -6,7 +6,12 @@ from .typemap import GraphQLTypeMap if False: # flake8: noqa - from .definition import GraphQLInterfaceType, GraphQLUnionType, GraphQLType + from .definition import ( + GraphQLNamedType, + GraphQLInterfaceType, + GraphQLUnionType, + GraphQLType, + ) from typing import Dict, Union, Any, List, Optional @@ -51,7 +56,7 @@ def __init__( mutation=None, # type: Optional[GraphQLObjectType] subscription=None, # type: Optional[GraphQLObjectType] directives=None, # type: Optional[List[GraphQLDirective]] - types=None, # type: Optional[List[GraphQLObjectType]] + types=None, # type: Optional[List[GraphQLNamedType]] ): # type: (...) -> None assert isinstance( @@ -87,10 +92,12 @@ def __init__( ) self._directives = directives - initial_types = [query, mutation, subscription, IntrospectionSchema] + initial_types = list( + filter(None, [query, mutation, subscription, IntrospectionSchema]) + ) # type: List[GraphQLNamedType] if types: initial_types += types - self._type_map = GraphQLTypeMap(initial_types) # type: Dict[str, GraphQLType] + self._type_map = GraphQLTypeMap(initial_types) # type: GraphQLTypeMap def get_query_type(self): # type: () -> GraphQLObjectType @@ -109,8 +116,9 @@ def get_type_map(self): return self._type_map def get_type(self, name): - # type: (str) -> Optional[GraphQLType] + # type: (str) -> Optional[GraphQLNamedType] return self._type_map.get(name) + # raise Exception("Type {name} not found in schema.".format(name=name)) def get_directives(self): # type: () -> List[GraphQLDirective] @@ -124,18 +132,13 @@ def get_directive(self, name): return None - def get_possible_types( - self, - # type: Union[GraphQLInterfaceType, GraphQLUnionType] - abstract_type, - ): - # type: (...) -> List[GraphQLObjectType] + def get_possible_types(self, abstract_type): + # type: (Union[GraphQLInterfaceType, GraphQLUnionType]) -> List[GraphQLObjectType] return self._type_map.get_possible_types(abstract_type) def is_possible_type( self, - # type: Union[GraphQLInterfaceType, GraphQLUnionType] - abstract_type, + abstract_type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] possible_type, # type: GraphQLObjectType ): # type: (...) -> bool diff --git a/graphql/type/tests/test_schema.py b/graphql/type/tests/test_schema.py index 96173f29..5f7b56dc 100644 --- a/graphql/type/tests/test_schema.py +++ b/graphql/type/tests/test_schema.py @@ -39,6 +39,6 @@ def test_throws_human_readable_error_if_schematypes_not_defined(): schema.is_possible_type(interface_type, implementing_type) assert str(exci.value) == ( - "Could not find possible implementing types for $Interface in schema. Check that " + "Could not find possible implementing types for Interface in schema. Check that " "schema.types is defined and is an array ofall possible types in the schema." ) diff --git a/graphql/type/typemap.py b/graphql/type/typemap.py index aebff6f0..146bd129 100644 --- a/graphql/type/typemap.py +++ b/graphql/type/typemap.py @@ -16,28 +16,25 @@ ) if False: # flake8: noqa - from typing import Any, List, Optional, Union + from ..type.definition import GraphQLNamedType + from typing import Any, List, Optional, Union, Dict, Set, Mapping, DefaultDict class GraphQLTypeMap(OrderedDict): - def __init__( - self, - # type: Union[List[Optional[GraphQLObjectType]], List[GraphQLObjectType]] - types, - ): - # type: (...) -> None + def __init__(self, types): + # type: (List[GraphQLNamedType]) -> None super(GraphQLTypeMap, self).__init__() - self.update(reduce(self.reducer, types, OrderedDict())) - self._possible_type_map = defaultdict(set) + self.update(reduce(self.reducer, types, OrderedDict())) # type: ignore + self._possible_type_map = defaultdict(set) # type: DefaultDict[str, Set[str]] # Keep track of all implementations by interface name. - self._implementations = {} + self._implementations = defaultdict( + list + ) # type: DefaultDict[str, List[GraphQLObjectType]] for gql_type in self.values(): if isinstance(gql_type, GraphQLObjectType): for interface in gql_type.interfaces: - self._implementations.setdefault(interface.name, []).append( - gql_type - ) + self._implementations[interface.name].append(gql_type) # Enforce correct interface implementations. for type in self.values(): @@ -45,32 +42,29 @@ def __init__( for interface in type.interfaces: self.assert_object_implements_interface(self, type, interface) - def get_possible_types( - self, - # type: Union[GraphQLInterfaceType, GraphQLUnionType] - abstract_type, - ): - # type: (...) -> List[GraphQLObjectType] + def get_possible_types(self, abstract_type): + # type: (Union[GraphQLInterfaceType, GraphQLUnionType]) -> List[GraphQLObjectType] if isinstance(abstract_type, GraphQLUnionType): return abstract_type.types assert isinstance(abstract_type, GraphQLInterfaceType) - return self._implementations.get(abstract_type.name, None) + if abstract_type.name not in self._implementations: + return [] + return self._implementations[abstract_type.name] def is_possible_type( self, - # type: Union[GraphQLInterfaceType, GraphQLUnionType] - abstract_type, + abstract_type, # type: Union[GraphQLInterfaceType, GraphQLUnionType] possible_type, # type: GraphQLObjectType ): # type: (...) -> bool possible_types = self.get_possible_types(abstract_type) - assert isinstance(possible_types, Sequence), ( - "Could not find possible implementing types for ${} in " + assert possible_types, ( + "Could not find possible implementing types for {} in " + "schema. Check that schema.types is defined and is an array of" + "all possible types in the schema." ).format(abstract_type) - if not self._possible_type_map[abstract_type.name]: + if abstract_type.name not in self._possible_type_map: self._possible_type_map[abstract_type.name].update( [p.name for p in possible_types] ) @@ -79,11 +73,11 @@ def is_possible_type( @classmethod def reducer(cls, map, type): - # type: (OrderedDict, Any) -> OrderedDict + # type: (Dict, Union[GraphQLNamedType, GraphQLList, GraphQLNonNull]) -> Dict if not type: return map - if isinstance(type, GraphQLList) or isinstance(type, GraphQLNonNull): + if isinstance(type, (GraphQLList, GraphQLNonNull)): return cls.reducer(map, type.of_type) if type.name in map: @@ -93,7 +87,7 @@ def reducer(cls, map, type): return map - map[type.name] = type + map[type.name] = type # type: ignore reduced_map = map diff --git a/graphql/utils/is_valid_literal_value.py b/graphql/utils/is_valid_literal_value.py index 9b9f3b41..9e99d587 100644 --- a/graphql/utils/is_valid_literal_value.py +++ b/graphql/utils/is_valid_literal_value.py @@ -17,13 +17,13 @@ def is_valid_literal_value(type, value_ast): - # type: (Union[GraphQLInputObjectType, GraphQLScalarType, GraphQLNonNull], Any) -> List + # type: (Union[GraphQLInputObjectType, GraphQLScalarType, GraphQLNonNull, GraphQLList], Any) -> List if isinstance(type, GraphQLNonNull): of_type = type.of_type if not value_ast: return [u'Expected "{}", found null.'.format(type)] - return is_valid_literal_value(of_type, value_ast) + return is_valid_literal_value(of_type, value_ast) # type: ignore if not value_ast: return _empty_list diff --git a/graphql/utils/type_comparators.py b/graphql/utils/type_comparators.py index b5a4952f..4789df26 100644 --- a/graphql/utils/type_comparators.py +++ b/graphql/utils/type_comparators.py @@ -14,6 +14,7 @@ GraphQLObjectType, GraphQLUnionType, ) + from ..type.typemap import GraphQLTypeMap from ..type.schema import GraphQLSchema from typing import Union @@ -32,7 +33,7 @@ def is_equal_type(type_a, type_b): def is_type_sub_type_of(schema, maybe_subtype, super_type): - # type: (GraphQLSchema, GraphQLScalarType, GraphQLScalarType) -> bool + # type: (Union[GraphQLSchema, GraphQLTypeMap], GraphQLScalarType, GraphQLScalarType) -> bool if maybe_subtype is super_type: return True diff --git a/graphql/utils/type_from_ast.py b/graphql/utils/type_from_ast.py index 51c0bbd6..c76a2d28 100644 --- a/graphql/utils/type_from_ast.py +++ b/graphql/utils/type_from_ast.py @@ -8,21 +8,18 @@ from typing import Any, Union -def type_from_ast(schema, input_type_ast): +def type_from_ast(schema, type_node): # type: (GraphQLSchema, Union[ListType, NamedType, NonNullType]) -> Union[GraphQLList, GraphQLNonNull, GraphQLNamedType] - if isinstance(input_type_ast, ast.ListType): - inner_type = type_from_ast(schema, input_type_ast.type) - if inner_type: - return GraphQLList(inner_type) - else: - return None + if isinstance(type_node, ast.ListType): + inner_type = type_from_ast(schema, type_node.type) + return inner_type and GraphQLList(inner_type) - if isinstance(input_type_ast, ast.NonNullType): - inner_type = type_from_ast(schema, input_type_ast.type) - if inner_type: - return GraphQLNonNull(inner_type) - else: - return None + elif isinstance(type_node, ast.NonNullType): + inner_type = type_from_ast(schema, type_node.type) + return inner_type and GraphQLNonNull(inner_type) # type: ignore - assert isinstance(input_type_ast, ast.NamedType), "Must be a type name." - return schema.get_type(input_type_ast.name.value) + elif isinstance(type_node, ast.NamedType): + schema_type = schema.get_type(type_node.name.value) + return schema_type # type: ignore + + raise Exception("Unexpected type kind: {type_kind}".format(type_kind=type_node)) diff --git a/graphql/validation/rules/__init__.py b/graphql/validation/rules/__init__.py index 329dabf8..e9f19cbf 100644 --- a/graphql/validation/rules/__init__.py +++ b/graphql/validation/rules/__init__.py @@ -25,7 +25,7 @@ if False: # flake8: noqa from typing import List, Type - from ...language.visitor import Visitor + from .base import ValidationRule specified_rules = [ @@ -53,7 +53,7 @@ OverlappingFieldsCanBeMerged, UniqueInputFieldNames, UniqueVariableNames, -] # type: List[Type[Visitor]] +] # type: List[Type[ValidationRule]] __all__ = [ "ArgumentsOfCorrectType", diff --git a/graphql/validation/rules/overlapping_fields_can_be_merged.py b/graphql/validation/rules/overlapping_fields_can_be_merged.py index e427be8c..ae8ee6c5 100644 --- a/graphql/validation/rules/overlapping_fields_can_be_merged.py +++ b/graphql/validation/rules/overlapping_fields_can_be_merged.py @@ -623,7 +623,7 @@ def _get_referenced_fields_and_fragment_names( context.get_schema(), fragment.type_condition ) - return _get_fields_and_fragments_names( + return _get_fields_and_fragments_names( # type: ignore context, cached_fields_and_fragment_names, fragment_type, fragment.selection_set ) @@ -661,9 +661,9 @@ def _collect_fields_and_fragment_names( context.get_schema(), selection.type_condition ) else: - inline_fragment_type = parent_type + inline_fragment_type = parent_type # type: ignore - _collect_fields_and_fragment_names( + _collect_fields_and_fragment_names( # type: ignore context, inline_fragment_type, selection.selection_set, @@ -681,11 +681,11 @@ def _subfield_conflicts( # type: (...) -> Optional[Tuple[Tuple[str, str], List[Node], List[Node]]] """Given a series of Conflicts which occurred between two sub-fields, generate a single Conflict.""" if conflicts: - return ( + return ( # type: ignore (response_name, [conflict[0] for conflict in conflicts]), tuple(itertools.chain([ast1], *[conflict[1] for conflict in conflicts])), tuple(itertools.chain([ast2], *[conflict[2] for conflict in conflicts])), - ) # type: ignore + ) return None diff --git a/graphql/validation/validation.py b/graphql/validation/validation.py index ac890516..9422ad56 100644 --- a/graphql/validation/validation.py +++ b/graphql/validation/validation.py @@ -185,7 +185,7 @@ def get_parent_type(self): def get_input_type(self): # type: () -> Optional[GraphQLInputObjectType] - return self._type_info.get_input_type() + return self._type_info.get_input_type() # type: ignore def get_field_def(self): # type: () -> Optional[GraphQLField]