Skip to content

Commit

Permalink
Sql optionals (#528)
Browse files Browse the repository at this point in the history
* Implement filters and traverse in optional

* Use left join in optional scope

* Replace hack with lesser hack

* Add some guiding comments

* Document approach

* Rewrite (x IS NOT NULL) = 0 to x IS NULL

* Lint

* Remove came_from hack

* Lint

* Lint

* Address comments
  • Loading branch information
bojanserafimov authored and obi1kenobi committed Sep 4, 2019
1 parent eb85f18 commit 7df647f
Show file tree
Hide file tree
Showing 4 changed files with 408 additions and 25 deletions.
29 changes: 24 additions & 5 deletions graphql_compiler/compiler/emit_sql.py
Expand Up @@ -54,6 +54,7 @@ def __init__(self, sql_schema_info, ir):
self._current_alias = None # a sqlalchemy table Alias at the current location
self._aliases = {} # mapping marked query paths to table _Aliases representing them
self._relocate(ir.query_metadata_table.root_location)
self._came_from = {} # mapping aliases to the column used to join into them.

# The query being constructed as the IR is processed
self._from_clause = self._current_alias # the main sqlalchemy Selectable
Expand Down Expand Up @@ -94,15 +95,31 @@ def traverse(self, vertex_field, optional):
edge = self._sql_schema_info.join_descriptors[self._current_classname][vertex_field]
self._relocate(self._current_location.navigate_to_subpath(vertex_field))

self._came_from[self._current_alias] = self._current_alias.c[edge.to_column]
if self._is_in_optional_scope() and not optional:
raise NotImplementedError(u'The SQL backend does not implement mandatory '
u'traversals inside an @optional scope.')
# For mandatory edges in optional scope, we emit LEFT OUTER JOIN and enforce the
# edge being mandatory with additional filters in the WHERE clause.
#
# This is some tricky logic. To prevent regression, here's some caution against
# solutions that might seem simpler, but are not correct:
# 1. You might think it's simpler to just use an INNER JOIN for mandatory edges in
# optional scope. However, if there is a LEFT JOIN miss, the NULL value resulting
# from it will not match anything in this INNER JOIN, and the row will be removed.
# As a result, @optional semantics will not be preserved.
# 2. You might think that a cleaner solution is performing all the mandatory traversals
# first in subqueries, and joining those subqueries with LEFT OUTER JOIN. This
# approach is incorrect because a mandatory edge traversal miss inside an optional
# scope is supposed to invalidate the whole result. However, with this solution the
# result will still appear.
self._filters.append(sqlalchemy.or_(
self._came_from[self._current_alias].isnot(None),
self._came_from[previous_alias].is_(None)))

# Join to where we came from
self._from_clause = self._from_clause.join(
self._current_alias,
onclause=(previous_alias.c[edge.from_column] == self._current_alias.c[edge.to_column]),
isouter=optional)
isouter=self._is_in_optional_scope())

def start_global_operations(self):
"""Execute a GlobalOperationsStart block."""
Expand All @@ -112,9 +129,11 @@ def start_global_operations(self):

def filter(self, predicate):
"""Execute a Filter Block."""
sql_expression = predicate.to_sql(self._aliases, self._current_alias)
if self._is_in_optional_scope():
raise NotImplementedError(u'Filters in @optional are not implemented in SQL')
self._filters.append(predicate.to_sql(self._aliases, self._current_alias))
sql_expression = sqlalchemy.or_(sql_expression,
self._came_from[self._current_alias].is_(None))
self._filters.append(sql_expression)

def mark_location(self):
"""Execute a MarkLocation Block."""
Expand Down
2 changes: 1 addition & 1 deletion graphql_compiler/compiler/expressions.py
Expand Up @@ -835,7 +835,7 @@ def to_cypher(self):

def to_sql(self, aliases, current_alias):
"""Must not be used -- ContextFieldExistence must be lowered during the IR lowering step."""
raise NotImplementedError(u'Filters in @optional are not implemented in SQL')
raise AssertionError(u'ContextFieldExistence.to_sql() was called: {}'.format(self))


def _validate_operator_name(operator, supported_operators):
Expand Down
92 changes: 92 additions & 0 deletions graphql_compiler/compiler/ir_lowering_sql/__init__.py
@@ -1,6 +1,9 @@
# Copyright 2018-present Kensho Technologies, LLC.
import six

