Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed May 13, 2024
1 parent e4d1046 commit 38496f3
Show file tree
Hide file tree
Showing 77 changed files with 2,543 additions and 3,107 deletions.
4 changes: 3 additions & 1 deletion hail/src/main/scala/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,9 @@ class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache
IRParser.parse_value_ir(
s,
IRParserEnvironment(ctx, persistedIR.toMap),
BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*),
BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) =>
Name(n) -> IRParser.parseType(t)
}.toSeq: _*),
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import is.hail.asm4s._
import is.hail.backend._
import is.hail.expr.Validate
import is.hail.expr.ir.{
Compile, IR, IRParser, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck,
Compile, IR, IRParser, LoweringAnalyses, MakeTuple, Name, SortField, TableIR, TableReader,
TypeCheck,
}
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.functions.IRFunctionRegistry
Expand Down Expand Up @@ -552,7 +553,7 @@ case class ServiceBackendExecutePayload(
case class SerializedIRFunction(
name: String,
type_parameters: Array[String],
value_parameter_names: Array[String],
value_parameter_names: Array[Name],
value_parameter_types: Array[String],
return_type: String,
rendered_body: String,
Expand Down
4 changes: 3 additions & 1 deletion hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,9 @@ class SparkBackend(
IRParser.parse_value_ir(
s,
IRParserEnvironment(ctx, irMap = persistedIR.toMap),
BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*),
BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) =>
Name(n) -> IRParser.parseType(t)
}.toSeq: _*),
)
}
}
Expand Down
10 changes: 10 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/BaseIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ import is.hail.types.virtual.Type
import is.hail.utils._
import is.hail.utils.StackSafe._

case class Name(str: String) extends AnyVal {
override def toString: String = str
}

