From fb31516f0b6da5f9b213ba0263088e0ecff5b687 Mon Sep 17 00:00:00 2001 From: Predrag Gruevski Date: Fri, 27 Jul 2018 18:10:49 -0400 Subject: [PATCH] Ensure type coercions are preserved when hiding non-preferred locations. (#110) When exposing only preferred locations as start points for query processing, we have to hide all other locations that are eligible to be start points. There existed an edge case where in this process, we'd accidentally drop type coercions at such locations during the hiding process. This diff fixes that problem. --- graphql_compiler/compiler/common.py | 1 + .../compiler/compiler_frontend.py | 9 ++++- .../compiler/ir_lowering_gremlin/__init__.py | 6 ++-- .../compiler/ir_lowering_match/__init__.py | 8 +++-- .../compiler/ir_lowering_match/ir_lowering.py | 5 +++ .../workarounds/orientdb_query_execution.py | 10 +++--- graphql_compiler/tests/test_compiler.py | 35 +++++++++++++++++++ graphql_compiler/tests/test_input_data.py | 27 ++++++++++++++ graphql_compiler/tests/test_ir_generation.py | 34 ++++++++++++++++++ graphql_compiler/tests/test_ir_lowering.py | 3 +- 10 files changed, 127 insertions(+), 11 deletions(-) diff --git a/graphql_compiler/compiler/common.py b/graphql_compiler/compiler/common.py index c4661bb59..91ebd7f33 100644 --- a/graphql_compiler/compiler/common.py +++ b/graphql_compiler/compiler/common.py @@ -91,6 +91,7 @@ def _compile_graphql_generic(language, lowering_func, query_emitter_func, lowered_ir_blocks = lowering_func( ir_and_metadata.ir_blocks, ir_and_metadata.location_types, + ir_and_metadata.coerced_locations, type_equivalence_hints=type_equivalence_hints) query = query_emitter_func(lowered_ir_blocks) diff --git a/graphql_compiler/compiler/compiler_frontend.py b/graphql_compiler/compiler/compiler_frontend.py index 2e2afd6b7..df7efe330 100644 --- a/graphql_compiler/compiler/compiler_frontend.py +++ b/graphql_compiler/compiler/compiler_frontend.py @@ -124,6 +124,7 @@ def __ne__(self, other): 'input_metadata', 'output_metadata', 'location_types', + 'coerced_locations', ) ) @@ -569,6 +570,7 @@ def _compile_fragment_ast(schema, current_schema_type, ast, location, context): if not (is_same_type_as_scope or is_base_type_of_union): # Coercion is required. + context['coerced_locations'].add(location) basic_blocks.append(blocks.CoerceType({coerces_to_type_name})) inner_basic_blocks = _compile_ast_node_to_ir( @@ -665,6 +667,7 @@ def _compile_root_ast_to_ir(schema, ast, type_equivalence_hints=None): - input_metadata: a dict of expected input parameters (string) -> inferred GraphQL type - output_metadata: a dict of output name (string) -> OutputMetadata object - location_types: a dict of location objects -> GraphQL type objects at that location + - coerced_locations: a set of location objects indicating where type coercions have happened """ if len(ast.selection_set.selections) != 1: raise GraphQLCompilationError(u'Cannot process AST with more than one root selection!') @@ -696,6 +699,9 @@ def _compile_root_ast_to_ir(schema, ast, type_equivalence_hints=None): # 'location_types' is a dict mapping each Location to its GraphQLType # (schema type of the location) 'location_types': dict(), + # 'coerced_locations' is the set of all locations whose type was coerced to a subtype + # of the type already implied by the GraphQL schema for that vertex field. + 'coerced_locations': set(), # 'type_equivalence_hints' is a dict mapping GraphQL types to equivalent GraphQL unions 'type_equivalence_hints': type_equivalence_hints or dict(), # The marked_location_stack explicitly maintains a stack (implemented as list) @@ -746,7 +752,8 @@ def _compile_root_ast_to_ir(schema, ast, type_equivalence_hints=None): ir_blocks=basic_blocks, input_metadata=context['inputs'], output_metadata=output_metadata, - location_types=context['location_types']) + location_types=context['location_types'], + coerced_locations=context['coerced_locations']) def _compile_output_step(outputs): diff --git a/graphql_compiler/compiler/ir_lowering_gremlin/__init__.py b/graphql_compiler/compiler/ir_lowering_gremlin/__init__.py index 62120b7df..530df40e3 100644 --- a/graphql_compiler/compiler/ir_lowering_gremlin/__init__.py +++ b/graphql_compiler/compiler/ir_lowering_gremlin/__init__.py @@ -10,12 +10,14 @@ # Public API # ############## -def lower_ir(ir_blocks, location_types, type_equivalence_hints=None): +def lower_ir(ir_blocks, location_types, coerced_locations, type_equivalence_hints=None): """Lower the IR into an IR form that can be represented in Gremlin queries. Args: ir_blocks: list of IR blocks to lower into Gremlin-compatible form - location_types: a dict of location objects -> GraphQL type objects at that location + location_types: dict of location objects -> GraphQL type objects at that location + coerced_locations: set of locations where type coercions were applied to constrain the type + relative to the type inferred by the GraphQL schema and the given field type_equivalence_hints: optional dict of GraphQL interface or type -> GraphQL union. Used as a workaround for GraphQL's lack of support for inheritance across "types" (i.e. non-interfaces), as well as a diff --git a/graphql_compiler/compiler/ir_lowering_match/__init__.py b/graphql_compiler/compiler/ir_lowering_match/__init__.py index 636314caa..62f59a118 100644 --- a/graphql_compiler/compiler/ir_lowering_match/__init__.py +++ b/graphql_compiler/compiler/ir_lowering_match/__init__.py @@ -28,12 +28,14 @@ ############## -def lower_ir(ir_blocks, location_types, type_equivalence_hints=None): +def lower_ir(ir_blocks, location_types, coerced_locations, type_equivalence_hints=None): """Lower the IR into an IR form that can be represented in MATCH queries. Args: ir_blocks: list of IR blocks to lower into MATCH-compatible form - location_types: a dict of location objects -> GraphQL type objects at that location + location_types: dict of location objects -> GraphQL type objects at that location + coerced_locations: set of locations where type coercions were applied to constrain the type + relative to the type inferred by the GraphQL schema and the given field type_equivalence_hints: optional dict of GraphQL interface or type -> GraphQL union. Used as a workaround for GraphQL's lack of support for inheritance across "types" (i.e. non-interfaces), as well as a @@ -108,6 +110,6 @@ def lower_ir(ir_blocks, location_types, type_equivalence_hints=None): compound_match_query = truncate_repeated_single_step_traversals_in_sub_queries( compound_match_query) compound_match_query = orientdb_query_execution.expose_ideal_query_execution_start_points( - compound_match_query, location_types) + compound_match_query, location_types, coerced_locations) return compound_match_query diff --git a/graphql_compiler/compiler/ir_lowering_match/ir_lowering.py b/graphql_compiler/compiler/ir_lowering_match/ir_lowering.py index 091738835..ab5649973 100644 --- a/graphql_compiler/compiler/ir_lowering_match/ir_lowering.py +++ b/graphql_compiler/compiler/ir_lowering_match/ir_lowering.py @@ -203,6 +203,11 @@ def lower_backtrack_blocks(match_query, location_types): if step.as_block is not None: location_translations[step.as_block.location] = backtrack_location + if step.coerce_type_block is not None: + raise AssertionError(u'Encountered type coercion in a MatchStep with ' + u'a Backtrack root block, this is unexpected: {} {}' + .format(step, match_query)) + new_step = step._replace(root_block=new_root_block, as_block=new_as_block) new_traversal.append(new_step) diff --git a/graphql_compiler/compiler/workarounds/orientdb_query_execution.py b/graphql_compiler/compiler/workarounds/orientdb_query_execution.py index 0f0e003f9..975563724 100644 --- a/graphql_compiler/compiler/workarounds/orientdb_query_execution.py +++ b/graphql_compiler/compiler/workarounds/orientdb_query_execution.py @@ -216,7 +216,7 @@ def _assert_type_bounds_are_not_conflicting(current_type_bound, previous_type_bo u'for query {}'.format(location, previous_type_bound, current_type_bound, match_query)) -def _expose_only_preferred_locations(match_query, location_types, +def _expose_only_preferred_locations(match_query, location_types, coerced_locations, preferred_locations, eligible_locations): """Return a MATCH query where only preferred locations are valid as query start locations.""" preferred_location_types = dict() @@ -267,7 +267,7 @@ def _expose_only_preferred_locations(match_query, location_types, # we ensure that we again infer the same type bound. eligible_location_types[current_step_location] = current_type_bound - if current_type_bound == location_types[current_step_location].name: + if current_step_location not in coerced_locations: # The type bound here is already implied by the GraphQL query structure. # We can simply delete the QueryRoot / CoerceType blocks that impart it. if isinstance(match_step.root_block, QueryRoot): @@ -348,7 +348,8 @@ def _expose_all_eligible_locations(match_query, location_types, eligible_locatio return match_query._replace(match_traversals=new_match_traversals) -def expose_ideal_query_execution_start_points(compound_match_query, location_types): +def expose_ideal_query_execution_start_points(compound_match_query, location_types, + coerced_locations): """Ensure that OrientDB only considers desirable query start points in query planning.""" new_queries = [] @@ -363,7 +364,8 @@ def expose_ideal_query_execution_start_points(compound_match_query, location_typ # to the location. We remove it by converting the class check into # an "INSTANCEOF" Filter block, which OrientDB is unable to optimize away. new_query = _expose_only_preferred_locations( - match_query, location_types, preferred_locations, eligible_locations) + match_query, location_types, coerced_locations, + preferred_locations, eligible_locations) elif eligible_locations: # Make sure that all eligible locations have a "class:" clause by adding # a CoerceType block that is a no-op as guaranteed by the schema. This merely diff --git a/graphql_compiler/tests/test_compiler.py b/graphql_compiler/tests/test_compiler.py index 0887dc41d..d7f106787 100644 --- a/graphql_compiler/tests/test_compiler.py +++ b/graphql_compiler/tests/test_compiler.py @@ -1929,6 +1929,41 @@ def test_simple_union(self): check_test_data(self, test_data, expected_match, expected_gremlin) + def test_filter_then_apply_fragment(self): + test_data = test_input_data.filter_then_apply_fragment() + + expected_match = ''' + SELECT + Species__out_Species_Eats___1.name AS `food_name`, + Species___1.name AS `species_name` + FROM ( + MATCH {{ + class: Species, + where: (({species} CONTAINS name)), + as: Species___1 + }}.out('Species_Eats') {{ + where: ((@this INSTANCEOF 'Food')), + as: Species__out_Species_Eats___1 + }} + RETURN $matches + ) + ''' + expected_gremlin = ''' + g.V('@class', 'Species') + .filter{it, m -> $species.contains(it.name)} + .as('Species___1') + .out('Species_Eats') + .filter{it, m -> ['Food'].contains(it['@class'])} + .as('Species__out_Species_Eats___1') + .back('Species___1') + .transform{it, m -> new com.orientechnologies.orient.core.record.impl.ODocument([ + food_name: m.Species__out_Species_Eats___1.name, + species_name: m.Species___1.name + ])} + ''' + + check_test_data(self, test_data, expected_match, expected_gremlin) + def test_filter_on_fragment_in_union(self): test_data = test_input_data.filter_on_fragment_in_union() diff --git a/graphql_compiler/tests/test_input_data.py b/graphql_compiler/tests/test_input_data.py index b61f7b409..78e90a2e0 100644 --- a/graphql_compiler/tests/test_input_data.py +++ b/graphql_compiler/tests/test_input_data.py @@ -540,6 +540,33 @@ def simple_union(): type_equivalence_hints=None) +def filter_then_apply_fragment(): + graphql_input = '''{ + Species { + name @filter(op_name: "in_collection", value: ["$species"]) + @output(out_name: "species_name") + out_Species_Eats { + ... on Food { + name @output(out_name: "food_name") + } + } + } + }''' + expected_output_metadata = { + 'species_name': OutputMetadata(type=GraphQLString, optional=False), + 'food_name': OutputMetadata(type=GraphQLString, optional=False), + } + expected_input_metadata = { + 'species': GraphQLList(GraphQLString), + } + + return CommonTestData( + graphql_input=graphql_input, + expected_output_metadata=expected_output_metadata, + expected_input_metadata=expected_input_metadata, + type_equivalence_hints=None) + + def filter_on_fragment_in_union(): graphql_input = '''{ Species { diff --git a/graphql_compiler/tests/test_ir_generation.py b/graphql_compiler/tests/test_ir_generation.py index d30fc037d..f81b13927 100644 --- a/graphql_compiler/tests/test_ir_generation.py +++ b/graphql_compiler/tests/test_ir_generation.py @@ -942,6 +942,40 @@ def test_simple_union(self): check_test_data(self, test_data, expected_blocks, expected_location_types) + def test_filter_then_apply_fragment(self): + test_data = test_input_data.filter_then_apply_fragment() + + base_location = helpers.Location(('Species',)) + food_location = base_location.navigate_to_subpath('out_Species_Eats') + + expected_blocks = [ + blocks.QueryRoot({'Species'}), + blocks.Filter( + expressions.BinaryComposition( + u'contains', + expressions.Variable('$species', GraphQLList(GraphQLString)), + expressions.LocalField('name') + ) + ), + blocks.MarkLocation(base_location), + blocks.Traverse('out', 'Species_Eats'), + blocks.CoerceType({'Food'}), + blocks.MarkLocation(food_location), + blocks.Backtrack(base_location), + blocks.ConstructResult({ + 'species_name': expressions.OutputContextField( + base_location.navigate_to_field('name'), GraphQLString), + 'food_name': expressions.OutputContextField( + food_location.navigate_to_field('name'), GraphQLString), + }), + ] + expected_location_types = { + base_location: 'Species', + food_location: 'Food', + } + + check_test_data(self, test_data, expected_blocks, expected_location_types) + def test_filter_on_fragment_in_union(self): test_data = test_input_data.filter_on_fragment_in_union() diff --git a/graphql_compiler/tests/test_ir_lowering.py b/graphql_compiler/tests/test_ir_lowering.py index 11e4449de..058873e87 100644 --- a/graphql_compiler/tests/test_ir_lowering.py +++ b/graphql_compiler/tests/test_ir_lowering.py @@ -758,6 +758,7 @@ def test_optional_traversal_edge_case(self): child_fed_at_location: 'Event', revisited_base_location: 'Animal', }) + coerced_locations = set() expected_final_blocks_without_optional_traverse = [ QueryRoot({'Animal'}), @@ -810,7 +811,7 @@ def test_optional_traversal_edge_case(self): ] ) - final_query = ir_lowering_match.lower_ir(ir_blocks, location_types) + final_query = ir_lowering_match.lower_ir(ir_blocks, location_types, coerced_locations) self.assertEqual( expected_compound_match_query, final_query,