Skip to content

Commit

Permalink
Update sqlalchemy schema generation branch (#333)
Browse files Browse the repository at this point in the history
* Add indexes (#312)

* Added indexes to the SchemaGraph

* Add additional index methods and remove type field

* Nits

* Nits

* More nits

* More nits

* Changed all_indexes to indexes

* Revert "Changed all_indexes to indexes"

This reverts commit 44ffd00.

* Nit

* Addressed code review comments

* Prohibit null value ignoring for edge indexes

* Fix typo

* Stop overriding ignore_nulls for edge indexes to true

* Fix typo

* Refactor inheritance structure constructor (#314)

* Add immediate superclasses dict

* Make toposorting not use OrientDB constructs.

* Refactored InheritanceStructure constructor

* Move inheritance structure class and related methods

* Fixed docstring

* Added documentation for InheritanceStructure class

* Nits

* Return validation to its original place

* Fix terminology

* More nits

* More nits

* More nits

* Re-add list coercion

* Add indexes to test SchemaGraph (#315)

* Add indexes to test SchemaGraph

* Changed sets arguments to lists and remove is None check

* Add output info to query metadata (#311)

* Add output info to query metadata

* Update explain info tests

* Add output tests to explain_info

* Remove context output

* Add docstring metadata outputs property

* Use query_metadta_table variable and name

* Replace more names with query_metadata_table

* Remove fold from OutputInfo

* Change set input for list input (#318)

* Fix index documentation (#321)

* Fix index documentation

* Nit

* Refactor and clean up IR lowering code. Add common IR lowering module. (#324)

* Refactor and clean up IR lowering code. Add common IR lowering operations module.

* Delint.

* Improve expression comment around null values.

* Add defensive strip_non_null_from_type() call.

* Fix inheritance of SQLAlchemy vertices
  • Loading branch information
pmantica1 authored and bojanserafimov committed May 30, 2019
1 parent 258577e commit 1f91f84
Show file tree
Hide file tree
Showing 23 changed files with 811 additions and 403 deletions.
2 changes: 1 addition & 1 deletion graphql_compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,6 @@ def get_graphql_schema_from_orientdb_schema_data(schema_data, class_to_field_typ
if hidden_classes is None:
hidden_classes = set()

schema_graph = get_orientdb_schema_graph(schema_data)
schema_graph = get_orientdb_schema_graph(schema_data, [])
return get_graphql_schema_from_schema_graph(schema_graph, class_to_field_type_overrides,
hidden_classes)
71 changes: 38 additions & 33 deletions graphql_compiler/compiler/compiler_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
get_vertex_field_type, invert_dict, is_tagged_parameter, is_vertex_field_name,
strip_non_null_from_type, validate_output_name, validate_safe_string
)
from .metadata import LocationInfo, QueryMetadataTable, RecurseInfo, TagInfo
from .metadata import LocationInfo, OutputInfo, QueryMetadataTable, RecurseInfo, TagInfo


# LocationStackEntry contains the following:
Expand Down Expand Up @@ -298,7 +298,7 @@ def _compile_property_ast(schema, current_schema_type, ast, location,
if output_directive:
# Schema validation has ensured that the fields below exist.
output_name = output_directive.arguments[0].value.value
if output_name in context['outputs']:
if context['metadata'].get_output_info(output_name):
raise GraphQLCompilationError(u'Cannot reuse output name: '
u'{}, {}'.format(output_name, context))
validate_safe_string(output_name)
Expand All @@ -312,12 +312,12 @@ def _compile_property_ast(schema, current_schema_type, ast, location,
if location.field != COUNT_META_FIELD_NAME:
graphql_type = GraphQLList(graphql_type)

context['outputs'][output_name] = {
'location': location,
'optional': is_in_optional_scope(context),
'type': graphql_type,
'fold': context.get('fold', None),
}
output_info = OutputInfo(
location=location,
type=graphql_type,
optional=is_in_optional_scope(context),
)
context['metadata'].record_output_info(output_name, output_info)


def _get_recurse_directive_depth(field_name, field_directives):
Expand Down Expand Up @@ -524,7 +524,7 @@ def _compile_vertex_ast(schema, current_schema_type, ast,
if edge_traversal_is_folded:
has_count_filter = has_fold_count_filter(context)
_validate_fold_has_outputs_or_count_filter(
get_context_fold_info(context), has_count_filter, context['outputs'])
get_context_fold_info(context), has_count_filter, query_metadata_table)
basic_blocks.append(blocks.Unfold())
unmark_context_fold_scope(context)
if has_count_filter:
Expand Down Expand Up @@ -560,7 +560,19 @@ def _compile_vertex_ast(schema, current_schema_type, ast,
return basic_blocks


def _validate_fold_has_outputs_or_count_filter(fold_scope_location, fold_has_count_filter, outputs):
def _are_locations_in_same_fold(first_location, second_location):
"""Returns True if locations are contained in the same fold scope."""
return (
isinstance(first_location, FoldScopeLocation) and
isinstance(second_location, FoldScopeLocation) and
first_location.base_location == second_location.base_location and
first_location.get_first_folded_edge() == second_location.get_first_folded_edge()
)


def _validate_fold_has_outputs_or_count_filter(
fold_scope_location, fold_has_count_filter, query_metadata_table
):
"""Ensure the @fold scope has at least one output, or filters on the size of the fold."""
# This function makes sure that the @fold scope has an effect.
# Folds either output data, or filter the data enclosing the fold based on the size of the fold.
Expand All @@ -570,8 +582,8 @@ def _validate_fold_has_outputs_or_count_filter(fold_scope_location, fold_has_cou

# At least one output in the outputs list must point to the fold_scope_location,
# or the scope corresponding to fold_scope_location had no @outputs and is illegal.
for output in six.itervalues(outputs):
if output['fold'] == fold_scope_location:
for _, output_info in query_metadata_table.outputs:
if _are_locations_in_same_fold(output_info.location, fold_scope_location):
return True

raise GraphQLCompilationError(u'Found a @fold scope that has no effect on the query. '
Expand Down Expand Up @@ -798,13 +810,6 @@ def _compile_root_ast_to_ir(schema, ast, type_equivalence_hints=None):
# query processing, but apply to the global query scope and should be appended to the
# IR blocks only after the GlobalOperationsStart block has been emitted.
'global_filters': [],
# 'outputs' is a dict mapping each output name to another dict which contains
# - location: Location where to output from
# - optional: boolean representing whether the output was defined within an @optional scope
# - type: GraphQLType of the output
# - fold: FoldScopeLocation object if the current output was defined within a fold scope,
# and None otherwise
'outputs': dict(),
# 'inputs' is a dict mapping input parameter names to their respective expected GraphQL
# types, as automatically inferred by inspecting the query structure
'inputs': dict(),
Expand Down Expand Up @@ -853,11 +858,10 @@ def _compile_root_ast_to_ir(schema, ast, type_equivalence_hints=None):
basic_blocks.extend(context['global_filters'])

# Based on the outputs context data, add an output step and construct the output metadata.
outputs_context = context['outputs']
basic_blocks.append(_compile_output_step(outputs_context))
basic_blocks.append(_compile_output_step(query_metadata_table))
output_metadata = {
name: OutputMetadata(type=value['type'], optional=value['optional'])
for name, value in six.iteritems(outputs_context)
name: OutputMetadata(type=info.type, optional=info.optional)
for name, info in query_metadata_table.outputs
}

return IrAndMetadata(
Expand All @@ -867,34 +871,35 @@ def _compile_root_ast_to_ir(schema, ast, type_equivalence_hints=None):
query_metadata_table=context['metadata'])


def _compile_output_step(outputs):
def _compile_output_step(query_metadata_table):
"""Construct the final ConstructResult basic block that defines the output format of the query.
Args:
outputs: dict, output name (string) -> output data dict, specifying the location
from where to get the data, and whether the data is optional (and therefore
may be missing); missing optional data is replaced with 'null'
query_metadata_table: QueryMetadataTable object, part of which specifies the location from
where to get the output, and whether the output is optional (and
therefore may be missing); missing optional data is replaced with
'null'
Returns:
a ConstructResult basic block that constructs appropriate outputs for the query
"""
if not outputs:
if next(query_metadata_table.outputs, None) is None:
raise GraphQLCompilationError(u'No fields were selected for output! Please mark at least '
u'one field with the @output directive.')

output_fields = {}
for output_name, output_context in six.iteritems(outputs):
location = output_context['location']
optional = output_context['optional']
graphql_type = output_context['type']
for output_name, output_info in query_metadata_table.outputs:
location = output_info.location
optional = output_info.optional
graphql_type = output_info.type

expression = None
existence_check = None
# pylint: disable=redefined-variable-type
if isinstance(location, FoldScopeLocation):
if optional:
raise AssertionError(u'Unreachable state reached, optional in fold: '
u'{}'.format(output_context))
u'{}'.format(output_info))

if location.field == COUNT_META_FIELD_NAME:
expression = expressions.FoldCountContextField(location)
Expand Down
16 changes: 10 additions & 6 deletions graphql_compiler/compiler/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ def to_match(self):

def to_gremlin(self):
"""Return a unicode object with the Gremlin representation of this expression."""
self.validate()

# We can't directly pass a Date or a DateTime object, so we have to pass it as a string
# and then parse it inline. For date format parameter meanings, see:
# http://docs.oracle.com/javase/7/docs/api/java/text/SimpleDateFormat.html
Expand Down Expand Up @@ -559,8 +561,9 @@ def to_match(self):
return template % template_data

def to_gremlin(self):
"""Must never be called."""
raise NotImplementedError()
"""Not implemented, should not be used."""
raise AssertionError(u'FoldedContextField are not used during the query emission process '
u'in Gremlin, so this is a bug. This function should not be called.')

def __eq__(self, other):
"""Return True if the given object is equal to this one, and False otherwise."""
Expand Down Expand Up @@ -618,7 +621,7 @@ def to_match(self):
return template % template_data

def to_gremlin(self):
"""Must never be called."""
"""Not supported yet."""
raise NotImplementedError()


Expand Down Expand Up @@ -799,9 +802,9 @@ def to_match(self):
intersects_operator_format = '(%(operator)s(%(left)s, %(right)s).asList().size() > 0)'
# pylint: enable=unused-variable

# Null literals use 'is/is not' as (in)equality operators, while other values use '=/<>'.
if any((isinstance(self.left, Literal) and self.left.value is None,
isinstance(self.right, Literal) and self.right.value is None)):
# Null literals use the OrientDB 'IS/IS NOT' (in)equality operators,
# while other values use the OrientDB '=/<>' operators.
if self.left == NullLiteral or self.right == NullLiteral:
translation_table = {
u'=': (u'IS', regular_operator_format),
u'!=': (u'IS NOT', regular_operator_format),
Expand Down Expand Up @@ -947,6 +950,7 @@ def visitor_fn(expression):
def to_gremlin(self):
"""Return a unicode object with the Gremlin representation of this expression."""
self.validate()

return u'({predicate} ? {if_true} : {if_false})'.format(
predicate=self.predicate.to_gremlin(),
if_true=self.if_true.to_gremlin(),
Expand Down
1 change: 1 addition & 0 deletions graphql_compiler/compiler/ir_lowering_common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright 2019-present Kensho Technologies, LLC.
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
"""Language-independent IR lowering and optimization functions."""
import six

from .blocks import (
from ..blocks import (
ConstructResult, EndOptional, Filter, Fold, MarkLocation, Recurse, Traverse, Unfold
)
from .expressions import (
from ..expressions import (
BinaryComposition, ContextField, ContextFieldExistence, FalseLiteral, NullLiteral, TrueLiteral
)
from .helpers import validate_safe_string
from ..helpers import validate_safe_string


def merge_consecutive_filter_clauses(ir_blocks):
Expand Down
67 changes: 67 additions & 0 deletions graphql_compiler/compiler/ir_lowering_common/location_renaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2019-present Kensho Technologies, LLC.
"""Utilities for rewriting IR to replace one set of locations with another."""
import six

from ..helpers import FoldScopeLocation, Location


def make_revisit_location_translations(query_metadata_table):
"""Return a dict mapping location revisits to the location being revisited, for rewriting."""
location_translations = dict()

for location, _ in query_metadata_table.registered_locations:
location_being_revisited = query_metadata_table.get_revisit_origin(location)
if location_being_revisited != location:
location_translations[location] = location_being_revisited

return location_translations


def translate_potential_location(location_translations, potential_location):
"""If the input is a BaseLocation object, translate it, otherwise return it as-is."""
if isinstance(potential_location, Location):
old_location_at_vertex = potential_location.at_vertex()
field = potential_location.field

new_location = location_translations.get(old_location_at_vertex, None)
if new_location is None:
# No translation needed.
return potential_location
else:
# If necessary, add the field component to the new location before returning it.
if field is None:
return new_location
else:
return new_location.navigate_to_field(field)
elif isinstance(potential_location, FoldScopeLocation):
old_base_location = potential_location.base_location
new_base_location = location_translations.get(old_base_location, old_base_location)
fold_path = potential_location.fold_path
fold_field = potential_location.field
return FoldScopeLocation(new_base_location, fold_path, field=fold_field)
else:
return potential_location


def make_location_rewriter_visitor_fn(location_translations):
"""Return a visitor function that is able to replace locations with equivalent locations."""
def visitor_fn(expression):
"""Expression visitor function used to rewrite expressions with updated Location data."""
# All CompilerEntity objects store their exact constructor input args/kwargs.
# To minimize the chances that we forget to update a location somewhere in an expression,
# we rewrite all locations that we find as arguments to expression constructors.
# pylint: disable=protected-access
new_args = [
translate_potential_location(location_translations, arg)
for arg in expression._print_args
]
new_kwargs = {
kwarg_name: translate_potential_location(location_translations, kwarg_value)
for kwarg_name, kwarg_value in six.iteritems(expression._print_kwargs)
}
# pylint: enable=protected-access

expression_cls = type(expression)
return expression_cls(*new_args, **new_kwargs)

return visitor_fn
10 changes: 6 additions & 4 deletions graphql_compiler/compiler/ir_lowering_gremlin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright 2018-present Kensho Technologies, LLC.
from .ir_lowering import (lower_coerce_type_block_type_data, lower_coerce_type_blocks,
lower_folded_outputs, rewrite_filters_in_optional_blocks)
lower_folded_outputs_and_context_fields,
rewrite_filters_in_optional_blocks)
from ..ir_sanity_checks import sanity_check_ir_blocks_from_frontend
from ..ir_lowering_common import (lower_context_field_existence, merge_consecutive_filter_clauses,
optimize_boolean_expression_comparisons)
from ..ir_lowering_common.common import (lower_context_field_existence,
merge_consecutive_filter_clauses,
optimize_boolean_expression_comparisons)


##############
Expand Down Expand Up @@ -48,6 +50,6 @@ def lower_ir(ir_blocks, query_metadata_table, type_equivalence_hints=None):
ir_blocks = lower_coerce_type_blocks(ir_blocks)
ir_blocks = rewrite_filters_in_optional_blocks(ir_blocks)
ir_blocks = merge_consecutive_filter_clauses(ir_blocks)
ir_blocks = lower_folded_outputs(ir_blocks)
ir_blocks = lower_folded_outputs_and_context_fields(ir_blocks)

return ir_blocks

0 comments on commit 1f91f84

Please sign in to comment.