diff --git a/graphql_compiler/compiler/blocks.py b/graphql_compiler/compiler/blocks.py index 293efa2d0..2f6bc1d76 100644 --- a/graphql_compiler/compiler/blocks.py +++ b/graphql_compiler/compiler/blocks.py @@ -444,3 +444,11 @@ class EndOptional(MarkerBlock): def validate(self): """In isolation, EndOptional blocks are always valid.""" pass + + +class GlobalOperationsStart(MarkerBlock): + """Marker block for the end of MATCH traversals, and the beginning of global operations.""" + + def validate(self): + """In isolation, GlobalOperationsStart blocks are always valid.""" + pass diff --git a/graphql_compiler/compiler/emit_match.py b/graphql_compiler/compiler/emit_match.py index 1c2e3b65f..0285bbf3a 100644 --- a/graphql_compiler/compiler/emit_match.py +++ b/graphql_compiler/compiler/emit_match.py @@ -5,6 +5,7 @@ import six from .blocks import Filter, QueryRoot, Recurse, Traverse +from .expressions import TrueLiteral from .helpers import validate_safe_string @@ -156,6 +157,14 @@ def _construct_output_to_match(output_block): return u'SELECT %s FROM' % (u', '.join(selections),) +def _construct_where_to_match(where_block): + """Transform a Filter block into a MATCH query string.""" + if where_block.predicate == TrueLiteral: + raise AssertionError(u'Received WHERE block with TrueLiteral predicate: {}' + .format(where_block)) + return u'WHERE ' + where_block.predicate.to_match() + + ############## # Public API # ############## @@ -197,6 +206,10 @@ def emit_code_from_single_match_query(match_query): # Represent and add the SELECT clauses with the proper output data. query_data.appendleft(_construct_output_to_match(match_query.output_block)) + # Represent and add the WHERE clause with the proper filters. + if match_query.where_block is not None: + query_data.append(_construct_where_to_match(match_query.where_block)) + return u' '.join(query_data) diff --git a/graphql_compiler/compiler/expressions.py b/graphql_compiler/compiler/expressions.py index 09760d9e7..e419e1cc8 100644 --- a/graphql_compiler/compiler/expressions.py +++ b/graphql_compiler/compiler/expressions.py @@ -6,8 +6,8 @@ from ..schema import GraphQLDate, GraphQLDateTime from .compiler_entities import Expression from .helpers import (STANDARD_DATE_FORMAT, STANDARD_DATETIME_FORMAT, FoldScopeLocation, Location, - ensure_unicode_string, is_graphql_type, safe_quoted_string, - strip_non_null_from_type, validate_safe_string) + ensure_unicode_string, is_graphql_type, is_vertex_field_name, + safe_quoted_string, strip_non_null_from_type, validate_safe_string) # Since MATCH uses $-prefixed keywords to indicate special values, @@ -235,6 +235,53 @@ def to_gremlin(self): return u'{}.{}'.format(local_object_name, self.field_name) +class SelectEdgeContextField(Expression): + """An edge field drawn from the global context, for use in a SELECT WHERE statement.""" + + def __init__(self, location): + """Construct a new SelectEdgeContextField object that references an edge field. + + Args: + location: Location, specifying where the field was declared. + The Location object must contain an edge field. + + Returns: + new SelectEdgeContextField object + """ + super(SelectEdgeContextField, self).__init__(location) + self.location = location + self.validate() + + def validate(self): + """Validate that the SelectEdgeContextField is correctly representable.""" + if not isinstance(self.location, Location): + raise TypeError(u'Expected Location location, got: {} {}' + .format(type(self.location).__name__, self.location)) + + if self.location.field is None: + raise AssertionError(u'Received Location without a field: {}' + .format(self.location)) + + if not is_vertex_field_name(self.location.field): + raise AssertionError(u'Received Location with a non-edge field: {}' + .format(self.location)) + + def to_match(self): + """Return a unicode object with the MATCH representation of this SelectEdgeContextField.""" + self.validate() + + mark_name, field_name = self.location.get_location_name() + validate_safe_string(mark_name) + validate_safe_string(field_name) + + return u'%s.%s' % (mark_name, field_name) + + def to_gremlin(self): + """Not implemented, should not be used.""" + raise AssertionError(u'SelectEdgeContextField is only used for the WHERE statement in ' + u'MATCH. This function should not be called.') + + class ContextField(Expression): """A field drawn from the global context, e.g. if selected earlier in the query.""" diff --git a/graphql_compiler/compiler/ir_lowering_common.py b/graphql_compiler/compiler/ir_lowering_common.py index 1627c47bf..b25fe96d9 100644 --- a/graphql_compiler/compiler/ir_lowering_common.py +++ b/graphql_compiler/compiler/ir_lowering_common.py @@ -1,5 +1,7 @@ # Copyright 2017-present Kensho Technologies, LLC. """Language-independent IR lowering and optimization functions.""" +import six + from .blocks import (ConstructResult, EndOptional, Filter, Fold, MarkLocation, Recurse, Traverse, Unfold) from .expressions import (BinaryComposition, ContextField, ContextFieldExistence, FalseLiteral, @@ -265,6 +267,56 @@ def extract_optional_location_root_info(ir_blocks): return complex_optional_roots, location_to_optional_root +def extract_simple_optional_location_info( + ir_blocks, complex_optional_roots, location_to_optional_root): + """Construct a map from simple optional locations to their inner location and traversed edge. + + Args: + ir_blocks: list of IR blocks to extract optional data from + complex_optional_roots: list of @optional locations (location immmediately preceding + an @optional traverse) that expand vertex fields + location_to_optional_root: dict mapping from location -> optional_root where location is + within @optional (not necessarily one that expands vertex fields) + and optional_root is the location preceding the corresponding + @optional scope + + Returns: + dict mapping from simple_optional_root_location -> dict containing keys + - 'inner_location_name': Location object correspoding to the unique MarkLocation present + within a simple optional (one that does not expand vertex fields) + scope + - 'edge_field': string representing the optional edge being traversed + where simple_optional_root_to_inner_location is the location preceding the @optional scope + """ + # Simple optional roots are a subset of location_to_optional_root.values() (all optional roots). + # We filter out the ones that are also present in complex_optional_roots. + simple_optional_root_to_inner_location = { + optional_root_location: inner_location + for inner_location, optional_root_location in six.iteritems(location_to_optional_root) + if optional_root_location not in complex_optional_roots + } + simple_optional_root_locations = set(simple_optional_root_to_inner_location.keys()) + + simple_optional_root_info = {} + preceding_location = None + for current_block in ir_blocks: + if isinstance(current_block, MarkLocation): + preceding_location = current_block.location + elif isinstance(current_block, Traverse) and current_block.optional: + if preceding_location in simple_optional_root_locations: + # The current optional Traverse is "simple" + # i.e. it does not contain any Traverses within. + inner_location = simple_optional_root_to_inner_location[preceding_location] + inner_location_name, _ = inner_location.get_location_name() + simple_optional_info_dict = { + 'inner_location_name': inner_location_name, + 'edge_field': current_block.get_field_name(), + } + simple_optional_root_info[preceding_location] = simple_optional_info_dict + + return simple_optional_root_info + + def remove_end_optionals(ir_blocks): """Return a list of IR blocks as a copy of the original, with EndOptional blocks removed.""" new_ir_blocks = [] diff --git a/graphql_compiler/compiler/ir_lowering_match/__init__.py b/graphql_compiler/compiler/ir_lowering_match/__init__.py index 7108192a7..c0e84cb16 100644 --- a/graphql_compiler/compiler/ir_lowering_match/__init__.py +++ b/graphql_compiler/compiler/ir_lowering_match/__init__.py @@ -1,7 +1,9 @@ # Copyright 2018-present Kensho Technologies, LLC. import six +from ..blocks import Filter, GlobalOperationsStart from ..ir_lowering_common import (extract_optional_location_root_info, + extract_simple_optional_location_info, lower_context_field_existence, merge_consecutive_filter_clauses, optimize_boolean_expression_comparisons, remove_end_optionals) from .ir_lowering import (lower_backtrack_blocks, @@ -17,12 +19,13 @@ lower_context_field_expressions, prune_non_existent_outputs) from ..match_query import convert_to_match_query from ..workarounds import orientdb_class_with_while, orientdb_eval_scheduling - +from .utils import construct_where_filter_predicate ############## # Public API # ############## + def lower_ir(ir_blocks, location_types, type_equivalence_hints=None): """Lower the IR into an IR form that can be represented in MATCH queries. @@ -50,11 +53,21 @@ def lower_ir(ir_blocks, location_types, type_equivalence_hints=None): """ sanity_check_ir_blocks_from_frontend(ir_blocks) - # These lowering / optimization passes work on IR blocks. + # Extract information for both simple and complex @optional traverses location_to_optional_results = extract_optional_location_root_info(ir_blocks) complex_optional_roots, location_to_optional_root = location_to_optional_results + simple_optional_root_info = extract_simple_optional_location_info( + ir_blocks, complex_optional_roots, location_to_optional_root) ir_blocks = remove_end_optionals(ir_blocks) + # Append global operation block(s) to filter out incorrect results + # from simple optional match traverses (using a WHERE statement) + if len(simple_optional_root_info) > 0: + where_filter_predicate = construct_where_filter_predicate(simple_optional_root_info) + ir_blocks.insert(-1, GlobalOperationsStart()) + ir_blocks.insert(-1, Filter(where_filter_predicate)) + + # These lowering / optimization passes work on IR blocks. ir_blocks = lower_context_field_existence(ir_blocks) ir_blocks = optimize_boolean_expression_comparisons(ir_blocks) ir_blocks = rewrite_binary_composition_inside_ternary_conditional(ir_blocks) diff --git a/graphql_compiler/compiler/ir_lowering_match/ir_lowering.py b/graphql_compiler/compiler/ir_lowering_match/ir_lowering.py index 4ed98de81..3954475f8 100644 --- a/graphql_compiler/compiler/ir_lowering_match/ir_lowering.py +++ b/graphql_compiler/compiler/ir_lowering_match/ir_lowering.py @@ -23,6 +23,7 @@ # Optimization / lowering passes # ################################## + def rewrite_binary_composition_inside_ternary_conditional(ir_blocks): """Rewrite BinaryConditional expressions in the true/false values of TernaryConditionals.""" def visitor_fn(expression): diff --git a/graphql_compiler/compiler/ir_lowering_match/optional_traversal.py b/graphql_compiler/compiler/ir_lowering_match/optional_traversal.py index 715725e3c..3d4f22ab2 100644 --- a/graphql_compiler/compiler/ir_lowering_match/optional_traversal.py +++ b/graphql_compiler/compiler/ir_lowering_match/optional_traversal.py @@ -6,26 +6,11 @@ from ..blocks import ConstructResult, Filter, Traverse from ..expressions import (BinaryComposition, ContextField, FoldedOutputContextField, Literal, - LocalField, NullLiteral, OutputContextField, TernaryConditional, - TrueLiteral, UnaryTransformation, Variable, ZeroLiteral) + LocalField, OutputContextField, TernaryConditional, TrueLiteral, + UnaryTransformation, Variable) from ..match_query import MatchQuery, MatchStep -from .utils import BetweenClause, CompoundMatchQuery - - -def _filter_local_edge_field_non_existence(field_name): - """Return an Expression that is True iff the specified edge (field_name) does not exist.""" - # When an edge does not exist at a given vertex, OrientDB represents that in one of two ways: - # - the edge's field does not exist (is null) on the vertex document, or - # - the edge's field does exist, but is an empty list. - # We check both of these possibilities. - local_field = LocalField(field_name) - - field_null_check = BinaryComposition(u'=', local_field, NullLiteral) - - local_field_size = UnaryTransformation(u'size', local_field) - field_size_check = BinaryComposition(u'=', local_field_size, ZeroLiteral) - - return BinaryComposition(u'||', field_null_check, field_size_check) +from .utils import (BetweenClause, CompoundMatchQuery, expression_list_to_conjunction, + filter_edge_field_non_existence) def _prune_traverse_using_omitted_locations(match_traversal, omitted_locations, @@ -63,7 +48,7 @@ def _prune_traverse_using_omitted_locations(match_traversal, omitted_locations, elif optional_root_location in omitted_locations: # Add filter to indicate that the omitted edge(s) shoud not exist field_name = step.root_block.get_field_name() - new_predicate = _filter_local_edge_field_non_existence(field_name) + new_predicate = filter_edge_field_non_existence(LocalField(field_name)) old_filter = new_match_traversal[-1].where_block if old_filter is not None: new_predicate = BinaryComposition(u'&&', old_filter.predicate, new_predicate) @@ -146,6 +131,7 @@ def convert_optional_traversals_to_compound_match_query( match_traversals=match_traversals, folds=match_query.folds, output_block=match_query.output_block, + where_block=match_query.where_block, ) for match_traversals in compound_match_traversals ] @@ -216,7 +202,6 @@ def prune_non_existent_outputs(compound_match_query): for match_query in compound_match_query.match_queries: match_traversals = match_query.match_traversals output_block = match_query.output_block - folds = match_query.folds present_locations_tuple = _get_present_locations(match_traversals) present_locations, present_non_optional_locations = present_locations_tuple @@ -260,8 +245,9 @@ def prune_non_existent_outputs(compound_match_query): match_queries.append( MatchQuery( match_traversals=match_traversals, - folds=folds, - output_block=ConstructResult(new_output_fields) + folds=match_query.folds, + output_block=ConstructResult(new_output_fields), + where_block=match_query.where_block, ) ) @@ -296,18 +282,12 @@ def _construct_location_to_filter_list(match_query): def _filter_list_to_conjunction_expression(filter_list): """Convert a list of filters to an Expression that is the conjunction of all of them.""" if not isinstance(filter_list, list): - raise AssertionError(u'Expected `list`, Received {}.'.format(filter_list)) + raise AssertionError(u'Expected `list`, Received: {}.'.format(filter_list)) + if any((not isinstance(filter_block, Filter) for filter_block in filter_list)): + raise AssertionError(u'Expected list of Filter objects. Received: {}'.format(filter_list)) - if not isinstance(filter_list[0], Filter): - raise AssertionError(u'Non-Filter object {} found in filter_list' - .format(filter_list[0])) - - if len(filter_list) == 1: - return filter_list[0].predicate - else: - return BinaryComposition(u'&&', - _filter_list_to_conjunction_expression(filter_list[1:]), - filter_list[0].predicate) + expression_list = [filter_block.predicate for filter_block in filter_list] + return expression_list_to_conjunction(expression_list) def _apply_filters_to_first_location_occurrence(match_traversal, location_to_filters, @@ -402,7 +382,8 @@ def collect_filters_to_first_location_occurrence(compound_match_query): MatchQuery( match_traversals=new_match_traversals, folds=match_query.folds, - output_block=match_query.output_block + output_block=match_query.output_block, + where_block=match_query.where_block, ) ) @@ -572,7 +553,8 @@ def lower_context_field_expressions(compound_match_query): MatchQuery( match_traversals=new_match_traversals, folds=match_query.folds, - output_block=match_query.output_block + output_block=match_query.output_block, + where_block=match_query.where_block, ) ) diff --git a/graphql_compiler/compiler/ir_lowering_match/utils.py b/graphql_compiler/compiler/ir_lowering_match/utils.py index 2be7bc18f..b7024f51a 100644 --- a/graphql_compiler/compiler/ir_lowering_match/utils.py +++ b/graphql_compiler/compiler/ir_lowering_match/utils.py @@ -1,7 +1,30 @@ # Copyright 2018-present Kensho Technologies, LLC. from collections import namedtuple -from ..expressions import Expression, LocalField +import six + +from ..expressions import (BinaryComposition, Expression, LocalField, NullLiteral, + SelectEdgeContextField, TrueLiteral, UnaryTransformation, ZeroLiteral) +from ..helpers import Location, is_vertex_field_name + + +def expression_list_to_conjunction(expression_list): + """Convert a list of expressions to an Expression that is the conjunction of all of them.""" + if not isinstance(expression_list, list): + raise AssertionError(u'Expected `list`, Received {}.'.format(expression_list)) + + if len(expression_list) == 0: + return TrueLiteral + + if not isinstance(expression_list[0], Expression): + raise AssertionError(u'Non-Expression object {} found in expression_list' + .format(expression_list[0])) + if len(expression_list) == 1: + return expression_list[0] + else: + return BinaryComposition(u'&&', + expression_list_to_conjunction(expression_list[1:]), + expression_list[0]) class BetweenClause(Expression): @@ -22,6 +45,7 @@ def __init__(self, field, lower_bound, upper_bound): self.field = field self.lower_bound = lower_bound self.upper_bound = upper_bound + self.validate() def validate(self): """Validate that the Between Expression is correctly representable.""" @@ -60,6 +84,113 @@ def to_gremlin(self): raise NotImplementedError() +def filter_edge_field_non_existence(edge_expression): + """Return an Expression that is True iff the specified edge (edge_expression) does not exist.""" + # When an edge does not exist at a given vertex, OrientDB represents that in one of two ways: + # - the edge's field does not exist (is null) on the vertex document, or + # - the edge's field does exist, but is an empty list. + # We check both of these possibilities. + if not isinstance(edge_expression, (LocalField, SelectEdgeContextField)): + raise AssertionError(u'Received invalid edge_expression {} of type {}.' + u'Expected LocalField or SelectEdgeContextField.' + .format(edge_expression, type(edge_expression).__name__)) + if isinstance(edge_expression, LocalField): + if not is_vertex_field_name(edge_expression.field_name): + raise AssertionError(u'Received LocalField edge_expression {} with non-edge field_name ' + u'{}.'.format(edge_expression, edge_expression.field_name)) + + field_null_check = BinaryComposition(u'=', edge_expression, NullLiteral) + + local_field_size = UnaryTransformation(u'size', edge_expression) + field_size_check = BinaryComposition(u'=', local_field_size, ZeroLiteral) + + return BinaryComposition(u'||', field_null_check, field_size_check) + + +def _filter_orientdb_simple_optional_edge(optional_edge_location, inner_location_name): + """Return an Expression that is False for rows that don't follow the @optional specification. + + OrientDB does not filter correctly within optionals. Namely, a result where the optional edge + DOES EXIST will be returned regardless of whether the inner filter is satisfed. + To mitigate this, we add a final filter to reject such results. + A valid result must satisfy either of the following: + - The location within the optional exists (the filter will have been applied in this case) + - The optional edge field does not exist at the root location of the optional traverse + So, if the inner location within the optional was never visited, it must be the case that + the corresponding edge field does not exist at all. + + Example: + A MATCH traversal which starts at location `Animal___1`, and follows the optional edge + `out_Animal_ParentOf` to the location `Animal__out_Animal_ParentOf___1` + results in the following filtering Expression: + ( + ( + (Animal___1.out_Animal_ParentOf IS null) + OR + (Animal___1.out_Animal_ParentOf.size() = 0) + ) + OR + (Animal__out_Animal_ParentOf___1 IS NOT null) + ) + Here, the `optional_edge_location` is `Animal___1.out_Animal_ParentOf`. + + Args: + optional_edge_location: Location object representing the optional edge field + inner_location_name: string representing location within the corresponding optional traverse + + Returns: + Expression that evaluates to False for rows that do not follow the @optional specification + """ + inner_local_field = LocalField(inner_location_name) + inner_location_existence = BinaryComposition(u'!=', inner_local_field, NullLiteral) + + select_edge_context_field = SelectEdgeContextField(optional_edge_location) + edge_field_non_existence = filter_edge_field_non_existence(select_edge_context_field) + + return BinaryComposition(u'||', edge_field_non_existence, inner_location_existence) + + +def construct_where_filter_predicate(simple_optional_root_info): + """Return an Expression that is True if and only if each simple optional filter is True. + + Construct filters for each simple optional, that are True if and only if `edge_field` does + not exist in the `simple_optional_root_location` OR the `inner_location` is not defined. + Return an Expression that evaluates to True if and only if *all* of the aforementioned filters + evaluate to True (conjunction). + + Args: + simple_optional_root_info: dict mapping from simple_optional_root_location -> dict + containing keys + - 'inner_location_name': Location object correspoding to the + unique MarkLocation present within a + simple @optional (one that does not + expands vertex fields) scope + - 'edge_field': string representing the optional edge being + traversed + where simple_optional_root_to_inner_location is the location + preceding the @optional scope + Returns: + a new Expression object + """ + inner_location_name_to_where_filter = {} + for root_location, root_info_dict in six.iteritems(simple_optional_root_info): + inner_location_name = root_info_dict['inner_location_name'] + edge_field = root_info_dict['edge_field'] + + optional_edge_location = Location(root_location.query_path, field=edge_field) + optional_edge_where_filter = _filter_orientdb_simple_optional_edge( + optional_edge_location, inner_location_name) + inner_location_name_to_where_filter[inner_location_name] = optional_edge_where_filter + + # Sort expressions by inner_location_name to obtain deterministic order + where_filter_expressions = [ + inner_location_name_to_where_filter[key] + for key in sorted(inner_location_name_to_where_filter.keys()) + ] + + return expression_list_to_conjunction(where_filter_expressions) + + ### # A CompoundMatchQuery is a representation of several MatchQuery objects containing # - match_queries: a list MatchQuery objects diff --git a/graphql_compiler/compiler/match_query.py b/graphql_compiler/compiler/match_query.py index df2054403..25ce9dade 100644 --- a/graphql_compiler/compiler/match_query.py +++ b/graphql_compiler/compiler/match_query.py @@ -3,8 +3,8 @@ from collections import namedtuple -from .blocks import (Backtrack, CoerceType, ConstructResult, Filter, MarkLocation, OutputSource, - QueryRoot, Recurse, Traverse) +from .blocks import (Backtrack, CoerceType, ConstructResult, Filter, Fold, GlobalOperationsStart, + MarkLocation, OutputSource, QueryRoot, Recurse, Traverse, Unfold) from .ir_lowering_common import extract_folds_from_ir_blocks @@ -16,7 +16,8 @@ # - folds: a dict of FoldScopeLocation -> list of IR blocks defining that @fold scope, # not including the Fold and Unfold blocks that signal the start and end of the @fold. # - output_block: a ConstructResult IR block, which defines how the query's results are returned. -MatchQuery = namedtuple('MatchQuery', ('match_traversals', 'folds', 'output_block')) +# - where_block: an optional Filter block, which determines the WHERE statement for the query. +MatchQuery = namedtuple('MatchQuery', ('match_traversals', 'folds', 'output_block', 'where_block')) ### @@ -82,11 +83,11 @@ def _per_location_tuple_to_step(ir_tuple): return step -def _split_ir_into_match_steps(ir_blocks): +def _split_ir_into_match_steps(pruned_ir_blocks): """Split a list of IR blocks into per-location MATCH steps. Args: - ir_blocks: list of IR basic block objects that have gone through a lowering step. + pruned_ir_blocks: list of IR basic block objects that have gone through a lowering step. Returns: list of MatchStep namedtuples, each of which contains all basic blocks that correspond @@ -94,7 +95,7 @@ def _split_ir_into_match_steps(ir_blocks): """ output = [] current_tuple = None - for block in ir_blocks: + for block in pruned_ir_blocks: if isinstance(block, OutputSource): # OutputSource blocks do not require any MATCH code, and only serve to help # optimizations and debugging. Simply omit them at this stage. @@ -107,10 +108,10 @@ def _split_ir_into_match_steps(ir_blocks): current_tuple += (block,) else: raise AssertionError(u'Unexpected block type when converting to MATCH query: ' - u'{} {}'.format(block, ir_blocks)) + u'{} {}'.format(block, pruned_ir_blocks)) if current_tuple is None: - raise AssertionError(u'current_tuple was unexpectedly None: {}'.format(ir_blocks)) + raise AssertionError(u'current_tuple was unexpectedly None: {}'.format(pruned_ir_blocks)) output.append(current_tuple) return [_per_location_tuple_to_step(x) for x in output] @@ -135,6 +136,39 @@ def _split_match_steps_into_match_traversals(match_steps): return output +def _extract_global_operations(ir_blocks_except_output_and_folds): + """Extract all global operation blocks (all blocks following GlobalOperationsStart). + + Args: + ir_blocks_except_output_and_folds: list of IR blocks (excluding ConstructResult and all + fold blocks), to extract global operations from + + Returns: + tuple (global_operation_blocks, remaining_ir_blocks): + - global_operation_blocks: list of IR blocks following a GlobalOperationsStart block if it + exists, and an empty list otherwise + - remaining_ir_blocks: list of IR blocks excluding GlobalOperationsStart and all global + operation blocks + """ + global_operation_blocks = [] + remaining_ir_blocks = [] + in_global_operations_scope = False + + for block in ir_blocks_except_output_and_folds: + if isinstance(block, (ConstructResult, Fold, Unfold)): + raise AssertionError(u'Received unexpected block of type {}. No ConstructResult or ' + u'Fold/Unfold blocks should be present: {}' + .format(type(block).__name__, ir_blocks_except_output_and_folds)) + elif isinstance(block, GlobalOperationsStart): + in_global_operations_scope = True + elif in_global_operations_scope: + global_operation_blocks.append(block) + else: + remaining_ir_blocks.append(block) + + return global_operation_blocks, remaining_ir_blocks + + ############## # Public API # ############## @@ -145,12 +179,31 @@ def convert_to_match_query(ir_blocks): if not isinstance(output_block, ConstructResult): raise AssertionError(u'Expected last IR block to be ConstructResult, found: ' u'{} {}'.format(output_block, ir_blocks)) - ir_except_output = ir_blocks[:-1] + folds, ir_except_output_and_folds = extract_folds_from_ir_blocks(ir_except_output) - match_steps = _split_ir_into_match_steps(ir_except_output_and_folds) + # Extract WHERE Filter + global_operation_ir_blocks_tuple = _extract_global_operations(ir_except_output_and_folds) + global_operation_blocks, pruned_ir_blocks = global_operation_ir_blocks_tuple + if len(global_operation_blocks) > 1: + raise AssertionError(u'Received IR blocks with multiple global operation blocks. Only one ' + u'is allowed: {} {}'.format(global_operation_blocks, ir_blocks)) + if len(global_operation_blocks) == 1: + if not isinstance(global_operation_blocks[0], Filter): + raise AssertionError(u'Received non-Filter global operation block. {}' + .format(global_operation_blocks[0])) + where_block = global_operation_blocks[0] + else: + where_block = None + + match_steps = _split_ir_into_match_steps(pruned_ir_blocks) match_traversals = _split_match_steps_into_match_traversals(match_steps) - return MatchQuery(match_traversals=match_traversals, folds=folds, output_block=output_block) + return MatchQuery( + match_traversals=match_traversals, + folds=folds, + output_block=output_block, + where_block=where_block, + ) diff --git a/graphql_compiler/tests/test_compiler.py b/graphql_compiler/tests/test_compiler.py index 139c3c3f1..548e82caf 100644 --- a/graphql_compiler/tests/test_compiler.py +++ b/graphql_compiler/tests/test_compiler.py @@ -247,6 +247,14 @@ def test_optional_traverse_after_mandatory_traverse(self): }} RETURN $matches ) + WHERE ( ( + (Animal___1.out_Animal_ParentOf IS null) + OR + (Animal___1.out_Animal_ParentOf.size() = 0) + ) + OR + (Animal__out_Animal_ParentOf___1 IS NOT null) + ) ''' expected_gremlin = ''' g.V('@class', 'Animal') @@ -390,6 +398,15 @@ def test_filter_on_optional_variable_equality(self): }} RETURN $matches ) + WHERE ( + ( + (Animal__out_Animal_ParentOf___1.out_Animal_FedAt IS null) + OR + (Animal__out_Animal_ParentOf___1.out_Animal_FedAt.size() = 0) + ) + OR + (Animal__out_Animal_ParentOf__out_Animal_FedAt___1 IS NOT null) + ) ''' expected_gremlin = ''' g.V('@class', 'Animal') @@ -443,6 +460,15 @@ def test_filter_on_optional_variable_name_or_alias(self): }} RETURN $matches ) + WHERE ( + ( + (Animal___1.in_Animal_ParentOf IS null) + OR + (Animal___1.in_Animal_ParentOf.size() = 0) + ) + OR + (Animal__in_Animal_ParentOf___1 IS NOT null) + ) ''' expected_gremlin = ''' g.V('@class', 'Animal') @@ -485,6 +511,15 @@ def test_filter_in_optional_block(self): }} RETURN $matches ) + WHERE ( + ( + (Animal___1.out_Animal_FedAt IS null) + OR + (Animal___1.out_Animal_FedAt.size() = 0) + ) + OR + (Animal__out_Animal_FedAt___1 IS NOT null) + ) ''' expected_gremlin = ''' g.V('@class', 'Animal') @@ -770,6 +805,30 @@ def test_complex_optional_variables(self): }} RETURN $matches ) + WHERE ( + ( + ( + (Animal__out_Animal_ParentOf___1.out_Animal_FedAt IS null) + OR + (Animal__out_Animal_ParentOf___1.out_Animal_FedAt.size() = 0) + ) + OR + (Animal__out_Animal_ParentOf__out_Animal_FedAt___1 IS NOT null) + ) + AND + ( + ( + (Animal__out_Animal_ParentOf__in_Animal_ParentOf___1 + .out_Animal_FedAt IS null) + OR + (Animal__out_Animal_ParentOf__in_Animal_ParentOf___1 + .out_Animal_FedAt.size() = 0) + ) + OR + (Animal__out_Animal_ParentOf__in_Animal_ParentOf__out_Animal_FedAt___1 + IS NOT null) + ) + ) ''' expected_gremlin = ''' g.V('@class', 'Animal') @@ -1218,6 +1277,15 @@ def test_in_collection_op_filter_with_optional_tag(self): }} RETURN $matches ) + WHERE ( + ( + (Animal___1.in_Animal_ParentOf IS null) + OR + (Animal___1.in_Animal_ParentOf.size() = 0) + ) + OR + (Animal__in_Animal_ParentOf___1 IS NOT null) + ) ''' expected_gremlin = ''' g.V('@class', 'Animal') @@ -1323,6 +1391,15 @@ def test_contains_op_filter_with_optional_tag(self): }} RETURN $matches ) + WHERE ( + ( + (Animal___1.in_Animal_ParentOf IS null) + OR + (Animal___1.in_Animal_ParentOf.size() = 0) + ) + OR + (Animal__in_Animal_ParentOf___1 IS NOT null) + ) ''' expected_gremlin = ''' g.V('@class', 'Animal') @@ -1499,6 +1576,15 @@ def test_has_substring_op_filter_with_optional_tag(self): }} RETURN $matches ) + WHERE ( + ( + (Animal___1.in_Animal_ParentOf IS null) + OR + (Animal___1.in_Animal_ParentOf.size() = 0) + ) + OR + (Animal__in_Animal_ParentOf___1 IS NOT null) + ) ''' expected_gremlin = ''' g.V('@class', 'Animal') @@ -1599,6 +1685,15 @@ def test_has_edge_degree_op_filter_with_optional(self): }} RETURN $matches ) + WHERE ( + ( + (Species__in_Animal_OfSpecies___1.out_Animal_ParentOf IS null) + OR + (Species__in_Animal_OfSpecies___1.out_Animal_ParentOf.size() = 0) + ) + OR + (Species__in_Animal_OfSpecies__out_Animal_ParentOf___1 IS NOT null) + ) ''' expected_gremlin = ''' g.V('@class', 'Species') @@ -1764,6 +1859,15 @@ def test_optional_on_union(self): }} RETURN $matches ) + WHERE ( + ( + (Species___1.out_Species_Eats IS null) + OR + (Species___1.out_Species_Eats.size() = 0) + ) + OR + (Species__out_Species_Eats___1 IS NOT null) + ) ''' expected_gremlin = ''' g.V('@class', 'Species') @@ -1895,6 +1999,39 @@ def test_unnecessary_traversal_elimination(self): }} RETURN $matches ) + WHERE ( + ( + ( + ( + (Animal___1.out_Animal_ParentOf IS null) + OR + (Animal___1.out_Animal_ParentOf.size() = 0) + ) + OR + (Animal__out_Animal_ParentOf___1 IS NOT null) + ) + AND + ( + ( + (Animal___1.out_Animal_OfSpecies IS null) + OR + (Animal___1.out_Animal_OfSpecies.size() = 0) + ) + OR + (Animal__out_Animal_OfSpecies___1 IS NOT null) + ) + ) + AND + ( + ( + (Animal___1.out_Animal_FedAt IS null) + OR + (Animal___1.out_Animal_FedAt.size() = 0) + ) + OR + (Animal__out_Animal_FedAt___1 IS NOT null) + ) + ) ''' expected_gremlin = ''' g.V('@class', 'Animal') @@ -3306,6 +3443,15 @@ def test_optional_traversal_and_optional_without_traversal(self): }} RETURN $matches ) + WHERE ( + ( + (Animal___1.in_Animal_ParentOf IS null) + OR + (Animal___1.in_Animal_ParentOf.size() = 0) + ) + OR + (Animal__in_Animal_ParentOf___1 IS NOT null) + ) ), $optional__1 = ( SELECT @@ -3337,6 +3483,15 @@ def test_optional_traversal_and_optional_without_traversal(self): }} RETURN $matches ) + WHERE ( + ( + (Animal___1.in_Animal_ParentOf IS null) + OR + (Animal___1.in_Animal_ParentOf.size() = 0) + ) + OR + (Animal__in_Animal_ParentOf___1 IS NOT null) + ) ), $result = UNIONALL($optional__0, $optional__1) ''' @@ -3700,6 +3855,15 @@ def test_complex_optional_traversal_variables(self): }} RETURN $matches ) + WHERE ( + ( + (Animal__out_Animal_ParentOf___1.out_Animal_FedAt IS null) + OR + (Animal__out_Animal_ParentOf___1.out_Animal_FedAt.size() = 0) + ) + OR + (Animal__out_Animal_ParentOf__out_Animal_FedAt___1 IS NOT null) + ) ), $optional__1 = ( SELECT @@ -3769,6 +3933,15 @@ def test_complex_optional_traversal_variables(self): }} RETURN $matches ) + WHERE ( + ( + (Animal__out_Animal_ParentOf___1.out_Animal_FedAt IS null) + OR + (Animal__out_Animal_ParentOf___1.out_Animal_FedAt.size() = 0) + ) + OR + (Animal__out_Animal_ParentOf__out_Animal_FedAt___1 IS NOT null) + ) ), $result = UNIONALL($optional__0, $optional__1) ''' @@ -4027,6 +4200,15 @@ def test_optional_and_fold(self): ) LET $Animal___1___out_Animal_ParentOf = Animal___1.out("Animal_ParentOf").asList() + WHERE ( + ( + (Animal___1.in_Animal_ParentOf IS null) + OR + (Animal___1.in_Animal_ParentOf.size() = 0) + ) + OR + (Animal__in_Animal_ParentOf___1 IS NOT null) + ) ''' expected_gremlin = ''' g.V('@class', 'Animal') @@ -4074,6 +4256,15 @@ def test_fold_and_optional(self): ) LET $Animal___1___out_Animal_ParentOf = Animal___1.out("Animal_ParentOf").asList() + WHERE ( + ( + (Animal___1.in_Animal_ParentOf IS null) + OR + (Animal___1.in_Animal_ParentOf.size() = 0) + ) + OR + (Animal__in_Animal_ParentOf___1 IS NOT null) + ) ''' expected_gremlin = ''' g.V('@class', 'Animal') diff --git a/graphql_compiler/tests/test_emit_output.py b/graphql_compiler/tests/test_emit_output.py index 358a504e4..0ac4eb8ad 100644 --- a/graphql_compiler/tests/test_emit_output.py +++ b/graphql_compiler/tests/test_emit_output.py @@ -4,12 +4,13 @@ from graphql import GraphQLString from ..compiler import emit_gremlin, emit_match -from ..compiler.blocks import Backtrack, ConstructResult, Filter, MarkLocation, QueryRoot, Traverse +from ..compiler.blocks import (Backtrack, ConstructResult, Filter, GlobalOperationsStart, + MarkLocation, QueryRoot, Traverse) from ..compiler.expressions import (BinaryComposition, ContextField, LocalField, NullLiteral, OutputContextField, TernaryConditional, Variable) from ..compiler.helpers import Location from ..compiler.ir_lowering_common import OutputContextVertex -from ..compiler.ir_lowering_match.optional_traversal import CompoundMatchQuery +from ..compiler.ir_lowering_match.utils import CompoundMatchQuery, construct_where_filter_predicate from ..compiler.match_query import convert_to_match_query from ..schema import GraphQLDateTime from .test_helpers import compare_gremlin, compare_match @@ -29,7 +30,7 @@ def test_simple_immediate_output(self): MarkLocation(base_location), ConstructResult({ 'foo_name': OutputContextField(base_name_location, GraphQLString), - }) + }), ] match_query = convert_to_match_query(ir_blocks) compound_match_query = CompoundMatchQuery(match_queries=[match_query]) @@ -66,7 +67,7 @@ def test_simple_traverse_filter_output(self): MarkLocation(base_location), ConstructResult({ 'foo_name': OutputContextField(base_name_location, GraphQLString), - }) + }), ] match_query = convert_to_match_query(ir_blocks) compound_match_query = CompoundMatchQuery(match_queries=[match_query]) @@ -92,9 +93,16 @@ def test_simple_traverse_filter_output(self): def test_output_inside_optional_traversal(self): base_location = Location(('Foo',)) + child_location = base_location.navigate_to_subpath('out_Foo_Bar') + child_location_name, _ = child_location.get_location_name() + child_name_location = child_location.navigate_to_field('name') + simple_optional_root_info = { + base_location: {'inner_location_name': child_location_name, 'edge_field': 'out_Foo_Bar'} + } + ir_blocks = [ QueryRoot({'Foo'}), MarkLocation(base_location), @@ -103,6 +111,8 @@ def test_output_inside_optional_traversal(self): QueryRoot({'Foo'}), MarkLocation(base_location), + GlobalOperationsStart(), + Filter(construct_where_filter_predicate(simple_optional_root_info)), ConstructResult({ 'bar_name': TernaryConditional( BinaryComposition( @@ -112,7 +122,7 @@ def test_output_inside_optional_traversal(self): ), OutputContextField(child_name_location, GraphQLString), NullLiteral) - }) + }), ] match_query = convert_to_match_query(ir_blocks) compound_match_query = CompoundMatchQuery(match_queries=[match_query]) @@ -135,6 +145,15 @@ def test_output_inside_optional_traversal(self): }} RETURN $matches ) + WHERE ( + ( + (Foo___1.out_Foo_Bar IS null) + OR + (Foo___1.out_Foo_Bar.size() = 0) + ) + OR + (Foo__out_Foo_Bar___1 IS NOT null) + ) ''' received_match = emit_match.emit_code_from_ir(compound_match_query) @@ -160,7 +179,7 @@ def test_datetime_variable_representation(self): MarkLocation(base_location), ConstructResult({ 'name': OutputContextField(base_name_location, GraphQLString) - }) + }), ] match_query = convert_to_match_query(ir_blocks) compound_match_query = CompoundMatchQuery(match_queries=[match_query]) @@ -191,7 +210,7 @@ def test_datetime_output_representation(self): MarkLocation(base_location), ConstructResult({ 'event_date': OutputContextField(base_event_date_location, GraphQLDateTime) - }) + }), ] match_query = convert_to_match_query(ir_blocks) compound_match_query = CompoundMatchQuery(match_queries=[match_query]) diff --git a/graphql_compiler/tests/test_ir_lowering.py b/graphql_compiler/tests/test_ir_lowering.py index 951ae7ea9..f65d4c1c0 100644 --- a/graphql_compiler/tests/test_ir_lowering.py +++ b/graphql_compiler/tests/test_ir_lowering.py @@ -100,23 +100,21 @@ def test_context_field_existence_lowering_in_output(self): OutputContextField(child_name_location, GraphQLString), NullLiteral ) - }) + }), ] ir_sanity_checks.sanity_check_ir_blocks_from_frontend(ir_blocks) # The expected final blocks just have a rewritten ConstructResult block, # where the ContextFieldExistence expression is replaced with a null check. - expected_final_blocks = ir_blocks[:-1] - expected_final_blocks.append( - ConstructResult({ - 'child_name': TernaryConditional( - BinaryComposition(u'!=', - OutputContextVertex(child_location), - NullLiteral), - OutputContextField(child_name_location, GraphQLString), - NullLiteral) - }) - ) + expected_final_blocks = ir_blocks[:] + expected_final_blocks[-1] = ConstructResult({ + 'child_name': TernaryConditional( + BinaryComposition(u'!=', + OutputContextVertex(child_location), + NullLiteral), + OutputContextField(child_name_location, GraphQLString), + NullLiteral) + }) final_blocks = ir_lowering_match.lower_context_field_existence(ir_blocks) check_test_data(self, expected_final_blocks, final_blocks) @@ -202,7 +200,7 @@ def test_backtrack_block_lowering_simple(self): Backtrack(base_location), ConstructResult({ 'animal_name': OutputContextField(base_name_location, GraphQLString), - }) + }), ] ir_sanity_checks.sanity_check_ir_blocks_from_frontend(ir_blocks) @@ -225,7 +223,7 @@ def test_backtrack_block_lowering_simple(self): MarkLocation(base_location), ConstructResult({ 'animal_name': OutputContextField(base_name_location, GraphQLString), - }) + }), ] expected_final_query = convert_to_match_query(expected_final_blocks) @@ -249,7 +247,7 @@ def test_backtrack_block_lowering_revisiting_root(self): Backtrack(base_location), ConstructResult({ 'animal_name': OutputContextField(base_name_location, GraphQLString), - }) + }), ] ir_sanity_checks.sanity_check_ir_blocks_from_frontend(ir_blocks) @@ -278,7 +276,7 @@ def test_backtrack_block_lowering_revisiting_root(self): MarkLocation(base_location), ConstructResult({ 'animal_name': OutputContextField(base_name_location, GraphQLString), - }) + }), ] expected_final_query = convert_to_match_query(expected_final_blocks) @@ -300,7 +298,7 @@ def test_optional_backtrack_block_lowering(self): MarkLocation(base_location_revisited), ConstructResult({ 'animal_name': OutputContextField(base_name_location, GraphQLString), - }) + }), ] ir_sanity_checks.sanity_check_ir_blocks_from_frontend(ir_blocks) @@ -325,7 +323,7 @@ def test_optional_backtrack_block_lowering(self): MarkLocation(base_location), ConstructResult({ 'animal_name': OutputContextField(base_name_location, GraphQLString), - }) + }), ] expected_final_query = convert_to_match_query(expected_final_blocks)