diff --git a/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala b/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala index 77280de352c..eed29b756cb 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala @@ -633,22 +633,27 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { } private def compareWithConstant(l: Any, r: Value, op: ComparisonOp[_], keySet: KeySet) - : BoolValue = { - if (op.strict && l == null) return BoolValue.allNA(keySet) - r match { + : BoolValue = + if (op.strict && l == null) BoolValue.allNA(keySet) + else r match { case r: KeyField if r.idx == 0 => // simple key comparison BoolValue.fromComparison(l, op).restrict(keySet) case Contig(rgStr) => - // locus contig comparison - assert(op.isInstanceOf[EQ]) + // locus contig equality comparison val b = getIntervalFromContig(l.asInstanceOf[String], ctx.getReference(rgStr)) match { case Some(i) => - BoolValue( + val b = BoolValue( KeySet(i), KeySet(Interval(negInf, i.left), Interval(i.right, endpoint(null, -1))), KeySet(Interval(endpoint(null, -1), posInf)), ) + + op match { + case _: EQ => b + case _: NEQ => BoolValue.not(b) + case _ => BoolValue.top(keySet) + } case None => BoolValue( KeySetLattice.bottom, @@ -671,7 +676,6 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { BoolValue.fromComparisonKeyPrefix(l.asInstanceOf[Row], op).restrict(keySet) case _ => BoolValue.top(keySet) } - } private def opIsSupported(op: ComparisonOp[_]): Boolean = op match { case _: EQ | _: NEQ | _: LTEQ | _: LT | _: GTEQ | _: GT | _: EQWithNA | _: NEQWithNA => true diff --git a/hail/src/test/scala/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala b/hail/src/test/scala/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala index 7a820f67253..12a76d10bda 100644 --- a/hail/src/test/scala/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala @@ -497,6 +497,10 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => checkAll(ir1, ref, ref, testRows, trueIntervals, falseIntervals, naIntervals) checkAll(ir2, ref, ref, testRows, trueIntervals, falseIntervals, naIntervals) + + val ir3 = neq(Str("chr2"), invoke("contig", TString, k)) + checkAll(ir3, ref, ref, testRows, falseIntervals, trueIntervals, naIntervals) + checkAll(not(ir1), ref, ref, testRows, falseIntervals, trueIntervals, naIntervals) } @Test def testLocusPositionComparison(): Unit = {