Skip to content

Commit

Permalink
feat(sqlalchemy): properly implement Intersection and Difference
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jul 29, 2022
1 parent cd9a34c commit 2bc0b69
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 47 deletions.
37 changes: 29 additions & 8 deletions ibis/backends/base/sql/alchemy/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
Select,
SelectBuilder,
TableSetFormatter,
Union,
)
from ibis.backends.base.sql.compiler.base import SetOp


def _schema_to_sqlalchemy_columns(schema: sch.Schema) -> list[sa.Column]:
Expand Down Expand Up @@ -343,21 +343,40 @@ def _convert_group_by(self, exprs):
return exprs


class AlchemyUnion(Union):
def compile(self):
def reduce_union(left, right, distincts=iter(self.distincts)):
distinct = next(distincts)
sa_func = sa.union if distinct else sa.union_all
return sa_func(left, right)
class AlchemySetOp(SetOp):
@classmethod
def reduce(cls, left, right, distincts):
distinct = next(distincts)
sa_func = cls.distinct_func if distinct else cls.non_distinct_func
return sa_func(left, right)

def compile(self):
context = self.context
selects = []

for table in self.tables:
table_set = context.get_compiled_expr(table)
selects.append(table_set.cte().select())

return functools.reduce(reduce_union, selects)
return functools.reduce(
functools.partial(self.reduce, distincts=iter(self.distincts)),
selects,
)


class AlchemyUnion(AlchemySetOp):
distinct_func = sa.union
non_distinct_func = sa.union_all


class AlchemyIntersection(AlchemySetOp):
distinct_func = sa.intersect
non_distinct_func = sa.intersect_all


class AlchemyDifference(AlchemySetOp):
distinct_func = sa.except_
non_distinct_func = sa.except_all


class AlchemyCompiler(Compiler):
Expand All @@ -367,6 +386,8 @@ class AlchemyCompiler(Compiler):
select_builder_class = AlchemySelectBuilder
select_class = AlchemySelect
union_class = AlchemyUnion
intersect_class = AlchemyIntersection
difference_class = AlchemyDifference

@classmethod
def to_sql(cls, expr, context=None, params=None, exists=False):
Expand Down
13 changes: 9 additions & 4 deletions ibis/backends/base/sql/compiler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,20 @@ def compile(self):


class SetOp(DML):
def __init__(self, tables, expr, context):
def __init__(self, tables, expr, context, distincts):
self.context = context
self.tables = tables
self.table_set = expr
self.distincts = distincts
self.filters = []

@classmethod
def keyword(cls, distinct):
return cls._keyword + (not distinct) * " ALL"

def _get_keyword_list(self):
return map(self.keyword, self.distincts)

def _extract_subqueries(self):
self.subqueries = _extract_common_table_expressions(
[self.table_set, *self.filters]
Expand All @@ -84,9 +92,6 @@ def format_relation(self, expr):
return f'SELECT *\nFROM {ref}'
return self.context.get_compiled_expr(expr)

def _get_keyword_list(self):
raise NotImplementedError("Need objects to interleave")

def compile(self):
self._extract_subqueries()

Expand Down
59 changes: 24 additions & 35 deletions ibis/backends/base/sql/compiler/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,33 +455,18 @@ def format_limit(self):


class Union(SetOp):
def __init__(self, tables, expr, context, distincts):
super().__init__(tables, expr, context)
self.distincts = distincts

@staticmethod
def keyword(distinct):
return 'UNION' if distinct else 'UNION ALL'

def _get_keyword_list(self):
return map(self.keyword, self.distincts)
_keyword = "UNION"


class Intersection(SetOp):
_keyword = "INTERSECT"

def _get_keyword_list(self):
return [self._keyword] * (len(self.tables) - 1)


class Difference(SetOp):
_keyword = "EXCEPT"

def _get_keyword_list(self):
return [self._keyword] * (len(self.tables) - 1)


def flatten_union(table: ir.Table):
def flatten_set_op(table: ir.Table):
"""Extract all union queries from `table`.
Parameters
Expand All @@ -493,14 +478,14 @@ def flatten_union(table: ir.Table):
Iterable[Union[Table, bool]]
"""
op = table.op()
if isinstance(op, ops.Union):
if isinstance(op, ops.SetOp):
# For some reason mypy considers `op.left` and `op.right`
# of `Argument` type, and fails the validation. While in
# `flatten` types are the same, and it works
return toolz.concatv(
flatten_union(op.left), # type: ignore
flatten_set_op(op.left), # type: ignore
[op.distinct],
flatten_union(op.right), # type: ignore
flatten_set_op(op.right), # type: ignore
)
return [table]

Expand All @@ -517,7 +502,9 @@ def flatten(table: ir.Table):
Iterable[Union[Table]]
"""
op = table.op()
return list(toolz.concatv(flatten_union(op.left), flatten_union(op.right)))
return list(
toolz.concatv(flatten_set_op(op.left), flatten_set_op(op.right))
)


class Compiler:
Expand Down Expand Up @@ -617,35 +604,37 @@ def _generate_setup_queries(expr, context):
def _generate_teardown_queries(expr, context):
return []

@classmethod
def _make_union(cls, expr, context):
@staticmethod
def _make_set_op(cls, expr, context):
# flatten unions so that we can codegen them all at once
union_info = list(flatten_union(expr))
set_op_info = list(flatten_set_op(expr))

# since op is a union, we have at least 3 elements in union_info (left
# distinct right) and if there is more than a single union we have an
# additional two elements per union (distinct right) which means the
# total number of elements is at least 3 + (2 * number of unions - 1)
# and is therefore an odd number
npieces = len(union_info)
assert npieces >= 3 and npieces % 2 != 0, 'Invalid union expression'
npieces = len(set_op_info)
assert (
npieces >= 3 and npieces % 2 != 0
), 'Invalid set operation expression'

# 1. every other object starting from 0 is a Table instance
# 2. every other object starting from 1 is a bool indicating the type
# of union (distinct or not distinct)
table_exprs, distincts = union_info[::2], union_info[1::2]
return cls.union_class(
table_exprs, expr, distincts=distincts, context=context
)
# of $set_op (distinct or not distinct)
table_exprs, distincts = set_op_info[::2], set_op_info[1::2]
return cls(table_exprs, expr, distincts=distincts, context=context)

@classmethod
def _make_union(cls, expr, context):
return cls._make_set_op(cls.union_class, expr, context)

@classmethod
def _make_intersect(cls, expr, context):
# flatten intersections so that we can codegen them all at once
table_exprs = list(flatten(expr))
return cls.intersect_class(table_exprs, expr, context=context)
return cls._make_set_op(cls.intersect_class, expr, context)

@classmethod
def _make_difference(cls, expr, context):
# flatten differences so that we can codegen them all at once
table_exprs = list(flatten(expr))
return cls.difference_class(table_exprs, expr, context=context)
return cls._make_set_op(cls.difference_class, expr, context)

0 comments on commit 2bc0b69

Please sign in to comment.