Skip to content

Commit

Permalink
Strong enums (#892)
Browse files Browse the repository at this point in the history
* Added new strongly-typed enum construct called "StrongEnum". "StrongEnum" will automatically generate annotations that HDL backends can use to mark components as enums

Removed "override val width" constructor parameter from "Element" so that classes with variable widths, like the new strong enums, can inherit from it

Changed the parameter types of certain functions, such as "switch", "is", and "LitArg.bindLitArg" from "Bits" to "Element", so that they can take the new strong enums as arguments

* Added tests for the new strong enums

* Changed StrongEnum exception names and made sure in StrongEnum tests that the correct types of exceptions are thrown

* Fixed bug where an enum's global annotation would not be set if it was used in multiple circuits

Made styling changes to StrongEnum.scala

* Reverted accidental changes to the AnnotatingDiamond test

* Changed the API for casting non-literal UInts to enums

Added an isValid function that checks whether or not enums have valid values

Calling getWidth on an enum's companion object now returns a BigInt instead of an Int

* Casting a literal to an enum using the StrongEnum.castFromNonLit(n) function is now simply a wrapper for StrongEnum.apply(n)

* Fixed compilation bug

* * Added "next" method to EnumType
* Renamed "castFromNonLit" to "fromBits"

* The FSM example in the test/scala/cookbook now uses StrongEnums

* * Changed strong enum API, so that users no longer have to declare both a class and a companion object for each strong enum

* Strong enums do not have to be static any longer

* * Added scope protections to ChiselEnum.Value so that users cannot call it
outside of a ChiselEnum definition

* Renamed ChiselEnum.Value type to ChiselEnum.Type so that we can give
it a companion object just like UInt and Bool do

* * Moved strong enums into experimental package
* Non-literal UInts can now be cast to enums with apply() rather than
fromBits()
* Reduced code-duplication by moving some functions from EnumType and
Bits to Element
  • Loading branch information
hngenc authored and chick committed Oct 12, 2018
1 parent 10d5472 commit 6004052
Show file tree
Hide file tree
Showing 12 changed files with 753 additions and 57 deletions.
66 changes: 33 additions & 33 deletions chiselFrontend/src/main/scala/chisel3/core/Bits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ import chisel3.internal.firrtl.PrimOp._
*
* @define coll element
*/
abstract class Element(private[chisel3] val width: Width) extends Data {
abstract class Element extends Data {
private[chisel3] final def allElements: Seq[Element] = Seq(this)
def widthKnown: Boolean = width.known
def name: String = getRef.name

private[chisel3] override def bind(target: Binding, parentDirection: SpecifiedDirection) {
binding = target
val resolvedDirection = SpecifiedDirection.fromParent(parentDirection, specifiedDirection)
Expand All @@ -30,9 +34,32 @@ abstract class Element(private[chisel3] val width: Width) extends Data {
}
}

private[chisel3] final def allElements: Seq[Element] = Seq(this)
def widthKnown: Boolean = width.known
def name: String = getRef.name
private[core] override def topBindingOpt: Option[TopBinding] = super.topBindingOpt match {
// Translate Bundle lit bindings to Element lit bindings
case Some(BundleLitBinding(litMap)) => litMap.get(this) match {
case Some(litArg) => Some(ElementLitBinding(litArg))
case _ => Some(DontCareBinding())
}
case topBindingOpt => topBindingOpt
}

private[core] def litArgOption: Option[LitArg] = topBindingOpt match {
case Some(ElementLitBinding(litArg)) => Some(litArg)
case _ => None
}

override def litOption: Option[BigInt] = litArgOption.map(_.num)
private[core] def litIsForcedWidth: Option[Boolean] = litArgOption.map(_.forcedWidth)

// provide bits-specific literal handling functionality here
override private[chisel3] def ref: Arg = topBindingOpt match {
case Some(ElementLitBinding(litArg)) => litArg
case Some(BundleLitBinding(litMap)) => litMap.get(this) match {
case Some(litArg) => litArg
case _ => throwException(s"internal error: DontCare should be caught before getting ref")
}
case _ => super.ref
}

private[core] def legacyConnect(that: Data)(implicit sourceInfo: SourceInfo): Unit = {
// If the source is a DontCare, generate a DefInvalid for the sink,
Expand Down Expand Up @@ -69,7 +96,7 @@ private[chisel3] sealed trait ToBoolable extends Element {
* @define sumWidth @note The width of the returned $coll is `width of this` + `width of that`.
* @define unchangedWidth @note The width of the returned $coll is unchanged, i.e., the `width of this`.
*/
sealed abstract class Bits(width: Width) extends Element(width) with ToBoolable { //scalastyle:off number.of.methods
sealed abstract class Bits(private[chisel3] val width: Width) extends Element with ToBoolable { //scalastyle:off number.of.methods
// TODO: perhaps make this concrete?
// Arguments for: self-checking code (can't do arithmetic on bits)
// Arguments against: generates down to a FIRRTL UInt anyways
Expand All @@ -79,33 +106,6 @@ sealed abstract class Bits(width: Width) extends Element(width) with ToBoolable

def cloneType: this.type = cloneTypeWidth(width)

private[core] override def topBindingOpt: Option[TopBinding] = super.topBindingOpt match {
// Translate Bundle lit bindings to Element lit bindings
case Some(BundleLitBinding(litMap)) => litMap.get(this) match {
case Some(litArg) => Some(ElementLitBinding(litArg))
case _ => Some(DontCareBinding())
}
case topBindingOpt => topBindingOpt
}

private[core] def litArgOption: Option[LitArg] = topBindingOpt match {
case Some(ElementLitBinding(litArg)) => Some(litArg)
case _ => None
}

override def litOption: Option[BigInt] = litArgOption.map(_.num)
private[core] def litIsForcedWidth: Option[Boolean] = litArgOption.map(_.forcedWidth)

// provide bits-specific literal handling functionality here
override private[chisel3] def ref: Arg = topBindingOpt match {
case Some(ElementLitBinding(litArg)) => litArg
case Some(BundleLitBinding(litMap)) => litMap.get(this) match {
case Some(litArg) => litArg
case _ => throwException(s"internal error: DontCare should be caught before getting ref")
}
case _ => super.ref
}

/** Tail operator
*
* @param n the number of bits to remove
Expand Down Expand Up @@ -1693,7 +1693,7 @@ object FixedPoint {
*
* @note This API is experimental and subject to change
*/
final class Analog private (width: Width) extends Element(width) {
final class Analog private (private[chisel3] val width: Width) extends Element {
require(width.known, "Since Analog is only for use in BlackBoxes, width must be known")

private[core] override def typeEquivalent(that: Data): Boolean =
Expand Down
2 changes: 1 addition & 1 deletion chiselFrontend/src/main/scala/chisel3/core/Clock.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ object Clock {
}

// TODO: Document this.
sealed class Clock extends Element(Width(1)) {
sealed class Clock(private[chisel3] val width: Width = Width(1)) extends Element {
def cloneType: this.type = Clock().asInstanceOf[this.type]

private[core] def typeEquivalent(that: Data): Boolean =
Expand Down
4 changes: 3 additions & 1 deletion chiselFrontend/src/main/scala/chisel3/core/Data.scala
Original file line number Diff line number Diff line change
Expand Up @@ -533,10 +533,12 @@ object WireInit {
/** RHS (source) for Invalidate API.
* Causes connection logic to emit a DefInvalid when connected to an output port (or wire).
*/
object DontCare extends Element(width = UnknownWidth()) {
object DontCare extends Element {
// This object should be initialized before we execute any user code that refers to it,
// otherwise this "Chisel" object will end up on the UserModule's id list.

private[chisel3] override val width: Width = UnknownWidth()

bind(DontCareBinding(), SpecifiedDirection.Output)
override def cloneType = DontCare

Expand Down
6 changes: 6 additions & 0 deletions chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ object MonoConnect {
elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)
case (sink_e: Clock, source_e: Clock) =>
elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)
case (sink_e: EnumType, source_e: UnsafeEnum) =>
elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)
case (sink_e: EnumType, source_e: EnumType) if sink_e.typeEquivalent(source_e) =>
elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)
case (sink_e: UnsafeEnum, source_e: UInt) =>
elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)

// Handle Vec case
case (sink_v: Vec[Data @unchecked], source_v: Vec[Data @unchecked]) =>
Expand Down
248 changes: 248 additions & 0 deletions chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
// See LICENSE for license details.

package chisel3.core

import scala.language.experimental.macros
import scala.reflect.macros.blackbox.Context
import scala.collection.mutable

import chisel3.internal.Builder.pushOp
import chisel3.internal.firrtl.PrimOp._
import chisel3.internal.firrtl._
import chisel3.internal.sourceinfo._
import chisel3.internal.{Builder, InstanceId, throwException}
import firrtl.annotations._


object EnumAnnotations {
case class EnumComponentAnnotation(target: Named, enumTypeName: String) extends SingleTargetAnnotation[Named] {
def duplicate(n: Named) = this.copy(target = n)
}

case class EnumComponentChiselAnnotation(target: InstanceId, enumTypeName: String) extends ChiselAnnotation {
def toFirrtl = EnumComponentAnnotation(target.toNamed, enumTypeName)
}

case class EnumDefAnnotation(enumTypeName: String, definition: Map[String, BigInt]) extends NoTargetAnnotation

case class EnumDefChiselAnnotation(enumTypeName: String, definition: Map[String, BigInt]) extends ChiselAnnotation {
override def toFirrtl: Annotation = EnumDefAnnotation(enumTypeName, definition)
}
}
import EnumAnnotations._


abstract class EnumType(private val factory: EnumFactory, selfAnnotating: Boolean = true) extends Element {
override def cloneType: this.type = factory().asInstanceOf[this.type]

private[core] def compop(sourceInfo: SourceInfo, op: PrimOp, other: EnumType): Bool = {
requireIsHardware(this, "bits operated on")
requireIsHardware(other, "bits operated on")

if(!this.typeEquivalent(other))
throwException(s"Enum types are not equivalent: ${this.enumTypeName}, ${other.enumTypeName}")

pushOp(DefPrim(sourceInfo, Bool(), op, this.ref, other.ref))
}

private[core] override def typeEquivalent(that: Data): Boolean = {
this.getClass == that.getClass &&
this.factory == that.asInstanceOf[EnumType].factory
}

// This isn't actually used anywhere (and it would throw an exception anyway). But it has to be defined since we
// inherit it from Data.
private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo,
compileOptions: CompileOptions): Unit = ???

final def === (that: EnumType): Bool = macro SourceInfoTransform.thatArg
final def =/= (that: EnumType): Bool = macro SourceInfoTransform.thatArg
final def < (that: EnumType): Bool = macro SourceInfoTransform.thatArg
final def <= (that: EnumType): Bool = macro SourceInfoTransform.thatArg
final def > (that: EnumType): Bool = macro SourceInfoTransform.thatArg
final def >= (that: EnumType): Bool = macro SourceInfoTransform.thatArg

def do_=== (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, EqualOp, that)
def do_=/= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, NotEqualOp, that)
def do_< (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessOp, that)
def do_> (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, GreaterOp, that)
def do_<= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessEqOp, that)
def do_>= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, GreaterEqOp, that)

override def do_asUInt(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): UInt =
pushOp(DefPrim(sourceInfo, UInt(width), AsUIntOp, ref))

protected[chisel3] override def width: Width = factory.width

def isValid(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = {
if (litOption.isDefined) {
true.B
} else {
factory.all.map(this === _).reduce(_ || _)
}
}

def next(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = {
if (litOption.isDefined) {
val index = factory.all.indexOf(this)

if (index < factory.all.length-1)
factory.all(index+1).asInstanceOf[this.type]
else
factory.all.head.asInstanceOf[this.type]
} else {
val enums_with_nexts = factory.all zip (factory.all.tail :+ factory.all.head)
val next_enum = SeqUtils.priorityMux(enums_with_nexts.map { case (e,n) => (this === e, n) } )
next_enum.asInstanceOf[this.type]
}
}

private[core] def bindToLiteral(num: BigInt, w: Width): Unit = {
val lit = ULit(num, w)
lit.bindLitArg(this)
}

override def bind(target: Binding, parentDirection: SpecifiedDirection): Unit = {
super.bind(target, parentDirection)

// If we try to annotate something that is bound to a literal, we get a FIRRTL annotation exception.
// To workaround that, we only annotate enums that are not bound to literals.
if (selfAnnotating && litOption.isEmpty) {
annotateEnum()
}
}

private def annotateEnum(): Unit = {
annotate(EnumComponentChiselAnnotation(this, enumTypeName))

if (!Builder.annotations.contains(factory.globalAnnotation)) {
annotate(factory.globalAnnotation)
}
}

protected def enumTypeName: String = factory.enumTypeName

def toPrintable: Printable = FullName(this) // TODO: Find a better pretty printer
}


abstract class EnumFactory {
class Type extends EnumType(this)
object Type {
def apply(): Type = EnumFactory.this.apply()
}

private var id: BigInt = 0
private[core] var width: Width = 0.W

private case class EnumRecord(inst: Type, name: String)
private val enum_records = mutable.ArrayBuffer.empty[EnumRecord]

private def enumNames = enum_records.map(_.name).toSeq
private def enumValues = enum_records.map(_.inst.litValue()).toSeq
private def enumInstances = enum_records.map(_.inst).toSeq

private[core] val enumTypeName = getClass.getName.init

private[core] def globalAnnotation: EnumDefChiselAnnotation =
EnumDefChiselAnnotation(enumTypeName, (enumNames, enumValues).zipped.toMap)

def getWidth: Int = width.get

def all: Seq[Type] = enumInstances

protected def Value: Type = macro EnumMacros.ValImpl
protected def Value(id: UInt): Type = macro EnumMacros.ValCustomImpl

protected def do_Value(names: Seq[String]): Type = {
val result = new Type

// We have to use UnknownWidth here, because we don't actually know what the final width will be
result.bindToLiteral(id, UnknownWidth())

val result_name = names.find(!enumNames.contains(_)).get
enum_records.append(EnumRecord(result, result_name))

width = (1 max id.bitLength).W
id += 1

result
}

protected def do_Value(names: Seq[String], id: UInt): Type = {
// TODO: These throw ExceptionInInitializerError which can be confusing to the user. Get rid of the error, and just
// throw an exception
if (id.litOption.isEmpty)
throwException(s"$enumTypeName defined with a non-literal type")
if (id.litValue() < this.id)
throwException(s"Enums must be strictly increasing: $enumTypeName")

this.id = id.litValue()
do_Value(names)
}

def apply(): Type = new Type

def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): Type = {
if (n.litOption.isDefined) {
val result = enumInstances.find(_.litValue == n.litValue)

if (result.isEmpty) {
throwException(s"${n.litValue}.U is not a valid value for $enumTypeName")
} else {
result.get
}
} else if (!n.isWidthKnown) {
throwException(s"Non-literal UInts being cast to $enumTypeName must have a defined width")
} else if (n.getWidth > this.getWidth) {
throwException(s"The UInt being cast to $enumTypeName is wider than $enumTypeName's width ($getWidth)")
} else {
Builder.warning(s"A non-literal UInt is being cast to $enumTypeName. You can check that its value is legal by calling isValid")

val glue = Wire(new UnsafeEnum(width))
glue := n
val result = Wire(new Type)
result := glue
result
}
}
}


private[core] object EnumMacros {
def ValImpl(c: Context) : c.Tree = {
import c.universe._
val names = getNames(c)
q"""this.do_Value(Seq(..$names))"""
}

def ValCustomImpl(c: Context)(id: c.Expr[UInt]) = {
import c.universe._
val names = getNames(c)
q"""this.do_Value(Seq(..$names), $id)"""
}

// Much thanks to Travis Brown for this solution:
// stackoverflow.com/questions/18450203/retrieve-the-name-of-the-value-a-scala-macro-invocation-will-be-assigned-to
def getNames(c: Context): Seq[String] = {
import c.universe._

val names = c.enclosingClass.collect {
case ValDef(_, name, _, rhs)
if rhs.pos == c.macroApplication.pos => name.decoded
}

if (names.isEmpty)
c.abort(c.enclosingPosition, "Value cannot be called without assigning to an enum")

names
}
}


// This is an enum type that can be connected directly to UInts. It is used as a "glue" to cast non-literal UInts
// to enums.
private[chisel3] class UnsafeEnum(override val width: Width) extends EnumType(UnsafeEnum, selfAnnotating = false) {
override def cloneType: this.type = new UnsafeEnum(width).asInstanceOf[this.type]
}
private object UnsafeEnum extends EnumFactory
Loading

0 comments on commit 6004052

Please sign in to comment.