Skip to content

Commit

Permalink
Ensure type coercions are preserved when hiding non-preferred locatio…
Browse files Browse the repository at this point in the history
…ns. (#110)

When exposing only preferred locations as start points for query processing, we have to hide all other locations that are eligible to be start points. There existed an edge case where in this process, we'd accidentally drop type coercions at such locations during the hiding process. This diff fixes that problem.
  • Loading branch information
obi1kenobi committed Jul 27, 2018
1 parent fa2e4d7 commit fb31516
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 11 deletions.
1 change: 1 addition & 0 deletions graphql_compiler/compiler/common.py
Expand Up @@ -91,6 +91,7 @@ def _compile_graphql_generic(language, lowering_func, query_emitter_func,

lowered_ir_blocks = lowering_func(
ir_and_metadata.ir_blocks, ir_and_metadata.location_types,
ir_and_metadata.coerced_locations,
type_equivalence_hints=type_equivalence_hints)

query = query_emitter_func(lowered_ir_blocks)
Expand Down
9 changes: 8 additions & 1 deletion graphql_compiler/compiler/compiler_frontend.py
Expand Up @@ -124,6 +124,7 @@ def __ne__(self, other):
'input_metadata',
'output_metadata',
'location_types',
'coerced_locations',
)
)

Expand Down Expand Up @@ -569,6 +570,7 @@ def _compile_fragment_ast(schema, current_schema_type, ast, location, context):

if not (is_same_type_as_scope or is_base_type_of_union):
# Coercion is required.
context['coerced_locations'].add(location)
basic_blocks.append(blocks.CoerceType({coerces_to_type_name}))

