Skip to content
266 changes: 238 additions & 28 deletions gql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,24 @@ def _build_schema_from_introspection(
self.introspection = cast(IntrospectionQuery, execution_result.data)
self.schema = build_client_schema(self.introspection)

@staticmethod
def _get_event_loop() -> asyncio.AbstractEventLoop:
"""Get the current asyncio event loop.

Or create a new event loop if there isn't one (in a new Thread).
"""
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="There is no current event loop"
)
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

return loop

@overload
def execute_sync(
self,
Expand Down Expand Up @@ -358,6 +376,58 @@ async def execute_async(
**kwargs,
)

@overload
async def execute_batch_async(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: Literal[False] = ...,
**kwargs: Any,
) -> List[Dict[str, Any]]: ... # pragma: no cover

@overload
async def execute_batch_async(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: Literal[True],
**kwargs: Any,
) -> List[ExecutionResult]: ... # pragma: no cover

@overload
async def execute_batch_async(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: bool,
**kwargs: Any,
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover

async def execute_batch_async(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: bool = False,
**kwargs: Any,
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]:
""":meta private:"""
async with self as session:
return await session.execute_batch(
requests,
serialize_variables=serialize_variables,
parse_result=parse_result,
get_execution_result=get_execution_result,
**kwargs,
)

@overload
def execute(
self,
Expand Down Expand Up @@ -430,17 +500,7 @@ def execute(
"""

if isinstance(self.transport, AsyncTransport):
# Get the current asyncio event loop
# Or create a new event loop if there isn't one (in a new Thread)
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="There is no current event loop"
)
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = self._get_event_loop()

assert not loop.is_running(), (
"Cannot run client.execute(query) if an asyncio loop is running."
Expand Down Expand Up @@ -537,7 +597,24 @@ def execute_batch(
"""

if isinstance(self.transport, AsyncTransport):
raise NotImplementedError("Batching is not implemented for async yet.")
loop = self._get_event_loop()

assert not loop.is_running(), (
"Cannot run client.execute_batch(query) if an asyncio loop is running."
" Use 'await client.execute_batch(query)' instead."
)

data = loop.run_until_complete(
self.execute_batch_async(
requests,
serialize_variables=serialize_variables,
parse_result=parse_result,
get_execution_result=get_execution_result,
**kwargs,
)
)

return data

else: # Sync transports
return self.execute_batch_sync(
Expand Down Expand Up @@ -675,17 +752,12 @@ def subscribe(
We need an async transport for this functionality.
"""

# Get the current asyncio event loop
# Or create a new event loop if there isn't one (in a new Thread)
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="There is no current event loop"
)
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop = self._get_event_loop()

assert not loop.is_running(), (
"Cannot run client.subscribe(query) if an asyncio loop is running."
" Use 'await client.subscribe_async(query)' instead."
)

async_generator: Union[
AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None]
Expand All @@ -699,11 +771,6 @@ def subscribe(
**kwargs,
)

assert not loop.is_running(), (
"Cannot run client.subscribe(query) if an asyncio loop is running."
" Use 'await client.subscribe_async(query)' instead."
)

try:
while True:
# Note: we need to create a task here in order to be able to close
Expand Down Expand Up @@ -1626,6 +1693,149 @@ async def execute(

return result.data

async def _execute_batch(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
validate_document: Optional[bool] = True,
**kwargs: Any,
) -> List[ExecutionResult]:
"""Execute multiple GraphQL requests in a batch, using
the async transport, returning a list of ExecutionResult objects.

:param requests: List of requests that will be executed.
:param serialize_variables: whether the variable values should be
serialized. Used for custom scalars and/or enums.
By default use the serialize_variables argument of the client.
:param parse_result: Whether gql will deserialize the result.
By default use the parse_results argument of the client.
:param validate_document: Whether we still need to validate the document.

The extra arguments are passed to the transport execute_batch method."""

# Validate document
if self.client.schema:

if validate_document:
for req in requests:
self.client.validate(req.document)

# Parse variable values for custom scalars if requested
if serialize_variables or (
serialize_variables is None and self.client.serialize_variables
):
requests = [
(
req.serialize_variable_values(self.client.schema)
if req.variable_values is not None
else req
)
for req in requests
]

results = await self.transport.execute_batch(requests, **kwargs)

# Unserialize the result if requested
if self.client.schema:
if parse_result or (parse_result is None and self.client.parse_results):
for result in results:
result.data = parse_result_fn(
self.client.schema,
req.document,
result.data,
operation_name=req.operation_name,
)

return results

@overload
async def execute_batch(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: Literal[False] = ...,
**kwargs: Any,
) -> List[Dict[str, Any]]: ... # pragma: no cover

@overload
async def execute_batch(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: Literal[True],
**kwargs: Any,
) -> List[ExecutionResult]: ... # pragma: no cover

@overload
async def execute_batch(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: bool,
**kwargs: Any,
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover

async def execute_batch(
self,
requests: List[GraphQLRequest],
*,
serialize_variables: Optional[bool] = None,
parse_result: Optional[bool] = None,
get_execution_result: bool = False,
**kwargs: Any,
) -> Union[List[Dict[str, Any]], List[ExecutionResult]]:
"""Execute multiple GraphQL requests in a batch, using
the async transport. This method sends the requests to the server all at once.

Raises a TransportQueryError if an error has been returned in any
ExecutionResult.

:param requests: List of requests that will be executed.
:param serialize_variables: whether the variable values should be
serialized. Used for custom scalars and/or enums.
By default use the serialize_variables argument of the client.
:param parse_result: Whether gql will deserialize the result.
By default use the parse_results argument of the client.
:param get_execution_result: return the full ExecutionResult instance instead of
only the "data" field. Necessary if you want to get the "extensions" field.

The extra arguments are passed to the transport execute method."""

# Validate and execute on the transport
results = await self._execute_batch(
requests,
serialize_variables=serialize_variables,
parse_result=parse_result,
**kwargs,
)

for result in results:
# Raise an error if an error is returned in the ExecutionResult object
if result.errors:
raise TransportQueryError(
str_first_element(result.errors),
errors=result.errors,
data=result.data,
extensions=result.extensions,
)

assert (
result.data is not None
), "Transport returned an ExecutionResult without data or errors"

if get_execution_result:
return results

return cast(List[Dict[str, Any]], [result.data for result in results])

async def fetch_schema(self) -> None:
"""Fetch the GraphQL schema explicitly using introspection.

Expand Down
Loading
Loading