Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ApproximateCDF aggregator #5570

Merged
merged 21 commits into from Mar 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion hail/python/hail/expr/aggregators/__init__.py
@@ -1,9 +1,10 @@
from .aggregators import collect, collect_as_set, count, count_where, counter, \
from .aggregators import approx_cdf, collect, collect_as_set, count, count_where, counter, \
any, all, take, min, max, sum, array_sum, mean, stats, product, fraction, \
hardy_weinberg_test, explode, filter, inbreeding, call_stats, info_score, \
hist, linreg, corr, group_by, downsample, array_agg, _prev_nonnull

__all__ = [
'approx_cdf',
'collect',
'collect_as_set',
'count',
Expand Down
44 changes: 44 additions & 0 deletions hail/python/hail/expr/aggregators/aggregators.py
Expand Up @@ -187,6 +187,50 @@ def _check_agg_bindings(expr, bindings):
raise ExpressionException("dynamic variables created by 'hl.bind' or lambda methods like 'hl.map' may not be aggregated")


def approx_cdf(expr, k=100):
"""Produce a summary of the distribution of values.

.. include: _templates/experimental.rst

Notes
-----
This method returns a struct containing two arrays: `values` and `ranks`.
The `values` array contains an ordered sample of values seen. The `ranks`
array is one longer, and contains the approximate ranks for the
corresponding values.

These represent a summary of the CDF of the distribution of values. In
particular, for any value `x = values(i)` in the summary, we estimate that
there are `ranks(i)` values strictly less than `x`, and that there are
`ranks(i+1)` values less than or equal to `x`. For any value `y` (not
necessarily in the summary), we estimate CDF(y) to be `ranks(i)`, where `i`
is such that `values(i-1) < y ≤ values(i)`.

An alternative intuition is that the summary encodes a compressed
approximation to the sorted list of values. For example, values=[0,2,5,6,9]
and ranks=[0,3,4,5,8,10] represents the approximation [0,0,0,2,5,6,6,6,9,9],
with the value `values(i)` occupying indices `ranks(i)` (inclusive) to
`ranks(i+1)` (exclusive).

Warning
-------
This is an approximate and nondeterministic method.

Parameters
----------
expr : :class:`.Expression`
Expression to collect.
k : :obj:`int`
Parameter controlling the accuracy vs. memory usage tradeoff.

Returns
-------
:class:`.StructExpression`
Struct containing `values` and `ranks` arrays.
"""
return _agg_func('ApproxCDF', [expr], tstruct(values=tarray(expr.dtype), ranks=tarray(tint64)), constructor_args=[k])


@typecheck(expr=expr_any)
def collect(expr) -> ArrayExpression:
"""Collect records into an array.
Expand Down
8 changes: 8 additions & 0 deletions hail/python/hail/ir/register_aggregators.py
Expand Up @@ -3,6 +3,14 @@
def register_aggregators():
from hail.expr.types import dtype

register_aggregator('ApproxCDF', (dtype('int32'),), None, (dtype('int32'),),
dtype('struct{values:array<int32>,ranks:array<int64>}'))
register_aggregator('ApproxCDF', (dtype('int32'),), None, (dtype('int64'),),
dtype('struct{values:array<int64>,ranks:array<int64>}'))
register_aggregator('ApproxCDF', (dtype('int32'),), None, (dtype('float32'),),
dtype('struct{values:array<float32>,ranks:array<int64>}'))
register_aggregator('ApproxCDF', (dtype('int32'),), None, (dtype('float64'),),
dtype('struct{values:array<float64>,ranks:array<int64>}'))
register_aggregator('Fraction', (), None, (dtype('bool'),), dtype('float64'))

stats_aggregator_type = dtype('struct{mean:float64,stdev:float64,min:float64,max:float64,n:int64,sum:float64}')
Expand Down
8 changes: 8 additions & 0 deletions hail/python/test/hail/expr/test_expr.py
Expand Up @@ -305,6 +305,14 @@ def test_aggregators(self):
self.assertTrue(r.assert1)
self.assertTrue(r.assert2)

def test_approx_cdf(self):
table = hl.utils.range_table(100)
table = table.annotate(i=table.idx)
table.aggregate(hl.agg.approx_cdf(table.i))
table.aggregate(hl.agg.approx_cdf(hl.int64(table.i)))
table.aggregate(hl.agg.approx_cdf(hl.float32(table.i)))
table.aggregate(hl.agg.approx_cdf(hl.float64(table.i)))

def test_counter_ordering(self):
ht = hl.utils.range_table(10)
assert ht.aggregate(hl.agg.counter(10 - ht.idx).get(10, -1)) == 1
Expand Down