Skip to content

Commit

Permalink
[dagit] refactor graphql execute methods (#7065)
Browse files Browse the repository at this point in the history
decompose the graphql server impl a bit

## Test Plan

existing bk suite
  • Loading branch information
alangenfeld committed Mar 16, 2022
1 parent 61c43eb commit 15b5ad9
Showing 1 changed file with 72 additions and 52 deletions.
124 changes: 72 additions & 52 deletions python_modules/dagit/dagit/graphql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from asyncio import Queue, get_event_loop
from asyncio import Queue, Task, get_event_loop
from enum import Enum
from typing import Any, AsyncGenerator, Dict, List, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union

from dagit.templates.playground import TEMPLATE
from graphene import Schema
Expand Down Expand Up @@ -119,14 +119,7 @@ async def graphql_http_endpoint(self, request: Request):
status_code=status.HTTP_400_BAD_REQUEST,
)

result = await run_in_threadpool(
self._graphql_schema.execute,
query,
variables=variables,
operation_name=operation_name,
context=self.make_request_context(request),
middleware=self._graphql_middleware,
)
result = await self.execute_graphql_request(request, query, variables, operation_name)

response_data = {"data": result.data}
status_code = status.HTTP_200_OK
Expand All @@ -143,9 +136,7 @@ async def graphql_ws_endpoint(self, websocket: WebSocket):
Once we are free of conflicting deps, we should be able to use an impl from
strawberry-graphql or the like.
"""
observables = {}
tasks = {}
event_loop = get_event_loop()
tasks: Dict[str, Task] = {}

await websocket.accept(subprotocol=GraphQLWS.PROTOCOL)

Expand All @@ -164,60 +155,89 @@ async def graphql_ws_endpoint(self, websocket: WebSocket):
elif message_type == GraphQLWS.CONNECTION_TERMINATE:
await websocket.close()
elif message_type == GraphQLWS.START:
try:
data = message["payload"]
query = data["query"]
variables = data.get("variables")
operation_name = data.get("operation_name")

request_context = self.make_request_context(websocket)
async_result = self._graphql_schema.execute(
query,
variables=variables,
operation_name=operation_name,
context=request_context,
allow_subscriptions=True,
)
except GraphQLError as error:
payload = format_graphql_error(error)
await _send_message(websocket, GraphQLWS.ERROR, payload, operation_id)
data = message["payload"]

task, error_payload = self.execute_graphql_subscription(
websocket=websocket,
operation_id=operation_id,
query=data["query"],
variables=data.get("variables"),
operation_name=data.get("operation_name"),
)
if error_payload:
await _send_message(websocket, GraphQLWS.ERROR, error_payload, operation_id)
continue

if isinstance(async_result, ExecutionResult):
if not async_result.errors:
check.failed(
f"Only expect non-async result on error, got {async_result}"
)
payload = format_graphql_error(async_result.errors[0]) # type: ignore
await _send_message(websocket, GraphQLWS.ERROR, payload, operation_id)
continue
assert task is not None

# in the future we should get back async gen directly, back compat for now
disposable, async_gen = _disposable_and_async_gen_from_obs(
async_result, event_loop
)
tasks[operation_id] = task

observables[operation_id] = disposable
tasks[operation_id] = event_loop.create_task(
_handle_async_results(async_gen, operation_id, websocket)
)
elif message_type == GraphQLWS.STOP:
if operation_id not in observables:
if operation_id not in tasks:
return

observables[operation_id].dispose()
del observables[operation_id]

tasks[operation_id].cancel()
del tasks[operation_id]

except WebSocketDisconnect:
pass
finally:
for operation_id in observables:
observables[operation_id].dispose()
for operation_id in tasks:
tasks[operation_id].cancel()

async def execute_graphql_request(
self,
request: Request,
query: str,
variables: Optional[Dict[str, Any]],
operation_name: Optional[str],
) -> ExecutionResult:
# use run_in_threadpool since underlying schema is sync
return await run_in_threadpool(
self._graphql_schema.execute,
query,
variables=variables,
operation_name=operation_name,
context=self.make_request_context(request),
middleware=self._graphql_middleware,
)

def execute_graphql_subscription(
self,
websocket: WebSocket,
operation_id: str,
query: str,
variables: Optional[Dict[str, Any]],
operation_name: Optional[str],
) -> Tuple[Optional[Task], Optional[Dict[str, Any]]]:
request_context = self.make_request_context(websocket)
try:
async_result = self._graphql_schema.execute(
query,
variables=variables,
operation_name=operation_name,
context=request_context,
allow_subscriptions=True,
)
except GraphQLError as error:
error_payload = format_graphql_error(error)
return None, error_payload

if isinstance(async_result, ExecutionResult):
if not async_result.errors:
check.failed(f"Only expect non-async result on error, got {async_result}")
error_payload = format_graphql_error(async_result.errors[0]) # type: ignore
return None, error_payload

# in the future we should get back async gen directly, back compat for now
disposable, async_gen = _disposable_and_async_gen_from_obs(async_result, get_event_loop())
task = get_event_loop().create_task(
_handle_async_results(async_gen, operation_id, websocket)
)
task.add_done_callback(lambda _: disposable.dispose())

return task, None

def create_asgi_app(
self,
**kwargs,
Expand Down

0 comments on commit 15b5ad9

Please sign in to comment.