Skip to content

Commit

Permalink
[query] Support != contig comparisons in filter intervals (#14335)
Browse files Browse the repository at this point in the history
Fixes: #14288 
All other unsupported comparisons now return top
  • Loading branch information
ehigham committed Feb 22, 2024
1 parent 3966066 commit 86db498
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
18 changes: 11 additions & 7 deletions hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala
Expand Up @@ -633,22 +633,27 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) {
}

private def compareWithConstant(l: Any, r: Value, op: ComparisonOp[_], keySet: KeySet)
: BoolValue = {
if (op.strict && l == null) return BoolValue.allNA(keySet)
r match {
: BoolValue =
if (op.strict && l == null) BoolValue.allNA(keySet)
else r match {
case r: KeyField if r.idx == 0 =>
// simple key comparison
BoolValue.fromComparison(l, op).restrict(keySet)
case Contig(rgStr) =>
// locus contig comparison
assert(op.isInstanceOf[EQ])
// locus contig equality comparison
val b = getIntervalFromContig(l.asInstanceOf[String], ctx.getReference(rgStr)) match {
case Some(i) =>
BoolValue(
val b = BoolValue(
KeySet(i),
KeySet(Interval(negInf, i.left), Interval(i.right, endpoint(null, -1))),
KeySet(Interval(endpoint(null, -1), posInf)),
)

op match {
case _: EQ => b
case _: NEQ => BoolValue.not(b)
case _ => BoolValue.top(keySet)
}
case None =>
BoolValue(
KeySetLattice.bottom,
Expand All @@ -671,7 +676,6 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) {
BoolValue.fromComparisonKeyPrefix(l.asInstanceOf[Row], op).restrict(keySet)
case _ => BoolValue.top(keySet)
}
}

private def opIsSupported(op: ComparisonOp[_]): Boolean = op match {
case _: EQ | _: NEQ | _: LTEQ | _: LT | _: GTEQ | _: GT | _: EQWithNA | _: NEQWithNA => true
Expand Down
Expand Up @@ -497,6 +497,10 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer =>

checkAll(ir1, ref, ref, testRows, trueIntervals, falseIntervals, naIntervals)
checkAll(ir2, ref, ref, testRows, trueIntervals, falseIntervals, naIntervals)

val ir3 = neq(Str("chr2"), invoke("contig", TString, k))
checkAll(ir3, ref, ref, testRows, falseIntervals, trueIntervals, naIntervals)
checkAll(not(ir1), ref, ref, testRows, falseIntervals, trueIntervals, naIntervals)
}

@Test def testLocusPositionComparison(): Unit = {
Expand Down

0 comments on commit 86db498

Please sign in to comment.