Skip to content

Commit

Permalink
Orientdb @optional bugfix (#95)
Browse files Browse the repository at this point in the history
* Extracting simple optional info

* Adding WHERE clause

* Continuing WhereBlock, now using Filter

* Added SelectEdgeContextField and SelectWhereFilter

* Added docstrings

* Adding docstrings

* Updating MATCH in test_compiler

* Fixing all test cases

* Adding comments and fixing lint errors

* Expanding docstring and removing default argument

* Removing TODO

* Deterministic order for WHERE filters

* reformatting

* Removing SelectWhereFilter and addressing comments.

* Addressing comments.

* Renaming input for _split_ir_into_match_steps to pruned_ir_blocks

* Minor changes.

* Minor changes.

* Addressing comments.

* Addressing comments.

* Addressing comments.
  • Loading branch information
amartyashankha authored and obi1kenobi committed Jun 27, 2018
1 parent 1caafae commit 7b9e4b9
Show file tree
Hide file tree
Showing 12 changed files with 585 additions and 77 deletions.
8 changes: 8 additions & 0 deletions graphql_compiler/compiler/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,3 +444,11 @@ class EndOptional(MarkerBlock):
def validate(self):
"""In isolation, EndOptional blocks are always valid."""
pass


class GlobalOperationsStart(MarkerBlock):
"""Marker block for the end of MATCH traversals, and the beginning of global operations."""

def validate(self):
"""In isolation, GlobalOperationsStart blocks are always valid."""
pass
13 changes: 13 additions & 0 deletions graphql_compiler/compiler/emit_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import six

from .blocks import Filter, QueryRoot, Recurse, Traverse
from .expressions import TrueLiteral
from .helpers import validate_safe_string


Expand Down Expand Up @@ -156,6 +157,14 @@ def _construct_output_to_match(output_block):
return u'SELECT %s FROM' % (u', '.join(selections),)


def _construct_where_to_match(where_block):
"""Transform a Filter block into a MATCH query string."""
if where_block.predicate == TrueLiteral:
raise AssertionError(u'Received WHERE block with TrueLiteral predicate: {}'
.format(where_block))
return u'WHERE ' + where_block.predicate.to_match()


##############
# Public API #
##############
Expand Down Expand Up @@ -197,6 +206,10 @@ def emit_code_from_single_match_query(match_query):
# Represent and add the SELECT clauses with the proper output data.
query_data.appendleft(_construct_output_to_match(match_query.output_block))

# Represent and add the WHERE clause with the proper filters.
if match_query.where_block is not None:
query_data.append(_construct_where_to_match(match_query.where_block))

return u' '.join(query_data)


Expand Down
51 changes: 49 additions & 2 deletions graphql_compiler/compiler/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from ..schema import GraphQLDate, GraphQLDateTime
from .compiler_entities import Expression
from .helpers import (STANDARD_DATE_FORMAT, STANDARD_DATETIME_FORMAT, FoldScopeLocation, Location,
ensure_unicode_string, is_graphql_type, safe_quoted_string,
strip_non_null_from_type, validate_safe_string)
ensure_unicode_string, is_graphql_type, is_vertex_field_name,
safe_quoted_string, strip_non_null_from_type, validate_safe_string)


# Since MATCH uses $-prefixed keywords to indicate special values,
Expand Down Expand Up @@ -235,6 +235,53 @@ def to_gremlin(self):
return u'{}.{}'.format(local_object_name, self.field_name)


class SelectEdgeContextField(Expression):
"""An edge field drawn from the global context, for use in a SELECT WHERE statement."""

def __init__(self, location):
"""Construct a new SelectEdgeContextField object that references an edge field.
Args:
location: Location, specifying where the field was declared.
The Location object must contain an edge field.
Returns:
new SelectEdgeContextField object
"""
super(SelectEdgeContextField, self).__init__(location)
self.location = location
self.validate()

def validate(self):
"""Validate that the SelectEdgeContextField is correctly representable."""
if not isinstance(self.location, Location):
raise TypeError(u'Expected Location location, got: {} {}'
.format(type(self.location).__name__, self.location))

if self.location.field is None:
raise AssertionError(u'Received Location without a field: {}'
.format(self.location))

if not is_vertex_field_name(self.location.field):
raise AssertionError(u'Received Location with a non-edge field: {}'
.format(self.location))

def to_match(self):
"""Return a unicode object with the MATCH representation of this SelectEdgeContextField."""
self.validate()

mark_name, field_name = self.location.get_location_name()
validate_safe_string(mark_name)
validate_safe_string(field_name)

return u'%s.%s' % (mark_name, field_name)

def to_gremlin(self):
"""Not implemented, should not be used."""
raise AssertionError(u'SelectEdgeContextField is only used for the WHERE statement in '
u'MATCH. This function should not be called.')


class ContextField(Expression):
"""A field drawn from the global context, e.g. if selected earlier in the query."""

Expand Down
52 changes: 52 additions & 0 deletions graphql_compiler/compiler/ir_lowering_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright 2017-present Kensho Technologies, LLC.
"""Language-independent IR lowering and optimization functions."""
import six