from .. import blocks, expressions
from ...compiler.compiler_frontend import IrAndMetadata
from ..helpers import FoldScopeLocation, get_edge_direction_and_name
from ..ir_lowering_common import common


Expand All @@ -21,6 +24,93 @@ def visitor_fn(expression):

return new_ir_blocks


def _find_non_null_columns(schema_info, query_metadata_table):
"""Find a column for each non-root location that's non-null if and only if the vertex exists."""
non_null_column = {}

# Find foreign keys used
for location, location_info in query_metadata_table.registered_locations:
for child_location in query_metadata_table.get_child_locations(location):
if isinstance(child_location, FoldScopeLocation):
raise NotImplementedError()

edge_direction, edge_name = get_edge_direction_and_name(child_location.query_path[-1])
vertex_field_name = '{}_{}'.format(edge_direction, edge_name)
edge = schema_info.join_descriptors[location_info.type.name][vertex_field_name]

# The value of the column used to join to this table is an indicator of whether
# the left join was a hit or a miss.
non_null_column[child_location.query_path] = edge.to_column

return non_null_column


class ContextColumn(expressions.Expression):
"""A column drawn from the global context.
It is different than an expressions.ContextField because it does not reference a property
type in the GraphQL schema, but a column name in the actual SQL table. Some columns are
not even represented in the GraphQL schema as properties. An example is Animals.parent
in the test schema.
"""

def __init__(self, vertex_query_path, column_name):
"""Construct a new ContextColumn."""
super(ContextColumn, self).__init__(vertex_query_path, column_name)
self._vertex_query_path = vertex_query_path
self._column_name = column_name
self.validate()

def validate(self):
"""Validate that the ContextColumn is correctly representable."""
if not isinstance(self._vertex_query_path, tuple):
raise AssertionError(u'vertex_query_path was expected to be a tuple, but was {}: {}'
.format(type(self._vertex_query_path), self._vertex_query_path))

if not isinstance(self._column_name, six.string_types):
raise AssertionError(u'column_name was expected to be a string, but was {}: {}'
.format(type(self._column_name), self._column_name))

def to_match(self):
"""Not implemented, should not be used."""
raise AssertionError(u'ContextColumns are not used during the query emission process '
u'in MATCH, so this is a bug. This function should not be called.')

def to_gremlin(self):
"""Not implemented, should not be used."""
raise AssertionError(u'ContextColumns are not used during the query emission process '
u'in Gremlin, so this is a bug. This function should not be called.')

def to_cypher(self):
"""Not implemented, should not be used."""
raise AssertionError(u'ContextColumns are not used during the query emission process '
u'in cypher, so this is a bug. This function should not be called.')

def to_sql(self, aliases, current_alias):
"""Return a sqlalchemy Column picked from the appropriate alias."""
self.validate()
return aliases[self._vertex_query_path].c[self._column_name]


def _lower_sql_context_field_existence(schema_info, ir_blocks, query_metadata_table):
"""Lower ContextFieldExistence to BinaryComposition."""
non_null_columns = _find_non_null_columns(schema_info, query_metadata_table)

def visitor_fn(expression):
"""Convert ContextFieldExistence expressions to TrueLiteral."""
if not isinstance(expression, expressions.ContextFieldExistence):
return expression

query_path = expression.location.query_path
return expressions.BinaryComposition(
u'!=',
ContextColumn(query_path, non_null_columns[query_path]),
expressions.NullLiteral)

return [block.visit_and_update_expressions(visitor_fn) for block in ir_blocks]


##############
# Public API #
##############
Expand All @@ -38,5 +128,7 @@ def lower_ir(schema_info, ir):
"""
ir_blocks = ir.ir_blocks
ir_blocks = _remove_output_context_field_existence(ir_blocks, ir.query_metadata_table)
ir_blocks = _lower_sql_context_field_existence(schema_info, ir_blocks, ir.query_metadata_table)
ir_blocks = common.short_circuit_ternary_conditionals(ir_blocks, ir.query_metadata_table)
ir_blocks = common.optimize_boolean_expression_comparisons(ir_blocks)
return IrAndMetadata(ir_blocks, ir.input_metadata, ir.output_metadata, ir.query_metadata_table)

0 comments on commit 7df647f

Please sign in to comment.