Skip to content

Commit

Permalink
Improved typing
Browse files Browse the repository at this point in the history
  • Loading branch information
syrusakbary committed Jul 2, 2018
1 parent 697ce76 commit 6a3f900
Show file tree
Hide file tree
Showing 19 changed files with 235 additions and 221 deletions.
85 changes: 47 additions & 38 deletions graphql/execution/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@
def subscribe(*args, **kwargs):
# type: (*Any, **Any) -> Union[ExecutionResult, Observable]
allow_subscriptions = kwargs.pop("allow_subscriptions", True)
return execute(*args, allow_subscriptions=allow_subscriptions, **kwargs)
return execute( # type: ignore
*args, allow_subscriptions=allow_subscriptions, **kwargs
)


def execute(
Expand Down Expand Up @@ -116,7 +118,7 @@ def execute(
allow_subscriptions,
)

def executor(v):
def promise_executor(v):
# type: (Optional[Any]) -> Union[OrderedDict, Promise, Observable]
return execute_operation(exe_context, exe_context.operation, root)

Expand All @@ -135,7 +137,9 @@ def on_resolve(data):

return ExecutionResult(data=data, errors=exe_context.errors)

promise = Promise.resolve(None).then(executor).catch(on_rejected).then(on_resolve)
promise = (
Promise.resolve(None).then(promise_executor).catch(on_rejected).then(on_resolve)
)

if not return_promise:
exe_context.executor.wait_until_finished()
Expand All @@ -151,7 +155,7 @@ def on_resolve(data):
def execute_operation(
exe_context, # type: ExecutionContext
operation, # type: OperationDefinition
root_value, # type: Union[None, Data, type]
root_value, # type: Any
):
# type: (...) -> Union[OrderedDict, Promise]
type = get_operation_root_type(exe_context.schema, operation)
Expand Down Expand Up @@ -224,7 +228,7 @@ def execute_fields(
parent_type, # type: GraphQLObjectType
source_value, # type: Any
fields, # type: DefaultOrderedDict
path, # type: Union[List[Union[int, str]], List[str]]
path, # type: List[Union[int, str]]
info, # type: Optional[ResolveInfo]
):
# type: (...) -> Union[OrderedDict, Promise]
Expand Down Expand Up @@ -257,39 +261,43 @@ def execute_fields(
def subscribe_fields(
exe_context, # type: ExecutionContext
parent_type, # type: GraphQLObjectType
source_value, # type: Union[None, Data, type]
source_value, # type: Any
fields, # type: DefaultOrderedDict
):
# type: (...) -> Observable
exe_context = SubscriberExecutionContext(exe_context)
subscriber_exe_context = SubscriberExecutionContext(exe_context)

def on_error(error):
exe_context.report_error(error)
subscriber_exe_context.report_error(error)

def map_result(
data # type: Union[Dict[str, None], Dict[str, OrderedDict], Dict[str, str]]
):
# type: (...) -> ExecutionResult
if exe_context.errors:
result = ExecutionResult(data=data, errors=exe_context.errors)
if subscriber_exe_context.errors:
result = ExecutionResult(data=data, errors=subscriber_exe_context.errors)
else:
result = ExecutionResult(data=data)
exe_context.reset()
subscriber_exe_context.reset()
return result

observables = []
observables = [] # type: List[Observable]

# assert len(fields) == 1, "Can only subscribe one element at a time."

for response_name, field_asts in fields.items():
result = subscribe_field(
exe_context, parent_type, source_value, field_asts, [response_name]
subscriber_exe_context,
parent_type,
source_value,
field_asts,
[response_name],
)
if result is Undefined:
continue

def catch_error(error):
exe_context.errors.append(error)
subscriber_exe_context.errors.append(error)
return Observable.just(None)

# Map observable results
Expand All @@ -305,10 +313,10 @@ def catch_error(error):
def resolve_field(
exe_context, # type: ExecutionContext
parent_type, # type: GraphQLObjectType
source, # type: Union[None, Cat, Dog]
source, # type: Any
field_asts, # type: List[Field]
parent_info, # type: Optional[ResolveInfo]
field_path, # type: Union[List[Union[int, str]], List[str]]
field_path, # type: List[Union[int, str]]
):
# type: (...) -> Any
field_ast = field_asts[0]
Expand Down Expand Up @@ -360,7 +368,7 @@ def resolve_field(
def subscribe_field(
exe_context, # type: SubscriberExecutionContext
parent_type, # type: GraphQLObjectType
source, # type: Union[None, Data, type]
source, # type: Any
field_asts, # type: List[Field]
path, # type: List[str]
):
Expand Down Expand Up @@ -430,12 +438,12 @@ def subscribe_field(

def resolve_or_error(
resolve_fn, # type: Callable
source, # type: Union[None, Cat, Dog]
source, # type: Any
info, # type: ResolveInfo
args, # type: Dict
executor, # type: Union[BaseExecutor, SyncExecutor]
executor, # type: Any
):
# type: (...) -> Union[List[Union[Cat, Dog]], bool, str]
# type: (...) -> Any
try:
return executor.execute(resolve_fn, source, info, **args)
except Exception as e:
Expand All @@ -444,7 +452,7 @@ def resolve_or_error(
info.parent_type.name, info.field_name
)
)
e.stack = sys.exc_info()[2]
e.stack = sys.exc_info()[2] # type: ignore
return e


Expand All @@ -453,10 +461,10 @@ def complete_value_catching_error(
return_type, # type: Any
field_asts, # type: List[Field]
info, # type: ResolveInfo
path, # type: Union[List[Union[int, str]], List[str]]
path, # type: List[Union[int, str]]
result, # type: Any
):
# type: (...) -> Union[bool, str]
# type: (...) -> Any
# If the field type is non-nullable, then it is resolved without any
# protection from errors.
if isinstance(return_type, GraphQLNonNull):
Expand All @@ -472,7 +480,7 @@ def complete_value_catching_error(

def handle_error(error):
# type: (Union[GraphQLError, GraphQLLocatedError]) -> Optional[Any]
traceback = completed._traceback
traceback = completed._traceback # type: ignore
exe_context.report_error(error, traceback)
return None

Expand All @@ -490,10 +498,10 @@ def complete_value(
return_type, # type: Any
field_asts, # type: List[Field]
info, # type: ResolveInfo
path, # type: Union[List[Union[int, str]], List[str]]
path, # type: List[Union[int, str]]
result, # type: Any
):
# type: (...) -> Union[bool, str]
# type: (...) -> Any
"""
Implements the instructions for completeValue as defined in the
"Field entries" section of the spec.
Expand Down Expand Up @@ -566,10 +574,10 @@ def complete_list_value(
return_type, # type: GraphQLList
field_asts, # type: List[Field]
info, # type: ResolveInfo
path, # type: List[str]
path, # type: List[Union[int, str]]
result, # type: Any
):
# type: (...) -> Any
# type: (...) -> List[Any]
"""
Complete a list value by completing each item in the list with the inner type
"""
Expand Down Expand Up @@ -597,10 +605,10 @@ def complete_list_value(

def complete_leaf_value(
return_type, # type: Union[GraphQLEnumType, GraphQLScalarType]
path, # type: Union[List[Union[int, str]], List[str]]
result, # type: Union[int, str]
path, # type: List[Union[int, str]]
result, # type: Any
):
# type: (...) -> Union[int, str]
# type: (...) -> Union[int, str, float, bool]
"""
Complete a Scalar or Enum by serializing to a valid value, returning null if serialization is not possible.
"""
Expand All @@ -625,12 +633,12 @@ def complete_abstract_value(
path, # type: List[Union[int, str]]
result, # type: Any
):
# type: (...) -> OrderedDict
# type: (...) -> Dict[str, Any]
"""
Complete an value of an abstract type by determining the runtime type of that value, then completing based
on that type.
"""
runtime_type = None
runtime_type = None # type: Union[str, GraphQLObjectType, None]

# Field type must be Object, Interface or Union and expect sub-selections.
if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)):
Expand All @@ -640,7 +648,7 @@ def complete_abstract_value(
runtime_type = get_default_resolve_type_fn(result, info, return_type)

if isinstance(runtime_type, string_types):
runtime_type = info.schema.get_type(runtime_type)
runtime_type = info.schema.get_type(runtime_type) # type: ignore

if not isinstance(runtime_type, GraphQLObjectType):
raise GraphQLError(
Expand Down Expand Up @@ -671,22 +679,23 @@ def get_default_resolve_type_fn(
info, # type: ResolveInfo
abstract_type, # type: Union[GraphQLInterfaceType, GraphQLUnionType]
):
# type: (...) -> GraphQLObjectType
# type: (...) -> Optional[GraphQLObjectType]
possible_types = info.schema.get_possible_types(abstract_type)
for type in possible_types:
if callable(type.is_type_of) and type.is_type_of(value, info):
return type
return None


def complete_object_value(
exe_context, # type: ExecutionContext
return_type, # type: GraphQLObjectType
field_asts, # type: List[Field]
info, # type: ResolveInfo
path, # type: Union[List[Union[int, str]], List[str]]
path, # type: List[Union[int, str]]
result, # type: Any
):
# type: (...) -> Union[OrderedDict, Promise]
# type: (...) -> Dict[str, Any]
"""
Complete an Object value by evaluating all sub-selections.
"""
Expand All @@ -708,7 +717,7 @@ def complete_nonnull_value(
return_type, # type: GraphQLNonNull
field_asts, # type: List[Field]
info, # type: ResolveInfo
path, # type: Union[List[Union[int, str]], List[str]]
path, # type: List[Union[int, str]]
result, # type: Any
):
# type: (...) -> Any
Expand Down
8 changes: 4 additions & 4 deletions graphql/execution/executors/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

if False: # flake8: noqa
from asyncio.unix_events import _UnixSelectorEventLoop
from typing import Optional, Any, Callable
from typing import Optional, Any, Callable, List

try:
from asyncio import ensure_future
except ImportError:
# ensure_future is only implemented in Python 3.4.4+
def ensure_future(coro_or_future, loop=None):
def ensure_future(coro_or_future, loop=None): # type: ignore
"""Wrap a coroutine or an awaitable in a future.
If the argument is a Future, it is returned directly.
Expand All @@ -39,7 +39,7 @@ def ensure_future(coro_or_future, loop=None):
def isasyncgen(obj):
False

def asyncgen_to_observable(asyncgen):
def asyncgen_to_observable(asyncgen, loop=None):
pass


Expand All @@ -49,7 +49,7 @@ def __init__(self, loop=None):
if loop is None:
loop = get_event_loop()
self.loop = loop
self.futures = []
self.futures = [] # type: List[Future]

def wait_until_finished(self):
# type: () -> None
Expand Down
7 changes: 4 additions & 3 deletions graphql/execution/executors/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .utils import process

if False: # flake8: noqa
from typing import Any, Callable
from typing import Any, Callable, List


class ThreadExecutor(object):
Expand All @@ -14,7 +14,7 @@ class ThreadExecutor(object):

def __init__(self, pool=False):
# type: (bool) -> None
self.threads = []
self.threads = [] # type: List[Thread]
if pool:
self.execute = self.execute_in_pool
self.pool = ThreadPool(processes=pool)
Expand All @@ -26,7 +26,8 @@ def wait_until_finished(self):
while self.threads:
threads = self.threads
self.threads = []
[thread.join() for thread in threads]
for thread in threads:
thread.join()

def clean(self):
self.threads = []
Expand Down
4 changes: 2 additions & 2 deletions graphql/execution/executors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
if False: # flake8: noqa
from ..base import ResolveInfo
from promise import Promise
from typing import Callable, Dict, Tuple, Union
from typing import Callable, Dict, Tuple, Union, Any


def process(
Expand All @@ -18,5 +18,5 @@ def process(
p.do_resolve(val)
except Exception as e:
traceback = exc_info()[2]
e.stack = traceback
e.stack = traceback # type: ignore
p.do_reject(e, traceback=traceback)
Loading

0 comments on commit 6a3f900

Please sign in to comment.