New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add set functions #3552

Merged
merged 8 commits into from May 14, 2018
File filter...
Filter file types
Jump to file or symbol
Failed to load files and symbols.
+305 −23
Diff settings

Always

Just for now

Copy path View file
@@ -1376,3 +1376,28 @@ def assert_min_reps_to(old, new, pos_change=0):
assert_min_reps_to(['GCTAA', 'GCAAA', 'G'], ['GCTAA', 'GCAAA', 'G'])
assert_min_reps_to(['GCTAA', 'GCAAA', 'GCCAA'], ['T', 'A', 'C'], pos_change=2)
assert_min_reps_to(['GCTAA', 'GCAAA', 'GCCAA', '*'], ['T', 'A', 'C', '*'], pos_change=2)

def assert_evals_to(self, e, v):
self.assertEqual(e.value, v)

def test_set_functions(self):
s = hl.set([1, 3, 7])
t = hl.set([3, 8])
self.assert_evals_to(s, set([1, 3, 7]))

self.assert_evals_to(s.add(3), set([1, 3, 7]))
self.assert_evals_to(s.add(4), set([1, 3, 4, 7]))

self.assert_evals_to(s.remove(3), set([1, 7]))
self.assert_evals_to(s.remove(4), set([1, 3, 7]))

self.assert_evals_to(s.contains(3), True)
self.assert_evals_to(s.contains(4), False)

self.assert_evals_to(s.difference(t), set([1, 7]))
self.assert_evals_to(s.intersection(t), set([3]))

self.assert_evals_to(s.is_subset(hl.set([1, 3, 4, 7])), True)
self.assert_evals_to(s.is_subset(hl.set([1, 3])), False)

