Skip to content

Commit

Permalink
Support unrolling functions with scalar arguments inside unrolled loo…
Browse files Browse the repository at this point in the history
…ps (#419)

* Support functions with scalar arguments inside unrolled loops

* Update error message

---------

Co-authored-by: Rachit Nigam <rachit.nigam12@gmail.com>
  • Loading branch information
Mark1626 and rachitnigam committed Mar 3, 2024
1 parent 04ef979 commit d0480c9
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/main/scala/common/Errors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ object Errors {
extends TypeError(s"$op cannot be inside an unrolled loop", pos)

case class FuncInUnroll(pos: Position)
extends TypeError("Cannot call function inside unrolled loop.", pos)
extends TypeError("Cannot call function with non scalar arguments (like arrays) inside unrolled loop.", pos)

// Unrolling and banking errors
case class UnrollRangeError(pos: Position, rSize: Int, uFactor: Int)
Expand Down
38 changes: 33 additions & 5 deletions src/main/scala/passes/WellFormedCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,37 @@ object WellFormedChecker {
def check(p: Prog) = WFCheck.check(p)

private case class WFEnv(
map: Map[Id, FuncDef] = Map(),
insideUnroll: Boolean = false,
insideFunc: Boolean = false
) extends ScopeManager[WFEnv] {
) extends ScopeManager[WFEnv]
with Tracker[Id, FuncDef, WFEnv] {
def merge(that: WFEnv): WFEnv = this

override def add(k: Id, v: FuncDef): WFEnv =
WFEnv(
insideUnroll=insideUnroll,
insideFunc=insideFunc,
map=this.map + (k -> v)
)

override def get(k: Id): Option[FuncDef] = this.map.get(k)

def canHaveFunctionInUnroll(k: Id): Boolean = {
this.get(k) match {
case Some(FuncDef(_, args, _, _)) =>
if (this.insideUnroll) {
args.foldLeft(true)({
(r, arg) => arg.typ match {
case TArray(_, _, _) => false
case _ => r
}
})
} else
true
case None => true // This is supposed to be unreachable
}
}
}

private final case object WFCheck extends PartialChecker {
Expand All @@ -27,8 +54,9 @@ object WellFormedChecker {
val emptyEnv = WFEnv()

override def checkDef(defi: Definition)(implicit env: Env) = defi match {
case FuncDef(_, _, _, bodyOpt) =>
bodyOpt.map(checkC(_)(env.copy(insideFunc = true))).getOrElse(env)
case fndef @ FuncDef(id, _, _, bodyOpt) =>
val nenv = env.add(id, fndef)
bodyOpt.map(checkC(_)(nenv.copy(insideFunc = true))).getOrElse(nenv)
case _: RecordDef => env
}

Expand All @@ -42,8 +70,8 @@ object WellFormedChecker {
throw NotInBinder(expr.pos, "Record Literal")
case (expr: EArrLiteral, _) =>
throw NotInBinder(expr.pos, "Array Literal")
case (expr: EApp, env) => {
assertOrThrow(env.insideUnroll == false, FuncInUnroll(expr.pos))
case (expr @ EApp(id, _), env) => {
assertOrThrow(env.canHaveFunctionInUnroll(id) == true, FuncInUnroll(expr.pos))
env
}
}
Expand Down
15 changes: 13 additions & 2 deletions src/test/scala/TypeCheckerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1009,17 +1009,28 @@ class TypeCheckerSpec extends FunSpec {
}
}

it("disallowed inside unrolled loops") {
it("should not allow functions with array arguments inside unrolled loops") {
assertThrows[FuncInUnroll] {
typeCheck("""
def bar(a: bool) = { }
def bar(a: bool[4]) = { }
for (let i = 0..10) unroll 5 {
bar(tre);
}
""")
}
}

it("should allow functions with scalar args in unrolled loops") {
typeCheck(
"""
def bar(a: bool) = { }
let tre: bool;
for (let i = 0..10) unroll 5 {
bar(tre);
}
""")
}

it("completely consume array parameters") {
assertThrows[AlreadyConsumed] {
typeCheck("""
Expand Down

0 comments on commit d0480c9

Please sign in to comment.