Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sql optionals #528

Merged
merged 11 commits into from
Sep 4, 2019
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions graphql_compiler/compiler/emit_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self, sql_schema_info, ir):
self._current_alias = None # a sqlalchemy table Alias at the current location
self._aliases = {} # mapping marked query paths to table _Aliases representing them
self._relocate(ir.query_metadata_table.root_location)
self._came_from = {} # mapping aliases to the column used to join into them.

# The query being constructed as the IR is processed
self._from_clause = self._current_alias # the main sqlalchemy Selectable
Expand Down Expand Up @@ -94,15 +95,31 @@ def traverse(self, vertex_field, optional):
edge = self._sql_schema_info.join_descriptors[self._current_classname][vertex_field]
self._relocate(self._current_location.navigate_to_subpath(vertex_field))

self._came_from[self._current_alias] = self._current_alias.c[edge.to_column]
if self._is_in_optional_scope() and not optional:
raise NotImplementedError(u'The SQL backend does not implement mandatory '
u'traversals inside an @optional scope.')
# For mandatory edges in optional scope, we emit LEFT OUTER JOIN and enforce the
# edge being mandatory with additional filters in the WHERE clause.
#
# This is some tricky logic. To prevent regression, here's some caution against
bojanserafimov marked this conversation as resolved.
Show resolved Hide resolved
# solutions that might seem simpler, but are not correct:
# 1. You might think it's simpler to just use an INNER JOIN for mandatory edges in
# optional scope. However, if there is a LEFT JOIN miss, the NULL value resulting
# from it will not match anything in this INNER JOIN, and the row will be removed.
# As a result, @optional semantics will not be preserved.
# 2. You might think that a cleaner solution is performing all the mandatory traversals
# first in subqueries, and joining those subqueries with LEFT OUTER JOIN. This
# approach is incorrect because a mandatory edge traversal miss inside an optional
# scope is supposed to invalidate the whole result. However, with this solution the
# result will still appear.
self._filters.append(sqlalchemy.or_(
self._came_from[self._current_alias].isnot(None),
self._came_from[previous_alias].is_(None)))

# Join to where we came from
self._from_clause = self._from_clause.join(
self._current_alias,
onclause=(previous_alias.c[edge.from_column] == self._current_alias.c[edge.to_column]),
isouter=optional)
isouter=self._is_in_optional_scope())

def start_global_operations(self):
"""Execute a GlobalOperationsStart block."""
Expand All @@ -112,9 +129,11 @@ def start_global_operations(self):

def filter(self, predicate):
"""Execute a Filter Block."""
sql_expression = predicate.to_sql(self._aliases, self._current_alias)
if self._is_in_optional_scope():
raise NotImplementedError(u'Filters in @optional are not implemented in SQL')
self._filters.append(predicate.to_sql(self._aliases, self._current_alias))
sql_expression = sqlalchemy.or_(sql_expression,
self._came_from[self._current_alias].is_(None))
self._filters.append(sql_expression)

def mark_location(self):
"""Execute a MarkLocation Block."""
Expand Down
2 changes: 1 addition & 1 deletion graphql_compiler/compiler/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ def to_cypher(self):

def to_sql(self, aliases, current_alias):
"""Must not be used -- ContextFieldExistence must be lowered during the IR lowering step."""
raise NotImplementedError(u'Filters in @optional are not implemented in SQL')
raise AssertionError(u'ContextFieldExistence.to_sql() was called: {}'.format(self))


def _validate_operator_name(operator, supported_operators):
Expand Down
75 changes: 75 additions & 0 deletions graphql_compiler/compiler/ir_lowering_sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2018-present Kensho Technologies, LLC.
from .. import blocks, expressions
from ...compiler.compiler_frontend import IrAndMetadata
from ..helpers import FoldScopeLocation, get_edge_direction_and_name
from ..ir_lowering_common import common


Expand All @@ -21,6 +22,78 @@ def visitor_fn(expression):

return new_ir_blocks


def _find_non_null_columns(schema_info, query_metadata_table):
"""Find a column for each non-root location that's non-null if and only if the vertex exists."""
non_null_column = {}

# Find foreign keys used
for location, location_info in query_metadata_table.registered_locations:
for child_location in query_metadata_table.get_child_locations(location):
if isinstance(child_location, FoldScopeLocation):
raise NotImplementedError()

edge_direction, edge_name = get_edge_direction_and_name(child_location.query_path[-1])
vertex_field_name = '{}_{}'.format(edge_direction, edge_name)
edge = schema_info.join_descriptors[location_info.type.name][vertex_field_name]

# The value of the column used to join to this table is an indicator of whether
# the left join was a hit or a miss.
non_null_column[child_location.query_path] = edge.to_column

return non_null_column


class ContextColumn(expressions.Expression):
"""A column drawn from the global context."""
bojanserafimov marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, vertex_query_path, column_name):
"""Construct a new ContextColumn."""
super(ContextColumn, self).__init__()
bojanserafimov marked this conversation as resolved.
Show resolved Hide resolved
self._vertex_query_path = vertex_query_path
self._column_name = column_name
bojanserafimov marked this conversation as resolved.
Show resolved Hide resolved

def validate(self):
"""Validate that the ContextColumn is correctly representable."""
pass
bojanserafimov marked this conversation as resolved.
Show resolved Hide resolved

def to_match(self):
"""Not implemented, should not be used."""
raise AssertionError(u'ContextColumns are not used during the query emission process '
u'in MATCH, so this is a bug. This function should not be called.')

def to_gremlin(self):
"""Not implemented, should not be used."""
raise AssertionError(u'ContextColumns are not used during the query emission process '
u'in Gremlin, so this is a bug. This function should not be called.')

def to_cypher(self):
"""Not implemented, should not be used."""
raise AssertionError(u'ContextColumns are not used during the query emission process '
u'in cypher, so this is a bug. This function should not be called.')

def to_sql(self, aliases, current_alias):
"""Return a sqlalchemy Column picked from the appropriate alias."""
return aliases[self._vertex_query_path].c[self._column_name]
bojanserafimov marked this conversation as resolved.
Show resolved Hide resolved


def _lower_sql_context_field_existence(schema_info, ir_blocks, query_metadata_table):
"""Lower ContextFieldExistence to BinaryComposition."""
non_null_columns = _find_non_null_columns(schema_info, query_metadata_table)

def visitor_fn(expression):
"""Convert ContextFieldExistence expressions to TrueLiteral."""
if isinstance(expression, expressions.ContextFieldExistence):
bojanserafimov marked this conversation as resolved.
Show resolved Hide resolved
query_path = expression.location.query_path
return expressions.BinaryComposition(
u'!=',
ContextColumn(query_path, non_null_columns[query_path]),
expressions.NullLiteral)
return expression

return [block.visit_and_update_expressions(visitor_fn) for block in ir_blocks]


##############
# Public API #
##############
Expand All @@ -38,5 +111,7 @@ def lower_ir(schema_info, ir):
"""
ir_blocks = ir.ir_blocks
ir_blocks = _remove_output_context_field_existence(ir_blocks, ir.query_metadata_table)
ir_blocks = _lower_sql_context_field_existence(schema_info, ir_blocks, ir.query_metadata_table)
ir_blocks = common.short_circuit_ternary_conditionals(ir_blocks, ir.query_metadata_table)
ir_blocks = common.optimize_boolean_expression_comparisons(ir_blocks)
return IrAndMetadata(ir_blocks, ir.input_metadata, ir.output_metadata, ir.query_metadata_table)
Loading