Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Direct to firrtl #829

Merged
merged 13 commits into from
Jul 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chiselFrontend/src/main/scala/chisel3/core/Printable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ case class PString(str: String) extends Printable {
(str replaceAll ("%", "%%"), List.empty)
}
/** Superclass for Firrtl format specifiers for Bits */
sealed abstract class FirrtlFormat(specifier: Char) extends Printable {
sealed abstract class FirrtlFormat(private[chisel3] val specifier: Char) extends Printable {
def bits: Bits
def unpack(ctx: Component): (String, Iterable[String]) = {
(s"%$specifier", List(bits.ref.fullName(ctx)))
Expand Down
44 changes: 38 additions & 6 deletions src/main/scala/chisel3/Driver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

package chisel3

import chisel3.internal.firrtl.Emitter
import chisel3.internal.firrtl.Converter
import chisel3.experimental.{RawModule, RunFirrtlTransform}

import java.io._
Expand Down Expand Up @@ -92,7 +92,9 @@ object Driver extends BackendCompilationUtilities {
*/
def elaborate[T <: RawModule](gen: () => T): Circuit = internal.Builder.build(Module(gen()))

def emit[T <: RawModule](gen: () => T): String = Emitter.emit(elaborate(gen))
def toFirrtl(ir: Circuit): firrtl.ir.Circuit = Converter.convert(ir)

def emit[T <: RawModule](gen: () => T): String = Driver.emit(elaborate(gen))

def emit[T <: RawModule](ir: Circuit): String = Emitter.emit(ir)

Expand All @@ -108,14 +110,41 @@ object Driver extends BackendCompilationUtilities {
}
}

/** Dumps the elaborated Circuit to FIRRTL
*
* If no File is given as input, it will dump to a default filename based on the name of the
* Top Module
*
* @param c Elaborated Chisel Circuit
* @param optName Optional File to dump to
* @return The File the circuit was dumped to
*/
def dumpFirrtl(ir: Circuit, optName: Option[File]): File = {
val f = optName.getOrElse(new File(ir.name + ".fir"))
val w = new FileWriter(f)
w.write(Emitter.emit(ir))
w.write(Driver.emit(ir))
w.close()
f
}

/** Dumps the elaborated Circuit to ProtoBuf
*
* If no File is given as input, it will dump to a default filename based on the name of the
* Top Module
*
* @param c Elaborated Chisel Circuit
* @param optFile Optional File to dump to
* @return The File the circuit was dumped to
*/
def dumpProto(c: Circuit, optFile: Option[File]): File = {
val f = optFile.getOrElse(new File(c.name + ".pb"))
val ostream = new java.io.FileOutputStream(f)
// Lazily convert modules to make intermediate objects garbage collectable
val modules = c.components.map(m => () => Converter.convert(m))
firrtl.proto.ToProto.writeToStreamFast(ostream, ir.NoInfo, modules, c.name)
f
}

private var target_dir: Option[String] = None
def parseArgs(args: Array[String]): Unit = {
for (i <- 0 until args.size) {
Expand Down Expand Up @@ -145,15 +174,18 @@ object Driver extends BackendCompilationUtilities {
val firrtlOptions = optionsManager.firrtlOptions
val chiselOptions = optionsManager.chiselOptions

// use input because firrtl will be reading this
val firrtlString = Emitter.emit(circuit)
val firrtlCircuit = Converter.convert(circuit)

// Still emit to leave an artifact (and because this always has been the behavior)
val firrtlString = Driver.emit(circuit)
val firrtlFileName = firrtlOptions.getInputFileName(optionsManager)
val firrtlFile = new File(firrtlFileName)

val w = new FileWriter(firrtlFile)
w.write(firrtlString)
w.close()

// Emit the annotations because it has always been the behavior
val annotationFile = new File(optionsManager.getBuildFileName("anno.json"))
val af = new FileWriter(annotationFile)
val firrtlAnnos = circuit.annotations.map(_.toFirrtl)
Expand All @@ -174,7 +206,7 @@ object Driver extends BackendCompilationUtilities {
}
/* This passes the firrtl source and annotations directly to firrtl */
optionsManager.firrtlOptions = optionsManager.firrtlOptions.copy(
firrtlSource = Some(firrtlString),
firrtlCircuit = Some(firrtlCircuit),
annotations = optionsManager.firrtlOptions.annotations ++ firrtlAnnos,
customTransforms = optionsManager.firrtlOptions.customTransforms ++ transforms.toList)

Expand Down
263 changes: 263 additions & 0 deletions src/main/scala/chisel3/internal/firrtl/Converter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
// See LICENSE for license details.

package chisel3.internal.firrtl
import chisel3._
import chisel3.core.SpecifiedDirection
import chisel3.experimental._
import chisel3.internal.sourceinfo.{NoSourceInfo, SourceLine, SourceInfo}
import firrtl.{ir => fir}
import chisel3.internal.throwException

import scala.annotation.tailrec
import scala.collection.immutable.{Queue}

private[chisel3] object Converter {
// TODO modeled on unpack method on Printable, refactor?
def unpack(pable: Printable, ctx: Component): (String, Seq[Arg]) = pable match {
case Printables(pables) =>
val (fmts, args) = pables.map(p => unpack(p, ctx)).unzip
(fmts.mkString, args.flatten.toSeq)
case PString(str) => (str.replaceAll("%", "%%"), List.empty)
case format: FirrtlFormat =>
("%" + format.specifier, List(format.bits.ref))
case Name(data) => (data.ref.name, List.empty)
case FullName(data) => (data.ref.fullName(ctx), List.empty)
case Percent => ("%%", List.empty)
}

def convert(info: SourceInfo): fir.Info = info match {
case _: NoSourceInfo => fir.NoInfo
case SourceLine(fn, line, col) => fir.FileInfo(fir.StringLit(s"$fn $line:$col"))
}

def convert(op: PrimOp): fir.PrimOp = firrtl.PrimOps.fromString(op.name)

def convert(dir: MemPortDirection): firrtl.MPortDir = dir match {
case MemPortDirection.INFER => firrtl.MInfer
case MemPortDirection.READ => firrtl.MRead
case MemPortDirection.WRITE => firrtl.MWrite
case MemPortDirection.RDWR => firrtl.MReadWrite
}

// TODO
// * Memoize?
// * Move into the Chisel IR?
def convert(arg: Arg, ctx: Component): fir.Expression = arg match {
case Node(id) =>
convert(id.getRef, ctx)
case Ref(name) =>
fir.Reference(name, fir.UnknownType)
case Slot(imm, name) =>
fir.SubField(convert(imm, ctx), name, fir.UnknownType)
case Index(imm, ILit(idx)) =>
fir.SubIndex(convert(imm, ctx), idx.toInt, fir.UnknownType)
case Index(imm, value) =>
fir.SubAccess(convert(imm, ctx), convert(value, ctx), fir.UnknownType)
case ModuleIO(mod, name) =>
if (mod eq ctx.id) fir.Reference(name, fir.UnknownType)
else fir.SubField(fir.Reference(mod.getRef.name, fir.UnknownType), name, fir.UnknownType)
case u @ ULit(n, UnknownWidth()) =>
fir.UIntLiteral(n, fir.IntWidth(u.minWidth))
case ULit(n, w) =>
fir.UIntLiteral(n, convert(w))
case slit @ SLit(n, w) => fir.SIntLiteral(n, convert(w))
val unsigned = if (n < 0) (BigInt(1) << slit.width.get) + n else n
val uint = convert(ULit(unsigned, slit.width), ctx)
fir.DoPrim(firrtl.PrimOps.AsSInt, Seq(uint), Seq.empty, fir.UnknownType)
// TODO Simplify
case fplit @ FPLit(n, w, bp) =>
val unsigned = if (n < 0) (BigInt(1) << fplit.width.get) + n else n
val uint = convert(ULit(unsigned, fplit.width), ctx)
val lit = bp.asInstanceOf[KnownBinaryPoint].value
fir.DoPrim(firrtl.PrimOps.AsFixedPoint, Seq(uint), Seq(lit), fir.UnknownType)
case lit: ILit =>
throwException(s"Internal Error! Unexpected ILit: $lit")
}

/** Convert Commands that map 1:1 to Statements */
def convertSimpleCommand(cmd: Command, ctx: Component): Option[fir.Statement] = cmd match {
case e: DefPrim[_] =>
val consts = e.args.collect { case ILit(i) => i }
val args = e.args.flatMap {
case _: ILit => None
case other => Some(convert(other, ctx))
}
val expr = e.op.name match {
case "mux" =>
assert(args.size == 3, s"Mux with unexpected args: $args")
fir.Mux(args(0), args(1), args(2), fir.UnknownType)
case _ =>
fir.DoPrim(convert(e.op), args, consts, fir.UnknownType)
}
Some(fir.DefNode(convert(e.sourceInfo), e.name, expr))
case e @ DefWire(info, id) =>
Some(fir.DefWire(convert(info), e.name, extractType(id)))
case e @ DefReg(info, id, clock) =>
Some(fir.DefRegister(convert(info), e.name, extractType(id), convert(clock, ctx),
firrtl.Utils.zero, convert(id.getRef, ctx)))
case e @ DefRegInit(info, id, clock, reset, init) =>
Some(fir.DefRegister(convert(info), e.name, extractType(id), convert(clock, ctx),
convert(reset, ctx), convert(init, ctx)))
case e @ DefMemory(info, id, t, size) =>
Some(firrtl.CDefMemory(convert(info), e.name, extractType(t), size, false))
case e @ DefSeqMemory(info, id, t, size) =>
Some(firrtl.CDefMemory(convert(info), e.name, extractType(t), size, true))
case e: DefMemPort[_] =>
Some(firrtl.CDefMPort(convert(e.sourceInfo), e.name, fir.UnknownType,
e.source.fullName(ctx), Seq(convert(e.index, ctx), convert(e.clock, ctx)), convert(e.dir)))
case Connect(info, loc, exp) =>
Some(fir.Connect(convert(info), convert(loc, ctx), convert(exp, ctx)))
case BulkConnect(info, loc, exp) =>
Some(fir.PartialConnect(convert(info), convert(loc, ctx), convert(exp, ctx)))
case Attach(info, locs) =>
Some(fir.Attach(convert(info), locs.map(l => convert(l, ctx))))
case DefInvalid(info, arg) =>
Some(fir.IsInvalid(convert(info), convert(arg, ctx)))
case e @ DefInstance(info, id, _) =>
Some(fir.DefInstance(convert(info), e.name, id.name))
case Stop(info, clock, ret) =>
Some(fir.Stop(convert(info), ret, convert(clock, ctx), firrtl.Utils.one))
case Printf(info, clock, pable) =>
val (fmt, args) = unpack(pable, ctx)
Some(fir.Print(convert(info), fir.StringLit(fmt),
args.map(a => convert(a, ctx)), convert(clock, ctx), firrtl.Utils.one))
case _ => None
}

/** Internal datastructure to help translate Chisel's flat Command structure to FIRRTL's AST
*
* In particular, when scoping is translated from flat with begin end to a nested datastructure
*
* @param when Current when Statement, holds info, condition, and consequence as they are
* available
* @param outer Already converted Statements that precede the current when block in the scope in
* which the when is defined (ie. 1 level up from the scope inside the when)
* @param alt Indicates if currently processing commands in the alternate (else) of the when scope
*/
// TODO we should probably have a different structure in the IR to close elses
private case class WhenFrame(when: fir.Conditionally, outer: Queue[fir.Statement], alt: Boolean)

/** Convert Chisel IR Commands into FIRRTL Statements
*
* @note ctx is needed because references to ports translate differently when referenced within
* the module in which they are defined vs. parent modules
* @param cmds Chisel IR Commands to convert
* @param ctx Component (Module) context within which we are translating
* @return FIRRTL Statement that is equivalent to the input cmds
*/
def convert(cmds: Seq[Command], ctx: Component): fir.Statement = {
@tailrec
def rec(acc: Queue[fir.Statement],
scope: List[WhenFrame])
(cmds: Seq[Command]): Seq[fir.Statement] = {
if (cmds.isEmpty) {
assert(scope.isEmpty)
acc
} else convertSimpleCommand(cmds.head, ctx) match {
// Most Commands map 1:1
case Some(stmt) =>
rec(acc :+ stmt, scope)(cmds.tail)
// When scoping logic does not map 1:1 and requires pushing/popping WhenFrames
// Please see WhenFrame for more details
case None => cmds.head match {
case WhenBegin(info, pred) =>
val when = fir.Conditionally(convert(info), convert(pred, ctx), fir.EmptyStmt, fir.EmptyStmt)
val frame = WhenFrame(when, acc, false)
rec(Queue.empty, frame +: scope)(cmds.tail)
case WhenEnd(info, depth, _) =>
val frame = scope.head
val when = if (frame.alt) frame.when.copy(alt = fir.Block(acc))
else frame.when.copy(conseq = fir.Block(acc))
// Check if this when has an else
cmds.tail.headOption match {
case Some(AltBegin(_)) =>
assert(!frame.alt, "Internal Error! Unexpected when structure!") // Only 1 else per when
rec(Queue.empty, frame.copy(when = when, alt = true) +: scope.tail)(cmds.drop(2))
case _ => // Not followed by otherwise
// If depth > 0 then we need to close multiple When scopes so we add a new WhenEnd
// If we're nested we need to add more WhenEnds to ensure each When scope gets
// properly closed
val cmdsx = if (depth > 0) WhenEnd(info, depth - 1, false) +: cmds.tail else cmds.tail
rec(frame.outer :+ when, scope.tail)(cmdsx)
}
case OtherwiseEnd(info, depth) =>
val frame = scope.head
val when = frame.when.copy(alt = fir.Block(acc))
// TODO For some reason depth == 1 indicates the last closing otherwise whereas
// depth == 0 indicates last closing when
val cmdsx = if (depth > 1) OtherwiseEnd(info, depth - 1) +: cmds.tail else cmds.tail
rec(scope.head.outer :+ when, scope.tail)(cmdsx)
}
}
}
fir.Block(rec(Queue.empty, List.empty)(cmds))
}

def convert(width: Width): fir.Width = width match {
case UnknownWidth() => fir.UnknownWidth
case KnownWidth(value) => fir.IntWidth(value)
}

def convert(bp: BinaryPoint): fir.Width = bp match {
case UnknownBinaryPoint => fir.UnknownWidth
case KnownBinaryPoint(value) => fir.IntWidth(value)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was confused on the exact replication of convert(BinaryPoint) and convert(Width), but this is how it's done in chisel3.internal.firrtl (with associated replication there). That may benefit from a refactor, but that is clearly out of scope of this encapsulated PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that was my thinking, a later refactor can clean a lot of stuff related to this and the Emitter


private def firrtlUserDirOf(d: Data): SpecifiedDirection = d match {
case d: Vec[_] =>
SpecifiedDirection.fromParent(d.specifiedDirection, firrtlUserDirOf(d.sample_element))
case d => d.specifiedDirection
}

def extractType(data: Data, clearDir: Boolean = false): fir.Type = data match {
case _: Clock => fir.ClockType
case d: UInt => fir.UIntType(convert(d.width))
case d: SInt => fir.SIntType(convert(d.width))
case d: FixedPoint => fir.FixedType(convert(d.width), convert(d.binaryPoint))
case d: Analog => fir.AnalogType(convert(d.width))
case d: Vec[_] => fir.VectorType(extractType(d.sample_element, clearDir), d.length)
case d: Record =>
val childClearDir = clearDir ||
d.specifiedDirection == SpecifiedDirection.Input || d.specifiedDirection == SpecifiedDirection.Output
def eltField(elt: Data): fir.Field = (childClearDir, firrtlUserDirOf(elt)) match {
case (true, _) => fir.Field(elt.getRef.name, fir.Default, extractType(elt, true))
case (false, SpecifiedDirection.Unspecified | SpecifiedDirection.Output) =>
fir.Field(elt.getRef.name, fir.Default, extractType(elt, false))
case (false, SpecifiedDirection.Flip | SpecifiedDirection.Input) =>
fir.Field(elt.getRef.name, fir.Flip, extractType(elt, false))
}
fir.BundleType(d.elements.toIndexedSeq.reverse.map { case (_, e) => eltField(e) })
}

def convert(name: String, param: Param): fir.Param = param match {
case IntParam(value) => fir.IntParam(name, value)
case DoubleParam(value) => fir.DoubleParam(name, value)
case StringParam(value) => fir.StringParam(name, fir.StringLit(value))
case RawParam(value) => fir.RawStringParam(name, value)
}
def convert(port: Port, topDir: SpecifiedDirection = SpecifiedDirection.Unspecified): fir.Port = {
val resolvedDir = SpecifiedDirection.fromParent(topDir, port.dir)
val dir = resolvedDir match {
case SpecifiedDirection.Unspecified | SpecifiedDirection.Output => fir.Output
case SpecifiedDirection.Flip | SpecifiedDirection.Input => fir.Input
}
val clearDir = resolvedDir match {
case SpecifiedDirection.Input | SpecifiedDirection.Output => true
case SpecifiedDirection.Unspecified | SpecifiedDirection.Flip => false
}
val tpe = extractType(port.id, clearDir)
fir.Port(fir.NoInfo, port.id.getRef.name, dir, tpe)
}

def convert(component: Component): fir.DefModule = component match {
case ctx @ DefModule(_, name, ports, cmds) =>
fir.Module(fir.NoInfo, name, ports.map(p => convert(p)), convert(cmds.toList, ctx))
case ctx @ DefBlackBox(id, name, ports, topDir, params) =>
fir.ExtModule(fir.NoInfo, name, ports.map(p => convert(p, topDir)), id.desiredName,
params.map { case (name, p) => convert(name, p) }.toSeq)
}

def convert(circuit: Circuit): fir.Circuit =
fir.Circuit(fir.NoInfo, circuit.components.map(convert), circuit.name)
}

6 changes: 2 additions & 4 deletions src/test/scala/chiselTests/PrintableSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ import chisel3.testers.BasicTester

/* Printable Tests */
class PrintableSpec extends FlatSpec with Matchers {
private val PrintfRegex = """\s*printf\((.*)\).*""".r
// This regex is brittle, it relies on the first two arguments of the printf
// not containing quotes, problematic if Chisel were to emit UInt<1>("h01")
// instead of the current UInt<1>(1) for the enable signal
// This regex is brittle, it specifically finds the clock and enable signals followed by commas
private val PrintfRegex = """\s*printf\(\w+, [^,]+,(.*)\).*""".r
private val StringRegex = """([^"]*)"(.*?)"(.*)""".r
private case class Printf(str: String, args: Seq[String])
private def getPrintfs(firrtl: String): Seq[Printf] = {
Expand Down