Skip to content

Commit

Permalink
Merge 8d6fb66 into 224704f
Browse files Browse the repository at this point in the history
  • Loading branch information
jcd2020 committed Nov 15, 2019
2 parents 224704f + 8d6fb66 commit ff3beba
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 71 deletions.
215 changes: 159 additions & 56 deletions graphql_compiler/compiler/emit_sql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright 2018-present Kensho Technologies, LLC.
"""Transform a SqlNode tree into an executable SQLAlchemy query."""
from collections import namedtuple

import six
import sqlalchemy

Expand Down Expand Up @@ -124,6 +126,24 @@ def _find_folded_fields(ir):
return folded_fields


# 3-tuple describing the join information for each traversal in a fold.
#
# Contains DirectJoinDescriptor naming the columns used in the join predicate,
# the source/from table, and the destination/to table
SQLFoldTraversalDescriptor = namedtuple('SQLFoldJoinInfo', (
# DirectJoinDescriptor giving columns used to join from_table/to_table
'join_descriptor',

# SQLAlchemy table corresponding to corresponding to the outside vertex of the traversal,
# appears on the left side of the join.
'from_table',

# SQLAlchemy table corresponding to corresponding to the inside vertex of the traversal,
# appears on the right side of the join.
'to_table'
))


class SQLFoldObject(object):
"""Object used to collect info for folds in order to ensure correct code emission."""

Expand All @@ -141,9 +161,11 @@ class SQLFoldObject(object):
#
# The SELECT clause for the fold subquery contains OuterVertex.SOME_COLUMN, a unique
# identifier (the primary key) for the OuterVertex determined by the edge descriptor
# from the vertex immediately outside the fold to the folded vertex.
# from the vertex immediately outside the fold to the folded vertex. This presently
# only supports non-composite primary keys.
#
# SELECT will also contain an ARRAY_AGG for each column labeled for output inside the fold.
# SELECT will also contain an ARRAY_AGG for each column labeled for output inside the fold if
# compiling to PostgreSQL. For compilation to MSSQL an XML PATH-based aggregation is performed.
#
# SELECT will also contain a COUNT(*) if _x_count is referred to by the query.
#
Expand All @@ -152,10 +174,10 @@ class SQLFoldObject(object):
# The FROM and JOIN clauses are constructed during end_fold using info from the
# visit_traversed_vertex function.
#
# The full subquery will look as follows:
# The full subquery will look as follows for PostgreSQL:
#
# SELECT
# OuterVertex.SOME_COLUMN <- this value is determined from the edge descriptor
# OuterVertex.SOME_COLUMN <- this value is the primary key
# ARRAY_AGG(OutputVertex.fold_output_column) AS fold_output
# FROM OuterVertex
# INNER JOIN ... <- INNER JOINs compiled during end_fold
Expand All @@ -164,35 +186,72 @@ class SQLFoldObject(object):
# INNER JOIN OutputVertex <- INNER JOINs compiled during end_fold
# ON ...
# GROUP BY OuterVertex.SOME_COLUMN
#
# and as follows for MSSQL:
#
# SELECT
# OuterVertex.SOME_COLUMN <- this value is the primary key
# COALESCE((SELECT ... FOR XML PATH(''), '~') AS fold_output
# FROM OuterVertex
# INNER JOIN ...
# ON ...
# ...
# INNER JOIN VertexPrecedingOutput
# ON ...

def __init__(self, outer_vertex_table, join_descriptor):
def __init__(self, outer_vertex_table, primary_key):
"""Create an SQLFoldObject with table, type, and join information supplied by the IR.
Args:
outer_vertex_table: SQLAlchemy table alias for vertex outside of fold.
join_descriptor: DirectJoinDescriptor object from the schema, describing the
first join from the outer vertex to the folded vertex.
primary_key: PrimaryKeyConstraint, primary_key of the vertex immediately outside the
fold. Used to set the group by as well as join the fold subquery to the
rest of the query. Composite keys unsupported.
"""
# table containing output columns
# initially None because output table is unknown until call to visit_output_vertex
self._output_vertex_alias = None

