Skip to content

Commit

Permalink
[compiler] add newtype for bound names in IR (#14547)
Browse files Browse the repository at this point in the history
This PR introduces a new type `Name` for representing bound variables in
the IR, replacing `String`. For now, it is just an `AnyVal` wrapper
around `String`, but in the future I would like to take advantage of the
new type. For example, I'd like to:
* change equality of `Name` from string comparison to comparing object
identity with `eq`. That way `freshName` becomes just `new Name()`, with
stronger guarantees that the new name doesn't occur anywhere in the
current IR, without needing to maintain global state as we do now.
* get rid of `NormalizeNames`, instead enforcing the global uniqueness
of names as a basic invariant of the IR (typecheck could also check this
invariant)
* keep a string in the `Name`, but no longer require it to be unique.
Instead it's just a suggestion for how to show the name in printouts,
adding a uniqueifying suffix as needed. With `NormalizeNames` gone, this
would let us preserve meaningful variable names further in the lowering
pipeline.
* possibly keep other state in the `Name`, for example to allow a more
efficient implementation of environments, similar to the `mark` state on
`BaseIR`

This is obviously a large change, but there are only a few conceptual
pieces (appologies for not managing to separate these out):
* attempt to minimize the number of locations in which the `Name`
constructor is called, to make future refactorings easier
* add `freshName()`, which just wraps `genUID()`, returning a `Name`
* convert IR construction to use the convenience methods in
`ir.package`, which take scala lambdas to represent blocks with bound
variables, instead of manually creating new variable names
* replace uses of the magic constant variable names (`row`, `va`, `sa`,
`g`, `global`) with constants (`TableIR.{rowName, globalName}`,
`MatrixIR.{rowName, colName, entryName, globalName}`)
* the above changes modified the names we use for bound variables in
many places. That shouldn't matter, but it cought a couple bugs where it
did.
* `NormalizeNames` optionally allows the IR to contain free variables.
But it didn't do anything to ensure the newly generated variable names
are distinct from any contained free variables. Thus it was possible to
rename a bound variable to mistakenly capture a contained free variable.
I've fixed that.
* `SimplifySuite` compared simplified IR with the pre-constructed
expected IR, carefully controlling the `genUID` global state to make
simplify generate exactly the names expected. I've replaced that by just
comparing with the expected IR using a alpha-equivalence comparison.
  • Loading branch information
patrick-schultz committed May 16, 2024
1 parent 1edd50d commit 8a3c60d
Show file tree
Hide file tree
Showing 77 changed files with 2,528 additions and 3,094 deletions.
4 changes: 3 additions & 1 deletion hail/src/main/scala/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,9 @@ class LocalBackend(val tmpdir: String) extends Backend with BackendWithCodeCache
IRParser.parse_value_ir(
s,
IRParserEnvironment(ctx, persistedIR.toMap),
BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*),
BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) =>
Name(n) -> IRParser.parseType(t)
}.toSeq: _*),
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import is.hail.asm4s._
import is.hail.backend._
import is.hail.expr.Validate
import is.hail.expr.ir.{
Compile, IR, IRParser, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck,
Compile, IR, IRParser, LoweringAnalyses, MakeTuple, Name, SortField, TableIR, TableReader,
TypeCheck,
}
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.functions.IRFunctionRegistry
Expand Down Expand Up @@ -554,7 +555,7 @@ case class ServiceBackendExecutePayload(
case class SerializedIRFunction(
name: String,
type_parameters: Array[String],
value_parameter_names: Array[String],
value_parameter_names: Array[Name],
value_parameter_types: Array[String],
return_type: String,
rendered_body: String,
Expand Down
4 changes: 3 additions & 1 deletion hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,9 @@ class SparkBackend(
IRParser.parse_value_ir(
s,
IRParserEnvironment(ctx, irMap = persistedIR.toMap),
BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*),
BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) =>
Name(n) -> IRParser.parseType(t)
}.toSeq: _*),
)
}
}
Expand Down
10 changes: 10 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/BaseIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ import is.hail.types.virtual.Type
import is.hail.utils._
import is.hail.utils.StackSafe._

case class Name(str: String) extends AnyVal {
override def toString: String = str
}

