Skip to content

Commit

Permalink
feat(datafusion): add count_distinct, median, approx_median, stddev a…
Browse files Browse the repository at this point in the history
…nd var aggregations
  • Loading branch information
mesejo authored and cpcloud committed Jul 22, 2023
1 parent 27e17d6 commit 45089c4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 21 deletions.
41 changes: 41 additions & 0 deletions ibis/backends/datafusion/compiler.py
Expand Up @@ -351,6 +351,11 @@ def count(op):
return df.functions.count(translate(op.arg))


@translate.register(ops.CountDistinct)
def count_distinct(op):
return df.functions.count(translate(op.arg), distinct=True)


@translate.register(ops.CountStar)
def count_star(_):
return df.functions.count(df.literal(1))
Expand Down Expand Up @@ -382,6 +387,42 @@ def mean(op):
return df.functions.avg(arg)


@translate.register(ops.Median)
def median(op):
arg = translate(op.arg)
return df.functions.median(arg)


@translate.register(ops.ApproxMedian)
def approx_median(op):
arg = translate(op.arg)
return df.functions.approx_median(arg)


@translate.register(ops.Variance)
def variance(op):
arg = translate(op.arg)

if op.how == "sample":
return df.functions.var_samp(arg)
elif op.how == "pop":
return df.functions.var_pop(arg)
else:
raise ValueError(f"Unrecognized how value: {op.how}")


@translate.register(ops.StandardDev)
def stddev(op):
arg = translate(op.arg)

if op.how == "sample":
return df.functions.stddev_samp(arg)
elif op.how == "pop":
return df.functions.stddev_pop(arg)
else:
raise ValueError(f"Unrecognized how value: {op.how}")


@translate.register(ops.Contains)
def contains(op):
value = translate(op.value)
Expand Down
23 changes: 2 additions & 21 deletions ibis/backends/tests/test_aggregation.py
Expand Up @@ -261,9 +261,6 @@ def mean_and_std(v):
lambda t, where: t.bool_col.nunique(where=where),
lambda t, where: t.bool_col[where].dropna().nunique(),
id='nunique',
marks=pytest.mark.notimpl(
["datafusion"], raises=com.OperationNotDefinedError
),
),
param(
lambda t, where: t.bool_col.any(where=where),
Expand Down Expand Up @@ -483,10 +480,6 @@ def mean_and_std(v):
lambda t, where: t.double_col[where].std(ddof=1),
id='std',
marks=[
mark.notimpl(
["datafusion"],
raises=com.OperationNotDefinedError,
),
mark.notimpl(
["druid"],
raises=sa.exc.ProgrammingError,
Expand All @@ -499,10 +492,6 @@ def mean_and_std(v):
lambda t, where: t.double_col[where].var(ddof=1),
id='var',
marks=[
mark.notimpl(
["datafusion"],
raises=com.OperationNotDefinedError,
),
mark.notimpl(
["druid"],
raises=sa.exc.ProgrammingError,
Expand All @@ -515,10 +504,6 @@ def mean_and_std(v):
lambda t, where: t.double_col[where].std(ddof=0),
id='std_pop',
marks=[
mark.notimpl(
["datafusion"],
raises=com.OperationNotDefinedError,
),
mark.notimpl(
["druid"],
raises=sa.exc.ProgrammingError,
Expand All @@ -531,10 +516,6 @@ def mean_and_std(v):
lambda t, where: t.double_col[where].var(ddof=0),
id='var_pop',
marks=[
mark.notimpl(
["datafusion"],
raises=com.OperationNotDefinedError,
),
mark.notimpl(
["druid"],
raises=sa.exc.ProgrammingError,
Expand Down Expand Up @@ -1080,7 +1061,7 @@ def test_corr_cov(


@pytest.mark.notimpl(
["datafusion", "mysql", "sqlite", "mssql", "druid"],
["mysql", "sqlite", "mssql", "druid"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.broken(
Expand All @@ -1095,7 +1076,7 @@ def test_approx_median(alltypes):


@pytest.mark.notimpl(
["bigquery", "datafusion", "druid", "sqlite"], raises=com.OperationNotDefinedError
["bigquery", "druid", "sqlite"], raises=com.OperationNotDefinedError
)
@pytest.mark.notyet(
["impala", "mysql", "mssql", "druid", "pyspark", "trino"],
Expand Down

0 comments on commit 45089c4

Please sign in to comment.