diff --git a/graphql_compiler/__init__.py b/graphql_compiler/__init__.py index d1a43c9cc..9c7edf44d 100644 --- a/graphql_compiler/__init__.py +++ b/graphql_compiler/__init__.py @@ -1,7 +1,12 @@ # Copyright 2017-present Kensho Technologies, LLC. """Commonly-used functions and data types from this package.""" -from .compiler import CompilationResult, OutputMetadata # noqa -from .compiler import compile_graphql_to_gremlin, compile_graphql_to_match # noqa +from .compiler import ( # noqa + CompilationResult, + OutputMetadata, + compile_graphql_to_gremlin, + compile_graphql_to_match, + compile_graphql_to_sql, +) from .query_formatting import insert_arguments_into_query # noqa from .query_formatting.graphql_formatting import pretty_print_graphql # noqa @@ -22,7 +27,7 @@ def graphql_to_match(schema, graphql_query, parameters, type_equivalence_hints=N Args: schema: GraphQL schema object describing the schema of the graph to be queried - graphql_string: the GraphQL query to compile to MATCH, as a string + graphql_query: the GraphQL query to compile to MATCH, as a string parameters: dict, mapping argument name to its value, for every parameter the query expects. type_equivalence_hints: optional dict of GraphQL interface or type -> GraphQL union. Used as a workaround for GraphQL's lack of support for @@ -53,12 +58,48 @@ def graphql_to_match(schema, graphql_query, parameters, type_equivalence_hints=N query=insert_arguments_into_query(compilation_result, parameters)) +def graphql_to_sql(schema, graphql_query, parameters, compiler_metadata, + type_equivalence_hints=None): + """Compile the GraphQL input using the schema into a SQL query and associated metadata. + + Args: + schema: GraphQL schema object describing the schema of the graph to be queried + graphql_query: the GraphQL query to compile to SQL, as a string + parameters: dict, mapping argument name to its value, for every parameter the query expects. + compiler_metadata: CompilerMetadata object, provides SQLAlchemy specific backend + information + type_equivalence_hints: optional dict of GraphQL interface or type -> GraphQL union. + Used as a workaround for GraphQL's lack of support for + inheritance across "types" (i.e. non-interfaces), as well as a + workaround for Gremlin's total lack of inheritance-awareness. + The key-value pairs in the dict specify that the "key" type + is equivalent to the "value" type, i.e. that the GraphQL type or + interface in the key is the most-derived common supertype + of every GraphQL type in the "value" GraphQL union. + Recursive expansion of type equivalence hints is not performed, + and only type-level correctness of this argument is enforced. + See README.md for more details on everything this parameter does. + ***** + Be very careful with this option, as bad input here will + lead to incorrect output queries being generated. + ***** + + Returns: + a CompilationResult object, containing: + - query: string, the resulting compiled and parameterized query string + - language: string, specifying the language to which the query was compiled + - output_metadata: dict, output name -> OutputMetadata namedtuple object + - input_metadata: dict, name of input variables -> inferred GraphQL type, based on use + """ + raise NotImplementedError(u'Compiling GraphQL to SQL is not yet supported.') + + def graphql_to_gremlin(schema, graphql_query, parameters, type_equivalence_hints=None): """Compile the GraphQL input using the schema into a Gremlin query and associated metadata. Args: schema: GraphQL schema object describing the schema of the graph to be queried - graphql_string: the GraphQL query to compile to Gremlin, as a string + graphql_query: the GraphQL query to compile to Gremlin, as a string parameters: dict, mapping argument name to its value, for every parameter the query expects. type_equivalence_hints: optional dict of GraphQL interface or type -> GraphQL union. Used as a workaround for GraphQL's lack of support for diff --git a/graphql_compiler/compiler/__init__.py b/graphql_compiler/compiler/__init__.py index 8321707d4..0521570fb 100644 --- a/graphql_compiler/compiler/__init__.py +++ b/graphql_compiler/compiler/__init__.py @@ -1,4 +1,9 @@ # Copyright 2017-present Kensho Technologies, LLC. -from .common import CompilationResult, compile_graphql_to_gremlin, compile_graphql_to_match # noqa -from .common import GREMLIN_LANGUAGE, MATCH_LANGUAGE # noqa +from .common import ( # noqa + CompilationResult, + compile_graphql_to_gremlin, + compile_graphql_to_match, + compile_graphql_to_sql, +) +from .common import GREMLIN_LANGUAGE, MATCH_LANGUAGE, SQL_LANGUAGE # noqa from .compiler_frontend import OutputMetadata # noqa diff --git a/graphql_compiler/compiler/blocks.py b/graphql_compiler/compiler/blocks.py index ce9286b71..d25927113 100644 --- a/graphql_compiler/compiler/blocks.py +++ b/graphql_compiler/compiler/blocks.py @@ -13,6 +13,8 @@ class QueryRoot(BasicBlock): """The starting object of the query to be compiled.""" + __slots__ = ('start_class',) + def __init__(self, start_class): """Construct a QueryRoot object that starts querying at the specified class name. @@ -56,6 +58,8 @@ def to_gremlin(self): class CoerceType(BasicBlock): """A special type of filter that discards any data that is not of the specified set of types.""" + __slots__ = ('target_class',) + def __init__(self, target_class): """Construct a CoerceType object that filters out any data that is not of the given types. @@ -91,6 +95,8 @@ def to_gremlin(self): class ConstructResult(BasicBlock): """A transformation of the data into a new form, for output.""" + __slots__ = ('fields',) + def __init__(self, fields): """Construct a ConstructResult object that maps the given field names to their expressions. @@ -157,6 +163,8 @@ def to_gremlin(self): class Filter(BasicBlock): """A filter that ensures data matches a predicate expression, and discards all other data.""" + __slots__ = ('predicate',) + def __init__(self, predicate): """Create a new Filter with the specified Expression as a predicate.""" super(Filter, self).__init__(predicate) @@ -186,6 +194,8 @@ def to_gremlin(self): class MarkLocation(BasicBlock): """A block that assigns a name to a given location in the query.""" + __slots__ = ('location',) + def __init__(self, location): """Create a new MarkLocation at the specified Location. @@ -213,6 +223,8 @@ def to_gremlin(self): class Traverse(BasicBlock): """A block that encodes a traversal across an edge, in either direction.""" + __slots__ = ('direction', 'edge_name', 'optional', 'within_optional_scope') + def __init__(self, direction, edge_name, optional=False, within_optional_scope=False): """Create a new Traverse block in the given direction and across the given edge. @@ -296,6 +308,8 @@ def to_gremlin(self): class Recurse(BasicBlock): """A block for recursive traversal of an edge, collecting all endpoints along the way.""" + __slots__ = ('direction', 'edge_name', 'depth', 'within_optional_scope') + def __init__(self, direction, edge_name, depth, within_optional_scope=False): """Create a new Recurse block which traverses the given edge up to "depth" times. @@ -360,6 +374,8 @@ def to_gremlin(self): class Backtrack(BasicBlock): """A block that specifies a return to a given Location in the query.""" + __slots__ = ('location', 'optional') + def __init__(self, location, optional=False): """Create a new Backtrack block, returning to the given location in the query. @@ -409,6 +425,8 @@ class OutputSource(MarkerBlock): See the comment on the @output_source directive in schema.py on why this is necessary. """ + __slots__ = () + def validate(self): """Validate the OutputSource block. An OutputSource block is always valid in isolation.""" pass @@ -417,6 +435,8 @@ def validate(self): class Fold(MarkerBlock): """A marker for the start of a @fold context.""" + __slots__ = ('fold_scope_location',) + def __init__(self, fold_scope_location): """Create a new Fold block rooted at the given location.""" super(Fold, self).__init__(fold_scope_location) @@ -433,6 +453,8 @@ def validate(self): class Unfold(MarkerBlock): """A marker for the end of a @fold context.""" + __slots__ = () + def validate(self): """Unfold blocks are always valid in isolation.""" pass @@ -444,6 +466,8 @@ class EndOptional(MarkerBlock): Optional scope is entered through an optional Traverse Block. """ + __slots__ = () + def validate(self): """In isolation, EndOptional blocks are always valid.""" pass @@ -452,6 +476,8 @@ def validate(self): class GlobalOperationsStart(MarkerBlock): """Marker block for the end of MATCH traversals, and the beginning of global operations.""" + __slots__ = () + def validate(self): """In isolation, GlobalOperationsStart blocks are always valid.""" pass diff --git a/graphql_compiler/compiler/common.py b/graphql_compiler/compiler/common.py index 380667ddc..18938cad8 100644 --- a/graphql_compiler/compiler/common.py +++ b/graphql_compiler/compiler/common.py @@ -15,6 +15,7 @@ MATCH_LANGUAGE = 'MATCH' GREMLIN_LANGUAGE = 'Gremlin' +SQL_LANGUAGE = 'SQL' def compile_graphql_to_match(schema, graphql_string, type_equivalence_hints=None): @@ -47,7 +48,7 @@ def compile_graphql_to_match(schema, graphql_string, type_equivalence_hints=None return _compile_graphql_generic( MATCH_LANGUAGE, lowering_func, query_emitter_func, - schema, graphql_string, type_equivalence_hints) + schema, graphql_string, type_equivalence_hints, None) def compile_graphql_to_gremlin(schema, graphql_string, type_equivalence_hints=None): @@ -80,12 +81,55 @@ def compile_graphql_to_gremlin(schema, graphql_string, type_equivalence_hints=No return _compile_graphql_generic( GREMLIN_LANGUAGE, lowering_func, query_emitter_func, - schema, graphql_string, type_equivalence_hints) + schema, graphql_string, type_equivalence_hints, None) + + +def compile_graphql_to_sql(schema, graphql_string, compiler_metadata, type_equivalence_hints=None): + """Compile the GraphQL input using the schema into a SQL query and associated metadata. + + Args: + schema: GraphQL schema object describing the schema of the graph to be queried + graphql_string: the GraphQL query to compile to SQL, as a string + compiler_metadata: SQLAlchemy metadata containing tables for use during compilation. + type_equivalence_hints: optional dict of GraphQL interface or type -> GraphQL union. + Used as a workaround for GraphQL's lack of support for + inheritance across "types" (i.e. non-interfaces), as well as a + workaround for Gremlin's total lack of inheritance-awareness. + The key-value pairs in the dict specify that the "key" type + is equivalent to the "value" type, i.e. that the GraphQL type or + interface in the key is the most-derived common supertype + of every GraphQL type in the "value" GraphQL union. + Recursive expansion of type equivalence hints is not performed, + and only type-level correctness of this argument is enforced. + See README.md for more details on everything this parameter does. + ***** + Be very careful with this option, as bad input here will + lead to incorrect output queries being generated. + ***** + + Returns: + a CompilationResult object + """ + raise NotImplementedError(u'Compiling GraphQL to SQL is not yet supported.') def _compile_graphql_generic(language, lowering_func, query_emitter_func, - schema, graphql_string, type_equivalence_hints): - """Compile the GraphQL input, lowering and emitting the query using the given functions.""" + schema, graphql_string, type_equivalence_hints, compiler_metadata): + """Compile the GraphQL input, lowering and emitting the query using the given functions. + + Args: + language: string indicating the target language to compile to. + lowering_func: Function to lower the compiler IR into a compatible form for the target + language backend. + query_emitter_func: Function that emits a query in the target language from the lowered IR. + schema: GraphQL schema object describing the schema of the graph to be queried. + graphql_string: the GraphQL query to compile to the target language, as a string. + type_equivalence_hints: optional dict of GraphQL interface or type -> GraphQL union. + compiler_metadata: optional target specific metadata for usage by the query_emitter_func. + + Returns: + a CompilationResult object + """ ir_and_metadata = graphql_to_ir( schema, graphql_string, type_equivalence_hints=type_equivalence_hints) @@ -93,7 +137,7 @@ def _compile_graphql_generic(language, lowering_func, query_emitter_func, ir_and_metadata.ir_blocks, ir_and_metadata.query_metadata_table, type_equivalence_hints=type_equivalence_hints) - query = query_emitter_func(lowered_ir_blocks) + query = query_emitter_func(lowered_ir_blocks, compiler_metadata) return CompilationResult( query=query, diff --git a/graphql_compiler/compiler/compiler_entities.py b/graphql_compiler/compiler/compiler_entities.py index 64603b901..97cee556b 100644 --- a/graphql_compiler/compiler/compiler_entities.py +++ b/graphql_compiler/compiler/compiler_entities.py @@ -11,6 +11,8 @@ class CompilerEntity(object): """An abstract compiler entity. Can represent things like basic blocks and expressions.""" + __slots__ = ('_print_args', '_print_kwargs') + def __init__(self, *args, **kwargs): """Construct a new CompilerEntity.""" self._print_args = args @@ -60,6 +62,8 @@ def to_gremlin(self): class Expression(CompilerEntity): """An expression that produces a value in the GraphQL compiler.""" + __slots__ = () + def visit_and_update(self, visitor_fn): """Create an updated version (if needed) of the Expression via the visitor pattern. @@ -86,6 +90,8 @@ def visit_and_update(self, visitor_fn): class BasicBlock(CompilerEntity): """A basic operation block of the GraphQL compiler.""" + __slots__ = () + def visit_and_update_expressions(self, visitor_fn): """Create an updated version (if needed) of the BasicBlock via the visitor pattern. @@ -113,6 +119,8 @@ def visit_and_update_expressions(self, visitor_fn): class MarkerBlock(BasicBlock): """A block that is used to mark that a context-affecting operation with no output happened.""" + __slots__ = () + def to_gremlin(self): """Return the Gremlin representation of the block, which should almost always be empty. diff --git a/graphql_compiler/compiler/compiler_frontend.py b/graphql_compiler/compiler/compiler_frontend.py index 85e7a05d6..55905cdbe 100644 --- a/graphql_compiler/compiler/compiler_frontend.py +++ b/graphql_compiler/compiler/compiler_frontend.py @@ -87,7 +87,7 @@ invert_dict, is_vertex_field_name, strip_non_null_from_type, validate_output_name, validate_safe_string ) -from .metadata import LocationInfo, QueryMetadataTable +from .metadata import LocationInfo, QueryMetadataTable, RecurseInfo # LocationStackEntry contains the following: @@ -393,15 +393,8 @@ def _compile_vertex_ast(schema, current_schema_type, ast, inner_location, context) basic_blocks.extend(inner_basic_blocks) - # The length of the stack should be the same before exiting this function - initial_marked_location_stack_size = len(context['marked_location_stack']) - # step V-3: mark the graph position, and process output_source directive basic_blocks.append(_mark_location(location)) - if not is_in_fold_scope(context): - # The following append is the Location corresponding to the initial MarkLocation - # for the current vertex and the `num_traverses` counter set to 0. - context['marked_location_stack'].append(_construct_location_stack_entry(location, 0)) output_source = _process_output_source_directive(schema, current_schema_type, ast, location, context, unique_local_directives) @@ -436,17 +429,27 @@ def _compile_vertex_ast(schema, current_schema_type, ast, within_optional_scope = 'optional' in context if edge_traversal_is_optional: - # Entering an optional block! - # Make sure there's a marked location right before it for the optional Backtrack - # to jump back to. Otherwise, the traversal could rewind to an old marked location - # and might ignore entire stretches of applied filtering. - if context['marked_location_stack'][-1].num_traverses > 0: + # Invariant: There must always be a marked location corresponding to the query position + # immediately before any optional Traverse. + # + # This invariant is verified in the IR sanity checks module (ir_sanity_checks.py), + # in the function named _sanity_check_mark_location_preceding_optional_traverse(). + # + # This marked location is the one that the @optional directive's corresponding + # optional Backtrack will jump back to. If such a marked location isn't present, + # the backtrack could rewind to an old marked location and might ignore + # entire stretches of applied filtering. + # + # Assumption: The only way there might not be a marked location here is + # if the current location already traversed into child locations, not including folds. + non_fold_child_locations = { + child_location + for child_location in query_metadata_table.get_child_locations(location) + if not isinstance(child_location, FoldScopeLocation) + } + if non_fold_child_locations: location = query_metadata_table.revisit_location(location) - basic_blocks.append(_mark_location(location)) - context['marked_location_stack'].pop() - new_stack_entry = _construct_location_stack_entry(location, 0) - context['marked_location_stack'].append(new_stack_entry) if fold_directive: inner_location = location.navigate_to_fold(field_name) @@ -485,19 +488,15 @@ def _compile_vertex_ast(schema, current_schema_type, ast, edge_name, recurse_depth, within_optional_scope=within_optional_scope)) + query_metadata_table.record_recurse_info(location, + RecurseInfo(edge_direction=edge_direction, + edge_name=edge_name, + depth=recurse_depth)) else: basic_blocks.append(blocks.Traverse(edge_direction, edge_name, optional=edge_traversal_is_optional, within_optional_scope=within_optional_scope)) - if not edge_traversal_is_folded and not is_in_fold_scope(context): - # Current block is either a Traverse or a Recurse that is not within any fold context. - # Increment the `num_traverses` counter. - old_location_stack_entry = context['marked_location_stack'][-1] - new_location_stack_entry = _construct_location_stack_entry( - old_location_stack_entry.location, old_location_stack_entry.num_traverses + 1) - context['marked_location_stack'][-1] = new_location_stack_entry - inner_basic_blocks = _compile_ast_node_to_ir(schema, field_schema_type, field_ast, inner_location, context) basic_blocks.extend(inner_basic_blocks) @@ -535,25 +534,9 @@ def _compile_vertex_ast(schema, current_schema_type, ast, location = query_metadata_table.revisit_location(location) basic_blocks.append(_mark_location(location)) - context['marked_location_stack'].pop() - new_stack_entry = _construct_location_stack_entry(location, 0) - context['marked_location_stack'].append(new_stack_entry) else: basic_blocks.append(blocks.Backtrack(location)) - # Pop off the initial Location for the current vertex. - if not is_in_fold_scope(context): - context['marked_location_stack'].pop() - - # Check that the length of the stack remains the same as when control entered this function. - final_marked_location_stack_size = len(context['marked_location_stack']) - if initial_marked_location_stack_size != final_marked_location_stack_size: - raise AssertionError(u'Size of stack changed from {} to {} after executing this function.' - u'This should never happen : {}' - .format(initial_marked_location_stack_size, - final_marked_location_stack_size, - context['marked_location_stack'])) - return basic_blocks @@ -663,7 +646,7 @@ def _compile_ast_node_to_ir(schema, current_schema_type, ast, location, context) # step 1: apply local filter, if any for filter_operation_info in filter_operations: basic_blocks.append( - process_filter_directive(filter_operation_info, context)) + process_filter_directive(filter_operation_info, location, context)) if location.field is not None: # The location is at a property, compile the property data following P-steps. @@ -755,12 +738,6 @@ def _compile_root_ast_to_ir(schema, ast, type_equivalence_hints=None): # 'type_equivalence_hints_inverse' is the inverse of type_equivalence_hints, # which is always invertible. 'type_equivalence_hints_inverse': invert_dict(type_equivalence_hints), - # The marked_location_stack explicitly maintains a stack (implemented as list) - # of namedtuples (each corresponding to a MarkLocation) containing: - # - location: the location within the corresponding MarkLocation object - # - num_traverses: the number of Recurse and Traverse blocks created - # after the corresponding MarkLocation - 'marked_location_stack': [] } # Add the query root basic block to the output. diff --git a/graphql_compiler/compiler/emit_gremlin.py b/graphql_compiler/compiler/emit_gremlin.py index fc879d707..c0daf5618 100644 --- a/graphql_compiler/compiler/emit_gremlin.py +++ b/graphql_compiler/compiler/emit_gremlin.py @@ -6,7 +6,7 @@ # Public API # ############## -def emit_code_from_ir(ir_blocks): +def emit_code_from_ir(ir_blocks, compiler_metadata): """Return a MATCH query string from a list of IR blocks.""" gremlin_steps = ( block.to_gremlin() diff --git a/graphql_compiler/compiler/emit_match.py b/graphql_compiler/compiler/emit_match.py index d854af690..ab20ae5cb 100644 --- a/graphql_compiler/compiler/emit_match.py +++ b/graphql_compiler/compiler/emit_match.py @@ -241,7 +241,7 @@ def emit_code_from_multiple_match_queries(match_queries): return u' '.join(query_data) -def emit_code_from_ir(compound_match_query): +def emit_code_from_ir(compound_match_query, compiler_metadata): """Return a MATCH query string from a CompoundMatchQuery.""" # If the compound match query contains only one match query, # just call `emit_code_from_single_match_query` diff --git a/graphql_compiler/compiler/emit_sql.py b/graphql_compiler/compiler/emit_sql.py new file mode 100644 index 000000000..fecf50faa --- /dev/null +++ b/graphql_compiler/compiler/emit_sql.py @@ -0,0 +1,7 @@ +# Copyright 2018-present Kensho Technologies, LLC. +"""Transform a tree representation of an SQL query into an executable SQLAlchemy query.""" + + +def emit_code_from_ir(sql_query_tree, compiler_metadata): + """Return a SQLAlchemy Query from a passed tree representation of an SQL query.""" + raise NotImplementedError(u'SQL query emitting is not yet supported.') diff --git a/graphql_compiler/compiler/expressions.py b/graphql_compiler/compiler/expressions.py index e1cd2b068..317fb9a25 100644 --- a/graphql_compiler/compiler/expressions.py +++ b/graphql_compiler/compiler/expressions.py @@ -37,6 +37,8 @@ class Literal(Expression): Think long and hard about the above before allowing literals in user-supplied GraphQL! """ + __slots__ = ('value',) + def __init__(self, value): """Construct a new Literal object with the given value.""" super(Literal, self).__init__(value) @@ -106,6 +108,8 @@ def _to_output_code(self): class Variable(Expression): """A variable for a parameterized query, to be filled in at runtime.""" + __slots__ = ('variable_name', 'inferred_type') + def __init__(self, variable_name, inferred_type): """Construct a new Variable object for the given variable name. @@ -203,6 +207,8 @@ def __ne__(self, other): class LocalField(Expression): """A field at the current position in the query.""" + __slots__ = ('field_name',) + def __init__(self, field_name): """Construct a new LocalField object that references a field at the current position.""" super(LocalField, self).__init__(field_name) @@ -240,6 +246,8 @@ def to_gremlin(self): class SelectEdgeContextField(Expression): """An edge field drawn from the global context, for use in a SELECT WHERE statement.""" + __slots__ = ('location',) + def __init__(self, location): """Construct a new SelectEdgeContextField object that references an edge field. @@ -287,6 +295,8 @@ def to_gremlin(self): class ContextField(Expression): """A field drawn from the global context, e.g. if selected earlier in the query.""" + __slots__ = ('location',) + def __init__(self, location): """Construct a new ContextField object that references a field from the global context. @@ -345,6 +355,8 @@ def to_gremlin(self): class OutputContextField(Expression): """A field used in ConstructResult blocks to output data from the global context.""" + __slots__ = ('location', 'field_type') + def __init__(self, location, field_type): """Construct a new OutputContextField object for the field at the given location. @@ -443,6 +455,8 @@ def __ne__(self, other): class FoldedOutputContextField(Expression): """An expression used to output data captured in a @fold scope.""" + __slots__ = ('fold_scope_location', 'field_type') + def __init__(self, fold_scope_location, field_type): """Construct a new FoldedOutputContextField object for this folded field. @@ -527,6 +541,8 @@ class ContextFieldExistence(Expression): Useful to determine whether e.g. a field at the end of an optional edge is defined or not. """ + __slots__ = ('location',) + def __init__(self, location): """Construct a new ContextFieldExistence object for a vertex field from the global context. @@ -572,8 +588,9 @@ def _validate_operator_name(operator, supported_operators): class UnaryTransformation(Expression): """An expression that modifies an underlying expression with a unary operator.""" - SUPPORTED_OPERATORS = frozenset( - {u'size'}) + SUPPORTED_OPERATORS = frozenset({u'size'}) + + __slots__ = ('operator', 'inner_expression') def __init__(self, operator, inner_expression): """Construct a UnaryExpression that modifies the given inner expression.""" @@ -638,9 +655,12 @@ def to_gremlin(self): class BinaryComposition(Expression): """An expression created by composing two expressions together.""" - SUPPORTED_OPERATORS = frozenset( - {u'=', u'!=', u'>=', u'<=', u'>', u'<', u'+', u'||', u'&&', u'contains', u'intersects', - u'has_substring', u'LIKE', u'INSTANCEOF'}) + SUPPORTED_OPERATORS = frozenset({ + u'=', u'!=', u'>=', u'<=', u'>', u'<', u'+', u'||', u'&&', + u'contains', u'intersects', u'has_substring', u'LIKE', u'INSTANCEOF', + }) + + __slots__ = ('operator', 'left', 'right') def __init__(self, operator, left, right): """Construct an expression that connects two expressions with an operator. @@ -764,6 +784,8 @@ def to_gremlin(self): class TernaryConditional(Expression): """A ternary conditional expression, returning one of two expressions depending on a third.""" + __slots__ = ('predicate', 'if_true', 'if_false') + def __init__(self, predicate, if_true, if_false): """Construct an expression that evaluates a predicate and returns one of two results. diff --git a/graphql_compiler/compiler/filters.py b/graphql_compiler/compiler/filters.py index 3942077b7..0eda9069e 100644 --- a/graphql_compiler/compiler/filters.py +++ b/graphql_compiler/compiler/filters.py @@ -11,6 +11,7 @@ get_uniquely_named_objects_by_name, is_vertex_field_name, is_vertex_field_type, strip_non_null_from_type, validate_safe_string ) +from .metadata import FilterInfo def scalar_leaf_only(operator): @@ -593,15 +594,15 @@ def is_filter_with_outer_scope_vertex_field_operator(directive): return op_name in OUTER_SCOPE_VERTEX_FIELD_OPERATORS -def process_filter_directive(filter_operation_info, context): +def process_filter_directive(filter_operation_info, location, context): """Return a Filter basic block that corresponds to the filter operation in the directive. Args: filter_operation_info: FilterOperationInfo object, containing the directive and field info of the field where the filter is to be applied. + location: Location where this filter is used. context: dict, various per-compilation data (e.g. declared tags, whether the current block is optional, etc.). May be mutated in-place in this function! - directive: GraphQL @filter directive object, obtained from the AST node Returns: a Filter basic block that performs the requested filtering operation @@ -641,4 +642,12 @@ def process_filter_directive(filter_operation_info, context): raise GraphQLCompilationError(u'The filter with op_name "{}" must be applied on a field. ' u'It may not be applied on a type coercion.'.format(op_name)) + fields = ((filter_operation_info.field_name,) if op_name != 'name_or_alias' + else ('name', 'alias')) + + context['metadata'].record_filter_info( + location, + FilterInfo(fields=fields, op_name=op_name, args=tuple(operator_params)) + ) + return process_func(filter_operation_info, context, operator_params) diff --git a/graphql_compiler/compiler/ir_lowering_sql/__init__.py b/graphql_compiler/compiler/ir_lowering_sql/__init__.py new file mode 100644 index 000000000..d29f6b6bf --- /dev/null +++ b/graphql_compiler/compiler/ir_lowering_sql/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2018-present Kensho Technologies, LLC. + +############## +# Public API # +############## + + +def lower_ir(ir_blocks, query_metadata_table, type_equivalence_hints=None): + """Lower the IR into a form that can be represented by a SQL query. + + Args: + ir_blocks: list of IR blocks to lower into SQL-compatible form + query_metadata_table: QueryMetadataTable object containing all metadata collected during + query processing, including location metadata (e.g. which locations + are folded or optional). + type_equivalence_hints: optional dict of GraphQL interface or type -> GraphQL union. + Used as a workaround for GraphQL's lack of support for + inheritance across "types" (i.e. non-interfaces), as well as a + workaround for Gremlin's total lack of inheritance-awareness. + The key-value pairs in the dict specify that the "key" type + is equivalent to the "value" type, i.e. that the GraphQL type or + interface in the key is the most-derived common supertype + of every GraphQL type in the "value" GraphQL union. + Recursive expansion of type equivalence hints is not performed, + and only type-level correctness of this argument is enforced. + See README.md for more details on everything this parameter does. + ***** + Be very careful with this option, as bad input here will + lead to incorrect output queries being generated. + ***** + + Returns: + tree representation of IR blocks for recursive traversal by SQL backend. + """ + raise NotImplementedError(u'SQL IR lowering is not yet implemented.') diff --git a/graphql_compiler/compiler/metadata.py b/graphql_compiler/compiler/metadata.py index 2c24d6137..f92e6ac6d 100644 --- a/graphql_compiler/compiler/metadata.py +++ b/graphql_compiler/compiler/metadata.py @@ -4,7 +4,7 @@ import six -from .helpers import Location +from .helpers import FoldScopeLocation, Location LocationInfo = namedtuple( @@ -24,6 +24,25 @@ ) +FilterInfo = namedtuple( + 'FilterInfo', + ( + 'fields', + 'op_name', + 'args', + ) +) + +RecurseInfo = namedtuple( + 'RecurseInfo', + ( + 'edge_direction', + 'edge_name', + 'depth', + ) +) + + @six.python_2_unicode_compatible class QueryMetadataTable(object): """Query metadata container with info on locations, inputs, outputs, and tags in the query.""" @@ -45,6 +64,9 @@ def __init__(self, root_location, root_location_info): self._outputs = dict() # dict, output name -> output info namedtuple self._tags = dict() # dict, tag name -> tag info namedtuple + self._filter_infos = dict() # Location -> FilterInfo array + self._recurse_infos = dict() # Location -> RecurseInfo array + # dict, revisiting Location -> revisit origin, i.e. the first Location with that query path self._revisit_origins = dict() @@ -132,6 +154,27 @@ def get_location_info(self, location): u'{}'.format(location)) return location_info + def record_filter_info(self, location, filter_info): + """Record filter information about the location.""" + if isinstance(location, FoldScopeLocation): + # NOTE(gurer): ignore filters inside the fold for now + return + record_location = location.at_vertex() + self._filter_infos.setdefault(record_location, []).append(filter_info) + + def get_filter_infos(self, location): + """Get information about filters at the location.""" + return self._filter_infos.get(location, []) + + def record_recurse_info(self, location, recurse_info): + """Record recursion information about the location.""" + record_location = location.at_vertex() + self._recurse_infos.setdefault(record_location, []).append(recurse_info) + + def get_recurse_infos(self, location): + """Get information about recursions at the location.""" + return self._recurse_infos.get(location, []) + def get_child_locations(self, location): """Yield an iterable of child locations for a given Location/FoldScopeLocation object.""" self.get_location_info(location) # purely to check for location validity diff --git a/graphql_compiler/tests/test_emit_output.py b/graphql_compiler/tests/test_emit_output.py index feebd9083..87f3545d4 100644 --- a/graphql_compiler/tests/test_emit_output.py +++ b/graphql_compiler/tests/test_emit_output.py @@ -48,7 +48,7 @@ def test_simple_immediate_output(self): ) ''' - received_match = emit_match.emit_code_from_ir(compound_match_query) + received_match = emit_match.emit_code_from_ir(compound_match_query, None) compare_match(self, expected_match, received_match) def test_simple_traverse_filter_output(self): @@ -91,7 +91,7 @@ def test_simple_traverse_filter_output(self): ) ''' - received_match = emit_match.emit_code_from_ir(compound_match_query) + received_match = emit_match.emit_code_from_ir(compound_match_query, None) compare_match(self, expected_match, received_match) def test_output_inside_optional_traversal(self): @@ -159,7 +159,7 @@ def test_output_inside_optional_traversal(self): ) ''' - received_match = emit_match.emit_code_from_ir(compound_match_query) + received_match = emit_match.emit_code_from_ir(compound_match_query, None) compare_match(self, expected_match, received_match) def test_datetime_variable_representation(self): @@ -201,7 +201,7 @@ def test_datetime_variable_representation(self): ) ''' - received_match = emit_match.emit_code_from_ir(compound_match_query) + received_match = emit_match.emit_code_from_ir(compound_match_query, None) compare_match(self, expected_match, received_match) def test_datetime_output_representation(self): @@ -228,7 +228,7 @@ def test_datetime_output_representation(self): ) ''' - received_match = emit_match.emit_code_from_ir(compound_match_query) + received_match = emit_match.emit_code_from_ir(compound_match_query, None) compare_match(self, expected_match, received_match) @@ -257,7 +257,7 @@ def test_simple_immediate_output(self): ])} ''' - received_match = emit_gremlin.emit_code_from_ir(ir_blocks) + received_match = emit_gremlin.emit_code_from_ir(ir_blocks, None) compare_gremlin(self, expected_gremlin, received_match) def test_simple_traverse_filter_output(self): @@ -292,7 +292,7 @@ def test_simple_traverse_filter_output(self): ])} ''' - received_match = emit_gremlin.emit_code_from_ir(ir_blocks) + received_match = emit_gremlin.emit_code_from_ir(ir_blocks, None) compare_gremlin(self, expected_gremlin, received_match) def test_output_inside_optional_traversal(self): @@ -331,7 +331,7 @@ def test_output_inside_optional_traversal(self): ])} ''' - received_match = emit_gremlin.emit_code_from_ir(ir_blocks) + received_match = emit_gremlin.emit_code_from_ir(ir_blocks, None) compare_gremlin(self, expected_gremlin, received_match) def test_datetime_output_representation(self): @@ -354,5 +354,5 @@ def test_datetime_output_representation(self): ])} ''' - received_match = emit_gremlin.emit_code_from_ir(ir_blocks) + received_match = emit_gremlin.emit_code_from_ir(ir_blocks, None) compare_gremlin(self, expected_gremlin, received_match) diff --git a/graphql_compiler/tests/test_explain_info.py b/graphql_compiler/tests/test_explain_info.py new file mode 100644 index 000000000..1fbd2ac6c --- /dev/null +++ b/graphql_compiler/tests/test_explain_info.py @@ -0,0 +1,132 @@ +# Copyright 2018-present Kensho Technologies, LLC. +import unittest + +from . import test_input_data +from ..compiler.compiler_frontend import graphql_to_ir +from ..compiler.helpers import Location +from ..compiler.metadata import FilterInfo, RecurseInfo +from .test_helpers import get_schema + + +class ExplainInfoTests(unittest.TestCase): + """Ensure we get correct information about filters and recursion.""" + + def setUp(self): + """Initialize the test schema once for all tests.""" + self.schema = get_schema() + + def check(self, graphql_test, expected_filters, expected_recurses): + """Verify query produces expected explain infos.""" + ir_and_metadata = graphql_to_ir(self.schema, graphql_test().graphql_input) + meta = ir_and_metadata.query_metadata_table + # Unfortunately literal dicts don't accept Location() as keys + expected_filters = dict(expected_filters) + expected_recurses = dict(expected_recurses) + for location, _ in meta.registered_locations: + # Do filters match with expected for this location? + filters = meta.get_filter_infos(location) + self.assertEqual(expected_filters.get(location, []), filters) + if filters: + del expected_filters[location] + # Do recurse match with expected for this location? + recurse = meta.get_recurse_infos(location) + self.assertEqual(expected_recurses.get(location, []), recurse) + if recurse: + del expected_recurses[location] + # Any expected infos missing? + self.assertEqual(0, len(expected_filters)) + self.assertEqual(0, len(expected_recurses)) + + def test_traverse_filter_and_output(self): + loc = Location(('Animal', 'out_Animal_ParentOf'), None, 1) + filters = [ + FilterInfo(fields=('name', 'alias'), op_name='name_or_alias', args=('$wanted',)), + ] + + self.check(test_input_data.traverse_filter_and_output, + [(loc, filters)], + []) + + def test_complex_optional_traversal_variables(self): + loc1 = Location(('Animal',), None, 1) + filters1 = [ + FilterInfo(fields=('name',), op_name='=', args=('$animal_name',)), + ] + + loc2 = Location(('Animal', 'in_Animal_ParentOf', 'out_Animal_FedAt'), None, 1) + filters2 = [ + FilterInfo(fields=('name',), op_name='=', args=('%parent_fed_at_event',)), + FilterInfo(fields=('event_date',), + op_name='between', + args=('%other_child_fed_at', '%parent_fed_at')), + ] + + self.check(test_input_data.complex_optional_traversal_variables, + [(loc1, filters1), (loc2, filters2)], + []) + + def test_coercion_filters_and_multiple_outputs_within_fold_scope(self): + self.check(test_input_data.coercion_filters_and_multiple_outputs_within_fold_scope, + [], + []) + + def test_multiple_filters(self): + loc = Location(('Animal',), None, 1) + filters = [ + FilterInfo(fields=('name',), op_name='>=', args=('$lower_bound',)), + FilterInfo(fields=('name',), op_name='<', args=('$upper_bound',)) + ] + + self.check(test_input_data.multiple_filters, + [(loc, filters)], + []) + + def test_has_edge_degree_op_filter(self): + loc = Location(('Animal',), None, 1) + filters = [ + FilterInfo(fields=('in_Animal_ParentOf',), + op_name='has_edge_degree', + args=('$child_count',)) + ] + + self.check(test_input_data.has_edge_degree_op_filter, + [(loc, filters)], + []) + + def test_simple_recurse(self): + loc = Location(('Animal',), None, 1) + recurses = [ + RecurseInfo(edge_direction='out', edge_name='Animal_ParentOf', depth=1) + ] + + self.check(test_input_data.simple_recurse, + [], + [(loc, recurses)]) + + def test_two_consecutive_recurses(self): + loc = Location(('Animal',), None, 1) + filters = [ + FilterInfo(fields=('name', 'alias'), + op_name='name_or_alias', + args=('$animal_name_or_alias',)) + ] + recurses = [ + RecurseInfo(edge_direction='out', edge_name='Animal_ParentOf', depth=2), + RecurseInfo(edge_direction='in', edge_name='Animal_ParentOf', depth=2) + ] + + self.check(test_input_data.two_consecutive_recurses, + [(loc, filters)], + [(loc, recurses)]) + + def test_filter_on_optional_traversal_name_or_alias(self): + loc = Location(('Animal', 'out_Animal_ParentOf'), None, 1) + filters = [ + FilterInfo(fields=('name', 'alias'), + op_name='name_or_alias', + args=('%grandchild_name',)) + ] + + self.check(test_input_data.filter_on_optional_traversal_name_or_alias, + [(loc, filters)], + []) diff --git a/requirements.txt b/requirements.txt index fb43d1f52..21d4447f2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ funcy==1.7.3 graphql-core==1.1 pytz==2017.2 six==1.10.0 +sqlalchemy==1.2.9 diff --git a/setup.py b/setup.py index c37dfa811..e49e5f1e3 100644 --- a/setup.py +++ b/setup.py @@ -61,6 +61,7 @@ def find_long_description(): 'graphql-core==1.1', 'pytz>=2016.10', 'six>=1.10.0', + 'sqlalchemy>=1.2.1,<1.3', ], classifiers=[ 'Development Status :: 5 - Production/Stable',