Expand Up
@@ -53,16 +53,77 @@ def get_leaf_classes(op):
class AggGen :
__slots__ = ( "aggfunc" ,)
"""A descriptor for compiling aggregate functions.
def __init__ ( self , * , aggfunc : Callable ) -> None :
self . aggfunc = aggfunc
Common cases can be handled by setting configuration flags,
special cases should override the `aggregate` method directly.
def __getattr__ (self , name : str ) -> partial :
return partial (self .aggfunc , name )
Parameters
----------
supports_filter
Whether the backend supports a FILTER clause in the aggregate.
Defaults to False.
"""
def __getitem__ (self , key : str ) -> partial :
return getattr (self , key )
class _Accessor :
"""An internal type to handle getattr/getitem access."""
__slots__ = ("handler" , "compiler" )
def __init__ (self , handler : Callable , compiler : SQLGlotCompiler ):
self .handler = handler
self .compiler = compiler
def __getattr__ (self , name : str ) -> Callable :
return partial (self .handler , self .compiler , name )
__getitem__ = __getattr__
__slots__ = ("supports_filter" ,)
def __init__ (self , * , supports_filter : bool = False ):
self .supports_filter = supports_filter
def __get__ (self , instance , owner = None ):
if instance is None :
return self
return AggGen ._Accessor (self .aggregate , instance )
def aggregate (
self ,
compiler : SQLGlotCompiler ,
name : str ,
* args : Any ,
where : Any = None ,
):
"""Compile the specified aggregate.
Parameters
----------
compiler
The backend's compiler.
name
The aggregate name (e.g. `"sum"`).
args
Any arguments to pass to the aggregate.
where
An optional column filter to apply before performing the aggregate.
"""
func = compiler .f [name ]
if where is None :
return func (* args )
if self .supports_filter :
return sge .Filter (
this = func (* args ),
expression = sge .Where (this = where ),
)
else :
args = tuple (compiler .if_ (where , arg , NULL ) for arg in args )
return func (* args )
class VarGen :
Expand Down
Expand Up
@@ -167,7 +228,10 @@ def wrapper(self, op, *, left, right):
@public
class SQLGlotCompiler (abc .ABC ):
__slots__ = "agg" , "f" , "v"
__slots__ = "f" , "v"
agg = AggGen ()
"""A generator for handling aggregate functions"""
rewrites : tuple [type [pats .Replace ], ...] = (
empty_in_values_right_side ,
Expand Down
Expand Up
@@ -345,7 +409,6 @@ class SQLGlotCompiler(abc.ABC):
lowered_ops : ClassVar [dict [type [ops .Node ], pats .Replace ]] = {}
def __init__ (self ) -> None :
self .agg = AggGen (aggfunc = self ._aggregate )
self .f = FuncGen (copy = self .__class__ .copy_func_args )
self .v = VarGen ()
Expand Down
Expand Up
@@ -411,20 +474,6 @@ def dialect(self) -> str:
def type_mapper (self ) -> type [SqlglotType ]:
"""The type mapper for the backend."""
@abc .abstractmethod
def _aggregate (self , funcname , * args , where ):
"""Translate an aggregate function.
Three flavors of filtering aggregate function inputs:
1. supports filter (duckdb, postgres, others)
e.g.: sum(x) filter (where predicate)
2. use null to filter out
e.g.: sum(if(predicate, x, NULL))
3. clickhouse's ${func}If implementation, e.g.:
sumIf(predicate, x)
"""
# Concrete API
def if_ (self , condition , true , false : sge .Expression | None = None ) -> sge .If :
Expand Down