Skip to content

Commit

Permalink
Merge branch 'master' into feature_add_make_check
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed May 16, 2020
2 parents a3ff05a + 70e54a4 commit 3f7e0b0
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 57 deletions.
97 changes: 64 additions & 33 deletions gql/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from inspect import isawaitable
from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union, cast

from graphql import (
Expand Down Expand Up @@ -35,11 +36,6 @@ def __init__(
assert (
not schema
), "Cant fetch the schema from transport if is already provided"
if isinstance(transport, Transport):
# For sync transports, we fetch the schema directly
execution_result = transport.execute(parse(get_introspection_query()))
execution_result = cast(ExecutionResult, execution_result)
introspection = execution_result.data
if introspection:
assert not schema, "Cant provide introspection and schema at the same time"
schema = build_client_schema(introspection)
Expand Down Expand Up @@ -68,6 +64,10 @@ def __init__(
# Enforced timeout of the execute function
self.execute_timeout = execute_timeout

if isinstance(transport, Transport) and fetch_schema_from_transport:
with self as session:
session.fetch_schema()

def validate(self, document):
if not self.schema:
raise Exception(
Expand All @@ -77,6 +77,10 @@ def validate(self, document):
if validation_errors:
raise validation_errors[0]

def execute_sync(self, document: DocumentNode, *args, **kwargs) -> Dict:
with self as session:
return session.execute(document, *args, **kwargs)

async def execute_async(self, document: DocumentNode, *args, **kwargs) -> Dict:
async with self as session:
return await session.execute(document, *args, **kwargs)
Expand Down Expand Up @@ -107,22 +111,7 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:
return data

else: # Sync transports

if self.schema:
self.validate(document)

assert self.transport is not None, "Cannot execute without a transport"

result = self.transport.execute(document, *args, **kwargs)

if result.errors:
raise TransportQueryError(str(result.errors[0]))

assert (
result.data is not None
), "Transport returned an ExecutionResult without data or errors"

return result.data
return self.execute_sync(document, *args, **kwargs)

async def subscribe_async(
self, document: DocumentNode, *args, **kwargs
Expand Down Expand Up @@ -170,30 +159,72 @@ async def __aenter__(self):
await self.transport.connect()

if not hasattr(self, "session"):
self.session = ClientSession(client=self)
self.session = AsyncClientSession(client=self)

return self.session

async def __aexit__(self, exc_type, exc, tb):

await self.transport.close()

def close(self):
"""Close the client and it's underlying transport (only for Sync transports)"""
if not isinstance(self.transport, AsyncTransport):
self.transport.close()

def __enter__(self):

assert not isinstance(
self.transport, AsyncTransport
), "Only a sync transport can be use. Use 'async with Client(...)' instead"
return self

self.transport.connect()

if not hasattr(self, "session"):
self.session = SyncClientSession(client=self)

return self.session

def __exit__(self, *args):
self.close()
self.transport.close()


class SyncClientSession:
"""An instance of this class is created when using 'with' on the client.
It contains the sync method execute to send queries
with the sync transports.
"""

def __init__(self, client: Client):
self.client = client

def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:

# Validate document
if self.client.schema:
self.client.validate(document)

result = self.transport.execute(document, *args, **kwargs)

assert not isawaitable(result), "Transport returned an awaitable result."
result = cast(ExecutionResult, result)

if result.errors:
raise TransportQueryError(str(result.errors[0]))

assert (
result.data is not None
), "Transport returned an ExecutionResult without data or errors"

return result.data

def fetch_schema(self) -> None:
execution_result = self.transport.execute(parse(get_introspection_query()))
self.client.introspection = execution_result.data
self.client.schema = build_client_schema(self.client.introspection)

@property
def transport(self):
return self.client.transport


class ClientSession:
class AsyncClientSession:
"""An instance of this class is created when using 'async with' on the client.
It contains the async methods (execute, subscribe) to send queries
Expand All @@ -203,7 +234,7 @@ class ClientSession:
def __init__(self, client: Client):
self.client = client

async def validate(self, document: DocumentNode):
async def fetch_and_validate(self, document: DocumentNode):
"""Fetch schema from transport if needed and validate document.
If no schema is present, the validation will be skipped.
Expand All @@ -222,7 +253,7 @@ async def subscribe(
) -> AsyncGenerator[Dict, None]:

# Fetch schema from transport if needed and validate document if possible
await self.validate(document)
await self.fetch_and_validate(document)

# Subscribe to the transport and yield data or raise error
self._generator: AsyncGenerator[
Expand All @@ -243,7 +274,7 @@ async def subscribe(
async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:

# Fetch schema from transport if needed and validate document if possible
await self.validate(document)
await self.fetch_and_validate(document)

# Execute the query with the transport with a timeout
result = await asyncio.wait_for(
Expand Down
70 changes: 49 additions & 21 deletions gql/transport/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@

from gql.transport import Transport

from .exceptions import (
TransportAlreadyConnected,
TransportClosed,
TransportProtocolError,
TransportServerError,
)


class RequestsHTTPTransport(Transport):
"""Transport to execute GraphQL queries on remote servers.
Expand Down Expand Up @@ -58,23 +65,32 @@ def __init__(
self.use_json = use_json
self.default_timeout = timeout
self.verify = verify
self.retries = retries
self.method = method
self.kwargs = kwargs

# Creating a session that can later be re-use to configure custom mechanisms
self.session = requests.Session()
self.session = None

def connect(self):

if self.session is None:

# If we specified some retries, we provide a predefined retry-logic
if retries > 0:
adapter = HTTPAdapter(
max_retries=Retry(
total=retries,
backoff_factor=0.1,
status_forcelist=[500, 502, 503, 504],
# Creating a session that can later be re-use to configure custom mechanisms
self.session = requests.Session()

# If we specified some retries, we provide a predefined retry-logic
if self.retries > 0:
adapter = HTTPAdapter(
max_retries=Retry(
total=self.retries,
backoff_factor=0.1,
status_forcelist=[500, 502, 503, 504],
)
)
)
for prefix in "http://", "https://":
self.session.mount(prefix, adapter)
for prefix in "http://", "https://":
self.session.mount(prefix, adapter)
else:
raise TransportAlreadyConnected("Transport is already connected")

def execute( # type: ignore
self,
Expand All @@ -94,6 +110,10 @@ def execute( # type: ignore
`data` is the result of executing the query, `errors` is null
if no errors occurred, and is a non-empty array if an error occurred.
"""

if not self.session:
raise TransportClosed("Transport is not connected")

query_str = print_ast(document)
payload = {"query": query_str, "variables": variable_values or {}}

Expand All @@ -116,18 +136,26 @@ def execute( # type: ignore
)
try:
result = response.json()
if not isinstance(result, dict):
raise ValueError
except ValueError:
result = {}
except Exception:
# We raise a TransportServerError if the status code is 400 or higher
# We raise a TransportProtocolError in the other cases

try:
# Raise a requests.HTTPerror if response status is 400 or higher
response.raise_for_status()

except requests.HTTPError as e:
raise TransportServerError(str(e))

raise TransportProtocolError("Server did not return a GraphQL result")

if "errors" not in result and "data" not in result:
response.raise_for_status()
raise requests.HTTPError(
"Server did not return a GraphQL result", response=response
)
raise TransportProtocolError("Server did not return a GraphQL result")

return ExecutionResult(errors=result.get("errors"), data=result.get("data"))

def close(self):
"""Closing the transport by closing the inner session"""
self.session.close()
if self.session:
self.session.close()
self.session = None
5 changes: 5 additions & 0 deletions gql/transport/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult:
"Any Transport subclass must implement execute method"
) # pragma: no cover

def connect(self):
"""Establish a session with the transport.
"""
pass

def close(self):
"""Close the transport
Expand Down
5 changes: 2 additions & 3 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def test_retries_on_transport(execute_mock):
}
"""
)
with client: # We're using the client as context manager
with client as session: # We're using the client as context manager
with pytest.raises(Exception):
client.execute(query)
session.execute(query)

# This might look strange compared to the previous test, but making 3 retries
# means you're actually doing 4 calls.
Expand Down Expand Up @@ -98,7 +98,6 @@ def test_execute_result_error():

with pytest.raises(Exception) as exc_info:
client.execute(failing_query)
client.close()
assert 'Cannot query field "id" on type "Continent".' in str(exc_info.value)


Expand Down

0 comments on commit 3f7e0b0

Please sign in to comment.