diff --git a/build.sbt b/build.sbt index 775edf3e..99cf6f3d 100644 --- a/build.sbt +++ b/build.sbt @@ -49,9 +49,6 @@ Compile / resourceGenerators += Def.task { } /* sbt-assembly configuration: build an executable jar. */ -//assembly / assemblyOption := (assembly / assemblyOption).value.copy( -// prependShellScript = Some(sbtassembly.AssemblyPlugin.defaultShellScript) -//) ThisBuild / assemblyPrependShellScript := Some(sbtassembly.AssemblyPlugin.defaultShellScript) assembly / assemblyJarName := "fuse.jar" assembly / test := {} diff --git a/src/main/scala/Compiler.scala b/src/main/scala/Compiler.scala index 8a8babed..9c049101 100644 --- a/src/main/scala/Compiler.scala +++ b/src/main/scala/Compiler.scala @@ -9,7 +9,7 @@ import Configuration._ import Syntax._ import Transformer.{PartialTransformer, TypedPartialTransformer} -object Compiler { +object Compiler: // Transformers to execute *before* type checking. val preTransformers: List[(String, PartialTransformer)] = List( @@ -28,28 +28,25 @@ object Compiler { "Add bitwidth" -> (passes.AddBitWidth, true) ) - def showDebug(ast: Prog, pass: String, c: Config): Unit = { - if c.passDebug then { + def showDebug(ast: Prog, pass: String, c: Config): Unit = + if c.passDebug then val top = ("=" * 15) + pass + ("=" * 15) println(top) println(Pretty.emitProg(ast)(c.logLevel == scribe.Level.Debug).trim) println("=" * top.length) - } - } - def toBackend(str: BackendOption): fuselang.backend.Backend = str match { + def toBackend(str: BackendOption): fuselang.backend.Backend = str match case Vivado => backend.VivadoBackend case Cpp => backend.CppRunnable case Calyx => backend.calyx.CalyxBackend - } - def checkStringWithError(prog: String, c: Config = emptyConf) = { + def checkStringWithError(prog: String, c: Config = emptyConf) = val preAst = Parser(prog).parse() showDebug(preAst, "Original", c) // Run pre transformers if lowering is enabled - val ast = if c.enableLowering then { + val ast = if c.enableLowering then preTransformers.foldLeft(preAst)({ case (ast, (name, pass)) => { val newAst = pass.rewrite(ast) @@ -70,9 +67,8 @@ object Compiler { } */ } }) - } else { + else preAst - } passes.WellFormedChecker.check(ast) typechecker.TypeChecker.typeCheck(ast); showDebug(ast, "Type Checking", c) @@ -83,9 +79,8 @@ object Compiler { showDebug(ast, "Capability Checking", c) typechecker.AffineChecker.check(ast); // Doesn't modify the AST ast - } - def codegen(ast: Prog, c: Config = emptyConf) = { + def codegen(ast: Prog, c: Config = emptyConf) = // Filter out transformers not running in this mode val toRun = postTransformers.filter({ case (_, (_, onlyLower)) => { @@ -101,14 +96,12 @@ object Compiler { } }) toBackend(c.backend).emit(transformedAst, c) - } // Outputs red text to the console - def red(txt: String): String = { + def red(txt: String): String = Console.RED + txt + Console.RESET - } - def compileString(prog: String, c: Config): Either[String, String] = { + def compileString(prog: String, c: Config): Either[String, String] = Try(codegen(checkStringWithError(prog, c), c)).toEither.left .map(err => { scribe.info(err.getStackTrace().take(10).mkString("\n")) @@ -136,13 +129,12 @@ object Compiler { val commentPre = toBackend(c.backend).commentPrefix s"$commentPre $meta\n" + out }) - } def compileStringToFile( prog: String, c: Config, out: String - ): Either[String, Path] = { + ): Either[String, Path] = compileString(prog, c).map(p => { Files.write( @@ -153,6 +145,4 @@ object Compiler { StandardOpenOption.WRITE ) }) - } -} diff --git a/src/main/scala/GenerateExec.scala b/src/main/scala/GenerateExec.scala index 2c78806c..6e05e9e3 100644 --- a/src/main/scala/GenerateExec.scala +++ b/src/main/scala/GenerateExec.scala @@ -10,7 +10,7 @@ import common.CompilerError.HeaderMissing * Provides utilities to compile a program and link it with headers required * by the CppRunnable backend. */ -object GenerateExec { +object GenerateExec: // TODO(rachit): Move this to build.sbt val headers = List("parser.cpp", "json.hpp") @@ -19,18 +19,18 @@ object GenerateExec { // Not the compiler directory, check if the fallback directory has been setup. - if Files.exists(headerLocation) == false then { + if Files.exists(headerLocation) == false then // Fallback for headers not setup. Unpack headers from JAR file. headerLocation = headerFallbackLocation - if Files.exists(headerFallbackLocation) == false then { + if Files.exists(headerFallbackLocation) == false then scribe.warn( s"Missing headers required for `fuse run`." + s" Unpacking from JAR file into $headerFallbackLocation." ) val dir = Files.createDirectory(headerFallbackLocation) - for header <- headers do { + for header <- headers do val stream = getClass.getResourceAsStream(s"/headers/$header") val hdrSource = Source.fromInputStream(stream).toArray.map(_.toByte) Files.write( @@ -39,9 +39,6 @@ object GenerateExec { StandardOpenOption.CREATE_NEW, StandardOpenOption.WRITE ) - } - } - } /** * Generates an executable object [[out]]. Assumes that [[src]] is a valid @@ -54,14 +51,12 @@ object GenerateExec { src: Path, out: String, compilerOpts: List[String] - ): Either[String, Int] = { + ): Either[String, Int] = // Make sure all headers are downloaded. - for header <- headers do { - if Files.exists(headerLocation.resolve(header)) == false then { + for header <- headers do + if Files.exists(headerLocation.resolve(header)) == false then throw HeaderMissing(header, headerLocation.toString) - } - } val CXX = Seq("g++", "-g", "--std=c++14", "-Wall", "-I", headerLocation.toString) ++ compilerOpts @@ -75,10 +70,7 @@ object GenerateExec { scribe.info(cmd.mkString(" ")) val status = cmd ! logger - if status != 0 then { + if status != 0 then Left(s"Failed to generate the executable $out.\n${stderr}") - } else { + else Right(status) - } - } -} diff --git a/src/main/scala/Main.scala b/src/main/scala/Main.scala index 914a3bb4..3199a142 100644 --- a/src/main/scala/Main.scala +++ b/src/main/scala/Main.scala @@ -31,7 +31,7 @@ object Main: }) .toMap - val parser = new scopt.OptionParser[Config]("fuse") { + val parser = new scopt.OptionParser[Config]("fuse"): head(s"Dahlia (sha = ${meta("git.hash")}, status = ${meta("git.status")})") @@ -112,16 +112,14 @@ object Main: .action((f, c) => c.copy(output = Some(f))) .text("Name of the output artifact.") ) - } def runWithConfig(conf: Config): Either[String, Int] = type ErrString = String val path = conf.srcFile.toPath - val prog = Files.exists(path) match { + val prog = Files.exists(path) match case true => Right(new String(Files.readAllBytes(path))) case false => Left(s"$path: No such file in working directory") - } val cppPath: Either[ErrString, Option[Path]] = prog.flatMap(prog => conf.output match { diff --git a/src/main/scala/Utils.scala b/src/main/scala/Utils.scala index cbd305a9..c3874a8f 100644 --- a/src/main/scala/Utils.scala +++ b/src/main/scala/Utils.scala @@ -2,70 +2,57 @@ package fuselang import scala.{PartialFunction => PF} import scala.math.{log10, ceil} -object Utils { +object Utils: - implicit class RichOption[A](opt: => Option[A]) { - def getOrThrow[T <: Throwable](except: T) = opt match { + implicit class RichOption[A](opt: => Option[A]): + def getOrThrow[T <: Throwable](except: T) = opt match case Some(v) => v case None => throw except - } - } // https://codereview.stackexchange.com/questions/14561/matching-bigints-in-scala // TODO: This can overflow and result in an runtime exception - object Big { + object Big: def unapply(n: BigInt) = Some(n.toInt) - } - def bitsNeeded(n: Int): Int = n match { + def bitsNeeded(n: Int): Int = n match case 0 => 1 case n if n > 0 => ceil(log10(n + 1) / log10(2)).toInt case n if n < 0 => bitsNeeded(n.abs) + 1 - } - def bitsNeeded(n: BigInt): Int = n match { + def bitsNeeded(n: BigInt): Int = n match case Big(0) => 1 case n if n > 0 => ceil(log10((n + 1).toDouble) / log10(2)).toInt case n if n < 0 => bitsNeeded(n.abs) + 1 - } - def cartesianProduct[T](llst: Seq[Seq[T]]): Seq[Seq[T]] = { + def cartesianProduct[T](llst: Seq[Seq[T]]): Seq[Seq[T]] = def pel(e: T, ll: Seq[Seq[T]], a: Seq[Seq[T]] = Nil): Seq[Seq[T]] = - ll match { + ll match case Nil => a.reverse case x +: xs => pel(e, xs, (e +: x) +: a) - } - llst match { + llst match case Nil => Nil case x +: Nil => x.map(Seq(_)) case x +: _ => - x match { + x match case Nil => Nil case _ => llst .foldRight(Seq(x))((l, a) => l.flatMap(x => pel(x, a))) .map(_.dropRight(x.size)) - } - } - } - @inline def asPartial[A, B, C](f: (A, B) => C): PF[(A, B), C] = { + @inline def asPartial[A, B, C](f: (A, B) => C): PF[(A, B), C] = case (a, b) => f(a, b) - } - @inline def assertOrThrow[T <: Throwable](cond: Boolean, except: => T) = { + @inline def assertOrThrow[T <: Throwable](cond: Boolean, except: => T) = if !cond then throw except - } @deprecated( "pr is used for debugging. Remove all call to it before committing", "fuse 0.0.1" ) - @inline def pr[T](v: T) = { + @inline def pr[T](v: T) = println(v) v - } -} diff --git a/src/main/scala/backends/Backend.scala b/src/main/scala/backends/Backend.scala index 1d6c8763..a5191d7e 100644 --- a/src/main/scala/backends/Backend.scala +++ b/src/main/scala/backends/Backend.scala @@ -6,14 +6,12 @@ import CompilerError.BackendError /** * Abstract definition of a Fuse backend. */ -trait Backend { +trait Backend: - def emit(p: Syntax.Prog, c: Configuration.Config): String = { - if c.header && (canGenerateHeader == false) then { + def emit(p: Syntax.Prog, c: Configuration.Config): String = + if c.header && (canGenerateHeader == false) then throw BackendError(s"Backend $this does not support header generation.") - } emitProg(p, c) - } /** * Generate a String representation of the Abstract Syntax Tree of the @@ -32,4 +30,3 @@ trait Backend { */ val commentPrefix: String = "//" -} diff --git a/src/main/scala/backends/CppLike.scala b/src/main/scala/backends/CppLike.scala index 7601b1dc..45b183d7 100644 --- a/src/main/scala/backends/CppLike.scala +++ b/src/main/scala/backends/CppLike.scala @@ -6,13 +6,13 @@ import CompilerError._ import PrettyPrint.Doc._ import PrettyPrint.Doc -object Cpp { +object Cpp: /** * A C++ backend that only emits one dimensionals arrays and one dimensional * array accesses. */ - trait CppLike { + trait CppLike: /** * This class aggressively uses Scala's implicitConversions. Make sure @@ -27,17 +27,15 @@ object Cpp { /** * Helper to generate a variable declaration with an initial value. */ - def cBind(id: String, rhs: Doc): Doc = { + def cBind(id: String, rhs: Doc): Doc = text("auto") <+> text(id) <+> text("=") <+> rhs <> semi - } /** * Helper to generate a function call that might have a type parameter */ - def cCall(f: String, tParam: Option[Doc], args: Seq[Doc]): Doc = { + def cCall(f: String, tParam: Option[Doc], args: Seq[Doc]): Doc = text(f) <> (if tParam.isDefined then angles(tParam.get) else emptyDoc) <> parens(commaSep(args)) - } /** * Function used for converting types from Fuse to C++. @@ -82,13 +80,12 @@ object Cpp { implicit def IdToString(id: Id): Doc = value(id.v) - def emitBaseInt(v: BigInt, base: Int): String = base match { + def emitBaseInt(v: BigInt, base: Int): String = base match case 8 => s"0${v.toString(8)}" case 10 => v.toString case 16 => s"0x${v.toString(16)}" - } - implicit def emitExpr(e: Expr): Doc = e match { + implicit def emitExpr(e: Expr): Doc = e match case ECast(e, typ) => parens(emitType(typ)) <> emitExpr(e) case EApp(fn, args) => fn <> parens(commaSep(args.map(emitExpr))) case EInt(v, base) => value(emitBaseInt(v, base)) @@ -103,31 +100,27 @@ object Cpp { throw NotImplemented("Physical access code gen for cpp-like backends.") case ERecAccess(rec, field) => rec <> dot <> field case ERecLiteral(fs) => - scope { + scope: commaSep(fs.toSeq.map({ case (id, expr) => dot <> id <+> equal <+> expr })) - } - } /** * Turns a range object into the parameter of a `for` loop. * (int = ; < ; ++) */ - def emitRange(range: CRange): Doc = parens { + def emitRange(range: CRange): Doc = parens: val CRange(id, _, rev, s, e, _) = range - if rev then { + if rev then text("int") <+> id <+> equal <+> value(e - 1) <> semi <+> id <+> text(">=") <+> value(s) <> semi <+> id <> text("--") - } else { + else text("int") <+> id <+> equal <+> value(s) <> semi <+> id <+> text("<") <+> value(e) <> semi <+> id <> text("++") - } - } - implicit def emitCmd(c: Command): Doc = c match { + implicit def emitCmd(c: Command): Doc = c match case CPar(cmds) => vsep(cmds.map(emitCmd)) case CSeq(cmds) => vsep(cmds.map(emitCmd), text("//---")) case l: CLet => emitLet(l) @@ -149,14 +142,12 @@ object Cpp { case _: CView | _: CSplit => throw Impossible("Views should not exist during codegen.") case CBlock(cmd) => scope(cmd) - } - def emitDecl(id: Id, typ: Type): Doc = typ match { + def emitDecl(id: Id, typ: Type): Doc = typ match case ta: TArray => emitArrayDecl(ta, id) case _ => emitType(typ) <+> id - } - def emitFunc(func: FuncDef, entry: Boolean = false): Doc = func match { + def emitFunc(func: FuncDef, entry: Boolean = false): Doc = func match case func @ FuncDef(id, args, ret, bodyOpt) => val as = commaSep(args.map(decl => emitDecl(decl.id, decl.typ))) // If body is not defined, this is an extern. Elide the definition. @@ -169,9 +160,8 @@ object Cpp { if entry then text("extern") <+> quote(text("C")) <+> scope(body) else body - } - def emitDef(defi: Definition): Doc = defi match { + def emitDef(defi: Definition): Doc = defi match case func: FuncDef => emitFunc(func) case RecordDef(name, fields) => text("typedef struct") <+> scope { @@ -179,11 +169,8 @@ object Cpp { case (id, typ) => emitType(typ) <+> id <> semi })) } <+> name <> semi - } def emitInclude(incl: String): Doc = text("#include") <+> quote(text(incl)) - } -} diff --git a/src/main/scala/backends/CppRunnable.scala b/src/main/scala/backends/CppRunnable.scala index 0fac3ccb..a3ef107d 100644 --- a/src/main/scala/backends/CppRunnable.scala +++ b/src/main/scala/backends/CppRunnable.scala @@ -16,12 +16,12 @@ import fuselang.common.{Configuration => C} * header file for parsing. It also emits `int` instead of `ap_int` so * that the code runnable by gcc. */ -private class CppRunnable extends CppLike { +private class CppRunnable extends CppLike: // Variable to store the results of the updated arrays. val serializer = text("__") - def emitType(typ: Type): Doc = typ match { + def emitType(typ: Type): Doc = typ match case _: TVoid => text("void") case _: TBool => text("bool") case _: TIndex => text("int") @@ -38,11 +38,10 @@ private class CppRunnable extends CppLike { case _: TFun => throw Impossible("Cannot emit function types") case TAlias(n) => value(n) - } def emitArrayDecl(ta: TArray, id: Id) = emitType(ta) <+> text(s"&$id") - override def emitLet(l: CLet) = l match { + override def emitLet(l: CLet) = l match case CLet(id, Some(TArray(typ, dims, _)), init) => { /* @@ -59,13 +58,12 @@ private class CppRunnable extends CppLike { vector>(5, vector(3, 0))))); */ - val initVal = init match { + val initVal = init match case Some(expr) => emitExpr(expr) case None => parens(value(dims.head._1) <> comma <+> dims.tail.foldRight((text("vector") <> angles(emitType(typ)), value(0)))({ case ((len, _), acc) => (text("vector") <> angles(acc._1), acc._1 <> parens(value(len) <> comma <+> acc._2)) })._2) - } dims.foldLeft(emitType(typ))({ @@ -74,17 +72,14 @@ private class CppRunnable extends CppLike { } case _ => super.emitLet(l) - } def emitFor(cmd: CFor): Doc = - text("for") <> emitRange(cmd.range) <+> scope { - cmd.par <> { - if cmd.combine != CEmpty then - line <> text("// combiner:") <@> cmd.combine - else - emptyDoc - } - } + val scopeContent : Doc = + if cmd.combine != CEmpty then + line <> text("// combiner:") <@> cmd.combine + else + emptyDoc + text("for") <> emitRange(cmd.range) <+> scope(cmd.par <> scopeContent) def emitFuncHeader(func: FuncDef, entry: Boolean = false) = emptyDoc @@ -97,18 +92,18 @@ private class CppRunnable extends CppLike { * * is generated based on the type of the param: */ - def emitParseDecl: Decl => Doc = { + def emitParseDecl: Decl => Doc = case Decl(id, _) => { // Use the type decoration for id since it's guaranteed to be resolved. val typ = id.typ.get - val (typeName, cTyp): (Doc, Doc) = typ match { + val (typeName, cTyp): (Doc, Doc) = typ match case _: TAlias | _: TRecType | _: TBool | _: IntType | _: TFloat | _: TFixed => { val typeName = emitType(typ) (quote(typeName), typeName) } - case arr @ TArray(_, dims, _) => { + case arr @ TArray(_, dims, _) => val typeName = quote( text(s"${arr.typ}${dims.map(_ => "[]").mkString}") ) @@ -116,12 +111,10 @@ private class CppRunnable extends CppLike { text("n_dim_vec_t") <> angles(emitType(arr.typ) <> comma <+> value(dims.length)) (typeName, cType) - } case t => throw NotImplemented( s"Cannot parse type `$t' with CppRunnable backend." ) - } cBind( s"${id}", @@ -129,20 +122,17 @@ private class CppRunnable extends CppLike { ) } - } - def emitSerializeDecl: Decl => Doc = { - case Decl(id, _) => { + def emitSerializeDecl: Decl => Doc = + case Decl(id, _) => serializer <> brackets(quote(id)) <+> text("=") <+> id <> semi - } - } /** * Generates [[from_json]] and [[to_json]] for a given record. Used by the * json library to extract records from json. * See: https://github.com/nlohmann/json#basic-usage */ - private def recordHelpers: RecordDef => Doc = { + private def recordHelpers: RecordDef => Doc = case RecordDef(name, fields) => text("void to_json") <> parens(text(s"nlohmann::json& j, const ${name}& r")) <+> scope { @@ -155,7 +145,7 @@ private class CppRunnable extends CppLike { })) <> semi } <@> text("void from_json") <> - parens(text(s"const nlohmann::json& j, ${name}& r")) <+> scope { + parens(text(s"const nlohmann::json& j, ${name}& r")) <+> scope: vsep({ fields .map({ @@ -166,10 +156,8 @@ private class CppRunnable extends CppLike { }) .toList }) - } - } - private def emitKernel(func: FuncDef): Doc = { + private def emitKernel(func: FuncDef): Doc = val FuncDef(id, args, ret, bodyOpt) = func // Generate serialization of decls @@ -192,9 +180,8 @@ private class CppRunnable extends CppLike { .getOrElse(emptyDoc) body - } - def emitProg(p: Prog, c: Config) = { + def emitProg(p: Prog, c: Config) = // Comments to demarcate autogenerated struct parsing helpers val startHelpers = value( "/***************** Parse helpers ******************/" @@ -211,7 +198,7 @@ private class CppRunnable extends CppLike { vsep(p.defs.collect({ case rec: RecordDef => recordHelpers(rec) })) // Generate code for the main kernel in this file. - val kernel = vsep { + val kernel = vsep: includes.map(emitInclude) ++ p.defs.map(emitDef) ++ (startHelpers :: @@ -219,14 +206,13 @@ private class CppRunnable extends CppLike { endHelpers :: emitKernel(FuncDef(Id(c.kernelName), p.decls, TVoid(), Some(p.cmd))) :: Nil) - } // Generate function calls to extract all kernel parameters from the JSON val getArgs: Doc = vsep(p.decls.map(emitParseDecl)) // Generate a main function that parses are kernel parameters and calls // the kernel function. - val main = value("int main(int argc, char** argv)") <+> scope { + val main = value("int main(int argc, char** argv)") <+> scope: text("using namespace flattening;") <@> cBind("v", cCall("parse_data", None, List(text("argc"), text("argv")))) <> semi <@> @@ -237,24 +223,20 @@ private class CppRunnable extends CppLike { p.decls.map(decl => value(decl.id.v)) ) <> semi <@> text("return 0") <> semi - } // Emit string (kernel <@> main).pretty - } -} -private class CppRunnableHeader extends CppRunnable { +private class CppRunnableHeader extends CppRunnable: override def emitCmd(c: Command): Doc = emptyDoc - override def emitFunc(func: FuncDef, entry: Boolean): Doc = func match { + override def emitFunc(func: FuncDef, entry: Boolean): Doc = func match case FuncDef(id, args, ret, _) => { val as = commaSep(args.map(d => emitDecl(d.id, d.typ))) emitType(ret) <+> id <> parens(as) <> semi } - } - override def emitProg(p: Prog, c: Config) = { + override def emitProg(p: Prog, c: Config) = val includes: Seq[String] = p.includes.flatMap(_.backends.get(C.Cpp)) :+ "parser.cpp" @@ -264,13 +246,9 @@ private class CppRunnableHeader extends CppRunnable { emitFunc(FuncDef(Id(c.kernelName), p.decls, TVoid(), None), true) declarations.pretty - } -} -case object CppRunnable extends Backend { - def emitProg(p: Prog, c: Config) = c.header match { +case object CppRunnable extends Backend: + def emitProg(p: Prog, c: Config) = c.header match case true => (new CppRunnableHeader()).emitProg(p, c) case false => (new CppRunnable()).emitProg(p, c) - } val canGenerateHeader = true -} diff --git a/src/main/scala/backends/VivadoBackend.scala b/src/main/scala/backends/VivadoBackend.scala index 12581edc..e3228a4e 100644 --- a/src/main/scala/backends/VivadoBackend.scala +++ b/src/main/scala/backends/VivadoBackend.scala @@ -10,7 +10,7 @@ import PrettyPrint.Doc import PrettyPrint.Doc._ import fuselang.common.{Configuration => C} -private class VivadoBackend(config: Config) extends CppLike { +private class VivadoBackend(config: Config) extends CppLike: val CppPreamble: Doc = text(""" |#include """.stripMargin.trim) @@ -35,7 +35,7 @@ private class VivadoBackend(config: Config) extends CppLike { } }) - def bankAndResource(id: Id, ports: Int, banks: Seq[Int]): Doc = { + def bankAndResource(id: Id, ports: Int, banks: Seq[Int]): Doc = val bankPragma = banks.zipWithIndex.map({ case (1, _) => emptyDoc case (bank, dim) => @@ -43,17 +43,15 @@ private class VivadoBackend(config: Config) extends CppLike { s"#pragma HLS ARRAY_PARTITION variable=$id cyclic factor=$bank dim=${dim + 1}" ) }) - val resource = ports match { + val resource = ports match case 1 => "RAM_1P_BRAM" case 2 => "RAM_T2P_BRAM" case n => throw BackendError(s"SDAccel does not support ${n}-ported memories.") - } val resPragma = text( s"#pragma HLS resource variable=${id} core=${resource}" ) vsep(resPragma +: bankPragma) - } def memoryPragmas(decls: Seq[Decl]): Seq[Doc] = decls @@ -62,36 +60,33 @@ private class VivadoBackend(config: Config) extends CppLike { bankAndResource(id, typ.ports, typ.dims.map(_._2)) }) - override def emitLet(let: CLet): Doc = { + override def emitLet(let: CLet): Doc = super.emitLet(let) <@> (let.typ match { case Some(t) => vsep(memoryPragmas(List(Decl(let.id, t)))) case None => emptyDoc }) - } def emitPipeline(enabled: Boolean): Doc = if enabled then value(s"#pragma HLS PIPELINE") <> line else emptyDoc def emitFor(cmd: CFor): Doc = - text("for") <> emitRange(cmd.range) <+> scope { + text("for") <> emitRange(cmd.range) <+> scope: emitPipeline(cmd.pipeline) <> unroll(cmd.range.u) <@> text("#pragma HLS LOOP_FLATTEN off") <@> cmd.par <> (if cmd.combine != CEmpty then line <> text("// combiner:") <@> cmd.combine else emptyDoc) - } override def emitWhile(cmd: CWhile): Doc = - text("while") <> parens(cmd.cond) <+> scope { + text("while") <> parens(cmd.cond) <+> scope: emitPipeline(cmd.pipeline) <> text("#pragma HLS LOOP_FLATTEN off") <@> cmd.body - } - private def axiHeader(arg: Syntax.Decl): Doc = { - arg.typ match { + private def axiHeader(arg: Syntax.Decl): Doc = + arg.typ match case _: TArray => { text( s"#pragma HLS INTERFACE m_axi port=${arg.id} offset=slave bundle=gmem" @@ -104,11 +99,9 @@ private class VivadoBackend(config: Config) extends CppLike { text( s"#pragma HLS INTERFACE s_axilite port=${arg.id} bundle=control" ) - } - } - private def apMemoryHeader(arg: Syntax.Decl): Doc = { - arg.typ match { + private def apMemoryHeader(arg: Syntax.Decl): Doc = + arg.typ match case _: TArray => { text(s"#pragma HLS INTERFACE ap_memory port=${arg.id}") } @@ -116,14 +109,12 @@ private class VivadoBackend(config: Config) extends CppLike { text( s"#pragma HLS INTERFACE s_axilite port=${arg.id} bundle=control" ) - } - } - def emitFuncHeader(func: FuncDef, entry: Boolean = false): Doc = { + def emitFuncHeader(func: FuncDef, entry: Boolean = false): Doc = // Error if function arguments are partitioned/ported. interfaceValid(func.args) - if entry then { + if entry then val argPragmas = func.args.map(arg => config.memoryInterface match { case Axi => axiHeader(arg) @@ -132,11 +123,9 @@ private class VivadoBackend(config: Config) extends CppLike { ) vsep(argPragmas) <@> text(s"#pragma HLS INTERFACE s_axilite port=return bundle=control") - } else { + else text(s"#pragma HLS INLINE") - } - } def emitArrayDecl(ta: TArray, id: Id): Doc = emitType(ta.typ) <+> id <> generateDims(ta.dims) @@ -144,7 +133,7 @@ private class VivadoBackend(config: Config) extends CppLike { def generateDims(dims: Seq[DimSpec]): Doc = ssep(dims.map(d => brackets(value(d._1))), emptyDoc) - def emitType(typ: Type): Doc = typ match { + def emitType(typ: Type): Doc = typ match case _: TVoid => text("void") case _: TBool | _: TIndex => text("int") case _: TStaticInt => throw Impossible("TStaticInt type should not exist") @@ -158,9 +147,8 @@ private class VivadoBackend(config: Config) extends CppLike { case TRecType(n, _) => text(n.toString) case _: TFun => throw Impossible("Cannot emit function types") case TAlias(n) => text(n.toString) - } - def emitProg(p: Prog, c: Config): String = { + def emitProg(p: Prog, c: Config): String = val layout = CppPreamble <@> vsep(p.includes.flatMap(_.backends.get(C.Vivado).map(emitInclude))) <@> @@ -169,33 +157,26 @@ private class VivadoBackend(config: Config) extends CppLike { emitFunc(FuncDef(Id(c.kernelName), p.decls, TVoid(), Some(p.cmd)), true) layout.pretty - } -} -private class VivadoBackendHeader(c: Config) extends VivadoBackend(c) { +private class VivadoBackendHeader(c: Config) extends VivadoBackend(c): override def emitCmd(c: Command): Doc = emptyDoc - override def emitFunc(func: FuncDef, entry: Boolean): Doc = func match { + override def emitFunc(func: FuncDef, entry: Boolean): Doc = func match case FuncDef(id, args, ret, _) => val as = commaSep(args.map(d => emitDecl(d.id, d.typ))) emitType(ret) <+> id <> parens(as) <> semi - } - override def emitProg(p: Prog, c: Config) = { + override def emitProg(p: Prog, c: Config) = val declarations = vsep(p.includes.flatMap(_.backends.get(C.Vivado).map(emitInclude))) <@> vsep(p.defs.map(emitDef)) <@> emitFunc(FuncDef(Id(c.kernelName), p.decls, TVoid(), None)) declarations.pretty - } -} -case object VivadoBackend extends Backend { - def emitProg(p: Prog, c: Config) = c.header match { +case object VivadoBackend extends Backend: + def emitProg(p: Prog, c: Config) = c.header match case true => (new VivadoBackendHeader(c)).emitProg(p, c) case false => (new VivadoBackend(c)).emitProg(p, c) - } val canGenerateHeader = true -} diff --git a/src/main/scala/backends/calyx/Ast.scala b/src/main/scala/backends/calyx/Ast.scala index b20ac16e..f6a87505 100644 --- a/src/main/scala/backends/calyx/Ast.scala +++ b/src/main/scala/backends/calyx/Ast.scala @@ -7,24 +7,22 @@ import scala.util.parsing.input.Position import fuselang.common.Syntax import scala.collection.mutable.{Map => MutableMap} -object Calyx { +object Calyx: // Track metadata while generating Calyx code. case class Metadata( // Mapping from position to the value of the counter map: MutableMap[Position, Int] = MutableMap(), var counter: Int = 0 - ) extends Emitable { - def addPos(pos: Position): Int = { + ) extends Emitable: + def addPos(pos: Position): Int = val key = pos - if !this.map.contains(key) then { + if !this.map.contains(key) then this.map.update(key, this.counter) this.counter = this.counter + 1 - } this.map(key) - } - override def doc(): Doc = { + override def doc(): Doc = text("metadata") <+> scope( vsep( this.map.toSeq @@ -39,19 +37,16 @@ object Calyx { left = text("#") <> lbrace, right = rbrace <> text("#") ) - } - } private def emitPos(pos: Position, @annotation.unused span: Int)( implicit meta: Metadata - ): Doc = { + ): Doc = // Add position information to the metadata. - if pos.line != 0 && pos.column != 0 then { + if pos.line != 0 && pos.column != 0 then val count = meta.addPos(pos) text("@pos") <> parens(text(count.toString)) <> space - } else { + else emptyDoc - } /* (if (pos.line == 0 && pos.column == 0) { emptyDoc } else { @@ -64,9 +59,8 @@ object Calyx { } else { emptyDoc }) */ - } - def emitCompStructure(structs: List[Structure]): Doc = { + def emitCompStructure(structs: List[Structure]): Doc = val (cells, connections) = structs.partition(st => st match { case _: Cell => true @@ -75,45 +69,38 @@ object Calyx { ) text("cells") <+> scope(vsep(cells.map(_.doc()))) <@> text("wires") <+> scope(vsep(connections.map(_.doc()))) - } - sealed trait Emitable { + sealed trait Emitable: def doc(): Doc def emit(): String = this.doc().pretty - } /** A variable representing the name of a component. **/ - case class CompVar(name: String) extends Emitable with Ordered[CompVar] { + case class CompVar(name: String) extends Emitable with Ordered[CompVar]: override def doc(): Doc = text(name) def port(port: String): CompPort = CompPort(this, port) def addSuffix(suffix: String): CompVar = CompVar(s"$name$suffix") - override def compare(that: CompVar): Int = { + override def compare(that: CompVar): Int = this.name.compare(that.name) - } - } case class PortDef( id: CompVar, width: Int, attrs: List[(String, Int)] = List() - ) extends Emitable { - override def doc(): Doc = { + ) extends Emitable: + override def doc(): Doc = val attrDoc = hsep(attrs.map({ case (attr, v) => text(s"@${attr}") <> parens(text(v.toString())) })) <> (if attrs.isEmpty then emptyDoc else space) attrDoc <> id.doc() <> colon <+> value(width) - } - } /**** definition statements *****/ - case class Namespace(name: String, comps: List[NamespaceStatement]) { + case class Namespace(name: String, comps: List[NamespaceStatement]): def doc(implicit meta: Metadata): Doc = vsep(comps.map(_.doc)) def emit(implicit meta: Metadata) = this.doc.pretty - } /** The statements that can appear at the top-level. */ - sealed trait NamespaceStatement { - def doc(implicit meta: Metadata): Doc = this match { + sealed trait NamespaceStatement: + def doc(implicit meta: Metadata): Doc = this match case Import(filename) => text("import") <+> quote(text(filename)) <> semi case Component(name, inputs, outputs, structure, control) => { text("component") <+> @@ -126,8 +113,6 @@ object Calyx { text("control") <+> scope(control.doc) ) } - } - } case class Import(filename: String) extends NamespaceStatement case class Component( @@ -139,8 +124,8 @@ object Calyx { ) extends NamespaceStatement /***** structure *****/ - sealed trait Port extends Emitable with Ordered[Port] { - override def doc(): Doc = this match { + sealed trait Port extends Emitable with Ordered[Port]: + override def doc(): Doc = this match case CompPort(id, name) => id.doc() <> dot <> text(name) case ThisPort(id) => id.doc() @@ -148,9 +133,8 @@ object Calyx { id.doc() <> brackets(text(name)) case ConstantPort(width, value) => text(width.toString) <> text("'d") <> text(value.toString) - } - override def compare(that: Port): Int = (this, that) match { + override def compare(that: Port): Int = (this, that) match case (ThisPort(thisId), ThisPort(thatId)) => thisId.compare(thatId) case (CompPort(thisId, _), CompPort(thatId, _)) => thisId.compare(thatId) case (HolePort(thisId, _), HolePort(thatId, _)) => thisId.compare(thatId) @@ -162,26 +146,22 @@ object Calyx { case (_, _: CompPort) => -1 case (_: ConstantPort, _) => 1 case (_, _: ConstantPort) => -1 - } - def isHole(): Boolean = this match { + def isHole(): Boolean = this match case _: HolePort => true case _ => false - } - def isConstant(value: Int, width: Int) = this match { + def isConstant(value: Int, width: Int) = this match case ConstantPort(v, w) if v == value && w == width => true case _ => false - } - } case class CompPort(id: CompVar, name: String) extends Port case class ThisPort(id: CompVar) extends Port case class HolePort(id: CompVar, name: String) extends Port case class ConstantPort(width: Int, value: BigInt) extends Port - sealed trait Structure extends Emitable with Ordered[Structure] { - override def doc(): Doc = this match { + sealed trait Structure extends Emitable with Ordered[Structure]: + override def doc(): Doc = this match case Cell(id, comp, ref, attrs) => { val attrDoc = hsep( @@ -206,28 +186,23 @@ object Calyx { angles(text("\"promotable\"") <> equal <> text(delay.get.toString())) else emptyDoc) <+> scope(vsep(conns.map(_.doc()))) - } - def compare(that: Structure): Int = { - (this, that) match { + def compare(that: Structure): Int = + (this, that) match case (Cell(thisId, _, _, _), Cell(thatId, _, _, _)) => thisId.compare(thatId) case (Group(thisId, _, _, _), Group(thatId, _, _, _)) => thisId.compare(thatId) case (Assign(thisSrc, thisDest, _), Assign(thatSrc, thatDest, _)) => { - if thisSrc.compare(thatSrc) == 0 then { + if thisSrc.compare(thatSrc) == 0 then thisDest.compare(thatDest) - } else { + else thisSrc.compare(thatSrc) - } } case (_: Cell, _) => -1 case (_, _: Cell) => 1 case (_: Group, _) => -1 case (_, _: Group) => 1 - } - } - } case class Cell( name: CompVar, comp: CompInst, @@ -244,13 +219,13 @@ object Calyx { case class Assign(src: Port, dest: Port, guard: GuardExpr = True) extends Structure - object Group { + object Group: def fromStructure( id: CompVar, structure: List[Structure], staticDelay: Option[Int], comb: Boolean - ): (Group, List[Structure]) = { + ): (Group, List[Structure]) = assert( !(comb && staticDelay.isDefined && staticDelay.get != 0), @@ -265,71 +240,60 @@ object Calyx { ) (this(id, connections, if comb then None else staticDelay, comb), st) - } - } - case class CompInst(id: String, args: List[BigInt]) extends Emitable { - override def doc(): Doc = { + case class CompInst(id: String, args: List[BigInt]) extends Emitable: + override def doc(): Doc = val strList = args.map((x: BigInt) => text(x.toString())) text(id) <> parens(hsep(strList, comma)) - } - } - sealed trait GuardExpr extends Emitable { - override def doc(): Doc = this match { + sealed trait GuardExpr extends Emitable: + override def doc(): Doc = this match case Atom(item) => item.doc() case And(left, right) => parens(left.doc() <+> text("&") <+> right.doc()) case Or(left, right) => parens(left.doc() <+> text("|") <+> right.doc()) case Not(inner) => text("!") <> inner.doc() case True => emptyDoc - } - } case class Atom(item: Port) extends GuardExpr - object Atom { - def apply(item: Port): GuardExpr = item match { + object Atom: + def apply(item: Port): GuardExpr = item match case ConstantPort(1, v) if v == 1 => True case _ => new Atom(item) - } - } case class And(left: GuardExpr, right: GuardExpr) extends GuardExpr case class Or(left: GuardExpr, right: GuardExpr) extends GuardExpr case class Not(inner: GuardExpr) extends GuardExpr case object True extends GuardExpr /***** control *****/ - sealed trait Control { + sealed trait Control: var attributes = Map[String, Int]() - def seq(c: Control): Control = (this, c) match { + def seq(c: Control): Control = (this, c) match case (Empty, c) => c case (c, Empty) => c case (seq0: SeqComp, seq1: SeqComp) => SeqComp(seq0.stmts ++ seq1.stmts) case (seq: SeqComp, _) => SeqComp(seq.stmts ++ List(c)) case (_, seq: SeqComp) => SeqComp(this :: seq.stmts) case _ => SeqComp(List(this, c)) - } - def par(c: Control): Control = (this, c) match { + def par(c: Control): Control = (this, c) match case (Empty, c) => c case (c, Empty) => c case (par0: ParComp, par1: ParComp) => ParComp(par0.stmts ++ par1.stmts) case (par0: ParComp, par1) => ParComp(par0.stmts ++ List(par1)) case (par0, par1: ParComp) => ParComp(par0 :: par1.stmts) case _ => ParComp(List(this, c)) - } def attributesDoc(): Doc = - if this.attributes.isEmpty then { + if this.attributes.isEmpty then emptyDoc - } else { + else hsep(attributes.map({ case (attr, v) => text(s"@$attr") <> parens(text(v.toString())) })) <> space - } - def doc(implicit meta: Metadata): Doc = { - val controlDoc = this match { + def doc(implicit meta: Metadata): Doc = + val controlDoc = this match case SeqComp(stmts) => text("seq") <+> scope(vsep(stmts.map(_.doc))) case ParComp(stmts) => @@ -370,10 +334,7 @@ object Calyx { parens(commaSep(outputDefs)) <> semi } case Empty => text("empty") - } attributesDoc() <> controlDoc - } - } case class SeqComp(stmts: List[Control]) extends Control case class ParComp(stmts: List[Control]) extends Control case class If(port: Port, cond: CompVar, trueBr: Control, falseBr: Control) @@ -388,10 +349,9 @@ object Calyx { ) extends Control with Syntax.PositionalWithSpan case object Empty extends Control -} /** Construct primitives in Calyx. */ -object Stdlib { +object Stdlib: def register(name: Calyx.CompVar, width: Int) = Calyx.Cell( name, @@ -431,4 +391,3 @@ object Stdlib { val staticTimingMap: Map[String, Int] = Map( "mult" -> 3 ) -} diff --git a/src/main/scala/backends/calyx/Backend.scala b/src/main/scala/backends/calyx/Backend.scala index f3065430..b14dcdc4 100644 --- a/src/main/scala/backends/calyx/Backend.scala +++ b/src/main/scala/backends/calyx/Backend.scala @@ -62,7 +62,7 @@ private case class EmitOutput( * by the Calyx compiler to enable such uses: * https://github.com/cucapra/Calyx/issues/304 */ -private class CalyxBackendHelper { +private class CalyxBackendHelper: /** A list of function IDs that require width arguments * in their SystemVerilog module definition. @@ -71,14 +71,12 @@ private class CalyxBackendHelper { /** Helper for generating unique names. */ var idx: Map[String, Int] = Map(); - def genName(base: String): CompVar = { + def genName(base: String): CompVar = // update idx - idx get base match { + idx get base match case Some(n) => idx = idx + (base -> (n + 1)) case None => idx = idx + (base -> 0) - } CompVar(s"$base${idx(base)}") - } /** A Calyx variable will either be a * local variable (LocalVar) or @@ -98,7 +96,7 @@ private class CalyxBackendHelper { arr: TArray, id: Id, attrs: List[(String, Int)] = List() - ): Cell = { + ): Cell = // No support for multi-ported memories or banked memories. assertOrThrow( arr.ports == 1, @@ -135,24 +133,21 @@ private class CalyxBackendHelper { false, attrs ) - } /** Returns the name of `p`, if it has one. */ - def getPortName(p: Port): String = { - p match { + def getPortName(p: Port): String = + p match case CompPort(id, _) => id.name case ThisPort(id) => id.name case HolePort(id, _) => id.name case ConstantPort(_, _) => throw Impossible("Constant Ports do not have names.") - } - } /** Returns a list of tuples (name, width) for each address port in a memory. For example, a D1 Memory declared as (32, 1, 1) would return List[("addr0", 1)]. */ - def getAddrPortToWidths(typ: TArray, id: Id): List[(String, BigInt)] = { + def getAddrPortToWidths(typ: TArray, id: Id): List[(String, BigInt)] = // Emit the array to determine the port widths. val Cell(_, CompInst(_, arrayArgs), _, _) = emitArrayDecl(typ, id) @@ -160,19 +155,17 @@ private class CalyxBackendHelper { // (bitwidth, size0, ..., sizeX, addr0, ..., addrX), // where X is the number of dimensions - 1. val dims = - arrayArgs.length match { + arrayArgs.length match case 3 => 1 case 5 => 2 case 7 => 3 case 9 => 4 case _ => throw NotImplemented(s"Arrays of dimension > 4.") - } val addressIndices = (dims + 1 to dims << 1).toList addressIndices.zipWithIndex.map({ case (n: Int, i: Int) => (s"addr${i}", arrayArgs(n)) }) - } /** Returns the width argument(s) of a given function, based on the return * type of the function. This is necessary because some components may @@ -188,26 +181,23 @@ private class CalyxBackendHelper { * `requiresWidthArguments`. */ def getCompInstArgs( funcId: Id - )(implicit id2FuncDef: FunctionMapping): List[BigInt] = { + )(implicit id2FuncDef: FunctionMapping): List[BigInt] = val id = funcId.toString() - if !requiresWidthArguments.contains(id) then { + if !requiresWidthArguments.contains(id) then List() - } else { + else val typ = id2FuncDef(funcId).retTy; - typ match { + typ match case TSizedInt(width, _) => List(width) case TFixed(width, intWidth, _) => List(width, intWidth, width - intWidth) case _ => throw Impossible(s"Type: $typ for $id is not supported.") - } - } - } /** `emitInvokeDecl` computes the necessary structure and control for Syntax.EApp. */ def emitInvokeDecl(app: EApp)( implicit store: Store, id2FuncDef: FunctionMapping - ): (Cell, Seq[Structure], Control) = { + ): (Cell, Seq[Structure], Control) = val functionName = app.func.toString() val declName = genName(functionName) val compInstArgs = getCompInstArgs(app.func) @@ -240,31 +230,29 @@ private class CalyxBackendHelper { argSt.flatten, Invoke(declName, refCells.toList, inConnects.toList, List()).withPos(app) ) - } /** `emitDecl(d)` computes the structure that is needed to * represent the declaration `d`. Simply returns a `List[Structure]`. */ - def emitDecl(d: Decl): Structure = d.typ match { + def emitDecl(d: Decl): Structure = d.typ match case tarr: TArray => emitArrayDecl(tarr, d.id, List("external" -> 1)) case _: TBool => Stdlib.register(CompVar(s"${d.id}"), 1) case TSizedInt(size, _) => Stdlib.register(CompVar(s"${d.id}"), size) case TFixed(ltotal, _, _) => Stdlib.register(CompVar(s"${d.id}"), ltotal) case x => throw NotImplemented(s"Type $x not implemented for decls.", x.pos) - } /** `emitBinop` is a helper function to generate the structure * for `e1 binop e2`. The return type is described in `emitExpr`. */ def emitBinop(compName: String, e1: Expr, e2: Expr)( implicit store: Store - ): EmitOutput = { + ): EmitOutput = val e1Out = emitExpr(e1) val e2Out = emitExpr(e2) val (e1Bits, e1Int) = bitsForType(e1.typ, e1.pos) val (e2Bits, e2Int) = bitsForType(e2.typ, e2.pos) // Throw error on numeric or bitwidth mismatch. - (e1Int, e2Int) match { + (e1Int, e2Int) match case (Some(_), Some(_)) => { /* Fixed-points allow this */ } case (None, None) => { assertOrThrow( @@ -283,9 +271,8 @@ private class CalyxBackendHelper { s"\nright: ${Pretty.emitExpr(e2)(false).pretty}" ) } - } - bitsForType(e1.typ, e1.pos) match { + bitsForType(e1.typ, e1.pos) match case (e1Bits, None) => { val isSigned = signed(e1.typ) val binOp = Stdlib.binop(s"$compName", e1Bits, isSigned) @@ -341,8 +328,6 @@ private class CalyxBackendHelper { None ) } - } - } def emitMultiCycleBinop( compName: String, @@ -352,13 +337,13 @@ private class CalyxBackendHelper { delay: Option[Int] )( implicit store: Store - ): EmitOutput = { + ): EmitOutput = val e1Out = emitExpr(e1) val e2Out = emitExpr(e2) val (e1Bits, e1Int) = bitsForType(e1.typ, e1.pos) val (e2Bits, e2Int) = bitsForType(e2.typ, e2.pos) // Check if we can compile this expression. - (e1Int, e2Int) match { + (e1Int, e2Int) match case (Some(intBit1), Some(intBit2)) => { assertOrThrow( intBit1 == intBit2, @@ -384,8 +369,7 @@ private class CalyxBackendHelper { s"\nright: ${Pretty.emitExpr(e2)(false).pretty}" ) } - } - val binOp = e1.typ match { + val binOp = e1.typ match case Some(TFixed(width, intWidth, unsigned)) => Stdlib.fixed_point_binop( s"$compName", @@ -398,7 +382,6 @@ private class CalyxBackendHelper { Stdlib.binop(s"$compName", width, !unsigned) case _ => throw NotImplemented(s"Multi-cycle binary operation with type: $e1.typ") - } val compVar = genName(compName) val comp = Cell(compVar, binOp, false, List()) val struct = List( @@ -419,7 +402,6 @@ private class CalyxBackendHelper { yield d1 + d2 + d3, Some((comp.name.port("done"), delay)) ) - } /** `emitExpr(expr, rhsInfo)(implicit store)` calculates the necessary structure * to compute `expr`. @@ -430,7 +412,7 @@ private class CalyxBackendHelper { def emitExpr(expr: Expr, rhsInfo: Option[(Port, Option[Int])] = None)( implicit store: Store ): EmitOutput = - expr match { + expr match case _: EInt => { throw PassError( "Cannot compile unannotated constants. Wrap constant in `as` expression", @@ -438,7 +420,7 @@ private class CalyxBackendHelper { ) } case EBinop(op, e1, e2) => { - val compName = op.op match { + val compName = op.op match case "+" => "add" case "-" => "sub" case "*" => "mult_pipe" @@ -462,8 +444,7 @@ private class CalyxBackendHelper { s"Calyx backend does not support '$x' yet.", op.pos ) - } - op.op match { + op.op match case "*" => emitMultiCycleBinop( compName, @@ -489,7 +470,6 @@ private class CalyxBackendHelper { None ) case _ => emitBinop(compName, e1, e2) - } } case EVar(id) => val (cell, calyxVarType) = store @@ -497,7 +477,7 @@ private class CalyxBackendHelper { .getOrThrow(BackendError(s"`$id' was not in store", expr.pos)) val (struct, port, done, delay) = - rhsInfo match { + rhsInfo match case Some((port, delay)) => ( List(Assign(port, cell.port("write_en"))), @@ -512,7 +492,6 @@ private class CalyxBackendHelper { None, Some(0) /* reading from a register is combinational */ ) - } EmitOutput( if calyxVarType == LocalVar then cell.port(port) @@ -566,27 +545,24 @@ private class CalyxBackendHelper { val isNegative = value.startsWith("-") val partition = value.split('.') val sIntPart = partition(0) - val intPart = if isNegative then { + val intPart = if isNegative then sIntPart.substring(1, sIntPart.length()) - } else { + else sIntPart - } val bdFracPart = BigDecimal("0." + partition(1)) val fracValue = (bdFracPart * BigDecimal(2).pow(fracWidth)) - if !fracValue.isWhole then { + if !fracValue.isWhole then throw BackendError( s"The value $value of type $typ is not representable in fixed point", expr.pos ) - } val intBits = binaryString(intPart.toInt, intWidth) val fracBits = binaryString(fracValue.toBigInt, fracWidth) - val bits = if isNegative then { + val bits = if isNegative then negateTwosComplement(intBits + fracBits) - } else { + else intBits + fracBits - } val fpconst = Cell( genName("fp_const"), @@ -606,7 +582,7 @@ private class CalyxBackendHelper { val (vBits, _) = bitsForType(e.typ, e.pos) val (cBits, _) = bitsForType(Some(t), e.pos) val res = emitExpr(e) - if vBits == cBits then { + if vBits == cBits then // No slicing or padding is necessary. EmitOutput( res.port, @@ -615,12 +591,11 @@ private class CalyxBackendHelper { Some(0), res.multiCycleInfo ) - } else { - val comp = if cBits > vBits then { + else + val comp = if cBits > vBits then Cell(genName("pad"), Stdlib.pad(vBits, cBits), false, List()) - } else { + else Cell(genName("slice"), Stdlib.slice(vBits, cBits), false, List()) - } val struct = List( comp, Assign(res.port, comp.name.port("in")) @@ -632,7 +607,6 @@ private class CalyxBackendHelper { Some(0), res.multiCycleInfo ) - } } case EArrAccess(id, accessors) => { val (arr, typ) = store @@ -657,19 +631,18 @@ private class CalyxBackendHelper { val isParam = (typ == ParameterVar) val (writeEnPort, donePort, accessPort) = - if isParam then { + if isParam then ( ThisPort(CompVar(s"${id}_write_en")), ThisPort(CompVar(s"${id}_${donePortName}")), ThisPort(CompVar(s"${id}_${portName}")) ) - } else { + else ( arr.port("write_en"), arr.port(donePortName), arr.port(portName) ) - } // We always need to specify and address on the `addr` ports. Generate // the additional structure. @@ -684,25 +657,22 @@ private class CalyxBackendHelper { } }) - val readEnPort = if isParam then { + val readEnPort = if isParam then ThisPort(CompVar(s"${id}_read_en")) - } else { + else arr.port("read_en") - } // always assign 1 to read_en port if we want to read from seq mem val readEnStruct = if rhsInfo.isDefined then List() else List(Assign(ConstantPort(1,1), readEnPort)) val writeEnStruct = - rhsInfo match { + rhsInfo match case Some((port, _)) => List(Assign(port, writeEnPort)) case None => List() - } - val delay = (rhsInfo) match { + val delay = (rhsInfo) match case (None) => Some(1) case (Some((_, delay))) => delay.map(_ + 1) - } EmitOutput( accessPort, @@ -727,15 +697,14 @@ private class CalyxBackendHelper { } case x => throw NotImplemented(s"Calyx backend does not support $x yet.", x.pos) - } def emitCmd( c: Command )( implicit store: Store, id2FuncDef: FunctionMapping - ): (List[Structure], Control, Store) = { - c match { + ): (List[Structure], Control, Store) = + c match case CBlock(cmd) => emitCmd(cmd) case CPar(cmds) => { cmds.foldLeft[(List[Structure], Control, Store)]( @@ -808,13 +777,12 @@ private class CalyxBackendHelper { ) // The write enable signal should not be high until // the multi-cycle operation is complete, if it exists. - val (writeEnableSrcPort, delay) = out.multiCycleInfo match { + val (writeEnableSrcPort, delay) = out.multiCycleInfo match case Some((port, Some(delay))) => (Some(port), out.delay.map(_ + delay + 1)) case Some((port, None)) => (Some(port), None) case None => (out.done, out.delay.map(_ + 1)) - } val struct = Assign(out.port, reg.name.port("in")) :: Assign( writeEnableSrcPort.getOrElse(ConstantPort(1, 1)), @@ -913,7 +881,7 @@ private class CalyxBackendHelper { val groupName = genName("cond") // If the conditional computation is not combinational, generate a group. - condOut.done match { + condOut.done match case Some(done) => { val doneAssign = Assign(done, HolePort(groupName, "done")) val (group, st) = @@ -939,7 +907,6 @@ private class CalyxBackendHelper { val control = If(condOut.port, group.id, tCon, fCon) (group :: st ++ struct, control, store) } - } } case CEmpty => (List(), Empty, store) case wh @ CWhile(cond, _, body) => { @@ -971,7 +938,7 @@ private class CalyxBackendHelper { // to // lhs = lhs + rhs val (op, numOp) = - rop.toString match { + rop.toString match case "+=" => ("+", (x: Double, y: Double) => x + y) case "*=" => ("*", (x: Double, y: Double) => x * y) case _ => @@ -979,9 +946,8 @@ private class CalyxBackendHelper { s"Calyx backend does not support $rop yet", c.pos ) - } - e1 match { + e1 match case _: EVar => emitCmd(CUpdate(e1, EBinop(NumOp(op, numOp), e1, e2))) case ea: EArrAccess => { @@ -997,7 +963,6 @@ private class CalyxBackendHelper { throw Impossible( s"LHS is neither a variable nor a memory access: ${Pretty.emitExpr(e)(false).pretty}" ) - } } case CReturn(expr:EVar) => { // Hooks the output port of the emitted `expr` to PortDef `out` of the component. @@ -1016,19 +981,15 @@ private class CalyxBackendHelper { case _: CDecorate => (List(), Empty, store) case x => throw NotImplemented(s"Calyx backend does not support $x yet", x.pos) - } - } /** Emits the function definition if a body exists. */ - def emitDefinition(definition: Definition): FuncDef = { - definition match { + def emitDefinition(definition: Definition): FuncDef = + definition match case fd: FuncDef => fd case x => throw NotImplemented(s"Calyx backend does not support $x yet", x.pos) - } - } - def emitProg(p: Prog, c: Config): String = { + def emitProg(p: Prog, c: Config): String = implicit val meta = Metadata() @@ -1048,8 +1009,8 @@ private class CalyxBackendHelper { ) val functionDefinitions: List[Component] = - for ( case (id, FuncDef(_, params, retType, Some(body))) <- id2FuncDef.toList ) - yield { + for case (id, FuncDef(_, params, retType, Some(body))) <- id2FuncDef.toList + yield val (refCells, inputPorts) = params.partitionMap(param => param.typ match { case tarr: TArray => { @@ -1089,7 +1050,6 @@ private class CalyxBackendHelper { refCells.toList ++ cmdStructure.sorted, controls ) - } val imports = Import("primitives/core.futil") :: @@ -1097,7 +1057,7 @@ private class CalyxBackendHelper { Import("primitives/binary_operators.futil") :: p.includes.flatMap(_.backends.get(C.Calyx)).map(i => Import(i)).toList - val main = if !c.compilerOpts.contains("no-main") then { + val main = if !c.compilerOpts.contains("no-main") then val declStruct = p.decls.map(emitDecl) val store = declStruct.foldLeft(Map[CompVar, (CompVar, VType)]())((store, struct) => @@ -1114,9 +1074,8 @@ private class CalyxBackendHelper { List( Component(mainComponentName, List(), List(), struct.sorted, control) ) - } else { + else List() - } // Emit the program (PrettyPrint.Doc @@ -1124,13 +1083,9 @@ private class CalyxBackendHelper { (imports ++ functionDefinitions ++ main) .map(_.doc(meta)) ) <@> meta.doc()).pretty - } -} -case object CalyxBackend extends fuselang.backend.Backend { - def emitProg(p: Prog, c: Config) = { +case object CalyxBackend extends fuselang.backend.Backend: + def emitProg(p: Prog, c: Config) = (new CalyxBackendHelper()).emitProg(p, c) - } val canGenerateHeader = false override val commentPrefix: String = "//" -} diff --git a/src/main/scala/backends/calyx/Helpers.scala b/src/main/scala/backends/calyx/Helpers.scala index f0fedb13..5660bf54 100644 --- a/src/main/scala/backends/calyx/Helpers.scala +++ b/src/main/scala/backends/calyx/Helpers.scala @@ -7,36 +7,33 @@ import fuselang.common._ import Syntax._ import CompilerError._ -object Helpers { +object Helpers: val slowBinops = List("*", "/", "%") /** Given a binary string, returns the negated * two's complement representation. */ - def negateTwosComplement(bitString: String): String = { - if bitString.forall(_ == '0') then { + def negateTwosComplement(bitString: String): String = + if bitString.forall(_ == '0') then return bitString - } val t = bitString .replaceAll("0", "_") .replaceAll("1", "0") .replaceAll("_", "1") (BigInt(t, 2) + 1).toString(2) - } /** Given an integer, returns the corresponding * zero-padded string of size `width`. */ - def binaryString(value: BigInt, width: Int): String = { + def binaryString(value: BigInt, width: Int): String = val s = value.toString(2) "0" * max(width - s.length(), 0) + s - } /** Extracts the bits needed from an optional type annotation. * Returns (total size, Option[integral]) bits for the computation. */ - def bitsForType(t: Option[Type], pos: Position): (Int, Option[Int]) = { - t match { + def bitsForType(t: Option[Type], pos: Position): (Int, Option[Int]) = + t match case Some(TSizedInt(width, _)) => (width, None) case Some(TFixed(t, i, _)) => (t, Some(i)) case Some(_: TBool) => (1, None) @@ -47,17 +44,12 @@ object Helpers { pos ) case None => throw Impossible(s"Explicit type missing. Try running with `--lower` or report an error with a reproducible program.") - } - } /** Returns true if the given int or fixed point is signed */ - def signed(typ: Option[Type]) = { - typ match { + def signed(typ: Option[Type]) = + typ match case Some(TSizedInt(_, un)) => un == false case Some(TFixed(_, _, un)) => un == false case _ => false - } - } -} diff --git a/src/main/scala/common/Checker.scala b/src/main/scala/common/Checker.scala index f75d390f..7f2f4fc3 100644 --- a/src/main/scala/common/Checker.scala +++ b/src/main/scala/common/Checker.scala @@ -5,7 +5,7 @@ import fuselang.Utils.asPartial import Syntax._ import EnvHelpers._ -object Checker { +object Checker: /** * A checker is a compiler pass that collects information using some Environment @@ -18,7 +18,7 @@ object Checker { * val env1 = check(e1)(currEnv) * check(e2)(env1) */ - abstract class Checker { + abstract class Checker: type Env <: ScopeManager[Env] @@ -27,7 +27,7 @@ object Checker { /** * Top level function called on the AST. */ - def check(p: Prog): Unit = { + def check(p: Prog): Unit = val Prog(_, defs, _, _, cmd) = p val env = defs.foldLeft(emptyEnv)({ @@ -35,7 +35,6 @@ object Checker { }) checkC(cmd)(env); () - } /** * Helper functions for checking sequences of the same element. @@ -43,20 +42,17 @@ object Checker { def checkSeqWith[T](f: (T, Env) => Env)(iter: Iterable[T])(env: Env): Env = iter.foldLeft(env)({ case (env, t) => f(t, env) }) - def checkESeq(exprs: Iterable[Expr])(implicit env: Env): Env = { + def checkESeq(exprs: Iterable[Expr])(implicit env: Env): Env = checkSeqWith[Expr](checkE(_: Expr)(_: Env))(exprs)(env) - } - def checkCSeq(cmds: Iterable[Command])(implicit env: Env): Env = { + def checkCSeq(cmds: Iterable[Command])(implicit env: Env): Env = checkSeqWith[Command](checkC(_: Command)(_: Env))(cmds)(env) - } - def checkDef(defi: Definition)(implicit env: Env) = defi match { + def checkDef(defi: Definition)(implicit env: Env) = defi match case FuncDef(_, _, _, bodyOpt) => bodyOpt.map(checkC).getOrElse(env) case _: RecordDef => env - } - def checkE(expr: Expr)(implicit env: Env): Env = expr match { + def checkE(expr: Expr)(implicit env: Env): Env = expr match case _: ERational | _: EInt | _: EBool | _: EVar => env case ERecLiteral(fields) => checkESeq(fields.map(_._2)) case EArrLiteral(idxs) => checkESeq(idxs) @@ -67,11 +63,10 @@ object Checker { case EArrAccess(_, idxs) => checkESeq(idxs) case EPhysAccess(_, bankIdxs) => checkESeq(bankIdxs.map(_._2)) - } def checkLVal(e: Expr)(implicit env: Env): Env = checkE(e) - def checkC(cmd: Command)(implicit env: Env): Env = cmd match { + def checkC(cmd: Command)(implicit env: Env): Env = cmd match case _: CSplit | _: CView | CEmpty | _: CDecorate => env case CPar(cmds) => checkCSeq(cmds) case CSeq(cmds) => checkCSeq(cmds) @@ -94,9 +89,7 @@ object Checker { checkE(cond).withScope(checkC(body)(_)) } case CBlock(cmd) => env.withScope(checkC(cmd)(_)) - } - } /** * Partial checker defines helper functions for writing down @@ -121,7 +114,7 @@ object Checker { * executes myCheckE first and if there are no matching cases, falls * back to partialRewriteE which has the default traversal behavior. */ - abstract class PartialChecker extends Checker { + abstract class PartialChecker extends Checker: private val partialCheckE: PF[(Expr, Env), Env] = asPartial(super.checkE(_: Expr)(_: Env)) @@ -132,13 +125,9 @@ object Checker { // pattern. def mergeCheckE( myCheckE: PF[(Expr, Env), Env] - ): PF[(Expr, Env), Env] = { + ): PF[(Expr, Env), Env] = myCheckE.orElse(partialCheckE) - } def mergeCheckC( myCheckC: PF[(Command, Env), Env] - ): PF[(Command, Env), Env] = { + ): PF[(Command, Env), Env] = myCheckC.orElse(partialCheckC) - } - } -} diff --git a/src/main/scala/common/CodeGenHelpers.scala b/src/main/scala/common/CodeGenHelpers.scala index 77382c94..25f17919 100644 --- a/src/main/scala/common/CodeGenHelpers.scala +++ b/src/main/scala/common/CodeGenHelpers.scala @@ -4,10 +4,10 @@ import fuselang.Utils.Big import scala.math.log10 -object CodeGenHelpers { +object CodeGenHelpers: import Syntax._ - implicit class RichExpr(e1: Expr) { + implicit class RichExpr(e1: Expr): import Syntax.{OpConstructor => OC} def +(e2: Expr) = @@ -34,7 +34,6 @@ object CodeGenHelpers { def &(e2: Expr) = binop(BitOp("&"), e1, e2) - } // Using the trick defined here: https://www.geeksforgeeks.org/program-to-find-whether-a-no-is-power-of-two/ def isPowerOfTwo(x: BigInt) = @@ -42,34 +41,31 @@ object CodeGenHelpers { def log2(n: BigInt) = log10(n.toDouble) / log10(2) - def fastDiv(l: Expr, r: Expr) = (l, r) match { + def fastDiv(l: Expr, r: Expr) = (l, r) match case (EInt(n, b), EInt(m, _)) => EInt(n / m, b) case (_, EInt(n, _)) if (isPowerOfTwo(n)) => l >> EInt(log2(n).toInt, 10) case _ => { scribe.warn(s"Cannot generate fast division for denominator $r") l div r } - } // Using the trick defined here: http://mziccard.me/2015/05/08/modulo-and-division-vs-bitwise-operations/ - def fastMod(l: Expr, r: Expr) = (l, r) match { + def fastMod(l: Expr, r: Expr) = (l, r) match case (EInt(n, b), EInt(m, _)) => EInt(n % m, b) case (_, EInt(n, _)) if (isPowerOfTwo(n)) => l & EInt(n - 1, 10) case _ => { scribe.warn(s"Cannot generate fast division for denominator $r") l mod r } - } - def and(l: Expr, r: Expr) = (l, r) match { + def and(l: Expr, r: Expr) = (l, r) match case (EBool(true), r) => r case (l, EBool(true)) => l case (_, EBool(false)) | (EBool(false), _) => EBool(false) case _ => EBinop(BoolOp("&&"), l, r) - } // Simple peephole optimization to turn: 1 * x => x, 0 + x => x, 0 * x => 0 - def binop(op: BOp, l: Expr, r: Expr) = (op, l, r) match { + def binop(op: BOp, l: Expr, r: Expr) = (op, l, r) match case (NumOp("*", _), EInt(Big(1), _), r) => r case (NumOp("*", _), l, EInt(Big(1), _)) => l case (NumOp("*", _), EInt(Big(0), b), _) => EInt(0, b) @@ -79,6 +75,4 @@ object CodeGenHelpers { case (BitOp("<<"), l, EInt(Big(0), _)) => l case (BitOp(">>"), l, EInt(Big(0), _)) => l case _ => EBinop(op, l, r) - } -} diff --git a/src/main/scala/common/Configuration.scala b/src/main/scala/common/Configuration.scala index 679222d1..1879c3be 100644 --- a/src/main/scala/common/Configuration.scala +++ b/src/main/scala/common/Configuration.scala @@ -2,27 +2,24 @@ package fuselang.common import java.io.File -object Configuration { +object Configuration: sealed trait Mode case object Compile extends Mode case object Run extends Mode - def stringToBackend(name: String): Option[BackendOption] = name match { + def stringToBackend(name: String): Option[BackendOption] = name match case "vivado" => Some(Vivado) case "c++" => Some(Cpp) case "futil" | "calyx" => Some(Calyx) case _ => None - } // What kind of code to generate. - sealed trait BackendOption { - override def toString() = this match { + sealed trait BackendOption: + override def toString() = this match case Vivado => "vivado" case Cpp => "c++" case Calyx => "calyx" - } - } case object Vivado extends BackendOption case object Cpp extends BackendOption case object Calyx extends BackendOption @@ -48,4 +45,3 @@ object Configuration { memoryInterface: MemoryInterface = Axi, // The memory interface to use for vivado ) -} diff --git a/src/main/scala/common/Document.scala b/src/main/scala/common/Document.scala index a7029396..4a4b068b 100644 --- a/src/main/scala/common/Document.scala +++ b/src/main/scala/common/Document.scala @@ -5,7 +5,7 @@ package fuselang.common import java.io.Writer -object PrettyPrint { +object PrettyPrint: case object DocNil extends Doc case object DocBreak extends Doc case object DocSpace extends Doc @@ -20,40 +20,36 @@ object PrettyPrint { * @author Michel Schinz * @version 1.0 */ - abstract class Doc { - def <@>(hd: Doc): Doc = { + abstract class Doc: + def <@>(hd: Doc): Doc = if hd == DocNil then this else this <> DocBreak <> hd - } - def <>(hd: Doc): Doc = (this, hd) match { + def <>(hd: Doc): Doc = (this, hd) match case (_, DocNil) => this case (DocNil, _) => hd case _ => new DocCons(this, hd) - } def <+>(hd: Doc): Doc = this <> DocSpace <> hd - def pretty: String = { + def pretty: String = val writer = new java.io.StringWriter() format(writer) writer.toString - } /** * Format this Doc on `writer`. */ - def format(writer: Writer): Unit = { + def format(writer: Writer): Unit = type FmtState = (Int, Doc) - def spaces(n: Int): Unit = { + def spaces(n: Int): Unit = var rem = n while rem >= 16 do { writer write " "; rem -= 16 } if rem >= 8 then { writer write " "; rem -= 8 } if rem >= 4 then { writer write " "; rem -= 4 } if rem >= 2 then { writer write " "; rem -= 2 } if rem == 1 then { writer write " " } - } - def fmt(state: List[FmtState]): Unit = state match { + def fmt(state: List[FmtState]): Unit = state match case List() => () case (_, DocNil) :: z => fmt(z) case (i, DocCons(h, t)) :: z => fmt((i, h) :: (i, t) :: z) @@ -71,13 +67,10 @@ object PrettyPrint { writer.write(" "); fmt(z) } case _ => () - } fmt(List((0, this))) - } - } - object Doc { + object Doc: /** The empty Doc */ def emptyDoc = DocNil @@ -142,5 +135,3 @@ object PrettyPrint { def braces(d: Doc) = enclose(text("{"), d, text("}")) def brackets(d: Doc) = enclose(text("["), d, text("]")) def angles(d: Doc) = enclose(text("<"), d, text(">")) - } -} diff --git a/src/main/scala/common/EnvHelpers.scala b/src/main/scala/common/EnvHelpers.scala index 932e1b02..b36bf517 100644 --- a/src/main/scala/common/EnvHelpers.scala +++ b/src/main/scala/common/EnvHelpers.scala @@ -1,6 +1,6 @@ package fuselang.common -object EnvHelpers { +object EnvHelpers: trait ScopeManager[T <: ScopeManager[_]] { this: T => @@ -10,9 +10,8 @@ object EnvHelpers { * * @param inScope Commands executed inside a new Scope level. */ - def withScope(inScope: T => T): T = { + def withScope(inScope: T => T): T = inScope(this) - } /** * Open a new scope and run commands in it. When the scope ends, the @@ -21,9 +20,8 @@ object EnvHelpers { * * @param inScope Commands executed inside a new Scope level. */ - def withScopeAndRet[V](inScope: T => (V, T)): (V, T) = { + def withScopeAndRet[V](inScope: T => (V, T)): (V, T) = inScope(this) - } /** * Merge this environment with [[that]] for some abstract merge function. @@ -53,8 +51,6 @@ object EnvHelpers { * Definition of a trivial environment that doesn't track any * information. */ - case class UnitEnv() extends ScopeManager[UnitEnv] { + case class UnitEnv() extends ScopeManager[UnitEnv]: def merge(that: UnitEnv) = this - } -} diff --git a/src/main/scala/common/Errors.scala b/src/main/scala/common/Errors.scala index bebc158d..574f43dc 100644 --- a/src/main/scala/common/Errors.scala +++ b/src/main/scala/common/Errors.scala @@ -4,16 +4,15 @@ import Syntax._ import MultiSet._ import scala.util.parsing.input.Position -object Errors { +object Errors: def withPos(s: String, pos: Position, postMsg: String = "") = s"[Line ${pos.line}, Column ${pos.column}] $s\n${pos.longString}\n${postMsg}" - class TypeError(msg: String) extends RuntimeException(msg) { + class TypeError(msg: String) extends RuntimeException(msg): def this(msg: String, pos: Position, postMsg: String) = this(withPos(msg, pos, postMsg)) def this(msg: String, pos: Position) = this(msg, pos, "") - } def alreadyConsumedError( id: Id, @@ -22,7 +21,7 @@ object Errors { conLocs: MultiSet[Position], pos: Position, trace: Seq[String] - ) = { + ) = val prevCons = conLocs.setMap .dropRight(1) .map({ @@ -43,7 +42,6 @@ object Errors { |Last gadget trace was: |${trace.mkString("\n")} """.stripMargin.trim - } @deprecated( "MsgErrors are not informative. Either create a new Error case or reuse one of the exisiting ones", @@ -277,24 +275,21 @@ object Errors { extends RuntimeException(withPos(s"$construct are not supported.", pos)) case class Malformed(pos: Position, msg: String) extends RuntimeException(withPos(msg, pos)) -} -object CompilerError { +object CompilerError: // Errors generated by a pass. Usually occur when an assumption is // violated. case class PassError(msg: String) extends RuntimeException(msg) - object PassError { + object PassError: def apply(msg: String, pos: Position): PassError = this(Errors.withPos(msg, pos)) - } // Errors generated by backends. case class BackendError(msg: String) extends RuntimeException(msg) - object BackendError { + object BackendError: def apply(msg: String, pos: Position): BackendError = this(Errors.withPos(msg, pos)) - } // Errors generated by fuse CLI case class HeaderMissing(hdr: String, hdrLoc: String) @@ -305,10 +300,9 @@ object CompilerError { implicit func: sourcecode.Enclosing, line: sourcecode.Line ) extends RuntimeException(s"[$func:$line] $msg") - object Impossible { + object Impossible: def apply(msg: String, pos: Position): Impossible = this(Errors.withPos(msg, pos)) - } // Used when a feature is not yet implemented case class NotImplemented(msg: String) @@ -316,8 +310,6 @@ object CompilerError { s"$msg This feature is not yet implemented. Please open a feature request for it." ) - object NotImplemented { + object NotImplemented: def apply(msg: String, pos: Position): NotImplemented = this(Errors.withPos(msg, pos)) - } -} diff --git a/src/main/scala/common/Logger.scala b/src/main/scala/common/Logger.scala index bc5ce282..194a5922 100644 --- a/src/main/scala/common/Logger.scala +++ b/src/main/scala/common/Logger.scala @@ -4,36 +4,31 @@ import scala.util.parsing.input.Positional import scribe._ import scribe.format._ -object Logger { +object Logger: /** Makes all positionals logable by scribe */ - implicit object PositionalLoggable extends Loggable[(String, Positional)] { - override def apply(value: (String, Positional)) = { + implicit object PositionalLoggable extends Loggable[(String, Positional)]: + override def apply(value: (String, Positional)) = val pos = value._2.pos new output.TextOutput( s"[${pos.line}.${pos.column}] ${value._1}\n${pos.longString}" ) - } - } - def stringToLevel(str: String): Level = str match { + def stringToLevel(str: String): Level = str match case "error" => Level.Error case "warn" => Level.Warn case "debug" => Level.Debug case _ => throw new RuntimeException(s"Unknown level: $str") - } /** * Stateful function to set the logging level in the compiler. */ - def setLogLevel(level: Level) = { + def setLogLevel(level: Level) = scribe.Logger.root .clearHandlers() .withHandler(formatter = format, minimumLevel = Some(level)) .replace() - } val format: Formatter = formatter"[$levelColored] $message$newLine" -} diff --git a/src/main/scala/common/MultiSet.scala b/src/main/scala/common/MultiSet.scala index 5a599230..d230f645 100644 --- a/src/main/scala/common/MultiSet.scala +++ b/src/main/scala/common/MultiSet.scala @@ -2,7 +2,7 @@ package fuselang.common import scala.collection.immutable.Map -object MultiSet { +object MultiSet: def emptyMultiSet[K]() = MultiSet[K](Map[K, Int]()) @@ -13,7 +13,7 @@ object MultiSet { else ms + (v -> 1) })) - case class MultiSet[K](val setMap: Map[K, Int]) extends AnyVal { + case class MultiSet[K](val setMap: Map[K, Int]) extends AnyVal: /** * Contains at least [[num]] copies of [[element]] @@ -30,16 +30,14 @@ object MultiSet { /** * Apply [[op]] on the values associated with the same key in [[this]] and [[that]]. */ - def zipWith(that: MultiSet[K], op: (Int, Int) => Int): MultiSet[K] = { + def zipWith(that: MultiSet[K], op: (Int, Int) => Int): MultiSet[K] = val thatMap = that.setMap val (thisKeys, thatKeys) = (setMap.keys.toSet, thatMap.keys.toSet) - if thisKeys != thatKeys then { + if thisKeys != thatKeys then throw new NoSuchElementException( s"Element ${thisKeys.diff(thatKeys).head} not in both multisets.\nThis: ${setMap}\nThat: ${thatMap}." ) - } MultiSet(setMap.map({ case (k, v) => k -> op(v, thatMap(k)) })) - } /** Calculate multiset difference */ def diff(that: MultiSet[K]) = @@ -62,6 +60,4 @@ object MultiSet { def getCount(k: K): Int = setMap(k) - } -} diff --git a/src/main/scala/common/Pretty.scala b/src/main/scala/common/Pretty.scala index 577c3584..fc3c2dd3 100644 --- a/src/main/scala/common/Pretty.scala +++ b/src/main/scala/common/Pretty.scala @@ -6,9 +6,9 @@ import Syntax._ import PrettyPrint.Doc import PrettyPrint.Doc._ -object Pretty { +object Pretty: - def emitProg(p: Prog)(implicit debug: Boolean): String = { + def emitProg(p: Prog)(implicit debug: Boolean): String = val layout = vsep(p.includes.map(emitInclude)) <@> vsep(p.defs.map(emitDef)) <@> vsep(p.decors.map(d => text(d.value))) <@> @@ -16,23 +16,20 @@ object Pretty { emitCmd(p.cmd) layout.pretty - } - def emitInclude(incl: Include)(implicit debug: Boolean): Doc = { + def emitInclude(incl: Include)(implicit debug: Boolean): Doc = text("import") <+> vsep(incl.backends.map({ case (b, incl) => text(b.toString) <> parens(quote(text(incl))) })) <+> scope( vsep(incl.defs.map(emitDef)) ) - } - def emitDef(defi: Definition)(implicit debug: Boolean): Doc = defi match { + def emitDef(defi: Definition)(implicit debug: Boolean): Doc = defi match case FuncDef(id, args, ret, bodyOpt) => { - val retDoc = ret match { + val retDoc = ret match case _: TVoid => emptyDoc case _ => colon <+> emitTyp(ret) - } text("def") <+> id <> parens(ssep(args.map(emitDecl), comma <> space)) <+> retDoc <> bodyOpt.map(c => equal <+> scope(emitCmd(c))).getOrElse(semi) } @@ -41,31 +38,27 @@ object Pretty { case (id, typ) => id <> colon <+> emitTyp(typ) <> semi }))) } - } def emitDecl(d: Decl): Doc = emitId(d.id)(false) <> colon <+> emitTyp(d.typ) - def emitConsume(ann: Annotations.Consumable): Doc = ann match { + def emitConsume(ann: Annotations.Consumable): Doc = ann match case Annotations.ShouldConsume => text("consume") case Annotations.SkipConsume => text("skip") - } - implicit def emitId(id: Id)(implicit debug: Boolean): Doc = { + implicit def emitId(id: Id)(implicit debug: Boolean): Doc = val idv = value(id.v) if debug then id.typ.map(t => idv <> text("@") <> emitTyp(t)).getOrElse(idv) else idv - } def emitTyp(t: Type): Doc = text(t.toString) - def emitBaseInt(v: BigInt, base: Int): String = base match { + def emitBaseInt(v: BigInt, base: Int): String = base match case 8 => s"0${v.toString(8)}" case 10 => v.toString case 16 => s"0x${v.toString(16)}" - } - implicit def emitExpr(e: Expr)(implicit debug: Boolean): Doc = e match { + implicit def emitExpr(e: Expr)(implicit debug: Boolean): Doc = e match case ECast(e, typ) => parens(e <+> text("as") <+> emitTyp(typ)) case EApp(fn, args) => fn <> parens(commaSep(args.map(emitExpr))) case EInt(v, base) => value(emitBaseInt(v, base)) @@ -100,14 +93,12 @@ object Pretty { case EArrLiteral(idxs) => braces(commaSep(idxs.map(idx => emitExpr(idx)))) case ERecAccess(rec, field) => rec <> dot <> field case ERecLiteral(fs) => - scope { + scope: hsep(fs.toList.map({ case (id, expr) => id <+> equal <+> expr <> semi })) - } - } - def emitRange(range: CRange)(implicit debug: Boolean): Doc = { + def emitRange(range: CRange)(implicit debug: Boolean): Doc = val CRange(id, t, rev, s, e, u) = range val typAnnot = @@ -120,30 +111,26 @@ object Pretty { ) ) <> (if u > 1 then space <> text("unroll") <+> value(u) else emptyDoc) - } - def emitView(view: View)(implicit debug: Boolean): Doc = { + def emitView(view: View)(implicit debug: Boolean): Doc = val View(suf, pre, sh) = view - val sufDoc = suf match { + val sufDoc = suf match case Aligned(f, e) => value(f) <+> text("*") <+> e case Rotation(e) => e <> text("!") - } sufDoc <+> colon <> pre.map(p => space <> text("+") <+> value(p)).getOrElse(emptyDoc) <> sh.map(sh => space <> text("bank") <+> value(sh)).getOrElse(emptyDoc) - } - def emitAttributes(attrs: Map[String, Int]): Doc = { + def emitAttributes(attrs: Map[String, Int]): Doc = hsep(attrs.map({ case (attr, v) => text(s"@${attr}") <> parens(text(v.toString())) })) - } implicit def emitCmd(c: Command)( implicit debug: Boolean - ): Doc = { + ): Doc = val attr = if c.attributes.isEmpty then emptyDoc else @@ -155,9 +142,8 @@ object Pretty { case _: CPar | _: CSeq | _: CBlock => attr case _ => attr }) <> emitCmdBare(c)(debug) - } - def emitCmdBare(c: Command)(implicit debug: Boolean): Doc = c match { + def emitCmdBare(c: Command)(implicit debug: Boolean): Doc = c match case CPar(cmds) => vsep(cmds.map(emitCmd)) case CSeq(cmds) => vsep(cmds.map(emitCmd), text("---")) case CLet(id, typ, e) => @@ -200,5 +186,3 @@ object Pretty { emptyDoc ) <> semi case CBlock(cmd) => scope(emitCmd(cmd)) - } -} diff --git a/src/main/scala/common/ScopeMap.scala b/src/main/scala/common/ScopeMap.scala index 47862543..f2537605 100644 --- a/src/main/scala/common/ScopeMap.scala +++ b/src/main/scala/common/ScopeMap.scala @@ -3,14 +3,14 @@ package fuselang.common import CompilerError._ import fuselang.Utils._ -object ScopeMap { +object ScopeMap: /** * A map that undestands scopes. A ScopedMap is a chain of maps from * [[K]] to [[V]]. */ case class ScopedMap[K, V](val mapList: List[Map[K, V]] = List(Map[K, V]())) - extends AnyVal { + extends AnyVal: override def toString = mapList @@ -33,18 +33,16 @@ object ScopeMap { * @returns None if the value is already bound in the scope chain, otherwise * a new [[ScopedMap]] with the binding in the top most scope. */ - def add(key: K, value: V): Option[ScopedMap[K, V]] = get(key) match { + def add(key: K, value: V): Option[ScopedMap[K, V]] = get(key) match case Some(_) => None case None => Some(this.copy(mapList = mapList.head + (key -> value) :: mapList.tail)) - } /** * Add key -> value binding to the topmost scope. */ - def addShadow(key: K, value: V): ScopedMap[K, V] = { + def addShadow(key: K, value: V): ScopedMap[K, V] = this.copy(mapList = mapList.head + (key -> value) :: mapList.tail) - } /** * Update the binding for [[key]] to [[value]]. The update method walks @@ -52,20 +50,18 @@ object ScopeMap { * @returns a new [[ScopedMap]] with the key bound to value * @throw [[Errors.Unbound]] If the key is not found. */ - def update(key: K, value: V): ScopedMap[K, V] = { + def update(key: K, value: V): ScopedMap[K, V] = val scope = mapList.indexWhere(m => m.get(key).isDefined) assertOrThrow(scope > -1, Impossible(s"$key was not found.")) val newMapList = mapList.updated(scope, mapList(scope) + (key -> value)) this.copy(mapList = newMapList) - } /** Methods to manage scopes. */ def addScope: ScopedMap[K, V] = ScopedMap(Map[K, V]() :: mapList) - def endScope: Option[(Map[K, V], ScopedMap[K, V])] = mapList match { + def endScope: Option[(Map[K, V], ScopedMap[K, V])] = mapList match case Nil => None case hd :: tl => Some((hd, this.copy(mapList = tl))) - } /** Return the set of all keys. */ def keys = mapList.flatMap(m => m.keys).toSet @@ -75,5 +71,3 @@ object ScopeMap { get(k).getOrThrow(Impossible(s"$k was not found in $this.")) def +(bind: (K, V)) = add(bind._1, bind._2) - } -} diff --git a/src/main/scala/common/ScopedSet.scala b/src/main/scala/common/ScopedSet.scala index 17ddf7de..daeb4858 100644 --- a/src/main/scala/common/ScopedSet.scala +++ b/src/main/scala/common/ScopedSet.scala @@ -4,7 +4,7 @@ package fuselang.common * Set that provides support for Scopes. */ case class ScopedSet[V](val setList: List[Set[V]] = List(Set[V]())) - extends AnyVal { + extends AnyVal: override def toString = setList.map(set => s"{${set.mkString(", ")}}").mkString("->") @@ -17,8 +17,6 @@ case class ScopedSet[V](val setList: List[Set[V]] = List(Set[V]())) /** Managing Scopes */ def addScope: ScopedSet[V] = this.copy(setList = Set[V]() :: setList) - def endScope: Option[(Set[V], ScopedSet[V])] = setList match { + def endScope: Option[(Set[V], ScopedSet[V])] = setList match case Nil => None case hd :: tl => Some((hd, this.copy(setList = tl))) - } -} diff --git a/src/main/scala/common/Syntax.scala b/src/main/scala/common/Syntax.scala index b9e12f5b..a9400132 100644 --- a/src/main/scala/common/Syntax.scala +++ b/src/main/scala/common/Syntax.scala @@ -6,60 +6,52 @@ import scala.math.abs import Errors._ import Configuration.BackendOption -object Syntax { +object Syntax: - trait PositionalWithSpan extends Positional { + trait PositionalWithSpan extends Positional: var span: Int = 0 - def setSpan(span: Int): this.type = { + def setSpan(span: Int): this.type = this.span = span this - } - def withPos[T <: PositionalWithSpan](other: T): this.type = { + def withPos[T <: PositionalWithSpan](other: T): this.type = this.setPos(other.pos).setSpan(other.span) this - } - } /** * Annotations added by the various passes of the type checker. */ - object Annotations { + object Annotations: sealed trait Consumable case object ShouldConsume extends Consumable case object SkipConsume extends Consumable - sealed trait ConsumableAnnotation { + sealed trait ConsumableAnnotation: var consumable: Option[Consumable] = None - } - sealed trait TypeAnnotation { + sealed trait TypeAnnotation: var typ: Option[Type] = None; - } - } - object OpConstructor { + object OpConstructor: val add: (Double, Double) => Double = (_ + _) val mul: (Double, Double) => Double = (_ * _) val div: (Double, Double) => Double = (_ / _) val sub: (Double, Double) => Double = (_ - _) val mod: (Double, Double) => Double = (_ % _) - } import Annotations._ - case class Id(v: String) extends PositionalWithSpan with TypeAnnotation { + case class Id(v: String) extends PositionalWithSpan with TypeAnnotation: override def toString = s"$v" - } // Capabilities for read/write sealed trait Capability case object Read extends Capability case object Write extends Capability - sealed trait Type extends PositionalWithSpan { - override def toString = this match { + sealed trait Type extends PositionalWithSpan: + override def toString = this match case _: TVoid => "void" case _: TBool => "bool" case _: TRational => "rational" @@ -76,19 +68,16 @@ object Syntax { case TFun(args, ret) => s"${args.mkString("->")} -> ${ret}" case TRecType(n, _) => s"$n" case TAlias(n) => n.toString - } - } // Types that can be upcast to Ints sealed trait IntType case class TSizedInt(len: Int, unsigned: Boolean) extends Type with IntType case class TStaticInt(v: BigInt) extends Type with IntType case class TIndex(static: (Int, Int), dynamic: (Int, Int)) extends Type - with IntType { + with IntType: // Our ranges are represented as s..e with e excluded from the range. // Therefore, the maximum value is one than the product of the interval ends. val maxVal: Int = static._2 * dynamic._2 - 1 - } // Use case class instead of case object to get unique positions case class TVoid() extends Type case class TBool() extends Type @@ -103,7 +92,7 @@ object Syntax { // Each dimension has a length and a bank type DimSpec = (Int, Int) - case class TArray(typ: Type, dims: Seq[DimSpec], ports: Int) extends Type { + case class TArray(typ: Type, dims: Seq[DimSpec], ports: Int) extends Type: dims.zipWithIndex.foreach({ case ((len, bank), dim) => if bank > len || len % bank != 0 then { @@ -112,16 +101,13 @@ object Syntax { ) } }) - } - sealed trait BOp extends PositionalWithSpan { + sealed trait BOp extends PositionalWithSpan: val op: String; override def toString = this.op - def toFun: Option[(Double, Double) => Double] = this match { + def toFun: Option[(Double, Double) => Double] = this match case n: NumOp => Some(n.fun) case _ => None - } - } case class EqOp(op: String) extends BOp case class CmpOp(op: String) extends BOp @@ -129,12 +115,10 @@ object Syntax { case class NumOp(op: String, fun: (Double, Double) => Double) extends BOp case class BitOp(op: String) extends BOp - sealed trait Expr extends PositionalWithSpan with TypeAnnotation { - def isLVal = this match { + sealed trait Expr extends PositionalWithSpan with TypeAnnotation: + def isLVal = this match case _: EVar | _: EArrAccess | _: EPhysAccess => true case _ => false - } - } case class EInt(v: BigInt, base: Int = 10) extends Expr case class ERational(d: String) extends Expr case class EBool(v: Boolean) extends Expr @@ -159,19 +143,15 @@ object Syntax { s: Int, e: Int, u: Int - ) extends PositionalWithSpan { - def idxType: TIndex = { - if abs(e - s) % u != 0 then { + ) extends PositionalWithSpan: + def idxType: TIndex = + if abs(e - s) % u != 0 then throw UnrollRangeError(this.pos, e - s, u) - } else { + else TIndex((0, u), (s / u, e / u)) - } - } - } - case class ROp(op: String) extends PositionalWithSpan { + case class ROp(op: String) extends PositionalWithSpan: override def toString = this.op - } /** Views **/ sealed trait Suffix extends PositionalWithSpan @@ -189,9 +169,8 @@ object Syntax { case class View(suffix: Suffix, prefix: Option[Int], shrink: Option[Int]) extends PositionalWithSpan - sealed trait Command extends PositionalWithSpan { + sealed trait Command extends PositionalWithSpan: var attributes: Map[String, Int] = Map() - } case class CPar(cmds: Seq[Command]) extends Command case class CSeq(cmds: Seq[Command]) extends Command case class CLet(id: Id, var typ: Option[Type], e: Option[Expr]) @@ -208,29 +187,26 @@ object Syntax { case class CWhile(cond: Expr, pipeline: Boolean, body: Command) extends Command case class CDecorate(value: String) extends Command - case class CUpdate(lhs: Expr, rhs: Expr) extends Command { + case class CUpdate(lhs: Expr, rhs: Expr) extends Command: if lhs.isLVal == false then throw UnexpectedLVal(lhs, "assignment") - } - case class CReduce(rop: ROp, lhs: Expr, rhs: Expr) extends Command { + case class CReduce(rop: ROp, lhs: Expr, rhs: Expr) extends Command: if lhs.isLVal == false then throw UnexpectedLVal(lhs, "reduction") - } case class CReturn(exp: Expr) extends Command case class CExpr(exp: Expr) extends Command case class CBlock(cmd: Command) extends Command case object CEmpty extends Command // Smart constructors for composition - object CPar { - def smart(c1: Command, c2: Command): Command = (c1, c2) match { + object CPar: + def smart(c1: Command, c2: Command): Command = (c1, c2) match case (l: CPar, r: CPar) => l.copy(cmds = l.cmds ++ r.cmds) case (l: CPar, r) => l.copy(cmds = l.cmds :+ r) case (l, r: CPar) => r.copy(cmds = l +: r.cmds) case (CEmpty, r) => r case (l, CEmpty) => l case _ => CPar(Seq(c1, c2)) - } - def smart(cmds: Seq[Command]): Command = { + def smart(cmds: Seq[Command]): Command = val flat = cmds.flatMap(cmd => cmd match { case CPar(cs) => cs @@ -238,26 +214,22 @@ object Syntax { case _ => Seq(cmd) } ) - if flat.length == 0 then { + if flat.length == 0 then CEmpty - } else if flat.length == 1 then { + else if flat.length == 1 then flat(0) - } else { + else CPar(flat) - } - } - } - object CSeq { - def smart(c1: Command, c2: Command): Command = (c1, c2) match { + object CSeq: + def smart(c1: Command, c2: Command): Command = (c1, c2) match case (l: CSeq, r: CSeq) => l.copy(cmds = l.cmds ++ r.cmds) case (l: CSeq, r) => l.copy(cmds = l.cmds :+ r) case (l, r: CSeq) => r.copy(cmds = l +: r.cmds) case (CEmpty, r) => r case (l, CEmpty) => l case _ => CSeq(Seq(c1, c2)) - } - def smart(cmds: Seq[Command]): Command = { + def smart(cmds: Seq[Command]): Command = val flat = cmds.flatMap(cmd => cmd match { case CSeq(cs) => cs @@ -265,15 +237,12 @@ object Syntax { case _ => Seq(cmd) } ) - if flat.length == 0 then { + if flat.length == 0 then CEmpty - } else if flat.length == 1 then { + else if flat.length == 1 then flat(0) - } else { + else CSeq(flat) - } - } - } sealed trait Definition extends PositionalWithSpan @@ -287,7 +256,7 @@ object Syntax { retTy: Type, bodyOpt: Option[Command] ) extends Definition - case class RecordDef(name: Id, fields: Map[Id, Type]) extends Definition { + case class RecordDef(name: Id, fields: Map[Id, Type]) extends Definition: fields.foreach({ case (f, t) => t match { @@ -295,7 +264,6 @@ object Syntax { case _ => () } }) - } /** * An include with the name of the module and function definitions. @@ -317,14 +285,10 @@ object Syntax { /** * Define common helper methods implicit classes. */ - implicit class RichType(typ: Type) { + implicit class RichType(typ: Type): def matchOrError[A](pos: Position, construct: String, exp: String)( andThen: PartialFunction[Type, A] - ): A = { - val mismatchError: PartialFunction[Type, A] = { + ): A = + val mismatchError: PartialFunction[Type, A] = case _ => throw UnexpectedType(pos, construct, exp, typ) - } andThen.orElse(mismatchError)(typ) - } - } -} diff --git a/src/main/scala/common/Transformer.scala b/src/main/scala/common/Transformer.scala index 0a229b90..74336c90 100644 --- a/src/main/scala/common/Transformer.scala +++ b/src/main/scala/common/Transformer.scala @@ -5,9 +5,9 @@ import Syntax._ import EnvHelpers._ import scala.{PartialFunction => PF} -object Transformer { +object Transformer: - abstract class Transformer { + abstract class Transformer: type Env <: ScopeManager[Env] @@ -18,86 +18,76 @@ object Transformer { */ def transferPos(cmd: Command, f: PF[(Command, Env), (Command, Env)])( implicit env: Env - ): (Command, Env) = { + ): (Command, Env) = val (c1, env1) = f(cmd, env) - c1 match { + c1 match case _: CPar | _: CSeq | _: CBlock => () case _ => { - if c1.pos.line == 0 && c1.pos.column == 0 then { + if c1.pos.line == 0 && c1.pos.column == 0 then c1.withPos(cmd) - } } - } // If the position for the command is undefined, add the previous position (c1, env1) - } /** * Top level function called on the AST. */ - def rewrite(p: Prog): Prog = { + def rewrite(p: Prog): Prog = val Prog(_, defs, _, decls, cmd) = p val (ndefs, env) = rewriteDefSeq(defs)(emptyEnv) val (ndecls, env1) = rewriteDeclSeq(decls)(env) val (ncmd, _) = rewriteC(cmd)(env1) p.copy(defs = ndefs.toSeq, decls = ndecls.toSeq, cmd = ncmd) - } /** * Helper functions for checking sequences of the same element. */ def rewriteSeqWith[T]( f: (T, Env) => (T, Env) - )(iter: Iterable[T])(env: Env): (Iterable[T], Env) = { + )(iter: Iterable[T])(env: Env): (Iterable[T], Env) = val (ts, env1) = iter.foldLeft(Seq[T](), env)({ case ((ts, env), t) => val (t1, env1) = f(t, env) (t1 +: ts, env1) }) (ts.reverse, env1) - } def rewriteESeq( exprs: Iterable[Expr] - )(implicit env: Env): (Iterable[Expr], Env) = { + )(implicit env: Env): (Iterable[Expr], Env) = rewriteSeqWith[Expr](rewriteE(_: Expr)(_: Env))(exprs)(env) - } def rewriteCSeq( cmds: Iterable[Command] - )(implicit env: Env): (Iterable[Command], Env) = { + )(implicit env: Env): (Iterable[Command], Env) = rewriteSeqWith[Command](rewriteC(_: Command)(_: Env))(cmds)(env) - } def rewriteDefSeq( defs: Iterable[Definition] - )(implicit env: Env): (Iterable[Definition], Env) = { + )(implicit env: Env): (Iterable[Definition], Env) = rewriteSeqWith[Definition](rewriteDef(_: Definition)(_: Env))(defs)(env) - } def rewriteDeclSeq(ds: Seq[Decl])(implicit env: Env): (Seq[Decl], Env) = (ds, env) - def rewriteDef(defi: Definition)(implicit env: Env) = defi match { + def rewriteDef(defi: Definition)(implicit env: Env) = defi match case fdef @ FuncDef(_, args, _, bodyOpt) => { val (nArgs, env1) = rewriteDeclSeq(args) - val (nBody, env2) = bodyOpt match { + val (nBody, env2) = bodyOpt match case None => (None, env1) case Some(body) => { val (nbody, nEnv) = env1.withScopeAndRet(rewriteC(body)(_)) Some(nbody) -> nEnv } - } fdef.copy(args = nArgs, bodyOpt = nBody) -> env2 } case _: RecordDef => (defi, env) - } def rewriteE(expr: Expr)(implicit env: Env): (Expr, Env) = - expr match { + expr match case _: ERational | _: EInt | _: EBool | _: EVar => (expr, env) case ERecLiteral(fields) => { val (fs, env1) = rewriteESeq(fields.map(_._2)) @@ -137,11 +127,10 @@ object Transformer { }) acc.copy(bankIdxs = nBankIdxsReversed.reverse) -> nEnv } - } def rewriteLVal(e: Expr)(implicit env: Env): (Expr, Env) = rewriteE(e) - def rewriteCBare(cmd: Command)(implicit env: Env): (Command, Env) = cmd match { + def rewriteCBare(cmd: Command)(implicit env: Env): (Command, Env) = cmd match case _: CSplit | _: CView | CEmpty | _: CDecorate => (cmd, env) case CPar(cmds) => { val (ncmds, env1) = rewriteCSeq(cmds) @@ -162,13 +151,12 @@ object Transformer { red.copy(lhs = nlhs, rhs = nrhs) -> env2 } case let @ CLet(_, _, eOpt) => - eOpt match { + eOpt match case None => let -> env case Some(e) => { val (e1, env1) = rewriteE(e) let.copy(e = Some(e1)) -> env1 } - } case CExpr(e) => { val (e1, env1) = rewriteE(e) CExpr(e1) -> env1 @@ -197,16 +185,13 @@ object Transformer { val (nbody, env1) = env.withScopeAndRet(rewriteC(body)(_)) cb.copy(cmd = nbody) -> env1 } - } // Rewrite the command and transfer the attribute - def rewriteC(cmd: Command)(implicit env: Env): (Command, Env) = { + def rewriteC(cmd: Command)(implicit env: Env): (Command, Env) = val (nCmd, nEnv) = rewriteCBare(cmd)(env) nCmd.attributes = cmd.attributes nCmd -> nEnv - } - } /** * Partial transformer defines helper functions for writing down @@ -231,7 +216,7 @@ object Transformer { * executes myRewriteE first and if there are no matching cases, falls * back to partialRewriteE which has the default traversal behavior. */ - abstract class PartialTransformer extends Transformer { + abstract class PartialTransformer extends Transformer: private val partialRewriteE: PF[(Expr, Env), (Expr, Env)] = asPartial(super.rewriteE(_: Expr)(_: Env)) @@ -242,19 +227,16 @@ object Transformer { // pattern. def mergeRewriteE( myRewriteE: PF[(Expr, Env), (Expr, Env)] - ): PF[(Expr, Env), (Expr, Env)] = { + ): PF[(Expr, Env), (Expr, Env)] = myRewriteE.orElse(partialRewriteE) - } def mergeRewriteC( myRewriteC: PF[(Command, Env), (Command, Env)] - ): PF[(Command, Env), (Command, Env)] = { + ): PF[(Command, Env), (Command, Env)] = val func = (cmd: Command, env: Env) => { transferPos(cmd, myRewriteC.orElse(partialRewriteC))(env) } asPartial(func(_: Command, _: Env)) - } - } /** * Transformer that adds type annotations to newly created AST nodes @@ -264,7 +246,7 @@ object Transformer { * environment returned from the TypeChecker and use it to add type * annotations to every AST node. */ - abstract class TypedPartialTransformer extends PartialTransformer { + abstract class TypedPartialTransformer extends PartialTransformer: private val partialRewriteE: PF[(Expr, Env), (Expr, Env)] = asPartial(super.rewriteE(_: Expr)(_: Env)) @@ -274,20 +256,16 @@ object Transformer { */ def transferType(expr: Expr, f: PF[(Expr, Env), (Expr, Env)])( implicit env: Env - ): (Expr, Env) = { + ): (Expr, Env) = val (e1, env1) = f(expr, env) e1.typ = expr.typ (e1, env1) - } override def mergeRewriteE( myRewriteE: PF[(Expr, Env), (Expr, Env)] - ): PF[(Expr, Env), (Expr, Env)] = { + ): PF[(Expr, Env), (Expr, Env)] = val func = (expr: Expr, env: Env) => { transferType(expr, myRewriteE.orElse(partialRewriteE))(env) } asPartial(func(_: Expr, _: Env)) - } - } -} diff --git a/src/main/scala/passes/AddBitWidth.scala b/src/main/scala/passes/AddBitWidth.scala index 9f5c1d4f..a227d77a 100644 --- a/src/main/scala/passes/AddBitWidth.scala +++ b/src/main/scala/passes/AddBitWidth.scala @@ -11,19 +11,17 @@ import fuselang.typechecker.Subtyping // Add bitwidth information to all leaves of a binary expression by adding // case expressions. -object AddBitWidth extends TypedPartialTransformer { +object AddBitWidth extends TypedPartialTransformer: - case class ABEnv(curTyp: Option[Type]) extends ScopeManager[ABEnv] { - def merge(that: ABEnv) = { + case class ABEnv(curTyp: Option[Type]) extends ScopeManager[ABEnv]: + def merge(that: ABEnv) = assert(this == that, "Tried to merge different bitwidth envs") this - } - } type Env = ABEnv val emptyEnv = ABEnv(None) - def myRewriteE: PF[(Expr, Env), (Expr, Env)] = { + def myRewriteE: PF[(Expr, Env), (Expr, Env)] = case (e: ECast, env) => e -> env case (e @ EArrAccess(arrId, idxs), env) => { val Some(TArray(_, dims, _)) = arrId.typ : @unchecked @@ -45,11 +43,10 @@ object AddBitWidth extends TypedPartialTransformer { e.copy(idxs = nIdxs) -> env } case (e: EInt, env) => - if env.curTyp.isDefined then { + if env.curTyp.isDefined then (ECast(e, env.curTyp.get), env) - } else { + else e -> env - } case (expr @ EBinop(_: EqOp | _: CmpOp, l, r), env) => { val typ = Subtyping .joinOf(l.typ.get, r.typ.get, expr.op) @@ -60,20 +57,18 @@ object AddBitWidth extends TypedPartialTransformer { expr.copy(e1 = nl, e2 = nr) -> env } case (expr @ EBinop(_: NumOp | _: BitOp, l, r), env) => { - val nEnv = if env.curTyp.isDefined then { + val nEnv = if env.curTyp.isDefined then env - } else { + else ABEnv( Some(expr.typ.getOrThrow(PassError("Expression is missing type"))) ) - } val (nl, _) = rewriteE(l)(nEnv) val (nr, _) = rewriteE(r)(nEnv) expr.copy(e1 = nl, e2 = nr) -> env } - } - def myRewriteC: PF[(Command, Env), (Command, Env)] = { + def myRewriteC: PF[(Command, Env), (Command, Env)] = case (CUpdate(l, r), env) => { val nEnv = ABEnv( Some(l.typ.getOrThrow(PassError("LHS is missing type"))) @@ -89,24 +84,20 @@ object AddBitWidth extends TypedPartialTransformer { val (ne, _) = rewriteE(e)(nEnv) cmd.copy(e = Some(ne)) -> env } - } override def transferType(expr: Expr, f: PF[(Expr, Env), (Expr, Env)])( implicit env: Env - ): (Expr, Env) = { + ): (Expr, Env) = val (e1, env1) = f(expr, env) - val nTyp = e1 match { + val nTyp = e1 match case ECast(_, t) => { Some(t) } case _ => expr.typ - } e1.typ = nTyp (e1, env1) - } override def rewriteC(cmd: Command)(implicit env: Env) = mergeRewriteC(myRewriteC)(cmd, env) override def rewriteE(expr: Expr)(implicit env: Env) = mergeRewriteE(myRewriteE)(expr, env) -} diff --git a/src/main/scala/passes/BoundsCheck.scala b/src/main/scala/passes/BoundsCheck.scala index 528c034f..bbd3cb09 100644 --- a/src/main/scala/passes/BoundsCheck.scala +++ b/src/main/scala/passes/BoundsCheck.scala @@ -12,11 +12,11 @@ import Logger._ import Checker._ import EnvHelpers._ -object BoundsChecker { +object BoundsChecker: def check(p: Prog) = BCheck.check(p) - private case object BCheck extends PartialChecker { + private case object BCheck extends PartialChecker: type Env = UnitEnv @@ -26,19 +26,18 @@ object BoundsChecker { * Given a view with a known prefix length, check if it **might** cause an * out of bound access when accessed. */ - private def checkView(arrLen: Int, viewId: Id, view: View) = { - if view.prefix.isDefined then { + private def checkView(arrLen: Int, viewId: Id, view: View) = + if view.prefix.isDefined then val View(suf, Some(pre), _) = view : @unchecked - val (sufExpr, fac) = suf match { + val (sufExpr, fac) = suf match case Aligned(fac, e) => (e, fac) case Rotation(e) => (e, 1) - } val maxVal: BigInt = sufExpr.typ .getOrThrow(Impossible(s"$sufExpr is missing type")) - .matchOrError(viewId.pos, "view", "Integer Type") { + .matchOrError(viewId.pos, "view", "Integer Type"): case idx: TIndex => fac * idx.maxVal case TStaticInt(v) => fac * v case idx: TSizedInt => @@ -49,19 +48,15 @@ object BoundsChecker { ) ); 1 - } - if maxVal + pre > arrLen then { + if maxVal + pre > arrLen then throw IndexOutOfBounds(viewId, arrLen, maxVal + pre, viewId.pos) - } - } - } - def myCheckE: PF[(Expr, Env), Env] = { + def myCheckE: PF[(Expr, Env), Env] = case (EArrAccess(id, idxs), e) => { id.typ .getOrThrow(Impossible(s"$id missing type in $e")) - .matchOrError(id.pos, "array access", s"array type") { + .matchOrError(id.pos, "array access", s"array type"): case TArray(_, dims, _) => idxs .map(idx => idx -> idx.typ) @@ -89,16 +84,14 @@ object BoundsChecker { throw UnexpectedType(id.pos, "array access", s"[$t]", t) }) }) - } e } - } - def myCheckC: PF[(Command, Env), Env] = { + def myCheckC: PF[(Command, Env), Env] = case (c @ CView(viewId, arrId, views), e) => { val typ = arrId.typ.getOrThrow(Impossible(s"$arrId is missing type in $c")) - typ.matchOrError(c.pos, "view", "array type") { + typ.matchOrError(c.pos, "view", "array type"): case TArray(_, dims, _) => views .zip(dims) @@ -106,14 +99,10 @@ object BoundsChecker { case (view, (len, _)) => checkView(len, viewId, view) }) - } e } - } override def checkE(expr: Expr)(implicit env: Env) = mergeCheckE(myCheckE)(expr, env) override def checkC(cmd: Command)(implicit env: Env) = mergeCheckC(myCheckC)(cmd, env) - } -} diff --git a/src/main/scala/passes/DependentLoops.scala b/src/main/scala/passes/DependentLoops.scala index 2a96df95..6e359e4f 100644 --- a/src/main/scala/passes/DependentLoops.scala +++ b/src/main/scala/passes/DependentLoops.scala @@ -10,71 +10,61 @@ import Errors._ import Checker._ import EnvHelpers._ -object DependentLoops { +object DependentLoops: def check(p: Prog) = DepCheck.check(p) private case class UseEnv( used: Set[Id] - ) extends ScopeManager[UseEnv] { + ) extends ScopeManager[UseEnv]: def merge(that: UseEnv): UseEnv = UseEnv(this.used ++ that.used) - def add(id: Id): UseEnv = { + def add(id: Id): UseEnv = UseEnv(this.used + id) - } - } - private case object UseCheck extends PartialChecker { + private case object UseCheck extends PartialChecker: type Env = UseEnv val emptyEnv = UseEnv(Set()) - def myCheckE: PF[(Expr, Env), Env] = { + def myCheckE: PF[(Expr, Env), Env] = case (EVar(id), env) => env.add(id) - } override def checkE(expr: Expr)(implicit env: Env) = mergeCheckE(myCheckE)(expr, env) - } private case class DepEnv( loopVars: Set[Id], depVars: Set[Id] - ) extends ScopeManager[DepEnv] { + ) extends ScopeManager[DepEnv]: - def forgetScope(inScope: DepEnv => DepEnv): DepEnv = { + def forgetScope(inScope: DepEnv => DepEnv): DepEnv = inScope(this) this - } def merge(that: DepEnv): DepEnv = DepEnv(this.loopVars ++ that.loopVars, this.depVars ++ that.depVars) - def addLoopVar(id: Id): DepEnv = { + def addLoopVar(id: Id): DepEnv = // remove id and then add it back so that the most recent id is in the set DepEnv(this.loopVars + id, this.depVars) - } - def addDep(id: Id): DepEnv = { + def addDep(id: Id): DepEnv = DepEnv(this.loopVars, (this.depVars - id) + id) - } - def removeDep(id: Id): DepEnv = { + def removeDep(id: Id): DepEnv = DepEnv(this.loopVars, this.depVars - id) - } - def intersect(set: Set[Id]): Set[Id] = { + def intersect(set: Set[Id]): Set[Id] = (this.loopVars.union(this.depVars)).intersect(set) - } - } - private case object DepCheck extends PartialChecker { + private case object DepCheck extends PartialChecker: type Env = DepEnv val emptyEnv = DepEnv(Set(), Set()) - def myCheckE: PF[(Expr, Env), Env] = { + def myCheckE: PF[(Expr, Env), Env] = case (EArrAccess(id @ _, idxs), env) => { idxs.foreach(e => { val used = UseCheck.checkE(e)(UseCheck.emptyEnv) @@ -86,37 +76,30 @@ object DependentLoops { }) env } - } - def myCheckC: PF[(Command, Env), Env] = { + def myCheckC: PF[(Command, Env), Env] = case (CFor(range, _, par, _), env) => { - if range.u > 1 then { + if range.u > 1 then env.forgetScope(e1 => checkC(par)(e1.addLoopVar(range.iter))) - } else { + else env.forgetScope(e1 => checkC(par)(e1)) - } } case (CLet(id, _, Some(exp)), env) => { val used = UseCheck.checkE(exp)(UseCheck.emptyEnv) - if env.intersect(used.used).size != 0 then { + if env.intersect(used.used).size != 0 then env.addDep(id) - } else { + else env - } } case (CUpdate(EVar(id), rhs), env) => { val used = UseCheck.checkE(rhs)(UseCheck.emptyEnv) - if env.intersect(used.used).size != 0 then { + if env.intersect(used.used).size != 0 then env.addDep(id) - } else { + else env.removeDep(id) - } } - } override def checkE(expr: Expr)(implicit env: Env) = mergeCheckE(myCheckE)(expr, env) override def checkC(cmd: Command)(implicit env: Env) = mergeCheckC(myCheckC)(cmd, env) - } -} diff --git a/src/main/scala/passes/HoistMemoryReads.scala b/src/main/scala/passes/HoistMemoryReads.scala index 5d98a6a8..e2c69770 100644 --- a/src/main/scala/passes/HoistMemoryReads.scala +++ b/src/main/scala/passes/HoistMemoryReads.scala @@ -7,36 +7,31 @@ import Transformer._ import EnvHelpers._ import Syntax._ -object HoistMemoryReads extends PartialTransformer { +object HoistMemoryReads extends PartialTransformer: // Env for storing the assignments for reads to replace case class BufferEnv(map: ListMap[Expr, CLet] = ListMap()) extends ScopeManager[BufferEnv] - with Tracker[Expr, CLet, BufferEnv] { - def merge(that: BufferEnv) = { + with Tracker[Expr, CLet, BufferEnv]: + def merge(that: BufferEnv) = BufferEnv(this.map ++ that.map) - } def get(key: Expr) = this.map.get(key) - def add(key: Expr, value: CLet) = { + def add(key: Expr, value: CLet) = BufferEnv(this.map + (key -> value)) - } - } type Env = BufferEnv val emptyEnv = BufferEnv() /** Helper for generating unique names. */ var idx: Map[String, Int] = Map(); - def genName(base: String): Id = { + def genName(base: String): Id = // update idx - idx.get(base) match { + idx.get(base) match case Some(n) => idx = idx + (base -> (n + 1)) case None => idx = idx + (base -> 0) - } Id(s"$base${idx(base)}") - } /** Constructs a (Command, Env) tuple from a command * and an environment containing new let bindings. */ @@ -44,22 +39,20 @@ object HoistMemoryReads extends PartialTransformer { cmd: Command, env: Env, acc: Command = CEmpty - ): Command = { - if env.map.values.isEmpty && acc == CEmpty then { + ): Command = + if env.map.values.isEmpty && acc == CEmpty then cmd - } else { + else CPar.smart(env.map.values.toSeq :+ acc :+ cmd) - } - } /** Replaces array accesses with reads from a temporary variable. * Inserts a let binding into the Env and relies on the rewriteC. * to insert this into the code. */ - def myRewriteE: PF[(Expr, Env), (Expr, Env)] = { + def myRewriteE: PF[(Expr, Env), (Expr, Env)] = case (e @ EArrAccess(id, exprs), env) => { val (nexprs, env1) = rewriteSeqWith[Expr](rewriteE(_: Expr)(_: Env))(exprs)(env) - env1.get(e) match { + env1.get(e) match case Some(let) => EVar(let.id) -> env1 case None => { val readTmp = genName(s"${id}_read") @@ -67,17 +60,14 @@ object HoistMemoryReads extends PartialTransformer { val nEnv = env.add(e, read) EVar(readTmp) -> nEnv } - } } - } /** Simple wrapper that calls rewriteC with an emptyEnv * and projects the first element of the result. */ - def rewrC(c: Command): Command = { + def rewrC(c: Command): Command = rewriteC(c)(emptyEnv)._1 - } - def myRewriteC: PF[(Command, Env), (Command, Env)] = { + def myRewriteC: PF[(Command, Env), (Command, Env)] = // Don't rewrite directly bound array reads. Rewrite access expressions // if any. case (c @ CLet(_, _, Some(arr @ EArrAccess(_, exprs))), env) => { @@ -139,11 +129,9 @@ object HoistMemoryReads extends PartialTransformer { val (rewrite, env) = rewriteE(rhs)(emptyEnv) construct(CReduce(rop, e, rewrite), env) -> emptyEnv } - } override def rewriteC(cmd: Command)(implicit env: Env) = mergeRewriteC(myRewriteC)(cmd, env) override def rewriteE(expr: Expr)(implicit env: Env) = mergeRewriteE(myRewriteE)(expr, env) -} diff --git a/src/main/scala/passes/HoistSlowBinop.scala b/src/main/scala/passes/HoistSlowBinop.scala index 059dbe62..b6e81e20 100644 --- a/src/main/scala/passes/HoistSlowBinop.scala +++ b/src/main/scala/passes/HoistSlowBinop.scala @@ -9,56 +9,47 @@ import Transformer._ import EnvHelpers._ import fuselang.backend.calyx.{Helpers => calyx}; -object HoistSlowBinop extends TypedPartialTransformer { +object HoistSlowBinop extends TypedPartialTransformer: case class ExprEnv(map: ListMap[Expr, CLet]) extends ScopeManager[ExprEnv] - with Tracker[Expr, CLet, ExprEnv] { - def merge(that: ExprEnv) = { + with Tracker[Expr, CLet, ExprEnv]: + def merge(that: ExprEnv) = ExprEnv(this.map ++ that.map) - } def get(key: Expr) = this.map.get(key) - def add(key: Expr, value: CLet) = { + def add(key: Expr, value: CLet) = ExprEnv(this.map + (key -> value)) - } - } type Env = ExprEnv val emptyEnv = ExprEnv(ListMap()) var idx: Map[String, Int] = Map(); - def genName(base: String): Id = { + def genName(base: String): Id = // update idx - idx.get(base) match { + idx.get(base) match case Some(n) => idx = idx + (base -> (n + 1)) case None => idx = idx + (base -> 0) - } Id(s"$base${idx(base)}_") - } def construct( cmd: Command, env: Env, acc: Command = CEmpty - ): (Command, Env) = { - if env.map.values.isEmpty && acc == CEmpty then { + ): (Command, Env) = + if env.map.values.isEmpty && acc == CEmpty then cmd -> emptyEnv - } else { + else CSeq.smart(env.map.values.toSeq :+ acc :+ cmd) -> emptyEnv - } - } - def binopRecur(expr: Expr, env: Env): (Expr, Env) = { - expr match { + def binopRecur(expr: Expr, env: Env): (Expr, Env) = + expr match // only recur when children are binops case EBinop(_, _, _) => rewriteE(expr)(env) case _ => (expr, env) - } - } - def myRewriteE: PF[(Expr, Env), (Expr, Env)] = { + def myRewriteE: PF[(Expr, Env), (Expr, Env)] = case (e @ EBinop(op, left, right), env) if calyx.slowBinops.contains(op.op) => { - env.get(e) match { + env.get(e) match case Some(let) => (EVar(let.id), env) case None => { val (leftRead, leftEnv) = binopRecur(left, env) @@ -70,11 +61,9 @@ object HoistSlowBinop extends TypedPartialTransformer { ) EVar(let.id) -> rightEnv.add(EBinop(op, leftRead, rightRead), let) } - } } - } - def myRewriteC: PF[(Command, Env), (Command, Env)] = { + def myRewriteC: PF[(Command, Env), (Command, Env)] = case (CLet(id, typ, Some(e)), _) => { val (expr, env) = rewriteE(e)(emptyEnv) construct(CLet(id, typ, Some(expr)), env) @@ -99,10 +88,9 @@ object HoistSlowBinop extends TypedPartialTransformer { construct(CUpdate(rewrLhs, rewrRhs), nEnv) } case (CReduce(rop, lhs, rhs), _) => { - rop.op match { + rop.op match case "*=" | "/=" => throw NotImplemented(s"Hoisting $rop.op", rop.pos) case _ => () - } val (rewrLhs, env) = rewriteE(lhs)(emptyEnv) val (rewrRhs, nEnv) = rewriteE(rhs)(env) construct(CReduce(rop, rewrLhs, rewrRhs), nEnv) @@ -115,10 +103,8 @@ object HoistSlowBinop extends TypedPartialTransformer { val (nExpr, env) = rewriteE(expr)(emptyEnv) construct(CExpr(nExpr), env) } - } override def rewriteE(expr: Expr)(implicit env: Env) = mergeRewriteE(myRewriteE)(expr, env) override def rewriteC(cmd: Command)(implicit env: Env) = mergeRewriteC(myRewriteC)(cmd, env) -} diff --git a/src/main/scala/passes/LoopCheck.scala b/src/main/scala/passes/LoopCheck.scala index 8d7c76ae..3fbf04c2 100644 --- a/src/main/scala/passes/LoopCheck.scala +++ b/src/main/scala/passes/LoopCheck.scala @@ -10,7 +10,7 @@ import CompilerError._ import Checker._ import EnvHelpers._ -object LoopChecker { +object LoopChecker: // Possible mappings for stateMap sealed trait States @@ -25,37 +25,33 @@ object LoopChecker { nameMap: ScopedMap[Id, Id] = ScopedMap(), exprMap: ScopedMap[Id, Seq[Expr]] = ScopedMap() )(implicit val res: Int = 1) - extends ScopeManager[LEnv] { + extends ScopeManager[LEnv]: // Helper functions for nameMap - def addName(vid: Id, tid: Id): LEnv = nameMap.add(vid, tid) match { + def addName(vid: Id, tid: Id): LEnv = nameMap.add(vid, tid) match case None => throw Impossible("nameMap has this view id before, redefinition") case Some(m) => LEnv(stateMap, m) - } def getName(aid: Id): Id = nameMap.get(aid).getOrElse(aid) // Helper functions for stateMap - def atDef(id: Id): LEnv = stateMap.head.get(id) match { + def atDef(id: Id): LEnv = stateMap.head.get(id) match case None | Some(DontKnow) => LEnv(stateMap.addShadow(id, Def), nameMap, exprMap) case Some(Def) => this case Some(Use) => throw LoopDepSequential(id) - } - def atDk(id: Id): LEnv = stateMap.head.get(id) match { + def atDk(id: Id): LEnv = stateMap.head.get(id) match case None | Some(Def) => LEnv(stateMap.addShadow(id, DontKnow), nameMap, exprMap) case Some(DontKnow) => this case Some(Use) => throw LoopDepSequential(id) - } - def atUse(id: Id): LEnv = stateMap.head.get(id) match { + def atUse(id: Id): LEnv = stateMap.head.get(id) match case None => LEnv(stateMap.addShadow(id, Use), nameMap, exprMap) case Some(DontKnow) => throw LoopDepSequential(id) case Some(Def) | Some(Use) => this //Use/Def -> Use don't update - } - def checkExprMap(idxs: Option[EArrAccess]): (LEnv, Boolean) = res match { + def checkExprMap(idxs: Option[EArrAccess]): (LEnv, Boolean) = res match case 1 => (this, false) case _ => idxs.map({ case EArrAccess(id, idxs) => exprMap.get(id) match { case None => (this.copy(exprMap = exprMap.add(id, idxs).get), true) @@ -64,75 +60,61 @@ object LoopChecker { idxs.zip(oldIdxs).exists({ case (nIdx, oIdx) => nIdx != oIdx }) (this, shouldCheck) }}).getOrElse((this, true)) - } //check and update the state table def updateState( id: Id, state: States, idxs: Option[EArrAccess] = None - ): LEnv = { + ): LEnv = val (env, check) = checkExprMap(idxs) - if check then { - val e2 = state match { + if check then + val e2 = state match case DontKnow => env.atDk(env.getName(id)) case Def => env.atDef(env.getName(id)) case Use => env.atUse(env.getName(id)) - } e2 - } else + else env - } // Helper functions for ScopeManager - def withScope(resources: Int)(inScope: LEnv => LEnv): LEnv = { - if resources == 1 then { - inScope(this.addNameScope) match { + def withScope(resources: Int)(inScope: LEnv => LEnv): LEnv = + if resources == 1 then + inScope(this.addNameScope) match case env: LEnv => env.endNameScope - } - } else { - inScope(this.addScope(resources)) match { + else + inScope(this.addScope(resources)) match case env: LEnv => env.endScope(resources) - } - } - } - override def withScopeAndRet[R](inScope: LEnv => (R, LEnv)): (R, LEnv) = { - inScope(this.addNameScope) match { + override def withScopeAndRet[R](inScope: LEnv => (R, LEnv)): (R, LEnv) = + inScope(this.addNameScope) match case (r, env: LEnv) => (r, env.endNameScope) - } - } // To satisfy envhelper override def withScope(inScope: LEnv => LEnv): LEnv = withScope(1)(inScope) - def addScope(resources: Int) = { + def addScope(resources: Int) = LEnv(stateMap.addScope, nameMap.addScope, exprMap.addScope)( res * resources ) - } - def endScope(resources: Int) = { + def endScope(resources: Int) = val nmap = nameMap.endScope.get._2 val emap = exprMap.endScope.get._2 val (innermap, outermap) = stateMap.endScope.get var outerenv = LEnv(outermap, nmap, emap)(res / resources) val keys = innermap.keys - for k <- keys do { + for k <- keys do outerenv = outerenv.updateState(k, innermap(k)) //inner map is a scala map - } outerenv - } - def addNameScope = { + def addNameScope = LEnv(stateMap, nameMap.addScope, exprMap) - } - def endNameScope = { + def endNameScope = val nmap = nameMap.endScope.get._2 LEnv(stateMap, nmap, exprMap) - } def mergeHelper( k: Id, v1: Option[States], v2: Option[States], env: LEnv - ): LEnv = (v1, v2) match { + ): LEnv = (v1, v2) match case (None, None) => throw Impossible("No such merging") case (None, Some(Use)) => env.copy(stateMap = env.stateMap.addShadow(k, Use)) @@ -142,44 +124,38 @@ object LoopChecker { case (Some(Def), Some(Use)) => throw LoopDepMerge(k) case (Some(DontKnow), Some(Use)) => throw LoopDepMerge(k) case (v1, v2) => if v1 == v2 then env else mergeHelper(k, v2, v1, env) - } // If statement - def merge(that: LEnv): LEnv = { + def merge(that: LEnv): LEnv = val m1 = this.stateMap val m2 = that.stateMap val result = m1.head.keys.foldLeft[LEnv](LEnv(m1, nameMap, exprMap))({ case (env, k) => mergeHelper(k, m1.get(k), m2.get(k), env) }) result - } - } - private case object LCheck extends PartialChecker { + private case object LCheck extends PartialChecker: type Env = LEnv val emptyEnv = LEnv()(1) - def myCheckLVal(e: Expr, env: Env): Env = { - e match { + def myCheckLVal(e: Expr, env: Env): Env = + e match case EVar(id) => env.updateState(id, Def) case EArrAccess(id, idxs) => env.updateState(id, Def, Some(EArrAccess(id, idxs))) case ERecAccess(rec, _) => myCheckLVal(rec, env) case _ => throw Impossible("Cannot be lhs value") - } - } - def myCheckE: PF[(Expr, Env), Env] = { + def myCheckE: PF[(Expr, Env), Env] = // by default, this is rval case (EVar(id), e) => e.updateState(id, Use) case (EArrAccess(id, idxs), e) => { e.updateState(id, Use, Some(EArrAccess(id, idxs))) } - } - def myCheckC: PF[(Command, Env), Env] = { + def myCheckC: PF[(Command, Env), Env] = case (CUpdate(lhs, rhs), e) => myCheckLVal(lhs, checkE(rhs)(e)) case (CReduce(_, lhs, rhs), e) => myCheckLVal(lhs, checkE(rhs)(e)) case (CLet(id, _, None), e) => myCheckLVal(EVar(id), e) @@ -207,11 +183,8 @@ object LoopChecker { e1 merge e2 }) } - } override def checkE(expr: Expr)(implicit env: Env) = mergeCheckE(myCheckE)(expr, env) override def checkC(cmd: Command)(implicit env: Env) = mergeCheckC(myCheckC)(cmd, env) - } -} diff --git a/src/main/scala/passes/LowerForLoops.scala b/src/main/scala/passes/LowerForLoops.scala index bd5ae242..22d09e4c 100644 --- a/src/main/scala/passes/LowerForLoops.scala +++ b/src/main/scala/passes/LowerForLoops.scala @@ -10,25 +10,22 @@ import CompilerError._ /** * Lower for loops to while loops. */ -object LowerForLoops extends PartialTransformer { +object LowerForLoops extends PartialTransformer: case class ForEnv(map: Map[Id, Type]) extends ScopeManager[ForEnv] - with Tracker[Id, Type, ForEnv] { - def merge(that: ForEnv) = { + with Tracker[Id, Type, ForEnv]: + def merge(that: ForEnv) = ForEnv(this.map ++ that.map) - } def get(key: Id) = this.map.get(key) - def add(key: Id, typ: Type) = { + def add(key: Id, typ: Type) = ForEnv(this.map + (key -> typ)) - } - } type Env = ForEnv val emptyEnv = ForEnv(Map()) - def myRewriteC: PF[(Command, Env), (Command, Env)] = { + def myRewriteC: PF[(Command, Env), (Command, Env)] = case (cfor @ CFor(range, pipeline, par, combine), env) => { if pipeline then throw NotImplemented("Lowering pipelined for loops.") @@ -41,29 +38,26 @@ object LowerForLoops extends PartialTransformer { itVar.typ = typ // Refuse lowering without explicit type on iterator. - if typ.isDefined == false then { + if typ.isDefined == false then throw NotImplemented( "Cannot lower `for` loop without iterator type. Add explicit type for the iterator", it.pos ) - } val t = typ.get val init = CLet(it, typ, Some(ECast(if rev then EInt(e - 1) else EInt(s), t))).withPos(range) - val op = if rev then { + val op = if rev then NumOp("-", OpConstructor.sub) - } else { + else NumOp("+", OpConstructor.add) - } val upd = CUpdate(itVar.copy(), EBinop(op, itVar.copy(), ECast(EInt(1), t))).withPos(range) val cond = - if rev then { + if rev then EBinop(CmpOp(">="), itVar.copy(), ECast(EInt(s), t)) - } else { + else EBinop(CmpOp("<="), itVar.copy(), ECast(EInt(e - 1), t)) - } val nEnv = env.add(it.copy(), t) // Rewrite par and combine @@ -74,22 +68,17 @@ object LowerForLoops extends PartialTransformer { wh.attributes = cfor.attributes CBlock(CSeq.smart(Seq(init, wh))) -> nEnv } - } /** We need to change the types of iterator variables so that * they match the iterators annotated type. */ - def myRewriteE: PF[(Expr, Env), (Expr, Env)] = { + def myRewriteE: PF[(Expr, Env), (Expr, Env)] = case (v @ EVar(id), env) => { - env.get(id) match { + env.get(id) match case Some(t) => v.typ = Some(t) case None => () - } v -> env } - } - override def rewriteC(cmd: Command)(implicit env: Env) = { + override def rewriteC(cmd: Command)(implicit env: Env) = mergeRewriteC(myRewriteC)(cmd, env) - } -} diff --git a/src/main/scala/passes/LowerUnroll.scala b/src/main/scala/passes/LowerUnroll.scala index e08f6555..30c34e33 100644 --- a/src/main/scala/passes/LowerUnroll.scala +++ b/src/main/scala/passes/LowerUnroll.scala @@ -12,18 +12,16 @@ import Errors._ import CodeGenHelpers._ import ScopeMap._ -object LowerUnroll extends PartialTransformer { +object LowerUnroll extends PartialTransformer: var curIdx = 0 - def genName(prefix: String): String = { + def genName(prefix: String): String = curIdx += 1; prefix + curIdx - } private def genViewAccessExpr(suffix: Suffix, idx: Expr): Expr = - suffix match { + suffix match case Aligned(factor, e2) => (EInt(factor) * e2) + idx case Rotation(e) => e + idx - } // Key for transformers: A sequence that tracks the index expression // and the bank implied by that expression. @@ -52,28 +50,25 @@ object LowerUnroll extends PartialTransformer { * * isDecl: True if this transformer is associated with a `decl` memory */ - case class ViewTransformer(t: TKey => TVal, isDecl: Boolean) { + case class ViewTransformer(t: TKey => TVal, isDecl: Boolean): // Apply method for accesses - def apply(key: TKey): TVal = { + def apply(key: TKey): TVal = t(key) - } - } - object ViewTransformer { + object ViewTransformer: /** * Define a transformer from a unbanked declaration */ - def fromDecl(@annotation.unused id: Id, @annotation.unused ta: TArray) = { + def fromDecl(@annotation.unused id: Id, @annotation.unused ta: TArray) = val t = (_: TKey) => { throw Impossible("Transformer on `decl` memory should not be called") } ViewTransformer(t, true) - } /** * Define a transformer for a base memory. */ - def fromArray(id: Id, ta: TArray) = { + def fromArray(id: Id, ta: TArray) = val t = (idxs: TKey) => { // If any of the indices are constants, scale them to the index in the // bank. @@ -116,16 +111,14 @@ object LowerUnroll extends PartialTransformer { .toMap } ViewTransformer(t, false) - } /** * Define a transformer for view built on top of a memory. */ - def fromView(dims: Seq[DimSpec], v: CView): ViewTransformer = { + def fromView(dims: Seq[DimSpec], v: CView): ViewTransformer = val t = (idxs: Seq[(Expr, Option[Int])]) => { - if idxs.length != dims.length then { + if idxs.length != dims.length then throw PassError("LowerUnroll: Incorrect access dimensions") - } // Bank and index for a dimension type PhyIdx = (Int, Expr) @@ -164,8 +157,6 @@ object LowerUnroll extends PartialTransformer { .toMap } ViewTransformer(t, false) - } - } // XXX(rachit): There are two maps that currently track some form of // rewriting: rewriteMap transforms local variables, combineReg transforms @@ -182,33 +173,28 @@ object LowerUnroll extends PartialTransformer { // DimSpec of bound arrays in this context. dimsMap: Map[Id, TArray] ) extends ScopeManager[ForEnv] - with Tracker[Id, Int, ForEnv] { - def merge(that: ForEnv) = { + with Tracker[Id, Int, ForEnv]: + def merge(that: ForEnv) = assert(this == that, "Tried to merge different unroll envs") this - } - override def withScopeAndRet[V](inScope: Env => (V, Env)) = { + override def withScopeAndRet[V](inScope: Env => (V, Env)) = val (ret, nEnv) = inScope(this.copy(rewrites = rewrites.addScope)) - val rws = nEnv.rewrites.endScope match { + val rws = nEnv.rewrites.endScope match case Some((_, rs)) => rs case None => throw Impossible("unroll env failed to end scope.") - } ret -> nEnv.copy(rewrites = rws) - } def get(key: Id) = this.idxMap.get(key) def add(key: Id, bank: Int) = this.copy(idxMap = this.idxMap + (key -> bank)) def rewriteGet(key: Id) = this.rewrites.get(key) - def rewriteAdd(k: Id, v: Id) = { - val newRewrites = this.rewrites.add(k, v) match { + def rewriteAdd(k: Id, v: Id) = + val newRewrites = this.rewrites.add(k, v) match case None => throw AlreadyBound(k) case Some(rw) => rw - } this.copy(rewrites = newRewrites) - } def viewAdd(k: Id, v: ViewTransformer) = this.copy(viewMap = viewMap + (k -> v)) @@ -222,14 +208,13 @@ object LowerUnroll extends PartialTransformer { .get(k) .getOrThrow(Impossible(s"Dimensions for `$k' not bound")) - } type Env = ForEnv val emptyEnv = ForEnv(Map(), ScopedMap(), Map(), Map()) // Given a logically banked memory type, generate several physical memories // corresponding to it. - def unbankedDecls(id: Id, ta: TArray): Seq[(Id, Type)] = { + def unbankedDecls(id: Id, ta: TArray): Seq[(Id, Type)] = val TArray(typ, dims, ports) = ta cartesianProduct(dims.map({ case (size, banks) => (0 to banks - 1).map((size / banks, _)) @@ -238,9 +223,8 @@ object LowerUnroll extends PartialTransformer { val dims = idxs.map({ case (s, _) => (s, 1) }) (Id(name), TArray(typ, dims, ports)) }) - } - override def rewriteDeclSeq(ds: Seq[Decl])(implicit env: Env) = { + override def rewriteDeclSeq(ds: Seq[Decl])(implicit env: Env) = // Memory decls cannot be banked val nEnv = ds.foldLeft(env)({ case (env, Decl(id, typ)) => @@ -260,7 +244,6 @@ object LowerUnroll extends PartialTransformer { }) ds -> nEnv - } private def getBanks(arr: Id, idxs: Seq[Expr])(implicit env: Env) = env @@ -295,24 +278,22 @@ object LowerUnroll extends PartialTransformer { * ``` * Read https://github.com/cucapra/dahlia/issues/311 for details. */ - private def mergePar(cmds: Seq[Command]): Command = { - if cmds.isEmpty then { + private def mergePar(cmds: Seq[Command]): Command = + if cmds.isEmpty then CEmpty - } else if cmds.length == 1 then { + else if cmds.length == 1 then cmds(0) - } // [{ a0 -- b0 -- ...}, {a1 -- b1 -- ..}] // => // { merge([a0, a1]) -- merge([b0, b1]) } - else if cmds.forall(_.isInstanceOf[CSeq]) then { + else if cmds.forall(_.isInstanceOf[CSeq]) then CSeq.smart( cmds.collect({ case CSeq(cs) => cs }).transpose.map(mergePar(_)) ) - } // [for (r) { b0 } combine { c0 }, for (r) { b1 } combine { c1 }, ...] // => // for (r) { merge([b0, b1, ...]) } combine { merge(c0, c1, ...) } - else if cmds.forall(_.isInstanceOf[CFor]) then { + else if cmds.forall(_.isInstanceOf[CFor]) then val fors = cmds.map[CFor](_.asInstanceOf[CFor]) val merged = fors .groupBy(f => f.range) @@ -328,11 +309,10 @@ object LowerUnroll extends PartialTransformer { } }) CPar.smart(merged.toSeq) - } // [while (c) { b0 }, while (c) { b1 }, ...] // => // while (c) { merge([b0, b1]) } - else if cmds.forall(_.isInstanceOf[CWhile]) then { + else if cmds.forall(_.isInstanceOf[CWhile]) then val whiles = cmds.map[CWhile](_.asInstanceOf[CWhile]) val merged = whiles .groupBy(w => w.cond) @@ -343,11 +323,10 @@ object LowerUnroll extends PartialTransformer { ) }) CPar.smart(merged.toSeq) - } // [if (c) { t0 } else { f0 }, if (c) { t1 } else { f1 }, ...] // => // if (c) { merge([t0, t1]) } else { merge([f0, f1]) } - else if cmds.forall(_.isInstanceOf[CIf]) then { + else if cmds.forall(_.isInstanceOf[CIf]) then val ifs = cmds.map[CIf](_.asInstanceOf[CIf]) val merged = ifs .groupBy(i => i.cond) @@ -362,25 +341,22 @@ object LowerUnroll extends PartialTransformer { } }) CPar.smart(merged.toSeq) - } // [ {a0; b0, ...}, {a1; b1, ...} ] // => // merge([a0, a1, ...]); merge([b0, b1, ...]) ... - else if cmds.forall(_.isInstanceOf[CPar]) then { + else if cmds.forall(_.isInstanceOf[CPar]) then CPar( cmds.collect({ case CPar(cs) => cs }).transpose.map(mergePar(_)) ) - } // [ { b0 }, { b1 } ...] // => // { merge([b0, b1 ...]) } - else if cmds.forall(_.isInstanceOf[CBlock]) then { + else if cmds.forall(_.isInstanceOf[CBlock]) then CBlock(mergePar(cmds.collect({ case CBlock(cmd) => cmd }))) - } // [ r0 += l1, r0 += l2, r1 += l3, r1 += l4 ... ] // => // [ r1 += l1 + l2 + ...; r1 += l3 + l4 ] - else if cmds.forall(_.isInstanceOf[CReduce]) then { + else if cmds.forall(_.isInstanceOf[CReduce]) then val creds = cmds.collect[CReduce]({ case c: CReduce => c }) val merged = creds .groupBy(c => c.lhs) @@ -412,12 +388,9 @@ object LowerUnroll extends PartialTransformer { } }) CPar.smart(merged.toSeq) - } // Just merge the statements - else { + else CPar.smart(cmds) - } - } // Generate a sequence of commands based on `allExps` private def condCmd( @@ -425,16 +398,15 @@ object LowerUnroll extends PartialTransformer { idxs: Seq[Expr], arrDims: Seq[DimSpec], newCommand: Expr => Command - )(implicit env: Env): (Command, Env) = { + )(implicit env: Env): (Command, Env) = // If we got exactly one value in TVal, that means that the returned // expression corresponds exactly to the input bank. In this case, // don't generate a condition. - if allExprs.size == 1 then { + if allExprs.size == 1 then val elems = allExprs.toArray val elem = elems(0)._2 val (nE, nEnv) = rewriteE(elem)(env) return (newCommand(nE), nEnv) - } val condAssigns = allExprs.map({ case (bankVals, accExpr) => { val cond = bankVals @@ -454,13 +426,12 @@ object LowerUnroll extends PartialTransformer { // Update the environment with all the newly generated names CPar.smart(condAssigns.toSeq) -> env - } - def myRewriteC: PF[(Command, Env), (Command, Env)] = { + def myRewriteC: PF[(Command, Env), (Command, Env)] = // Transform reads from memories case (c @ CLet(bind, typ, Some(EArrAccess(arrId, idxs))), env) => { val transformer = env.viewGet(arrId) - if transformer.isDefined && !transformer.get.isDecl then { + if transformer.isDefined && !transformer.get.isDecl then val t = transformer.get val allExprs = t(idxs.zip(getBanks(arrId, idxs)(env))) @@ -472,11 +443,11 @@ object LowerUnroll extends PartialTransformer { condCmd(allExprs, idxs, arrDims, updCmd)(env) val cmd = CPar.smart(CLet(bind, Some(typ), None), newCmd) rewriteC(cmd)(nEnv) - } else { + else // Rewrite the idxs val (nIdxs, nEnv) = rewriteESeq(idxs)(env) // If there is a rewrite for the LHS, use it - env.rewriteGet(bind) match { + env.rewriteGet(bind) match case Some(lhs) => CLet(lhs, typ, Some(EArrAccess(arrId, nIdxs.toSeq))) -> nEnv case None => { @@ -490,8 +461,6 @@ object LowerUnroll extends PartialTransformer { lhs ) } - } - } } // Rewrite banked let bound memories case (CLet(id, Some(ta: TArray), None), env) => { @@ -505,9 +474,8 @@ object LowerUnroll extends PartialTransformer { } // Handle case for initialized, unbanked memories. case (c @ CLet(_, Some(ta: TArray), _), env) => { - if ta.dims.exists({ case (_, bank) => bank > 1 }) then { + if ta.dims.exists({ case (_, bank) => bank > 1 }) then throw NotImplemented("Banked local arrays with initial values") - } c -> env } // Rewrite let bound variables if needed. @@ -516,16 +484,15 @@ object LowerUnroll extends PartialTransformer { // Don't rewrite this name if there is already a binding in // rewrite map. val rewriteVal = env.rewriteGet(id) - val (cmd, nEnv) = if rewriteVal.isDefined then { + val (cmd, nEnv) = if rewriteVal.isDefined then c.copy(e = nInit).withPos(c) -> env - } else { + else val suf = env.idxMap.toList.sortBy(_._1.v).map(_._2).mkString("_") val newName = id.copy(s"${id.v}_${suf}") c.copy(id = newName, e = nInit).withPos(c) -> env.rewriteAdd( id, newName ) - } (cmd, nEnv) } // Handle views @@ -543,7 +510,7 @@ object LowerUnroll extends PartialTransformer { (CEmpty, nEnv) } case (c @ CFor(range, _, par, combine), env) => { - if range.u == 1 then { + if range.u == 1 then val ((nPar, nComb), _) = env.withScopeAndRet(env => { val (p, e1) = rewriteC(par)(env) val (c, _) = rewriteC(combine)(e1) @@ -553,27 +520,25 @@ object LowerUnroll extends PartialTransformer { // Add bound attribute cfor.attributes = cfor.attributes + ("bound" -> (range.e - range.s)) cfor -> env - } else { + else // Create duplicates of the loop bodies and merge them together. mergePar((0 until range.u).map(idx => { val nRange = range.copy(e = range.e / range.u, u = 1) val nEnv = env.add(range.iter, idx) rewriteC(c.copy(range = nRange))(nEnv)._1 })) -> env - } } case (c @ CReduce(_, l, r), env) => { // Transform RHS. RHS cannot have an array read. - val (nR, nEnv) = r match { + val (nR, nEnv) = r match case _: EArrAccess => throw PassError("Unexpected array read on reduce RHS.") case _ => rewriteE(r)(env) - } val nCmd = c.copy(rhs = nR) - l match { + l match case EArrAccess(id, idxs) => { val transformer = env.viewGet(id) - if transformer.isDefined && !transformer.get.isDecl then { + if transformer.isDefined && !transformer.get.isDecl then val t = transformer.get val allExprs = t(idxs.zip(getBanks(id, idxs)(env))) @@ -582,24 +547,22 @@ object LowerUnroll extends PartialTransformer { val (newCmd, nEnv) = condCmd(allExprs, idxs, arrDims, (e) => nCmd.copy(lhs = e))(env) rewriteC(newCmd)(nEnv) - } else { + else nCmd.copy(lhs = rewriteE(l)(nEnv)._1) -> env - } } case _ => nCmd.copy(lhs = rewriteE(l)(nEnv)._1) -> env - } } case (CUpdate(lhs, rhs), env0) => val (nRhs, env) = rewriteE(rhs)(env0) val c = CUpdate(lhs, nRhs) - lhs match { + lhs match case e @ EVar(id) => c.copy(lhs = env.rewriteGet(id).map(nId => EVar(nId)).getOrElse(e)) -> env case e @ EArrAccess(id, idxs) => { val nIdxs = idxs.map(rewriteE(_)(env)._1) val nCmd = c.copy(lhs = e.copy(idxs = nIdxs)) val transformer = env.viewGet(id) - if transformer.isDefined && !transformer.get.isDecl then { + if transformer.isDefined && !transformer.get.isDecl then val t = transformer.get val allExprs = t(idxs.zip(getBanks(id, idxs)(env))) @@ -608,9 +571,8 @@ object LowerUnroll extends PartialTransformer { val (newCmd, nEnv) = condCmd(allExprs, idxs, arrDims, (e) => nCmd.copy(lhs = e))(env) rewriteC(newCmd)(nEnv) - } else { + else nCmd -> env - } } case (EPhysAccess(id, physIdxs)) => { val transformer = env @@ -639,25 +601,22 @@ object LowerUnroll extends PartialTransformer { rewriteC(newCmd)(nEnv) } case _ => throw Impossible("Not an LHS") - } - } - def myRewriteE: PF[(Expr, Env), (Expr, Env)] = { + def myRewriteE: PF[(Expr, Env), (Expr, Env)] = case (e @ EVar(id), env) => { val varRewrite = env.rewriteGet(id) val arrRewrite = env.viewGet(id) - if varRewrite.isDefined then { + if varRewrite.isDefined then EVar(varRewrite.get) -> env - } else if arrRewrite.isDefined && !arrRewrite.get.isDecl then { + else if arrRewrite.isDefined && !arrRewrite.get.isDecl then val TArray(_, dims, _) = env.dimsGet(id) // Construct a fake access expression val t = arrRewrite.get val map = t(dims.map(_ => EInt(0) -> None)) - if map.size != 1 then { + if map.size != 1 then throw Impossible(s"Memory parameter is banked: $id.", e.pos) - } val List((_, acc)) = map.toList - acc match { + acc match case EArrAccess(id, _) => { rewriteE(EVar(id))(env) } @@ -666,10 +625,8 @@ object LowerUnroll extends PartialTransformer { s"Memory parameter returned unexpected access expression: ${Pretty .emitExpr(e)(false)}" ) - } - } else { + else e -> env - } } // Since physical access expression imply exactly one expression, we can // rewrite them. @@ -688,7 +645,6 @@ object LowerUnroll extends PartialTransformer { val nExpr = allExprs.values.toArray rewriteE(nExpr(0))(env) } - } override def rewriteC(cmd: Command)(implicit env: Env) = mergeRewriteC(myRewriteC)(cmd, env) @@ -696,4 +652,3 @@ object LowerUnroll extends PartialTransformer { override def rewriteE(expr: Expr)(implicit env: Env) = mergeRewriteE(myRewriteE)(expr, env) -} diff --git a/src/main/scala/passes/RewriteView.scala b/src/main/scala/passes/RewriteView.scala index dfe1a4fb..edecab1b 100644 --- a/src/main/scala/passes/RewriteView.scala +++ b/src/main/scala/passes/RewriteView.scala @@ -21,52 +21,47 @@ import fuselang.Utils.RichOption * For information about the monadic implementation, refer to the docs for * [[fuselang.StateHelper.State]]. */ -object RewriteView extends TypedPartialTransformer { +object RewriteView extends TypedPartialTransformer: case class ViewEnv(map: Map[Id, Seq[Expr] => Expr]) extends ScopeManager[ViewEnv] - with Tracker[Id, Seq[Expr] => Expr, ViewEnv] { - def merge(that: ViewEnv) = { + with Tracker[Id, Seq[Expr] => Expr, ViewEnv]: + def merge(that: ViewEnv) = if this.map.keys != that.map.keys then throw Impossible("Tried to merge ViewEnvs with different keys.") this - } def get(arrId: Id) = this.map.get(arrId) def add(arrId: Id, func: Seq[Expr] => Expr) = ViewEnv(this.map + (arrId -> func)) - } type Env = ViewEnv val emptyEnv = ViewEnv(Map()) private def genViewAccessExpr(view: View, idx: Expr): Expr = - view.suffix match { + view.suffix match case Aligned(factor, e2) => (EInt(factor) * e2) + idx case Rotation(e) => e + idx - } private def splitAccessExpr( i: Expr, j: Expr, arrBank: Int, viewBank: Int - ): Expr = { + ): Expr = (i * EInt(viewBank)) + ((j / EInt(viewBank)) * EInt(arrBank)) + (j % EInt(viewBank)) - } - def myRewriteE: PF[(Expr, Env), (Expr, Env)] = { + def myRewriteE: PF[(Expr, Env), (Expr, Env)] = case (acc @ EArrAccess(arrId, idxs), env) => { // Rewrite the indexing expressions val (nIdxs, nEnv) = super.rewriteESeq(idxs)(env) val rewrite = nEnv.get(arrId) - if rewrite.isDefined then { + if rewrite.isDefined then rewriteE((rewrite.get)(nIdxs.toSeq))(nEnv) - } else { + else acc.copy(idxs = nIdxs.toSeq) -> nEnv - } } case (acc @ EPhysAccess(arrId, bankIdxs), env) => { // Rewrite the indexing expressions @@ -76,14 +71,12 @@ object RewriteView extends TypedPartialTransformer { (bank, nIdx) -> env1 }: ((Int, Expr), Env) => ((Int, Expr), Env))(bankIdxs)(env) - if nEnv.get(arrId).isDefined then { + if nEnv.get(arrId).isDefined then throw NotImplemented("Rewriting physical accesses on views.") - } acc.copy(bankIdxs = nBankIdxs.toSeq) -> nEnv } - } - def myRewriteC: PF[(Command, Env), (Command, Env)] = { + def myRewriteC: PF[(Command, Env), (Command, Env)] = case (CView(id, arrId, dims), env) => { val f = (es: Seq[Expr]) => EArrAccess( @@ -98,10 +91,9 @@ object RewriteView extends TypedPartialTransformer { } case (c @ CSplit(id, arrId, factors), env) => { val arrBanks = arrId.typ - .getOrThrow(Impossible(s"$arrId is missing type in $c")) match { + .getOrThrow(Impossible(s"$arrId is missing type in $c")) match case TArray(_, dims, _) => dims.map(_._2) case t => throw Impossible(s"Array has type $t in $c") - } val f = (es: Seq[Expr]) => { val it = es.iterator // For each dimension, if it was split by more than 1, group the next @@ -120,11 +112,9 @@ object RewriteView extends TypedPartialTransformer { } (CEmpty, env.add(id, f)) } - } // Compose custom traversal with parent's generic traversal. override def rewriteC(cmd: Command)(implicit env: Env) = mergeRewriteC(myRewriteC)(cmd, env) override def rewriteE(expr: Expr)(implicit env: Env) = mergeRewriteE(myRewriteE)(expr, env) -} diff --git a/src/main/scala/passes/Sequentialize.scala b/src/main/scala/passes/Sequentialize.scala index 0386d527..be868b9e 100644 --- a/src/main/scala/passes/Sequentialize.scala +++ b/src/main/scala/passes/Sequentialize.scala @@ -8,11 +8,11 @@ import EnvHelpers._ import Syntax._ import CompilerError._ -object Sequentialize extends PartialTransformer { +object Sequentialize extends PartialTransformer: case class SeqEnv(uses: Set[Id], defines: Set[Id], useLHS: Boolean) - extends ScopeManager[SeqEnv] { - def merge(that: SeqEnv) = { + extends ScopeManager[SeqEnv]: + def merge(that: SeqEnv) = assert( this.useLHS == that.useLHS, "Attempting to merge environment with different useLHS" @@ -22,19 +22,17 @@ object Sequentialize extends PartialTransformer { this.defines union that.defines, this.useLHS ) - } def addUse(x: Id) = this.copy(uses = this.uses + x) def addDefine(x: Id) = this.copy(defines = this.defines + x) def setUseLHS(useLHS: Boolean) = this.copy(useLHS = useLHS) - } type Env = SeqEnv val emptyEnv = SeqEnv(Set(), Set(), false) - def myRewriteE: PF[(Expr, Env), (Expr, Env)] = { + def myRewriteE: PF[(Expr, Env), (Expr, Env)] = case (e @ EVar(id), env) => e -> env.addUse(id) case (e @ EArrAccess(id, idxs), env) => { val (nIdxs, e1) = rewriteESeq(idxs)(env) @@ -42,10 +40,9 @@ object Sequentialize extends PartialTransformer { } case (e: EPhysAccess, _) => throw NotImplemented("Physical accesses in sequentialize", e.pos) - } override def rewriteLVal(e: Expr)(implicit env: SeqEnv): (Expr, SeqEnv) = - e match { + e match case EVar(id) => { val env1 = if env.useLHS then env.addUse(id) else env e -> env1.addDefine(id) @@ -59,9 +56,8 @@ object Sequentialize extends PartialTransformer { throw NotImplemented("Physical accesses in sequentialize", e.pos) case e => throw Impossible(s"Not an LVal: ${Pretty.emitExpr(e)(false).pretty}") - } - def myRewriteC: PF[(Command, Env), (Command, Env)] = { + def myRewriteC: PF[(Command, Env), (Command, Env)] = case (CUpdate(lhs, rhs), env) => { val (nRhs, e1) = rewriteE(rhs)(env) val (nLhs, e2) = rewriteLVal(lhs)(e1) @@ -83,7 +79,7 @@ object Sequentialize extends PartialTransformer { var curUses: SetM[Id] = SetM() val newSeq: Buffer[Buffer[Command]] = Buffer(Buffer()) - for cmd <- cmds do { + for cmd <- cmds do val (nCmd, e1) = rewriteC(cmd)(emptyEnv) /* System.err.println(Pretty.emitCmd(cmd)(false).pretty) System.err.println(s""" @@ -99,29 +95,25 @@ object Sequentialize extends PartialTransformer { // If there are no conflicts, add this to the current parallel // block. if curDefines.intersect(e1.uses).isEmpty && - curUses.intersect(e1.defines).isEmpty then { + curUses.intersect(e1.defines).isEmpty then newSeq.last += nCmd - } else { + else curUses = SetM() curDefines = SetM() newSeq += Buffer(nCmd) - } curUses ++= e1.uses curDefines ++= e1.defines allDefines ++= e1.defines allUses ++= e1.uses - } // Add all the uses and defines from this loop into the summary. val allEnv = SeqEnv(allUses.toSet, allDefines.toSet, false).merge(env) CSeq.smart(newSeq.map(ps => CPar.smart(ps.toSeq)).toSeq) -> allEnv } - } override def rewriteC(cmd: Command)(implicit env: Env) = mergeRewriteC(myRewriteC)(cmd, env) // No need to traverse expressions override def rewriteE(expr: Expr)(implicit env: Env) = mergeRewriteE(myRewriteE)(expr, env) -} diff --git a/src/main/scala/passes/WellFormedCheck.scala b/src/main/scala/passes/WellFormedCheck.scala index 7e4ae35d..a096d020 100644 --- a/src/main/scala/passes/WellFormedCheck.scala +++ b/src/main/scala/passes/WellFormedCheck.scala @@ -10,7 +10,7 @@ import Errors._ import Checker._ import EnvHelpers._ -object WellFormedChecker { +object WellFormedChecker: def check(p: Prog) = WFCheck.check(p) @@ -19,7 +19,7 @@ object WellFormedChecker { insideUnroll: Boolean = false, insideFunc: Boolean = false ) extends ScopeManager[WFEnv] - with Tracker[Id, FuncDef, WFEnv] { + with Tracker[Id, FuncDef, WFEnv]: def merge(that: WFEnv): WFEnv = this override def add(k: Id, v: FuncDef): WFEnv = @@ -31,36 +31,32 @@ object WellFormedChecker { override def get(k: Id): Option[FuncDef] = this.map.get(k) - def canHaveFunctionInUnroll(k: Id): Boolean = { - this.get(k) match { + def canHaveFunctionInUnroll(k: Id): Boolean = + this.get(k) match case Some(FuncDef(_, args, _, _)) => - if this.insideUnroll then { + if this.insideUnroll then args.foldLeft(true)({ (r, arg) => arg.typ match { case TArray(_, _, _) => false case _ => r } }) - } else + else true case None => true // This is supposed to be unreachable - } - } - } - private case object WFCheck extends PartialChecker { + private case object WFCheck extends PartialChecker: type Env = WFEnv val emptyEnv = WFEnv() - override def checkDef(defi: Definition)(implicit env: Env) = defi match { + override def checkDef(defi: Definition)(implicit env: Env) = defi match case fndef @ FuncDef(id, _, _, bodyOpt) => val nenv = env.add(id, fndef) bodyOpt.map(checkC(_)(nenv.copy(insideFunc = true))).getOrElse(nenv) case _: RecordDef => env - } - def myCheckE: PF[(Expr, Env), Env] = { + def myCheckE: PF[(Expr, Env), Env] = case (expr: EPhysAccess, _) => throw CompilerError.PassError( "Physical accesses should be removed up the lowering passes.", @@ -74,9 +70,8 @@ object WellFormedChecker { assertOrThrow(env.canHaveFunctionInUnroll(id) == true, FuncInUnroll(expr.pos)) env } - } - def myCheckC: PF[(Command, Env), Env] = { + def myCheckC: PF[(Command, Env), Env] = case (cmd @ CReduce(op, l, r), e) => { assertOrThrow(e.insideUnroll == false, ReduceInsideUnroll(op, cmd.pos)) checkE(r)(checkE(l)(e)) @@ -84,14 +79,13 @@ object WellFormedChecker { case (l @ CLet(id, typ, Some(EArrLiteral(_))), e) => { val expTyp = typ .getOrThrow(ExplicitTypeMissing(l.pos, "Array literal", id)) - expTyp match { + expTyp match case TArray(_, dims, _) => assertOrThrow( dims.length == 1, Unsupported(l.pos, "Multidimensional array literals") ) case _ => () - } e } case (l @ CLet(id, typ, Some(ERecLiteral(_))), e) => { @@ -122,11 +116,8 @@ object WellFormedChecker { assertOrThrow(env.insideFunc, ReturnNotInFunc(cmd.pos)) env } - } override def checkE(expr: Expr)(implicit env: Env) = mergeCheckE(myCheckE)(expr, env) override def checkC(cmd: Command)(implicit env: Env) = mergeCheckC(myCheckC)(cmd, env) - } -} diff --git a/src/main/scala/typechecker/AffineCheck.scala b/src/main/scala/typechecker/AffineCheck.scala index 7d910aa5..53b464de 100644 --- a/src/main/scala/typechecker/AffineCheck.scala +++ b/src/main/scala/typechecker/AffineCheck.scala @@ -72,27 +72,25 @@ import CompilerError._ * */ -object AffineChecker { +object AffineChecker: def check(p: Prog) = AffineChecker.check(p) - private case object AffineChecker extends PartialChecker { + private case object AffineChecker extends PartialChecker: type Env = AffineEnv.Environment val emptyEnv = AffineEnv.emptyEnv - override def check(p: Prog): Unit = { + override def check(p: Prog): Unit = val Prog(_, defs, _, decls, cmd) = p val topFunc = FuncDef(Id(""), decls, TVoid(), Some(cmd)) - (defs ++ Seq(topFunc)).foldLeft(emptyEnv) { + (defs ++ Seq(topFunc)).foldLeft(emptyEnv): case (e, d) => checkDef(d)(e) - } () - } - override def checkDef(defi: Definition)(implicit env: Env) = defi match { + override def checkDef(defi: Definition)(implicit env: Env) = defi match case FuncDef(_, args, _, bodyOpt) => { val (env2, _, _) = env.withScope(1) { newScope => // Add physical resources corresponding to array decls @@ -110,18 +108,16 @@ object AffineChecker { env2 } case _: RecordDef => env - } /** * Add physical resources and default accessor gadget corresponding to a new * array. This is used for `decl` with arrays and new `let` bound arrays. */ - private def addPhysicalResource(id: Id, typ: TArray, env: Env) = { + private def addPhysicalResource(id: Id, typ: TArray, env: Env) = val banks = typ.dims.map(_._2) env .addResource(id, ArrayInfo(id, banks, typ.ports)) .add(id, MultiDimGadget(ResourceGadget(id, banks), typ.dims)) - } /** * Generate a ConsumeList corresponding to the underlying memory type and @@ -129,7 +125,7 @@ object AffineChecker { */ private def getConsumeList(idxs: Seq[Expr], dims: Seq[DimSpec])( implicit arrId: Id - ) = { + ) = val (bres, consume) = idxs.zipWithIndex.foldLeft( (1, IndexedSeq[Seq[Int]]()) @@ -158,7 +154,6 @@ object AffineChecker { // Reverse the types list to match the order with idxs. (bres, consume.reverse) - } /** * Checks a given simple view and returns the dimensions for the view, @@ -179,11 +174,11 @@ object AffineChecker { (newBank, (pre.getOrElse(len) -> newBank)) }*/ - override def checkLVal(e: Expr)(implicit env: Env) = e match { + override def checkLVal(e: Expr)(implicit env: Env) = e match case acc @ EArrAccess(id, idxs) => { // This only triggers for l-values. val TArray(_, dims, _) = id.typ.get : @unchecked - acc.consumable match { + acc.consumable match case Some(Annotations.ShouldConsume) => { val (bres, consumeList) = getConsumeList(idxs, dims)(id) // Check if the accessors generated enough copies for the context. @@ -198,12 +193,10 @@ object AffineChecker { } case con => throw Impossible(s"$acc in write position has $con annotation") - } } case _ => checkE(e) - } - def myCheckC: PF[(Command, Env), Env] = { + def myCheckC: PF[(Command, Env), Env] = case (CLet(id, Some(ta @ TArray(_, _, _)), _), env) => { addPhysicalResource(id, ta, env) } @@ -255,9 +248,8 @@ object AffineChecker { val TArray(_, vdims, _) = id.typ.get : @unchecked env.add(id, splitGadget(env(arrId), adims, vdims)) } - } - def myCheckE: PF[(Expr, Env), Env] = { + def myCheckE: PF[(Expr, Env), Env] = case (EApp(_, args), env) => { args.foldLeft(env)({ case (e, argExpr) => { @@ -278,7 +270,7 @@ object AffineChecker { } case (expr @ EArrAccess(id, idxs), env) => { val TArray(_, dims, _) = id.typ.get : @unchecked - expr.consumable match { + expr.consumable match case None => throw Impossible( s"$expr in read position has no consumable annotation" @@ -289,16 +281,12 @@ object AffineChecker { // Consume the resources required by this gadget. env.consumeWithGadget(id, consumeList)(expr.pos) } - } } case (_: EPhysAccess, _) => { throw NotImplemented("Affine checking for physical accesses.") } - } override def checkE(expr: Expr)(implicit env: Env) = mergeCheckE(myCheckE)(expr, env) override def checkC(cmd: Command)(implicit env: Env) = mergeCheckC(myCheckC)(cmd, env) - } -} diff --git a/src/main/scala/typechecker/AffineEnv.scala b/src/main/scala/typechecker/AffineEnv.scala index d73fd3ae..0f6a68c4 100644 --- a/src/main/scala/typechecker/AffineEnv.scala +++ b/src/main/scala/typechecker/AffineEnv.scala @@ -13,7 +13,7 @@ import Errors._ import CompilerError._ import EnvHelpers._ -object AffineEnv { +object AffineEnv: val emptyEnv: Environment = Env()(1) @@ -25,7 +25,7 @@ object AffineEnv { */ sealed trait Environment extends ScopeManager[Environment] - with Tracker[Id, Gadget, Environment] { + with Tracker[Id, Gadget, Environment]: /** * Associate a gadget to the name of the physical resource it consumes. @@ -100,23 +100,21 @@ object AffineEnv { implicit pos: Position ): Environment - } private case class Env( phyRes: ScopedMap[Id, ArrayInfo] = ScopedMap(), gadgetMap: ScopedMap[Id, Gadget] = ScopedMap() )(implicit val res: Int) - extends Environment { + extends Environment: def consumeWithGadget(gadget: Id, consumeList: ConsumeList)( implicit pos: Position - ) = { + ) = val (resName, resources, trace) = this(gadget).getSummary(consumeList) implicit val t = trace this.consumeResource(resName, resources) - } - override def toString = { + override def toString = val lst = for (ps, gs) <- phyRes.iterator.zip(gadgetMap.iterator) yield ( ps.map({ case (k, v) => s"$k -> $v" }).mkString(", "), @@ -124,7 +122,6 @@ object AffineEnv { ) lst.mkString(" ==> ") - } /** Tracking bound gadgets */ def add(id: Id, resource: Gadget) = @@ -135,15 +132,14 @@ object AffineEnv { def get(id: Id) = gadgetMap.get(id) /** Managing physical resources */ - def addResource(id: Id, info: ArrayInfo) = { + def addResource(id: Id, info: ArrayInfo) = val pRes = phyRes.add(id, info).getOrThrow(AlreadyBound(id)) this.copy(phyRes = pRes) - } def consumeResource( name: Id, resources: Seq[Int] - )(implicit pos: Position, trace: Seq[String]): Environment = { - phyRes.get(name) match { + )(implicit pos: Position, trace: Seq[String]): Environment = + phyRes.get(name) match case None => throw Impossible(s"No physical resource named $name.") case Some(info) => { @@ -151,32 +147,28 @@ object AffineEnv { phyRes.update(name, info.consumeResources(resources)) ) } - } - } /** Helper functions for Mergable[Env] */ - def merge(env: Environment): Environment = env match { + def merge(env: Environment): Environment = env match case next: Env => val (oldRes, nextRes) = (this.phyRes.keys, next.phyRes.keys) val (oldGads, nextGads) = (this.gadgetMap.keys, next.gadgetMap.keys) // The next environment should bind all resources in this env. - if oldRes.subsetOf(nextRes) == false then { + if oldRes.subsetOf(nextRes) == false then throw Impossible( "New environment is missing resources bound in old env." + s"\n\nOld Env: ${oldRes}" + s"\n\nNew Env: ${nextRes}" + s"\n\nMissing: ${oldRes diff nextRes}" ) - } - if oldRes.subsetOf(nextRes) == false then { + if oldRes.subsetOf(nextRes) == false then throw Impossible( "New environment is missing gadgets bound in old env." + s"\n\nOld Env: ${oldGads}" + s"\n\nNew Env: ${nextGads}" + s"\n\nMissing: ${oldGads diff nextGads}.\n" ) - } /** * For each bound id, set consumed banks to the union of consumed bank @@ -188,39 +180,29 @@ object AffineEnv { env.phyRes.update(id, env.phyRes(id) merge this.phyRes(id)) ) }) - } /** Helper functions for ScopeManager */ - def addScope(resources: Int) = { + def addScope(resources: Int) = Env(phyRes.addScope, gadgetMap.addScope)(res * resources) - } - def endScope(resources: Int) = { + def endScope(resources: Int) = val scopes = for (pDefs, pMap) <- phyRes.endScope (gDefs, gMap) <- gadgetMap.endScope yield (Env(pMap, gMap)(res / resources), pDefs, gDefs) scopes.getOrThrow(Impossible("Removed topmost scope")) - } def withScope( resources: Int - )(inScope: Environment => Environment) = { - inScope(this.addScope(resources)) match { + )(inScope: Environment => Environment) = + inScope(this.addScope(resources)) match case env: Env => env.endScope(resources) - } - } override def withScopeAndRet[R]( inScope: Environment => (R, Environment) - ) = { - inScope(this.addScope(1)) match { + ) = + inScope(this.addScope(1)) match case (r, env: Env) => (r, env.endScope(1)._1) - } - } - override def withScope(inScope: Environment => Environment): Environment = { + override def withScope(inScope: Environment => Environment): Environment = this.withScope(1)(inScope)._1 - } val getResources = res - } -} diff --git a/src/main/scala/typechecker/CapabilityChecker.scala b/src/main/scala/typechecker/CapabilityChecker.scala index 28d58fb7..d69ba974 100644 --- a/src/main/scala/typechecker/CapabilityChecker.scala +++ b/src/main/scala/typechecker/CapabilityChecker.scala @@ -26,27 +26,24 @@ object CapabilityChecker: * - If there are no capabilities or a write capability, add a consume * annotation. */ - def myCheckE: PF[(Expr, Env), Env] = { + def myCheckE: PF[(Expr, Env), Env] = case (acc @ EArrAccess(_, idxs), env) => { - val (nEnv, consumableAnn, cap) = env.get(acc) match { + val (nEnv, consumableAnn, cap) = env.get(acc) match case Some(Read) => (env, SkipConsume, Read) case Some(Write) | None => (checkESeq(idxs)(env), ShouldConsume, Read) - } acc.consumable = Some(consumableAnn) nEnv.add(acc, cap) } case (_: EPhysAccess, _) => { throw NotImplemented("Capability checking for physical accesses.") } - } - def myCheckC: PF[(Command, Env), Env] = { + def myCheckC: PF[(Command, Env), Env] = case (CSeq(cmds), env) => { // Check all seq commands under the same environment cmds.foreach(c => checkC(c)(env)); env } - } /** * Check an array write. If there is already a write capability, error. @@ -55,18 +52,16 @@ object CapabilityChecker: * This doesn't need to be partial function since it deals with all * cases in checkLVal. */ - override def checkLVal(e: Expr)(implicit env: Env) = e match { + override def checkLVal(e: Expr)(implicit env: Env) = e match case acc @ EArrAccess(_, idxs) => { - val (nEnv, consumableAnn, cap) = env.get(e) match { + val (nEnv, consumableAnn, cap) = env.get(e) match case Some(Write) => throw AlreadyWrite(e) case Some(Read) | None => (checkESeq(idxs), ShouldConsume, Write) - } acc.consumable = Some(consumableAnn) nEnv.add(e, cap) } case _ => checkE(e) - } override def checkE(expr: Expr)(implicit env: Env) = mergeCheckE(myCheckE)(expr, env) diff --git a/src/main/scala/typechecker/CapabilityEnv.scala b/src/main/scala/typechecker/CapabilityEnv.scala index b2fc841b..e2408494 100644 --- a/src/main/scala/typechecker/CapabilityEnv.scala +++ b/src/main/scala/typechecker/CapabilityEnv.scala @@ -7,7 +7,7 @@ import EnvHelpers._ import Syntax._ import CompilerError._ -object CapabilityEnv { +object CapabilityEnv: val emptyEnv: CapabilityEnv = Env() @@ -18,50 +18,41 @@ object CapabilityEnv { private case class Env( readSet: ScopedSet[Expr] = ScopedSet(), writeSet: ScopedSet[Expr] = ScopedSet() - ) extends CapabilityEnv { + ) extends CapabilityEnv: def get(e: Expr) = if readSet.contains(e) then Some(Read) else if writeSet.contains(e) then Some(Write) else None - def add(e: Expr, cap: Capability) = cap match { + def add(e: Expr, cap: Capability) = cap match case Read => this.copy(readSet = readSet.add(e)) case Write => this.copy(writeSet = writeSet.add(e)) - } - def endScope = { + def endScope = val scopes = for (_, rSet) <- readSet.endScope; (_, wSet) <- writeSet.endScope yield this.copy(readSet = rSet, writeSet = wSet) scopes.getOrThrow(Impossible("Removed topmost scope")) - } override def withScopeAndRet[R]( inScope: CapabilityEnv => (R, CapabilityEnv) - ): (R, CapabilityEnv) = { + ): (R, CapabilityEnv) = inScope( this.copy(readSet = readSet.addScope, writeSet = writeSet.addScope) - ) match { + ) match case (r, that: Env) => (r, that.endScope) - } - } override def withScope( inScope: CapabilityEnv => CapabilityEnv - ): CapabilityEnv = { + ): CapabilityEnv = inScope( this.copy(readSet = readSet.addScope, writeSet = writeSet.addScope) - ) match { + ) match case that: Env => that.endScope - } - } - def merge(that: CapabilityEnv) = { + def merge(that: CapabilityEnv) = assert(this == that, "Tried to merge different capability envs") this - } - } -} diff --git a/src/main/scala/typechecker/Gadget.scala b/src/main/scala/typechecker/Gadget.scala index c81df434..476a97e6 100644 --- a/src/main/scala/typechecker/Gadget.scala +++ b/src/main/scala/typechecker/Gadget.scala @@ -14,28 +14,25 @@ import CompilerError._ * asks for resources from a gadget, the gadget determines which resources * it requires from the underlying gadget. */ -object Gadgets { +object Gadgets: type ConsumeList = Seq[Seq[Int]] - def clString(cl: Seq[Seq[Int]]): String = { + def clString(cl: Seq[Seq[Int]]): String = cl.map(els => els.mkString("{", ",", "}")).mkString("[", "][", "]") - } - trait Gadget { + trait Gadget: // Return the name of the resource, the list of banks to be consumed, // and a trace of transformations done on the original resource. def getSummary(consume: ConsumeList): (Id, Seq[Int], Seq[String]) - } - case class ResourceGadget(resource: Id, banks: Seq[Int]) extends Gadget { - private def cross[A](acc: Seq[Seq[A]], l: Seq[A]): Seq[Seq[A]] = { + case class ResourceGadget(resource: Id, banks: Seq[Int]) extends Gadget: + private def cross[A](acc: Seq[Seq[A]], l: Seq[A]): Seq[Seq[A]] = for a <- acc; el <- l yield a :+ el - } override def toString = resource.toString - private def hyperBankToBank(hyperBanks: Seq[Int]) = { + private def hyperBankToBank(hyperBanks: Seq[Int]) = if hyperBanks.length != banks.length then throw Impossible("hyperbank size is different from original banking") @@ -44,13 +41,12 @@ object Gadgets { .foldLeft(0)({ case (acc, (hb, b)) => b * acc + hb }) - } /** * The root for all gadgets. Maps a multidimensional consume list to * corresponding one dimensional banks. */ - def getSummary(consume: ConsumeList) = { + def getSummary(consume: ConsumeList) = // Transform consumelist into a Seq[Seq[A]] where the inner list // represents a sequence of banks for the dimension. These are // latter transformed to 1D banks. @@ -62,36 +58,30 @@ object Gadgets { val outRes = hyperBanks.map(hyperBankToBank) (resource, outRes, Seq(clString(Seq(outRes)))) - } - } case class MultiDimGadget(underlying: Gadget, dim: Seq[DimSpec]) - extends Gadget { + extends Gadget: /** * A base physical memory with `k` banks redirects access from bank `b` to * to `b % k`. */ - def getSummary(consume: ConsumeList) = { + def getSummary(consume: ConsumeList) = val resourceTransform = consume .zip(dim) .map({ case (resources, (_, banks)) => resources.map(_ % banks) }) val (res, sum, trace) = underlying.getSummary(resourceTransform) (res, sum, clString(resourceTransform) +: trace) - } - } case class ViewGadget( underlying: Gadget, transformer: ConsumeList => ConsumeList - ) extends Gadget { - def getSummary(consume: ConsumeList) = { + ) extends Gadget: + def getSummary(consume: ConsumeList) = val outRes = transformer(consume) val (res, sum, trace) = underlying.getSummary(outRes) (res, sum, clString(outRes) +: trace) - } - } /** * Creates a conservative simple view that fully consumes the array @@ -103,7 +93,7 @@ object Gadgets { underlying: Gadget, shrinks: Seq[Int], arrDims: Seq[DimSpec] - ): ViewGadget = { + ): ViewGadget = // Multiply the resource requirements by the origBanking / shrink. // This simulates that shrinking "connects" multiple banks into a // single one. @@ -135,7 +125,6 @@ object Gadgets { }) } ViewGadget(underlying, transformer) - } /** * Creates logic for a split view. A split view always has an even number @@ -148,7 +137,5 @@ object Gadgets { underlying: Gadget, arrayDims: Seq[DimSpec], @deprecated("Not used", "0.0.1") splitDims: Seq[DimSpec] - ): ViewGadget = { + ): ViewGadget = viewGadget(underlying, arrayDims.map(_._2), arrayDims) - } -} diff --git a/src/main/scala/typechecker/Info.scala b/src/main/scala/typechecker/Info.scala index bbaa30df..3cc6959c 100644 --- a/src/main/scala/typechecker/Info.scala +++ b/src/main/scala/typechecker/Info.scala @@ -9,7 +9,7 @@ import Syntax._ import Errors._ import MultiSet._ -object Info { +object Info: case class ArrayInfo( id: Id, @@ -19,32 +19,31 @@ object Info { remBanks: MultiSet[Int], // Source code locations that consumed a bank. conLocs: Map[Int, MultiSet[Position]] = Map() - ) { + ): override def toString = remBanks.toString def consumeResources( resources: Seq[Int] - )(implicit pos: Position, trace: Seq[String]) = { + )(implicit pos: Position, trace: Seq[String]) = // Make sure banks exist. val missingBank = resources.find(!avBanks.containsAtLeast(_, 1)) assertOrThrow(missingBank.isEmpty, UnknownBank(id, missingBank.head)) val newConLocs = - conLocs ++ resources.foldLeft(conLocs) { + conLocs ++ resources.foldLeft(conLocs): case (newConLocs, res) => newConLocs + (res -> newConLocs .getOrElse(res, emptyMultiSet[Position]()) .add(pos)) - } // Calculate multi-set difference b/w required resource and available // resources. val resourceMS = fromSeq(resources) val afterConsume = remBanks.diff(resourceMS) val hasRequired = afterConsume.forall({ case (_, v) => v >= 0 }) - if hasRequired == false then { + if hasRequired == false then val bank = afterConsume.find({ case (_, v) => v < 0 }).get._1 throw AlreadyConsumed( id, @@ -52,21 +51,17 @@ object Info { avBanks.getCount(bank), newConLocs(bank) ) - } this.copy(remBanks = afterConsume, conLocs = newConLocs) - } // Return a copy of the physical resource with all the resources available. def toFresh = this.copy(remBanks = avBanks, conLocs = Map()) - def merge(that: ArrayInfo) = { + def merge(that: ArrayInfo) = val remBanks = this.remBanks.zipWith(that.remBanks, Math.min) this.copy(remBanks = remBanks, conLocs = this.conLocs ++ that.conLocs) - } - } - object ArrayInfo { + object ArrayInfo: private def cross[A](acc: Seq[Seq[A]], l: Seq[A]): Seq[Seq[A]] = for a <- acc; el <- l yield a :+ el @@ -77,7 +72,7 @@ object Info { case (acc, (hb, maxBank)) => acc * maxBank + hb }) - def apply(id: Id, banks: Iterable[Int], ports: Int): ArrayInfo = { + def apply(id: Id, banks: Iterable[Int], ports: Int): ArrayInfo = val startResources: MultiSet[Int] = fromSeq( banks .map(b => List.tabulate(b)(identity)) @@ -90,6 +85,3 @@ object Info { ) ArrayInfo(id, startResources, startResources) - } - } -} diff --git a/src/main/scala/typechecker/Subtyping.scala b/src/main/scala/typechecker/Subtyping.scala index 6f0b1f35..3e53f85f 100644 --- a/src/main/scala/typechecker/Subtyping.scala +++ b/src/main/scala/typechecker/Subtyping.scala @@ -40,17 +40,16 @@ import fuselang.Utils.bitsNeeded * if there is no type T' statisfying T' < T and is an upper bound for t1 and * t2. */ -object Subtyping { - def areEqual(t1: Type, t2: Type) = (t1, t2) match { +object Subtyping: + def areEqual(t1: Type, t2: Type) = (t1, t2) match case (TStaticInt(v1), TStaticInt(v2)) => v1 == v2 case (_: TIndex, _: TIndex) => true case (_: TFloat, _: TFloat) => true case (TAlias(r1), t) => r1.toString == t.toString case (t, TAlias(r1)) => t.toString == r1.toString case _ => t1 == t2 - } - def isSubtype(sub: Type, sup: Type): Boolean = (sub, sup) match { + def isSubtype(sub: Type, sup: Type): Boolean = (sub, sup) match case (TSizedInt(v1, un1), TSizedInt(v2, un2)) => un1 == un2 && v1 <= v2 case (TStaticInt(v1), TSizedInt(v2, un2)) => ((v1 < 0 && un2 == false) || (v1 >= 0)) && bitsNeeded(v1) <= v2 @@ -70,20 +69,18 @@ object Subtyping { case (TFixed(t1, i1, un1), TFixed(t2, i2, un2)) => (un1 == un2 && i1 <= i2 && (t1 - i1) <= (t2 - i2)) case _ => areEqual(sub, sup) - } private def joinOfHelper(t1: Type, t2: Type, op: BOp): Option[Type] = - (t1, t2) match { + (t1, t2) match //XXX(Zhijing): what happens for multiplication? Overflow? case (TStaticInt(v1), TStaticInt(v2)) => - op.toFun match { + op.toFun match case Some(fun) => Some(TStaticInt(fun(v1.toDouble, v2.toDouble).toInt)) case None => Some(TSizedInt(max(bitsNeeded(v1), bitsNeeded(v2)), false)) - } case (TRational(v1), TRational(v2)) => - op.toFun match { + op.toFun match //XXX(Zhijing):deprecated case Some(fun) => Some(TRational(fun(v1.toDouble, v2.toDouble).toString)) @@ -91,7 +88,6 @@ object Subtyping { if bitsNeeded(v1.toDouble.toInt) > bitsNeeded(v2.toDouble.toInt) then Some(TRational(v1)) else Some(TRational(v2)) - } case (TSizedInt(s1, un1), TSizedInt(s2, un2)) => if un1 == un2 then Some(TSizedInt(max(s1, s2), un1)) else None @@ -121,19 +117,17 @@ object Subtyping { TSizedInt(max(bitsNeeded(ti1.maxVal), bitsNeeded(ti2.maxVal)), false) ) case (t1, t2) => if t1 == t2 then Some(t1) else None - } /** * Try finding the join of either ordering and use the result. */ - def joinOf(t1: Type, t2: Type, op: BOp): Option[Type] = { + def joinOf(t1: Type, t2: Type, op: BOp): Option[Type] = val j1 = joinOfHelper(t1, t2, op) if j1.isDefined then j1 else joinOfHelper(t2, t1, op) - } def safeCast(originalType: Type, castType: Type) = - (originalType, castType) match { + (originalType, castType) match case (t1: IntType, t2: TSizedInt) => isSubtype(t1, t2) case (_: TFloat, _: TSizedInt) => false case (_: TDouble, _: TSizedInt) => false @@ -144,5 +138,3 @@ object Subtyping { case (_: TRational, _: TDouble) => true case (TSizedInt(i1, un1), TFixed(_, i2, un2)) => (un1 == un2 && i1 <= i2) case (t1, t2) => areEqual(t1, t2) - } -} diff --git a/src/main/scala/typechecker/TypeCheck.scala b/src/main/scala/typechecker/TypeCheck.scala index a8b8d9c1..951446f8 100644 --- a/src/main/scala/typechecker/TypeCheck.scala +++ b/src/main/scala/typechecker/TypeCheck.scala @@ -14,30 +14,27 @@ import Logger.PositionalLoggable * Type checker implementation for Dahlia. * */ -object TypeChecker { +object TypeChecker: - def pr[A](v: A): A = { + def pr[A](v: A): A = println(v) v - } /* A program consists of a list of function or type definitions, a list of * variable declarations and then a command. We build up an environment with * all the declarations and definitions, then check the command in that environment * (`checkC`). */ - def typeCheck(p: Prog) = { + def typeCheck(p: Prog) = val Prog(includes, defs, _, decls, cmd) = p val allDefs = includes.flatMap(_.defs) ++ defs val topFunc = FuncDef(Id(""), decls, TVoid(), Some(cmd)) - (allDefs ++ List(topFunc)).foldLeft(emptyEnv) { + (allDefs ++ List(topFunc)).foldLeft(emptyEnv): case (e, d) => checkDef(d, e) - } - } - private def checkDef(defi: Definition, env: Environment) = defi match { + private def checkDef(defi: Definition, env: Environment) = defi match case FuncDef(id, args, ret, bodyOpt) => { val env2 = env.withScope { newScope => // Bind all declarations to the body. @@ -61,9 +58,8 @@ object TypeChecker { val rFields = fields.map({ case (k, t) => k -> env.resolveType(t) }) env.addType(name, TRecType(name, rFields)) } - } - private def checkB(t1: Type, t2: Type, op: BOp) = op match { + private def checkB(t1: Type, t2: Type, op: BOp) = op match case _: EqOp => { if t1.isInstanceOf[TArray] then throw UnexpectedType(op.pos, op.toString, "primitive types", t1) @@ -73,12 +69,11 @@ object TypeChecker { throw NoJoin(op.pos, op.toString, t1, t2) } case _: BoolOp => - (t1, t2) match { + (t1, t2) match case (TBool(), TBool()) => TBool() case _ => throw BinopError(op, "booleans", t1, t2) - } case _: CmpOp => - (t1, t2) match { + (t1, t2) match case (_: IntType, _: IntType) => TBool() case (_: TFloat, _: TFloat) => TBool() case (_: TDouble, _: TDouble) => TBool() @@ -95,20 +90,17 @@ object TypeChecker { t1, t2 ) - } case _: NumOp => joinOf(t1, t2, op).getOrThrow(NoJoin(op.pos, op.toString, t1, t2)) //case _:DoubleOp => // joinOf(t1, t2, op).getOrThrow(NoJoin(op.pos, op.toString, t1, t2)) case _: BitOp => - (t1, t2) match { + (t1, t2) match case (_: TSizedInt, _: IntType) => t1 case (TStaticInt(v), _: IntType) => TSizedInt(bitsNeeded(v), false) case (tidx @ TIndex(_, _), _: IntType) => TSizedInt(bitsNeeded(tidx.maxVal), false) case _ => throw BinopError(op, "integer type", t1, t2) - } - } /** * Wrapper for checkE that annotates each expression with it's full type. @@ -116,22 +108,20 @@ object TypeChecker { */ private def checkE( e: Expr - )(implicit env: Environment): (Type, Environment) = { + )(implicit env: Environment): (Type, Environment) = val (typ, nEnv) = _checkE(e) - if e.typ.isDefined && typ != e.typ.get then { + if e.typ.isDefined && typ != e.typ.get then throw Impossible( s"$e was type checked multiple times and given different types." ) - } e.typ = Some(typ) typ -> nEnv - } // Implicit parameters can be elided when a recursive call is reusing the // same env and its. See EBinop case for an example. private def _checkE( expr: Expr - )(implicit env: Environment): (Type, Environment) = expr match { + )(implicit env: Environment): (Type, Environment) = expr match case ERational(v) => TRational(v) -> env case EInt(v, _) => TStaticInt(v) -> env case EBool(_) => TBool() -> env @@ -139,11 +129,9 @@ object TypeChecker { case EArrLiteral(_) => throw NotInBinder(expr.pos, "Array Literal") case ECast(e, castType) => { val (typ, nEnv) = checkE(e) - if safeCast(typ, castType) == false then { - scribe.warn { + if safeCast(typ, castType) == false then + scribe.warn: (s"Casting $typ to $castType which may lose precision.", expr) - } - } castType -> nEnv } case EVar(id) => { @@ -157,11 +145,10 @@ object TypeChecker { checkB(t1, t2, op) -> env2 } case EApp(f, args) => - env(f) match { + env(f) match case TFun(argTypes, retType) => { - if argTypes.length != args.length then { + if argTypes.length != args.length then throw ArgLengthMismatch(expr.pos, argTypes.length, args.length) - } retType -> args .zip(argTypes) @@ -181,23 +168,19 @@ object TypeChecker { }) } case t => throw UnexpectedType(expr.pos, "application", "function", t) - } case ERecAccess(rec, field) => - checkE(rec) match { + checkE(rec) match case (TRecType(name, fields), env1) => - fields.get(field) match { + fields.get(field) match case Some(typ) => typ -> env1 case None => throw UnknownRecordField(expr.pos, name, field) - } case (t, _) => throw UnexpectedType(expr.pos, "record access", "record type", t) - } case EArrAccess(id, idxs) => - env(id).matchOrError(expr.pos, "array access", s"array type") { + env(id).matchOrError(expr.pos, "array access", s"array type"): case TArray(typ, dims, _) => { - if dims.length != idxs.length then { + if dims.length != idxs.length then throw IncorrectAccessDims(id, dims.length, idxs.length) - } idxs.foldLeft(env)((env, idx) => { val (typ, nEnv) = checkE(idx)(env) typ match { @@ -216,13 +199,11 @@ object TypeChecker { id.typ = Some(env(id)); typ -> env } - } case EPhysAccess(id, bankIdxs) => - env(id).matchOrError(expr.pos, "array access", s"array type") { + env(id).matchOrError(expr.pos, "array access", s"array type"): case TArray(typ, dims, _) => { - if dims.length != bankIdxs.length then { + if dims.length != bankIdxs.length then throw IncorrectAccessDims(id, dims.length, bankIdxs.length) - } bankIdxs.foldLeft(env)((env, bankIdx) => { val (_, idx) = bankIdx val (idxTyp, env1) = checkE(idx)(env) @@ -235,73 +216,62 @@ object TypeChecker { id.typ = Some(env(id)); typ -> env } - } - } // Check if this array dimension is well formed and return the dimension // spec for the corresponding dimension in the view. private def checkView(view: View, arrDim: DimSpec)( implicit env: Environment - ): (Environment, DimSpec) = { + ): (Environment, DimSpec) = val View(suf, prefix, shrink) = view val (len, bank) = arrDim // Shrinking factor must be a factor of banking for the dimension - if shrink.isDefined && (shrink.get > bank || bank % shrink.get != 0) then { + if shrink.isDefined && (shrink.get > bank || bank % shrink.get != 0) then throw InvalidShrinkWidth(view.pos, bank, shrink.get) - } val newBank = shrink.getOrElse(bank) // Get the indexing expression - val idx = suf match { + val idx = suf match case Aligned(fac, idx) => - if newBank > fac then { + if newBank > fac then throw InvalidAlignFactor( suf.pos, s"Invalid align factor. Banking factor $newBank is bigger than alignment factor $fac." ) - } else if fac % newBank != 0 then { + else if fac % newBank != 0 then throw InvalidAlignFactor( suf.pos, s"Invalid align factor. Banking factor $newBank not a factor of the alignment factor $fac." ) - } else { + else idx - } case Rotation(idx) => idx - } val (typ, nEnv) = checkE(idx) - typ.matchOrError(idx.pos, "view", "integer type") { + typ.matchOrError(idx.pos, "view", "integer type"): case _: IntType => () // IntTypes are valid - } (nEnv, (prefix.getOrElse(len) -> newBank)) - } - private def checkPipeline(enabled: Boolean, loop: Command, body: Command) = { + private def checkPipeline(enabled: Boolean, loop: Command, body: Command) = // Only loops without sequencing may be pipelined. - body match { + body match case _: CSeq => - if enabled then { + if enabled then throw PipelineError(loop.pos) - } case _ => {} - } - } private def checkC(cmd: Command)(implicit env: Environment): Environment = - cmd match { + cmd match case CBlock(cmd) => env.withScope(checkC(cmd)(_)) case CPar(cmds) => cmds.foldLeft(env)({ case (env, c) => checkC(c)(env) }) case CSeq(cmds) => cmds.foldLeft(env)({ case (env, c) => checkC(c)(env) }) case CIf(cond, cons, alt) => { val (cTyp, e1) = checkE(cond)(env) - cTyp.matchOrError(cond.pos, "if condition", "bool") { + cTyp.matchOrError(cond.pos, "if condition", "bool"): case _: TBool => () - } e1.withScope(e => checkC(cons)(e)) e1.withScope(e => checkC(alt)(e)) // No binding updates need to be reflected. @@ -310,14 +280,13 @@ object TypeChecker { case CWhile(cond, pipeline, body) => { checkPipeline(pipeline, cmd, body) val (cTyp, e1) = checkE(cond)(env) - if cTyp != TBool() then { + if cTyp != TBool() then throw UnexpectedType( cond.pos, "while condition", TBool().toString, cTyp ) - } e1.withScope(e => checkC(body)(e)) } case CUpdate(lhs, rhs) => { @@ -339,7 +308,7 @@ object TypeChecker { env .resolveType(expTyp) - .matchOrError(l.pos, "Let bound array literal", "array type") { + .matchOrError(l.pos, "Let bound array literal", "array type"): case ta @ TArray(elemTyp, dims, _) => { assertOrThrow( dims.length == 1, @@ -364,12 +333,11 @@ object TypeChecker { id.typ = typ nEnv.add(id, ta) } - } } case l @ CLet(id, typ, Some(exp @ ERecLiteral(fs))) => { val expTyp = typ.getOrThrow(ExplicitTypeMissing(l.pos, "Record literal", id)) - env.resolveType(expTyp) match { + env.resolveType(expTyp) match case recTyp @ TRecType(name, expTypes) => { // Typecheck expressions in the literal and generate a new id to type map. val (env1, actualTypes) = fs.foldLeft((env, Map[Id, Type]()))({ @@ -405,7 +373,6 @@ object TypeChecker { env1.add(id, recTyp) } case t => throw UnexpectedType(exp.pos, "let", "record type", t) - } } case l @ CLet(id, typ, Some(exp)) => { // Check if the explicit type is bound in scope. Also, if the type is @@ -420,7 +387,7 @@ object TypeChecker { // Check the type of the expression val (t, e1) = checkE(exp) // Check if type of expression is a subtype of the annotated type. - rTyp match { + rTyp match case Some(t2) => { if isSubtype(t, t2) then e1.add(id, t2) @@ -428,15 +395,13 @@ object TypeChecker { throw UnexpectedSubtype(exp.pos, "let", t2, t) } case None => { - val typ = t match { + val typ = t match case TStaticInt(v) => TSizedInt(bitsNeeded(v), false) case _: TRational => TDouble() case t => t - } // Add inferred type to the AST Node. l.typ = Some(typ); e1.add(id, typ) } - } } case l @ CLet(id, typ, None) => { val fullTyp = typ @@ -461,12 +426,11 @@ object TypeChecker { } } case CView(id, arrId, vdims) => - env(arrId) match { + env(arrId) match case TArray(typ, adims, port) => { val (vlen, alen) = (vdims.length, adims.length) - if vlen != alen then { + if vlen != alen then throw IncorrectAccessDims(arrId, alen, vlen) - } val (env1, viewDims) = adims @@ -486,14 +450,12 @@ object TypeChecker { env1.add(id, viewTyp) } case t => throw UnexpectedType(cmd.pos, "view", "array", t) - } case CSplit(id, arrId, dims) => - env(arrId) match { + env(arrId) match case TArray(typ, adims, ports) => { val (vlen, alen) = (dims.length, adims.length) - if vlen != alen then { + if vlen != alen then throw IncorrectAccessDims(arrId, alen, vlen) - } /** * Create a type for the split view. For the following split view: @@ -525,17 +487,13 @@ object TypeChecker { env.add(id, viewTyp) } case t => throw UnexpectedType(cmd.pos, "split", "array", t) - } case CExpr(e) => checkE(e)._2 case CReturn(expr) => { val retType = env.getReturn.get val (t, e) = checkE(expr) - if isSubtype(t, retType) == false then { + if isSubtype(t, retType) == false then throw UnexpectedSubtype(expr.pos, "return", retType, t) - } e } case CEmpty => env case _: CDecorate => env - } -} diff --git a/src/main/scala/typechecker/TypeEnv.scala b/src/main/scala/typechecker/TypeEnv.scala index 8f77e79f..c361d375 100644 --- a/src/main/scala/typechecker/TypeEnv.scala +++ b/src/main/scala/typechecker/TypeEnv.scala @@ -11,7 +11,7 @@ import Errors._ import CompilerError._ import EnvHelpers._ -object TypeEnv { +object TypeEnv: val emptyEnv: Environment = Env() @@ -26,7 +26,7 @@ object TypeEnv { * last two associations. A scope is a logical grouping of assoication * corresponding to lexical scope in programs. */ - sealed trait Environment extends Tracker[Id, Type, Environment] { + sealed trait Environment extends Tracker[Id, Type, Environment]: /** * Type binding manipulation @@ -80,27 +80,24 @@ object TypeEnv { def getReturn: Option[Type] def withReturn(typ: Type): Environment - } private case class Env( typeMap: ScopedMap[Id, Type] = ScopedMap(), typeDefMap: Map[Id, Type] = Map(), retType: Option[Type] = None - ) extends Environment { + ) extends Environment: /** Type definitions */ - def addType(alias: Id, typ: Type) = typeDefMap.get(alias) match { + def addType(alias: Id, typ: Type) = typeDefMap.get(alias) match case Some(_) => throw AlreadyBound(alias) case None => this.copy(typeDefMap = typeDefMap + (alias -> typ)) - } def getType(alias: Id) = typeDefMap.get(alias).getOrThrow(Unbound(alias)) - def resolveType(typ: Type): Type = typ match { + def resolveType(typ: Type): Type = typ match case TAlias(n) => getType(n) case TFun(args, ret) => TFun(args.map(resolveType(_)), resolveType(ret)) case arr @ TArray(t, _, _) => arr.copy(typ = resolveType(t)) case t => t - } /** Type binding methods */ def get(id: Id) = typeMap.get(id) @@ -110,23 +107,17 @@ object TypeEnv { ) /** Helper functions for ScopeManager */ - def addScope = { + def addScope = Env(typeMap.addScope, typeDefMap, retType) - } - def endScope = { + def endScope = val scopes = for (_, tMap) <- typeMap.endScope yield Env(tMap, typeDefMap, retType) scopes.getOrThrow(Impossible("Removed topmost scope")) - } - def withScope(inScope: Environment => Environment) = { - inScope(this.addScope) match { + def withScope(inScope: Environment => Environment) = + inScope(this.addScope) match case env: Env => env.endScope - } - } def getReturn = retType def withReturn(typ: Type) = this.copy(retType = Some(typ)) - } -} diff --git a/src/test/scala/ParsingPositive.scala b/src/test/scala/ParsingPositive.scala index 6d441e5d..3dc45aca 100644 --- a/src/test/scala/ParsingPositive.scala +++ b/src/test/scala/ParsingPositive.scala @@ -3,23 +3,21 @@ package fuselang import TestUtils._ import org.scalatest.funsuite.AnyFunSuite -class ParsingTests extends AnyFunSuite { - test("numbers") { +class ParsingTests extends AnyFunSuite: + test("numbers"): parseAst("1;") parseAst("1.25;") parseAst("0.25;") parseAst("0x19;") parseAst("014;") parseAst("0x9e3779b9;") - } - test("atoms") { + test("atoms"): parseAst("true;") parseAst("false;") parseAst("true;") - } - test("comments") { + test("comments"): parseAst(""" /* this is a comment * on @@ -28,9 +26,8 @@ class ParsingTests extends AnyFunSuite { // this is comment x; """) - } - test("binops") { + test("binops"): parseAst("1 + 2;") parseAst("1 + 2;") parseAst("1 + 2.5;") @@ -41,42 +38,36 @@ class ParsingTests extends AnyFunSuite { parseAst("1 % 2;") parseAst("true || false;") parseAst("true && false;") - } - test("binop precedence order") { + test("binop precedence order"): parseAst("(1 + 2) * 3;") parseAst("1 + 2 * 3 >= 10 - 5 / 7;") parseAst("1 >> 2 | 3 ^ 4 & 5;") parseAst("1 >= 2 || 4 < 5;") - } - test("if") { + test("if"): parseAst("if (true) {}") parseAst("if (false) { 1 + 2; }") parseAst("if (false) { 1 + 2; }") - } - test("decl") { + test("decl"): parseAst("decl x: bit<64>;") parseAst("decl x: bool;") parseAst("decl x: bit<64>[10 bank 5];") - } - test("let") { + test("let"): parseAst("let x = 1; x + 2;") parseAst("let force = 1; x + 2;") parseAst("let x: bit<32>; x + 2;") - } - test("for loop") { + test("for loop"): parseAst(""" for (let i = 0..10) unroll 5 { x + 1; } """) - } - test("while loop") { + test("while loop"): parseAst(""" while (false) { let x = 1; @@ -86,9 +77,8 @@ class ParsingTests extends AnyFunSuite { } } """) - } - test("combiner syntax") { + test("combiner syntax"): parseAst(""" for (let i = 0..10) { } combine { @@ -102,21 +92,18 @@ class ParsingTests extends AnyFunSuite { let x = 1; } """) - } - test("refresh banks") { + test("refresh banks"): parseAst(""" x + 1; --- x + 2; """) - } - test("commands") { + test("commands"): parseAst("""{ x+1; }""") - } - test("functions") { + test("functions"): parseAst(""" def foo(a: bit<32>) = {} """) @@ -130,9 +117,8 @@ class ParsingTests extends AnyFunSuite { bar(1, 2, 3); } """) - } - test("records definitions") { + test("records definitions"): parseAst(""" record Point { x: bit<32>; @@ -145,21 +131,18 @@ class ParsingTests extends AnyFunSuite { y: bit<32> } """) - } - test("record literals") { + test("record literals"): parseAst(""" let res: point = { x = 10; y = 10 }; """) - } - test("array literals") { + test("array literals"): parseAst(""" let res: bit<32>[10] = { 1, 2, 3 }; """) - } - test("records access") { + test("records access"): parseAst(""" let k = p.x; """) @@ -169,9 +152,8 @@ class ParsingTests extends AnyFunSuite { parseAst(""" let k = rec.po.x; """) - } - test("imports") { + test("imports"): parseAst(""" import vivado("print.h") {} """) @@ -180,9 +162,8 @@ class ParsingTests extends AnyFunSuite { def foo(a: bit<32>); } """) - } - test("simple views") { + test("simple views"): parseAst(""" view v = a[_ :]; """) @@ -222,9 +203,8 @@ class ParsingTests extends AnyFunSuite { parseAst(""" view v = a[i + 1! : +3 bank 5]; """) - } - test("split views") { + test("split views"): parseAst(""" split b = a[by 10]; """) @@ -232,9 +212,8 @@ class ParsingTests extends AnyFunSuite { parseAst(""" split b = a[by 10][by 20]; """) - } - test("casting") { + test("casting"): parseAst(""" let x = (y as bit<32>); """) @@ -245,5 +224,3 @@ class ParsingTests extends AnyFunSuite { let x = (0x9e3779b9 as ubit<32>); let y = (023615674671 as ubit<32>); """) - } -} diff --git a/src/test/scala/TestUtils.scala b/src/test/scala/TestUtils.scala index 3d464efe..c5c302df 100644 --- a/src/test/scala/TestUtils.scala +++ b/src/test/scala/TestUtils.scala @@ -3,15 +3,13 @@ package fuselang import common._ import Compiler._ -object TestUtils { +object TestUtils: import scala.language.implicitConversions // Allow for env("x") style calls. - implicit def stringToId(s: String): Syntax.Id = { + implicit def stringToId(s: String): Syntax.Id = Syntax.Id(s) - } def parseAst(s: String) = Parser(s).parse() def typeCheck(s: String) = checkStringWithError(s) -} diff --git a/src/test/scala/TypeCheckerSpec.scala b/src/test/scala/TypeCheckerSpec.scala index 196754f1..1bd3c160 100644 --- a/src/test/scala/TypeCheckerSpec.scala +++ b/src/test/scala/TypeCheckerSpec.scala @@ -5,103 +5,73 @@ import TestUtils._ import Errors._ import org.scalatest.funspec.AnyFunSpec -class TypeCheckerSpec extends AnyFunSpec { +class TypeCheckerSpec extends AnyFunSpec: // Suppress logging. common.Logger.setLogLevel(scribe.Level.Error) - describe("Let bindings") { - describe("with explicit type and initializer") { - it("disallows using smaller sized int in assignment") { - assertThrows[UnexpectedSubtype] { + describe("Let bindings"): + describe("with explicit type and initializer"): + it("disallows using smaller sized int in assignment"): + assertThrows[UnexpectedSubtype]: typeCheck("decl a: bit<16>; let x: bit<8> = a;") - } - } - it("disallows using smaller range of fix in assignment") { - assertThrows[UnexpectedSubtype] { + it("disallows using smaller range of fix in assignment"): + assertThrows[UnexpectedSubtype]: typeCheck("decl a: fix<16,8>; let x: fix<8,4> = a;") - } - } - it("RHS type must be equal to LHS type") { - assertThrows[UnexpectedSubtype] { + it("RHS type must be equal to LHS type"): + assertThrows[UnexpectedSubtype]: typeCheck("let x: bit<16> = true;") - } - } - it("RHS type must be equal to LHS type for computable numbers") { - assertThrows[NoJoin] { + it("RHS type must be equal to LHS type for computable numbers"): + assertThrows[NoJoin]: typeCheck("let x: fix<2,2> = 2+2.1;") - } - } - it("should allow large unsigned literals") { + it("should allow large unsigned literals"): typeCheck("let x: ubit<32> = 0x9e3779b9;") typeCheck("let x: ubit<64> = 0xffffffffffffffff;") typeCheck("let x: bit<64> = 0x7fffffffffffffff;") typeCheck("let x: ubit<128> = 0xffffffffffffffffffffffffffffffff;") - } - } - describe("with explicit type and without initializer") { - it("bit type works") { + describe("with explicit type and without initializer"): + it("bit type works"): typeCheck("let x: bit<16>;") - } - it("bit type can be assigned to") { + it("bit type can be assigned to"): typeCheck("let x: bit<16>; x := 1;") - } - it("bit type requires correct type in assignment") { - assertThrows[UnexpectedSubtype] { + it("bit type requires correct type in assignment"): + assertThrows[UnexpectedSubtype]: typeCheck("let x: bit<16>; x := true;") - } - } - it("fix type works") { + it("fix type works"): typeCheck("let x: fix<2,1>; x := 1.1;") - } - it("fix type requires correct type in assignment") { - assertThrows[UnexpectedSubtype] { + it("fix type requires correct type in assignment"): + assertThrows[UnexpectedSubtype]: typeCheck("let x: fix<2,2>; x := true;") - } - } - it("fix type has range checking") { - assertThrows[UnexpectedSubtype] { + it("fix type has range checking"): + assertThrows[UnexpectedSubtype]: typeCheck("let x: fix<2,1>; x := 2;") - } - } - } - } - describe("Cannot reference undeclared var") { - it("in top level") { - assertThrows[Unbound] { + describe("Cannot reference undeclared var"): + it("in top level"): + assertThrows[Unbound]: typeCheck("x + 1;") - } - } - } - describe("Array access") { - it("with invalid accessor type") { - assertThrows[UnexpectedType] { + describe("Array access"): + it("with invalid accessor type"): + assertThrows[UnexpectedType]: typeCheck(""" decl a: bit<10>[10]; a[true]; """) - } - } - it("with too many dimensions") { - assertThrows[IncorrectAccessDims] { + it("with too many dimensions"): + assertThrows[IncorrectAccessDims]: typeCheck(""" decl a: bit<10>[10]; a[1][1]; """) - } - } - it("with too few dimensions") { - assertThrows[IncorrectAccessDims] { + it("with too few dimensions"): + assertThrows[IncorrectAccessDims]: typeCheck(""" decl a: bit<10>[10][10]; a[1]; """) - } - } - it("consumes bank without unroll") { - assertThrows[AlreadyConsumed] { + it("consumes bank without unroll"): + assertThrows[AlreadyConsumed]: typeCheck(""" decl a: bit<64>[10]; for (let i = 0..10) { @@ -109,10 +79,8 @@ class TypeCheckerSpec extends AnyFunSpec { } a[0]; """) - } - } - it("consumes all banks with unroll") { - assertThrows[AlreadyConsumed] { + it("consumes all banks with unroll"): + assertThrows[AlreadyConsumed]: typeCheck(""" decl a: bit<64>[10 bank 5]; for (let i = 0..10) unroll 5 { @@ -120,41 +88,27 @@ class TypeCheckerSpec extends AnyFunSpec { } a[0]; """) - } - } - } - describe("Variables scoping") { - describe("shadowing variables not allowed") { - it("in top level") { - assertThrows[AlreadyBound] { + describe("Variables scoping"): + describe("shadowing variables not allowed"): + it("in top level"): + assertThrows[AlreadyBound]: typeCheck("let x = 1; let x = 1;") - } - } - it("in if body") { - assertThrows[AlreadyBound] { + it("in if body"): + assertThrows[AlreadyBound]: typeCheck("let x = 1; if(true) { let x = 1; }") - } - } - } - it("in if") { - assertThrows[Unbound] { + it("in if"): + assertThrows[Unbound]: typeCheck("if (true) {let x = 1;} x + 2;") - } - } - it("in for") { - assertThrows[Unbound] { + it("in for"): + assertThrows[Unbound]: typeCheck("for (let i = 0..10){let x = 1;} x + 2;") - } - } - it("in while") { - assertThrows[Unbound] { + it("in while"): + assertThrows[Unbound]: typeCheck("while (true) {let x = 1;} x + 2;") - } - } - it("allows same name in different scopes") { + it("allows same name in different scopes"): typeCheck(""" for (let i = 0..1) { let x = 10; @@ -163,105 +117,71 @@ class TypeCheckerSpec extends AnyFunSpec { let x = 10; } """) - } - it("variables not available outside block scope") { - assertThrows[Unbound] { + it("variables not available outside block scope"): + assertThrows[Unbound]: typeCheck(""" { let x = 10; } x; """) - } - } - } - describe("Binary operations") { + describe("Binary operations"): - it("comparisons on rational returns a boolean") { + it("comparisons on rational returns a boolean"): typeCheck("if (2.5 < 23.5) { 1; }") - } - it("comparisons on float and rational a boolean") { + it("comparisons on float and rational a boolean"): typeCheck("let x: float = 1.0; x < 10.0;") - } - it("can add sized int to sized int") { + it("can add sized int to sized int"): typeCheck("decl x: bit<64>; let y = 1; x + y;") - } - it("cannot add fixed point to int") { - assertThrows[NoJoin] { + it("cannot add fixed point to int"): + assertThrows[NoJoin]: typeCheck("decl x: fix<64,32>; let y = 1 + x;") - } - } - it("can add fixed point to rational") { + it("can add fixed point to rational"): typeCheck("decl x: fix<64,32>; let y = 1.5 + x;") - } - it("can add float with double") { + it("can add float with double"): typeCheck("decl f: float; let y = 1.5; f + y;") - } - it("cannot add int and rational") { - assertThrows[NoJoin] { + it("cannot add int and rational"): + assertThrows[NoJoin]: typeCheck("1 + 2.5;") - } - } - it("cannot add float dec and int") { - assertThrows[NoJoin] { + it("cannot add float dec and int"): + assertThrows[NoJoin]: typeCheck("decl f: float; f + 1;") - } - } - it("cannot add fix dec and int") { - assertThrows[NoJoin] { + it("cannot add fix dec and int"): + assertThrows[NoJoin]: typeCheck("decl f: fix<32,16>; f + 1;") - } - } - it("comparison not defined for memories") { - assertThrows[UnexpectedType] { + it("comparison not defined for memories"): + assertThrows[UnexpectedType]: typeCheck("decl a: bit<10>[10]; decl b: bit<10>[10]; a == b;") - } - } - it("cannot shift rational") { - assertThrows[BinopError] { + it("cannot shift rational"): + assertThrows[BinopError]: typeCheck("10.5 << 1;") - } - } - it("logical and defined on booleans") { - assertThrows[BinopError] { + it("logical and defined on booleans"): + assertThrows[BinopError]: typeCheck("1 || 2;") - } - } - it("adding static int does NOT perform type level computation") { + it("adding static int does NOT perform type level computation"): typeCheck("let x = 1; let y = 2; let z = x + y;") - } - it("result of bit type addition upcast to subtype join") { + it("result of bit type addition upcast to subtype join"): typeCheck("decl x: bit<32>; decl y: bit<16>; let z = x + y;") - } - it("result of fix type addition upcast to subtype join") { + it("result of fix type addition upcast to subtype join"): typeCheck("decl x: fix<32,16>; decl y: fix<16,8>; let z = x + y;") - } - it("result of peephole optimization on large unsigned int should be valid") { + it("result of peephole optimization on large unsigned int should be valid"): typeCheck("let x: ubit<32> = 0x9e3779b9; let y = x + 0;") - } - } - describe("Reassign") { - it("cannot reassign to non-subtype") { - assertThrows[UnexpectedSubtype] { + describe("Reassign"): + it("cannot reassign to non-subtype"): + assertThrows[UnexpectedSubtype]: typeCheck("let x = 1; x := 2.5;") - } - } - it("can reassign decl") { + it("can reassign decl"): typeCheck("decl x: bit<32>; decl y: bit<16>; x := y;") - } - } - describe("Conditionals (if)") { - it("condition must be a boolean") { - assertThrows[UnexpectedType] { + describe("Conditionals (if)"): + it("condition must be a boolean"): + assertThrows[UnexpectedType]: typeCheck("if (1) { let x = 10; }") - } - } - it("cannot consume same banks in sequenced statements") { - assertThrows[AlreadyConsumed] { + it("cannot consume same banks in sequenced statements"): + assertThrows[AlreadyConsumed]: typeCheck(""" decl a: bit<10>[2 bank 2]; if (true) { @@ -269,10 +189,8 @@ class TypeCheckerSpec extends AnyFunSpec { } a[0]; """) - } - } - it("cannot consume same banks in sequenced statements from else branch") { - assertThrows[AlreadyConsumed] { + it("cannot consume same banks in sequenced statements from else branch"): + assertThrows[AlreadyConsumed]: typeCheck(""" decl a: bit<10>[2 bank 2]; if (true) { @@ -282,9 +200,7 @@ class TypeCheckerSpec extends AnyFunSpec { } a[1]; """) - } - } - it("can consume same banks in branches") { + it("can consume same banks in branches"): typeCheck(""" decl a: bit<10>[10]; if (true) { @@ -293,8 +209,7 @@ class TypeCheckerSpec extends AnyFunSpec { a[0]; } """) - } - it("can consume bank in sequenced statement not used in either branch") { + it("can consume bank in sequenced statement not used in either branch"): typeCheck(""" decl a: bit<10>[2 bank 2]; if (true) { @@ -304,8 +219,7 @@ class TypeCheckerSpec extends AnyFunSpec { } a[1]; """) - } - it("can create different capabilities in branches") { + it("can create different capabilities in branches"): // See discussion in: https://github.com/cucapra/dahlia/pull/81 typeCheck(""" decl a: bit<10>[2 bank 2]; @@ -315,34 +229,27 @@ class TypeCheckerSpec extends AnyFunSpec { let x = a[0]; } """) - } - } - describe("while loops") { - it("work") { + describe("while loops"): + it("work"): typeCheck(""" while (true) { let x = 1; } """) - } - } - describe("Ranges") { - it("Unrolling constant must be factor of length") { - assertThrows[UnrollRangeError] { + describe("Ranges"): + it("Unrolling constant must be factor of length"): + assertThrows[UnrollRangeError]: typeCheck(""" for (let i = 0..10) unroll 3 { let x = 1; } """) - } - } - } - describe("Reductions") { + describe("Reductions"): // This is equivalent to the example above - it("fully unrolled loop and fully banked array") { + it("fully unrolled loop and fully banked array"): typeCheck(""" decl a: bit<64>[10 bank 10]; let sum: bit<64> = 0; @@ -352,11 +259,9 @@ class TypeCheckerSpec extends AnyFunSpec { sum += v; } """) - } - } - describe("Combine") { - it("without unrolling") { + describe("Combine"): + it("without unrolling"): typeCheck(""" decl a: bit<64>[10]; let sum: bit<64> = 0; @@ -366,9 +271,8 @@ class TypeCheckerSpec extends AnyFunSpec { sum += x; } """) - } - it("with unrolling") { + it("with unrolling"): typeCheck(""" decl a: bit<64>[10 bank 5]; let sum: bit<64> = 0; @@ -378,36 +282,30 @@ class TypeCheckerSpec extends AnyFunSpec { sum += x; } """) - } - } - describe("Multi-ported Memories") { - it("without an annotaion default to single port") { - assertThrows[AlreadyConsumed] { + describe("Multi-ported Memories"): + it("without an annotaion default to single port"): + assertThrows[AlreadyConsumed]: typeCheck(""" decl a: bit<32>[10]; a[0]; a[1]; """) - } - } - it("multiple writes consume resources") { + it("multiple writes consume resources"): typeCheck(""" decl a: bit<32>{2}[10]; a[0] := 0; a[1] := 1; """) - } - it("multiple reads from different locations consume resources") { + it("multiple reads from different locations consume resources"): typeCheck(""" decl a: bit<32>{2}[10]; a[0]; a[1]; """) - } - it("multiple reads from same locations don't consume resources") { + it("multiple reads from same locations don't consume resources"): typeCheck(""" decl a: bit<32>{2}[10]; a[0]; @@ -415,27 +313,23 @@ class TypeCheckerSpec extends AnyFunSpec { a[0]; a[1]; """) - } - it("allow reads and writes to the same location") { + it("allow reads and writes to the same location"): typeCheck(""" decl a: bit<32>{2}[10]; a[0] := 1; let x = a[0]; """) - } - it("disallow writes to the same location") { - assertThrows[AlreadyWrite] { + it("disallow writes to the same location"): + assertThrows[AlreadyWrite]: typeCheck(""" decl a: bit<32>{2}[10]; a[0] := 1; a[0] := 2; """) - } - } - it("each bank gets multiple ports") { + it("each bank gets multiple ports"): typeCheck(""" decl a: bit<32>{2}[10 bank 2]; // Bank 0 @@ -445,17 +339,15 @@ class TypeCheckerSpec extends AnyFunSpec { a[1] := 1; a[3] := 3; """) - } - it("index types only consume one port") { + it("index types only consume one port"): typeCheck(""" decl a: bit<32>{2}[10]; for (let i = 0..10) { a[i] := 1; } a[0] := 2; """) - } - it("regenerate after ---") { + it("regenerate after ---"): typeCheck(""" decl a: bit<32>{2}[10]; a[1] := 1; @@ -464,44 +356,36 @@ class TypeCheckerSpec extends AnyFunSpec { a[1] := 3; a[0] := 4; """) - } - it("allow more unrolling than banks") { + it("allow more unrolling than banks"): typeCheck(""" decl a: bit<32>{2}[8 bank 2]; for (let i = 0..8) unroll 4 { a[i] := 1; } """) - } - } - describe("Banking factor not equal to unrolling factor") { - it("bank not equal") { - assertThrows[BankUnrollInvalid] { + describe("Banking factor not equal to unrolling factor"): + it("bank not equal"): + assertThrows[BankUnrollInvalid]: typeCheck(""" decl a: bit<32>[10 bank 5]; for (let i = 0..10) unroll 2 { let x = a[i]; } """) - } - } - it("bank factor of unroll") { - assertThrows[BankUnrollInvalid] { + it("bank factor of unroll"): + assertThrows[BankUnrollInvalid]: typeCheck(""" decl a: bit<32>[8 bank 4]; for (let i = 0..8) unroll 2 { let x = a[i]; } """) - } - } - } - describe("Sequential composition") { - it("total resources consumed is union of all resources consumed") { - assertThrows[AlreadyConsumed] { + describe("Sequential composition"): + it("total resources consumed is union of all resources consumed"): + assertThrows[AlreadyConsumed]: typeCheck(""" decl a: bit<32>[8]; decl b: bit<32>[8]; @@ -515,11 +399,9 @@ class TypeCheckerSpec extends AnyFunSpec { } b[0] := 1; """) - } - } - it("works with scoped blocks") { - assertThrows[AlreadyConsumed] { + it("works with scoped blocks"): + assertThrows[AlreadyConsumed]: typeCheck(""" decl a: bit<32>[8]; { @@ -529,18 +411,15 @@ class TypeCheckerSpec extends AnyFunSpec { } a[0] := 1; """) - } - } - it("allow declarations in first statement to be used in second") { + it("allow declarations in first statement to be used in second"): typeCheck(""" let bucket_idx = 10; --- bucket_idx := (20 as bit<4>); """) - } - it("check for declarations used in both branches") { + it("check for declarations used in both branches"): typeCheck(""" let test_var:bit<32> = 10; { @@ -549,11 +428,9 @@ class TypeCheckerSpec extends AnyFunSpec { test_var := 30; } """) - } - } - describe("Parallel composition") { - it("allows same banks to used") { + describe("Parallel composition"): + it("allows same banks to used"): typeCheck(""" decl a: bit<64>[10 bank 5]; for (let i = 0..10) unroll 5 { @@ -562,9 +439,8 @@ class TypeCheckerSpec extends AnyFunSpec { let y = a[i]; } """) - } - it("allows same banks to be used with reassignment") { + it("allows same banks to be used with reassignment"): typeCheck(""" decl a: bit<10>[20 bank 10]; for (let i = 0..20) unroll 10 { @@ -573,9 +449,8 @@ class TypeCheckerSpec extends AnyFunSpec { a[i] := 2; } """) - } - it("reuses banks of multidimensional array") { + it("reuses banks of multidimensional array"): typeCheck(""" decl a: bit<10>[20 bank 10][10 bank 5]; for (let i = 0..20) unroll 10 { @@ -586,23 +461,19 @@ class TypeCheckerSpec extends AnyFunSpec { } } """) - } - } - describe("Capabilities in simple contexts") { - it("read capabilities end at scope boundaries") { - assertThrows[AlreadyConsumed] { + describe("Capabilities in simple contexts"): + it("read capabilities end at scope boundaries"): + assertThrows[AlreadyConsumed]: typeCheck(""" decl a: bit<32>[6 bank 6]; for(let i = 0..6) { a[0]; } for(let i = 0..6) { a[0]; } """) - } - } - it("write capabilities can only be used once") { - assertThrows[AlreadyWrite] { + it("write capabilities can only be used once"): + assertThrows[AlreadyWrite]: typeCheck(""" decl a: bit<32>[6 bank 6]; @@ -611,10 +482,8 @@ class TypeCheckerSpec extends AnyFunSpec { a[0] := 1; } """) - } - } - it("read capabilities can be used multiple times") { + it("read capabilities can be used multiple times"): typeCheck(""" decl a: bit<32>[6 bank 6]; @@ -623,10 +492,9 @@ class TypeCheckerSpec extends AnyFunSpec { let y = a[0]; } """) - } - it("read cannot occur after write") { - assertThrows[AlreadyConsumed] { + it("read cannot occur after write"): + assertThrows[AlreadyConsumed]: typeCheck(""" decl a: bit<32>[6 bank 6]; for (let i = 0..6) { @@ -634,11 +502,9 @@ class TypeCheckerSpec extends AnyFunSpec { let x = a[0] + 1; } """) - } - } - it("write cannot occur after read") { - assertThrows[AlreadyConsumed] { + it("write cannot occur after read"): + assertThrows[AlreadyConsumed]: typeCheck(""" decl a: bit<32>[6 bank 6]; for (let i = 0..6) { @@ -646,10 +512,8 @@ class TypeCheckerSpec extends AnyFunSpec { a[0] := 1; } """) - } - } - it("read after write in same context with seq composition") { + it("read after write in same context with seq composition"): typeCheck(""" decl a: bit<32>[6 bank 6]; for (let i = 0..6) { @@ -658,9 +522,8 @@ class TypeCheckerSpec extends AnyFunSpec { a[0] := 1; } """) - } - it("write after read in same context with seq composition") { + it("write after read in same context with seq composition"): typeCheck(""" decl a: bit<32>[6 bank 6]; for (let i = 0..6) { @@ -669,9 +532,8 @@ class TypeCheckerSpec extends AnyFunSpec { let x = a[0] + 1; } """) - } - it("write after write in same context with seq composition") { + it("write after write in same context with seq composition"): typeCheck(""" decl a: bit<32>[6 bank 6]; for (let i = 0..6) { @@ -680,22 +542,18 @@ class TypeCheckerSpec extends AnyFunSpec { a[0] := 2; } """) - } - } - describe("Capabilities in unrolled context") { - it("write in one unrolled loop and a constant access") { - assertThrows[InsufficientResourcesInUnrollContext] { + describe("Capabilities in unrolled context"): + it("write in one unrolled loop and a constant access"): + assertThrows[InsufficientResourcesInUnrollContext]: typeCheck(""" decl a: bit<32>[10]; for (let i = 0..10) unroll 5 { a[0] := 1; } """) - } - } - it("write in two unrolled loops and incorrect idx accessor") { - assertThrows[InsufficientResourcesInUnrollContext] { + it("write in two unrolled loops and incorrect idx accessor"): + assertThrows[InsufficientResourcesInUnrollContext]: typeCheck(""" decl a: bit<32>[10][10 bank 5]; for (let i = 0..10) { @@ -704,10 +562,8 @@ class TypeCheckerSpec extends AnyFunSpec { } } """) - } - } - it("write with three loops, 2 unrolled") { - assertThrows[InsufficientResourcesInUnrollContext] { + it("write with three loops, 2 unrolled"): + assertThrows[InsufficientResourcesInUnrollContext]: typeCheck(""" decl a: bit<32>[10][10 bank 5]; for (let k = 0..10) { @@ -718,9 +574,7 @@ class TypeCheckerSpec extends AnyFunSpec { } } """) - } - } - it("read with one constant accessor and one unrolled iterator") { + it("read with one constant accessor and one unrolled iterator"): typeCheck(""" decl a: bit<32>[10 bank 5][10 bank 5]; for (let i = 0..10) { @@ -731,16 +585,14 @@ class TypeCheckerSpec extends AnyFunSpec { } } """) - } - it("read in one unrolled loop and a constant access") { + it("read in one unrolled loop and a constant access"): typeCheck(""" decl a: bit<32>[10]; for (let i = 0..10) unroll 5 { let x = a[0]; } """) - } - it("read in two unrolled loops and incorrect idx accessor") { + it("read in two unrolled loops and incorrect idx accessor"): typeCheck(""" decl a: bit<32>[10][10 bank 5]; for (let i = 0..10) { @@ -749,8 +601,7 @@ class TypeCheckerSpec extends AnyFunSpec { } } """) - } - it("read with three loops, 2 unrolled") { + it("read with three loops, 2 unrolled"): typeCheck(""" decl a: bit<32>[10][10 bank 5]; for (let k = 0..10) { @@ -761,8 +612,7 @@ class TypeCheckerSpec extends AnyFunSpec { } } """) - } - it("read with two dimensional unrolling") { + it("read with two dimensional unrolling"): typeCheck(""" decl a: bit<32>[10 bank 5][10 bank 5]; for (let i = 0..10) unroll 5 { @@ -773,11 +623,9 @@ class TypeCheckerSpec extends AnyFunSpec { } } """) - } - } - describe("Loop depedency in unrolled context") { - it("defined use inside multiple unrolled contexts") { + describe("Loop depedency in unrolled context"): + it("defined use inside multiple unrolled contexts"): typeCheck(""" decl q: ubit<32>[8 bank 2]; @@ -789,9 +637,8 @@ class TypeCheckerSpec extends AnyFunSpec { } } """) - } - it("use after define in an unrolled loop is not allowed") { - assertThrows[LoopDepSequential] { + it("use after define in an unrolled loop is not allowed"): + assertThrows[LoopDepSequential]: typeCheck(""" let a: bit<32>{2}[10]; for (let i = 0..10) unroll 2 { @@ -800,10 +647,8 @@ class TypeCheckerSpec extends AnyFunSpec { a[i] := 1; } """) - } - } - it("use after define in a nested unrolled loop is not allowed") { - assertThrows[LoopDepSequential] { + it("use after define in a nested unrolled loop is not allowed"): + assertThrows[LoopDepSequential]: typeCheck(""" decl a: bit<32>[10 bank 5][10 bank 5]; for (let i = 0..10) unroll 5 { @@ -814,9 +659,7 @@ class TypeCheckerSpec extends AnyFunSpec { } } """) - } - } - it("skip loop dependency check when indexing is the same") { + it("skip loop dependency check when indexing is the same"): typeCheck(""" decl a: bit<32>[10 bank 2]; for (let i = 0..10) unroll 2 { @@ -825,10 +668,9 @@ class TypeCheckerSpec extends AnyFunSpec { a[i] := 1; } """) - } it( "skip loop dependency check when indexing is the same in a nested unrolled loop" - ) { + ): typeCheck(""" decl a: bit<32>{2}[10][10 bank 2]; for (let i = 0..6) unroll 2 { @@ -839,9 +681,8 @@ class TypeCheckerSpec extends AnyFunSpec { } } """) - } - it("no skip on loop dependency check when encountered a view") { - assertThrows[LoopDepSequential] { + it("no skip on loop dependency check when encountered a view"): + assertThrows[LoopDepSequential]: typeCheck(""" decl a: bit<32>{2}[10 bank 2]; view a_v = a[1!: bank 2]; @@ -851,9 +692,7 @@ class TypeCheckerSpec extends AnyFunSpec { a_v[i] := 1; } """) - } - } - it("condition in if is always a use") { + it("condition in if is always a use"): typeCheck(""" decl a: bit<32>[10 bank 5]; for (let i = 0..10) unroll 5 { @@ -864,9 +703,8 @@ class TypeCheckerSpec extends AnyFunSpec { } } """) - } - it("merging condition creates don't know state") { - assertThrows[LoopDepSequential] { + it("merging condition creates don't know state"): + assertThrows[LoopDepSequential]: typeCheck(""" decl a: bit<32>; decl b: bit<32>[10 bank 2]; @@ -878,10 +716,8 @@ class TypeCheckerSpec extends AnyFunSpec { a; } """) - } - } - it("merging use and define is not allowed") { - assertThrows[LoopDepMerge] { + it("merging use and define is not allowed"): + assertThrows[LoopDepMerge]: typeCheck(""" decl a: bit<32>[10 bank 5]; decl b: bit<32>[10 bank 5]; @@ -893,9 +729,7 @@ class TypeCheckerSpec extends AnyFunSpec { } } """) - } - } - it("define after use in an unrolled loop works") { + it("define after use in an unrolled loop works"): typeCheck(""" decl a: bit<32>[10 bank 5]; for (let i = 0..10) unroll 5 { @@ -904,8 +738,7 @@ class TypeCheckerSpec extends AnyFunSpec { let x = a[i]; } """) - } - it("define after use in nested unrolled loop works") { + it("define after use in nested unrolled loop works"): typeCheck(""" decl a: bit<32>[10 bank 5][10 bank 5]; for (let i = 0..10) unroll 5 { @@ -916,8 +749,7 @@ class TypeCheckerSpec extends AnyFunSpec { a[i][0]; } """) - } - it("don't know state can be transferred to defined state again") { + it("don't know state can be transferred to defined state again"): typeCheck(""" decl a: bit<32>[10 bank 5]; decl b: bit<32>[10 bank 5]; @@ -931,8 +763,7 @@ class TypeCheckerSpec extends AnyFunSpec { a[i]; } """) - } - it("don't know state is fine as long as the loop execution ends") { + it("don't know state is fine as long as the loop execution ends"): typeCheck(""" let b1min = 10; for(let i = 0..2) unroll 2 { @@ -944,43 +775,34 @@ class TypeCheckerSpec extends AnyFunSpec { } b1min :=10; """) - } - } - describe("Functions") { - it("cannot have same name for multiple params") { - assertThrows[AlreadyBound] { + describe("Functions"): + it("cannot have same name for multiple params"): + assertThrows[AlreadyBound]: typeCheck(""" def foo(a: bool, a: bit<10>) = {} """) - } - } - it("cannot be used before defintion") { - assertThrows[Unbound] { + it("cannot be used before defintion"): + assertThrows[Unbound]: typeCheck(""" def bar(a: bool) = { foo(a); } def foo(a: bool) = { foo(a); } """) - } - } - it("do not allow recursion") { - assertThrows[Unbound] { + it("do not allow recursion"): + assertThrows[Unbound]: typeCheck(""" def bar(a: bool) = { bar(a); } """) - } - } - it("allow return values") { + it("allow return values"): typeCheck(""" def foo(): bit<10> = { return 5; } let res: bit<10> = foo(); """) - } - it("allow record as return value") { + it("allow record as return value"): typeCheck(""" record point { x: bit<32> } def f(p: point): point = { @@ -988,39 +810,31 @@ class TypeCheckerSpec extends AnyFunSpec { return np; } """) - } - it("disallow ill-typed return values") { - assertThrows[UnexpectedSubtype] { + it("disallow ill-typed return values"): + assertThrows[UnexpectedSubtype]: typeCheck(""" def foo(): bool = { return 5; } """) - } - } - } - describe("Function applications") { - it("require the correct types") { - assertThrows[UnexpectedSubtype] { + describe("Function applications"): + it("require the correct types"): + assertThrows[UnexpectedSubtype]: typeCheck(""" def bar(a: bool) = { } bar(1); """) - } - } - it("should not allow functions with array arguments inside unrolled loops") { - assertThrows[FuncInUnroll] { + it("should not allow functions with array arguments inside unrolled loops"): + assertThrows[FuncInUnroll]: typeCheck(""" def bar(a: bool[4]) = { } for (let i = 0..10) unroll 5 { bar(tre); } """) - } - } - it("should allow functions with scalar args in unrolled loops") { + it("should allow functions with scalar args in unrolled loops"): typeCheck( """ def bar(a: bool) = { } @@ -1029,63 +843,51 @@ class TypeCheckerSpec extends AnyFunSpec { bar(tre); } """) - } - it("completely consume array parameters") { - assertThrows[AlreadyConsumed] { + it("completely consume array parameters"): + assertThrows[AlreadyConsumed]: typeCheck(""" def bar(a: bit<10>[10 bank 5]) = { } decl x: bit<10>[10 bank 5]; bar(x); x[1]; """) - } - } - it("consume array accesses") { - assertThrows[AlreadyConsumed] { + it("consume array accesses"): + assertThrows[AlreadyConsumed]: typeCheck(""" def bar(a: bit<10>) = { } decl x: bit<10>[10]; bar(x[0]); x[0] := 1; """) - } - } - it("Require exact match for array dimensions and banks") { - assertThrows[UnexpectedSubtype] { + it("Require exact match for array dimensions and banks"): + assertThrows[UnexpectedSubtype]: typeCheck(""" def foo(a: bit<32>[10 bank 5]) = { } decl b: bit<32>[5 bank 5]; foo(b); """) - } - } - it("Require argument and parameter lengths to match") { - assertThrows[ArgLengthMismatch] { + it("Require argument and parameter lengths to match"): + assertThrows[ArgLengthMismatch]: typeCheck(""" def foo(a: bit<32>, b: bit<32>) = { } foo(1); """) - } - } - } - describe("Simple views") { - it("must have dimensions equal to array") { - assertThrows[IncorrectAccessDims] { + describe("Simple views"): + it("must have dimensions equal to array"): + assertThrows[IncorrectAccessDims]: typeCheck(""" decl a: bit<10>[10 bank 5][10 bank 5]; view v = a[5 * i :]; """) - } - } - it("cannot be nested inside unroll context") { - assertThrows[ViewInsideUnroll] { + it("cannot be nested inside unroll context"): + assertThrows[ViewInsideUnroll]: typeCheck(""" decl a: bit<10>[16 bank 8]; for (let i = 0..4) unroll 4 { @@ -1094,45 +896,36 @@ class TypeCheckerSpec extends AnyFunSpec { } } """) - } - } - it("can use loop iterator if not unrolled") { + it("can use loop iterator if not unrolled"): typeCheck(""" decl a: bit<10>[16 bank 8]; for (let i = 0..4) { view v = a[8 * i :]; } """) - } - it("has the same type as the underlying array") { - assertThrows[NoJoin] { + it("has the same type as the underlying array"): + assertThrows[NoJoin]: typeCheck(""" decl a: bool[10 bank 5]; view v = a[0!:]; v[3] + 1; """) - } - } - it("has the same dimensions as underlying array") { - assertThrows[IncorrectAccessDims] { + it("has the same dimensions as underlying array"): + assertThrows[IncorrectAccessDims]: typeCheck(""" decl a: bool[10 bank 5][10 bank 5]; view v = a[0!:][0!:]; v[1]; """) - } - } - it("cannot be used in the same time step as underlying array") { - assertThrows[AlreadyConsumed] { + it("cannot be used in the same time step as underlying array"): + assertThrows[AlreadyConsumed]: typeCheck(""" decl a: bool[10 bank 5][10 bank 5]; view v = a[0!:][0!:]; a[0][0]; v[0][0]; """) - } - } - it("throw an error if require more resources than the underlying array.") { - assertThrows[AlreadyConsumed] { + it("throw an error if require more resources than the underlying array."): + assertThrows[AlreadyConsumed]: typeCheck(""" decl m1: double[64][64]; @@ -1141,10 +934,8 @@ class TypeCheckerSpec extends AnyFunSpec { let temp_x = m1_v[i][1]; } """) - } - } - it("with shrink require more resources than consume list") { - assertThrows[AlreadyConsumed] { + it("with shrink require more resources than consume list"): + assertThrows[AlreadyConsumed]: typeCheck(""" decl m1: double{2}[64 bank 2]; @@ -1153,10 +944,8 @@ class TypeCheckerSpec extends AnyFunSpec { let temp_x = m1_v[i]; } """) - } - } - it("with multiple dimensions share ports") { - assertThrows[AlreadyConsumed] { + it("with multiple dimensions share ports"): + assertThrows[AlreadyConsumed]: typeCheck(""" decl m1: double{2}[64][64]; @@ -1166,9 +955,7 @@ class TypeCheckerSpec extends AnyFunSpec { } } """) - } - } - it("can cross sequential composition boundary") { + it("can cross sequential composition boundary"): typeCheck(""" let A: float[10 bank 2]; view m1 = A[_: bank 2]; @@ -1176,153 +963,122 @@ class TypeCheckerSpec extends AnyFunSpec { --- m1[1]; """) - } - } - describe("Split views") { - it("requires the same dimesions as the underlying array") { - assertThrows[IncorrectAccessDims] { + describe("Split views"): + it("requires the same dimesions as the underlying array"): + assertThrows[IncorrectAccessDims]: typeCheck(""" decl a: bit<32>[10 bank 5][2]; split v = a[by 5]; """) - } - } - it("cannot be created inside an unrolled loop") { - assertThrows[ViewInsideUnroll] { + it("cannot be created inside an unrolled loop"): + assertThrows[ViewInsideUnroll]: typeCheck(""" decl a: bit<32>[10 bank 5]; for (let i = 0 .. 8) unroll 2 { split v = a[by 5]; } """) - } - } - it("requires split factor to divide banking factor") { - assertThrows[InvalidSplitFactor] { + it("requires split factor to divide banking factor"): + assertThrows[InvalidSplitFactor]: typeCheck(""" decl a: bit<32>[10]; split v = a[by 5]; """) - } - } - it("add another dimension for each non-zero split") { - assertThrows[IncorrectAccessDims] { + it("add another dimension for each non-zero split"): + assertThrows[IncorrectAccessDims]: typeCheck(""" decl a: bit<32>[10 bank 5]; split v = a[by 5]; v[0]; """) - } - } - it("can be created inside an normal loop") { + it("can be created inside an normal loop"): typeCheck(""" decl a: bit<32>[10 bank 5]; split v = a[by 5]; """) - } - } - describe("Simple aligned views") { - it("width must be factor of banking factor") { - assertThrows[InvalidAlignFactor] { + describe("Simple aligned views"): + it("width must be factor of banking factor"): + assertThrows[InvalidAlignFactor]: typeCheck(""" decl x: bit<32>; decl a: bit<10>[10 bank 5]; view v = a[3 * x :]; """) - } - } - it("width must be a multiple of the banking factor") { - assertThrows[InvalidAlignFactor] { + it("width must be a multiple of the banking factor"): + assertThrows[InvalidAlignFactor]: typeCheck(""" decl x: bit<32>; decl a: bit<10>[10 bank 10]; view v = a[5 * x :]; """) - } - } - it("width must be statically known") { - assertThrows[ParserError] { + it("width must be statically known"): + assertThrows[ParserError]: typeCheck(""" decl x: bit<32>; decl a: bit<10>[10 bank 5]; view v = a[x * x :]; """) - } - } - it("suffix factor is a factor of the new banking") { + it("suffix factor is a factor of the new banking"): typeCheck(""" decl x: bit<32>; decl a: bit<10>[16 bank 8]; view v = a[6 * x : bank 2]; """) - } - } - describe("Simple rotation views") { - it("can describe arbitrary, unrestricted rotations") { + describe("Simple rotation views"): + it("can describe arbitrary, unrestricted rotations"): typeCheck(""" decl a: bit<10>[10 bank 5]; decl i: bit<32>; view v = a[i * i ! :]; """) - } - } - describe("Gadget checking") { - it("simple views aliasing same underlying array cannot be used together") { - assertThrows[AlreadyConsumed] { + describe("Gadget checking"): + it("simple views aliasing same underlying array cannot be used together"): + assertThrows[AlreadyConsumed]: typeCheck(""" decl a: bit<32>[10 bank 2][10 bank 5]; view v1 = a[0!:][0!:]; view v2 = a[1!:][2!:]; v1[0][0]; v2[0][0]; """) - } - } - it("views created from other views cannot be used together") { - assertThrows[AlreadyConsumed] { + it("views created from other views cannot be used together"): + assertThrows[AlreadyConsumed]: typeCheck(""" decl a: bit<32>[10 bank 2][10 bank 5]; view v1 = a[0!:][0!:]; view v2 = v1[1!:][2!:]; a[0][0]; v2[0][0]; """) - } - } it( "split views created from the same underlying arrays cannot be used together" - ) { - assertThrows[AlreadyConsumed] { + ): + assertThrows[AlreadyConsumed]: typeCheck(""" decl a: bit<32>[10 bank 2][10 bank 2]; view v1 = a[0!:][0!:]; split v2 = a[by 2][by 2]; v2[0][1][0][2]; v1[0][0]; """) - } - } - it("arrays have fine grained bank tracking") { + it("arrays have fine grained bank tracking"): typeCheck(""" decl a: float[2 bank 2][10 bank 2]; a[0][0]; a[1][1]; """) - } - it("simple views dont have fine grained bank tracking") { - assertThrows[AlreadyConsumed] { + it("simple views dont have fine grained bank tracking"): + assertThrows[AlreadyConsumed]: typeCheck(""" decl a: float[10 bank 2][10]; view v1 = a[2 * 1: +2][0!:]; v1[0][0]; v1[1][0]; """) - } - } - it("views defined in sequencing are available") { + it("views defined in sequencing are available"): typeCheck(""" let A: float{2}[8 bank 8]; @@ -1335,8 +1091,7 @@ class TypeCheckerSpec extends AnyFunSpec { A_2[0]; } """) - } - it("physical resources defined in sequencing are available") { + it("physical resources defined in sequencing are available"): typeCheck(""" let A: float{2}[8]; --- @@ -1344,19 +1099,16 @@ class TypeCheckerSpec extends AnyFunSpec { --- let x = A[1] + B[1]; """) - } - } - describe("Loop iterators") { - it("can be used for arithmetic") { + describe("Loop iterators"): + it("can be used for arithmetic"): typeCheck(""" for (let i = 0..10) { let x = i * 2; } """) - } - it("can be used for comparisons") { + it("can be used for comparisons"): typeCheck(""" let temp = 0; for (let i = 0..10) { @@ -1365,9 +1117,8 @@ class TypeCheckerSpec extends AnyFunSpec { } } """) - } - it("can be passed to functions with int types") { + it("can be passed to functions with int types"): typeCheck(""" def test(a: bit<32>) = { let test2 = a; @@ -1377,38 +1128,33 @@ class TypeCheckerSpec extends AnyFunSpec { test(i); } """) - } - it("can be used with bit shifts") { + it("can be used with bit shifts"): typeCheck(""" for (let i = 0..10) { let x = i | 2; } """) - } - it("with arithmetic, can be used for access") { + it("with arithmetic, can be used for access"): typeCheck(""" decl a: bit<10>[10]; for (let i = 0..10) { a[i * 2]; } """) - } - } - describe("Loop pipelining") { - it("allowed on simple for loop") { + describe("Loop pipelining"): + it("allowed on simple for loop"): typeCheck(""" for (let i = 0..4) pipeline { let a = 1 + 2; let b = 3 + 4; } """) - } - it("disallowed on sequenced for loop") { - assertThrows[PipelineError] { + it("disallowed on sequenced for loop"): + assertThrows[PipelineError]: typeCheck(""" for (let i = 0..4) pipeline { let a = 1 + 2; @@ -1416,10 +1162,8 @@ class TypeCheckerSpec extends AnyFunSpec { let b = 3 + 4; } """) - } - } - it("allowed on simple while loop") { + it("allowed on simple while loop"): typeCheck(""" let x = 10; while (x < 100) pipeline { @@ -1427,10 +1171,9 @@ class TypeCheckerSpec extends AnyFunSpec { let b = 3 + 4; } """) - } - it("disallowed on sequenced while loop") { - assertThrows[PipelineError] { + it("disallowed on sequenced while loop"): + assertThrows[PipelineError]: typeCheck(""" let x = 10; while (x < 100) pipeline { @@ -1439,20 +1182,16 @@ class TypeCheckerSpec extends AnyFunSpec { let b = 3 + 4; } """) - } - } - } - describe("Records") { - it("type can be defined") { + describe("Records"): + it("type can be defined"): typeCheck(""" record point { x: bit<32>; y: bit<32> } """) - } - it("can be used in a declaration") { + it("can be used in a declaration"): typeCheck(""" record point { x: bit<32>; @@ -1460,8 +1199,7 @@ class TypeCheckerSpec extends AnyFunSpec { } decl k: point; """) - } - it("can be used in a declaration nested declaration") { + it("can be used in a declaration nested declaration"): typeCheck(""" record point { x: bit<32>; @@ -1471,27 +1209,22 @@ class TypeCheckerSpec extends AnyFunSpec { k: point } """) - } - it("cannot use undefined type aliases") { - assertThrows[Unbound] { + it("cannot use undefined type aliases"): + assertThrows[Unbound]: typeCheck(""" record bars { k: point } """) - } - } - it("cannot contain arrays") { - assertThrows[ArrayInRecord] { + it("cannot contain arrays"): + assertThrows[ArrayInRecord]: typeCheck(""" record bars { k: bit<10>[10] } """) - } - } - it("cannot rebind type alias") { - assertThrows[AlreadyBound] { + it("cannot rebind type alias"): + assertThrows[AlreadyBound]: typeCheck(""" record bars { k: bit<32> @@ -1500,9 +1233,7 @@ class TypeCheckerSpec extends AnyFunSpec { l: bit<32> } """) - } - } - it("can access bound field") { + it("can access bound field"): typeCheck(""" record point { x: bit<32> @@ -1510,8 +1241,7 @@ class TypeCheckerSpec extends AnyFunSpec { decl k: point; let x = k.x; """) - } - it("can bound field has the right return type") { + it("can bound field has the right return type"): typeCheck(""" record point { x: bit<32> @@ -1519,8 +1249,7 @@ class TypeCheckerSpec extends AnyFunSpec { decl k: point; let x = k.x + 1; """) - } - it("can bound field has the right return type in nested struct") { + it("can bound field has the right return type in nested struct"): typeCheck(""" record point { x: bit<32> @@ -1531,99 +1260,76 @@ class TypeCheckerSpec extends AnyFunSpec { decl k: foo; let x = k.p.x + 1; """) - } - it("should not throw error on casting") { + it("should not throw error on casting"): typeCheck(""" record point { x: ubit<32> } let a: point = {x=10}; let b: point = (a as point); """) - } - } - describe("Record Literals") { - it("can be defined with let") { + describe("Record Literals"): + it("can be defined with let"): typeCheck(""" record point { x: bit<32>; y: bit<32> } let p: point = {x = 1; y = 2 }; """) - } - it("cannot be defined without explicit type in let") { - assertThrows[ExplicitTypeMissing] { + it("cannot be defined without explicit type in let"): + assertThrows[ExplicitTypeMissing]: typeCheck(""" record point { x: bit<32>; y: bit<32> } let p = {x = 1; y = 2 }; """) - } - } - it("cannot be used inside expressions") { - assertThrows[NotInBinder] { + it("cannot be used inside expressions"): + assertThrows[NotInBinder]: typeCheck(""" record point { x: bit<32>; y: bit<32> } let p = 1 + {x = 1; y = 2 }; """) - } - } - it("get the right type") { + it("get the right type"): typeCheck(""" record point { x: bit<32>; y: bit<32> } let p: point = {x = 1; y = 2 }; let f: bit<32> = p.x; """) - } - it("cannot have fields missing") { - assertThrows[MissingField] { + it("cannot have fields missing"): + assertThrows[MissingField]: typeCheck(""" record point { x: bit<32>; y: bit<32> } let p: point = {x = 1}; """) - } - } - it("cannot have extra fields") { - assertThrows[ExtraField] { + it("cannot have extra fields"): + assertThrows[ExtraField]: typeCheck(""" record point { x: bit<32> } let p: point = {x = 1; y = 2}; """) - } - } - } - describe("Array literals") { - it("requires explicit type in the let binder") { - assertThrows[ExplicitTypeMissing] { + describe("Array literals"): + it("requires explicit type in the let binder"): + assertThrows[ExplicitTypeMissing]: typeCheck(""" let x = {1, 2, 3}; """) - } - } - it("does not support multidimensional literals") { - assertThrows[Unsupported] { + it("does not support multidimensional literals"): + assertThrows[Unsupported]: typeCheck(""" let x: bit<32>[10][10] = {1, 2, 3}; """) - } - } - it("requires literal to have the same size") { - assertThrows[LiteralLengthMismatch] { + it("requires literal to have the same size"): + assertThrows[LiteralLengthMismatch]: typeCheck(""" let x: bit<32>[5] = {1, 2, 3}; """) - } - } - it("requires subtypes in the array literal") { - assertThrows[UnexpectedSubtype] { + it("requires subtypes in the array literal"): + assertThrows[UnexpectedSubtype]: typeCheck(""" let x: bit<32>[3] = {true, false, true}; """) - } - } - it("can be banked") { + it("can be banked"): typeCheck(""" let x: bool[3 bank 3] = {true, false, true}; """) - } - it("can be used without initializer") { + it("can be used without initializer"): typeCheck(""" let x: bool[3] = {true, false, true}; { @@ -1632,61 +1338,50 @@ class TypeCheckerSpec extends AnyFunSpec { x[0] := false; } """) - } - } - describe("Indexing with dynamic (sized) var") { - it("works with an unbanked array") { + describe("Indexing with dynamic (sized) var"): + it("works with an unbanked array"): typeCheck("decl a: bit<10>[10]; decl x: bit<10>; a[x] := 5;") - } - it("doesn't work with banked array") { - assertThrows[InvalidDynamicIndex] { + it("doesn't work with banked array"): + assertThrows[InvalidDynamicIndex]: typeCheck("decl a: bit<10>[10 bank 5]; decl x: bit<10>; a[x] := 5;") - } - } - } - describe("Subtyping relations") { - it("static ints are always subtypes") { + describe("Subtyping relations"): + it("static ints are always subtypes"): typeCheck("1 == 2;") - } - it("smaller sized ints are subtypes of larger sized ints") { + it("smaller sized ints are subtypes of larger sized ints"): typeCheck(""" decl x: bit<16>; decl y: bit<32>; x == y; """) - } - it("static ints are subtypes of index types") { + it("static ints are subtypes of index types"): typeCheck(""" for (let i = 0..12) { i == 1; } """) - } - it("index types are subtypes of sized ints") { + it("index types are subtypes of sized ints"): typeCheck(""" decl x: bit<32>; for (let i = 0..12) { i == x; } """) - } - it("index types get upcast to sized int with log2(maxVal)") { + it("index types get upcast to sized int with log2(maxVal)"): typeCheck(""" decl arr:bit<32>[10]; for (let i = 0..33) { arr[5] := i * 1; } """) - } - it("Array subtyping is not allowed") { - assertThrows[UnexpectedSubtype] { + it("Array subtyping is not allowed"): + assertThrows[UnexpectedSubtype]: typeCheck(""" def foo(b: bit<10>[10]) = { b[0] := 10; // overflows @@ -1695,10 +1390,8 @@ class TypeCheckerSpec extends AnyFunSpec { decl a: bit<2>[10]; foo(a); """) - } - } - it("join of index types is dynamic") { + it("join of index types is dynamic"): typeCheck(""" for (let i = 0..10) { for (let j = 0..2) { @@ -1706,9 +1399,8 @@ class TypeCheckerSpec extends AnyFunSpec { } } """) - } - it("equal types have joins") { + it("equal types have joins"): typeCheck(""" let x = true; let y = false; @@ -1718,55 +1410,44 @@ class TypeCheckerSpec extends AnyFunSpec { let i2: bit<32> = 11; i1 == i2; """) - } - it("float is a subtype of double") { + it("float is a subtype of double"): typeCheck(""" decl x: float; decl y: float; let z: double = x + y; """) - } - it("fix is not a subtype of double") { - assertThrows[UnexpectedSubtype] { + it("fix is not a subtype of double"): + assertThrows[UnexpectedSubtype]: typeCheck(""" decl x: fix<3,2>; decl y: fix<3,2>; let z: double = x + y; """) - } - } - it("unsigned and signed bit types are incomparable") { - assertThrows[NoJoin] { + it("unsigned and signed bit types are incomparable"): + assertThrows[NoJoin]: typeCheck(""" decl x: ubit<32>; decl y: bit<32>; let z = x + y; """) - } - } - it("unsigned and signed fixed point types are incomparable") { - assertThrows[NoJoin] { + it("unsigned and signed fixed point types are incomparable"): + assertThrows[NoJoin]: typeCheck(""" decl x: ufix<32,16>; decl y: fix<32,16>; let z = x + y; """) - } - } it( "negative rational number cannot be assigend to unsigned fixed point type" - ) { - assertThrows[UnexpectedSubtype] { + ): + assertThrows[UnexpectedSubtype]: typeCheck(""" let z:ufix<32,16> = -0.5; """) - } - } - } - describe("Imports") { - it("adds imported functions into type checking scope") { + describe("Imports"): + it("adds imported functions into type checking scope"): typeCheck(""" import vivado("print.cpp") { def print_vect(f: float[4]); @@ -1774,37 +1455,30 @@ class TypeCheckerSpec extends AnyFunSpec { decl a: float[4]; print_vect(a); """) - } - } - describe("Bound Checking") { - it("static ints smaller than array size are valid") { + describe("Bound Checking"): + it("static ints smaller than array size are valid"): typeCheck(""" decl a: bit<32>[10]; let v = a[9]; """) - } - it("static ints larger than array size fail") { - assertThrows[IndexOutOfBounds] { + it("static ints larger than array size fail"): + assertThrows[IndexOutOfBounds]: typeCheck(""" decl a: bit<32>[10]; let v = a[10]; """) - } - } - it("nested expressions are checked") { - assertThrows[IndexOutOfBounds] { + it("nested expressions are checked"): + assertThrows[IndexOutOfBounds]: typeCheck(""" decl a: bit<32>[10]; decl b: bit<32>[10]; let v = a[9] + b[10]; """) - } - } - it("array access with index types is valid when maxVal <= array length") { + it("array access with index types is valid when maxVal <= array length"): typeCheck(""" decl a: bit<32>[10]; @@ -1812,10 +1486,9 @@ class TypeCheckerSpec extends AnyFunSpec { a[i] := 0; } """) - } - it("array access with index types fails when maxVal > array length") { - assertThrows[IndexOutOfBounds] { + it("array access with index types fails when maxVal > array length"): + assertThrows[IndexOutOfBounds]: typeCheck(""" decl a: bit<32>[10]; @@ -1823,80 +1496,66 @@ class TypeCheckerSpec extends AnyFunSpec { a[i] := 0; } """) - } - } - it("simple views with prefixes have their bounds checked") { - assertThrows[IndexOutOfBounds] { + it("simple views with prefixes have their bounds checked"): + assertThrows[IndexOutOfBounds]: typeCheck(""" decl a: bit<32>[10]; view v_a = a[0!:+3]; v_a[3]; """) - } - } - it("creation of simple views is bounds checked") { - assertThrows[IndexOutOfBounds] { + it("creation of simple views is bounds checked"): + assertThrows[IndexOutOfBounds]: typeCheck(""" decl a: bit<32>[10]; for (let i = 0..10) { view v_a = a[i!:+3]; } """) - } - } - } - describe("Explicit casting") { - it("safe to cast integer types to float") { + describe("Explicit casting"): + it("safe to cast integer types to float"): typeCheck(""" decl x: float; decl y: bit<32>; (y as float) + x; """) - } - it("safe to cast bit to fix as long as the integer length part is shorter") { + it("safe to cast bit to fix as long as the integer length part is shorter"): typeCheck(""" decl x: fix<32,16>; decl y: bit<16>; (y as fix<32,16>) + x; """) - } - it("warning when casting float to bit type") { + it("warning when casting float to bit type"): typeCheck(""" decl x: float; decl y: bit<32>; (x as bit<32>) + y; """) - } - it("warning when casting fix to bit type") { + it("warning when casting fix to bit type"): typeCheck(""" decl x: fix<32,16>; decl y: bit<16>; (x as bit<16>) + y; """) - } - it("safe to cast integer float to double") { + it("safe to cast integer float to double"): typeCheck(""" decl x: float; decl y: double; y + (x as double); """) - } - it("safe to cast fix to double") { + it("safe to cast fix to double"): typeCheck(""" decl x: fix<10,5>; decl y: double; y + (x as double); """) - } - } - describe("Array access dependent on loop iteration variables") { + describe("Array access dependent on loop iteration variables"): it( "assignments in both branches of a conditional can override variable that depends on loop iteration" - ) { + ): typeCheck(""" let foo: float[2]; for (let j = 0..2) unroll 2 { @@ -1910,11 +1569,10 @@ class TypeCheckerSpec extends AnyFunSpec { foo[1 + y]; } }""") - } it( "single branche conditional does not override variable that depends on loop iteration" - ) { - assertThrows[TypeError] { + ): + assertThrows[TypeError]: typeCheck(""" let foo: float[2]; for (let j = 0..2) unroll 2 { @@ -1926,10 +1584,8 @@ class TypeCheckerSpec extends AnyFunSpec { foo[1 + y]; } }""") - } - } - it("can't have array access that depends on loop iterator") { - assertThrows[TypeError] { + it("can't have array access that depends on loop iterator"): + assertThrows[TypeError]: typeCheck(""" let foo: float[2]; for (let j = 0..2) unroll 2 { @@ -1938,10 +1594,8 @@ class TypeCheckerSpec extends AnyFunSpec { foo[1 + y]; } }""") - } - } - it("Chain of assignments that leads back to loop iterator") { - assertThrows[TypeError] { + it("Chain of assignments that leads back to loop iterator"): + assertThrows[TypeError]: typeCheck(""" let foo: float[2]; for (let j = 0..2) unroll 2 { @@ -1953,11 +1607,9 @@ class TypeCheckerSpec extends AnyFunSpec { foo[1 + a]; } }""") - } - } it( "Chain of assignments that leads back to loop iterator + conditional overwriting of last var" - ) { + ): typeCheck(""" let foo: float[2]; for (let j = 0..2) unroll 2 { @@ -1974,6 +1626,3 @@ class TypeCheckerSpec extends AnyFunSpec { foo[1 + a]; } }""") - } - } -}