Skip to content

Commit

Permalink
feat: flask asyncio support for dataloaders
Browse files Browse the repository at this point in the history
  • Loading branch information
Cameron Hurst committed Oct 8, 2020
1 parent 482f21b commit 734b3f0
Showing 1 changed file with 44 additions and 32 deletions.
76 changes: 44 additions & 32 deletions graphql_server/flask/graphqlview.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import copy
from collections.abc import MutableMapping
from functools import partial
Expand All @@ -6,6 +7,7 @@
from flask import Response, render_template_string, request
from flask.views import View
from graphql.error import GraphQLError
from graphql.pyutils import is_awaitable
from graphql.type.schema import GraphQLSchema

from graphql_server import (
Expand Down Expand Up @@ -41,6 +43,7 @@ class GraphQLView(View):
default_query = None
header_editor_enabled = None
should_persist_headers = None
enable_async = False

methods = ["GET", "POST", "PUT", "DELETE"]

Expand All @@ -53,26 +56,51 @@ def __init__(self, **kwargs):
if hasattr(self, key):
setattr(self, key, value)

assert isinstance(
self.schema, GraphQLSchema
), "A Schema is required to be provided to GraphQLView."
assert isinstance(self.schema, GraphQLSchema), "A Schema is required to be provided to GraphQLView."

def get_root_value(self):
return self.root_value

def get_context(self):
context = (
copy.copy(self.context)
if self.context and isinstance(self.context, MutableMapping)
else {}
)
context = copy.copy(self.context) if self.context and isinstance(self.context, MutableMapping) else {}
if isinstance(context, MutableMapping) and "request" not in context:
context.update({"request": request})
return context

def get_middleware(self):
return self.middleware

def result_results(self, request_method, data, catch):
return run_http_query(
self.schema,
request_method,
data,
query_data=request.args,
batch_enabled=self.batch,
catch=catch,
# Execute options
root_value=self.get_root_value(),
context_value=self.get_context(),
middleware=self.get_middleware(),
run_sync=not self.enable_async,
)

async def resolve_results_async(self, request_method, data, catch):
execution_results, all_params = run_http_query(
self.schema,
request_method,
data,
query_data=request.args,
batch_enabled=self.batch,
catch=catch,
# Execute options
root_value=self.get_root_value(),
context_value=self.get_context(),
middleware=self.get_middleware(),
run_sync=not self.enable_async,
)
return [await ex if is_awaitable(ex) else ex for ex in execution_results], all_params

def dispatch_request(self):
try:
request_method = request.method.lower()
Expand All @@ -84,18 +112,11 @@ def dispatch_request(self):
pretty = self.pretty or show_graphiql or request.args.get("pretty")

all_params: List[GraphQLParams]
execution_results, all_params = run_http_query(
self.schema,
request_method,
data,
query_data=request.args,
batch_enabled=self.batch,
catch=catch,
# Execute options
root_value=self.get_root_value(),
context_value=self.get_context(),
middleware=self.get_middleware(),
)
if self.enable_async:
execution_results, all_params = asyncio.run(self.resolve_results_async(request_method, data, catch))
else:
execution_results, all_params = self.result_results(request_method, data, catch)

result, status_code = encode_execution_results(
execution_results,
is_batch=isinstance(data, list),
Expand Down Expand Up @@ -123,9 +144,7 @@ def dispatch_request(self):
header_editor_enabled=self.header_editor_enabled,
should_persist_headers=self.should_persist_headers,
)
source = render_graphiql_sync(
data=graphiql_data, config=graphiql_config, options=graphiql_options
)
source = render_graphiql_sync(data=graphiql_data, config=graphiql_config, options=graphiql_options)
return render_template_string(source)

return Response(result, status=status_code, content_type="application/json")
Expand All @@ -150,10 +169,7 @@ def parse_body(self):
elif content_type == "application/json":
return load_json_body(request.data.decode("utf8"))

elif content_type in (
"application/x-www-form-urlencoded",
"multipart/form-data",
):
elif content_type in ("application/x-www-form-urlencoded", "multipart/form-data",):
return request.form

return {}
Expand All @@ -166,8 +182,4 @@ def should_display_graphiql(self):

def request_wants_html(self):
best = request.accept_mimetypes.best_match(["application/json", "text/html"])
return (
best == "text/html"
and request.accept_mimetypes[best]
> request.accept_mimetypes["application/json"]
)
return best == "text/html" and request.accept_mimetypes[best] > request.accept_mimetypes["application/json"]

0 comments on commit 734b3f0

Please sign in to comment.