Skip to content

Commit

Permalink
fix: enable async cleaned up with hook points
Browse files Browse the repository at this point in the history
  • Loading branch information
Cameron Hurst committed Jan 6, 2021
1 parent c03e1a4 commit bdb41ea
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 40 deletions.
33 changes: 9 additions & 24 deletions graphql_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from graphql.error import format_error as format_error_default
from graphql.execution import ExecutionResult, execute
from graphql.language import OperationType, parse
from graphql.pyutils import AwaitableOrValue
from graphql.pyutils import AwaitableOrValue, is_awaitable
from graphql.type import GraphQLSchema, validate_schema
from graphql.utilities import get_operation_ast
from graphql.validation import ASTValidationRule, validate
Expand Down Expand Up @@ -99,9 +99,7 @@ def run_http_query(

if not is_batch:
if not isinstance(data, (dict, MutableMapping)):
raise HttpQueryError(
400, f"GraphQL params should be a dict. Received {data!r}."
)
raise HttpQueryError(400, f"GraphQL params should be a dict. Received {data!r}.")
data = [data]
elif not batch_enabled:
raise HttpQueryError(400, "Batch GraphQL requests are not enabled.")
Expand All @@ -114,15 +112,10 @@ def run_http_query(
if not is_batch:
extra_data = query_data or {}

all_params: List[GraphQLParams] = [
get_graphql_params(entry, extra_data) for entry in data
]
all_params: List[GraphQLParams] = [get_graphql_params(entry, extra_data) for entry in data]

results: List[Optional[AwaitableOrValue[ExecutionResult]]] = [
get_response(
schema, params, catch_exc, allow_only_query, run_sync, **execute_options
)
for params in all_params
get_response(schema, params, catch_exc, allow_only_query, run_sync, **execute_options) for params in all_params
]
return GraphQLResponse(results, all_params)

Expand Down Expand Up @@ -160,10 +153,7 @@ def encode_execution_results(
Returns a ServerResponse tuple with the serialized response as the first item and
a status code of 200 or 400 in case any result was invalid as the second item.
"""
results = [
format_execution_result(execution_result, format_error)
for execution_result in execution_results
]
results = [format_execution_result(execution_result, format_error) for execution_result in execution_results]
result, status_codes = zip(*results)
status_code = max(status_codes)

Expand Down Expand Up @@ -274,14 +264,11 @@ def get_response(
if operation != OperationType.QUERY.value:
raise HttpQueryError(
405,
f"Can only perform a {operation} operation"
" from a POST request.",
f"Can only perform a {operation} operation" " from a POST request.",
headers={"Allow": "POST"},
)

validation_errors = validate(
schema, document, rules=validation_rules, max_errors=max_errors
)
validation_errors = validate(schema, document, rules=validation_rules, max_errors=max_errors)
if validation_errors:
return ExecutionResult(data=None, errors=validation_errors)

Expand All @@ -290,7 +277,7 @@ def get_response(
document,
variable_values=params.variables,
operation_name=params.operation_name,
is_awaitable=assume_not_awaitable if run_sync else None,
is_awaitable=assume_not_awaitable if run_sync else is_awaitable,
**kwargs,
)

Expand All @@ -317,9 +304,7 @@ def format_execution_result(
fe = [format_error(e) for e in execution_result.errors] # type: ignore
response = {"errors": fe}

if execution_result.errors and any(
not getattr(e, "path", None) for e in execution_result.errors
):
if execution_result.errors and any(not getattr(e, "path", None) for e in execution_result.errors):
status_code = 400
else:
response["data"] = execution_result.data
Expand Down
36 changes: 20 additions & 16 deletions graphql_server/flask/graphqlview.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import asyncio
import copy
from collections.abc import MutableMapping
from functools import partial
from typing import List

from flask import Response, render_template_string, request
from flask.views import View
from graphql import ExecutionResult
from graphql.error import GraphQLError
from graphql.pyutils import is_awaitable
from graphql.type.schema import GraphQLSchema

from graphql_server import (
Expand Down Expand Up @@ -41,6 +44,7 @@ class GraphQLView(View):
default_query = None
header_editor_enabled = None
should_persist_headers = None
enable_async = True

methods = ["GET", "POST", "PUT", "DELETE"]

Expand All @@ -53,26 +57,27 @@ def __init__(self, **kwargs):
if hasattr(self, key):
setattr(self, key, value)

assert isinstance(
self.schema, GraphQLSchema
), "A Schema is required to be provided to GraphQLView."
assert isinstance(self.schema, GraphQLSchema), "A Schema is required to be provided to GraphQLView."

def get_root_value(self):
return self.root_value

def get_context(self):
context = (
copy.copy(self.context)
if self.context and isinstance(self.context, MutableMapping)
else {}
)
context = copy.copy(self.context) if self.context and isinstance(self.context, MutableMapping) else {}
if isinstance(context, MutableMapping) and "request" not in context:
context.update({"request": request})
return context

def get_middleware(self):
return self.middleware

@staticmethod
def get_async_execution_results(execution_results):
async def await_execution_results(execution_results):
return [ex if ex is None or is_awaitable(ex) else await ex for ex in execution_results]

return asyncio.run(await_execution_results(execution_results))

def dispatch_request(self):
try:
request_method = request.method.lower()
Expand All @@ -96,6 +101,11 @@ def dispatch_request(self):
context_value=self.get_context(),
middleware=self.get_middleware(),
)

if self.enable_async:
if any(is_awaitable(ex) for ex in execution_results):
execution_results = self.get_async_execution_results(execution_results)

result, status_code = encode_execution_results(
execution_results,
is_batch=isinstance(data, list),
Expand Down Expand Up @@ -123,9 +133,7 @@ def dispatch_request(self):
header_editor_enabled=self.header_editor_enabled,
should_persist_headers=self.should_persist_headers,
)
source = render_graphiql_sync(
data=graphiql_data, config=graphiql_config, options=graphiql_options
)
source = render_graphiql_sync(data=graphiql_data, config=graphiql_config, options=graphiql_options)
return render_template_string(source)

return Response(result, status=status_code, content_type="application/json")
Expand Down Expand Up @@ -167,8 +175,4 @@ def should_display_graphiql(self):
@staticmethod
def request_wants_html():
best = request.accept_mimetypes.best_match(["application/json", "text/html"])
return (
best == "text/html"
and request.accept_mimetypes[best]
> request.accept_mimetypes["application/json"]
)
return best == "text/html" and request.accept_mimetypes[best] > request.accept_mimetypes["application/json"]

0 comments on commit bdb41ea

Please sign in to comment.