Skip to content

Commit

Permalink
refactor: make quantile, any, and all reductions filterable
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Jan 5, 2023
1 parent 4f03c49 commit 1bafc9e
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 40 deletions.
25 changes: 25 additions & 0 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,31 @@ def mean_and_std(v):
)
],
),
param(
lambda t, where: t.double_col.quantile(0.5, where=where),
lambda t, where: t.double_col[where].quantile(0.5),
id="quantile",
marks=[
mark.notimpl(
[
"bigquery",
"clickhouse",
"dask",
"datafusion",
"duckdb",
"impala",
"mssql",
"mysql",
"polars",
"postgres",
"pyspark",
"snowflake",
"sqlite",
"trino",
]
)
],
),
],
)
@pytest.mark.parametrize(
Expand Down
11 changes: 8 additions & 3 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,8 +772,13 @@ def finder(node: ops.Node):


# TODO(kszucs): move to types/logical.py
def _make_any(expr, any_op_class: type[ops.Any] | type[ops.NotAny]):
assert isinstance(expr, ir.Expr)
def _make_any(
expr,
any_op_class: type[ops.Any] | type[ops.NotAny],
*,
where: ir.BooleanValue | None = None,
):
assert isinstance(expr, ir.Expr), type(expr)

tables = find_immediate_parent_tables(expr.op())
predicates = find_predicates(expr.op(), flatten=True)
Expand All @@ -784,7 +789,7 @@ def _make_any(expr, any_op_class: type[ops.Any] | type[ops.NotAny]):
predicates=predicates,
)
else:
op = any_op_class(expr)
op = any_op_class(expr, where=where)
return op.to_expr()


Expand Down
63 changes: 40 additions & 23 deletions ibis/expr/operations/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,22 @@ class Filterable(Value):
@public
class Count(Filterable, Reduction):
arg = rlz.column(rlz.any)

output_dtype = dt.int64


@public
class CountStar(Filterable, Reduction):
arg = rlz.table

output_dtype = dt.int64


@public
class Arbitrary(Filterable, Reduction):
arg = rlz.column(rlz.any)
how = rlz.isin({'first', 'last', 'heavy'})

output_dtype = rlz.dtype_like('arg')


Expand All @@ -52,6 +55,7 @@ class BitAnd(Filterable, Reduction):
"""

arg = rlz.column(rlz.integer)

output_dtype = rlz.dtype_like('arg')


Expand All @@ -69,6 +73,7 @@ class BitOr(Filterable, Reduction):
"""

arg = rlz.column(rlz.integer)

output_dtype = rlz.dtype_like('arg')


