Skip to content

Commit

Permalink
add tests, fix testing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
drice committed Aug 21, 2019
1 parent 7c67db0 commit ed8ff5b
Show file tree
Hide file tree
Showing 12 changed files with 438 additions and 6 deletions.
62 changes: 61 additions & 1 deletion ariadne/contrib/tracing/apollotracing.py
Expand Up @@ -4,7 +4,7 @@

from graphql import GraphQLResolveInfo

from ...types import ContextValue, Extension, Resolver
from ...types import ContextValue, Extension, ExtensionSync, Resolver
from .utils import format_path, should_trace

try:
Expand Down Expand Up @@ -86,3 +86,63 @@ def format(self):
"execution": {"resolvers": totals["resolvers"]},
}
}


class ApolloTracingExtensionSync(ExtensionSync):
def __init__(self):
self.start_date = None
self.start_timestamp = None
self.resolvers = []

self._totals = None

def request_started(self, context: ContextValue):
self.start_date = datetime.datetime.utcnow()
self.start_timestamp = perf_counter_ns()

def resolve(self, next_: Resolver, parent: Any, info: GraphQLResolveInfo, **kwargs):
if not should_trace(info):
result = next_(parent, info, **kwargs)
return result

start_timestamp = perf_counter_ns()
record = {
"path": format_path(info.path),
"parentType": str(info.parent_type),
"fieldName": info.field_name,
"returnType": str(info.return_type),
"startOffset": start_timestamp - self.start_timestamp,
}
self.resolvers.append(record)
try:
result = next_(parent, info, **kwargs)
return result
finally:
end_timestamp = perf_counter_ns()
record["duration"] = end_timestamp - start_timestamp

def get_totals(self):
if self._totals is None:
self._totals = self._get_totals()
return self._totals

def _get_totals(self):
return {
"start": self.start_date,
"end": datetime.datetime.utcnow(),
"duration": perf_counter_ns() - self.start_timestamp,
"resolvers": self.resolvers,
}

def format(self):
totals = self.get_totals()

return {
"tracing": {
"version": 1,
"startTime": totals["start"].strftime(TIMESTAMP_FORMAT),
"endTime": totals["end"].strftime(TIMESTAMP_FORMAT),
"duration": totals["duration"],
"execution": {"resolvers": totals["resolvers"]},
}
}
58 changes: 56 additions & 2 deletions ariadne/contrib/tracing/opentracing.py
Expand Up @@ -7,10 +7,9 @@
from opentracing import Scope, Tracer, global_tracer
from opentracing.ext import tags

from ...types import ContextValue, Extension, Resolver
from ...types import ContextValue, Extension, ExtensionSync, Resolver
from .utils import format_path, should_trace


ArgFilter = Callable[[Dict[str, Any], GraphQLResolveInfo], Dict[str, Any]]


