Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add `ExtractIntervalFilters` optimizer pass. #5979

Merged
merged 8 commits into from May 13, 2019

Conversation

Projects
None yet
3 participants
@tpoterba
Copy link
Collaborator

commented Apr 29, 2019

This PR will not be merged as-is, but split along the 3 commits contained within:

  • Add Coalesce IR node
  • Expose pruning on FilterIntervals relational functions. These should be promoted to full IR nodes, especially after this PR.
  • Add ExtractIntervalFilters optimizer pass.

I also have yet to add tests for the last commit.

What does this PR do?

In [2]: mt = hl.read_matrix_table('data/1kg.rep.mt')

In [3]: mt.filter_rows(mt.locus.contig == '16').count()
Hail: INFO: interval filter loaded 5 of 128 partitions
Out[3]: (384, 284)

In [4]: mt.filter_rows(mt.locus.contig == '16').count_rows()
Hail: INFO: interval filter loaded 5 of 128 partitions
Out[4]: 384

In [5]: mt.filter_rows((mt.locus.contig == '16') | (mt.locus.contig == '19')).count()
Hail: INFO: interval filter loaded 10 of 128 partitions
Out[5]: (730, 284)

In [6]: mt.filter_rows(hl.literal({'16', '19'}).contains(mt.locus.contig)).count_rows()
Hail: INFO: interval filter loaded 10 of 128 partitions
Out[6]: 730

In [7]: mt.filter_rows((mt.locus.contig == '16') & (mt.locus.position > 10_000_000)).count_rows()
Hail: INFO: interval filter loaded 2 of 128 partitions
Out[7]: 82

In [8]: mt.filter_rows((mt.locus.contig == '16') & (mt.locus.position > 10_000_000) & (mt.locus.position < 12_000_000)).count_rows()
Hail: INFO: interval filter loaded 5 of 128 partitions
Out[8]: 384

In [9]: mt.filter_rows(mt.locus == hl.parse_locus('1:3761547')).count()
Hail: INFO: interval filter loaded 1 of 128 partitions
Out[9]: (1, 284)

In [10]: mt.filter_rows(hl.parse_locus_interval('16:20000000-30000000').contains(mt.locus)).count()
Hail: INFO: interval filter loaded 1 of 128 partitions
Out[10]: (35, 284)

@tpoterba tpoterba force-pushed the tpoterba:pushdown-pass branch 2 times, most recently from 9fa703c to 38fd68d May 7, 2019

@tpoterba tpoterba added the stacked PR label May 8, 2019

@tpoterba

This comment has been minimized.

Copy link
Collaborator Author

commented May 8, 2019

stacked on #6073

@tpoterba tpoterba force-pushed the tpoterba:pushdown-pass branch from 38da222 to c451bb9 May 9, 2019

@@ -454,12 +457,6 @@ object Simplify {
mct,
Subst(newRow, BindingEnv(Env("sa" -> Ref("row", mct.typ.rowType)))))

case MatrixColsTable(MatrixFilterCols(child, pred)) =>

This comment has been minimized.

Copy link
@tpoterba

tpoterba May 9, 2019

Author Collaborator

deoptimization. Encountered this while doing some manual tests.

@@ -41,6 +41,13 @@ case class MatrixFilterIntervals(
override def lower(): Option[TableToTableFunction] = Some(TableFilterIntervals(keyType, intervals, keep))

def execute(mv: MatrixValue): MatrixValue = throw new UnsupportedOperationException

override def requestType(requestedType: MatrixType, childBaseType: MatrixType): MatrixType = {

This comment has been minimized.

Copy link
@tpoterba

tpoterba May 9, 2019

Author Collaborator

These will be backed out in a subsequent PR to create full IR nodes for FilterIntervals nodes.

@@ -15,7 +15,7 @@ object CanEmit {
object IsConstant {
def apply(ir: IR): Boolean = {
ir match {
case I32(_) | I64(_) | F32(_) | F64(_) | True() | False() | NA(_) | Literal(_, _) => true
case I32(_) | I64(_) | F32(_) | F64(_) | True() | False() | NA(_) | Str(_) | Literal(_, _) => true

This comment has been minimized.

Copy link
@tpoterba

tpoterba May 9, 2019

Author Collaborator

this change required to make the contig stuff work, so can't really stage this in a separate PR.

@tpoterba

This comment has been minimized.

Copy link
Collaborator Author

commented May 9, 2019

Randomly assigned Patrick. This is a big PR but about half of that is tests.

@patrick-schultz
Copy link
Collaborator

left a comment

Awesome. Here's my first pass, still need to understand some details better.

val ApplySpecial(_, Seq(Literal(_, lit), _)) = comp
val intervals = lit match {
case null => Array[Interval]()
case i: Interval => Array(Interval(endpoint(i.left.point, i.left.sign), endpoint(i.right.point, i.right.sign)))

This comment has been minimized.

Copy link
@patrick-schultz

patrick-schultz May 13, 2019

Collaborator

Couldn't this just be Array(i)?

This comment has been minimized.

Copy link
@tpoterba

tpoterba May 13, 2019

Author Collaborator

yes. endpoint used to wrap things in Row, but I now do that at the end.

val intOrd = TInt32().ordering.intervalEndpointOrdering
val intervals = rg.contigs.indices
.flatMap { i =>
Interval.intersection(Array(openInterval(pos, TInt32(), comp.op, isFlipped)),

This comment has been minimized.

Copy link
@patrick-schultz

patrick-schultz May 13, 2019

Collaborator

You're only intersecting single intervals here. I think this could be made a bit more efficient using the Interval.intersect method (as opposed to the class object method).


val k1 = GetField(ref, key.head)

val (nodes, intervals) = processPredicates(extract(cond, ref, k1), k1.typ)

This comment has been minimized.

Copy link
@patrick-schultz

patrick-schultz May 13, 2019

Collaborator

You don't need to make this change, but I think you could refactor this in a way that avoids building an intermediate KeyFilterPredicate, but keeping the code mostly unchanged.

For example, in processPredicates replace

      case KeyComparison(comp) =>
        val (v, isFlipped) = if (IsConstant(comp.l)) (comp.l, false) else (comp.r, true)
        Set[IR](comp) -> Array(openInterval(constValue(v), v.typ, comp.op, isFlipped))

by

def keyComparison(comp): (Set[IR], Array[Interval]) = {
    val (v, isFlipped) = if (IsConstant(comp.l)) (comp.l, false) else (comp.r, true)
    Set[IR](comp) -> Array(openInterval(constValue(v), v.typ, comp.op, isFlipped))
}

and instead of constructing a KeyComparison node in extract, just call keyComparison directly, combining processPredicates(extract(cond, ref, k1), k1.typ)) into a single recursive traversal.

This comment has been minimized.

Copy link
@tpoterba

tpoterba May 13, 2019

Author Collaborator

ah, yeah, you're totally right. I had to do some transformations of the KeyFilterPredicate objects in an earlier iteration, but now this should be straightforward! I'll take a stab at it.

This comment has been minimized.

Copy link
@tpoterba

tpoterba May 13, 2019

Author Collaborator

yep, even simpler after this change!

import is.hail.variant.{Locus, ReferenceGenome}
import org.apache.spark.sql.Row

sealed trait KeyFilterPredicate

This comment has been minimized.

Copy link
@patrick-schultz

patrick-schultz May 13, 2019

Collaborator

Unused now?

This comment has been minimized.

Copy link
@tpoterba

tpoterba May 13, 2019

Author Collaborator

oh, duh.

@danking danking merged commit 1801ce0 into hail-is:master May 13, 2019

1 check passed

ci-test success
Details
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.