abstract class BaseIR {
def typ: BaseType

Expand All @@ -28,6 +32,12 @@ abstract class BaseIR {
// New sentinel values can be obtained by `nextFlag` on `IRMetadata`.
var mark: Int = 0

def isAlphaEquiv(ctx: ExecuteContext, other: BaseIR): Boolean =
/* FIXME: rewrite to not rebuild the irs, by maintaining an env mapping left names to right
* names */
new NormalizeNames(iruid(_), allowFreeVariables = true)(ctx, this) ==
new NormalizeNames(iruid(_), allowFreeVariables = true)(ctx, other)

def mapChildrenWithIndex(f: (BaseIR, Int) => BaseIR): BaseIR = {
val newChildren = childrenSeq.view.zipWithIndex.map(f.tupled).toArray
if (childrenSeq.elementsSameObjects(newChildren))
Expand Down
26 changes: 13 additions & 13 deletions hail/src/main/scala/is/hail/expr/ir/Binds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ import is.hail.utils.FastSeq
import scala.collection.mutable

object Binds {
def apply(x: IR, v: String, i: Int): Boolean =
def apply(x: IR, v: Name, i: Int): Boolean =
Bindings.get(x, i).eval.exists(_._1 == v)
}

final case class Bindings[+T](
eval: IndexedSeq[(String, T)] = FastSeq.empty,
eval: IndexedSeq[(Name, T)] = FastSeq.empty,
agg: AggEnv[T] = AggEnv.NoOp,
scan: AggEnv[T] = AggEnv.NoOp,
relational: IndexedSeq[(String, T)] = FastSeq.empty,
relational: IndexedSeq[(Name, T)] = FastSeq.empty,
dropEval: Boolean = false,
) {
def map[U](f: (String, T) => U): Bindings[U] = Bindings(
def map[U](f: (Name, T) => U): Bindings[U] = Bindings(
eval.map { case (n, v) => n -> f(n, v) },
agg.map(f),
scan.map(f),
Expand Down Expand Up @@ -51,10 +51,10 @@ object Bindings {

// Create a `Bindings` which cannot see anything bound in the enclosing context.
private def inFreshScope(
eval: IndexedSeq[(String, Type)] = FastSeq.empty,
agg: Option[IndexedSeq[(String, Type)]] = None,
scan: Option[IndexedSeq[(String, Type)]] = None,
relational: IndexedSeq[(String, Type)] = FastSeq.empty,
eval: IndexedSeq[(Name, Type)] = FastSeq.empty,
agg: Option[IndexedSeq[(Name, Type)]] = None,
scan: Option[IndexedSeq[(Name, Type)]] = None,
relational: IndexedSeq[(Name, Type)] = FastSeq.empty,
): Bindings[Type] = Bindings(
eval,
agg.map(AggEnv.Create(_)).getOrElse(AggEnv.Drop),
Expand All @@ -67,15 +67,15 @@ object Bindings {
ir match {
case MatrixMapRows(child, _) if i == 1 =>
Bindings.inFreshScope(
eval = child.typ.rowBindings :+ "n_cols" -> TInt32,
eval = child.typ.rowBindings :+ Name("n_cols") -> TInt32,
agg = Some(child.typ.entryBindings),
scan = Some(child.typ.rowBindings),
)
case MatrixFilterRows(child, _) if i == 1 =>
Bindings.inFreshScope(child.typ.rowBindings)
case MatrixMapCols(child, _, _) if i == 1 =>
Bindings.inFreshScope(
eval = child.typ.colBindings :+ "n_rows" -> TInt64,
eval = child.typ.colBindings :+ Name("n_rows") -> TInt64,
agg = Some(child.typ.entryBindings),
scan = Some(child.typ.colBindings),
)
Expand Down Expand Up @@ -176,9 +176,9 @@ object Bindings {
private def childEnvValue(ir: IR, i: Int): Bindings[Type] =
ir match {
case Block(bindings, _) =>
val eval = mutable.ArrayBuilder.make[(String, Type)]
val agg = mutable.ArrayBuilder.make[(String, Type)]
val scan = mutable.ArrayBuilder.make[(String, Type)]
val eval = mutable.ArrayBuilder.make[(Name, Type)]
val agg = mutable.ArrayBuilder.make[(Name, Type)]
val scan = mutable.ArrayBuilder.make[(Name, Type)]
for (k <- 0 until i) bindings(k) match {
case Binding(name, value, Scope.EVAL) =>
eval += name -> value.typ
Expand Down
8 changes: 4 additions & 4 deletions hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ case class BlockMatrixPersistReader(id: String, typ: BlockMatrixType) extends Bl
HailContext.backend.getPersistedBlockMatrix(ctx.backendContext, id)
}

case class BlockMatrixMap(child: BlockMatrixIR, eltName: String, f: IR, needsDense: Boolean)
case class BlockMatrixMap(child: BlockMatrixIR, eltName: Name, f: IR, needsDense: Boolean)
extends BlockMatrixIR {
override def typecheck(): Unit =
assert(!(needsDense && child.typ.isSparse))
Expand Down Expand Up @@ -432,8 +432,8 @@ case object NeedsDense extends SparsityStrategy {
case class BlockMatrixMap2(
left: BlockMatrixIR,
right: BlockMatrixIR,
leftName: String,
rightName: String,
leftName: Name,
rightName: Name,
f: IR,
sparsityStrategy: SparsityStrategy,
) extends BlockMatrixIR {
Expand Down Expand Up @@ -1117,7 +1117,7 @@ case class BlockMatrixRandom(
BlockMatrix.random(shape(0), shape(1), blockSize, ctx.rngNonce, staticUID, gaussian)
}

case class RelationalLetBlockMatrix(name: String, value: IR, body: BlockMatrixIR)
case class RelationalLetBlockMatrix(name: Name, value: IR, body: BlockMatrixIR)
extends BlockMatrixIR {
override def typ: BlockMatrixType = body.typ

Expand Down
10 changes: 5 additions & 5 deletions hail/src/main/scala/is/hail/expr/ir/Compile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.io.PrintWriter

case class CodeCacheKey(
aggSigs: IndexedSeq[AggStateSig],
args: Seq[(String, EmitParamType)],
args: Seq[(Name, EmitParamType)],
body: IR,
)

Expand All @@ -35,7 +35,7 @@ case class CompiledFunction[T](
object Compile {
def apply[F: TypeInfo](
ctx: ExecuteContext,
params: IndexedSeq[(String, EmitParamType)],
params: IndexedSeq[(Name, EmitParamType)],
expectedCodeParamTypes: IndexedSeq[TypeInfo[_]],
expectedCodeReturnType: TypeInfo[_],
body: IR,
Expand All @@ -44,7 +44,7 @@ object Compile {
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F) = {

val normalizedBody =
new NormalizeNames(_.toString)(ctx, body, Env(params.map { case (n, _) => n -> n }: _*))
new NormalizeNames(_.toString, allowFreeVariables = true)(ctx, body).asInstanceOf[IR]
val k =
CodeCacheKey(FastSeq[AggStateSig](), params.map { case (n, pt) => (n, pt) }, normalizedBody)
(ctx.backend.lookupOrCompileCachedFunction[F](k) {
Expand Down Expand Up @@ -97,7 +97,7 @@ object CompileWithAggregators {
def apply[F: TypeInfo](
ctx: ExecuteContext,
aggSigs: Array[AggStateSig],
params: IndexedSeq[(String, EmitParamType)],
params: IndexedSeq[(Name, EmitParamType)],
expectedCodeParamTypes: IndexedSeq[TypeInfo[_]],
expectedCodeReturnType: TypeInfo[_],
body: IR,
Expand All @@ -107,7 +107,7 @@ object CompileWithAggregators {
(HailClassLoader, FS, HailTaskContext, Region) => (F with FunctionWithAggRegion),
) = {
val normalizedBody =
new NormalizeNames(_.toString)(ctx, body, Env(params.map { case (n, _) => n -> n }: _*))
new NormalizeNames(_.toString, allowFreeVariables = true)(ctx, body).asInstanceOf[IR]
val k = CodeCacheKey(aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody)
(ctx.backend.lookupOrCompileCachedFunction[F with FunctionWithAggRegion](k) {

Expand Down
48 changes: 28 additions & 20 deletions hail/src/main/scala/is/hail/expr/ir/DeprecatedIRBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ object DeprecatedIRBuilder {
implicit def booleanToProxy(b: Boolean): IRProxy = if (b) True() else False()

implicit def ref(s: Symbol): IRProxy = (env: E) =>
Ref(s.name, env.lookup(s.name))
Ref(Name(s.name), env.lookup(Name(s.name)))

implicit def symbolToSymbolProxy(s: Symbol): SymbolProxy = new SymbolProxy(s)

Expand Down Expand Up @@ -242,7 +242,7 @@ object DeprecatedIRBuilder {
def insertStruct(other: IRProxy, ordering: Option[IndexedSeq[String]] = None): IRProxy =
(env: E) => {
val right = other(env)
val sym = genUID()
val sym = freshName()
Let(
FastSeq(sym -> right),
InsertFields(
Expand All @@ -260,7 +260,7 @@ object DeprecatedIRBuilder {
def isNA: IRProxy = (env: E) => IsNA(ir(env))

def orElse(alt: IRProxy): IRProxy = { env: E =>
val uid = genUID()
val uid = freshName()
val eir = ir(env)
Let(FastSeq(uid -> eir), If(IsNA(Ref(uid, eir.typ)), alt(env), Ref(uid, eir.typ)))
}
Expand All @@ -270,23 +270,27 @@ object DeprecatedIRBuilder {
val eltType = array.typ.asInstanceOf[TArray].elementType
ToArray(StreamFilter(
ToStream(array),
pred.s.name,
pred.body(env.bind(pred.s.name -> eltType)),
Name(pred.s.name),
pred.body(env.bind(Name(pred.s.name) -> eltType)),
))
}

def map(f: LambdaProxy): IRProxy = (env: E) => {
val array = ir(env)
val eltType = array.typ.asInstanceOf[TArray].elementType
ToArray(StreamMap(ToStream(array), f.s.name, f.body(env.bind(f.s.name -> eltType))))
ToArray(StreamMap(
ToStream(array),
Name(f.s.name),
f.body(env.bind(Name(f.s.name) -> eltType)),
))
}

def aggExplode(f: LambdaProxy): IRProxy = (env: E) => {
val array = ir(env)
AggExplode(
ToStream(array),
f.s.name,
f.body(env.bind(f.s.name, array.typ.asInstanceOf[TArray].elementType)),
Name(f.s.name),
f.body(env.bind(Name(f.s.name), array.typ.asInstanceOf[TArray].elementType)),
isScan = false,
)
}
Expand All @@ -296,21 +300,25 @@ object DeprecatedIRBuilder {
val eltType = array.typ.asInstanceOf[TArray].elementType
ToArray(StreamFlatMap(
ToStream(array),
f.s.name,
ToStream(f.body(env.bind(f.s.name -> eltType))),
Name(f.s.name),
ToStream(f.body(env.bind(Name(f.s.name) -> eltType))),
))
}

def streamAgg(f: LambdaProxy): IRProxy = (env: E) => {
val array = ir(env)
val eltType = array.typ.asInstanceOf[TArray].elementType
StreamAgg(ToStream(array), f.s.name, f.body(env.bind(f.s.name -> eltType)))
StreamAgg(ToStream(array), Name(f.s.name), f.body(env.bind(Name(f.s.name) -> eltType)))
}

def streamAggScan(f: LambdaProxy): IRProxy = (env: E) => {
val array = ir(env)
val eltType = array.typ.asInstanceOf[TArray].elementType
ToArray(StreamAggScan(ToStream(array), f.s.name, f.body(env.bind(f.s.name -> eltType))))
ToArray(StreamAggScan(
ToStream(array),
Name(f.s.name),
f.body(env.bind(Name(f.s.name) -> eltType)),
))
}

def arraySlice(start: IRProxy, stop: Option[IRProxy], step: IRProxy): IRProxy = {
Expand All @@ -335,9 +343,9 @@ object DeprecatedIRBuilder {
val eltType = array.typ.asInstanceOf[TArray].elementType
AggArrayPerElement(
array,
elementsSym.name,
indexSym.name,
aggBody.apply(env.bind(elementsSym.name -> eltType, indexSym.name -> TInt32)),
Name(elementsSym.name),
Name(indexSym.name),
aggBody.apply(env.bind(Name(elementsSym.name) -> eltType, Name(indexSym.name) -> TInt32)),
knownLength.map(_(env)),
isScan = false,
)
Expand Down Expand Up @@ -384,8 +392,8 @@ object DeprecatedIRBuilder {
var newEnv = env
val resolvedBindings = bindings.map { case BindingProxy(sym, value, scope) =>
val resolvedValue = value(newEnv)
newEnv = newEnv.bind(sym.name -> resolvedValue.typ)
Binding(sym.name, resolvedValue, scope)
newEnv = newEnv.bind(Name(sym.name) -> resolvedValue.typ)
Binding(Name(sym.name), resolvedValue, scope)
}
Block(resolvedBindings, body(newEnv))
}
Expand All @@ -394,13 +402,13 @@ object DeprecatedIRBuilder {
object let extends Dynamic {
def applyDynamicNamed(method: String)(args: (String, IRProxy)*): LetProxy = {
assert(method == "apply")
letDyn(args: _*)
letDyn(args.map { case (n, ir) => Name(n) -> ir }: _*)
}
}

object letDyn {
def apply(args: (String, IRProxy)*): LetProxy =
new LetProxy(args.map { case (s, b) => BindingProxy(Symbol(s), b, Scope.EVAL) }.toFastSeq)
def apply(args: (Name, IRProxy)*): LetProxy =
new LetProxy(args.map { case (s, b) => BindingProxy(Symbol(s.str), b, Scope.EVAL) }.toFastSeq)
}

class LetProxy(val bindings: IndexedSeq[BindingProxy]) extends AnyVal {
Expand Down
Loading

0 comments on commit 38496f3

Please sign in to comment.