inner_basic_blocks = _compile_ast_node_to_ir(
Expand Down Expand Up @@ -665,6 +667,7 @@ def _compile_root_ast_to_ir(schema, ast, type_equivalence_hints=None):
- input_metadata: a dict of expected input parameters (string) -> inferred GraphQL type
- output_metadata: a dict of output name (string) -> OutputMetadata object
- location_types: a dict of location objects -> GraphQL type objects at that location
- coerced_locations: a set of location objects indicating where type coercions have happened
"""
if len(ast.selection_set.selections) != 1:
raise GraphQLCompilationError(u'Cannot process AST with more than one root selection!')
Expand Down Expand Up @@ -696,6 +699,9 @@ def _compile_root_ast_to_ir(schema, ast, type_equivalence_hints=None):
# 'location_types' is a dict mapping each Location to its GraphQLType
# (schema type of the location)
'location_types': dict(),
# 'coerced_locations' is the set of all locations whose type was coerced to a subtype
# of the type already implied by the GraphQL schema for that vertex field.
'coerced_locations': set(),
# 'type_equivalence_hints' is a dict mapping GraphQL types to equivalent GraphQL unions
'type_equivalence_hints': type_equivalence_hints or dict(),
# The marked_location_stack explicitly maintains a stack (implemented as list)
Expand Down Expand Up @@ -746,7 +752,8 @@ def _compile_root_ast_to_ir(schema, ast, type_equivalence_hints=None):
ir_blocks=basic_blocks,
input_metadata=context['inputs'],
output_metadata=output_metadata,
location_types=context['location_types'])
location_types=context['location_types'],
coerced_locations=context['coerced_locations'])


def _compile_output_step(outputs):
Expand Down
6 changes: 4 additions & 2 deletions graphql_compiler/compiler/ir_lowering_gremlin/__init__.py
Expand Up @@ -10,12 +10,14 @@
# Public API #
##############

def lower_ir(ir_blocks, location_types, type_equivalence_hints=None):
def lower_ir(ir_blocks, location_types, coerced_locations, type_equivalence_hints=None):
"""Lower the IR into an IR form that can be represented in Gremlin queries.
Args:
ir_blocks: list of IR blocks to lower into Gremlin-compatible form
location_types: a dict of location objects -> GraphQL type objects at that location
location_types: dict of location objects -> GraphQL type objects at that location
coerced_locations: set of locations where type coercions were applied to constrain the type
relative to the type inferred by the GraphQL schema and the given field
type_equivalence_hints: optional dict of GraphQL interface or type -> GraphQL union.
Used as a workaround for GraphQL's lack of support for
inheritance across "types" (i.e. non-interfaces), as well as a
Expand Down
8 changes: 5 additions & 3 deletions graphql_compiler/compiler/ir_lowering_match/__init__.py
Expand Up @@ -28,12 +28,14 @@
##############


def lower_ir(ir_blocks, location_types, type_equivalence_hints=None):
def lower_ir(ir_blocks, location_types, coerced_locations, type_equivalence_hints=None):
"""Lower the IR into an IR form that can be represented in MATCH queries.
Args:
ir_blocks: list of IR blocks to lower into MATCH-compatible form
location_types: a dict of location objects -> GraphQL type objects at that location
location_types: dict of location objects -> GraphQL type objects at that location
coerced_locations: set of locations where type coercions were applied to constrain the type
relative to the type inferred by the GraphQL schema and the given field
type_equivalence_hints: optional dict of GraphQL interface or type -> GraphQL union.
Used as a workaround for GraphQL's lack of support for
inheritance across "types" (i.e. non-interfaces), as well as a
Expand Down Expand Up @@ -108,6 +110,6 @@ def lower_ir(ir_blocks, location_types, type_equivalence_hints=None):
compound_match_query = truncate_repeated_single_step_traversals_in_sub_queries(
compound_match_query)
compound_match_query = orientdb_query_execution.expose_ideal_query_execution_start_points(
compound_match_query, location_types)
compound_match_query, location_types, coerced_locations)

return compound_match_query
5 changes: 5 additions & 0 deletions graphql_compiler/compiler/ir_lowering_match/ir_lowering.py
Expand Up @@ -203,6 +203,11 @@ def lower_backtrack_blocks(match_query, location_types):
if step.as_block is not None:
location_translations[step.as_block.location] = backtrack_location

if step.coerce_type_block is not None:
raise AssertionError(u'Encountered type coercion in a MatchStep with '
u'a Backtrack root block, this is unexpected: {} {}'
.format(step, match_query))

new_step = step._replace(root_block=new_root_block, as_block=new_as_block)
new_traversal.append(new_step)

Expand Down
Expand Up @@ -216,7 +216,7 @@ def _assert_type_bounds_are_not_conflicting(current_type_bound, previous_type_bo
u'for query {}'.format(location, previous_type_bound, current_type_bound, match_query))


def _expose_only_preferred_locations(match_query, location_types,
def _expose_only_preferred_locations(match_query, location_types, coerced_locations,
preferred_locations, eligible_locations):
"""Return a MATCH query where only preferred locations are valid as query start locations."""
preferred_location_types = dict()
Expand Down Expand Up @@ -267,7 +267,7 @@ def _expose_only_preferred_locations(match_query, location_types,
# we ensure that we again infer the same type bound.
eligible_location_types[current_step_location] = current_type_bound

if current_type_bound == location_types[current_step_location].name:
if current_step_location not in coerced_locations:
# The type bound here is already implied by the GraphQL query structure.
# We can simply delete the QueryRoot / CoerceType blocks that impart it.
if isinstance(match_step.root_block, QueryRoot):
Expand Down Expand Up @@ -348,7 +348,8 @@ def _expose_all_eligible_locations(match_query, location_types, eligible_locatio
return match_query._replace(match_traversals=new_match_traversals)


def expose_ideal_query_execution_start_points(compound_match_query, location_types):
def expose_ideal_query_execution_start_points(compound_match_query, location_types,
coerced_locations):
"""Ensure that OrientDB only considers desirable query start points in query planning."""
new_queries = []

Expand All @@ -363,7 +364,8 @@ def expose_ideal_query_execution_start_points(compound_match_query, location_typ
# to the location. We remove it by converting the class check into
# an "INSTANCEOF" Filter block, which OrientDB is unable to optimize away.
new_query = _expose_only_preferred_locations(
match_query, location_types, preferred_locations, eligible_locations)
match_query, location_types, coerced_locations,
preferred_locations, eligible_locations)
elif eligible_locations:
# Make sure that all eligible locations have a "class:" clause by adding
# a CoerceType block that is a no-op as guaranteed by the schema. This merely
Expand Down
35 changes: 35 additions & 0 deletions graphql_compiler/tests/test_compiler.py
Expand Up @@ -1929,6 +1929,41 @@ def test_simple_union(self):

check_test_data(self, test_data, expected_match, expected_gremlin)

def test_filter_then_apply_fragment(self):
test_data = test_input_data.filter_then_apply_fragment()

expected_match = '''
SELECT
Species__out_Species_Eats___1.name AS `food_name`,
Species___1.name AS `species_name`
FROM (
MATCH {{
class: Species,
where: (({species} CONTAINS name)),
as: Species___1
}}.out('Species_Eats') {{
where: ((@this INSTANCEOF 'Food')),
as: Species__out_Species_Eats___1
}}
RETURN $matches
)
'''
expected_gremlin = '''
g.V('@class', 'Species')
.filter{it, m -> $species.contains(it.name)}
.as('Species___1')
.out('Species_Eats')
.filter{it, m -> ['Food'].contains(it['@class'])}
.as('Species__out_Species_Eats___1')
.back('Species___1')
.transform{it, m -> new com.orientechnologies.orient.core.record.impl.ODocument([
food_name: m.Species__out_Species_Eats___1.name,
species_name: m.Species___1.name
])}
'''

check_test_data(self, test_data, expected_match, expected_gremlin)

def test_filter_on_fragment_in_union(self):
test_data = test_input_data.filter_on_fragment_in_union()

Expand Down
27 changes: 27 additions & 0 deletions graphql_compiler/tests/test_input_data.py
Expand Up @@ -540,6 +540,33 @@ def simple_union():
type_equivalence_hints=None)


def filter_then_apply_fragment():
graphql_input = '''{
Species {
name @filter(op_name: "in_collection", value: ["$species"])
@output(out_name: "species_name")
out_Species_Eats {
... on Food {
name @output(out_name: "food_name")
}
}
}
}'''
expected_output_metadata = {
'species_name': OutputMetadata(type=GraphQLString, optional=False),
'food_name': OutputMetadata(type=GraphQLString, optional=False),
}
expected_input_metadata = {
'species': GraphQLList(GraphQLString),
}

return CommonTestData(
graphql_input=graphql_input,
expected_output_metadata=expected_output_metadata,
expected_input_metadata=expected_input_metadata,
type_equivalence_hints=None)


def filter_on_fragment_in_union():
graphql_input = '''{
Species {
Expand Down
34 changes: 34 additions & 0 deletions graphql_compiler/tests/test_ir_generation.py
Expand Up @@ -942,6 +942,40 @@ def test_simple_union(self):

check_test_data(self, test_data, expected_blocks, expected_location_types)

def test_filter_then_apply_fragment(self):
test_data = test_input_data.filter_then_apply_fragment()

base_location = helpers.Location(('Species',))
food_location = base_location.navigate_to_subpath('out_Species_Eats')

expected_blocks = [
blocks.QueryRoot({'Species'}),
blocks.Filter(
expressions.BinaryComposition(
u'contains',
expressions.Variable('$species', GraphQLList(GraphQLString)),
expressions.LocalField('name')
)
),
blocks.MarkLocation(base_location),
blocks.Traverse('out', 'Species_Eats'),
blocks.CoerceType({'Food'}),
blocks.MarkLocation(food_location),
blocks.Backtrack(base_location),
blocks.ConstructResult({
'species_name': expressions.OutputContextField(
base_location.navigate_to_field('name'), GraphQLString),
'food_name': expressions.OutputContextField(
food_location.navigate_to_field('name'), GraphQLString),
}),
]
expected_location_types = {
base_location: 'Species',
food_location: 'Food',
}

check_test_data(self, test_data, expected_blocks, expected_location_types)

def test_filter_on_fragment_in_union(self):
test_data = test_input_data.filter_on_fragment_in_union()

Expand Down
3 changes: 2 additions & 1 deletion graphql_compiler/tests/test_ir_lowering.py
Expand Up @@ -758,6 +758,7 @@ def test_optional_traversal_edge_case(self):
child_fed_at_location: 'Event',
revisited_base_location: 'Animal',
})
coerced_locations = set()

expected_final_blocks_without_optional_traverse = [
QueryRoot({'Animal'}),
Expand Down Expand Up @@ -810,7 +811,7 @@ def test_optional_traversal_edge_case(self):
]
)

final_query = ir_lowering_match.lower_ir(ir_blocks, location_types)
final_query = ir_lowering_match.lower_ir(ir_blocks, location_types, coerced_locations)

self.assertEqual(
expected_compound_match_query, final_query,
Expand Down

0 comments on commit fb31516

Please sign in to comment.