Skip to content

[query] force aggregate_cols to be local #13405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 16, 2023

Conversation

patrick-schultz
Copy link
Collaborator

@patrick-schultz patrick-schultz commented Aug 10, 2023

CHANGELOG: MatrixTable.aggregate_cols no longer forces a distributed computation. This should be what you want in the majority of cases. In case you know the aggregation is very slow and should be parallelized, use mt.cols().aggregate instead.

Most of the time, aggregate_cols will be much faster performing the aggregation locally. Currently, we generate a TableAggregate over a TableParallelize of the columns. We shouldn't try to optimize that to a local computation during compilation; TableParallelize should express the intent that the computation is expensive and really should be parallelized. This should be considered part of the semantics the compiler must preserve.

This PR changes aggregate_cols to explicitly generate a local computation using StreamAgg (which was only exposed in Python relatively recently, which is why we haven't made this change sooner). Longer term, aggregating columns should probably get its own IR node, especially once we start partitioning along columns.

@danking
Copy link
Contributor

danking commented Aug 10, 2023

Looks like you need the globals too. Two concerns:

  1. Does localize entries not produce a table aggregate?
  2. Does this preserve the column ordering that we all agreed on? Is that tested?

Copy link
Contributor

@danking danking left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comment

@patrick-schultz
Copy link
Collaborator Author

  1. localize_entries is a no-op at runtime, it just exposes the table representation to python
  2. Good point. We agreed on ordering by column key, right?

@danking
Copy link
Contributor

danking commented Aug 10, 2023

I can't find the conversation we all had about this, but I strongly disagree with (2). It has to be the case that this always evaluates to true. This would be profoundly confusing if not.

mt = mt.annotate_globals(
    phenos = hl.agg.collect(mt.pheno)
)
mt = mt.add_col_index('col_idx')
mt = mt.annotate_cols(
    same = mt.phenos[mt.col_idx] == mt.pheno
)

I hate it but I'm willing to accept that

phenos = mt.cols().pheno.collect()

Returns them out of order, even though I don't like that.

EDIT: hit enter too fast

@danking
Copy link
Contributor

danking commented Aug 10, 2023

Heh. We can't annotate_globals.

@danking
Copy link
Contributor

danking commented Aug 10, 2023

OK, alright. I think you're right. I'm just reliving how frustrated I am by this situation. Can you modify aggregate_cols to include the same warning from cols()?

In [6]: import hail as hl
   ...: mt = hl.utils.range_matrix_table(3, 3)
   ...: mt = mt.choose_cols([2, 1, 0])
   ...: mt = mt.checkpoint('/tmp/foo.mt', overwrite=True)
   ...: mt.col_idx.show()
   ...: print(mt.aggregate_cols(hl.agg.collect(mt.col_idx)))
   ...: mt = mt.key_cols_by()
   ...: print(mt.aggregate_cols(hl.agg.collect(mt.col_idx)))
2023-08-10 14:43:14.286 Hail: INFO: wrote matrix table with 3 rows and 3 columns in 3 partitions to /tmp/foo.mt
+---------+
| col_idx |
+---------+
|   int32 |
+---------+
|       2 |
|       1 |
|       0 |
+---------+
2023-08-10 14:43:16.338 Hail: INFO: Coerced sorted dataset
[0, 1, 2]
[2, 1, 0]

@danking
Copy link
Contributor

danking commented Aug 10, 2023

I can't find good tests either, can you add some tests in the spirit of my shared ipython session?

@patrick-schultz
Copy link
Collaborator Author

Had to chase down a latent bug in StreamAgg, but I think everything should be sorted now (pun intended).

Copy link
Contributor

@danking danking left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update your PR message with a CHANGELOG line indicating that we fixed the performance regression?

mt = mt.checkpoint(path)
assert(mt.aggregate_cols(hl.agg.collect(mt.col_idx)) == [0, 1, 2])
mt = mt.key_cols_by()
assert(mt.aggregate_cols(hl.agg.collect(mt.col_idx)) == [2, 1, 0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, sorry, assert in Python doesn't use parentheses. My bad

@danking danking merged commit bfdd53e into hail-is:main Aug 16, 2023
@patrick-schultz patrick-schultz deleted the local-agg-cols branch January 2, 2025 13:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants