Skip to content

Commit

Permalink
Avoid polluting the context argument
Browse files Browse the repository at this point in the history
  • Loading branch information
Bojan Serafimov committed Feb 4, 2019
1 parent 6302e02 commit 3dce8ff
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 31 deletions.
48 changes: 23 additions & 25 deletions graphql_compiler/compiler/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ def takes_parameters(count):
def decorator(f):
"""Decorate the supplied function with the "takes_parameters" logic."""
@wraps(f)
def wrapper(filter_operation_info, context, parameters, *args, **kwargs):
def wrapper(filter_operation_info, location, context, parameters, *args, **kwargs):
"""Check that the supplied number of parameters equals the expected number."""
if len(parameters) != count:
raise GraphQLCompilationError(u'Incorrect number of parameters, expected {} got '
u'{}: {}'.format(count, len(parameters), parameters))

return f(filter_operation_info, context, parameters, *args, **kwargs)
return f(filter_operation_info, location, context, parameters, *args, **kwargs)

return wrapper

Expand All @@ -89,7 +89,7 @@ def _is_tag_argument(argument_name):
return argument_name.startswith('%')


def _represent_argument(context, argument, inferred_type):
def _represent_argument(directive_location, context, argument, inferred_type):
"""Return a two-element tuple that represents the argument to the directive being processed.
Args:
Expand Down Expand Up @@ -142,9 +142,7 @@ def _represent_argument(context, argument, inferred_type):
u'{} vs {}'.format(tag_inferred_type, inferred_type))

# See if the argument is colocated with the directive
directive_location_without_field = context['directive_location'].remove_field()
argument_location_without_field = argument_context['location'].remove_field()
colocated = directive_location_without_field == argument_location_without_field
colocated = directive_location.at_vertex() == location.at_vertex()

non_existence_expression = None
if optional:
Expand All @@ -169,7 +167,8 @@ def _represent_argument(context, argument, inferred_type):

@scalar_leaf_only(u'comparison operator')
@takes_parameters(1)
def _process_comparison_filter_directive(filter_operation_info, context, parameters, operator=None):
def _process_comparison_filter_directive(filter_operation_info, location,
context, parameters, operator=None):
"""Return a Filter basic block that performs the given comparison against the property field.
Args:
Expand All @@ -196,7 +195,7 @@ def _process_comparison_filter_directive(filter_operation_info, context, paramet

argument_inferred_type = strip_non_null_from_type(filtered_field_type)
argument_expression, non_existence_expression = _represent_argument(
context, parameters[0], argument_inferred_type)
location, context, parameters[0], argument_inferred_type)