abstract class BaseIR {
def typ: BaseType

Expand All @@ -28,6 +32,12 @@ abstract class BaseIR {
// New sentinel values can be obtained by `nextFlag` on `IRMetadata`.
var mark: Int = 0

def isAlphaEquiv(ctx: ExecuteContext, other: BaseIR): Boolean =
/* FIXME: rewrite to not rebuild the irs, by maintaining an env mapping left names to right
* names */
new NormalizeNames(iruid(_), allowFreeVariables = true)(ctx, this) ==
new NormalizeNames(iruid(_), allowFreeVariables = true)(ctx, other)

def mapChildrenWithIndex(f: (BaseIR, Int) => BaseIR): BaseIR = {
val newChildren = childrenSeq.view.zipWithIndex.map(f.tupled).toArray
if (childrenSeq.elementsSameObjects(newChildren))
Expand Down
14 changes: 7 additions & 7 deletions hail/src/main/scala/is/hail/expr/ir/Binds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ object AggEnv {
}

object Binds {
def apply(x: IR, v: String, i: Int): Boolean = {
def apply(x: IR, v: Name, i: Int): Boolean = {
val bindings = Bindings.get(x, i)
bindings.all.zipWithIndex.exists { case ((name, _), i) =>
name == v && bindings.eval.contains(i)
Expand All @@ -44,14 +44,14 @@ object Binds {
}

final case class Bindings[+T](
all: IndexedSeq[(String, T)],
all: IndexedSeq[(Name, T)],
eval: IndexedSeq[Int],
agg: AggEnv,
scan: AggEnv,
relational: IndexedSeq[Int],
dropEval: Boolean,
) {
def map[U](f: (String, T) => U): Bindings[U] =
def map[U](f: (Name, T) => U): Bindings[U] =
copy(all = all.map { case (n, t) => (n, f(n, t)) })

def allEmpty: Boolean =
Expand All @@ -63,7 +63,7 @@ final case class Bindings[+T](

object Bindings {
def apply[T](
bindings: IndexedSeq[(String, T)] = FastSeq.empty,
bindings: IndexedSeq[(Name, T)] = FastSeq.empty,
eval: IndexedSeq[Int] = FastSeq.empty,
agg: AggEnv = AggEnv.NoOp,
scan: AggEnv = AggEnv.NoOp,
Expand Down Expand Up @@ -91,7 +91,7 @@ object Bindings {

// Create a `Bindings` which cannot see anything bound in the enclosing context.
private def inFreshScope(
bindings: IndexedSeq[(String, Type)] = FastSeq.empty,
bindings: IndexedSeq[(Name, Type)] = FastSeq.empty,
eval: IndexedSeq[Int] = FastSeq.empty,
agg: Option[IndexedSeq[Int]] = None,
scan: Option[IndexedSeq[Int]] = None,
Expand All @@ -113,7 +113,7 @@ object Bindings {
ir match {
case MatrixMapRows(child, _) if i == 1 =>
Bindings.inFreshScope(
child.typ.entryBindings :+ "n_cols" -> TInt32,
child.typ.entryBindings :+ Name("n_cols") -> TInt32,
eval = rowInEntryBindings :+ 4,
agg = Some(entryBindings),
scan = Some(rowInEntryBindings),
Expand All @@ -122,7 +122,7 @@ object Bindings {
Bindings.inFreshScope(child.typ.rowBindings)
case MatrixMapCols(child, _, _) if i == 1 =>
Bindings.inFreshScope(
child.typ.entryBindings :+ "n_rows" -> TInt64,
child.typ.entryBindings :+ Name("n_rows") -> TInt64,
eval = colInEntryBindings :+ 4,
agg = Some(entryBindings),
scan = Some(colInEntryBindings),
Expand Down
8 changes: 4 additions & 4 deletions hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ case class BlockMatrixPersistReader(id: String, typ: BlockMatrixType) extends Bl
HailContext.backend.getPersistedBlockMatrix(ctx.backendContext, id)
}

case class BlockMatrixMap(child: BlockMatrixIR, eltName: String, f: IR, needsDense: Boolean)
case class BlockMatrixMap(child: BlockMatrixIR, eltName: Name, f: IR, needsDense: Boolean)
extends BlockMatrixIR {
override def typecheck(): Unit =
assert(!(needsDense && child.typ.isSparse))
Expand Down Expand Up @@ -432,8 +432,8 @@ case object NeedsDense extends SparsityStrategy {
case class BlockMatrixMap2(
left: BlockMatrixIR,
right: BlockMatrixIR,
leftName: String,
rightName: String,
leftName: Name,
rightName: Name,
f: IR,
sparsityStrategy: SparsityStrategy,
) extends BlockMatrixIR {
Expand Down Expand Up @@ -1117,7 +1117,7 @@ case class BlockMatrixRandom(
BlockMatrix.random(shape(0), shape(1), blockSize, ctx.rngNonce, staticUID, gaussian)
}

case class RelationalLetBlockMatrix(name: String, value: IR, body: BlockMatrixIR)
case class RelationalLetBlockMatrix(name: Name, value: IR, body: BlockMatrixIR)
extends BlockMatrixIR {
override def typ: BlockMatrixType = body.typ

Expand Down
10 changes: 5 additions & 5 deletions hail/src/main/scala/is/hail/expr/ir/Compile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.io.PrintWriter

case class CodeCacheKey(
aggSigs: IndexedSeq[AggStateSig],
args: Seq[(String, EmitParamType)],
args: Seq[(Name, EmitParamType)],
body: IR,
)

Expand All @@ -35,7 +35,7 @@ case class CompiledFunction[T](
object Compile {
def apply[F: TypeInfo](
ctx: ExecuteContext,
params: IndexedSeq[(String, EmitParamType)],
params: IndexedSeq[(Name, EmitParamType)],
expectedCodeParamTypes: IndexedSeq[TypeInfo[_]],
expectedCodeReturnType: TypeInfo[_],
body: IR,
Expand All @@ -44,7 +44,7 @@ object Compile {
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F) = {

val normalizedBody =
new NormalizeNames(_.toString)(ctx, body, Env(params.map { case (n, _) => n -> n }: _*))
new NormalizeNames(_.toString, allowFreeVariables = true)(ctx, body).asInstanceOf[IR]
val k =
CodeCacheKey(FastSeq[AggStateSig](), params.map { case (n, pt) => (n, pt) }, normalizedBody)
(ctx.backend.lookupOrCompileCachedFunction[F](k) {
Expand Down Expand Up @@ -97,7 +97,7 @@ object CompileWithAggregators {
def apply[F: TypeInfo](
ctx: ExecuteContext,
aggSigs: Array[AggStateSig],
params: IndexedSeq[(String, EmitParamType)],
params: IndexedSeq[(Name, EmitParamType)],
expectedCodeParamTypes: IndexedSeq[TypeInfo[_]],
expectedCodeReturnType: TypeInfo[_],
body: IR,
Expand All @@ -107,7 +107,7 @@ object CompileWithAggregators {
(HailClassLoader, FS, HailTaskContext, Region) => (F with FunctionWithAggRegion),
) = {
val normalizedBody =
new NormalizeNames(_.toString)(ctx, body, Env(params.map { case (n, _) => n -> n }: _*))
new NormalizeNames(_.toString, allowFreeVariables = true)(ctx, body).asInstanceOf[IR]
val k = CodeCacheKey(aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody)
(ctx.backend.lookupOrCompileCachedFunction[F with FunctionWithAggRegion](k) {

Expand Down
48 changes: 28 additions & 20 deletions hail/src/main/scala/is/hail/expr/ir/DeprecatedIRBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ object DeprecatedIRBuilder {
implicit def booleanToProxy(b: Boolean): IRProxy = if (b) True() else False()

implicit def ref(s: Symbol): IRProxy = (env: E) =>
Ref(s.name, env.lookup(s.name))
Ref(Name(s.name), env.lookup(Name(s.name)))

implicit def symbolToSymbolProxy(s: Symbol): SymbolProxy = new SymbolProxy(s)

Expand Down Expand Up @@ -242,7 +242,7 @@ object DeprecatedIRBuilder {
def insertStruct(other: IRProxy, ordering: Option[IndexedSeq[String]] = None): IRProxy =
(env: E) => {
val right = other(env)
val sym = genUID()
val sym = freshName()
Let(
FastSeq(sym -> right),
InsertFields(
Expand All @@ -260,7 +260,7 @@ object DeprecatedIRBuilder {
def isNA: IRProxy = (env: E) => IsNA(ir(env))

def orElse(alt: IRProxy): IRProxy = { env: E =>
val uid = genUID()
val uid = freshName()
val eir = ir(env)
Let(FastSeq(uid -> eir), If(IsNA(Ref(uid, eir.typ)), alt(env), Ref(uid, eir.typ)))
}
Expand All @@ -270,23 +270,27 @@ object DeprecatedIRBuilder {
val eltType = array.typ.asInstanceOf[TArray].elementType
ToArray(StreamFilter(
ToStream(array),
pred.s.name,
pred.body(env.bind(pred.s.name -> eltType)),
Name(pred.s.name),
pred.body(env.bind(Name(pred.s.name) -> eltType)),
))
}

def map(f: LambdaProxy): IRProxy = (env: E) => {
val array = ir(env)
val eltType = array.typ.asInstanceOf[TArray].elementType
ToArray(StreamMap(ToStream(array), f.s.name, f.body(env.bind(f.s.name -> eltType))))
ToArray(StreamMap(
ToStream(array),
Name(f.s.name),
f.body(env.bind(Name(f.s.name) -> eltType)),
))
}

def aggExplode(f: LambdaProxy): IRProxy = (env: E) => {
val array = ir(env)
AggExplode(
ToStream(array),
f.s.name,
f.body(env.bind(f.s.name, array.typ.asInstanceOf[TArray].elementType)),
Name(f.s.name),
f.body(env.bind(Name(f.s.name), array.typ.asInstanceOf[TArray].elementType)),
isScan = false,
)
}
Expand All @@ -296,21 +300,25 @@ object DeprecatedIRBuilder {
val eltType = array.typ.asInstanceOf[TArray].elementType
ToArray(StreamFlatMap(
ToStream(array),
f.s.name,
ToStream(f.body(env.bind(f.s.name -> eltType))),
Name(f.s.name),
ToStream(f.body(env.bind(Name(f.s.name) -> eltType))),
))
}

def streamAgg(f: LambdaProxy): IRProxy = (env: E) => {
val array = ir(env)
val eltType = array.typ.asInstanceOf[TArray].elementType
StreamAgg(ToStream(array), f.s.name, f.body(env.bind(f.s.name -> eltType)))
StreamAgg(ToStream(array), Name(f.s.name), f.body(env.bind(Name(f.s.name) -> eltType)))
}

def streamAggScan(f: LambdaProxy): IRProxy = (env: E) => {
val array = ir(env)
val eltType = array.typ.asInstanceOf[TArray].elementType
ToArray(StreamAggScan(ToStream(array), f.s.name, f.body(env.bind(f.s.name -> eltType))))
ToArray(StreamAggScan(
ToStream(array),
Name(f.s.name),
f.body(env.bind(Name(f.s.name) -> eltType)),
))
}

def arraySlice(start: IRProxy, stop: Option[IRProxy], step: IRProxy): IRProxy = {
Expand All @@ -335,9 +343,9 @@ object DeprecatedIRBuilder {
val eltType = array.typ.asInstanceOf[TArray].elementType
AggArrayPerElement(
array,
elementsSym.name,
indexSym.name,
aggBody.apply(env.bind(elementsSym.name -> eltType, indexSym.name -> TInt32)),
Name(elementsSym.name),
Name(indexSym.name),
aggBody.apply(env.bind(Name(elementsSym.name) -> eltType, Name(indexSym.name) -> TInt32)),
knownLength.map(_(env)),
isScan = false,
)
Expand Down Expand Up @@ -384,8 +392,8 @@ object DeprecatedIRBuilder {
var newEnv = env
val resolvedBindings = bindings.map { case BindingProxy(sym, value, scope) =>
val resolvedValue = value(newEnv)
newEnv = newEnv.bind(sym.name -> resolvedValue.typ)
Binding(sym.name, resolvedValue, scope)
newEnv = newEnv.bind(Name(sym.name) -> resolvedValue.typ)
Binding(Name(sym.name), resolvedValue, scope)
}
Block(resolvedBindings, body(newEnv))
}
Expand All @@ -394,13 +402,13 @@ object DeprecatedIRBuilder {
object let extends Dynamic {
def applyDynamicNamed(method: String)(args: (String, IRProxy)*): LetProxy = {
assert(method == "apply")
letDyn(args: _*)
letDyn(args.map { case (n, ir) => Name(n) -> ir }: _*)
}
}

object letDyn {
def apply(args: (String, IRProxy)*): LetProxy =
new LetProxy(args.map { case (s, b) => BindingProxy(Symbol(s), b, Scope.EVAL) }.toFastSeq)
def apply(args: (Name, IRProxy)*): LetProxy =
new LetProxy(args.map { case (s, b) => BindingProxy(Symbol(s.str), b, Scope.EVAL) }.toFastSeq)
}

class LetProxy(val bindings: IndexedSeq[BindingProxy]) extends AnyVal {
Expand Down
Loading

0 comments on commit 8a3c60d

Please sign in to comment.