Skip to content

Commit

Permalink
refactor(analysis): remove ScalarAggregate, reduction_to_aggregation …
Browse files Browse the repository at this point in the history
…and has_multiple_bases
  • Loading branch information
kszucs authored and cpcloud committed Aug 7, 2023
1 parent df63e8b commit ed75866
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 127 deletions.
6 changes: 1 addition & 5 deletions ibis/backends/base/sql/alchemy/registry.py
Expand Up @@ -78,11 +78,7 @@ def formatter(t, op):

def get_sqla_table(ctx, table):
if ctx.has_ref(table, parent_contexts=True):
ctx_level = ctx
sa_table = ctx_level.get_ref(table)
while sa_table is None and ctx_level.parent is not ctx_level:
ctx_level = ctx_level.parent
sa_table = ctx_level.get_ref(table)
sa_table = ctx.get_ref(table, search_parents=True)
else:
sa_table = ctx.get_compiled_expr(table)

Expand Down
18 changes: 3 additions & 15 deletions ibis/backends/base/sql/compiler/select_builder.py
Expand Up @@ -14,20 +14,6 @@ class _LimitSpec(NamedTuple):
offset: int


def _get_scalar(field):
def scalar_handler(results):
return results[field][0]

return scalar_handler


def _get_column(name):
def column_handler(results):
return results[name]

return column_handler


