Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strong enums #892

Merged
merged 16 commits into from
Oct 12, 2018
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions chiselFrontend/src/main/scala/chisel3/core/Bits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import chisel3.internal.firrtl.PrimOp._
*
* @define coll element
*/
abstract class Element(private[chisel3] val width: Width) extends Data {
abstract class Element extends Data {
hngenc marked this conversation as resolved.
Show resolved Hide resolved
private[chisel3] override def bind(target: Binding, parentDirection: SpecifiedDirection) {
binding = target
val resolvedDirection = SpecifiedDirection.fromParent(parentDirection, specifiedDirection)
Expand Down Expand Up @@ -69,7 +69,7 @@ private[chisel3] sealed trait ToBoolable extends Element {
* @define sumWidth @note The width of the returned $coll is `width of this` + `width of that`.
* @define unchangedWidth @note The width of the returned $coll is unchanged, i.e., the `width of this`.
*/
sealed abstract class Bits(width: Width) extends Element(width) with ToBoolable { //scalastyle:off number.of.methods
sealed abstract class Bits(private[chisel3] val width: Width) extends Element with ToBoolable { //scalastyle:off number.of.methods
// TODO: perhaps make this concrete?
// Arguments for: self-checking code (can't do arithmetic on bits)
// Arguments against: generates down to a FIRRTL UInt anyways
Expand Down Expand Up @@ -1693,7 +1693,7 @@ object FixedPoint {
*
* @note This API is experimental and subject to change
*/
final class Analog private (width: Width) extends Element(width) {
final class Analog private (private[chisel3] val width: Width) extends Element {
require(width.known, "Since Analog is only for use in BlackBoxes, width must be known")

private[core] override def typeEquivalent(that: Data): Boolean =
Expand Down
2 changes: 1 addition & 1 deletion chiselFrontend/src/main/scala/chisel3/core/Clock.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ object Clock {
}

// TODO: Document this.
sealed class Clock extends Element(Width(1)) {
sealed class Clock(private[chisel3] val width: Width = Width(1)) extends Element {
def cloneType: this.type = Clock().asInstanceOf[this.type]

private[core] def typeEquivalent(that: Data): Boolean =
Expand Down
4 changes: 3 additions & 1 deletion chiselFrontend/src/main/scala/chisel3/core/Data.scala
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,12 @@ object WireInit {
/** RHS (source) for Invalidate API.
* Causes connection logic to emit a DefInvalid when connected to an output port (or wire).
*/
object DontCare extends Element(width = UnknownWidth()) {
object DontCare extends Element {
// This object should be initialized before we execute any user code that refers to it,
// otherwise this "Chisel" object will end up on the UserModule's id list.

private[chisel3] override val width: Width = UnknownWidth()

bind(DontCareBinding(), SpecifiedDirection.Output)
override def cloneType = DontCare

Expand Down
6 changes: 6 additions & 0 deletions chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ object MonoConnect {
elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)
case (sink_e: Clock, source_e: Clock) =>
elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)
case (sink_e: EnumType, source_e: UnsafeEnum) =>
elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)
case (sink_e: EnumType, source_e: EnumType) if sink_e.typeEquivalent(source_e) =>
elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)
case (sink_e: UnsafeEnum, source_e: UInt) =>
elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)

// Handle Vec case
case (sink_v: Vec[Data @unchecked], source_v: Vec[Data @unchecked]) =>
Expand Down
290 changes: 290 additions & 0 deletions chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
// See LICENSE for license details.

package chisel3.core

import scala.language.experimental.macros
import scala.reflect.macros.blackbox.Context
import scala.collection.mutable

import chisel3.internal.Builder.pushOp
import chisel3.internal.firrtl.PrimOp._
import chisel3.internal.firrtl._
import chisel3.internal.sourceinfo._
import chisel3.internal.{Builder, InstanceId, throwException}
import firrtl.annotations._


object EnumAnnotations {
case class EnumComponentAnnotation(target: Named, enumTypeName: String) extends SingleTargetAnnotation[Named] {
def duplicate(n: Named) = this.copy(target = n)
}

case class EnumComponentChiselAnnotation(target: InstanceId, enumTypeName: String) extends ChiselAnnotation {
def toFirrtl = EnumComponentAnnotation(target.toNamed, enumTypeName)
}

case class EnumDefAnnotation(enumTypeName: String, definition: Map[String, BigInt]) extends NoTargetAnnotation

case class EnumDefChiselAnnotation(enumTypeName: String, definition: Map[String, BigInt]) extends ChiselAnnotation {
override def toFirrtl: Annotation = EnumDefAnnotation(enumTypeName, definition)
}
}
import EnumAnnotations._


abstract class EnumType(private val factory: EnumFactory, selfAnnotating: Boolean = true) extends Element {
override def cloneType: this.type = factory().asInstanceOf[this.type]

private[core] override def topBindingOpt: Option[TopBinding] = super.topBindingOpt match {
hngenc marked this conversation as resolved.
Show resolved Hide resolved
// Translate Bundle lit bindings to Element lit bindings
case Some(BundleLitBinding(litMap)) => litMap.get(this) match {
case Some(litArg) => Some(ElementLitBinding(litArg))
case _ => Some(DontCareBinding())
}
case topBindingOpt => topBindingOpt
}

private[core] def litArgOption: Option[LitArg] = topBindingOpt match {
hngenc marked this conversation as resolved.
Show resolved Hide resolved
case Some(ElementLitBinding(litArg)) => Some(litArg)
case _ => None
}
hngenc marked this conversation as resolved.
Show resolved Hide resolved

override def litOption: Option[BigInt] = litArgOption.map(_.num)
private[core] def litIsForcedWidth: Option[Boolean] = litArgOption.map(_.forcedWidth)

// provide bits-specific literal handling functionality here
override private[chisel3] def ref: Arg = topBindingOpt match {
case Some(ElementLitBinding(litArg)) => litArg
case Some(BundleLitBinding(litMap)) => litMap.get(this) match {
case Some(litArg) => litArg
case _ => throwException(s"internal error: DontCare should be caught before getting ref")
}
case _ => super.ref
}

private[core] def compop(sourceInfo: SourceInfo, op: PrimOp, other: EnumType): Bool = {
requireIsHardware(this, "bits operated on")
requireIsHardware(other, "bits operated on")

if(!this.typeEquivalent(other))
throwException(s"Enum types are not equivalent: ${this.enumTypeName}, ${other.enumTypeName}")

pushOp(DefPrim(sourceInfo, Bool(), op, this.ref, other.ref))
}

private[core] override def typeEquivalent(that: Data): Boolean = {
this.getClass == that.getClass &&
this.factory == that.asInstanceOf[EnumType].factory
}

// This isn't actually used anywhere (and it would throw an exception anyway). But it has to be defined since we
// inherit it from Data.
private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo,
compileOptions: CompileOptions): Unit = {
this := that.asUInt
}

final def === (that: EnumType): Bool = macro SourceInfoTransform.thatArg
final def =/= (that: EnumType): Bool = macro SourceInfoTransform.thatArg
final def < (that: EnumType): Bool = macro SourceInfoTransform.thatArg
final def <= (that: EnumType): Bool = macro SourceInfoTransform.thatArg
final def > (that: EnumType): Bool = macro SourceInfoTransform.thatArg
final def >= (that: EnumType): Bool = macro SourceInfoTransform.thatArg

def do_=== (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, EqualOp, that)
def do_=/= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, NotEqualOp, that)
def do_< (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessOp, that)
def do_> (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, GreaterOp, that)
def do_<= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessEqOp, that)
def do_>= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, GreaterEqOp, that)

override def do_asUInt(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): UInt =
pushOp(DefPrim(sourceInfo, UInt(width), AsUIntOp, ref))

protected[chisel3] override def width: Width = factory.width

def isValid(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = {
if (litOption.isDefined) {
true.B
} else {
def muxBuilder(enums: List[EnumType]): Bool = enums match {
hngenc marked this conversation as resolved.
Show resolved Hide resolved
case Nil => false.B
case e :: es => Mux(this === e, true.B, muxBuilder(es))
}

muxBuilder(factory.all.toList)
}
}

def next(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = {
chick marked this conversation as resolved.
Show resolved Hide resolved
if (litOption.isDefined) {
val index = factory.all.indexOf(this)
hngenc marked this conversation as resolved.
Show resolved Hide resolved

if (index < factory.all.length-1)
factory.all(index+1).asInstanceOf[this.type]
else
factory.all.head.asInstanceOf[this.type]
} else {
def muxBuilder(enums: List[EnumType], first_enum: EnumType): EnumType = enums match {
hngenc marked this conversation as resolved.
Show resolved Hide resolved
case e :: Nil => first_enum
case e :: e_next :: es => Mux(this === e, e_next, muxBuilder(e_next :: es, first_enum))
}

muxBuilder(factory.all.toList, factory.all.head).asInstanceOf[this.type]
}
}

private[core] def bindToLiteral(num: BigInt, w: Width): Unit = {
val lit = ULit(num, w)
lit.bindLitArg(this)
}

override def bind(target: Binding, parentDirection: SpecifiedDirection): Unit = {
super.bind(target, parentDirection)

// If we try to annotate something that is bound to a literal, we get a FIRRTL annotation exception.
// To workaround that, we only annotate enums that are not bound to literals.
if (selfAnnotating && !litOption.isDefined) {
annotateEnum()
}
}

private def annotateEnum(): Unit = {
annotate(EnumComponentChiselAnnotation(this, enumTypeName))

if (!Builder.annotations.contains(factory.globalAnnotation)) {
annotate(factory.globalAnnotation)
}
}

protected def enumTypeName: String = factory.enumTypeName

def toPrintable: Printable = FullName(this) // TODO: Find a better pretty printer
}


abstract class EnumFactory {
class Type extends EnumType(this)
object Type {
def apply(): Type = EnumFactory.this.apply()
}

var id: BigInt = 0
var width: Width = 0.W

val enum_names = mutable.ArrayBuffer.empty[String]
hngenc marked this conversation as resolved.
Show resolved Hide resolved
val enum_values = mutable.ArrayBuffer.empty[BigInt]
val enum_instances = mutable.ArrayBuffer.empty[Type]

private[core] def globalAnnotation: EnumDefChiselAnnotation =
EnumDefChiselAnnotation(enumTypeName, (enum_names, enum_values).zipped.toMap)

private[core] val enumTypeName = getClass.getName.init
hngenc marked this conversation as resolved.
Show resolved Hide resolved

def getWidth: Int = width.get

def all: Seq[Type] = enum_instances.toSeq

protected def Value: Type = macro EnumMacros.ValImpl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably have some documentation, preferably ScalaDoc

protected def Value(id: UInt): Type = macro EnumMacros.ValCustomImpl

protected def do_Value(names: Seq[String]): Type = {
val result = new Type
enum_names ++= names.filter(!enum_names.contains(_))
enum_instances.append(result)
enum_values.append(id)

// We have to use UnknownWidth here, because we don't actually know what the final width will be
result.bindToLiteral(id, UnknownWidth())

width = (1 max id.bitLength).W
id += 1

result
}

protected def do_Value(names: Seq[String], id: UInt): Type = {
// TODO: These throw ExceptionInInitializerError which can be confusing to the user. Get rid of the error, and just
// throw an exception
if (!id.litOption.isDefined)
throwException(s"$enumTypeName defined with a non-literal type")
if (id.litValue() < this.id)
throwException(s"Enums must be strictly increasing: $enumTypeName")

this.id = id.litValue()
do_Value(names)
}

def apply(): Type = new Type

def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): Type = {
hngenc marked this conversation as resolved.
Show resolved Hide resolved
if (n.litOption.isEmpty) {
throwException(s"Illegal cast from non-literal UInt to $enumTypeName. Use fromBits instead")
}

val result = enum_instances.find(_.litValue == n.litValue)

if (result.isEmpty) {
throwException(s"${n.litValue}.U is not a valid value for $enumTypeName")
} else {
result.get
}
}

def fromBits(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): Type = {
if (n.litOption.isDefined) {
apply(n)
} else if (!n.isWidthKnown) {
throwException(s"Non-literal UInts being cast to $enumTypeName must have a defined width")
} else if (n.getWidth > this.getWidth) {
throwException(s"The UInt being cast to $enumTypeName is wider than $enumTypeName's width ($getWidth)")
} else {
Builder.warning(s"A non-literal UInt is being cast to $enumTypeName. You can check that the value is legal by calling isValid")

val glue = Wire(new UnsafeEnum(width))
glue := n
val result = Wire(new Type)
result := glue
result
}
}
}


private[core] object EnumMacros {
def ValImpl(c: Context) : c.Tree = {
import c.universe._
val names = getNames(c)
q"""this.do_Value(Seq(..$names))"""
}

def ValCustomImpl(c: Context)(id: c.Expr[UInt]) = {
import c.universe._
val names = getNames(c)
q"""this.do_Value(Seq(..$names), $id)"""
}

// Much thanks to Travis Brown for this solution:
// stackoverflow.com/questions/18450203/retrieve-the-name-of-the-value-a-scala-macro-invocation-will-be-assigned-to
def getNames(c: Context): Seq[String] = {
import c.universe._

val names = c.enclosingClass.collect {
case ValDef(_, name, _, rhs)
if rhs.pos == c.macroApplication.pos => name.decoded
}

if (names.isEmpty)
c.abort(c.enclosingPosition, "Value cannot be called without assigning to an enum")

names
}
}


// This is an enum type that can be connected directly to UInts. It is used as a "glue" to cast non-literal UInts
// to enums.
private[chisel3] class UnsafeEnum(override val width: Width) extends EnumType(UnsafeEnum, selfAnnotating = false) {
override def cloneType: this.type = new UnsafeEnum(width).asInstanceOf[this.type]
}
private object UnsafeEnum extends EnumFactory
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ abstract class LitArg(val num: BigInt, widthArg: Width) extends Arg {
private[chisel3] def width: Width = if (forcedWidth) widthArg else Width(minWidth)
override def fullName(ctx: Component): String = name
// Ensure the node representing this LitArg has a ref to it and a literal binding.
def bindLitArg[T <: Bits](bits: T): T = {
bits.bind(ElementLitBinding(this))
bits.setRef(this)
bits
def bindLitArg[T <: Element](elem: T): T = {
elem.bind(ElementLitBinding(this))
elem.setRef(this)
elem
}

protected def minWidth: Int
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/chisel3/internal/firrtl/Converter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

package chisel3.internal.firrtl
import chisel3._
import chisel3.core.SpecifiedDirection
import chisel3.core.{SpecifiedDirection, EnumType}
import chisel3.experimental._
import chisel3.internal.sourceinfo.{NoSourceInfo, SourceLine, SourceInfo}
import firrtl.{ir => fir}
Expand Down Expand Up @@ -211,6 +211,7 @@ private[chisel3] object Converter {

def extractType(data: Data, clearDir: Boolean = false): fir.Type = data match {
case _: Clock => fir.ClockType
case d: EnumType => fir.UIntType(convert(d.width))
case d: UInt => fir.UIntType(convert(d.width))
case d: SInt => fir.SIntType(convert(d.width))
case d: FixedPoint => fir.FixedType(convert(d.width), convert(d.binaryPoint))
Expand Down