Skip to content

Commit

Permalink
Merge pull request #15 from maxmzkr/master
Browse files Browse the repository at this point in the history
Allow filters before analytic Speed up analytics
  • Loading branch information
maxmzkr committed Jan 12, 2017
2 parents 05056a2 + 3e73b56 commit b7911b0
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 56 deletions.
49 changes: 35 additions & 14 deletions ibis/impala/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,24 +374,45 @@ def _format_table(self, expr):

class ImpalaUnion(comp.Union):

def compile(self):
def _extract_subqueries(self):
self.subqueries = comp._extract_subqueries(self)
for subquery in self.subqueries:
self.context.set_extracted(subquery)

def format_subqueries(self):
context = self.context
subqueries = self.subqueries

if self.distinct:
union_keyword = 'UNION'
else:
union_keyword = 'UNION ALL'
return ',\n'.join([
'{0} AS (\n{1}\n)'.format(
context.get_ref(expr),
util.indent(context.get_compiled_expr(expr), 2)
) for expr in subqueries
])

left_set = context.get_compiled_expr(self.left, isolated=True)
right_set = context.get_compiled_expr(self.right, isolated=True)
def format_relation(self, expr):
ref = self.context.get_ref(expr)
if ref is not None:
return 'SELECT *\nFROM {0}'.format(ref)
return self.context.get_compiled_expr(expr)

# XXX: hack of all trades - our right relation has a CTE
# TODO: factor out common subqueries in the union
if right_set.startswith('WITH'):
format_string = '({0})\n{1}\n({2})'
else:
format_string = '{0}\n{1}\n{2}'
return format_string.format(left_set, union_keyword, right_set)
def compile(self):
union_keyword = 'UNION' if self.distinct else 'UNION ALL'

self._extract_subqueries()

left_set = self.format_relation(self.left)
right_set = self.format_relation(self.right)
extracted = self.format_subqueries()

buf = []

if extracted:
buf.append('WITH {0}'.format(extracted))

buf.extend([left_set, union_keyword, right_set])

return '\n'.join(buf)


# ---------------------------------------------------------------------
Expand Down
11 changes: 4 additions & 7 deletions ibis/sql/alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,17 +412,14 @@ def __init__(self, *args, **kwargs):
self.dialect = kwargs.pop('dialect', AlchemyDialect)
comp.QueryContext.__init__(self, *args, **kwargs)

def subcontext(self, isolated=False):
if not isolated:
return type(self)(dialect=self.dialect, parent=self)
else:
return type(self)(dialect=self.dialect)
def subcontext(self):
return type(self)(dialect=self.dialect, parent=self)

def _to_sql(self, expr, ctx):
return to_sqlalchemy(expr, context=ctx)

def _compile_subquery(self, expr, isolated=False):
sub_ctx = self.subcontext(isolated=isolated)
def _compile_subquery(self, expr):
sub_ctx = self.subcontext()
return self._to_sql(expr, sub_ctx)

def has_table(self, expr, parent_contexts=False):
Expand Down
51 changes: 28 additions & 23 deletions ibis/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,8 @@ def _collect_Limit(self, expr, toplevel=False):
self._collect(op.table, toplevel=toplevel)

def _collect_Union(self, expr, toplevel=False):
if not toplevel:
return
else:
raise NotImplementedError
if toplevel:
raise NotImplementedError()

def _collect_Aggregation(self, expr, toplevel=False):
# The select set includes the grouping keys (if any), and these are
Expand Down Expand Up @@ -680,7 +678,9 @@ def _visit_Limit(self, expr):
self.visit(expr.op().table)

def _visit_Union(self, expr):
self.observe(expr)
op = expr.op()
self.visit(op.left)
self.visit(op.right)

def _visit_MaterializedJoin(self, expr):
self.observe(expr)
Expand Down Expand Up @@ -728,7 +728,8 @@ def get_result(self):
self._visit(self.expr)
return self.has_query_root and self.has_foreign_root

def _visit(self, expr, in_subquery=False, visit_cache=None, visit_table_cache=None):
def _visit(self, expr, in_subquery=False, visit_cache=None,
visit_table_cache=None):
if visit_cache is None:
visit_cache = set()

