95 changes: 72 additions & 23 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,77 @@ def get_leaf_classes(op):


class AggGen:
__slots__ = ("aggfunc",)
"""A descriptor for compiling aggregate functions.
def __init__(self, *, aggfunc: Callable) -> None:
self.aggfunc = aggfunc
Common cases can be handled by setting configuration flags,
special cases should override the `aggregate` method directly.
def __getattr__(self, name: str) -> partial:
return partial(self.aggfunc, name)
Parameters
----------
supports_filter
Whether the backend supports a FILTER clause in the aggregate.
Defaults to False.
"""

def __getitem__(self, key: str) -> partial:
return getattr(self, key)
class _Accessor:
"""An internal type to handle getattr/getitem access."""

__slots__ = ("handler", "compiler")

def __init__(self, handler: Callable, compiler: SQLGlotCompiler):
self.handler = handler
self.compiler = compiler

def __getattr__(self, name: str) -> Callable:
return partial(self.handler, self.compiler, name)

__getitem__ = __getattr__

__slots__ = ("supports_filter",)

def __init__(self, *, supports_filter: bool = False):
self.supports_filter = supports_filter

def __get__(self, instance, owner=None):
if instance is None:
return self

return AggGen._Accessor(self.aggregate, instance)

def aggregate(
self,
compiler: SQLGlotCompiler,
name: str,
*args: Any,
where: Any = None,
):
"""Compile the specified aggregate.
Parameters
----------
compiler
The backend's compiler.
name
The aggregate name (e.g. `"sum"`).
args
Any arguments to pass to the aggregate.
where
An optional column filter to apply before performing the aggregate.
"""
func = compiler.f[name]

if where is None:
return func(*args)

if self.supports_filter:
return sge.Filter(
this=func(*args),
expression=sge.Where(this=where),
)
else:
args = tuple(compiler.if_(where, arg, NULL) for arg in args)
return func(*args)


class VarGen:
Expand Down Expand Up @@ -167,7 +228,10 @@ def wrapper(self, op, *, left, right):

@public
class SQLGlotCompiler(abc.ABC):
__slots__ = "agg", "f", "v"
__slots__ = "f", "v"

agg = AggGen()
"""A generator for handling aggregate functions"""

rewrites: tuple[type[pats.Replace], ...] = (
empty_in_values_right_side,
Expand Down Expand Up @@ -345,7 +409,6 @@ class SQLGlotCompiler(abc.ABC):
lowered_ops: ClassVar[dict[type[ops.Node], pats.Replace]] = {}

def __init__(self) -> None:
self.agg = AggGen(aggfunc=self._aggregate)
self.f = FuncGen(copy=self.__class__.copy_func_args)
self.v = VarGen()

Expand Down Expand Up @@ -411,20 +474,6 @@ def dialect(self) -> str:
def type_mapper(self) -> type[SqlglotType]:
"""The type mapper for the backend."""

@abc.abstractmethod
def _aggregate(self, funcname, *args, where):
"""Translate an aggregate function.
Three flavors of filtering aggregate function inputs:
1. supports filter (duckdb, postgres, others)
e.g.: sum(x) filter (where predicate)
2. use null to filter out
e.g.: sum(if(predicate, x, NULL))
3. clickhouse's ${func}If implementation, e.g.:
sumIf(predicate, x)
"""

# Concrete API

def if_(self, condition, true, false: sge.Expression | None = None) -> sge.If:
Expand Down
16 changes: 6 additions & 10 deletions ibis/backends/sqlite/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compiler import NULL, SQLGlotCompiler
from ibis.backends.sql.compiler import NULL, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import SQLiteType
from ibis.backends.sql.dialects import SQLite
from ibis.common.temporal import DateUnit, IntervalUnit
Expand All @@ -22,6 +22,8 @@ class SQLiteCompiler(SQLGlotCompiler):
dialect = SQLite
type_mapper = SQLiteType

agg = AggGen(supports_filter=True)

NAN = NULL
POS_INF = sge.Literal.number("1e999")
NEG_INF = sge.Literal.number("-1e999")
Expand Down Expand Up @@ -97,12 +99,6 @@ class SQLiteCompiler(SQLGlotCompiler):
ops.Date: "date",
}

def _aggregate(self, funcname: str, *args, where):
expr = self.f[funcname](*args)
if where is not None:
return sge.Filter(this=expr, expression=sge.Where(this=where))
return expr

def visit_Log10(self, op, *, arg):
return self.f.anon.log10(arg)

Expand Down Expand Up @@ -222,7 +218,7 @@ def _visit_arg_reduction(self, func, op, *, arg, key, where):
if op.where is not None:
cond = sg.and_(cond, where)

agg = self._aggregate(func, key, where=cond)
agg = self.agg[func](key, where=cond)
return self.f.anon.json_extract(self.f.json_array(arg, agg), "$[0]")

def visit_UnwrapJSONString(self, op, *, arg):
Expand Down Expand Up @@ -254,10 +250,10 @@ def visit_UnwrapJSONBoolean(self, op, *, arg):
)

def visit_Variance(self, op, *, arg, how, where):
return self._aggregate(f"_ibis_var_{op.how}", arg, where=where)
return self.agg[f"_ibis_var_{op.how}"](arg, where=where)

def visit_StandardDev(self, op, *, arg, how, where):
var = self._aggregate(f"_ibis_var_{op.how}", arg, where=where)
var = self.agg[f"_ibis_var_{op.how}"](arg, where=where)
return self.f.sqrt(var)

def visit_ApproxCountDistinct(self, op, *, arg, where):
Expand Down
11 changes: 4 additions & 7 deletions ibis/backends/trino/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compiler import FALSE, NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.compiler import FALSE, NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import TrinoType
from ibis.backends.sql.dialects import Trino
from ibis.backends.sql.rewrites import exclude_unsupported_window_frame_from_ops
Expand All @@ -21,6 +21,9 @@ class TrinoCompiler(SQLGlotCompiler):

dialect = Trino
type_mapper = TrinoType

agg = AggGen(supports_filter=True)

rewrites = (
exclude_unsupported_window_frame_from_ops,
*SQLGlotCompiler.rewrites,
Expand Down Expand Up @@ -83,12 +86,6 @@ class TrinoCompiler(SQLGlotCompiler):
ops.ExtractIsoYear: "year_of_week",
}

def _aggregate(self, funcname: str, *args, where):
expr = self.f[funcname](*args)
if where is not None:
return sge.Filter(this=expr, expression=sge.Where(this=where))
return expr

@staticmethod
def _minimize_spec(start, end, spec):
if (
Expand Down