Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite to Scala 3 syntax using -rewrite flag #423

Merged
merged 1 commit into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ Compile / resourceGenerators += Def.task {
}

/* sbt-assembly configuration: build an executable jar. */
//assembly / assemblyOption := (assembly / assemblyOption).value.copy(
// prependShellScript = Some(sbtassembly.AssemblyPlugin.defaultShellScript)
//)
ThisBuild / assemblyPrependShellScript := Some(sbtassembly.AssemblyPlugin.defaultShellScript)
assembly / assemblyJarName := "fuse.jar"
assembly / test := {}
Expand Down
32 changes: 11 additions & 21 deletions src/main/scala/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import Configuration._
import Syntax._
import Transformer.{PartialTransformer, TypedPartialTransformer}

object Compiler {
object Compiler:

// Transformers to execute *before* type checking.
val preTransformers: List[(String, PartialTransformer)] = List(
Expand All @@ -28,28 +28,25 @@ object Compiler {
"Add bitwidth" -> (passes.AddBitWidth, true)
)

def showDebug(ast: Prog, pass: String, c: Config): Unit = {
if c.passDebug then {
def showDebug(ast: Prog, pass: String, c: Config): Unit =
if c.passDebug then
val top = ("=" * 15) + pass + ("=" * 15)
println(top)
println(Pretty.emitProg(ast)(c.logLevel == scribe.Level.Debug).trim)
println("=" * top.length)
}
}

def toBackend(str: BackendOption): fuselang.backend.Backend = str match {
def toBackend(str: BackendOption): fuselang.backend.Backend = str match
case Vivado => backend.VivadoBackend
case Cpp => backend.CppRunnable
case Calyx => backend.calyx.CalyxBackend
}

def checkStringWithError(prog: String, c: Config = emptyConf) = {
def checkStringWithError(prog: String, c: Config = emptyConf) =
val preAst = Parser(prog).parse()

showDebug(preAst, "Original", c)

// Run pre transformers if lowering is enabled
val ast = if c.enableLowering then {
val ast = if c.enableLowering then
preTransformers.foldLeft(preAst)({
case (ast, (name, pass)) => {
val newAst = pass.rewrite(ast)
Expand All @@ -70,9 +67,8 @@ object Compiler {
} */
}
})
} else {
else
preAst
}
passes.WellFormedChecker.check(ast)
typechecker.TypeChecker.typeCheck(ast);
showDebug(ast, "Type Checking", c)
Expand All @@ -83,9 +79,8 @@ object Compiler {
showDebug(ast, "Capability Checking", c)
typechecker.AffineChecker.check(ast); // Doesn't modify the AST
ast
}

def codegen(ast: Prog, c: Config = emptyConf) = {
def codegen(ast: Prog, c: Config = emptyConf) =
// Filter out transformers not running in this mode
val toRun = postTransformers.filter({
case (_, (_, onlyLower)) => {
Expand All @@ -101,14 +96,12 @@ object Compiler {
}
})
toBackend(c.backend).emit(transformedAst, c)
}

// Outputs red text to the console
def red(txt: String): String = {
def red(txt: String): String =
Console.RED + txt + Console.RESET
}

def compileString(prog: String, c: Config): Either[String, String] = {
def compileString(prog: String, c: Config): Either[String, String] =
Try(codegen(checkStringWithError(prog, c), c)).toEither.left
.map(err => {
scribe.info(err.getStackTrace().take(10).mkString("\n"))
Expand Down Expand Up @@ -136,13 +129,12 @@ object Compiler {
val commentPre = toBackend(c.backend).commentPrefix
s"$commentPre $meta\n" + out
})
}

def compileStringToFile(
prog: String,
c: Config,
out: String
): Either[String, Path] = {
): Either[String, Path] =

compileString(prog, c).map(p => {
Files.write(
Expand All @@ -153,6 +145,4 @@ object Compiler {
StandardOpenOption.WRITE
)
})
}

}
26 changes: 9 additions & 17 deletions src/main/scala/GenerateExec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import common.CompilerError.HeaderMissing
* Provides utilities to compile a program and link it with headers required
* by the CppRunnable backend.
*/
object GenerateExec {
object GenerateExec:
// TODO(rachit): Move this to build.sbt
val headers = List("parser.cpp", "json.hpp")

Expand All @@ -19,18 +19,18 @@ object GenerateExec {


// Not the compiler directory, check if the fallback directory has been setup.
if Files.exists(headerLocation) == false then {
if Files.exists(headerLocation) == false then
// Fallback for headers not setup. Unpack headers from JAR file.
headerLocation = headerFallbackLocation

if Files.exists(headerFallbackLocation) == false then {
if Files.exists(headerFallbackLocation) == false then
scribe.warn(
s"Missing headers required for `fuse run`." +
s" Unpacking from JAR file into $headerFallbackLocation."
)

val dir = Files.createDirectory(headerFallbackLocation)
for header <- headers do {
for header <- headers do
val stream = getClass.getResourceAsStream(s"/headers/$header")
val hdrSource = Source.fromInputStream(stream).toArray.map(_.toByte)
Files.write(
Expand All @@ -39,9 +39,6 @@ object GenerateExec {
StandardOpenOption.CREATE_NEW,
StandardOpenOption.WRITE
)
}
}
}

/**
* Generates an executable object [[out]]. Assumes that [[src]] is a valid
Expand All @@ -54,14 +51,12 @@ object GenerateExec {
src: Path,
out: String,
compilerOpts: List[String]
): Either[String, Int] = {
): Either[String, Int] =

// Make sure all headers are downloaded.
for header <- headers do {
if Files.exists(headerLocation.resolve(header)) == false then {
for header <- headers do
if Files.exists(headerLocation.resolve(header)) == false then
throw HeaderMissing(header, headerLocation.toString)
}
}

val CXX =
Seq("g++", "-g", "--std=c++14", "-Wall", "-I", headerLocation.toString) ++ compilerOpts
Expand All @@ -75,10 +70,7 @@ object GenerateExec {
scribe.info(cmd.mkString(" "))
val status = cmd ! logger

if status != 0 then {
if status != 0 then
Left(s"Failed to generate the executable $out.\n${stderr}")
} else {
else
Right(status)
}
}
}
6 changes: 2 additions & 4 deletions src/main/scala/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ object Main:
})
.toMap

val parser = new scopt.OptionParser[Config]("fuse") {
val parser = new scopt.OptionParser[Config]("fuse"):

head(s"Dahlia (sha = ${meta("git.hash")}, status = ${meta("git.status")})")

Expand Down Expand Up @@ -112,16 +112,14 @@ object Main:
.action((f, c) => c.copy(output = Some(f)))
.text("Name of the output artifact.")
)
}

def runWithConfig(conf: Config): Either[String, Int] =
type ErrString = String

val path = conf.srcFile.toPath
val prog = Files.exists(path) match {
val prog = Files.exists(path) match
case true => Right(new String(Files.readAllBytes(path)))
case false => Left(s"$path: No such file in working directory")
}

val cppPath: Either[ErrString, Option[Path]] = prog.flatMap(prog =>
conf.output match {
Expand Down
39 changes: 13 additions & 26 deletions src/main/scala/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,70 +2,57 @@ package fuselang
import scala.{PartialFunction => PF}
import scala.math.{log10, ceil}

object Utils {
object Utils:

implicit class RichOption[A](opt: => Option[A]) {
def getOrThrow[T <: Throwable](except: T) = opt match {
implicit class RichOption[A](opt: => Option[A]):
def getOrThrow[T <: Throwable](except: T) = opt match
case Some(v) => v
case None => throw except
}
}

// https://codereview.stackexchange.com/questions/14561/matching-bigints-in-scala
// TODO: This can overflow and result in an runtime exception
object Big {
object Big:
def unapply(n: BigInt) = Some(n.toInt)
}

def bitsNeeded(n: Int): Int = n match {
def bitsNeeded(n: Int): Int = n match
case 0 => 1
case n if n > 0 => ceil(log10(n + 1) / log10(2)).toInt
case n if n < 0 => bitsNeeded(n.abs) + 1
}

def bitsNeeded(n: BigInt): Int = n match {
def bitsNeeded(n: BigInt): Int = n match
case Big(0) => 1
case n if n > 0 => ceil(log10((n + 1).toDouble) / log10(2)).toInt
case n if n < 0 => bitsNeeded(n.abs) + 1
}

def cartesianProduct[T](llst: Seq[Seq[T]]): Seq[Seq[T]] = {
def cartesianProduct[T](llst: Seq[Seq[T]]): Seq[Seq[T]] =
def pel(e: T, ll: Seq[Seq[T]], a: Seq[Seq[T]] = Nil): Seq[Seq[T]] =
ll match {
ll match
case Nil => a.reverse
case x +: xs => pel(e, xs, (e +: x) +: a)
}

llst match {
llst match
case Nil => Nil
case x +: Nil => x.map(Seq(_))
case x +: _ =>
x match {
x match
case Nil => Nil
case _ =>
llst
.foldRight(Seq(x))((l, a) => l.flatMap(x => pel(x, a)))
.map(_.dropRight(x.size))
}
}
}


@inline def asPartial[A, B, C](f: (A, B) => C): PF[(A, B), C] = {
@inline def asPartial[A, B, C](f: (A, B) => C): PF[(A, B), C] =
case (a, b) => f(a, b)
}

@inline def assertOrThrow[T <: Throwable](cond: Boolean, except: => T) = {
@inline def assertOrThrow[T <: Throwable](cond: Boolean, except: => T) =
if !cond then throw except
}

@deprecated(
"pr is used for debugging. Remove all call to it before committing",
"fuse 0.0.1"
)
@inline def pr[T](v: T) = {
@inline def pr[T](v: T) =
println(v)
v
}

}
9 changes: 3 additions & 6 deletions src/main/scala/backends/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@ import CompilerError.BackendError
/**
* Abstract definition of a Fuse backend.
*/
trait Backend {
trait Backend:

def emit(p: Syntax.Prog, c: Configuration.Config): String = {
if c.header && (canGenerateHeader == false) then {
def emit(p: Syntax.Prog, c: Configuration.Config): String =
if c.header && (canGenerateHeader == false) then
throw BackendError(s"Backend $this does not support header generation.")
}
emitProg(p, c)
}

/**
* Generate a String representation of the Abstract Syntax Tree of the
Expand All @@ -32,4 +30,3 @@ trait Backend {
*/
val commentPrefix: String = "//"

}
Loading
Loading