self.assert_evals_to(s.union(t), set([1, 3, 7, 8]))
@@ -869,6 +869,10 @@ case class ApplyMethod(posn: Position, lhs: AST, method: String, args: Array[AST
case (_: TArray, "map") => ir.ArrayMap(a, name, b)
case (_: TArray, "filter") => ir.ArrayFilter(a, name, b)
case (_: TArray, "flatMap") => ir.ArrayFlatMap(a, name, b)
case (_: TSet, "flatMap") =>
ir.ToSet(
ir.ArrayFlatMap(ir.ToArray(a), name,
ir.ToArray(b)))
case (_: TArray, "exists") =>
val v = ir.genUID()
ir.ArrayFold(a, ir.False(), v, name, ir.ApplySpecial("||", FastSeq(ir.Ref(v, TBoolean()), b)))
@@ -254,7 +254,6 @@ private class Emit(
val codeR = emit(r)
EmitTriplet(Code(codeL.setup, codeR.setup),
codeL.m || codeR.m,

BinaryOp.emit(op, l.typ, r.typ, codeL.v, codeR.v))
case ApplyUnaryPrimOp(op, x) =>
val typ = ir.typ
@@ -816,14 +815,14 @@ private class Emit(
val filterCont = { (cont: F, m: Code[Boolean], v: Code[_]) =>
Code(
xmv := m,
xmv.mux(
Code._empty,
Code(
xvv := v,
codeCond.setup,
(codeCond.m || !coerce[Boolean](codeCond.v)).mux(
Code._empty,
cont(false, xvv)))))
xvv := xmv.mux(
defaultValue(x.typ.elementType),
v),
Code(
codeCond.setup,
(codeCond.m || !coerce[Boolean](codeCond.v)).mux(
Code._empty,
cont(xmv, xvv))))
}
emitArrayIterator(a).copy(length = None).wrapContinuation(filterCont)

@@ -168,7 +168,7 @@ final case class ApplySpecial(function: String, args: Seq[IR]) extends IR {
implementation.unify(argTypes)
implementation.returnType.subst()
}

def isDeterministic: Boolean = implementation.isDeterministic
}

@@ -81,6 +81,7 @@ object IRFunctionRegistry {
}
}

SetFunctions.registerAll()
CallFunctions.registerAll()
GenotypeFunctions.registerAll()
MathFunctions.registerAll()
@@ -0,0 +1,73 @@
package is.hail.expr.ir.functions

import is.hail.expr.ir._
import is.hail.expr.types.{TArray, TBoolean, TSet}
import is.hail.utils.FastSeq

object SetFunctions extends RegistryFunctions {
def registerAll() {
registerIR("toSet", TArray(tv("T"))) { a =>
ToSet(a)
}

registerIR("contains", TSet(tv("T")), tv("T")) { case (s, v) =>
SetContains(s, v)
}

registerIR("remove", TSet(tv("T")), tv("T")) { case (s, v) =>
val t = v.typ
val x = genUID()
ToSet(
ArrayFilter(
ToArray(s),
x,
ApplyUnaryPrimOp(Bang(), nonstrictEQ(Ref(x, t), v))))
}

registerIR("add", TSet(tv("T")), tv("T")) { case (s, v) =>
val t = v.typ
val x = genUID()
ToSet(
ArrayFlatMap(
MakeArray(FastSeq(ToArray(s), MakeArray(FastSeq(v), TArray(t))), TArray(TArray(t))),
x,
Ref(x, TArray(t))))
}

registerIR("union", TSet(tv("T")), TSet(tv("T"))) { case (s1, s2) =>
val t = -s1.typ.asInstanceOf[TSet].elementType
val x = genUID()
ToSet(
ArrayFlatMap(
MakeArray(FastSeq(ToArray(s1), ToArray(s2)), TArray(TArray(t))),
x,
Ref(x, TArray(t))))
}

registerIR("intersection", TSet(tv("T")), TSet(tv("T"))) { case (s1, s2) =>
val t = -s1.typ.asInstanceOf[TSet].elementType
val x = genUID()
ToSet(
ArrayFilter(ToArray(s1), x,
SetContains(s2, Ref(x, t))))
}

registerIR("difference", TSet(tv("T")), TSet(tv("T"))) { case (s1, s2) =>
val t = -s1.typ.asInstanceOf[TSet].elementType
val x = genUID()
ToSet(
ArrayFilter(ToArray(s1), x,
ApplyUnaryPrimOp(Bang(), SetContains(s2, Ref(x, t)))))
}

registerIR("isSubset", TSet(tv("T")), TSet(tv("T"))) { case (s, w) =>
val t = -s.typ.asInstanceOf[TSet].elementType
val a = genUID()
val x = genUID()
ArrayFold(ToArray(s), True(), a, x,
// FIXME short circuit
ApplySpecial("&&",
FastSeq(Ref(a, TBoolean()), SetContains(w, Ref(x, t)))))
}
}
}
@@ -2,6 +2,7 @@ package is.hail.expr

import is.hail.asm4s
import is.hail.asm4s._
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.types._

package object ir {
@@ -53,8 +54,26 @@ package object ir {

private[ir] def coerce[T](ti: TypeInfo[_]): TypeInfo[T] = ti.asInstanceOf[TypeInfo[T]]

private[ir] def coerce[T <: Type](x: Type): T = {
import is.hail.expr.types
types.coerce[T](x)
private[ir] def coerce[T <: Type](x: Type): T = types.coerce[T](x)

def invoke(name: String, args: IR*): IR = {
IRFunctionRegistry.lookupConversion(name, args.map(_.typ)) match {
case Some(f) => f(args)
}
}

def nonstrictEQ(l: IR, r: IR): IR = {
// FIXME better as a (non-strict) BinaryOp?
assert(l.typ == r.typ)
val t = l.typ
val lv = genUID()
val rv = genUID()
Let(lv, l,
Let(rv, r,
If(IsNA(Ref(lv, t)),
IsNA(Ref(rv, t)),
If(IsNA(Ref(rv, t)),
False(),
ApplyBinaryPrimOp(EQ(), Ref(lv, t), Ref(rv, t))))))
}
}
@@ -4,9 +4,9 @@ import scala.collection.mutable
import scala.reflect.ClassTag

object FastSeq {
def empty[T](implicit tct: ClassTag[T]): Seq[T] = FastSeq()
def empty[T](implicit tct: ClassTag[T]): IndexedSeq[T] = FastSeq()

def apply[T](args: T*)(implicit tct: ClassTag[T]): Seq[T] = {
def apply[T](args: T*)(implicit tct: ClassTag[T]): IndexedSeq[T] = {
args match {
case args: mutable.WrappedArray[T] => args
case args: mutable.ArrayBuffer[T] => args
@@ -4,8 +4,9 @@ import java.net.URI
import java.nio.file.{Files, Paths}

import breeze.linalg.{DenseMatrix, Matrix, Vector}
import is.hail.annotations.Annotation
import is.hail.expr.types.{TFloat64, TString}
import is.hail.annotations.{Annotation, Region, RegionValueBuilder, SafeRow}
import is.hail.expr.ir._
import is.hail.expr.types._
import is.hail.linalg.BlockMatrix
import is.hail.methods.{KinshipMatrix, SplitMulti}
import is.hail.table.Table
@@ -249,8 +250,65 @@ object TestUtils {
def exportGen(mt: MatrixTable, path: String, precision: Int): Unit = {
mt.selectCols(""" {id1: sa.s, id2: sa.s, missing: 0.toFloat64} """, Some(FastIndexedSeq()))
.annotateRowsExpr(
"varid" -> """let l = va.locus and a = va.alleles in [l.contig, str(l.position), a[0], a[1]].mkString(":")""",
"rsid" -> "\".\"")
"varid" -> """let l = va.locus and a = va.alleles in [l.contig, str(l.position), a[0], a[1]].mkString(":")""",
"rsid" -> "\".\"")
.exportGen(path, precision)
}

def eval(x: IR): Any = {
eval(x, FastSeq())
}

def eval(x: IR, env: Env[(Any, Type)]): Any = {
val (substEnv, args) = env.m.toFastSeq
.zipWithIndex
.foldLeft((Env.empty[IR], FastSeq.empty[(Any, Type)])) { case ((e, xs), ((n, (v, t)), i)) => (e.bind(n, In(i, t)), xs :+ (v, t)) }
eval(Subst(x, substEnv), args)
}

def eval(x: IR, args: IndexedSeq[(Any, Type)]): Any = {
val i = Interpret[Any](x, Env.empty[(Any, Type)], args, None)

val i2 = Interpret[Any](x, Env.empty[(Any, Type)], args, None, optimize = false)
assert(i == i2)

// verify compiler and interpreter agree
val argsVar = genUID()
val argsType = TTuple(args.map(_._2): _*)
val resultType = TTuple(x.typ)

def rewrite(x: IR): IR = {

This comment has been minimized.

@maccum

maccum May 11, 2018

Contributor

rewrite() isn't being used. delete?

x match {
case In(i, t) =>
GetTupleElement(Ref(argsVar, argsType), i)
case _ =>
Recur(rewrite)(x)
}
}

val (resultType2, f) = Compile[Long, Long]("args", argsType, MakeTuple(FastSeq(rewrite(x))))
assert(resultType2 == resultType)

val c = Region.scoped { region =>
val rvb = new RegionValueBuilder(region)
rvb.start(argsType)
rvb.startTuple()
args.foreach { case (v, t) =>
rvb.addAnnotation(t, v)
}
rvb.endTuple()
val argsOff = rvb.end()

val resultOff = f()(region, argsOff, false)
SafeRow(resultType.asInstanceOf[TBaseStruct], region, resultOff).get(0)
}

assert(i == c)

i
}

def assertEvalsTo(x: IR, expected: Any) {
assert(eval(x) == expected)
}
}
@@ -7,14 +7,13 @@ import is.hail.annotations._
import is.hail.asm4s._
import is.hail.expr.ir.functions.{IRFunctionRegistry, RegistryFunctions}
import is.hail.expr.types._
import is.hail.TestUtils._
import org.testng.annotations.Test
import is.hail.expr.{EvalContext, Parser}
import is.hail.table.Table
import is.hail.utils.FastSeq
import is.hail.variant.Call2

import scala.reflect.ClassTag

object ScalaTestObject {
def testFunction(): Int = 1
}
@@ -0,0 +1,34 @@
package is.hail.expr.ir

import is.hail.expr.types._
import is.hail.TestUtils._
import is.hail.utils.FastIndexedSeq
import org.testng.annotations.Test
import org.scalatest.testng.TestNGSuite

class IRSuite extends TestNGSuite {
@Test def testNonstrictEQ() {
assertEvalsTo(nonstrictEQ(NA(TInt32()), NA(TInt32())), true)
assertEvalsTo(nonstrictEQ(I32(5), I32(5)), true)
assertEvalsTo(nonstrictEQ(NA(TInt32()), I32(5)), false)
}

@Test def testArrayFilter() {
val naa = NA(TArray(TInt32()))
val a = MakeArray(Seq(I32(3), NA(TInt32()), I32(7)), TArray(TInt32()))

assertEvalsTo(ArrayFilter(naa, "x", True()), null)

assertEvalsTo(ArrayFilter(a, "x", NA(TBoolean())), FastIndexedSeq())
assertEvalsTo(ArrayFilter(a, "x", False()), FastIndexedSeq())
assertEvalsTo(ArrayFilter(a, "x", True()), FastIndexedSeq(3, null, 7))

assertEvalsTo(ArrayFilter(a, "x",
IsNA(Ref("x", TInt32()))), FastIndexedSeq(null))
assertEvalsTo(ArrayFilter(a, "x",
ApplyUnaryPrimOp(Bang(), IsNA(Ref("x", TInt32())))), FastIndexedSeq(3, 7))

assertEvalsTo(ArrayFilter(a, "x",
ApplyBinaryPrimOp(LT(), Ref("x", TInt32()), I32(6))), FastIndexedSeq(3))
}
}
Oops, something went wrong.
ProTip! Use n and p to navigate between commits in a pull request.