From d0480c9c9c4e28cd2d34da3d482546a5161ba289 Mon Sep 17 00:00:00 2001 From: Nimalan Date: Sun, 3 Mar 2024 22:42:59 +0530 Subject: [PATCH 1/2] Support unrolling functions with scalar arguments inside unrolled loops (#419) * Support functions with scalar arguments inside unrolled loops * Update error message --------- Co-authored-by: Rachit Nigam --- src/main/scala/common/Errors.scala | 2 +- src/main/scala/passes/WellFormedCheck.scala | 38 ++++++++++++++++++--- src/test/scala/TypeCheckerSpec.scala | 15 ++++++-- 3 files changed, 47 insertions(+), 8 deletions(-) diff --git a/src/main/scala/common/Errors.scala b/src/main/scala/common/Errors.scala index 162110fa..bebc158d 100644 --- a/src/main/scala/common/Errors.scala +++ b/src/main/scala/common/Errors.scala @@ -88,7 +88,7 @@ object Errors { extends TypeError(s"$op cannot be inside an unrolled loop", pos) case class FuncInUnroll(pos: Position) - extends TypeError("Cannot call function inside unrolled loop.", pos) + extends TypeError("Cannot call function with non scalar arguments (like arrays) inside unrolled loop.", pos) // Unrolling and banking errors case class UnrollRangeError(pos: Position, rSize: Int, uFactor: Int) diff --git a/src/main/scala/passes/WellFormedCheck.scala b/src/main/scala/passes/WellFormedCheck.scala index 7696ddc2..ca0d2aaf 100644 --- a/src/main/scala/passes/WellFormedCheck.scala +++ b/src/main/scala/passes/WellFormedCheck.scala @@ -15,10 +15,37 @@ object WellFormedChecker { def check(p: Prog) = WFCheck.check(p) private case class WFEnv( + map: Map[Id, FuncDef] = Map(), insideUnroll: Boolean = false, insideFunc: Boolean = false - ) extends ScopeManager[WFEnv] { + ) extends ScopeManager[WFEnv] + with Tracker[Id, FuncDef, WFEnv] { def merge(that: WFEnv): WFEnv = this + + override def add(k: Id, v: FuncDef): WFEnv = + WFEnv( + insideUnroll=insideUnroll, + insideFunc=insideFunc, + map=this.map + (k -> v) + ) + + override def get(k: Id): Option[FuncDef] = this.map.get(k) + + def canHaveFunctionInUnroll(k: Id): Boolean = { + this.get(k) match { + case Some(FuncDef(_, args, _, _)) => + if (this.insideUnroll) { + args.foldLeft(true)({ + (r, arg) => arg.typ match { + case TArray(_, _, _) => false + case _ => r + } + }) + } else + true + case None => true // This is supposed to be unreachable + } + } } private final case object WFCheck extends PartialChecker { @@ -27,8 +54,9 @@ object WellFormedChecker { val emptyEnv = WFEnv() override def checkDef(defi: Definition)(implicit env: Env) = defi match { - case FuncDef(_, _, _, bodyOpt) => - bodyOpt.map(checkC(_)(env.copy(insideFunc = true))).getOrElse(env) + case fndef @ FuncDef(id, _, _, bodyOpt) => + val nenv = env.add(id, fndef) + bodyOpt.map(checkC(_)(nenv.copy(insideFunc = true))).getOrElse(nenv) case _: RecordDef => env } @@ -42,8 +70,8 @@ object WellFormedChecker { throw NotInBinder(expr.pos, "Record Literal") case (expr: EArrLiteral, _) => throw NotInBinder(expr.pos, "Array Literal") - case (expr: EApp, env) => { - assertOrThrow(env.insideUnroll == false, FuncInUnroll(expr.pos)) + case (expr @ EApp(id, _), env) => { + assertOrThrow(env.canHaveFunctionInUnroll(id) == true, FuncInUnroll(expr.pos)) env } } diff --git a/src/test/scala/TypeCheckerSpec.scala b/src/test/scala/TypeCheckerSpec.scala index 507b4cda..0a650c32 100644 --- a/src/test/scala/TypeCheckerSpec.scala +++ b/src/test/scala/TypeCheckerSpec.scala @@ -1009,10 +1009,10 @@ class TypeCheckerSpec extends FunSpec { } } - it("disallowed inside unrolled loops") { + it("should not allow functions with array arguments inside unrolled loops") { assertThrows[FuncInUnroll] { typeCheck(""" - def bar(a: bool) = { } + def bar(a: bool[4]) = { } for (let i = 0..10) unroll 5 { bar(tre); } @@ -1020,6 +1020,17 @@ class TypeCheckerSpec extends FunSpec { } } + it("should allow functions with scalar args in unrolled loops") { + typeCheck( + """ + def bar(a: bool) = { } + let tre: bool; + for (let i = 0..10) unroll 5 { + bar(tre); + } + """) + } + it("completely consume array parameters") { assertThrows[AlreadyConsumed] { typeCheck(""" From 410d42b544eb040fdbe886732fc0e296b6db0642 Mon Sep 17 00:00:00 2001 From: Nimalan Date: Tue, 5 Mar 2024 22:11:01 +0530 Subject: [PATCH 2/2] Migrate to Scala 3.3.1 LTS (#421) --- .scalafmt.conf | 2 +- build.sbt | 36 +++++++------- fuse | 2 +- project/assembly.sbt | 2 +- src/main/scala/Compiler.scala | 10 ++-- src/main/scala/GenerateExec.scala | 12 ++--- src/main/scala/Main.scala | 31 ++++-------- src/main/scala/Parser.scala | 26 ++++------ src/main/scala/Utils.scala | 2 +- src/main/scala/backends/Backend.scala | 2 +- src/main/scala/backends/CppLike.scala | 10 ++-- src/main/scala/backends/CppRunnable.scala | 4 +- src/main/scala/backends/VivadoBackend.scala | 14 +++--- src/main/scala/backends/calyx/Ast.scala | 28 +++++------ src/main/scala/backends/calyx/Backend.scala | 48 +++++++++---------- src/main/scala/backends/calyx/Helpers.scala | 4 +- src/main/scala/common/Configuration.scala | 14 +++--- src/main/scala/common/Document.scala | 14 +++--- src/main/scala/common/EnvHelpers.scala | 2 +- src/main/scala/common/MultiSet.scala | 6 +-- src/main/scala/common/Pretty.scala | 20 ++++---- src/main/scala/common/Syntax.scala | 32 ++++++------- src/main/scala/common/Transformer.scala | 2 +- src/main/scala/passes/AddBitWidth.scala | 6 +-- src/main/scala/passes/BoundsCheck.scala | 14 +++--- src/main/scala/passes/DependentLoops.scala | 12 ++--- src/main/scala/passes/HoistMemoryReads.scala | 2 +- src/main/scala/passes/HoistSlowBinop.scala | 2 +- src/main/scala/passes/LoopCheck.scala | 10 ++-- src/main/scala/passes/LowerForLoops.scala | 12 ++--- src/main/scala/passes/LowerUnroll.scala | 42 ++++++++-------- src/main/scala/passes/RewriteView.scala | 6 +-- src/main/scala/passes/Sequentialize.scala | 10 ++-- src/main/scala/passes/WellFormedCheck.scala | 4 +- src/main/scala/typechecker/AffineCheck.scala | 20 ++++---- src/main/scala/typechecker/AffineEnv.scala | 10 ++-- .../scala/typechecker/CapabilityChecker.scala | 7 +-- .../scala/typechecker/CapabilityEnv.scala | 8 ++-- src/main/scala/typechecker/Gadget.scala | 4 +- src/main/scala/typechecker/Info.scala | 4 +- src/main/scala/typechecker/Subtyping.scala | 10 ++-- src/main/scala/typechecker/TypeCheck.scala | 44 ++++++++--------- src/main/scala/typechecker/TypeEnv.scala | 4 +- src/project/build.properties | 2 +- src/test/scala/ParsingPositive.scala | 3 +- src/test/scala/TypeCheckerSpec.scala | 4 +- 46 files changed, 270 insertions(+), 293 deletions(-) diff --git a/.scalafmt.conf b/.scalafmt.conf index 0ec5c8fe..74d57207 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,2 +1,2 @@ -version = "2.4.2" +version = "3.8.0" align.tokens = [] diff --git a/build.sbt b/build.sbt index d973c9ab..775edf3e 100644 --- a/build.sbt +++ b/build.sbt @@ -1,12 +1,13 @@ name := "Dahlia" version := "0.0.2" -scalaVersion := "2.13.12" +scalaVersion := "3.3.1" libraryDependencies ++= Seq( - "org.scalatest" %% "scalatest" % "3.0.8" % "test", + "org.scalatest" %% "scalatest" % "3.2.18" % "test", + "org.scalatest" %% "scalatest-funspec" % "3.2.18" % "test", "org.scala-lang.modules" %% "scala-parser-combinators" % "2.0.0", - "com.lihaoyi" %% "fastparse" % "2.3.0", + "com.lihaoyi" %% "fastparse" % "3.0.2", "com.github.scopt" %% "scopt" % "4.0.1", "com.outr" %% "scribe" % "3.5.5", "com.lihaoyi" %% "sourcecode" % "0.2.7" @@ -16,26 +17,26 @@ scalacOptions ++= Seq( "-deprecation", "-unchecked", "-feature", - "-Ywarn-unused", - "-Ywarn-value-discard", - "-Xfatal-warnings" + "-Xfatal-warnings", + "-new-syntax", + "-indent" ) // Reload changes to this file. Global / onChangedBuildSource := ReloadOnSourceChanges // Disable options in sbt console. -scalacOptions in (Compile, console) ~= +Compile / console / scalacOptions ~= (_ filterNot ((Set("-Xfatal-warnings", "-Ywarn-unused").contains(_)))) -testOptions in Test += Tests.Argument("-oD") -parallelExecution in Test := false -logBuffered in Test := false +Test / testOptions += Tests.Argument("-oD") +Test / parallelExecution := false +Test / logBuffered := false /* Store commit hash information */ -resourceGenerators in Compile += Def.task { +Compile / resourceGenerators += Def.task { import scala.sys.process._ - val file = (resourceManaged in Compile).value / "version.properties" + val file = (Compile / resourceManaged).value / "version.properties" val gitHash = "git rev-parse --short HEAD".!! val gitDiff = "git diff --stat".!! val status = if (gitDiff.trim() != "") "dirty" else "clean" @@ -48,11 +49,12 @@ resourceGenerators in Compile += Def.task { } /* sbt-assembly configuration: build an executable jar. */ -assemblyOption in assembly := (assemblyOption in assembly).value.copy( - prependShellScript = Some(sbtassembly.AssemblyPlugin.defaultShellScript) -) -assemblyJarName in assembly := "fuse.jar" -test in assembly := {} +//assembly / assemblyOption := (assembly / assemblyOption).value.copy( +// prependShellScript = Some(sbtassembly.AssemblyPlugin.defaultShellScript) +//) +ThisBuild / assemblyPrependShellScript := Some(sbtassembly.AssemblyPlugin.defaultShellScript) +assembly / assemblyJarName := "fuse.jar" +assembly / test := {} /* Define task to download picojson headers */ val getHeaders = taskKey[Unit]("Download header dependencies for runnable backend.") diff --git a/fuse b/fuse index 35a8f7c5..3a25f110 120000 --- a/fuse +++ b/fuse @@ -1 +1 @@ -target/scala-2.13/fuse.jar \ No newline at end of file +target/scala-3.3.1/fuse.jar \ No newline at end of file diff --git a/project/assembly.sbt b/project/assembly.sbt index 9c014713..d83c8830 100644 --- a/project/assembly.sbt +++ b/project/assembly.sbt @@ -1 +1 @@ -addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.9") +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.1.5") diff --git a/src/main/scala/Compiler.scala b/src/main/scala/Compiler.scala index 9c22a46e..8a8babed 100644 --- a/src/main/scala/Compiler.scala +++ b/src/main/scala/Compiler.scala @@ -29,7 +29,7 @@ object Compiler { ) def showDebug(ast: Prog, pass: String, c: Config): Unit = { - if (c.passDebug) { + if c.passDebug then { val top = ("=" * 15) + pass + ("=" * 15) println(top) println(Pretty.emitProg(ast)(c.logLevel == scribe.Level.Debug).trim) @@ -49,7 +49,7 @@ object Compiler { showDebug(preAst, "Original", c) // Run pre transformers if lowering is enabled - val ast = if (c.enableLowering) { + val ast = if c.enableLowering then { preTransformers.foldLeft(preAst)({ case (ast, (name, pass)) => { val newAst = pass.rewrite(ast) @@ -115,7 +115,7 @@ object Compiler { err match { case _: Errors.TypeError => { s"[${red("Type error")}] ${err.getMessage}" + - (if (c.enableLowering) + (if c.enableLowering then "\nDoes this program type check without the `--lower` flag? If it does, please report this as a bug: https://github.com/cucapra/dahlia/issues/new" else "") } @@ -129,9 +129,7 @@ object Compiler { }) .map(out => { // Get metadata about the compiler build. - val version = getClass.getResourceAsStream("/version.properties") - val meta = Source - .fromInputStream(version) + val meta = scala.io.Source.fromResource("version.properties") .getLines() .filter(l => l.trim != "") .mkString(", ") diff --git a/src/main/scala/GenerateExec.scala b/src/main/scala/GenerateExec.scala index 03449c31..2c78806c 100644 --- a/src/main/scala/GenerateExec.scala +++ b/src/main/scala/GenerateExec.scala @@ -19,18 +19,18 @@ object GenerateExec { // Not the compiler directory, check if the fallback directory has been setup. - if (Files.exists(headerLocation) == false) { + if Files.exists(headerLocation) == false then { // Fallback for headers not setup. Unpack headers from JAR file. headerLocation = headerFallbackLocation - if (Files.exists(headerFallbackLocation) == false) { + if Files.exists(headerFallbackLocation) == false then { scribe.warn( s"Missing headers required for `fuse run`." + s" Unpacking from JAR file into $headerFallbackLocation." ) val dir = Files.createDirectory(headerFallbackLocation) - for (header <- headers) { + for header <- headers do { val stream = getClass.getResourceAsStream(s"/headers/$header") val hdrSource = Source.fromInputStream(stream).toArray.map(_.toByte) Files.write( @@ -57,8 +57,8 @@ object GenerateExec { ): Either[String, Int] = { // Make sure all headers are downloaded. - for (header <- headers) { - if (Files.exists(headerLocation.resolve(header)) == false) { + for header <- headers do { + if Files.exists(headerLocation.resolve(header)) == false then { throw HeaderMissing(header, headerLocation.toString) } } @@ -75,7 +75,7 @@ object GenerateExec { scribe.info(cmd.mkString(" ")) val status = cmd ! logger - if (status != 0) { + if status != 0 then { Left(s"Failed to generate the executable $out.\n${stderr}") } else { Right(status) diff --git a/src/main/scala/Main.scala b/src/main/scala/Main.scala index 38b298df..914a3bb4 100644 --- a/src/main/scala/Main.scala +++ b/src/main/scala/Main.scala @@ -8,8 +8,7 @@ import Compiler._ import common.Logger import common.Configuration._ -object Main { - +object Main: // Command-line names for backends. val backends = Map( "vivado" -> Vivado, @@ -23,9 +22,7 @@ object Main { "axi" -> Axi ) - val version = getClass.getResourceAsStream("/version.properties") - val meta = Source - .fromInputStream(version) + val meta = scala.io.Source.fromResource("version.properties") .getLines() .filter(l => l.trim != "") .map(d => { @@ -53,7 +50,7 @@ object Main { opt[String]('n', "name") .valueName("") .validate(x => - if (x.matches("[A-Za-z0-9_]+")) success + if x.matches("[A-Za-z0-9_]+") then success else failure("Kernel name should only contain alphanumerals and _") ) .action((x, c) => c.copy(kernelName = x)) @@ -62,7 +59,7 @@ object Main { opt[String]('b', "backend") .valueName("") .validate(b => - if (backends.contains(b)) success + if backends.contains(b) then success else failure( s"Invalid backend name. Valid backends are ${backends.keys.mkString(", ")}" @@ -89,7 +86,7 @@ object Main { opt[String]("memory-interface") .validate(b => - if (memoryInterfaces.contains(b)) success + if memoryInterfaces.contains(b) then success else failure( s"Invalid memory interface. Valid memory interfaces are ${memoryInterfaces.keys.mkString(", ")}" @@ -117,7 +114,7 @@ object Main { ) } - def runWithConfig(conf: Config): Either[String, Int] = { + def runWithConfig(conf: Config): Either[String, Int] = type ErrString = String val path = conf.srcFile.toPath @@ -146,14 +143,11 @@ object Main { case _ => Right(0) } ) - status - } - - def main(args: Array[String]): Unit = { - parser.parse(args, emptyConf) match { - case Some(conf) => { + def main(args: Array[String]): Unit = + parser.parse(args, emptyConf) match + case Some(conf) => Logger.setLogLevel(conf.logLevel) val status = runWithConfig(conf) sys.exit( @@ -161,10 +155,5 @@ object Main { .map(compileErr => { System.err.println(compileErr); 1 }) .merge ) - } - case None => { + case None => sys.exit(1) - } - } - } -} diff --git a/src/main/scala/Parser.scala b/src/main/scala/Parser.scala index dad40f4d..37955af0 100644 --- a/src/main/scala/Parser.scala +++ b/src/main/scala/Parser.scala @@ -11,7 +11,7 @@ import Configuration.stringToBackend import Utils.RichOption import CompilerError.BackendError -case class Parser(input: String) { +case class Parser(input: String): // Common surround expressions def braces[K: P, T](p: => P[T]): P[T] = P("{" ~/ p ~ "}") @@ -19,19 +19,18 @@ case class Parser(input: String) { def angular[K: P, T](p: => P[T]): P[T] = P("<" ~/ p ~ ">") def parens[K: P, T](p: => P[T]): P[T] = P("(" ~/ p ~ ")") - def positioned[K: P, T <: PositionalWithSpan](p: => P[T]): P[T] = { + def positioned[K: P, T <: PositionalWithSpan](p: => P[T]): P[T] = P(Index ~ p ~ Index).map({ case (index, t, end) => { val startPos = OffsetPosition(input, index) val out = t.setPos(startPos) val endPos = OffsetPosition(input, end) - if (startPos.line == endPos.line) { + if startPos.line == endPos.line then { out.setSpan(end - index) } out } }) - } /*def notKws[K: P] = { import fastparse.NoWhitespace._ @@ -43,18 +42,16 @@ case class Parser(input: String) { ) ~ &(" "))).opaque("non reserved keywords") }*/ - def kw[K: P](word: String): P[Unit] = { + def kw[K: P](word: String): P[Unit] = import fastparse.NoWhitespace._ P(word ~ !CharsWhileIn("a-zA-Z0-9_")) - } // Basic atoms - def iden[K: P]: P[Id] = { + def iden[K: P]: P[Id] = import fastparse.NoWhitespace._ positioned(P(CharIn("a-zA-Z_") ~ CharsWhileIn("a-zA-Z0-9_").?).!.map({ case rest => Id(rest) }).opaque("Expected valid identifier")) - } def number[K: P]: P[Int] = P(CharIn("0-9").rep(1).!.map(_.toInt)).opaque("Expected positive number") @@ -469,19 +466,12 @@ case class Parser(input: String) { }) ) - def parse(): Prog = { - fastparse.parse[Prog](input, prog(_)) match { + def parse(): Prog = + fastparse.parse[Prog](input, prog(_)) match case Parsed.Success(e, _) => e - case Parsed.Failure(_, index, extra) => { + case Parsed.Failure(_, index, extra) => val traced = extra.trace() val loc = OffsetPosition(input, index) val msg = Errors.withPos(s"Expected ${traced.failure.label}", loc) throw Errors.ParserError(msg) - } - // XXX(rachit): Scala 2.13.4 complains that this pattern is not exhaustive. - // This is not true... - case _ => ??? - } - } -} diff --git a/src/main/scala/Utils.scala b/src/main/scala/Utils.scala index 94eebc13..cbd305a9 100644 --- a/src/main/scala/Utils.scala +++ b/src/main/scala/Utils.scala @@ -56,7 +56,7 @@ object Utils { } @inline def assertOrThrow[T <: Throwable](cond: Boolean, except: => T) = { - if (!cond) throw except + if !cond then throw except } @deprecated( diff --git a/src/main/scala/backends/Backend.scala b/src/main/scala/backends/Backend.scala index 8804eeee..1d6c8763 100644 --- a/src/main/scala/backends/Backend.scala +++ b/src/main/scala/backends/Backend.scala @@ -9,7 +9,7 @@ import CompilerError.BackendError trait Backend { def emit(p: Syntax.Prog, c: Configuration.Config): String = { - if (c.header && (canGenerateHeader == false)) { + if c.header && (canGenerateHeader == false) then { throw BackendError(s"Backend $this does not support header generation.") } emitProg(p, c) diff --git a/src/main/scala/backends/CppLike.scala b/src/main/scala/backends/CppLike.scala index 8a92638e..7601b1dc 100644 --- a/src/main/scala/backends/CppLike.scala +++ b/src/main/scala/backends/CppLike.scala @@ -35,7 +35,7 @@ object Cpp { * Helper to generate a function call that might have a type parameter */ def cCall(f: String, tParam: Option[Doc], args: Seq[Doc]): Doc = { - text(f) <> (if (tParam.isDefined) angles(tParam.get) else emptyDoc) <> + text(f) <> (if tParam.isDefined then angles(tParam.get) else emptyDoc) <> parens(commaSep(args)) } @@ -76,7 +76,7 @@ object Cpp { */ def emitLet(let: CLet): Doc = emitDecl(let.id, let.typ.get) <> - (if (let.e.isDefined) space <> equal <+> emitExpr(let.e.get) + (if let.e.isDefined then space <> equal <+> emitExpr(let.e.get) else emptyDoc) <> semi @@ -93,7 +93,7 @@ object Cpp { case EApp(fn, args) => fn <> parens(commaSep(args.map(emitExpr))) case EInt(v, base) => value(emitBaseInt(v, base)) case ERational(d) => value(d) - case EBool(b) => value(if (b) 1 else 0) + case EBool(b) => value(if b then 1 else 0) case EVar(id) => value(id) case EBinop(op, e1, e2) => parens(e1 <+> text(op.toString) <+> e2) case EArrAccess(id, idxs) => @@ -116,7 +116,7 @@ object Cpp { */ def emitRange(range: CRange): Doc = parens { val CRange(id, _, rev, s, e, _) = range - if (rev) { + if rev then { text("int") <+> id <+> equal <+> value(e - 1) <> semi <+> id <+> text(">=") <+> value(s) <> semi <+> id <> text("--") @@ -167,7 +167,7 @@ object Cpp { ) .getOrElse(emptyDoc) - if (entry) text("extern") <+> quote(text("C")) <+> scope(body) + if entry then text("extern") <+> quote(text("C")) <+> scope(body) else body } diff --git a/src/main/scala/backends/CppRunnable.scala b/src/main/scala/backends/CppRunnable.scala index e703dd00..0fac3ccb 100644 --- a/src/main/scala/backends/CppRunnable.scala +++ b/src/main/scala/backends/CppRunnable.scala @@ -26,7 +26,7 @@ private class CppRunnable extends CppLike { case _: TBool => text("bool") case _: TIndex => text("int") case _: TStaticInt => throw Impossible("TStaticInt type should not exist") - case TSizedInt(_, un) => text(if (un) "unsigned int" else "int") + case TSizedInt(_, un) => text(if un then "unsigned int" else "int") case _: TFloat => text("float") case _: TDouble | _: TFixed => text("double") case _: TRational => throw Impossible("Rational type should not exist") @@ -79,7 +79,7 @@ private class CppRunnable extends CppLike { def emitFor(cmd: CFor): Doc = text("for") <> emitRange(cmd.range) <+> scope { cmd.par <> { - if (cmd.combine != CEmpty) + if cmd.combine != CEmpty then line <> text("// combiner:") <@> cmd.combine else emptyDoc diff --git a/src/main/scala/backends/VivadoBackend.scala b/src/main/scala/backends/VivadoBackend.scala index 3b8a52e3..12581edc 100644 --- a/src/main/scala/backends/VivadoBackend.scala +++ b/src/main/scala/backends/VivadoBackend.scala @@ -21,13 +21,13 @@ private class VivadoBackend(config: Config) extends CppLike { def interfaceValid(decls: Seq[Decl]) = decls.collect({ case Decl(id, typ: TArray) => { - if (typ.ports > 1) + if typ.ports > 1 then throw BackendError( s"Multiported array argument `${id}' is disallowed. SDAccel inconsistently fails with RESOURCE pragma on argument arrays." ) typ.dims.foreach({ case (_, bank) => - if (bank > 1) + if bank > 1 then throw BackendError( s"Partitioned array argument `${id}' is disallowed. SDAccel generates incorrect hardware for partitioned argument arrays." ) @@ -71,7 +71,7 @@ private class VivadoBackend(config: Config) extends CppLike { } def emitPipeline(enabled: Boolean): Doc = - if (enabled) value(s"#pragma HLS PIPELINE") <> line else emptyDoc + if enabled then value(s"#pragma HLS PIPELINE") <> line else emptyDoc def emitFor(cmd: CFor): Doc = text("for") <> emitRange(cmd.range) <+> scope { @@ -79,7 +79,7 @@ private class VivadoBackend(config: Config) extends CppLike { unroll(cmd.range.u) <@> text("#pragma HLS LOOP_FLATTEN off") <@> cmd.par <> - (if (cmd.combine != CEmpty) line <> text("// combiner:") <@> cmd.combine + (if cmd.combine != CEmpty then line <> text("// combiner:") <@> cmd.combine else emptyDoc) } @@ -123,7 +123,7 @@ private class VivadoBackend(config: Config) extends CppLike { // Error if function arguments are partitioned/ported. interfaceValid(func.args) - if (entry) { + if entry then { val argPragmas = func.args.map(arg => config.memoryInterface match { case Axi => axiHeader(arg) @@ -151,9 +151,9 @@ private class VivadoBackend(config: Config) extends CppLike { case _: TFloat => text("float") case _: TDouble => text("double") case _: TRational => throw Impossible("Rational type should not exist") - case TSizedInt(s, un) => text(if (un) s"ap_uint<$s>" else s"ap_int<$s>") + case TSizedInt(s, un) => text(if un then s"ap_uint<$s>" else s"ap_int<$s>") case TFixed(t, i, un) => - text(if (un) s"ap_ufixed<$t,$i>" else s"ap_fixed<$t,$i>") + text(if un then s"ap_ufixed<$t,$i>" else s"ap_fixed<$t,$i>") case TArray(typ, _, _) => emitType(typ) case TRecType(n, _) => text(n.toString) case _: TFun => throw Impossible("Cannot emit function types") diff --git a/src/main/scala/backends/calyx/Ast.scala b/src/main/scala/backends/calyx/Ast.scala index c9005959..b20ac16e 100644 --- a/src/main/scala/backends/calyx/Ast.scala +++ b/src/main/scala/backends/calyx/Ast.scala @@ -17,7 +17,7 @@ object Calyx { ) extends Emitable { def addPos(pos: Position): Int = { val key = pos - if (!this.map.contains(key)) { + if !this.map.contains(key) then { this.map.update(key, this.counter) this.counter = this.counter + 1 } @@ -46,7 +46,7 @@ object Calyx { implicit meta: Metadata ): Doc = { // Add position information to the metadata. - if (pos.line != 0 && pos.column != 0) { + if pos.line != 0 && pos.column != 0 then { val count = meta.addPos(pos) text("@pos") <> parens(text(count.toString)) <> space } else { @@ -99,7 +99,7 @@ object Calyx { override def doc(): Doc = { val attrDoc = hsep(attrs.map({ case (attr, v) => text(s"@${attr}") <> parens(text(v.toString())) - })) <> (if (attrs.isEmpty) emptyDoc else space) + })) <> (if attrs.isEmpty then emptyDoc else space) attrDoc <> id.doc() <> colon <+> value(width) } } @@ -190,9 +190,9 @@ object Calyx { case (attr, v) => text("@") <> text(attr) <> parens(text(v.toString())) }) - ) <> (if (attrs.isEmpty) emptyDoc else space) + ) <> (if attrs.isEmpty then emptyDoc else space) - attrDoc <> (if (ref) text("ref") <> space else emptyDoc) <> + attrDoc <> (if ref then text("ref") <> space else emptyDoc) <> id.doc() <+> equal <+> comp.doc() <> semi } case Assign(src, dest, True) => @@ -200,9 +200,9 @@ object Calyx { case Assign(src, dest, guard) => dest.doc() <+> equal <+> guard.doc() <+> text("?") <+> src.doc() <> semi case Group(id, conns, delay, comb) => - (if (comb) text("comb ") else emptyDoc) <> + (if comb then text("comb ") else emptyDoc) <> text("group") <+> id.doc() <> - (if (delay.isDefined) + (if delay.isDefined then angles(text("\"promotable\"") <> equal <> text(delay.get.toString())) else emptyDoc) <+> scope(vsep(conns.map(_.doc()))) @@ -215,7 +215,7 @@ object Calyx { case (Group(thisId, _, _, _), Group(thatId, _, _, _)) => thisId.compare(thatId) case (Assign(thisSrc, thisDest, _), Assign(thatSrc, thatDest, _)) => { - if (thisSrc.compare(thatSrc) == 0) { + if thisSrc.compare(thatSrc) == 0 then { thisDest.compare(thatDest) } else { thisSrc.compare(thatSrc) @@ -264,7 +264,7 @@ object Calyx { } ) - (this(id, connections, if (comb) None else staticDelay, comb), st) + (this(id, connections, if comb then None else staticDelay, comb), st) } } @@ -319,7 +319,7 @@ object Calyx { } def attributesDoc(): Doc = - if (this.attributes.isEmpty) { + if this.attributes.isEmpty then { emptyDoc } else { hsep(attributes.map({ @@ -338,7 +338,7 @@ object Calyx { text("if") <+> port.doc() <+> text("with") <+> cond.doc() <+> scope(trueBr.doc) <> ( - if (falseBr == Empty) + if falseBr == Empty then emptyDoc else space <> text("else") <+> scope(falseBr.doc) @@ -352,7 +352,7 @@ object Calyx { } case i @ Invoke(id, refCells, inConnects, outConnects) => { val cells = - if (refCells.isEmpty) + if refCells.isEmpty then emptyDoc else brackets(commaSep(refCells.map({ @@ -405,7 +405,7 @@ object Stdlib { def binop(op: String, bitwidth: Int, signed: Boolean): Calyx.CompInst = Calyx.CompInst( - s"std_${if (signed) "s" else ""}$op", + s"std_${if signed then "s" else ""}$op", List(bitwidth) ) @@ -424,7 +424,7 @@ object Stdlib { signed: Boolean ): Calyx.CompInst = Calyx.CompInst( - s"std_fp_${(if (signed) "s" else "")}$op", + s"std_fp_${(if signed then "s" else "")}$op", List(width, int_width, frac_width) ) diff --git a/src/main/scala/backends/calyx/Backend.scala b/src/main/scala/backends/calyx/Backend.scala index 29b571fa..f3065430 100644 --- a/src/main/scala/backends/calyx/Backend.scala +++ b/src/main/scala/backends/calyx/Backend.scala @@ -190,7 +190,7 @@ private class CalyxBackendHelper { funcId: Id )(implicit id2FuncDef: FunctionMapping): List[BigInt] = { val id = funcId.toString() - if (!requiresWidthArguments.contains(id)) { + if !requiresWidthArguments.contains(id) then { List() } else { val typ = id2FuncDef(funcId).retTy; @@ -299,7 +299,7 @@ private class CalyxBackendHelper { comp.name.port("out"), None, struct ++ e1Out.structure ++ e2Out.structure, - for (d1 <- e1Out.delay; d2 <- e2Out.delay) + for d1 <- e1Out.delay; d2 <- e2Out.delay yield d1 + d2, None ) @@ -307,7 +307,7 @@ private class CalyxBackendHelper { // if there is additional information about the integer bit, // use fixed point binary operation case (e1Bits, Some(intBit1)) => { - val (e2Bits, Some(intBit2)) = bitsForType(e2.typ, e2.pos) + val (e2Bits, Some(intBit2)) = bitsForType(e2.typ, e2.pos) : @unchecked val fracBit1 = e1Bits - intBit1 val fracBit2 = e2Bits - intBit2 val isSigned = signed(e1.typ) @@ -336,7 +336,7 @@ private class CalyxBackendHelper { comp.name.port("out"), None, struct ++ e1Out.structure ++ e2Out.structure, - for (d1 <- e1Out.delay; d2 <- e2Out.delay) + for d1 <- e1Out.delay; d2 <- e2Out.delay yield d1 + d2, None ) @@ -415,7 +415,7 @@ private class CalyxBackendHelper { comp.name.port(outPort), Some(comp.name.port("done")), struct ++ e1Out.structure ++ e2Out.structure, - for (d1 <- e1Out.delay; d2 <- e2Out.delay; d3 <- delay) + for d1 <- e1Out.delay; d2 <- e2Out.delay; d3 <- delay yield d1 + d2 + d3, Some((comp.name.port("done"), delay)) ) @@ -515,7 +515,7 @@ private class CalyxBackendHelper { } EmitOutput( - if (calyxVarType == LocalVar) cell.port(port) + if calyxVarType == LocalVar then cell.port(port) else ThisPort(cell), done, struct, @@ -545,7 +545,7 @@ private class CalyxBackendHelper { val const = Cell( genName("bool"), - Stdlib.constant(1, if (v) 1 else 0), + Stdlib.constant(1, if v then 1 else 0), false, List() ) @@ -560,20 +560,20 @@ private class CalyxBackendHelper { // Cast ERational to Fixed Point. case ECast(ERational(value), typ) => { val _ = rhsInfo - val (width, Some(intWidth)) = bitsForType(Some(typ), expr.pos) + val (width, Some(intWidth)) = bitsForType(Some(typ), expr.pos) : @unchecked val fracWidth = width - intWidth // Interpret as an integer. val isNegative = value.startsWith("-") val partition = value.split('.') val sIntPart = partition(0) - val intPart = if (isNegative) { + val intPart = if isNegative then { sIntPart.substring(1, sIntPart.length()) } else { sIntPart } val bdFracPart = BigDecimal("0." + partition(1)) val fracValue = (bdFracPart * BigDecimal(2).pow(fracWidth)) - if (!fracValue.isWhole) { + if !fracValue.isWhole then { throw BackendError( s"The value $value of type $typ is not representable in fixed point", expr.pos @@ -582,7 +582,7 @@ private class CalyxBackendHelper { val intBits = binaryString(intPart.toInt, intWidth) val fracBits = binaryString(fracValue.toBigInt, fracWidth) - val bits = if (isNegative) { + val bits = if isNegative then { negateTwosComplement(intBits + fracBits) } else { intBits + fracBits @@ -606,7 +606,7 @@ private class CalyxBackendHelper { val (vBits, _) = bitsForType(e.typ, e.pos) val (cBits, _) = bitsForType(Some(t), e.pos) val res = emitExpr(e) - if (vBits == cBits) { + if vBits == cBits then { // No slicing or padding is necessary. EmitOutput( res.port, @@ -616,7 +616,7 @@ private class CalyxBackendHelper { res.multiCycleInfo ) } else { - val comp = if (cBits > vBits) { + val comp = if cBits > vBits then { Cell(genName("pad"), Stdlib.pad(vBits, cBits), false, List()) } else { Cell(genName("slice"), Stdlib.slice(vBits, cBits), false, List()) @@ -645,11 +645,11 @@ private class CalyxBackendHelper { ) val donePortName = - if (rhsInfo.isDefined) "write_done" else "read_done" + if rhsInfo.isDefined then "write_done" else "read_done" // The value is generated on `read_data` and written on `write_data`. val portName = - if (rhsInfo.isDefined) "write_data" else "read_data" + if rhsInfo.isDefined then "write_data" else "read_data" // The array ports change if the array is a function parameter. We want to access the @@ -657,7 +657,7 @@ private class CalyxBackendHelper { val isParam = (typ == ParameterVar) val (writeEnPort, donePort, accessPort) = - if (isParam) { + if isParam then { ( ThisPort(CompVar(s"${id}_write_en")), ThisPort(CompVar(s"${id}_${donePortName}")), @@ -677,21 +677,21 @@ private class CalyxBackendHelper { case (structs, (accessor, idx)) => { val result = emitExpr(accessor) val addrPort = - if (isParam) ThisPort(CompVar(s"${id}_addr${idx}")) + if isParam then ThisPort(CompVar(s"${id}_addr${idx}")) else arr.port("addr" + idx) val con = Assign(result.port, addrPort) con :: result.structure ++ structs } }) - val readEnPort = if (isParam) { + val readEnPort = if isParam then { ThisPort(CompVar(s"${id}_read_en")) } else { arr.port("read_en") } // always assign 1 to read_en port if we want to read from seq mem - val readEnStruct = if (rhsInfo.isDefined) List() else List(Assign(ConstantPort(1,1), readEnPort)) + val readEnStruct = if rhsInfo.isDefined then List() else List(Assign(ConstantPort(1,1), readEnPort)) val writeEnStruct = rhsInfo match { @@ -1043,12 +1043,12 @@ private class CalyxBackendHelper { case FuncDef(id, params, retTy, bodyOpt) => { definitions + (id -> FuncDef(id, params, retTy, bodyOpt)) } - case _ => definitions +// case _: RecordDef => definitions } ) val functionDefinitions: List[Component] = - for ((id, FuncDef(_, params, retType, Some(body))) <- id2FuncDef.toList) + for ( case (id, FuncDef(_, params, retType, Some(body))) <- id2FuncDef.toList ) yield { val (refCells, inputPorts) = params.partitionMap(param => param.typ match { @@ -1080,7 +1080,7 @@ private class CalyxBackendHelper { Component( id.toString(), inputPorts.toList, - if (retType == TVoid()) List() + if retType == TVoid() then List() // If the return type of the component is not void, add an `out` wire. else List( @@ -1097,7 +1097,7 @@ private class CalyxBackendHelper { Import("primitives/binary_operators.futil") :: p.includes.flatMap(_.backends.get(C.Calyx)).map(i => Import(i)).toList - val main = if (!c.compilerOpts.contains("no-main")) { + val main = if !c.compilerOpts.contains("no-main") then { val declStruct = p.decls.map(emitDecl) val store = declStruct.foldLeft(Map[CompVar, (CompVar, VType)]())((store, struct) => @@ -1110,7 +1110,7 @@ private class CalyxBackendHelper { val struct = declStruct.toList ++ cmdStruct val mainComponentName = - if (c.kernelName == "kernel") "main" else c.kernelName + if c.kernelName == "kernel" then "main" else c.kernelName List( Component(mainComponentName, List(), List(), struct.sorted, control) ) diff --git a/src/main/scala/backends/calyx/Helpers.scala b/src/main/scala/backends/calyx/Helpers.scala index b355e11b..f0fedb13 100644 --- a/src/main/scala/backends/calyx/Helpers.scala +++ b/src/main/scala/backends/calyx/Helpers.scala @@ -15,8 +15,8 @@ object Helpers { * two's complement representation. */ def negateTwosComplement(bitString: String): String = { - if (bitString.forall(_ == '0')) { - bitString + if bitString.forall(_ == '0') then { + return bitString } val t = bitString .replaceAll("0", "_") diff --git a/src/main/scala/common/Configuration.scala b/src/main/scala/common/Configuration.scala index 892c7473..679222d1 100644 --- a/src/main/scala/common/Configuration.scala +++ b/src/main/scala/common/Configuration.scala @@ -5,8 +5,8 @@ import java.io.File object Configuration { sealed trait Mode - final case object Compile extends Mode - final case object Run extends Mode + case object Compile extends Mode + case object Run extends Mode def stringToBackend(name: String): Option[BackendOption] = name match { case "vivado" => Some(Vivado) @@ -23,14 +23,14 @@ object Configuration { case Calyx => "calyx" } } - final case object Vivado extends BackendOption - final case object Cpp extends BackendOption - final case object Calyx extends BackendOption + case object Vivado extends BackendOption + case object Cpp extends BackendOption + case object Calyx extends BackendOption // The type of Vivado memory interface to generate sealed trait MemoryInterface - final case object ApMemory extends MemoryInterface - final case object Axi extends MemoryInterface + case object ApMemory extends MemoryInterface + case object Axi extends MemoryInterface val emptyConf = Config(null) diff --git a/src/main/scala/common/Document.scala b/src/main/scala/common/Document.scala index c400b5f9..a7029396 100644 --- a/src/main/scala/common/Document.scala +++ b/src/main/scala/common/Document.scala @@ -22,7 +22,7 @@ object PrettyPrint { */ abstract class Doc { def <@>(hd: Doc): Doc = { - if (hd == DocNil) this + if hd == DocNil then this else this <> DocBreak <> hd } def <>(hd: Doc): Doc = (this, hd) match { @@ -46,11 +46,11 @@ object PrettyPrint { def spaces(n: Int): Unit = { var rem = n - while (rem >= 16) { writer write " "; rem -= 16 } - if (rem >= 8) { writer write " "; rem -= 8 } - if (rem >= 4) { writer write " "; rem -= 4 } - if (rem >= 2) { writer write " "; rem -= 2 } - if (rem == 1) { writer write " " } + while rem >= 16 do { writer write " "; rem -= 16 } + if rem >= 8 then { writer write " "; rem -= 8 } + if rem >= 4 then { writer write " "; rem -= 4 } + if rem >= 2 then { writer write " "; rem -= 2 } + if rem == 1 then { writer write " " } } def fmt(state: List[FmtState]): Unit = state match { @@ -102,7 +102,7 @@ object PrettyPrint { def nest(d: Doc, i: Int): Doc = DocNest(i, d) def folddoc(ds: Iterable[Doc], f: (Doc, Doc) => Doc) = - if (ds.isEmpty) emptyDoc + if ds.isEmpty then emptyDoc else ds.tail.foldLeft(ds.head)(f) /** Builder functions */ diff --git a/src/main/scala/common/EnvHelpers.scala b/src/main/scala/common/EnvHelpers.scala index a113be9c..932e1b02 100644 --- a/src/main/scala/common/EnvHelpers.scala +++ b/src/main/scala/common/EnvHelpers.scala @@ -53,7 +53,7 @@ object EnvHelpers { * Definition of a trivial environment that doesn't track any * information. */ - final case class UnitEnv() extends ScopeManager[UnitEnv] { + case class UnitEnv() extends ScopeManager[UnitEnv] { def merge(that: UnitEnv) = this } diff --git a/src/main/scala/common/MultiSet.scala b/src/main/scala/common/MultiSet.scala index 82c6ff8e..5a599230 100644 --- a/src/main/scala/common/MultiSet.scala +++ b/src/main/scala/common/MultiSet.scala @@ -9,7 +9,7 @@ object MultiSet { def fromSeq[K](seq: Seq[K]): MultiSet[K] = MultiSet(seq.foldLeft(Map[K, Int]())({ case (ms, v) => - if (ms.contains(v)) ms + (v -> (ms(v) + 1)) + if ms.contains(v) then ms + (v -> (ms(v) + 1)) else ms + (v -> 1) })) @@ -33,7 +33,7 @@ object MultiSet { def zipWith(that: MultiSet[K], op: (Int, Int) => Int): MultiSet[K] = { val thatMap = that.setMap val (thisKeys, thatKeys) = (setMap.keys.toSet, thatMap.keys.toSet) - if (thisKeys != thatKeys) { + if thisKeys != thatKeys then { throw new NoSuchElementException( s"Element ${thisKeys.diff(thatKeys).head} not in both multisets.\nThis: ${setMap}\nThat: ${thatMap}." ) @@ -45,7 +45,7 @@ object MultiSet { def diff(that: MultiSet[K]) = MultiSet(setMap.map({ case (k, v) => { - k -> (if (that.setMap.contains(k)) (v - that.setMap(k)) else v) + k -> (if that.setMap.contains(k) then (v - that.setMap(k)) else v) } })) diff --git a/src/main/scala/common/Pretty.scala b/src/main/scala/common/Pretty.scala index e74bd4b1..577c3584 100644 --- a/src/main/scala/common/Pretty.scala +++ b/src/main/scala/common/Pretty.scala @@ -52,7 +52,7 @@ object Pretty { implicit def emitId(id: Id)(implicit debug: Boolean): Doc = { val idv = value(id.v) - if (debug) + if debug then id.typ.map(t => idv <> text("@") <> emitTyp(t)).getOrElse(idv) else idv } @@ -70,12 +70,12 @@ object Pretty { case EApp(fn, args) => fn <> parens(commaSep(args.map(emitExpr))) case EInt(v, base) => value(emitBaseInt(v, base)) case ERational(d) => value(d) - case EBool(b) => value(if (b) "true" else "false") + case EBool(b) => value(if b then "true" else "false") case EVar(id) => emitId(id) case EBinop(op, e1, e2) => parens(e1 <+> text(op.toString) <+> e2) case acc @ EArrAccess(id, idxs) => { val doc = id <> ssep(idxs.map(idx => brackets(emitExpr(idx))), emptyDoc) - if (debug) + if debug then acc.consumable .map(ann => brackets(doc <> colon <+> emitConsume(ann))) .getOrElse(doc) @@ -91,7 +91,7 @@ object Pretty { emptyDoc ) - if (debug) + if debug then brackets( doc <> colon <+> acc.consumable.map(emitConsume).getOrElse(emptyDoc) ) @@ -114,12 +114,12 @@ object Pretty { t.map(x => text(":") <+> text(x.toString)).getOrElse(emptyDoc) parens( text("let") <+> id <> typAnnot <+> equal <+> - (if (rev) text("rev") <+> emptyDoc else emptyDoc) <> + (if rev then text("rev") <+> emptyDoc else emptyDoc) <> value(s) <+> text("..") <+> value( e ) ) <> - (if (u > 1) space <> text("unroll") <+> value(u) else emptyDoc) + (if u > 1 then space <> text("unroll") <+> value(u) else emptyDoc) } def emitView(view: View)(implicit debug: Boolean): Doc = { @@ -145,7 +145,7 @@ object Pretty { implicit debug: Boolean ): Doc = { val attr = - if (c.attributes.isEmpty) emptyDoc + if c.attributes.isEmpty then emptyDoc else text("/*") <+> emitAttributes(c.attributes) <+> text("*/") <> space @@ -173,15 +173,15 @@ object Pretty { } case CFor(r, pipe, par, com) => text("for") <+> emitRange(r) <> - (if (pipe) space <> text("pipeline") else emptyDoc) <+> + (if pipe then space <> text("pipeline") else emptyDoc) <+> scope(emitCmd(par)) <> - (if (com != CEmpty) + (if com != CEmpty then space <> text("combine") <+> scope(emitCmd(com)) else emptyDoc) case CWhile(cond, pipe, body) => text("while") <+> parens(cond) <> - (if (pipe) space <> text("pipeline") else emptyDoc) <+> scope( + (if pipe then space <> text("pipeline") else emptyDoc) <+> scope( emitCmd(body) ) case CDecorate(dec) => text("decor") <+> quote(value(dec)) diff --git a/src/main/scala/common/Syntax.scala b/src/main/scala/common/Syntax.scala index ebf986ea..b9e12f5b 100644 --- a/src/main/scala/common/Syntax.scala +++ b/src/main/scala/common/Syntax.scala @@ -27,8 +27,8 @@ object Syntax { */ object Annotations { sealed trait Consumable - final case object ShouldConsume extends Consumable - final case object SkipConsume extends Consumable + case object ShouldConsume extends Consumable + case object SkipConsume extends Consumable sealed trait ConsumableAnnotation { var consumable: Option[Consumable] = None @@ -55,8 +55,8 @@ object Syntax { // Capabilities for read/write sealed trait Capability - final case object Read extends Capability - final case object Write extends Capability + case object Read extends Capability + case object Write extends Capability sealed trait Type extends PositionalWithSpan { override def toString = this match { @@ -65,12 +65,12 @@ object Syntax { case _: TRational => "rational" case _: TFloat => "float" case _: TDouble => "double" - case TFixed(t, i, un) => s"${if (un) "u" else ""}fix<$t,$i>" - case TSizedInt(l, un) => s"${if (un) "u" else ""}bit<$l>" + case TFixed(t, i, un) => s"${if un then "u" else ""}fix<$t,$i>" + case TSizedInt(l, un) => s"${if un then "u" else ""}bit<$l>" case TStaticInt(s) => s"static($s)" case TArray(t, dims, p) => - (if (p > 1) s"$t{$p}" else s"$t") + dims.foldLeft("")({ - case (acc, (d, b)) => s"$acc[$d${if (b > 1) s" bank $b" else ""}]" + (if p > 1 then s"$t{$p}" else s"$t") + dims.foldLeft("")({ + case (acc, (d, b)) => s"$acc[$d${if b > 1 then s" bank $b" else ""}]" }) case TIndex(s, d) => s"idx($s, $d)" case TFun(args, ret) => s"${args.mkString("->")} -> ${ret}" @@ -106,7 +106,7 @@ object Syntax { case class TArray(typ: Type, dims: Seq[DimSpec], ports: Int) extends Type { dims.zipWithIndex.foreach({ case ((len, bank), dim) => - if (bank > len || len % bank != 0) { + if bank > len || len % bank != 0 then { throw MalformedType( s"Dimension $dim of TArray is malformed. Full type: $this" ) @@ -161,7 +161,7 @@ object Syntax { u: Int ) extends PositionalWithSpan { def idxType: TIndex = { - if (abs(e - s) % u != 0) { + if abs(e - s) % u != 0 then { throw UnrollRangeError(this.pos, e - s, u) } else { TIndex((0, u), (s / u, e / u)) @@ -209,10 +209,10 @@ object Syntax { extends Command case class CDecorate(value: String) extends Command case class CUpdate(lhs: Expr, rhs: Expr) extends Command { - if (lhs.isLVal == false) throw UnexpectedLVal(lhs, "assignment") + if lhs.isLVal == false then throw UnexpectedLVal(lhs, "assignment") } case class CReduce(rop: ROp, lhs: Expr, rhs: Expr) extends Command { - if (lhs.isLVal == false) throw UnexpectedLVal(lhs, "reduction") + if lhs.isLVal == false then throw UnexpectedLVal(lhs, "reduction") } case class CReturn(exp: Expr) extends Command case class CExpr(exp: Expr) extends Command @@ -238,9 +238,9 @@ object Syntax { case _ => Seq(cmd) } ) - if (flat.length == 0) { + if flat.length == 0 then { CEmpty - } else if (flat.length == 1) { + } else if flat.length == 1 then { flat(0) } else { CPar(flat) @@ -265,9 +265,9 @@ object Syntax { case _ => Seq(cmd) } ) - if (flat.length == 0) { + if flat.length == 0 then { CEmpty - } else if (flat.length == 1) { + } else if flat.length == 1 then { flat(0) } else { CSeq(flat) diff --git a/src/main/scala/common/Transformer.scala b/src/main/scala/common/Transformer.scala index 8f7ecf50..0a229b90 100644 --- a/src/main/scala/common/Transformer.scala +++ b/src/main/scala/common/Transformer.scala @@ -23,7 +23,7 @@ object Transformer { c1 match { case _: CPar | _: CSeq | _: CBlock => () case _ => { - if (c1.pos.line == 0 && c1.pos.column == 0) { + if c1.pos.line == 0 && c1.pos.column == 0 then { c1.withPos(cmd) } } diff --git a/src/main/scala/passes/AddBitWidth.scala b/src/main/scala/passes/AddBitWidth.scala index 0b1b4b00..9f5c1d4f 100644 --- a/src/main/scala/passes/AddBitWidth.scala +++ b/src/main/scala/passes/AddBitWidth.scala @@ -26,7 +26,7 @@ object AddBitWidth extends TypedPartialTransformer { def myRewriteE: PF[(Expr, Env), (Expr, Env)] = { case (e: ECast, env) => e -> env case (e @ EArrAccess(arrId, idxs), env) => { - val Some(TArray(_, dims, _)) = arrId.typ + val Some(TArray(_, dims, _)) = arrId.typ : @unchecked val nIdxs = idxs .zip(dims) .map({ @@ -45,7 +45,7 @@ object AddBitWidth extends TypedPartialTransformer { e.copy(idxs = nIdxs) -> env } case (e: EInt, env) => - if (env.curTyp.isDefined) { + if env.curTyp.isDefined then { (ECast(e, env.curTyp.get), env) } else { e -> env @@ -60,7 +60,7 @@ object AddBitWidth extends TypedPartialTransformer { expr.copy(e1 = nl, e2 = nr) -> env } case (expr @ EBinop(_: NumOp | _: BitOp, l, r), env) => { - val nEnv = if (env.curTyp.isDefined) { + val nEnv = if env.curTyp.isDefined then { env } else { ABEnv( diff --git a/src/main/scala/passes/BoundsCheck.scala b/src/main/scala/passes/BoundsCheck.scala index 4e19d355..528c034f 100644 --- a/src/main/scala/passes/BoundsCheck.scala +++ b/src/main/scala/passes/BoundsCheck.scala @@ -16,7 +16,7 @@ object BoundsChecker { def check(p: Prog) = BCheck.check(p) - private final case object BCheck extends PartialChecker { + private case object BCheck extends PartialChecker { type Env = UnitEnv @@ -27,8 +27,8 @@ object BoundsChecker { * out of bound access when accessed. */ private def checkView(arrLen: Int, viewId: Id, view: View) = { - if (view.prefix.isDefined) { - val View(suf, Some(pre), _) = view + if view.prefix.isDefined then { + val View(suf, Some(pre), _) = view : @unchecked val (sufExpr, fac) = suf match { case Aligned(fac, e) => (e, fac) @@ -51,7 +51,7 @@ object BoundsChecker { 1 } - if (maxVal + pre > arrLen) { + if maxVal + pre > arrLen then { throw IndexOutOfBounds(viewId, arrLen, maxVal + pre, viewId.pos) } } @@ -70,7 +70,7 @@ object BoundsChecker { case ((idx, t), (size, _)) => t.foreach({ case idxt @ TSizedInt(n, _) => - if ((math.pow(2, n) - 1) >= size) { + if (math.pow(2, n) - 1) >= size then { scribe.warn( ( s"$idxt is used for an array access. " + @@ -80,10 +80,10 @@ object BoundsChecker { ) } case TStaticInt(v) => - if (v >= size) + if v >= size then throw IndexOutOfBounds(id, size, v, idx.pos) case t @ TIndex(_, _) => - if (t.maxVal >= size) + if t.maxVal >= size then throw IndexOutOfBounds(id, size, t.maxVal, idx.pos) case t => throw UnexpectedType(id.pos, "array access", s"[$t]", t) diff --git a/src/main/scala/passes/DependentLoops.scala b/src/main/scala/passes/DependentLoops.scala index b917b662..2a96df95 100644 --- a/src/main/scala/passes/DependentLoops.scala +++ b/src/main/scala/passes/DependentLoops.scala @@ -25,7 +25,7 @@ object DependentLoops { } } - private final case object UseCheck extends PartialChecker { + private case object UseCheck extends PartialChecker { type Env = UseEnv val emptyEnv = UseEnv(Set()) @@ -69,7 +69,7 @@ object DependentLoops { } } - private final case object DepCheck extends PartialChecker { + private case object DepCheck extends PartialChecker { type Env = DepEnv val emptyEnv = DepEnv(Set(), Set()) @@ -79,7 +79,7 @@ object DependentLoops { idxs.foreach(e => { val used = UseCheck.checkE(e)(UseCheck.emptyEnv) val intersect = env.depVars.intersect(used.used) - if (intersect.size != 0) { + if intersect.size != 0 then { val sourceId = intersect.toList(0) throw LoopDynamicAccess(e, sourceId) } @@ -90,7 +90,7 @@ object DependentLoops { def myCheckC: PF[(Command, Env), Env] = { case (CFor(range, _, par, _), env) => { - if (range.u > 1) { + if range.u > 1 then { env.forgetScope(e1 => checkC(par)(e1.addLoopVar(range.iter))) } else { env.forgetScope(e1 => checkC(par)(e1)) @@ -98,7 +98,7 @@ object DependentLoops { } case (CLet(id, _, Some(exp)), env) => { val used = UseCheck.checkE(exp)(UseCheck.emptyEnv) - if (env.intersect(used.used).size != 0) { + if env.intersect(used.used).size != 0 then { env.addDep(id) } else { env @@ -106,7 +106,7 @@ object DependentLoops { } case (CUpdate(EVar(id), rhs), env) => { val used = UseCheck.checkE(rhs)(UseCheck.emptyEnv) - if (env.intersect(used.used).size != 0) { + if env.intersect(used.used).size != 0 then { env.addDep(id) } else { env.removeDep(id) diff --git a/src/main/scala/passes/HoistMemoryReads.scala b/src/main/scala/passes/HoistMemoryReads.scala index 397e05ab..5d98a6a8 100644 --- a/src/main/scala/passes/HoistMemoryReads.scala +++ b/src/main/scala/passes/HoistMemoryReads.scala @@ -45,7 +45,7 @@ object HoistMemoryReads extends PartialTransformer { env: Env, acc: Command = CEmpty ): Command = { - if (env.map.values.isEmpty && acc == CEmpty) { + if env.map.values.isEmpty && acc == CEmpty then { cmd } else { CPar.smart(env.map.values.toSeq :+ acc :+ cmd) diff --git a/src/main/scala/passes/HoistSlowBinop.scala b/src/main/scala/passes/HoistSlowBinop.scala index 814cc336..059dbe62 100644 --- a/src/main/scala/passes/HoistSlowBinop.scala +++ b/src/main/scala/passes/HoistSlowBinop.scala @@ -41,7 +41,7 @@ object HoistSlowBinop extends TypedPartialTransformer { env: Env, acc: Command = CEmpty ): (Command, Env) = { - if (env.map.values.isEmpty && acc == CEmpty) { + if env.map.values.isEmpty && acc == CEmpty then { cmd -> emptyEnv } else { CSeq.smart(env.map.values.toSeq :+ acc :+ cmd) -> emptyEnv diff --git a/src/main/scala/passes/LoopCheck.scala b/src/main/scala/passes/LoopCheck.scala index add59320..8d7c76ae 100644 --- a/src/main/scala/passes/LoopCheck.scala +++ b/src/main/scala/passes/LoopCheck.scala @@ -72,7 +72,7 @@ object LoopChecker { idxs: Option[EArrAccess] = None ): LEnv = { val (env, check) = checkExprMap(idxs) - if (check) { + if check then { val e2 = state match { case DontKnow => env.atDk(env.getName(id)) case Def => env.atDef(env.getName(id)) @@ -84,7 +84,7 @@ object LoopChecker { } // Helper functions for ScopeManager def withScope(resources: Int)(inScope: LEnv => LEnv): LEnv = { - if (resources == 1) { + if resources == 1 then { inScope(this.addNameScope) match { case env: LEnv => env.endNameScope } @@ -114,7 +114,7 @@ object LoopChecker { val (innermap, outermap) = stateMap.endScope.get var outerenv = LEnv(outermap, nmap, emap)(res / resources) val keys = innermap.keys - for (k <- keys) { + for k <- keys do { outerenv = outerenv.updateState(k, innermap(k)) //inner map is a scala map } outerenv @@ -141,7 +141,7 @@ object LoopChecker { env.copy(stateMap = env.stateMap.addShadow(k, DontKnow)) case (Some(Def), Some(Use)) => throw LoopDepMerge(k) case (Some(DontKnow), Some(Use)) => throw LoopDepMerge(k) - case (v1, v2) => if (v1 == v2) env else mergeHelper(k, v2, v1, env) + case (v1, v2) => if v1 == v2 then env else mergeHelper(k, v2, v1, env) } // If statement @@ -155,7 +155,7 @@ object LoopChecker { } } - private final case object LCheck extends PartialChecker { + private case object LCheck extends PartialChecker { type Env = LEnv diff --git a/src/main/scala/passes/LowerForLoops.scala b/src/main/scala/passes/LowerForLoops.scala index 065595d5..bd5ae242 100644 --- a/src/main/scala/passes/LowerForLoops.scala +++ b/src/main/scala/passes/LowerForLoops.scala @@ -30,10 +30,10 @@ object LowerForLoops extends PartialTransformer { def myRewriteC: PF[(Command, Env), (Command, Env)] = { case (cfor @ CFor(range, pipeline, par, combine), env) => { - if (pipeline) throw NotImplemented("Lowering pipelined for loops.") + if pipeline then throw NotImplemented("Lowering pipelined for loops.") val CRange(it, typ, rev, s, e, u) = range - if (u != 1) throw NotImplemented("Lowering unrolled for loops.") + if u != 1 then throw NotImplemented("Lowering unrolled for loops.") // Generate a let bound variable sequenced with a while loop that // updates the iterator value. @@ -41,7 +41,7 @@ object LowerForLoops extends PartialTransformer { itVar.typ = typ // Refuse lowering without explicit type on iterator. - if (typ.isDefined == false) { + if typ.isDefined == false then { throw NotImplemented( "Cannot lower `for` loop without iterator type. Add explicit type for the iterator", it.pos @@ -50,8 +50,8 @@ object LowerForLoops extends PartialTransformer { val t = typ.get val init = - CLet(it, typ, Some(ECast(if (rev) EInt(e - 1) else EInt(s), t))).withPos(range) - val op = if (rev) { + CLet(it, typ, Some(ECast(if rev then EInt(e - 1) else EInt(s), t))).withPos(range) + val op = if rev then { NumOp("-", OpConstructor.sub) } else { NumOp("+", OpConstructor.add) @@ -59,7 +59,7 @@ object LowerForLoops extends PartialTransformer { val upd = CUpdate(itVar.copy(), EBinop(op, itVar.copy(), ECast(EInt(1), t))).withPos(range) val cond = - if (rev) { + if rev then { EBinop(CmpOp(">="), itVar.copy(), ECast(EInt(s), t)) } else { EBinop(CmpOp("<="), itVar.copy(), ECast(EInt(e - 1), t)) diff --git a/src/main/scala/passes/LowerUnroll.scala b/src/main/scala/passes/LowerUnroll.scala index fbe361a4..e08f6555 100644 --- a/src/main/scala/passes/LowerUnroll.scala +++ b/src/main/scala/passes/LowerUnroll.scala @@ -123,7 +123,7 @@ object LowerUnroll extends PartialTransformer { */ def fromView(dims: Seq[DimSpec], v: CView): ViewTransformer = { val t = (idxs: Seq[(Expr, Option[Int])]) => { - if (idxs.length != dims.length) { + if idxs.length != dims.length then { throw PassError("LowerUnroll: Incorrect access dimensions") } @@ -248,7 +248,7 @@ object LowerUnroll extends PartialTransformer { case ta: TArray => { ta.dims.foreach({ case (_, b) => - if (b > 1) + if b > 1 then throw Malformed(id.pos, "Banked `decl` cannot be lowered") }) env @@ -296,15 +296,15 @@ object LowerUnroll extends PartialTransformer { * Read https://github.com/cucapra/dahlia/issues/311 for details. */ private def mergePar(cmds: Seq[Command]): Command = { - if (cmds.isEmpty) { + if cmds.isEmpty then { CEmpty - } else if (cmds.length == 1) { + } else if cmds.length == 1 then { cmds(0) } // [{ a0 -- b0 -- ...}, {a1 -- b1 -- ..}] // => // { merge([a0, a1]) -- merge([b0, b1]) } - else if (cmds.forall(_.isInstanceOf[CSeq])) { + else if cmds.forall(_.isInstanceOf[CSeq]) then { CSeq.smart( cmds.collect({ case CSeq(cs) => cs }).transpose.map(mergePar(_)) ) @@ -312,7 +312,7 @@ object LowerUnroll extends PartialTransformer { // [for (r) { b0 } combine { c0 }, for (r) { b1 } combine { c1 }, ...] // => // for (r) { merge([b0, b1, ...]) } combine { merge(c0, c1, ...) } - else if (cmds.forall(_.isInstanceOf[CFor])) { + else if cmds.forall(_.isInstanceOf[CFor]) then { val fors = cmds.map[CFor](_.asInstanceOf[CFor]) val merged = fors .groupBy(f => f.range) @@ -332,7 +332,7 @@ object LowerUnroll extends PartialTransformer { // [while (c) { b0 }, while (c) { b1 }, ...] // => // while (c) { merge([b0, b1]) } - else if (cmds.forall(_.isInstanceOf[CWhile])) { + else if cmds.forall(_.isInstanceOf[CWhile]) then { val whiles = cmds.map[CWhile](_.asInstanceOf[CWhile]) val merged = whiles .groupBy(w => w.cond) @@ -347,7 +347,7 @@ object LowerUnroll extends PartialTransformer { // [if (c) { t0 } else { f0 }, if (c) { t1 } else { f1 }, ...] // => // if (c) { merge([t0, t1]) } else { merge([f0, f1]) } - else if (cmds.forall(_.isInstanceOf[CIf])) { + else if cmds.forall(_.isInstanceOf[CIf]) then { val ifs = cmds.map[CIf](_.asInstanceOf[CIf]) val merged = ifs .groupBy(i => i.cond) @@ -366,7 +366,7 @@ object LowerUnroll extends PartialTransformer { // [ {a0; b0, ...}, {a1; b1, ...} ] // => // merge([a0, a1, ...]); merge([b0, b1, ...]) ... - else if (cmds.forall(_.isInstanceOf[CPar])) { + else if cmds.forall(_.isInstanceOf[CPar]) then { CPar( cmds.collect({ case CPar(cs) => cs }).transpose.map(mergePar(_)) ) @@ -374,13 +374,13 @@ object LowerUnroll extends PartialTransformer { // [ { b0 }, { b1 } ...] // => // { merge([b0, b1 ...]) } - else if (cmds.forall(_.isInstanceOf[CBlock])) { + else if cmds.forall(_.isInstanceOf[CBlock]) then { CBlock(mergePar(cmds.collect({ case CBlock(cmd) => cmd }))) } // [ r0 += l1, r0 += l2, r1 += l3, r1 += l4 ... ] // => // [ r1 += l1 + l2 + ...; r1 += l3 + l4 ] - else if (cmds.forall(_.isInstanceOf[CReduce])) { + else if cmds.forall(_.isInstanceOf[CReduce]) then { val creds = cmds.collect[CReduce]({ case c: CReduce => c }) val merged = creds .groupBy(c => c.lhs) @@ -429,7 +429,7 @@ object LowerUnroll extends PartialTransformer { // If we got exactly one value in TVal, that means that the returned // expression corresponds exactly to the input bank. In this case, // don't generate a condition. - if (allExprs.size == 1) { + if allExprs.size == 1 then { val elems = allExprs.toArray val elem = elems(0)._2 val (nE, nEnv) = rewriteE(elem)(env) @@ -460,7 +460,7 @@ object LowerUnroll extends PartialTransformer { // Transform reads from memories case (c @ CLet(bind, typ, Some(EArrAccess(arrId, idxs))), env) => { val transformer = env.viewGet(arrId) - if (transformer.isDefined && !transformer.get.isDecl) { + if transformer.isDefined && !transformer.get.isDecl then { val t = transformer.get val allExprs = t(idxs.zip(getBanks(arrId, idxs)(env))) @@ -505,7 +505,7 @@ object LowerUnroll extends PartialTransformer { } // Handle case for initialized, unbanked memories. case (c @ CLet(_, Some(ta: TArray), _), env) => { - if (ta.dims.exists({ case (_, bank) => bank > 1 })) { + if ta.dims.exists({ case (_, bank) => bank > 1 }) then { throw NotImplemented("Banked local arrays with initial values") } c -> env @@ -516,7 +516,7 @@ object LowerUnroll extends PartialTransformer { // Don't rewrite this name if there is already a binding in // rewrite map. val rewriteVal = env.rewriteGet(id) - val (cmd, nEnv) = if (rewriteVal.isDefined) { + val (cmd, nEnv) = if rewriteVal.isDefined then { c.copy(e = nInit).withPos(c) -> env } else { val suf = env.idxMap.toList.sortBy(_._1.v).map(_._2).mkString("_") @@ -543,7 +543,7 @@ object LowerUnroll extends PartialTransformer { (CEmpty, nEnv) } case (c @ CFor(range, _, par, combine), env) => { - if (range.u == 1) { + if range.u == 1 then { val ((nPar, nComb), _) = env.withScopeAndRet(env => { val (p, e1) = rewriteC(par)(env) val (c, _) = rewriteC(combine)(e1) @@ -573,7 +573,7 @@ object LowerUnroll extends PartialTransformer { l match { case EArrAccess(id, idxs) => { val transformer = env.viewGet(id) - if (transformer.isDefined && !transformer.get.isDecl) { + if transformer.isDefined && !transformer.get.isDecl then { val t = transformer.get val allExprs = t(idxs.zip(getBanks(id, idxs)(env))) @@ -599,7 +599,7 @@ object LowerUnroll extends PartialTransformer { val nIdxs = idxs.map(rewriteE(_)(env)._1) val nCmd = c.copy(lhs = e.copy(idxs = nIdxs)) val transformer = env.viewGet(id) - if (transformer.isDefined && !transformer.get.isDecl) { + if transformer.isDefined && !transformer.get.isDecl then { val t = transformer.get val allExprs = t(idxs.zip(getBanks(id, idxs)(env))) @@ -646,14 +646,14 @@ object LowerUnroll extends PartialTransformer { case (e @ EVar(id), env) => { val varRewrite = env.rewriteGet(id) val arrRewrite = env.viewGet(id) - if (varRewrite.isDefined) { + if varRewrite.isDefined then { EVar(varRewrite.get) -> env - } else if (arrRewrite.isDefined && !arrRewrite.get.isDecl) { + } else if arrRewrite.isDefined && !arrRewrite.get.isDecl then { val TArray(_, dims, _) = env.dimsGet(id) // Construct a fake access expression val t = arrRewrite.get val map = t(dims.map(_ => EInt(0) -> None)) - if (map.size != 1) { + if map.size != 1 then { throw Impossible(s"Memory parameter is banked: $id.", e.pos) } val List((_, acc)) = map.toList diff --git a/src/main/scala/passes/RewriteView.scala b/src/main/scala/passes/RewriteView.scala index 712603ce..dfe1a4fb 100644 --- a/src/main/scala/passes/RewriteView.scala +++ b/src/main/scala/passes/RewriteView.scala @@ -26,7 +26,7 @@ object RewriteView extends TypedPartialTransformer { extends ScopeManager[ViewEnv] with Tracker[Id, Seq[Expr] => Expr, ViewEnv] { def merge(that: ViewEnv) = { - if (this.map.keys != that.map.keys) + if this.map.keys != that.map.keys then throw Impossible("Tried to merge ViewEnvs with different keys.") this } @@ -62,7 +62,7 @@ object RewriteView extends TypedPartialTransformer { // Rewrite the indexing expressions val (nIdxs, nEnv) = super.rewriteESeq(idxs)(env) val rewrite = nEnv.get(arrId) - if (rewrite.isDefined) { + if rewrite.isDefined then { rewriteE((rewrite.get)(nIdxs.toSeq))(nEnv) } else { acc.copy(idxs = nIdxs.toSeq) -> nEnv @@ -76,7 +76,7 @@ object RewriteView extends TypedPartialTransformer { (bank, nIdx) -> env1 }: ((Int, Expr), Env) => ((Int, Expr), Env))(bankIdxs)(env) - if (nEnv.get(arrId).isDefined) { + if nEnv.get(arrId).isDefined then { throw NotImplemented("Rewriting physical accesses on views.") } acc.copy(bankIdxs = nBankIdxs.toSeq) -> nEnv diff --git a/src/main/scala/passes/Sequentialize.scala b/src/main/scala/passes/Sequentialize.scala index 8158e2c3..0386d527 100644 --- a/src/main/scala/passes/Sequentialize.scala +++ b/src/main/scala/passes/Sequentialize.scala @@ -47,11 +47,11 @@ object Sequentialize extends PartialTransformer { override def rewriteLVal(e: Expr)(implicit env: SeqEnv): (Expr, SeqEnv) = e match { case EVar(id) => { - val env1 = if (env.useLHS) env.addUse(id) else env + val env1 = if env.useLHS then env.addUse(id) else env e -> env1.addDefine(id) } case e @ EArrAccess(id, idxs) => { - val env1 = if (env.useLHS) env.addUse(id) else env + val env1 = if env.useLHS then env.addUse(id) else env val (nIdxs, e1) = rewriteESeq(idxs)(env1) e.copy(idxs = nIdxs.toSeq) -> e1.addDefine(id) } @@ -83,7 +83,7 @@ object Sequentialize extends PartialTransformer { var curUses: SetM[Id] = SetM() val newSeq: Buffer[Buffer[Command]] = Buffer(Buffer()) - for (cmd <- cmds) { + for cmd <- cmds do { val (nCmd, e1) = rewriteC(cmd)(emptyEnv) /* System.err.println(Pretty.emitCmd(cmd)(false).pretty) System.err.println(s""" @@ -98,8 +98,8 @@ object Sequentialize extends PartialTransformer { """) */ // If there are no conflicts, add this to the current parallel // block. - if (curDefines.intersect(e1.uses).isEmpty && - curUses.intersect(e1.defines).isEmpty) { + if curDefines.intersect(e1.uses).isEmpty && + curUses.intersect(e1.defines).isEmpty then { newSeq.last += nCmd } else { curUses = SetM() diff --git a/src/main/scala/passes/WellFormedCheck.scala b/src/main/scala/passes/WellFormedCheck.scala index ca0d2aaf..7e4ae35d 100644 --- a/src/main/scala/passes/WellFormedCheck.scala +++ b/src/main/scala/passes/WellFormedCheck.scala @@ -34,7 +34,7 @@ object WellFormedChecker { def canHaveFunctionInUnroll(k: Id): Boolean = { this.get(k) match { case Some(FuncDef(_, args, _, _)) => - if (this.insideUnroll) { + if this.insideUnroll then { args.foldLeft(true)({ (r, arg) => arg.typ match { case TArray(_, _, _) => false @@ -48,7 +48,7 @@ object WellFormedChecker { } } - private final case object WFCheck extends PartialChecker { + private case object WFCheck extends PartialChecker { type Env = WFEnv val emptyEnv = WFEnv() diff --git a/src/main/scala/typechecker/AffineCheck.scala b/src/main/scala/typechecker/AffineCheck.scala index b8e8b0b0..7d910aa5 100644 --- a/src/main/scala/typechecker/AffineCheck.scala +++ b/src/main/scala/typechecker/AffineCheck.scala @@ -76,7 +76,7 @@ object AffineChecker { def check(p: Prog) = AffineChecker.check(p) - private final case object AffineChecker extends PartialChecker { + private case object AffineChecker extends PartialChecker { type Env = AffineEnv.Environment @@ -138,7 +138,7 @@ object AffineChecker { idx.typ.get match { // Index is an index type. case TIndex((s, e), _) => - if ((e - s) % dims(dim)._2 != 0) + if (e - s) % dims(dim)._2 != 0 then throw BankUnrollInvalid(arrId, dims(dim)._2, e - s)(idx.pos) else (bres * (e - s), Range(s, e) +: consume) @@ -147,7 +147,7 @@ object AffineChecker { (bres * 1, Seq((v % dims(dim)._2).toInt) +: consume) // Index is a dynamic number. case _: TSizedInt => - if (dims(dim)._2 != 1) + if dims(dim)._2 != 1 then throw InvalidDynamicIndex(arrId, dims(dim)._2) else (bres * 1, Seq(0) +: consume) @@ -182,12 +182,12 @@ object AffineChecker { override def checkLVal(e: Expr)(implicit env: Env) = e match { case acc @ EArrAccess(id, idxs) => { // This only triggers for l-values. - val TArray(_, dims, _) = id.typ.get + val TArray(_, dims, _) = id.typ.get : @unchecked acc.consumable match { case Some(Annotations.ShouldConsume) => { val (bres, consumeList) = getConsumeList(idxs, dims)(id) // Check if the accessors generated enough copies for the context. - if (bres != env.getResources) + if bres != env.getResources then throw InsufficientResourcesInUnrollContext( env.getResources, bres, @@ -245,14 +245,14 @@ object AffineChecker { case (CView(id, arrId, _), env) => { // Add gadget for the view and add missing well formedness checks // from new type checker - val TArray(_, adims, _) = arrId.typ.get - val TArray(_, vdims, _) = id.typ.get + val TArray(_, adims, _) = arrId.typ.get : @unchecked + val TArray(_, vdims, _) = id.typ.get : @unchecked val shrinks = vdims.map(_._2) env.add(id, viewGadget(env(arrId), shrinks, adims)) } case (CSplit(id, arrId, _), env) => { - val TArray(_, adims, _) = arrId.typ.get - val TArray(_, vdims, _) = id.typ.get + val TArray(_, adims, _) = arrId.typ.get : @unchecked + val TArray(_, vdims, _) = id.typ.get : @unchecked env.add(id, splitGadget(env(arrId), adims, vdims)) } } @@ -277,7 +277,7 @@ object AffineChecker { }) } case (expr @ EArrAccess(id, idxs), env) => { - val TArray(_, dims, _) = id.typ.get + val TArray(_, dims, _) = id.typ.get : @unchecked expr.consumable match { case None => throw Impossible( diff --git a/src/main/scala/typechecker/AffineEnv.scala b/src/main/scala/typechecker/AffineEnv.scala index 3d360bd4..d73fd3ae 100644 --- a/src/main/scala/typechecker/AffineEnv.scala +++ b/src/main/scala/typechecker/AffineEnv.scala @@ -118,7 +118,7 @@ object AffineEnv { override def toString = { val lst = - for { (ps, gs) <- phyRes.iterator.zip(gadgetMap.iterator) } yield ( + for (ps, gs) <- phyRes.iterator.zip(gadgetMap.iterator) yield ( ps.map({ case (k, v) => s"$k -> $v" }).mkString(", "), gs.keys.mkString("{", ", ", "}") ) @@ -161,7 +161,7 @@ object AffineEnv { val (oldGads, nextGads) = (this.gadgetMap.keys, next.gadgetMap.keys) // The next environment should bind all resources in this env. - if (oldRes.subsetOf(nextRes) == false) { + if oldRes.subsetOf(nextRes) == false then { throw Impossible( "New environment is missing resources bound in old env." + s"\n\nOld Env: ${oldRes}" + @@ -169,7 +169,7 @@ object AffineEnv { s"\n\nMissing: ${oldRes diff nextRes}" ) } - if (oldRes.subsetOf(nextRes) == false) { + if oldRes.subsetOf(nextRes) == false then { throw Impossible( "New environment is missing gadgets bound in old env." + s"\n\nOld Env: ${oldGads}" + @@ -195,10 +195,10 @@ object AffineEnv { Env(phyRes.addScope, gadgetMap.addScope)(res * resources) } def endScope(resources: Int) = { - val scopes = for { + val scopes = for (pDefs, pMap) <- phyRes.endScope (gDefs, gMap) <- gadgetMap.endScope - } yield (Env(pMap, gMap)(res / resources), pDefs, gDefs) + yield (Env(pMap, gMap)(res / resources), pDefs, gDefs) scopes.getOrThrow(Impossible("Removed topmost scope")) } diff --git a/src/main/scala/typechecker/CapabilityChecker.scala b/src/main/scala/typechecker/CapabilityChecker.scala index 4e710cad..28d58fb7 100644 --- a/src/main/scala/typechecker/CapabilityChecker.scala +++ b/src/main/scala/typechecker/CapabilityChecker.scala @@ -10,11 +10,11 @@ import CompilerError._ import CapabilityEnv._ import Checker._ -object CapabilityChecker { +object CapabilityChecker: def check(p: Prog) = CapChecker.check(p) - private final case object CapChecker extends PartialChecker { + private case object CapChecker extends PartialChecker: type Env = CapabilityEnv @@ -72,6 +72,3 @@ object CapabilityChecker { mergeCheckE(myCheckE)(expr, env) override def checkC(cmd: Command)(implicit env: Env) = mergeCheckC(myCheckC)(cmd, env) - - } -} diff --git a/src/main/scala/typechecker/CapabilityEnv.scala b/src/main/scala/typechecker/CapabilityEnv.scala index 756eccef..b2fc841b 100644 --- a/src/main/scala/typechecker/CapabilityEnv.scala +++ b/src/main/scala/typechecker/CapabilityEnv.scala @@ -21,8 +21,8 @@ object CapabilityEnv { ) extends CapabilityEnv { def get(e: Expr) = - if (readSet.contains(e)) Some(Read) - else if (writeSet.contains(e)) Some(Write) + if readSet.contains(e) then Some(Read) + else if writeSet.contains(e) then Some(Write) else None def add(e: Expr, cap: Capability) = cap match { @@ -31,10 +31,10 @@ object CapabilityEnv { } def endScope = { - val scopes = for { + val scopes = for (_, rSet) <- readSet.endScope; (_, wSet) <- writeSet.endScope - } yield this.copy(readSet = rSet, writeSet = wSet) + yield this.copy(readSet = rSet, writeSet = wSet) scopes.getOrThrow(Impossible("Removed topmost scope")) } diff --git a/src/main/scala/typechecker/Gadget.scala b/src/main/scala/typechecker/Gadget.scala index 292d44b2..c81df434 100644 --- a/src/main/scala/typechecker/Gadget.scala +++ b/src/main/scala/typechecker/Gadget.scala @@ -30,13 +30,13 @@ object Gadgets { case class ResourceGadget(resource: Id, banks: Seq[Int]) extends Gadget { private def cross[A](acc: Seq[Seq[A]], l: Seq[A]): Seq[Seq[A]] = { - for { a <- acc; el <- l } yield a :+ el + for a <- acc; el <- l yield a :+ el } override def toString = resource.toString private def hyperBankToBank(hyperBanks: Seq[Int]) = { - if (hyperBanks.length != banks.length) + if hyperBanks.length != banks.length then throw Impossible("hyperbank size is different from original banking") hyperBanks diff --git a/src/main/scala/typechecker/Info.scala b/src/main/scala/typechecker/Info.scala index 6ee5fc91..bbaa30df 100644 --- a/src/main/scala/typechecker/Info.scala +++ b/src/main/scala/typechecker/Info.scala @@ -44,7 +44,7 @@ object Info { val resourceMS = fromSeq(resources) val afterConsume = remBanks.diff(resourceMS) val hasRequired = afterConsume.forall({ case (_, v) => v >= 0 }) - if (hasRequired == false) { + if hasRequired == false then { val bank = afterConsume.find({ case (_, v) => v < 0 }).get._1 throw AlreadyConsumed( id, @@ -68,7 +68,7 @@ object Info { object ArrayInfo { private def cross[A](acc: Seq[Seq[A]], l: Seq[A]): Seq[Seq[A]] = - for { a <- acc; el <- l } yield a :+ el + for a <- acc; el <- l yield a :+ el private def hyperBankToBank(maxBanks: Iterable[Int])(hyperBank: Seq[Int]) = hyperBank diff --git a/src/main/scala/typechecker/Subtyping.scala b/src/main/scala/typechecker/Subtyping.scala index 54782a04..6f0b1f35 100644 --- a/src/main/scala/typechecker/Subtyping.scala +++ b/src/main/scala/typechecker/Subtyping.scala @@ -88,12 +88,12 @@ object Subtyping { case Some(fun) => Some(TRational(fun(v1.toDouble, v2.toDouble).toString)) case None => - if (bitsNeeded(v1.toDouble.toInt) > bitsNeeded(v2.toDouble.toInt)) + if bitsNeeded(v1.toDouble.toInt) > bitsNeeded(v2.toDouble.toInt) then Some(TRational(v1)) else Some(TRational(v2)) } case (TSizedInt(s1, un1), TSizedInt(s2, un2)) => - if (un1 == un2) Some(TSizedInt(max(s1, s2), un1)) + if un1 == un2 then Some(TSizedInt(max(s1, s2), un1)) else None case (TSizedInt(s, un), TStaticInt(v)) => Some(TSizedInt(max(s, bitsNeeded(v)), un)) @@ -113,14 +113,14 @@ object Subtyping { ) ) case (TFixed(t1, i1, un1), TFixed(t2, i2, un2)) => - if (un1 == un2) + if un1 == un2 then Some(TFixed(max(t1 - i1, t2 - i2) + max(i1, i2), max(i1, i2), un1)) else None case (ti1: TIndex, ti2: TIndex) => Some( TSizedInt(max(bitsNeeded(ti1.maxVal), bitsNeeded(ti2.maxVal)), false) ) - case (t1, t2) => if (t1 == t2) Some(t1) else None + case (t1, t2) => if t1 == t2 then Some(t1) else None } /** @@ -128,7 +128,7 @@ object Subtyping { */ def joinOf(t1: Type, t2: Type, op: BOp): Option[Type] = { val j1 = joinOfHelper(t1, t2, op) - if (j1.isDefined) j1 + if j1.isDefined then j1 else joinOfHelper(t2, t1, op) } diff --git a/src/main/scala/typechecker/TypeCheck.scala b/src/main/scala/typechecker/TypeCheck.scala index 32e4c619..a8b8d9c1 100644 --- a/src/main/scala/typechecker/TypeCheck.scala +++ b/src/main/scala/typechecker/TypeCheck.scala @@ -65,9 +65,9 @@ object TypeChecker { private def checkB(t1: Type, t2: Type, op: BOp) = op match { case _: EqOp => { - if (t1.isInstanceOf[TArray]) + if t1.isInstanceOf[TArray] then throw UnexpectedType(op.pos, op.toString, "primitive types", t1) - else if (joinOf(t1, t2, op).isDefined) + else if joinOf(t1, t2, op).isDefined then TBool() else throw NoJoin(op.pos, op.toString, t1, t2) @@ -118,7 +118,7 @@ object TypeChecker { e: Expr )(implicit env: Environment): (Type, Environment) = { val (typ, nEnv) = _checkE(e) - if (e.typ.isDefined && typ != e.typ.get) { + if e.typ.isDefined && typ != e.typ.get then { throw Impossible( s"$e was type checked multiple times and given different types." ) @@ -139,7 +139,7 @@ object TypeChecker { case EArrLiteral(_) => throw NotInBinder(expr.pos, "Array Literal") case ECast(e, castType) => { val (typ, nEnv) = checkE(e) - if (safeCast(typ, castType) == false) { + if safeCast(typ, castType) == false then { scribe.warn { (s"Casting $typ to $castType which may lose precision.", expr) } @@ -159,7 +159,7 @@ object TypeChecker { case EApp(f, args) => env(f) match { case TFun(argTypes, retType) => { - if (argTypes.length != args.length) { + if argTypes.length != args.length then { throw ArgLengthMismatch(expr.pos, argTypes.length, args.length) } @@ -168,7 +168,7 @@ object TypeChecker { .foldLeft(env)({ case (e, (arg, expectedTyp)) => { val (typ, e1) = checkE(arg)(e); - if (isSubtype(typ, expectedTyp) == false) { + if isSubtype(typ, expectedTyp) == false then { throw UnexpectedSubtype( arg.pos, "parameter", @@ -195,7 +195,7 @@ object TypeChecker { case EArrAccess(id, idxs) => env(id).matchOrError(expr.pos, "array access", s"array type") { case TArray(typ, dims, _) => { - if (dims.length != idxs.length) { + if dims.length != idxs.length then { throw IncorrectAccessDims(id, dims.length, idxs.length) } idxs.foldLeft(env)((env, idx) => { @@ -220,7 +220,7 @@ object TypeChecker { case EPhysAccess(id, bankIdxs) => env(id).matchOrError(expr.pos, "array access", s"array type") { case TArray(typ, dims, _) => { - if (dims.length != bankIdxs.length) { + if dims.length != bankIdxs.length then { throw IncorrectAccessDims(id, dims.length, bankIdxs.length) } bankIdxs.foldLeft(env)((env, bankIdx) => { @@ -248,7 +248,7 @@ object TypeChecker { val (len, bank) = arrDim // Shrinking factor must be a factor of banking for the dimension - if (shrink.isDefined && (shrink.get > bank || bank % shrink.get != 0)) { + if shrink.isDefined && (shrink.get > bank || bank % shrink.get != 0) then { throw InvalidShrinkWidth(view.pos, bank, shrink.get) } @@ -257,12 +257,12 @@ object TypeChecker { // Get the indexing expression val idx = suf match { case Aligned(fac, idx) => - if (newBank > fac) { + if newBank > fac then { throw InvalidAlignFactor( suf.pos, s"Invalid align factor. Banking factor $newBank is bigger than alignment factor $fac." ) - } else if (fac % newBank != 0) { + } else if fac % newBank != 0 then { throw InvalidAlignFactor( suf.pos, s"Invalid align factor. Banking factor $newBank not a factor of the alignment factor $fac." @@ -285,7 +285,7 @@ object TypeChecker { // Only loops without sequencing may be pipelined. body match { case _: CSeq => - if (enabled) { + if enabled then { throw PipelineError(loop.pos) } case _ => {} @@ -310,7 +310,7 @@ object TypeChecker { case CWhile(cond, pipeline, body) => { checkPipeline(pipeline, cmd, body) val (cTyp, e1) = checkE(cond)(env) - if (cTyp != TBool()) { + if cTyp != TBool() then { throw UnexpectedType( cond.pos, "while condition", @@ -323,14 +323,14 @@ object TypeChecker { case CUpdate(lhs, rhs) => { val (t1, e1) = checkE(lhs) val (t2, e2) = checkE(rhs)(e1) - if (isSubtype(t2, t1)) e2 + if isSubtype(t2, t1) then e2 else throw UnexpectedSubtype(rhs.pos, "assignment", t1, t2) } case CReduce(_, l, r) => { val (t1, e1) = checkE(l) val (t2, e2) = checkE(r)(e1) - if (isSubtype(t2, t1)) e2 + if isSubtype(t2, t1) then e2 else throw UnexpectedSubtype(r.pos, "reduction operator", t1, t2) } case l @ CLet(id, typ, Some(EArrLiteral(idxs))) => { @@ -383,10 +383,10 @@ object TypeChecker { // required fields. expTypes.keys.foreach(field => { val (eTyp, acTyp) = (expTypes(field), actualTypes.get(field)) - if (acTyp.isDefined == false) { + if acTyp.isDefined == false then { throw MissingField(exp.pos, name, field) } - if (isSubtype(acTyp.get, eTyp) == false) { + if isSubtype(acTyp.get, eTyp) == false then { throw UnexpectedType( fs(field).pos, "record literal", @@ -422,7 +422,7 @@ object TypeChecker { // Check if type of expression is a subtype of the annotated type. rTyp match { case Some(t2) => { - if (isSubtype(t, t2)) + if isSubtype(t, t2) then e1.add(id, t2) else throw UnexpectedSubtype(exp.pos, "let", t2, t) @@ -464,7 +464,7 @@ object TypeChecker { env(arrId) match { case TArray(typ, adims, port) => { val (vlen, alen) = (vdims.length, adims.length) - if (vlen != alen) { + if vlen != alen then { throw IncorrectAccessDims(arrId, alen, vlen) } @@ -491,7 +491,7 @@ object TypeChecker { env(arrId) match { case TArray(typ, adims, ports) => { val (vlen, alen) = (dims.length, adims.length) - if (vlen != alen) { + if vlen != alen then { throw IncorrectAccessDims(arrId, alen, vlen) } @@ -507,7 +507,7 @@ object TypeChecker { .zip(dims) .flatMap({ case ((dim, bank), n) if n > 0 => { - if (bank % n == 0) { + if bank % n == 0 then { List((n, n), (dim / n, bank / n)) } else { throw InvalidSplitFactor(id, arrId, n, bank, dim) @@ -530,7 +530,7 @@ object TypeChecker { case CReturn(expr) => { val retType = env.getReturn.get val (t, e) = checkE(expr) - if (isSubtype(t, retType) == false) { + if isSubtype(t, retType) == false then { throw UnexpectedSubtype(expr.pos, "return", retType, t) } e diff --git a/src/main/scala/typechecker/TypeEnv.scala b/src/main/scala/typechecker/TypeEnv.scala index 675d8fb3..8f77e79f 100644 --- a/src/main/scala/typechecker/TypeEnv.scala +++ b/src/main/scala/typechecker/TypeEnv.scala @@ -114,9 +114,9 @@ object TypeEnv { Env(typeMap.addScope, typeDefMap, retType) } def endScope = { - val scopes = for { + val scopes = for (_, tMap) <- typeMap.endScope - } yield Env(tMap, typeDefMap, retType) + yield Env(tMap, typeDefMap, retType) scopes.getOrThrow(Impossible("Removed topmost scope")) } diff --git a/src/project/build.properties b/src/project/build.properties index c8fcab54..abbbce5d 100644 --- a/src/project/build.properties +++ b/src/project/build.properties @@ -1 +1 @@ -sbt.version=1.6.2 +sbt.version=1.9.8 diff --git a/src/test/scala/ParsingPositive.scala b/src/test/scala/ParsingPositive.scala index ca8707fc..6d441e5d 100644 --- a/src/test/scala/ParsingPositive.scala +++ b/src/test/scala/ParsingPositive.scala @@ -1,8 +1,9 @@ package fuselang import TestUtils._ +import org.scalatest.funsuite.AnyFunSuite -class ParsingTests extends org.scalatest.FunSuite { +class ParsingTests extends AnyFunSuite { test("numbers") { parseAst("1;") parseAst("1.25;") diff --git a/src/test/scala/TypeCheckerSpec.scala b/src/test/scala/TypeCheckerSpec.scala index 0a650c32..196754f1 100644 --- a/src/test/scala/TypeCheckerSpec.scala +++ b/src/test/scala/TypeCheckerSpec.scala @@ -3,9 +3,9 @@ package fuselang import fuselang.common._ import TestUtils._ import Errors._ -import org.scalatest.FunSpec +import org.scalatest.funspec.AnyFunSpec -class TypeCheckerSpec extends FunSpec { +class TypeCheckerSpec extends AnyFunSpec { // Suppress logging. common.Logger.setLogLevel(scribe.Level.Error)