Skip to content

Commit

Permalink
Check macro argument types
Browse files Browse the repository at this point in the history
  • Loading branch information
Bojan Serafimov committed Feb 27, 2019
1 parent 281580a commit 7b2280e
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 9 deletions.
2 changes: 1 addition & 1 deletion graphql_compiler/macros/macro_edge/validation.py
Expand Up @@ -234,7 +234,7 @@ def get_and_validate_macro_edge_info(schema, ast, macro_edge_args,
# Check that the macro successfully compiles to IR
_, input_metadata, _, _ = ast_to_ir(schema, _get_minimal_query_ast_from_macro_ast(ast),
type_equivalence_hints=type_equivalence_hints)
ensure_arguments_are_provided(input_metadata, macro_edge_args)
ensure_arguments_are_provided(input_metadata, macro_edge_args, check_types=True)
# TODO(bojanserafimov): Check all the provided arguments were necessary
# TODO(bojanserafimov): Check the arguments have the correct types
# TODO(bojanserafimov): @macro_edge_target is not on a union type
Expand Down
69 changes: 66 additions & 3 deletions graphql_compiler/query_formatting/common.py
@@ -1,9 +1,15 @@
# Copyright 2017-present Kensho Technologies, LLC.
"""Safely insert runtime arguments into compiled GraphQL queries."""
import datetime
import decimal

import arrow
from graphql import GraphQLBoolean, GraphQLFloat, GraphQLID, GraphQLInt, GraphQLList, GraphQLString
import six

from ..compiler import GREMLIN_LANGUAGE, MATCH_LANGUAGE, SQL_LANGUAGE
from ..exceptions import GraphQLInvalidArgumentError
from ..schema import GraphQLDate, GraphQLDateTime, GraphQLDecimal
from .gremlin_formatting import insert_arguments_into_gremlin_query
from .match_formatting import insert_arguments_into_match_query
from .sql_formatting import insert_arguments_into_sql_query
Expand All @@ -13,10 +19,64 @@
# Public API
######

def ensure_arguments_are_provided(expected_types, arguments):
def _check_is_string_value(value):
"""Raise if the value is not a proper utf-8 string."""
if not isinstance(value, six.string_types):
if isinstance(value, bytes): # should only happen in py3
value.decode('utf-8') # decoding should not raise errors
else:
raise GraphQLInvalidArgumentError(u'Attempting to convert a non-string into a string: '
u'{}'.format(value))


# TODO(bojanserafimov): test this function
def _validate_argument_type(expected_type, value):
"""Check if the value is appropriate for the type and usable in any of our backends."""
if GraphQLString.is_same_type(expected_type):
_check_is_string_value(value)
elif GraphQLID.is_same_type(expected_type):
# IDs can be strings or numbers, but the GraphQL library coerces them to strings.
# We will follow suit and treat them as strings.
_check_is_string_value(value)
elif GraphQLFloat.is_same_type(expected_type):
if not isinstance(value, float):
raise GraphQLInvalidArgumentError(u'Attempting to represent a non-float as a float: '
u'{}'.format(value))
elif GraphQLInt.is_same_type(expected_type):
# Special case: in Python, isinstance(True, int) returns True.
# Safeguard against this with an explicit check against bool type.
if isinstance(value, bool) or not isinstance(value, six.integer_types):
raise GraphQLInvalidArgumentError(u'Attempting to represent a non-int as an int: '
u'{}'.format(value))
elif GraphQLBoolean.is_same_type(expected_type):
if not isinstance(value, bool):
raise GraphQLInvalidArgumentError(u'Attempting to represent a non-bool as a bool: '
u'{}'.format(value))
elif GraphQLDecimal.is_same_type(expected_type):
if not isinstance(value, decimal.Decimal):
try:
decimal.Decimal(value)
except decimal.InvalidOperation as e:
raise GraphQLInvalidArgumentError(e)
elif GraphQLDate.is_same_type(expected_type):
if not isinstance(value, datetime.date):
raise GraphQLInvalidArgumentError(u'Attempting to represent a non-date as a date: '
u'{}'.format(value))
elif GraphQLDateTime.is_same_type(expected_type):
if not isinstance(value, (datetime.date, arrow.Arrow)):
raise GraphQLInvalidArgumentError(u'Attempting to represent a non-date as a date: '
u'{}'.format(value))
elif isinstance(expected_type, GraphQLList):
if not isinstance(value, list):
raise GraphQLInvalidArgumentError(u'Attempting to represent a non-list as a list: '
u'{}'.format(value))
else:
raise AssertionError(u'Could not safely represent the requested GraphQL type: '
u'{} {}'.format(expected_type, value))


def ensure_arguments_are_provided(expected_types, arguments, check_types=False):
"""Ensure that all arguments expected by the query were actually provided."""
# This function only checks that the arguments were specified,
# and does not check types. Type checking is done as part of the actual formatting step.
expected_arg_names = set(six.iterkeys(expected_types))
provided_arg_names = set(six.iterkeys(arguments))

Expand All @@ -26,6 +86,9 @@ def ensure_arguments_are_provided(expected_types, arguments):
raise GraphQLInvalidArgumentError(u'Missing or unexpected arguments found: '
u'missing {}, unexpected '
u'{}'.format(missing_args, unexpected_args))
if check_types:
for name in expected_arg_names:
_validate_argument_type(expected_types[name], arguments[name])


def insert_arguments_into_query(compilation_result, arguments):
Expand Down
32 changes: 27 additions & 5 deletions graphql_compiler/tests/test_macro_validation.py
Expand Up @@ -3,7 +3,7 @@

import pytest

from ..exceptions import GraphQLInvalidMacroError
from ..exceptions import GraphQLInvalidArgumentError, GraphQLInvalidMacroError
from ..macros import create_macro_registry, register_macro_edge
from .test_helpers import get_schema

Expand Down Expand Up @@ -243,7 +243,7 @@ def test_macro_edge_invalid_no_op_2(self):

def test_macro_edge_missing_args(self):
query = '''{
Animal @macro_edge_definition {
Animal @macro_edge_definition(name: "out_Animal_GrandparentOf") {
net_worth @filter(op_name: "=", value: ["$net_worth"])
color @filter(op_name: "=", value: ["$color"])
out_Animal_ParentOf {
Expand All @@ -258,13 +258,13 @@ def test_macro_edge_missing_args(self):
}

macro_registry = create_macro_registry()
with self.assertRaises(GraphQLInvalidMacroError):
with self.assertRaises(GraphQLInvalidArgumentError):
register_macro_edge(macro_registry, self.schema, query,
args, self.type_equivalence_hints)

def test_macro_edge_extra_args(self):
query = '''{
Animal @macro_edge_definition {
Animal @macro_edge_definition(name: "out_Animal_GrandparentOf") {
net_worth @filter(op_name: "=", value: ["$net_worth"])
color @filter(op_name: "=", value: ["$color"])
out_Animal_ParentOf {
Expand All @@ -281,7 +281,29 @@ def test_macro_edge_extra_args(self):
}

macro_registry = create_macro_registry()
with self.assertRaises(GraphQLInvalidMacroError):
with self.assertRaises(GraphQLInvalidArgumentError):
register_macro_edge(macro_registry, self.schema, query,
args, self.type_equivalence_hints)

def test_macro_edge_incorrect_arg_types(self):
query = '''{
Animal @macro_edge_definition(name: "out_Animal_GrandparentOf") {
net_worth @filter(op_name: "=", value: ["$net_worth"])
color @filter(op_name: "=", value: ["$color"])
out_Animal_ParentOf {
out_Animal_ParentOf @macro_edge_target {
uuid
}
}
}
}'''
args = {
'net_worth': 'five_cows',
'color': 'green',
}

macro_registry = create_macro_registry()
with self.assertRaises(GraphQLInvalidArgumentError):
register_macro_edge(macro_registry, self.schema, query,
args, self.type_equivalence_hints)

Expand Down

0 comments on commit 7b2280e

Please sign in to comment.