# table for vertex immediately outside fold
self._outer_vertex_alias = outer_vertex_table
if len(primary_key.columns) > 1:
raise NotImplementedError(u'Composite keys not supported. '
u'A composite primary key {} was found for table {}. '
u'SQL fold only supports non-composite primary '
u'keys.'.format(primary_key, outer_vertex_table.original))

# name of the field used in the primary key for the vertex outside the fold
# current implementation does not support composite primary keys
#
# we must use the name of the column as opposed to the column itself because
# primary key refers to the column from the original table, while we need the
# identically named column from its alias
self._outer_vertex_primary_key = list(primary_key.columns)[0].description

# group by column for fold subquery
self._group_by = [self._outer_vertex_alias.c[join_descriptor.from_column]]
self._group_by = [self._outer_vertex_alias.c[self._outer_vertex_primary_key]]

# List of 3-tuples describing the join required for each traversal in the fold
# List of SQLFoldTraversalDescriptor namedtuples describing each traversal in the fold
# starting with the join from the vertex immediately outside the fold to the folded vertex:
#
# edge: join descriptor for the columns used to join one vertex to the next in the fold
# from_table: the table on the left side of the join
# to_table: the table on the right side of the join
self._join_info = []

self._traversal_descriptors = []
self._outputs = [] # output columns for fold

self._ended = False # indicates whether `end_fold` has been called on this object

def __str__(self):
"""Produce string used to customize error messages."""
if self._outer_vertex_alias is None:
return u'SQLFoldObject("Invalid fold: no vertex preceding fold.")'
elif self._output_vertex_alias is None:
return (u'SQLFoldObject("Vertex outside fold: {}.'
u'Output vertex for fold: None.")').format(
self._outer_vertex_alias.original
)
else:
return u'SQLFoldObject("Vertex outside fold: {}. Output vertex for fold: {}.")'.format(
self._outer_vertex_alias.original, self._output_vertex_alias.original
)

@property
def outputs(self):
"""Get the output columns for the fold subquery."""
Expand All @@ -213,11 +272,6 @@ def outer_vertex_alias(self):
"""Get the SQLAlchemy table corresponding to vertex immediately outside the fold."""
return self._outer_vertex_alias

@property
def join_info(self):
"""Get a tuple containing edge and table info for the joins within the subquery."""
return self._join_info

def _set_outputs(self, outputs):
"""Set output columns for the fold object."""
self._outputs = outputs
Expand All @@ -226,33 +280,61 @@ def _set_group_by(self, group_by):
"""Set output columns for the fold object."""
self._group_by = group_by

def _construct_fold_joins(self, edge, from_alias, to_alias):
"""Use the edge descriptors to create the join clause between the tables in the fold."""
join_clause = sqlalchemy.join(
from_alias,
to_alias,
onclause=(from_alias.c[edge.from_column] == to_alias.c[edge.to_column])
)
def _construct_fold_joins(self):
"""Use the traversal descriptors to create the join clause for the tables in the fold."""
# Start the join clause with the from_table of the first traversal descriptor,
# which is the vertex immediately preceding the fold
join_clause = self._traversal_descriptors[0].from_table

# MSSQL and PostgreSQL have different terminating indices
terminating_index = len(self._traversal_descriptors)
traversal_descriptors = self._traversal_descriptors[:terminating_index]

# Starting at the first from_table, join traversed vertices in order until the output
# vertex (PostgreSQL) is reached
for travel_descriptor in traversal_descriptors:
# joins from earlier in the chain of traversals are at the beginning of the list
# b/c joins are appended in the order they are traversed
from_table = travel_descriptor.from_table
to_table = travel_descriptor.to_table
join_descriptor = travel_descriptor.join_descriptor
join_clause = sqlalchemy.join(
join_clause,
to_table,
onclause=(from_table.c[join_descriptor.from_column] == to_table.c[
join_descriptor.to_column
])
)