from .blocks import (ConstructResult, EndOptional, Filter, Fold, MarkLocation, Recurse, Traverse,
Unfold)
from .expressions import (BinaryComposition, ContextField, ContextFieldExistence, FalseLiteral,
Expand Down Expand Up @@ -265,6 +267,56 @@ def extract_optional_location_root_info(ir_blocks):
return complex_optional_roots, location_to_optional_root


def extract_simple_optional_location_info(
ir_blocks, complex_optional_roots, location_to_optional_root):
"""Construct a map from simple optional locations to their inner location and traversed edge.
Args:
ir_blocks: list of IR blocks to extract optional data from
complex_optional_roots: list of @optional locations (location immmediately preceding
an @optional traverse) that expand vertex fields
location_to_optional_root: dict mapping from location -> optional_root where location is
within @optional (not necessarily one that expands vertex fields)
and optional_root is the location preceding the corresponding
@optional scope
Returns:
dict mapping from simple_optional_root_location -> dict containing keys
- 'inner_location_name': Location object correspoding to the unique MarkLocation present
within a simple optional (one that does not expand vertex fields)
scope
- 'edge_field': string representing the optional edge being traversed
where simple_optional_root_to_inner_location is the location preceding the @optional scope
"""
# Simple optional roots are a subset of location_to_optional_root.values() (all optional roots).
# We filter out the ones that are also present in complex_optional_roots.
simple_optional_root_to_inner_location = {
optional_root_location: inner_location
for inner_location, optional_root_location in six.iteritems(location_to_optional_root)
if optional_root_location not in complex_optional_roots
}
simple_optional_root_locations = set(simple_optional_root_to_inner_location.keys())

simple_optional_root_info = {}
preceding_location = None
for current_block in ir_blocks:
if isinstance(current_block, MarkLocation):
preceding_location = current_block.location
elif isinstance(current_block, Traverse) and current_block.optional:
if preceding_location in simple_optional_root_locations:
# The current optional Traverse is "simple"
# i.e. it does not contain any Traverses within.
inner_location = simple_optional_root_to_inner_location[preceding_location]
inner_location_name, _ = inner_location.get_location_name()
simple_optional_info_dict = {
'inner_location_name': inner_location_name,
'edge_field': current_block.get_field_name(),
}
simple_optional_root_info[preceding_location] = simple_optional_info_dict

return simple_optional_root_info


def remove_end_optionals(ir_blocks):
"""Return a list of IR blocks as a copy of the original, with EndOptional blocks removed."""
new_ir_blocks = []
Expand Down
17 changes: 15 additions & 2 deletions graphql_compiler/compiler/ir_lowering_match/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright 2018-present Kensho Technologies, LLC.
import six

from ..blocks import Filter, GlobalOperationsStart
from ..ir_lowering_common import (extract_optional_location_root_info,
extract_simple_optional_location_info,
lower_context_field_existence, merge_consecutive_filter_clauses,
optimize_boolean_expression_comparisons, remove_end_optionals)
from .ir_lowering import (lower_backtrack_blocks,
Expand All @@ -17,12 +19,13 @@
lower_context_field_expressions, prune_non_existent_outputs)
from ..match_query import convert_to_match_query
from ..workarounds import orientdb_class_with_while, orientdb_eval_scheduling

from .utils import construct_where_filter_predicate

##############
# Public API #
##############


def lower_ir(ir_blocks, location_types, type_equivalence_hints=None):
"""Lower the IR into an IR form that can be represented in MATCH queries.
Expand Down Expand Up @@ -50,11 +53,21 @@ def lower_ir(ir_blocks, location_types, type_equivalence_hints=None):
"""
sanity_check_ir_blocks_from_frontend(ir_blocks)

# These lowering / optimization passes work on IR blocks.
# Extract information for both simple and complex @optional traverses
location_to_optional_results = extract_optional_location_root_info(ir_blocks)
complex_optional_roots, location_to_optional_root = location_to_optional_results
simple_optional_root_info = extract_simple_optional_location_info(
ir_blocks, complex_optional_roots, location_to_optional_root)
ir_blocks = remove_end_optionals(ir_blocks)

# Append global operation block(s) to filter out incorrect results
# from simple optional match traverses (using a WHERE statement)
if len(simple_optional_root_info) > 0:
where_filter_predicate = construct_where_filter_predicate(simple_optional_root_info)
ir_blocks.insert(-1, GlobalOperationsStart())
ir_blocks.insert(-1, Filter(where_filter_predicate))

# These lowering / optimization passes work on IR blocks.
ir_blocks = lower_context_field_existence(ir_blocks)
ir_blocks = optimize_boolean_expression_comparisons(ir_blocks)
ir_blocks = rewrite_binary_composition_inside_ternary_conditional(ir_blocks)
Expand Down
1 change: 1 addition & 0 deletions graphql_compiler/compiler/ir_lowering_match/ir_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# Optimization / lowering passes #
##################################


def rewrite_binary_composition_inside_ternary_conditional(ir_blocks):
"""Rewrite BinaryConditional expressions in the true/false values of TernaryConditionals."""
def visitor_fn(expression):
Expand Down
54 changes: 18 additions & 36 deletions graphql_compiler/compiler/ir_lowering_match/optional_traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,11 @@

from ..blocks import ConstructResult, Filter, Traverse
from ..expressions import (BinaryComposition, ContextField, FoldedOutputContextField, Literal,
LocalField, NullLiteral, OutputContextField, TernaryConditional,
TrueLiteral, UnaryTransformation, Variable, ZeroLiteral)
LocalField, OutputContextField, TernaryConditional, TrueLiteral,
UnaryTransformation, Variable)
from ..match_query import MatchQuery, MatchStep
from .utils import BetweenClause, CompoundMatchQuery