class SelectBuilder:
"""Transforms expression IR to a query pipeline.
Expand Down Expand Up @@ -127,7 +113,9 @@ def _collect_elements(self):
# expression that is being translated only depends on a single table
# expression.

if isinstance(self.op, ops.TableNode):
if isinstance(self.op, ops.DummyTable):
self.select_set = list(self.op.values)
elif isinstance(self.op, ops.TableNode):
self._collect(self.op, toplevel=True)
else:
self.select_set = [self.op]
Expand Down
19 changes: 2 additions & 17 deletions ibis/backends/polars/__init__.py
Expand Up @@ -8,7 +8,6 @@

import ibis
import ibis.common.exceptions as com
import ibis.expr.analysis as an
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
Expand Down Expand Up @@ -356,22 +355,8 @@ def compile(
node = node.replace(replacements)
expr = node.to_expr()

if isinstance(expr, ir.Table):
return translate(node, ctx=ctx)
elif isinstance(expr, ir.Column):
# expression must be named for the projection
node = expr.as_table().op()
return translate(node, ctx=ctx)
elif isinstance(expr, ir.Scalar):
if an.is_scalar_reduction(node):
node = an.reduction_to_aggregation(node).op()
return translate(node, ctx=ctx)
else:
# doesn't have any _tables associated so create projection
# based off of an empty table
return pl.DataFrame().lazy().select(translate(node, ctx=ctx))
else:
raise com.IbisError(f"Cannot compile expression of type: {type(expr)}")
node = expr.as_table().op()
return translate(node, ctx=ctx)

def _get_schema_using_query(self, query: str) -> sch.Schema:
return schema_from_polars(
Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/polars/compiler.py
Expand Up @@ -42,6 +42,12 @@ def table(op, **_):
return op.source._tables[op.name]


@translate.register(ops.DummyTable)
def dummy_table(op, **kw):
selections = [translate(arg, **kw) for arg in op.values]
return pl.DataFrame().lazy().select(selections)


@translate.register(ops.InMemoryTable)
def pandas_in_memory_table(op, **_):
lf = pl.from_pandas(op.data.to_frame()).lazy()
Expand Down
75 changes: 3 additions & 72 deletions ibis/expr/analysis.py
Expand Up @@ -55,70 +55,6 @@ def sub_immediate_parents(op: ops.Node, table: ops.TableNode) -> ops.Node:
return sub_for(op, {base: table for base in find_immediate_parent_tables(op)})


class ScalarAggregate:
def __init__(self, expr):
assert isinstance(expr, ir.Expr)
self.expr = expr
self.tables = []

def get_result(self):
expr = self.expr
subbed_expr = self._visit(expr)

table = self.tables[0]
for other in self.tables[1:]:
table = table.cross_join(other)

return table.select(subbed_expr)

def _visit(self, expr):
assert isinstance(expr, ir.Expr), type(expr)

if is_scalar_reduction(expr.op()) and not has_multiple_bases(expr.op()):
# An aggregation unit
if not expr.has_name():
expr = expr.name('tmp')
agg_expr = reduction_to_aggregation(expr.op())
self.tables.append(agg_expr)
return agg_expr[expr.get_name()]
elif not isinstance(expr, ir.Expr):
return expr

node = expr.op()
# TODO(kszucs): use the substitute() utility instead
new_args = (
self._visit(arg.to_expr()) if isinstance(arg, ops.Node) else arg
for arg in node.args
)
new_node = node.__class__(*new_args)
new_expr = new_node.to_expr()

if expr.has_name():
new_expr = new_expr.name(name=expr.get_name())

return new_expr


def has_multiple_bases(node):
assert isinstance(node, ops.Node), type(node)
return len(find_immediate_parent_tables(node)) > 1


def reduction_to_aggregation(node):
tables = find_immediate_parent_tables(node)

# TODO(kszucs): avoid the expression roundtrip
node = ops.Alias(node, node.name)
expr = node.to_expr()
if len(tables) == 1:
(table,) = tables
agg = table.to_expr().aggregate([expr])
else:
agg = ScalarAggregate(expr).get_result()

return agg


def find_physical_tables(node):
"""Find every first occurrence of a `ir.PhysicalTable` object in `node`."""

Expand Down Expand Up @@ -728,11 +664,6 @@ def predicate(node):
return any(g.traverse(predicate, node))


def is_scalar_reduction(node):
assert isinstance(node, ops.Node), type(node)
return node.output_shape.is_scalar() and is_reduction(node)


_ANY_OP_MAPPING = {
ops.Any: ops.UnresolvedExistsSubquery,
ops.NotAny: ops.UnresolvedNotExistsSubquery,
Expand Down Expand Up @@ -805,11 +736,11 @@ def _rewrite_filter_reduction(op, name: str | None = None, **kwargs):
# TODO: what about reductions that reference a join that isn't visible at
# this level? Means we probably have the wrong design, but will have to
# revisit when it becomes a problem.

if name is not None:
op = ops.Alias(op, name=name)
aggregation = reduction_to_aggregation(op)
return ops.TableArrayView(aggregation)

agg = op.to_expr().as_table()
return ops.TableArrayView(agg)


@_rewrite_filter.register(ops.Any)
Expand Down
4 changes: 0 additions & 4 deletions ibis/expr/decompile.py
Expand Up @@ -86,10 +86,6 @@ def translate(op, *args, **kwargs):
raise NotImplementedError(op)


# TODO(kszucs): we do rewrites on construction, so we need to handle specific
# cases like when reduction_to_aggregation is called


@translate.register(ops.Value)
@translate.register(ops.TableNode)
def value(op, *args, **kwargs):
Expand Down
14 changes: 3 additions & 11 deletions ibis/expr/types/generic.py
Expand Up @@ -959,22 +959,14 @@ def as_table(self) -> ir.Table:
>>> isinstance(lit, ir.Table)
True
"""
from ibis.expr.analysis import (
find_first_base_table,
is_scalar_reduction,
reduction_to_aggregation,
)
from ibis.expr.analysis import find_first_base_table

op = self.op()
if is_scalar_reduction(op):
return reduction_to_aggregation(op)

table = find_first_base_table(op)
if table is not None:
agg = ops.Aggregation(table=table, metrics=(op,))
return table.to_expr().aggregate([self])
else:
agg = ops.DummyTable(values=(op,))
return agg.to_expr()
return ops.DummyTable(values=(op,)).to_expr()

def _repr_html_(self) -> str | None:
return None
Expand Down
3 changes: 0 additions & 3 deletions ibis/tests/expr/test_analysis.py
Expand Up @@ -8,9 +8,6 @@
import ibis.expr.operations as ops
from ibis.tests.util import assert_equal

# TODO: test is_reduction
# TODO: test is_scalar_reduction

# Place to collect esoteric expression analysis bugs and tests


Expand Down

0 comments on commit ed75866

Please sign in to comment.