comparison_expression = expressions.BinaryComposition(
operator, expressions.LocalField(filtered_field_name), argument_expression)
Expand All @@ -215,7 +214,7 @@ def _process_comparison_filter_directive(filter_operation_info, context, paramet

@vertex_field_only(u'has_edge_degree')
@takes_parameters(1)
def _process_has_edge_degree_filter_directive(filter_operation_info, context, parameters):
def _process_has_edge_degree_filter_directive(filter_operation_info, location, context, parameters):
"""Return a Filter basic block that checks the degree of the edge to the given vertex field.
Args:
Expand Down Expand Up @@ -251,7 +250,7 @@ def _process_has_edge_degree_filter_directive(filter_operation_info, context, pa

argument_inferred_type = GraphQLInt
argument_expression, non_existence_expression = _represent_argument(
context, argument, argument_inferred_type)
location, context, argument, argument_inferred_type)

if non_existence_expression is not None:
raise AssertionError(u'Since we do not support tagged values, non_existence_expression '
Expand Down Expand Up @@ -289,7 +288,7 @@ def _process_has_edge_degree_filter_directive(filter_operation_info, context, pa

@vertex_field_only(u'name_or_alias')
@takes_parameters(1)
def _process_name_or_alias_filter_directive(filter_operation_info, context, parameters):
def _process_name_or_alias_filter_directive(filter_operation_info, location, context, parameters):
"""Return a Filter basic block that checks for a match against an Entity's name or alias.
Args:
Expand Down Expand Up @@ -335,7 +334,7 @@ def _process_name_or_alias_filter_directive(filter_operation_info, context, para

argument_inferred_type = name_field_type
argument_expression, non_existence_expression = _represent_argument(
context, parameters[0], argument_inferred_type)
location, context, parameters[0], argument_inferred_type)

check_against_name = expressions.BinaryComposition(
u'=', expressions.LocalField('name'), argument_expression)
Expand All @@ -355,7 +354,7 @@ def _process_name_or_alias_filter_directive(filter_operation_info, context, para

@scalar_leaf_only(u'between')
@takes_parameters(2)
def _process_between_filter_directive(filter_operation_info, context, parameters):
def _process_between_filter_directive(filter_operation_info, location, context, parameters):
"""Return a Filter basic block that checks that a field is between two values, inclusive.
Args:
Expand All @@ -375,9 +374,9 @@ def _process_between_filter_directive(filter_operation_info, context, parameters

argument_inferred_type = strip_non_null_from_type(filtered_field_type)
arg1_expression, arg1_non_existence = _represent_argument(
context, parameters[0], argument_inferred_type)
location, context, parameters[0], argument_inferred_type)
arg2_expression, arg2_non_existence = _represent_argument(
context, parameters[1], argument_inferred_type)
location, context, parameters[1], argument_inferred_type)

lower_bound_clause = expressions.BinaryComposition(
u'>=', expressions.LocalField(filtered_field_name), arg1_expression)
Expand All @@ -400,7 +399,7 @@ def _process_between_filter_directive(filter_operation_info, context, parameters

@scalar_leaf_only(u'in_collection')
@takes_parameters(1)
def _process_in_collection_filter_directive(filter_operation_info, context, parameters):
def _process_in_collection_filter_directive(filter_operation_info, location, context, parameters):
"""Return a Filter basic block that checks for a value's existence in a collection.
Args:
Expand All @@ -419,7 +418,7 @@ def _process_in_collection_filter_directive(filter_operation_info, context, para

argument_inferred_type = GraphQLList(strip_non_null_from_type(filtered_field_type))
argument_expression, non_existence_expression = _represent_argument(
context, parameters[0], argument_inferred_type)
location, context, parameters[0], argument_inferred_type)

filter_predicate = expressions.BinaryComposition(
u'contains', argument_expression, expressions.LocalField(filtered_field_name))
Expand All @@ -434,7 +433,7 @@ def _process_in_collection_filter_directive(filter_operation_info, context, para

@scalar_leaf_only(u'has_substring')
@takes_parameters(1)
def _process_has_substring_filter_directive(filter_operation_info, context, parameters):
def _process_has_substring_filter_directive(filter_operation_info, location, context, parameters):
"""Return a Filter basic block that checks if the directive arg is a substring of the field.
Args:
Expand All @@ -457,7 +456,7 @@ def _process_has_substring_filter_directive(filter_operation_info, context, para
argument_inferred_type = GraphQLString

argument_expression, non_existence_expression = _represent_argument(
context, parameters[0], argument_inferred_type)
location, context, parameters[0], argument_inferred_type)

filter_predicate = expressions.BinaryComposition(
u'has_substring', expressions.LocalField(filtered_field_name), argument_expression)
Expand All @@ -471,7 +470,7 @@ def _process_has_substring_filter_directive(filter_operation_info, context, para


@takes_parameters(1)
def _process_contains_filter_directive(filter_operation_info, context, parameters):
def _process_contains_filter_directive(filter_operation_info, location, context, parameters):
"""Return a Filter basic block that checks if the directive arg is contained in the field.
Args:
Expand All @@ -495,7 +494,7 @@ def _process_contains_filter_directive(filter_operation_info, context, parameter

argument_inferred_type = strip_non_null_from_type(base_field_type.of_type)
argument_expression, non_existence_expression = _represent_argument(
context, parameters[0], argument_inferred_type)
location, context, parameters[0], argument_inferred_type)

filter_predicate = expressions.BinaryComposition(
u'contains', expressions.LocalField(filtered_field_name), argument_expression)
Expand All @@ -509,7 +508,7 @@ def _process_contains_filter_directive(filter_operation_info, context, parameter


@takes_parameters(1)
def _process_intersects_filter_directive(filter_operation_info, context, parameters):
def _process_intersects_filter_directive(filter_operation_info, location, context, parameters):
"""Return a Filter basic block that checks if the directive arg and the field intersect.
Args:
Expand All @@ -532,7 +531,7 @@ def _process_intersects_filter_directive(filter_operation_info, context, paramet
u'type {}'.format(filtered_field_type))

argument_expression, non_existence_expression = _represent_argument(
context, parameters[0], argument_inferred_type)
location, context, parameters[0], argument_inferred_type)

filter_predicate = expressions.BinaryComposition(
u'intersects', expressions.LocalField(filtered_field_name), argument_expression)
Expand Down Expand Up @@ -662,5 +661,4 @@ def process_filter_directive(filter_operation_info, location, context):
FilterInfo(fields=fields, op_name=op_name, args=tuple(operator_params))
)

context['directive_location'] = location
return process_func(filter_operation_info, context, operator_params)
return process_func(filter_operation_info, location, context, operator_params)
11 changes: 5 additions & 6 deletions graphql_compiler/compiler/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,6 @@ def __init__(self, query_path, field=None, visit_counter=1):
# visit 'Y' in two different ways to generate colliding 'X__Y___1' identifiers.
self.visit_counter = visit_counter

def remove_field(self):
"""Return a new location object with field set to none."""
return Location(self.query_path, field=None, visit_counter=self.visit_counter)

def navigate_to_field(self, field):
"""Return a new Location object at the specified field of the current Location's vertex."""
if self.field:
Expand Down Expand Up @@ -502,8 +498,11 @@ def get_first_folded_edge(self):
first_folded_edge_direction, first_folded_edge_name = self.fold_path[0]
return first_folded_edge_direction, first_folded_edge_name

def remove_field(self):
"""Return a new location object with field set to none."""
def at_vertex(self):
"""Get the Location ignoring its field component."""
if not self.field:
return self

return FoldScopeLocation(self.base_location, self.fold_path, field=None)

def navigate_to_field(self, field):
Expand Down

0 comments on commit 3dce8ff

Please sign in to comment.