Expand All @@ -86,6 +91,7 @@ class BitXor(Filterable, Reduction):
"""

arg = rlz.column(rlz.integer)

output_dtype = rlz.dtype_like('arg')


Expand All @@ -107,14 +113,14 @@ class Mean(Filterable, Reduction):

@attribute.default
def output_dtype(self):
if self.arg.output_dtype.is_boolean():
if (dtype := self.arg.output_dtype).is_boolean():
return dt.float64
else:
return dt.higher_precedence(self.arg.output_dtype, dt.float64)
return dt.higher_precedence(dtype, dt.float64)


@public
class Quantile(Reduction):
class Quantile(Filterable, Reduction):
arg = rlz.any
quantile = rlz.strict_numeric
interpolation = rlz.isin({'linear', 'lower', 'higher', 'midpoint', 'nearest'})
Expand All @@ -138,8 +144,8 @@ class VarianceBase(Filterable, Reduction):

@attribute.default
def output_dtype(self):
if self.arg.output_dtype.is_decimal():
return self.arg.output_dtype.largest
if (dtype := self.arg.output_dtype).is_decimal():
return dtype.largest
else:
return dt.float64

Expand Down Expand Up @@ -179,32 +185,37 @@ class Covariance(Filterable, Reduction):
@public
class Mode(Filterable, Reduction):
arg = rlz.column(rlz.any)

output_dtype = rlz.dtype_like('arg')


@public
class Max(Filterable, Reduction):
arg = rlz.column(rlz.any)

output_dtype = rlz.dtype_like('arg')


@public
class Min(Filterable, Reduction):
arg = rlz.column(rlz.any)

output_dtype = rlz.dtype_like('arg')


@public
class ArgMax(Filterable, Reduction):
arg = rlz.column(rlz.any)
key = rlz.column(rlz.any)

output_dtype = rlz.dtype_like("arg")


@public
class ArgMin(Filterable, Reduction):
arg = rlz.column(rlz.any)
key = rlz.column(rlz.any)

output_dtype = rlz.dtype_like("arg")


Expand All @@ -217,7 +228,7 @@ class ApproxCountDistinct(Filterable, Reduction):

arg = rlz.column(rlz.any)

# Impala 2.0 and higher returns a DOUBLE return ir.DoubleScalar
# Impala 2.0 and higher returns a DOUBLE
output_dtype = dt.int64


Expand All @@ -226,6 +237,7 @@ class ApproxMedian(Filterable, Reduction):
"""Compute the approximate median of a set of comparable values."""

arg = rlz.column(rlz.any)

output_dtype = rlz.dtype_like('arg')


Expand All @@ -237,21 +249,6 @@ class GroupConcat(Filterable, Reduction):
output_dtype = dt.string


@public
class All(Reduction):
arg = rlz.column(rlz.boolean)
output_dtype = dt.boolean

def negate(self):
return NotAll(self.arg)


@public
class NotAll(All):
def negate(self):
return All(self.arg)


@public
class CountDistinct(Filterable, Reduction):
arg = rlz.column(rlz.any)
Expand All @@ -269,7 +266,27 @@ def output_dtype(self):


@public
class Any(Reduction, _Negatable):
class All(Filterable, Reduction, _Negatable):
arg = rlz.column(rlz.boolean)

output_dtype = dt.boolean

def negate(self):
return NotAll(self.arg)


@public
class NotAll(Filterable, Reduction, _Negatable):
arg = rlz.column(rlz.boolean)

output_dtype = dt.boolean

def negate(self) -> Any:
return All(*self.args)


@public
class Any(Filterable, Reduction, _Negatable):
arg = rlz.column(rlz.boolean)

output_dtype = dt.boolean
Expand All @@ -279,7 +296,7 @@ def negate(self) -> NotAny:


@public
class NotAny(Reduction, _Negatable):
class NotAny(Filterable, Reduction, _Negatable):
arg = rlz.column(rlz.boolean)

output_dtype = dt.boolean
Expand Down
22 changes: 11 additions & 11 deletions ibis/expr/types/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

from typing import TYPE_CHECKING

if TYPE_CHECKING:
import ibis.expr.types as ir

from public import public

import ibis.expr.operations as ops
from ibis.expr.types.core import _binop
from ibis.expr.types.numeric import NumericColumn, NumericScalar, NumericValue

if TYPE_CHECKING:
import ibis.expr.types as ir


@public
class BooleanValue(NumericValue):
Expand Down Expand Up @@ -77,21 +77,21 @@ class BooleanScalar(NumericScalar, BooleanValue):

@public
class BooleanColumn(NumericColumn, BooleanValue):
def any(self) -> BooleanValue:
def any(self, where: BooleanValue | None = None) -> BooleanValue:
import ibis.expr.analysis as an

return an._make_any(self, ops.Any)
return an._make_any(self, ops.Any, where=where)

def notany(self) -> BooleanValue:
def notany(self, where: BooleanValue | None = None) -> BooleanValue:
import ibis.expr.analysis as an

return an._make_any(self, ops.NotAny)
return an._make_any(self, ops.NotAny, where=where)

def all(self) -> BooleanScalar:
return ops.All(self).to_expr()
def all(self, where: BooleanValue | None = None) -> BooleanScalar:
return ops.All(self, where=where).to_expr()

def notall(self) -> BooleanScalar:
return ops.NotAll(self).to_expr()
def notall(self, where: BooleanValue | None = None) -> BooleanScalar:
return ops.NotAll(self, where=where).to_expr()

def cumany(self) -> BooleanColumn:
return ops.CumulativeAny(self).to_expr()
Expand Down
11 changes: 8 additions & 3 deletions ibis/expr/types/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def quantile(
"midpoint",
"nearest",
] = "linear",
where: ir.BooleanValue | None = None,
) -> NumericScalar:
"""Return value at the given quantile.
Expand All @@ -328,15 +329,19 @@ def quantile(
quantile
`0 <= quantile <= 1`, the quantile(s) to compute
interpolation
This optional parameter specifies the interpolation method to use,
when the desired quantile lies between two data points `i` and `j`:
!!! warning "This parameter is backend dependent and may have no effect"
This parameter specifies the interpolation method to use, when the
desired quantile lies between two data points `i` and `j`:
* linear: `i + (j - i) * fraction`, where `fraction` is the
fractional part of the index surrounded by `i` and `j`.
* lower: `i`.
* higher: `j`.
* nearest: `i` or `j` whichever is nearest.
* midpoint: (`i` + `j`) / 2.
where
Boolean filter for input values
Returns
-------
Expand All @@ -347,7 +352,7 @@ def quantile(
op = ops.MultiQuantile
else:
op = ops.Quantile
return op(self, quantile, interpolation).to_expr()
return op(self, quantile, interpolation, where=where).to_expr()

def std(
self,
Expand Down

0 comments on commit 1bafc9e

Please sign in to comment.