Expand Down Expand Up @@ -71,5 +70,60 @@ def filter_resolver_args(
return self._arg_filter(deepcopy(args), info)


class OpenTracingExtensionSync(ExtensionSync):
_arg_filter: Optional[ArgFilter]
_root_scope: Scope
_tracer: Tracer

def __init__(self, *, arg_filter: Optional[ArgFilter] = None):
self._arg_filter = arg_filter
self._tracer = global_tracer()
self._root_scope = None

def request_started(self, context: ContextValue):
self._root_scope = self._tracer.start_active_span("GraphQL Query")
self._root_scope.span.set_tag(tags.COMPONENT, "graphql")

def request_finished(
self, context: ContextValue, error: Optional[Exception] = None
):
self._root_scope.close()

def resolve(self, next_: Resolver, parent: Any, info: GraphQLResolveInfo, **kwargs):
if not should_trace(info):
result = next_(parent, info, **kwargs)
return result

with self._tracer.start_active_span(info.field_name) as scope:
span = scope.span
span.set_tag(tags.COMPONENT, "graphql")
span.set_tag("graphql.parentType", info.parent_type.name)

graphql_path = ".".join(
map(str, format_path(info.path)) # pylint: disable=bad-builtin
)
span.set_tag("graphql.path", graphql_path)

if kwargs:
filtered_kwargs = self.filter_resolver_args(kwargs, info)
for kwarg, value in filtered_kwargs.items():
span.set_tag(f"graphql.param.{kwarg}", value)

result = next_(parent, info, **kwargs)
return result

def filter_resolver_args(
self, args: Dict[str, Any], info: GraphQLResolveInfo
) -> Dict[str, Any]:
if not self._arg_filter:
return args

return self._arg_filter(deepcopy(args), info)


def opentracing_extension(*, arg_filter: Optional[ArgFilter] = None):
return partial(OpenTracingExtension, arg_filter=arg_filter)


def opentracing_extension_sync(*, arg_filter: Optional[ArgFilter] = None):
return partial(OpenTracingExtensionSync, arg_filter=arg_filter)
3 changes: 2 additions & 1 deletion ariadne/graphql.py
Expand Up @@ -23,6 +23,7 @@
from .types import (
ErrorFormatter,
Extension,
ExtensionSync,
GraphQLResult,
RootValue,
SubscriptionResult,
Expand Down Expand Up @@ -114,7 +115,7 @@ def graphql_sync(
validation_rules: Optional[Sequence[RuleType]] = None,
error_formatter: ErrorFormatter = format_error,
middleware: Optional[MiddlewareManager] = None,
extensions: Optional[List[Type[Extension]]] = None,
extensions: Optional[List[Type[ExtensionSync]]] = None,
**kwargs,
) -> GraphQLResult:
extension_manager = ExtensionManager(extensions)
Expand Down
20 changes: 18 additions & 2 deletions ariadne/wsgi.py
@@ -1,6 +1,6 @@
import json
from cgi import FieldStorage
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, List, Optional, Type, Union

from graphql import GraphQLError, GraphQLSchema
from graphql.execution import Middleware, MiddlewareManager
Expand All @@ -19,8 +19,12 @@
from .file_uploads import combine_multipart_data
from .format_error import format_error
from .graphql import graphql_sync
from .types import ContextValue, ErrorFormatter, GraphQLResult, RootValue
from .types import ContextValue, ErrorFormatter, ExtensionSync, GraphQLResult, RootValue

ExtensionList = Optional[List[Type[ExtensionSync]]]
Extensions = Union[
Callable[[Any, Optional[ContextValue]], ExtensionList], ExtensionList
]
MiddlewareList = Optional[List[Middleware]]
Middlewares = Union[
Callable[[Any, Optional[ContextValue]], MiddlewareList], MiddlewareList
Expand All @@ -37,13 +41,15 @@ def __init__(
debug: bool = False,
logger: Optional[str] = None,
error_formatter: ErrorFormatter = format_error,
extensions: Optional[Extensions] = None,
middleware: Optional[Middlewares] = None,
) -> None:
self.context_value = context_value
self.root_value = root_value
self.debug = debug
self.logger = logger
self.error_formatter = error_formatter
self.extensions = extensions
self.middleware = middleware
self.schema = schema

Expand Down Expand Up @@ -155,6 +161,7 @@ def extract_data_from_multipart_request(self, environ: dict) -> Any:

def execute_query(self, environ: dict, data: dict) -> GraphQLResult:
context_value = self.get_context_for_request(environ)
extensions = self.get_extensions_for_request(environ, context_value)
middleware = self.get_middleware_for_request(environ, context_value)

return graphql_sync(
Expand All @@ -165,6 +172,7 @@ def execute_query(self, environ: dict, data: dict) -> GraphQLResult:
debug=self.debug,
logger=self.logger,
error_formatter=self.error_formatter,
extensions=extensions,
middleware=middleware,
)

Expand All @@ -173,6 +181,14 @@ def get_context_for_request(self, environ: dict) -> Optional[ContextValue]:
return self.context_value(environ)
return self.context_value or environ

def get_extensions_for_request(
self, environ: dict, context: Optional[ContextValue]
) -> ExtensionList:
if callable(self.extensions):
extensions = self.extensions(environ, context)
return extensions
return self.extensions

def get_middleware_for_request(
self, environ: dict, context: Optional[ContextValue]
) -> Optional[MiddlewareManager]:
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.in
Expand Up @@ -13,5 +13,6 @@ pytest-mock
python-dateutil
snapshottest
requests
werkzeug

pip-tools
1 change: 1 addition & 0 deletions requirements-dev.txt
Expand Up @@ -50,5 +50,6 @@ typing-extensions==3.7.4 # via mypy
urllib3==1.25.3 # via requests
wasmer==0.3.0 # via fastdiff
wcwidth==0.1.7 # via pytest
werkzeug==0.15.5
wrapt==1.11.2 # via astroid
zipp==0.5.2 # via importlib-metadata
55 changes: 55 additions & 0 deletions tests/test_extensions_sync.py
@@ -0,0 +1,55 @@
from unittest.mock import Mock

from ariadne import ExtensionManager
from ariadne.types import ExtensionSync as Extension


context: dict = {}
exception = ValueError()
query = "{ test }"


def test_request_started_event_is_called_by_extension_manager():
extension = Mock(spec=Extension)
manager = ExtensionManager([Mock(return_value=extension)])
with manager.request(context):
pass

extension.request_started.assert_called_once_with(context)


def test_request_finished_event_is_called_by_extension_manager():
extension = Mock(spec=Extension)
manager = ExtensionManager([Mock(return_value=extension)])
with manager.request(context):
pass

extension.request_finished.assert_called_once_with(context)


def test_request_finished_event_is_called_with_error():
extension = Mock(spec=Extension)
manager = ExtensionManager([Mock(return_value=extension)])
try:
with manager.request(context):
raise exception
except: # pylint: disable=bare-except
pass

extension.request_finished.assert_called_once_with(context, exception)


def test_has_errors_event_is_called_with_errors_list():
extension = Mock(spec=Extension)
manager = ExtensionManager([Mock(return_value=extension)])
manager.has_errors([exception])
extension.has_errors.assert_called_once_with([exception])


def test_extensions_are_formatted():
extensions = [
Mock(spec=Extension, format=Mock(return_value={"a": 1})),
Mock(spec=Extension, format=Mock(return_value={"b": 2})),
]
manager = ExtensionManager([Mock(return_value=ext) for ext in extensions])
assert manager.format() == {"a": 1, "b": 2}
54 changes: 54 additions & 0 deletions tests/tracing/test_apollotracing_sync.py
@@ -0,0 +1,54 @@
import pytest
from freezegun import freeze_time
from graphql import get_introspection_query

from ariadne import graphql_sync as graphql
from ariadne.contrib.tracing.apollotracing import (
ApolloTracingExtensionSync as ApolloTracingExtension,
)


@pytest.mark.asyncio
async def test_apollotracing_extension_causes_no_errors_in_query_execution(schema):
_, result = graphql(
schema, {"query": "{ status }"}, extensions=[ApolloTracingExtension]
)
assert result["data"] == {"status": True}


@pytest.fixture
def freeze_microtime(mocker):
mocker.patch(
"ariadne.contrib.tracing.apollotracing.perf_counter_ns", return_value=2
)


@freeze_time("2012-01-14 03:21:34")
@pytest.mark.asyncio
async def test_apollotracing_extension_adds_tracing_data_to_result_extensions(
schema, freeze_microtime, snapshot # pylint: disable=unused-argument
):
_, result = graphql(
schema, {"query": "{ status }"}, extensions=[ApolloTracingExtension]
)
snapshot.assert_match(result)


@freeze_time("2012-01-14 03:21:34")
@pytest.mark.asyncio
async def test_apollotracing_extension_handles_exceptions_in_resolvers(
schema, freeze_microtime, snapshot # pylint: disable=unused-argument
):
_, result = graphql(
schema, {"query": "{ testError }"}, extensions=[ApolloTracingExtension]
)
snapshot.assert_match(result["data"])


@pytest.mark.asyncio
async def test_apollotracing_extension_doesnt_break_introspection(schema):
introspection_query = get_introspection_query(descriptions=True)
_, result = graphql(
schema, {"query": introspection_query}, extensions=[ApolloTracingExtension]
)
assert "errors" not in result

0 comments on commit ed8ff5b

Please sign in to comment.