def _filter_local_edge_field_non_existence(field_name):
"""Return an Expression that is True iff the specified edge (field_name) does not exist."""
# When an edge does not exist at a given vertex, OrientDB represents that in one of two ways:
# - the edge's field does not exist (is null) on the vertex document, or
# - the edge's field does exist, but is an empty list.
# We check both of these possibilities.
local_field = LocalField(field_name)

field_null_check = BinaryComposition(u'=', local_field, NullLiteral)

local_field_size = UnaryTransformation(u'size', local_field)
field_size_check = BinaryComposition(u'=', local_field_size, ZeroLiteral)

return BinaryComposition(u'||', field_null_check, field_size_check)
from .utils import (BetweenClause, CompoundMatchQuery, expression_list_to_conjunction,
filter_edge_field_non_existence)


def _prune_traverse_using_omitted_locations(match_traversal, omitted_locations,
Expand Down Expand Up @@ -63,7 +48,7 @@ def _prune_traverse_using_omitted_locations(match_traversal, omitted_locations,
elif optional_root_location in omitted_locations:
# Add filter to indicate that the omitted edge(s) shoud not exist
field_name = step.root_block.get_field_name()
new_predicate = _filter_local_edge_field_non_existence(field_name)
new_predicate = filter_edge_field_non_existence(LocalField(field_name))
old_filter = new_match_traversal[-1].where_block
if old_filter is not None:
new_predicate = BinaryComposition(u'&&', old_filter.predicate, new_predicate)
Expand Down Expand Up @@ -146,6 +131,7 @@ def convert_optional_traversals_to_compound_match_query(
match_traversals=match_traversals,
folds=match_query.folds,
output_block=match_query.output_block,
where_block=match_query.where_block,
)
for match_traversals in compound_match_traversals
]
Expand Down Expand Up @@ -216,7 +202,6 @@ def prune_non_existent_outputs(compound_match_query):
for match_query in compound_match_query.match_queries:
match_traversals = match_query.match_traversals
output_block = match_query.output_block
folds = match_query.folds

present_locations_tuple = _get_present_locations(match_traversals)
present_locations, present_non_optional_locations = present_locations_tuple
Expand Down Expand Up @@ -260,8 +245,9 @@ def prune_non_existent_outputs(compound_match_query):
match_queries.append(
MatchQuery(
match_traversals=match_traversals,
folds=folds,
output_block=ConstructResult(new_output_fields)
folds=match_query.folds,
output_block=ConstructResult(new_output_fields),
where_block=match_query.where_block,
)
)

Expand Down Expand Up @@ -296,18 +282,12 @@ def _construct_location_to_filter_list(match_query):
def _filter_list_to_conjunction_expression(filter_list):
"""Convert a list of filters to an Expression that is the conjunction of all of them."""
if not isinstance(filter_list, list):
raise AssertionError(u'Expected `list`, Received {}.'.format(filter_list))
raise AssertionError(u'Expected `list`, Received: {}.'.format(filter_list))
if any((not isinstance(filter_block, Filter) for filter_block in filter_list)):
raise AssertionError(u'Expected list of Filter objects. Received: {}'.format(filter_list))

if not isinstance(filter_list[0], Filter):
raise AssertionError(u'Non-Filter object {} found in filter_list'
.format(filter_list[0]))

if len(filter_list) == 1:
return filter_list[0].predicate
else:
return BinaryComposition(u'&&',
_filter_list_to_conjunction_expression(filter_list[1:]),
filter_list[0].predicate)
expression_list = [filter_block.predicate for filter_block in filter_list]
return expression_list_to_conjunction(expression_list)


def _apply_filters_to_first_location_occurrence(match_traversal, location_to_filters,
Expand Down Expand Up @@ -402,7 +382,8 @@ def collect_filters_to_first_location_occurrence(compound_match_query):
MatchQuery(
match_traversals=new_match_traversals,
folds=match_query.folds,
output_block=match_query.output_block
output_block=match_query.output_block,
where_block=match_query.where_block,
)
)

Expand Down Expand Up @@ -572,7 +553,8 @@ def lower_context_field_expressions(compound_match_query):
MatchQuery(
match_traversals=new_match_traversals,
folds=match_query.folds,
output_block=match_query.output_block
output_block=match_query.output_block,
where_block=match_query.where_block,
)
)

Expand Down
Loading

0 comments on commit 7b9e4b9

Please sign in to comment.