From d0480c9c9c4e28cd2d34da3d482546a5161ba289 Mon Sep 17 00:00:00 2001 From: Nimalan Date: Sun, 3 Mar 2024 22:42:59 +0530 Subject: [PATCH] Support unrolling functions with scalar arguments inside unrolled loops (#419) * Support functions with scalar arguments inside unrolled loops * Update error message --------- Co-authored-by: Rachit Nigam --- src/main/scala/common/Errors.scala | 2 +- src/main/scala/passes/WellFormedCheck.scala | 38 ++++++++++++++++++--- src/test/scala/TypeCheckerSpec.scala | 15 ++++++-- 3 files changed, 47 insertions(+), 8 deletions(-) diff --git a/src/main/scala/common/Errors.scala b/src/main/scala/common/Errors.scala index 162110fa..bebc158d 100644 --- a/src/main/scala/common/Errors.scala +++ b/src/main/scala/common/Errors.scala @@ -88,7 +88,7 @@ object Errors { extends TypeError(s"$op cannot be inside an unrolled loop", pos) case class FuncInUnroll(pos: Position) - extends TypeError("Cannot call function inside unrolled loop.", pos) + extends TypeError("Cannot call function with non scalar arguments (like arrays) inside unrolled loop.", pos) // Unrolling and banking errors case class UnrollRangeError(pos: Position, rSize: Int, uFactor: Int) diff --git a/src/main/scala/passes/WellFormedCheck.scala b/src/main/scala/passes/WellFormedCheck.scala index 7696ddc2..ca0d2aaf 100644 --- a/src/main/scala/passes/WellFormedCheck.scala +++ b/src/main/scala/passes/WellFormedCheck.scala @@ -15,10 +15,37 @@ object WellFormedChecker { def check(p: Prog) = WFCheck.check(p) private case class WFEnv( + map: Map[Id, FuncDef] = Map(), insideUnroll: Boolean = false, insideFunc: Boolean = false - ) extends ScopeManager[WFEnv] { + ) extends ScopeManager[WFEnv] + with Tracker[Id, FuncDef, WFEnv] { def merge(that: WFEnv): WFEnv = this + + override def add(k: Id, v: FuncDef): WFEnv = + WFEnv( + insideUnroll=insideUnroll, + insideFunc=insideFunc, + map=this.map + (k -> v) + ) + + override def get(k: Id): Option[FuncDef] = this.map.get(k) + + def canHaveFunctionInUnroll(k: Id): Boolean = { + this.get(k) match { + case Some(FuncDef(_, args, _, _)) => + if (this.insideUnroll) { + args.foldLeft(true)({ + (r, arg) => arg.typ match { + case TArray(_, _, _) => false + case _ => r + } + }) + } else + true + case None => true // This is supposed to be unreachable + } + } } private final case object WFCheck extends PartialChecker { @@ -27,8 +54,9 @@ object WellFormedChecker { val emptyEnv = WFEnv() override def checkDef(defi: Definition)(implicit env: Env) = defi match { - case FuncDef(_, _, _, bodyOpt) => - bodyOpt.map(checkC(_)(env.copy(insideFunc = true))).getOrElse(env) + case fndef @ FuncDef(id, _, _, bodyOpt) => + val nenv = env.add(id, fndef) + bodyOpt.map(checkC(_)(nenv.copy(insideFunc = true))).getOrElse(nenv) case _: RecordDef => env } @@ -42,8 +70,8 @@ object WellFormedChecker { throw NotInBinder(expr.pos, "Record Literal") case (expr: EArrLiteral, _) => throw NotInBinder(expr.pos, "Array Literal") - case (expr: EApp, env) => { - assertOrThrow(env.insideUnroll == false, FuncInUnroll(expr.pos)) + case (expr @ EApp(id, _), env) => { + assertOrThrow(env.canHaveFunctionInUnroll(id) == true, FuncInUnroll(expr.pos)) env } } diff --git a/src/test/scala/TypeCheckerSpec.scala b/src/test/scala/TypeCheckerSpec.scala index 507b4cda..0a650c32 100644 --- a/src/test/scala/TypeCheckerSpec.scala +++ b/src/test/scala/TypeCheckerSpec.scala @@ -1009,10 +1009,10 @@ class TypeCheckerSpec extends FunSpec { } } - it("disallowed inside unrolled loops") { + it("should not allow functions with array arguments inside unrolled loops") { assertThrows[FuncInUnroll] { typeCheck(""" - def bar(a: bool) = { } + def bar(a: bool[4]) = { } for (let i = 0..10) unroll 5 { bar(tre); } @@ -1020,6 +1020,17 @@ class TypeCheckerSpec extends FunSpec { } } + it("should allow functions with scalar args in unrolled loops") { + typeCheck( + """ + def bar(a: bool) = { } + let tre: bool; + for (let i = 0..10) unroll 5 { + bar(tre); + } + """) + } + it("completely consume array parameters") { assertThrows[AlreadyConsumed] { typeCheck("""