Skip to content

Commit

Permalink
Add MarkLocation blocks inside fold scopes.
Browse files Browse the repository at this point in the history
  • Loading branch information
obi1kenobi committed Aug 20, 2018
1 parent 7f54185 commit abb266a
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 13 deletions.
3 changes: 1 addition & 2 deletions graphql_compiler/compiler/compiler_frontend.py
Expand Up @@ -373,9 +373,8 @@ def _compile_vertex_ast(schema, current_schema_type, ast,
initial_marked_location_stack_size = len(context['marked_location_stack'])

# step V-3: mark the graph position, and process output_source directive
basic_blocks.append(_mark_location(location))
if not is_in_fold_scope(context):
# We only mark the position if we aren't in a folded scope.
basic_blocks.append(_mark_location(location))
# The following append is the Location corresponding to the initial MarkLocation
# for the current vertex and the `num_traverses` counter set to 0.
context['marked_location_stack'].append(_construct_location_stack_entry(location, 0))
Expand Down
7 changes: 5 additions & 2 deletions graphql_compiler/compiler/emit_match.py
Expand Up @@ -4,7 +4,7 @@

import six

from .blocks import Filter, QueryRoot, Recurse, Traverse
from .blocks import Filter, MarkLocation, QueryRoot, Recurse, Traverse
from .expressions import TrueLiteral
from .helpers import get_only_element_from_collection, validate_safe_string

Expand Down Expand Up @@ -133,8 +133,11 @@ def _represent_fold(fold_location, fold_ir_blocks):
'edge_name': block.edge_name,
}
final_string += traverse_edge_template % template_data
elif isinstance(block, MarkLocation):
# MarkLocation blocks inside a fold do not result in any MATCH output.
pass
else:
raise AssertionError(u'Found a non-Filter/Traverse IR block in the folded IR blocks: '
raise AssertionError(u'Found an unexpected IR block in the folded IR blocks: '
u'{} {} {}'.format(type(block), block, fold_ir_blocks))

# Workaround for OrientDB's inconsistent return type when filtering a list.
Expand Down
4 changes: 2 additions & 2 deletions graphql_compiler/compiler/helpers.py
Expand Up @@ -204,8 +204,8 @@ def validate_edge_direction(edge_direction):

def validate_marked_location(location):
"""Validate that a Location object is safe for marking, and not at a field."""
if not isinstance(location, Location):
raise TypeError(u'Expected Location location, got: {} {}'.format(
if not isinstance(location, (Location, FoldScopeLocation)):
raise TypeError(u'Expected Location or FoldScopeLocation location, got: {} {}'.format(
type(location).__name__, location))

if location.field is not None:
Expand Down
16 changes: 11 additions & 5 deletions graphql_compiler/compiler/ir_lowering_common.py
Expand Up @@ -228,16 +228,19 @@ def extract_optional_location_root_info(ir_blocks):
in_optional_root_location = None
encountered_traverse_within_optional = False

# Blocks within folded scopes should not be taken into account in this function.
_, non_folded_ir_blocks = extract_folds_from_ir_blocks(ir_blocks)

preceding_location = None
for current_block in ir_blocks:
for current_block in non_folded_ir_blocks:
if isinstance(current_block, Traverse) and current_block.optional:
if in_optional_root_location is not None:
raise AssertionError(u'in_optional_root_location was not None at an optional '
u'Traverse: {} {}'.format(current_block, ir_blocks))
u'Traverse: {} {}'.format(current_block, non_folded_ir_blocks))

if preceding_location is None:
raise AssertionError(u'No MarkLocation found before an optional Traverse: {} {}'
.format(current_block, ir_blocks))
.format(current_block, non_folded_ir_blocks))

in_optional_root_location = preceding_location
elif all((in_optional_root_location is not None,
Expand All @@ -246,7 +249,7 @@ def extract_optional_location_root_info(ir_blocks):
elif isinstance(current_block, EndOptional):
if in_optional_root_location is None:
raise AssertionError(u'in_optional_root_location was None at an EndOptional block: '
u'{}'.format(ir_blocks))
u'{}'.format(non_folded_ir_blocks))

if encountered_traverse_within_optional:
complex_optional_roots.append(in_optional_root_location)
Expand Down Expand Up @@ -297,9 +300,12 @@ def extract_simple_optional_location_info(
}
simple_optional_root_locations = set(simple_optional_root_to_inner_location.keys())

# Blocks within folded scopes should not be taken into account in this function.
_, non_folded_ir_blocks = extract_folds_from_ir_blocks(ir_blocks)

simple_optional_root_info = {}
preceding_location = None
for current_block in ir_blocks:
for current_block in non_folded_ir_blocks:
if isinstance(current_block, MarkLocation):
preceding_location = current_block.location
elif isinstance(current_block, Traverse) and current_block.optional:
Expand Down
9 changes: 7 additions & 2 deletions graphql_compiler/compiler/ir_lowering_gremlin/ir_lowering.py
Expand Up @@ -14,7 +14,7 @@

from ...exceptions import GraphQLCompilationError
from ...schema import GraphQLDate, GraphQLDateTime
from ..blocks import Backtrack, CoerceType, ConstructResult, Filter, Traverse
from ..blocks import Backtrack, CoerceType, ConstructResult, Filter, MarkLocation, Traverse
from ..compiler_entities import Expression
from ..expressions import (BinaryComposition, FoldedOutputContextField, Literal, LocalField,
NullLiteral)
Expand Down Expand Up @@ -303,8 +303,13 @@ def folded_context_visitor(expression):
new_block = GremlinFoldedFilter(new_predicate)
elif isinstance(block, Traverse):
new_block = GremlinFoldedTraverse.from_traverse(block)
else:
elif isinstance(block, MarkLocation):
# We remove MarkLocation blocks from the folded blocks output,
# since they do not produce any Gremlin output code.
continue
else:
raise AssertionError(u'Found an unexpected IR block in the folded IR blocks: '
u'{} {} {}'.format(type(block), block, folded_ir_blocks))

new_folded_ir_blocks.append(new_block)

Expand Down
39 changes: 39 additions & 0 deletions graphql_compiler/tests/test_ir_generation.py
Expand Up @@ -2146,6 +2146,7 @@ def test_has_edge_degree_op_filter_with_fold(self):
),
blocks.MarkLocation(animal_location),
blocks.Fold(animal_fold),
blocks.MarkLocation(animal_fold),
blocks.Unfold(),
blocks.Backtrack(base_location),
blocks.ConstructResult({
Expand Down Expand Up @@ -2175,6 +2176,7 @@ def test_fold_on_output_variable(self):
blocks.QueryRoot({'Animal'}),
blocks.MarkLocation(base_location),
blocks.Fold(base_fold),
blocks.MarkLocation(base_fold),
blocks.Unfold(),
blocks.ConstructResult({
'animal_name': expressions.OutputContextField(
Expand Down Expand Up @@ -2203,6 +2205,7 @@ def test_fold_after_traverse(self):
blocks.Traverse('in', 'Animal_ParentOf'),
blocks.MarkLocation(parent_location),
blocks.Fold(parent_fold),
blocks.MarkLocation(parent_fold),
blocks.Unfold(),
blocks.Backtrack(base_location),
blocks.ConstructResult({
Expand Down Expand Up @@ -2231,7 +2234,9 @@ def test_fold_and_traverse(self):
blocks.QueryRoot({'Animal'}),
blocks.MarkLocation(base_location),
blocks.Fold(parent_fold),
blocks.MarkLocation(parent_fold),
blocks.Traverse('out', 'Animal_ParentOf'),
blocks.MarkLocation(first_traversed_fold),
blocks.Unfold(),
blocks.ConstructResult({
'animal_name': expressions.OutputContextField(
Expand Down Expand Up @@ -2260,8 +2265,11 @@ def test_fold_and_deep_traverse(self):
blocks.QueryRoot({'Animal'}),
blocks.MarkLocation(base_location),
blocks.Fold(parent_fold),
blocks.MarkLocation(parent_fold),
blocks.Traverse('out', 'Animal_ParentOf'),
blocks.MarkLocation(first_traversed_fold),
blocks.Traverse('out', 'Animal_OfSpecies'),
blocks.MarkLocation(second_traversed_fold),
blocks.Unfold(),
blocks.ConstructResult({
'animal_name': expressions.OutputContextField(
Expand Down Expand Up @@ -2293,7 +2301,9 @@ def test_traverse_and_fold_and_traverse(self):
blocks.Traverse('in', 'Animal_ParentOf'),
blocks.MarkLocation(parent_location),
blocks.Fold(sibling_fold),
blocks.MarkLocation(sibling_fold),
blocks.Traverse('out', 'Animal_OfSpecies'),
blocks.MarkLocation(sibling_species_fold),
blocks.Unfold(),
blocks.Backtrack(base_location),
blocks.ConstructResult({
Expand Down Expand Up @@ -2322,6 +2332,7 @@ def test_multiple_outputs_in_same_fold(self):
blocks.QueryRoot({'Animal'}),
blocks.MarkLocation(base_location),
blocks.Fold(base_fold),
blocks.MarkLocation(base_fold),
blocks.Unfold(),
blocks.ConstructResult({
'animal_name': expressions.OutputContextField(
Expand Down Expand Up @@ -2350,7 +2361,9 @@ def test_multiple_outputs_in_same_fold_and_traverse(self):
blocks.QueryRoot({'Animal'}),
blocks.MarkLocation(base_location),
blocks.Fold(base_fold),
blocks.MarkLocation(base_fold),
blocks.Traverse('out', 'Animal_ParentOf'),
blocks.MarkLocation(first_traversed_fold),
blocks.Unfold(),
blocks.ConstructResult({
'animal_name': expressions.OutputContextField(
Expand Down Expand Up @@ -2380,8 +2393,10 @@ def test_multiple_folds(self):
blocks.QueryRoot({'Animal'}),
blocks.MarkLocation(base_location),
blocks.Fold(base_out_fold),
blocks.MarkLocation(base_out_fold),
blocks.Unfold(),
blocks.Fold(base_in_fold),
blocks.MarkLocation(base_in_fold),
blocks.Unfold(),
blocks.ConstructResult({
'animal_name': expressions.OutputContextField(
Expand Down Expand Up @@ -2416,10 +2431,14 @@ def test_multiple_folds_and_traverse(self):
blocks.QueryRoot({'Animal'}),
blocks.MarkLocation(base_location),
blocks.Fold(base_out_fold),
blocks.MarkLocation(base_out_fold),
blocks.Traverse('in', 'Animal_ParentOf'),
blocks.MarkLocation(base_out_traversed_fold),
blocks.Unfold(),
blocks.Fold(base_in_fold),
blocks.MarkLocation(base_in_fold),
blocks.Traverse('out', 'Animal_ParentOf'),
blocks.MarkLocation(base_in_traversed_fold),
blocks.Unfold(),
blocks.ConstructResult({
'animal_name': expressions.OutputContextField(
Expand Down Expand Up @@ -2455,8 +2474,10 @@ def test_fold_date_and_datetime_fields(self):
blocks.QueryRoot({'Animal'}),
blocks.MarkLocation(base_location),
blocks.Fold(base_parent_fold),
blocks.MarkLocation(base_parent_fold),
blocks.Unfold(),
blocks.Fold(base_fed_at_fold),
blocks.MarkLocation(base_fed_at_fold),
blocks.Unfold(),
blocks.ConstructResult({
'animal_name': expressions.OutputContextField(
Expand Down Expand Up @@ -2487,6 +2508,7 @@ def test_coercion_to_union_base_type_inside_fold(self):
blocks.QueryRoot({'Animal'}),
blocks.MarkLocation(base_location),
blocks.Fold(important_event_fold),
blocks.MarkLocation(important_event_fold),
blocks.Unfold(),
blocks.ConstructResult({
'animal_name': expressions.OutputContextField(
Expand Down Expand Up @@ -2525,6 +2547,7 @@ def test_coercion_filters_and_multiple_outputs_within_fold_scope(self):
expressions.Variable('$latest', GraphQLDate)
)
),
blocks.MarkLocation(related_entity_fold),
blocks.Unfold(),
blocks.ConstructResult({
'related_animals': expressions.FoldedOutputContextField(
Expand Down Expand Up @@ -2553,6 +2576,7 @@ def test_coercion_filters_and_multiple_outputs_within_fold_traversal(self):
blocks.QueryRoot({'Animal'}),
blocks.MarkLocation(base_location),
blocks.Fold(parent_fold),
blocks.MarkLocation(parent_fold),
blocks.Traverse('out', 'Entity_Related'),
blocks.CoerceType({'Animal'}),
blocks.Filter(expressions.BinaryComposition(
Expand All @@ -2567,6 +2591,7 @@ def test_coercion_filters_and_multiple_outputs_within_fold_traversal(self):
expressions.Variable('$latest', GraphQLDate)
)
),
blocks.MarkLocation(inner_fold),
blocks.Unfold(),
blocks.ConstructResult({
'name': expressions.OutputContextField(
Expand Down Expand Up @@ -2596,6 +2621,7 @@ def test_no_op_coercion_inside_fold(self):
blocks.QueryRoot({'Animal'}),
blocks.MarkLocation(base_location),
blocks.Fold(related_entity_fold),
blocks.MarkLocation(related_entity_fold),
blocks.Unfold(),
blocks.ConstructResult({
'animal_name': expressions.OutputContextField(
Expand Down Expand Up @@ -2628,6 +2654,7 @@ def test_filter_within_fold_scope(self):
expressions.Variable('$desired', GraphQLString)
)
),
blocks.MarkLocation(base_parent_fold),
blocks.Unfold(),
blocks.ConstructResult({
'name': expressions.OutputContextField(
Expand Down Expand Up @@ -2670,6 +2697,7 @@ def test_filter_on_fold_scope(self):
)
)
),
blocks.MarkLocation(base_parent_fold),
blocks.Unfold(),
blocks.ConstructResult({
'name': expressions.OutputContextField(
Expand All @@ -2696,6 +2724,7 @@ def test_coercion_on_interface_within_fold_scope(self):
blocks.MarkLocation(base_location),
blocks.Fold(related_entity_fold),
blocks.CoerceType({'Animal'}),
blocks.MarkLocation(related_entity_fold),
blocks.Unfold(),
blocks.ConstructResult({
'name': expressions.OutputContextField(
Expand Down Expand Up @@ -2723,9 +2752,12 @@ def test_coercion_on_interface_within_fold_traversal(self):
blocks.QueryRoot({'Animal'}),
blocks.MarkLocation(base_location),
blocks.Fold(base_parent_fold),
blocks.MarkLocation(base_parent_fold),
blocks.Traverse('out', 'Entity_Related'),
blocks.CoerceType({'Animal'}),
blocks.MarkLocation(first_traversed_fold),
blocks.Traverse('out', 'Animal_OfSpecies'),
blocks.MarkLocation(second_traversed_fold),
blocks.Unfold(),
blocks.ConstructResult({
'animal_name': expressions.OutputContextField(
Expand Down Expand Up @@ -2754,6 +2786,7 @@ def test_coercion_on_union_within_fold_scope(self):
blocks.MarkLocation(base_location),
blocks.Fold(important_event_fold),
blocks.CoerceType({'BirthEvent'}),
blocks.MarkLocation(important_event_fold),
blocks.Unfold(),
blocks.ConstructResult({
'name': expressions.OutputContextField(
Expand Down Expand Up @@ -3539,6 +3572,7 @@ def test_optional_and_fold(self):
blocks.Backtrack(base_location, optional=True),
blocks.MarkLocation(revisited_base_location),
blocks.Fold(fold_scope),
blocks.MarkLocation(fold_scope),
blocks.Unfold(),
blocks.ConstructResult({
'animal_name': expressions.OutputContextField(
Expand Down Expand Up @@ -3574,6 +3608,7 @@ def test_fold_and_optional(self):
blocks.QueryRoot({'Animal'}),
blocks.MarkLocation(base_location),
blocks.Fold(base_fold),
blocks.MarkLocation(base_fold),
blocks.Unfold(),
blocks.Traverse('in', 'Animal_ParentOf', optional=True),
blocks.MarkLocation(parent_location),
Expand Down Expand Up @@ -3624,7 +3659,9 @@ def test_optional_traversal_and_fold_traversal(self):
blocks.Backtrack(base_location, optional=True),
blocks.MarkLocation(revisited_base_location),
blocks.Fold(fold_scope),
blocks.MarkLocation(fold_scope),
blocks.Traverse('out', 'Animal_ParentOf'),
blocks.MarkLocation(first_traversed_fold),
blocks.Unfold(),
blocks.ConstructResult({
'grandparent_name': expressions.TernaryConditional(
Expand Down Expand Up @@ -3664,7 +3701,9 @@ def test_fold_traversal_and_optional_traversal(self):
blocks.QueryRoot({'Animal'}),
blocks.MarkLocation(base_location),
blocks.Fold(base_fold),
blocks.MarkLocation(base_fold),
blocks.Traverse('out', 'Animal_ParentOf'),
blocks.MarkLocation(first_traversed_fold),
blocks.Unfold(),
blocks.Traverse('in', 'Animal_ParentOf', optional=True),
blocks.MarkLocation(parent_location),
Expand Down

0 comments on commit abb266a

Please sign in to comment.