return join_clause

def _construct_fold_subquery(self, subquery_from_clause):
"""Combine all parts of the fold object to produce the complete fold subquery."""
return sqlalchemy.select(
select_statement = sqlalchemy.select(
self.outputs
).select_from(
subquery_from_clause
).group_by(
)
# Factor our GROUP BY because MSSQL won't use it
return select_statement.group_by(
*self.group_by
)

def _get_fold_outputs(self, fold_scope_location, join_descriptor, all_folded_outputs):
def _get_array_agg_column(self, intermediate_fold_output_name, fold_output_field):
"""Select an array_agg of the fold output field, labeled as requested."""
return sqlalchemy.func.array_agg(
self.output_vertex_alias.c[fold_output_field]
).label(intermediate_fold_output_name)

def _get_fold_outputs(self, fold_scope_location, all_folded_outputs):
"""Generate output columns for innermost fold scope and add them to active SQLFoldObject."""
# find outputs for this fold in all_folded_outputs and add to self._outputs
if fold_scope_location.fold_path in all_folded_outputs:
for fold_output in all_folded_outputs[fold_scope_location.fold_path]:
# distinguish between folds with the same fold path but different query paths
# distinguish folds with the same fold path but different query paths
if (fold_output.base_location, fold_output.fold_path) == (
fold_scope_location.base_location, fold_scope_location.fold_path):

if fold_output.field == COUNT_META_FIELD_NAME:
self._outputs.append(
sqlalchemy.func.coalesce(
Expand All @@ -267,47 +349,59 @@ def _get_fold_outputs(self, fold_scope_location, join_descriptor, all_folded_out
fold_output.field)
# add array aggregated output column to self._outputs
self._outputs.append(
sqlalchemy.func.array_agg(
self.output_vertex_alias.c[fold_output.field]
).label(intermediate_fold_output_name)
self._get_array_agg_column(intermediate_fold_output_name,
fold_output.field)
)

# use to join unique identifier for the fold's outer vertex to the final table
self._outputs.append(self.outer_vertex_alias.c[join_descriptor.from_column])
self._outputs.append(self.outer_vertex_alias.c[self._outer_vertex_primary_key])

# sort to make select order deterministic
return sorted(self._outputs, key=lambda column: column.name, reverse=True)

def visit_output_vertex(self,
output_alias,
fold_scope_location,
join_descriptor,
all_folded_outputs):
"""Update output columns when visiting the vertex containing output directives."""
if self._ended:
raise AssertionError(u'Cannot visit output vertices after end_fold has been called. '
u'Invalid state encountered during fold {}'.format(self))
if self._output_vertex_alias is not None:
raise AssertionError('Cannot visit multiple output vertices in one fold.')
raise AssertionError(u'Cannot visit multiple output vertices in one fold. '
u'Invalid state encountered during fold {}'.format(self))
self._output_vertex_alias = output_alias
self._outputs = self._get_fold_outputs(fold_scope_location,
join_descriptor,
all_folded_outputs)

def visit_traversed_vertex(self, join_descriptor, from_table, to_table):
"""Add a new join descriptor for every vertex traversed in the fold."""
self._join_info.append((join_descriptor, from_table, to_table))
"""Add a new traversal descriptor for every vertex traversed in the fold."""
if self._ended:
raise AssertionError(u'Cannot visit traversed vertices after end_fold has been called.'
u'Invalid state encountered during fold {}'.format(self))
self._traversal_descriptors.append(SQLFoldTraversalDescriptor(join_descriptor,
from_table,
to_table))

def end_fold(self, alias_generator, from_clause, outer_from_table):
"""Produce the final subquery and join it onto the rest of the query."""
if self._ended:
raise AssertionError(u'Cannot call end_fold more than once. '
u'Invalid state encountered during fold {}'.format(self))
if self._output_vertex_alias is None:
raise AssertionError('No output vertex visited.')
if len(self._join_info) == 0:
raise AssertionError('No traversed vertices visited.')
raise AssertionError(u'No output vertex visited. '
u'Invalid state encountered during fold {}'.format(self))
if len(self._traversal_descriptors) == 0:
raise AssertionError(u'No traversed vertices visited. '
u'Invalid state encountered during fold {}'.format(self))

# for now we only handle folds containing one traversal (i.e. join)
if len(self.join_info) > 1:
raise NotImplementedError('Folds containing multiple traversals are not implemented.')
edge, from_alias, to_alias = self.join_info.pop()
if len(self._traversal_descriptors) > 1:
raise NotImplementedError(u'Folds containing multiple traversals are not '
u'implemented {}.'.format(self))

# produce the from clause/joins for the subquery
subquery_from_clause = self._construct_fold_joins(edge, from_alias, to_alias)
# join together all vertices traversed
subquery_from_clause = self._construct_fold_joins()

# produce full subquery
fold_subquery = self._construct_fold_subquery(subquery_from_clause).alias(
Expand All @@ -318,10 +412,14 @@ def end_fold(self, alias_generator, from_clause, outer_from_table):
joined_from_clause = sqlalchemy.join(
from_clause,
fold_subquery,
onclause=(outer_from_table.c[edge.from_column] == fold_subquery.c[edge.from_column]),
onclause=(
outer_from_table.c[self._outer_vertex_primary_key] == fold_subquery.c[
self._outer_vertex_primary_key
] # only support a single primary key field, no composite keys
),
isouter=False
)

self._ended = True # prevent any more functions being called on this fold
return fold_subquery, joined_from_clause, outer_from_table


Expand Down Expand Up @@ -435,7 +533,8 @@ def backtrack(self, previous_location):
def traverse(self, vertex_field, optional):
"""Execute a Traverse Block."""
if self._current_fold is not None:
raise NotImplementedError('Traversals inside a fold are not implemented yet.')
raise NotImplementedError('Traversals inside a fold are not implemented '
'yet {}.'.format(self))
# Follow the edge
previous_alias = self._current_alias
edge = self._sql_schema_info.join_descriptors[self._current_classname][vertex_field]
Expand Down Expand Up @@ -528,14 +627,14 @@ def fold(self, fold_scope_location):
u'fold block at current location {}.'
.format(fold_scope_location, self._current_location_info))
# begin the fold

# 1. get fold metadata
# location of vertex that is folded on
self._fold_vertex_location = fold_scope_location
outer_alias = self._current_alias.alias()

outer_vertex_primary_key = self._sql_schema_info.vertex_name_to_table[
self._current_classname
].primary_key
# 2. get information on the folded vertex and its edge to the outer vertex

# basic info about the folded vertex
fold_vertex_alias = self._sql_schema_info.vertex_name_to_table[
self._ir.query_metadata_table.get_location_info(fold_scope_location).type.name
Expand All @@ -550,15 +649,14 @@ def fold(self, fold_scope_location):
][full_edge_name]

# 3. initialize fold object
self._current_fold = SQLFoldObject(outer_alias, join_descriptor)
self._current_fold = SQLFoldObject(outer_alias, outer_vertex_primary_key)

# 4. add join information for this traversal to the fold object
self._current_fold.visit_traversed_vertex(join_descriptor, outer_alias, fold_vertex_alias)

# 5. add output columns to fold object
self._current_fold.visit_output_vertex(fold_vertex_alias,
fold_scope_location,
join_descriptor,
self._all_folded_fields)

def unfold(self):
Expand Down Expand Up @@ -608,6 +706,11 @@ def get_query(self):
return sqlalchemy.select(self._outputs).select_from(
self._from_clause).where(sqlalchemy.and_(*self._filters))

@property
def sql_schema_info(self):
"""Get the SQLALchemySchemaInfo for the current query."""
return self._sql_schema_info


def emit_code_from_ir(sql_schema_info, ir):
"""Return a SQLAlchemy Query from a passed SqlQueryTree.
Expand Down

0 comments on commit ff3beba

Please sign in to comment.