Expand All @@ -742,9 +743,13 @@ def _visit(self, expr, in_subquery=False, visit_cache=None, visit_table_cache=No

for arg in node.flat_args():
if isinstance(arg, ir.TableExpr):
self._visit_table(arg, in_subquery=in_subquery, visit_cache=visit_cache, visit_table_cache=visit_table_cache)
self._visit_table(arg, in_subquery=in_subquery,
visit_cache=visit_cache,
visit_table_cache=visit_table_cache)
elif isinstance(arg, ir.Expr):
self._visit(arg, in_subquery=in_subquery, visit_cache=visit_cache, visit_table_cache=visit_table_cache)
self._visit(arg, in_subquery=in_subquery,
visit_cache=visit_cache,
visit_table_cache=visit_table_cache)
else:
continue

Expand All @@ -760,7 +765,8 @@ def _is_subquery(self, node):

return False

def _visit_table(self, expr, in_subquery=False, visit_cache=None, visit_table_cache=None):
def _visit_table(self, expr, in_subquery=False, visit_cache=None,
visit_table_cache=None):
if visit_table_cache is None:
visit_table_cache = set()

Expand All @@ -775,7 +781,9 @@ def _visit_table(self, expr, in_subquery=False, visit_cache=None, visit_table_ca

for arg in node.flat_args():
if isinstance(arg, ir.Expr):
self._visit(arg, in_subquery=in_subquery, visit_cache=visit_cache, visit_table_cache=visit_table_cache)
self._visit(arg, in_subquery=in_subquery,
visit_cache=visit_cache,
visit_table_cache=visit_table_cache)

def _ref_check(self, node, in_subquery=False):
is_aliased = self.ctx.has_ref(node)
Expand Down Expand Up @@ -918,7 +926,7 @@ def get_result(self):

def _make_union(self):
op = self.expr.op()
return self._union_class(op.left, op.right,
return self._union_class(op.left, op.right, self.expr,
distinct=op.distinct,
context=self.context)

Expand Down Expand Up @@ -947,8 +955,8 @@ def __init__(self, indent=2, parent=None, memo=None):
self._table_key_memo = {}
self.memo = memo or format.FormatMemo()

def _compile_subquery(self, expr, isolated=False):
sub_ctx = self.subcontext(isolated=isolated)
def _compile_subquery(self, expr):
sub_ctx = self.subcontext()
return self._to_sql(expr, sub_ctx)

def _to_sql(self, expr, ctx):
Expand All @@ -964,7 +972,7 @@ def top_context(self):
def set_always_alias(self):
self.always_alias = True

def get_compiled_expr(self, expr, isolated=False):
def get_compiled_expr(self, expr):
this = self.top_context

key = self._get_table_key(expr)
Expand All @@ -975,7 +983,7 @@ def get_compiled_expr(self, expr, isolated=False):
if isinstance(op, ops.SQLQueryResult):
result = op.query
else:
result = self._compile_subquery(expr, isolated=isolated)
result = self._compile_subquery(expr)

this.subquery_memo[key] = result
return result
Expand Down Expand Up @@ -1030,11 +1038,8 @@ def set_extracted(self, expr):
self.extracted_subexprs.add(key)
self.make_alias(expr)

def subcontext(self, isolated=False):
if not isolated:
return type(self)(indent=self.indent, parent=self)
else:
return type(self)(indent=self.indent)
def subcontext(self):
return type(self)(indent=self.indent, parent=self)

# Maybe temporary hacks for correlated / uncorrelated subqueries

Expand Down Expand Up @@ -1432,10 +1437,10 @@ def _validate_join_predicates(self, predicates):

class Union(DDL):

def __init__(self, left_table, right_table, distinct=False,
context=None):
def __init__(self, left_table, right_table, expr, distinct=False, context=None):
self.context = context
self.left = left_table
self.right = right_table

self.distinct = distinct
self.table_set = expr
self.filters = []
23 changes: 11 additions & 12 deletions ibis/sql/tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,25 +1459,24 @@ def test_subquery_in_union(self):
expr = join1.union(join2)
result = to_sql(expr)
expected = """\
(WITH t0 AS (
WITH t0 AS (
SELECT `a`, `g`, sum(`f`) AS `metric`
FROM alltypes
GROUP BY 1, 2
),
t1 AS (
SELECT t0.*
FROM t0
INNER JOIN t0 t3
ON t0.`g` = t3.`g`
)
SELECT t0.*
FROM t0
INNER JOIN t0 t1
ON t0.`g` = t1.`g`)
SELECT *
FROM t1
UNION ALL
(WITH t0 AS (
SELECT `a`, `g`, sum(`f`) AS `metric`
FROM alltypes
GROUP BY 1, 2
)
SELECT t0.*
FROM t0
INNER JOIN t0 t1
ON t0.`g` = t1.`g`)"""
INNER JOIN t0 t3
ON t0.`g` = t3.`g`"""
assert result == expected

def test_subquery_factor_correlated_subquery(self):
Expand Down

0 comments on commit b7911b0

Please sign in to comment.