Navigation Menu

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ExtractIntervalFilters optimizer pass. #5979

Merged
merged 8 commits into from May 13, 2019

Conversation

tpoterba
Copy link
Contributor

@tpoterba tpoterba commented Apr 29, 2019

This PR will not be merged as-is, but split along the 3 commits contained within:

  • Add Coalesce IR node
  • Expose pruning on FilterIntervals relational functions. These should be promoted to full IR nodes, especially after this PR.
  • Add ExtractIntervalFilters optimizer pass.

I also have yet to add tests for the last commit.

What does this PR do?

In [2]: mt = hl.read_matrix_table('data/1kg.rep.mt')

In [3]: mt.filter_rows(mt.locus.contig == '16').count()
Hail: INFO: interval filter loaded 5 of 128 partitions
Out[3]: (384, 284)

In [4]: mt.filter_rows(mt.locus.contig == '16').count_rows()
Hail: INFO: interval filter loaded 5 of 128 partitions
Out[4]: 384

In [5]: mt.filter_rows((mt.locus.contig == '16') | (mt.locus.contig == '19')).count()
Hail: INFO: interval filter loaded 10 of 128 partitions
Out[5]: (730, 284)

In [6]: mt.filter_rows(hl.literal({'16', '19'}).contains(mt.locus.contig)).count_rows()
Hail: INFO: interval filter loaded 10 of 128 partitions
Out[6]: 730

In [7]: mt.filter_rows((mt.locus.contig == '16') & (mt.locus.position > 10_000_000)).count_rows()
Hail: INFO: interval filter loaded 2 of 128 partitions
Out[7]: 82

In [8]: mt.filter_rows((mt.locus.contig == '16') & (mt.locus.position > 10_000_000) & (mt.locus.position < 12_000_000)).count_rows()
Hail: INFO: interval filter loaded 5 of 128 partitions
Out[8]: 384

In [9]: mt.filter_rows(mt.locus == hl.parse_locus('1:3761547')).count()
Hail: INFO: interval filter loaded 1 of 128 partitions
Out[9]: (1, 284)

In [10]: mt.filter_rows(hl.parse_locus_interval('16:20000000-30000000').contains(mt.locus)).count()
Hail: INFO: interval filter loaded 1 of 128 partitions
Out[10]: (35, 284)

@tpoterba tpoterba force-pushed the pushdown-pass branch 2 times, most recently from 9fa703c to 38fd68d Compare May 8, 2019 15:45
@tpoterba
Copy link
Contributor Author

tpoterba commented May 8, 2019

stacked on #6073

@@ -454,12 +457,6 @@ object Simplify {
mct,
Subst(newRow, BindingEnv(Env("sa" -> Ref("row", mct.typ.rowType)))))

case MatrixColsTable(MatrixFilterCols(child, pred)) =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deoptimization. Encountered this while doing some manual tests.

@@ -41,6 +41,13 @@ case class MatrixFilterIntervals(
override def lower(): Option[TableToTableFunction] = Some(TableFilterIntervals(keyType, intervals, keep))

def execute(mv: MatrixValue): MatrixValue = throw new UnsupportedOperationException

override def requestType(requestedType: MatrixType, childBaseType: MatrixType): MatrixType = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These will be backed out in a subsequent PR to create full IR nodes for FilterIntervals nodes.

@@ -15,7 +15,7 @@ object CanEmit {
object IsConstant {
def apply(ir: IR): Boolean = {
ir match {
case I32(_) | I64(_) | F32(_) | F64(_) | True() | False() | NA(_) | Literal(_, _) => true
case I32(_) | I64(_) | F32(_) | F64(_) | True() | False() | NA(_) | Str(_) | Literal(_, _) => true
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change required to make the contig stuff work, so can't really stage this in a separate PR.

@tpoterba
Copy link
Contributor Author

tpoterba commented May 9, 2019

Randomly assigned Patrick. This is a big PR but about half of that is tests.

Copy link
Collaborator

@patrick-schultz patrick-schultz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome. Here's my first pass, still need to understand some details better.

val ApplySpecial(_, Seq(Literal(_, lit), _)) = comp
val intervals = lit match {
case null => Array[Interval]()
case i: Interval => Array(Interval(endpoint(i.left.point, i.left.sign), endpoint(i.right.point, i.right.sign)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't this just be Array(i)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. endpoint used to wrap things in Row, but I now do that at the end.

val intOrd = TInt32().ordering.intervalEndpointOrdering
val intervals = rg.contigs.indices
.flatMap { i =>
Interval.intersection(Array(openInterval(pos, TInt32(), comp.op, isFlipped)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're only intersecting single intervals here. I think this could be made a bit more efficient using the Interval.intersect method (as opposed to the class object method).


val k1 = GetField(ref, key.head)

val (nodes, intervals) = processPredicates(extract(cond, ref, k1), k1.typ)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to make this change, but I think you could refactor this in a way that avoids building an intermediate KeyFilterPredicate, but keeping the code mostly unchanged.

For example, in processPredicates replace

      case KeyComparison(comp) =>
        val (v, isFlipped) = if (IsConstant(comp.l)) (comp.l, false) else (comp.r, true)
        Set[IR](comp) -> Array(openInterval(constValue(v), v.typ, comp.op, isFlipped))

by

def keyComparison(comp): (Set[IR], Array[Interval]) = {
    val (v, isFlipped) = if (IsConstant(comp.l)) (comp.l, false) else (comp.r, true)
    Set[IR](comp) -> Array(openInterval(constValue(v), v.typ, comp.op, isFlipped))
}

and instead of constructing a KeyComparison node in extract, just call keyComparison directly, combining processPredicates(extract(cond, ref, k1), k1.typ)) into a single recursive traversal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, yeah, you're totally right. I had to do some transformations of the KeyFilterPredicate objects in an earlier iteration, but now this should be straightforward! I'll take a stab at it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, even simpler after this change!

import is.hail.variant.{Locus, ReferenceGenome}
import org.apache.spark.sql.Row

sealed trait KeyFilterPredicate
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, duh.

@danking danking merged commit 1801ce0 into hail-is:master May 13, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants