Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strong enums #892

Merged
merged 16 commits into from
Oct 12, 2018
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions chiselFrontend/src/main/scala/chisel3/core/Bits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import chisel3.internal.firrtl.PrimOp._
*
* @define coll element
*/
abstract class Element(private[chisel3] val width: Width) extends Data {
abstract class Element extends Data {
hngenc marked this conversation as resolved.
Show resolved Hide resolved
private[chisel3] override def bind(target: Binding, parentDirection: SpecifiedDirection) {
binding = target
val resolvedDirection = SpecifiedDirection.fromParent(parentDirection, specifiedDirection)
Expand Down Expand Up @@ -69,7 +69,7 @@ private[chisel3] sealed trait ToBoolable extends Element {
* @define sumWidth @note The width of the returned $coll is `width of this` + `width of that`.
* @define unchangedWidth @note The width of the returned $coll is unchanged, i.e., the `width of this`.
*/
sealed abstract class Bits(width: Width) extends Element(width) with ToBoolable { //scalastyle:off number.of.methods
sealed abstract class Bits(private[chisel3] val width: Width) extends Element with ToBoolable { //scalastyle:off number.of.methods
// TODO: perhaps make this concrete?
// Arguments for: self-checking code (can't do arithmetic on bits)
// Arguments against: generates down to a FIRRTL UInt anyways
Expand Down Expand Up @@ -1693,7 +1693,7 @@ object FixedPoint {
*
* @note This API is experimental and subject to change
*/
final class Analog private (width: Width) extends Element(width) {
final class Analog private (private[chisel3] val width: Width) extends Element {
require(width.known, "Since Analog is only for use in BlackBoxes, width must be known")

private[core] override def typeEquivalent(that: Data): Boolean =
Expand Down
2 changes: 1 addition & 1 deletion chiselFrontend/src/main/scala/chisel3/core/Clock.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ object Clock {
}

// TODO: Document this.
sealed class Clock extends Element(Width(1)) {
sealed class Clock(private[chisel3] val width: Width = Width(1)) extends Element {
def cloneType: this.type = Clock().asInstanceOf[this.type]

private[core] def typeEquivalent(that: Data): Boolean =
Expand Down
4 changes: 3 additions & 1 deletion chiselFrontend/src/main/scala/chisel3/core/Data.scala
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,12 @@ object WireInit {
/** RHS (source) for Invalidate API.
* Causes connection logic to emit a DefInvalid when connected to an output port (or wire).
*/
object DontCare extends Element(width = UnknownWidth()) {
object DontCare extends Element {
// This object should be initialized before we execute any user code that refers to it,
// otherwise this "Chisel" object will end up on the UserModule's id list.

private[chisel3] override val width: Width = UnknownWidth()

bind(DontCareBinding(), SpecifiedDirection.Output)
override def cloneType = DontCare

Expand Down
6 changes: 6 additions & 0 deletions chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ object MonoConnect {
elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)
case (sink_e: Clock, source_e: Clock) =>
elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)
case (sink_e: EnumType, source_e: UnsafeEnum) =>
elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)
case (sink_e: EnumType, source_e: EnumType) if sink_e.getClass.equals(source_e.getClass) =>
elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)
case (sink_e: UnsafeEnum, source_e: UInt) =>
elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)

// Handle Vec case
case (sink_v: Vec[Data @unchecked], source_v: Vec[Data @unchecked]) =>
Expand Down
316 changes: 316 additions & 0 deletions chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
// See LICENSE for license details.

package chisel3.core

import scala.language.experimental.macros
import scala.reflect.ClassTag
import scala.reflect.runtime.currentMirror
import scala.reflect.runtime.universe.{MethodSymbol, runtimeMirror}
import scala.collection.mutable

import chisel3.internal.Builder.pushOp
import chisel3.internal.firrtl.PrimOp._
import chisel3.internal.firrtl._
import chisel3.internal.sourceinfo._
import chisel3.internal.{Builder, InstanceId, throwException}
import firrtl.annotations._

object EnumExceptions {
case class EnumTypeMismatchException(message: String) extends Exception(message)
case class EnumHasNoCompanionObjectException(message: String) extends Exception(message)
case class NonLiteralEnumException(message: String) extends Exception(message)
case class NonIncreasingEnumException(message: String) extends Exception(message)
case class IllegalDefinitionOfEnumException(message: String) extends Exception(message)
case class IllegalCastToEnumException(message: String) extends Exception(message)
case class NoEmptyConstructorException(message: String) extends Exception(message)
}

object EnumAnnotations {
case class EnumComponentAnnotation(target: Named, enumTypeName: String) extends SingleTargetAnnotation[Named] {
def duplicate(n: Named) = this.copy(target = n)
}

case class EnumComponentChiselAnnotation(target: InstanceId, enumTypeName: String) extends ChiselAnnotation {
def toFirrtl = EnumComponentAnnotation(target.toNamed, enumTypeName)
}

case class EnumDefAnnotation(enumTypeName: String, definition: Map[String, UInt]) extends NoTargetAnnotation

case class EnumDefChiselAnnotation(enumTypeName: String, definition: Map[String, UInt]) extends ChiselAnnotation {
override def toFirrtl: Annotation = EnumDefAnnotation(enumTypeName, definition)
}
}

import EnumExceptions._
import EnumAnnotations._

abstract class EnumType(selfAnnotating: Boolean = true) extends Element {
override def cloneType: this.type = getClass.getConstructor().newInstance().asInstanceOf[this.type]

private[core] override def topBindingOpt: Option[TopBinding] = super.topBindingOpt match {
hngenc marked this conversation as resolved.
Show resolved Hide resolved
// Translate Bundle lit bindings to Element lit bindings
case Some(BundleLitBinding(litMap)) => litMap.get(this) match {
case Some(litArg) => Some(ElementLitBinding(litArg))
case _ => Some(DontCareBinding())
}
case topBindingOpt => topBindingOpt
}

private[core] def litArgOption: Option[LitArg] = topBindingOpt match {
hngenc marked this conversation as resolved.
Show resolved Hide resolved
case Some(ElementLitBinding(litArg)) => Some(litArg)
case _ => None
}
hngenc marked this conversation as resolved.
Show resolved Hide resolved

override def litOption: Option[BigInt] = litArgOption.map(_.num)
private[core] def litIsForcedWidth: Option[Boolean] = litArgOption.map(_.forcedWidth)

// provide bits-specific literal handling functionality here
override private[chisel3] def ref: Arg = topBindingOpt match {
case Some(ElementLitBinding(litArg)) => litArg
case Some(BundleLitBinding(litMap)) => litMap.get(this) match {
case Some(litArg) => litArg
case _ => throwException(s"internal error: DontCare should be caught before getting ref")
}
case _ => super.ref
}

private[core] def compop(sourceInfo: SourceInfo, op: PrimOp, other: EnumType): Bool = {
requireIsHardware(this, "bits operated on")
requireIsHardware(other, "bits operated on")

checkTypeEquivalency(other)

pushOp(DefPrim(sourceInfo, Bool(), op, this.ref, other.ref))
}

private[core] override def typeEquivalent(that: Data): Boolean = this.getClass == that.getClass

// This isn't actually used anywhere (and it would throw an exception anyway). But it has to be defined since we
// inherit it from Data.
private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo,
compileOptions: CompileOptions): Unit = {
this := that.asUInt
}

final def === (that: EnumType): Bool = macro SourceInfoTransform.thatArg
final def =/= (that: EnumType): Bool = macro SourceInfoTransform.thatArg
final def < (that: EnumType): Bool = macro SourceInfoTransform.thatArg
final def <= (that: EnumType): Bool = macro SourceInfoTransform.thatArg
final def > (that: EnumType): Bool = macro SourceInfoTransform.thatArg
final def >= (that: EnumType): Bool = macro SourceInfoTransform.thatArg

def do_=== (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, EqualOp, that)
def do_=/= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, NotEqualOp, that)
def do_< (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessOp, that)
def do_> (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, GreaterOp, that)
def do_<= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessEqOp, that)
def do_>= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, GreaterEqOp, that)

override def do_asUInt(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): UInt =
pushOp(DefPrim(sourceInfo, UInt(width), AsUIntOp, ref))

private val companionModule = currentMirror.reflect(this).symbol.companion.asModule
hngenc marked this conversation as resolved.
Show resolved Hide resolved
private val companionObject =
try {
currentMirror.reflectModule(companionModule).instance.asInstanceOf[StrongEnum[this.type]]
} catch {
case ex: java.lang.ClassNotFoundException =>
throw EnumHasNoCompanionObjectException(s"$enumTypeName's companion object was not found")
}

private[chisel3] override def width: Width = companionObject.width

def isValid(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = {
if (!companionObject.finishedInstantiation)
throwException(s"Not all enums values have been defined yet")

if (litOption.isDefined) {
true.B
} else {
def mux_builder(enums: List[this.type]): Bool = enums match {
case Nil => false.B
case e :: es => Mux(this === e, true.B, mux_builder(es))
}

mux_builder(companionObject.all)
}
}

private[core] def bindToLiteral(bits: UInt): Unit = {
val litNum = bits.litOption.get
val lit = ULit(litNum, width) // We must make sure to use the enum's width, rather than the UInt's width
lit.bindLitArg(this)
}

override def bind(target: Binding, parentDirection: SpecifiedDirection): Unit = {
super.bind(target, parentDirection)

// If we try to annotate something that is bound to a literal, we get a FIRRTL annotation exception.
// To workaround that, we only annotate enums that are not bound to literals.
if (selfAnnotating && !litOption.isDefined) {
annotateEnum()
}
}

private def annotateEnum(): Unit = {
annotate(EnumComponentChiselAnnotation(this, enumTypeName))

if (!Builder.annotations.contains(companionObject.globalAnnotation)) {
annotate(companionObject.globalAnnotation)
}
}

private def enumTypeName: String = getClass.getName

// TODO: See if there is a way to catch this at compile-time
def checkTypeEquivalency(that: EnumType): Unit =
if (!typeEquivalent(that)) {
throw EnumTypeMismatchException(s"${this.getClass.getName} and ${that.getClass.getName} are different enum types")
}

def toPrintable: Printable = FullName(this) // TODO: Find a better pretty printer
}

// This is an enum type that can be connected directly to UInts. It is used as a "glue" to cast non-literal UInts
// to enums.
sealed private[chisel3] class UnsafeEnum(override val width: Width) extends EnumType(selfAnnotating = false) {
override def cloneType: this.type = getClass.getConstructor(classOf[Width]).newInstance(width).asInstanceOf[this.type]
}
private object UnsafeEnum extends StrongEnum[UnsafeEnum] {
override def checkEmptyConstructorExists(): Unit = {}
}

abstract class StrongEnum[T <: EnumType : ClassTag] {
private var id: BigInt = 0
private[core] var width: Width = 0.W

private val enum_names = getEnumNames
private val enum_values = mutable.ArrayBuffer.empty[BigInt]
private val enum_instances = mutable.ArrayBuffer.empty[T]

private def getEnumNames(implicit ct: ClassTag[T]): Seq[String] = {
val mirror = runtimeMirror(this.getClass.getClassLoader)

// We use Java reflection to get all the enum fields, and then we use Scala reflection to sort them in declaration
// order. TODO: Use only Scala reflection here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ducky64 Could we use this mechanism for Bundle field ordering? I suspect Scala reflection was too flakey back in the day for this to work, but maybe it does now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there were cases where Scala reflection just pukes all over the place. I think this had to do with inner classes and fun like that, so we're still stuck with looking at id order, which is kind of hacky and nasty.
We use Scala reflection in autoclonetype, but that's fine because that's best effort - it's more or less accepted to tell the user that they need to overload cloneType (though we might not have if we were doing clean-slate design).

val fields = getClass.getDeclaredFields.filter(_.getType == ct.runtimeClass).map(_.getName)
val getters = mirror.classSymbol(this.getClass).toType.members.sorted.collect {
case m: MethodSymbol if m.isGetter => m.name.toString
}

getters.filter(fields.contains(_))
}

private def bindAllEnums(): Unit =
(enum_instances, enum_values).zipped.foreach((inst, v) => inst.bindToLiteral(v.U(width)))

private[core] def globalAnnotation: EnumDefChiselAnnotation =
EnumDefChiselAnnotation(enumTypeName, (enum_names, enum_values.map(_.U(width))).zipped.toMap)

private[core] def finishedInstantiation: Boolean =
enum_names.length == enum_instances.length

private def newEnum()(implicit ct: ClassTag[T]): T =
ct.runtimeClass.newInstance.asInstanceOf[T]

// TODO: This depends upon undocumented behavior (which, to be fair, is unlikely to change). Use reflection to find
// the companion class's name in a more robust way.
private val enumTypeName = getClass.getName.init

def getWidth: Int = width.get

def all: List[T] = enum_instances.toList
hngenc marked this conversation as resolved.
Show resolved Hide resolved

def Value: T = {
val result = newEnum()
enum_instances.append(result)
enum_values.append(id)

width = (1 max id.bitLength).W
id += 1

// Instantiate all the enums when Value is called for the last time
if (enum_instances.length == enum_names.length && isTopLevelConstructor) {
hngenc marked this conversation as resolved.
Show resolved Hide resolved
bindAllEnums()
}

result
}

def Value(id: UInt): T = {
// TODO: These throw ExceptionInInitializerError which can be confusing to the user. Get rid of the error, and just
// throw an exception
if (!id.litOption.isDefined)
throw NonLiteralEnumException(s"$enumTypeName defined with a non-literal type in companion object")
if (id.litValue() <= this.id)
throw NonIncreasingEnumException(s"Enums must be strictly increasing: $enumTypeName")

this.id = id.litValue()
Value
}

def apply(): T = newEnum()

def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): T = {
if (!n.litOption.isDefined) {
throwException(s"Illegal cast from non-literal UInt to $enumTypeName. Use castFromNonLit instead")
} else if (!enum_values.contains(n.litValue)) {
throwException(s"${n.litValue}.U is not a valid value for $enumTypeName")
}

val result = newEnum()
result.bindToLiteral(n)
result
}

def castFromNonLit(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): T = {
if (n.litOption.isDefined) {
apply(n)
} else if (!n.isWidthKnown) {
throwException(s"Non-literal UInts being cast to $enumTypeName must have a defined width")
} else if (n.getWidth > this.getWidth) {
throwException(s"The UInt being cast to $enumTypeName is wider than $enumTypeName's width ($getWidth)")
} else {
Builder.warning(s"A non-literal UInt is being cast to $enumTypeName. You can check that the value is legal by calling isValid")

val glue = Wire(new UnsafeEnum(width))
glue := n
val result = Wire(newEnum())
result := glue
result
}
}

// StrongEnum basically has a recursive constructor. It instantiates a copy of itself internally, so that it can
// make sure that all EnumType's inside of it were instantiated using the "Value" function. However, in order to
// instantiate its copy, as well as to instantiate new enums, it has to make sure that it has a no-args constructor
// as it won't know what parameters to add otherwise.

protected def checkEmptyConstructorExists(): Unit = {
try {
implicitly[ClassTag[T]].runtimeClass.getDeclaredConstructor()
getClass.getDeclaredConstructor()
} catch {
case ex: NoSuchMethodException => throw NoEmptyConstructorException(s"$enumTypeName does not have a no-args constructor. Did you declare it inside a class?")
hngenc marked this conversation as resolved.
Show resolved Hide resolved
}
}

private val isTopLevelConstructor: Boolean = {
hngenc marked this conversation as resolved.
Show resolved Hide resolved
val stack_trace = Thread.currentThread().getStackTrace
val constructorName = "<init>"

stack_trace.count(se => se.getClassName.equals(getClass.getName) && se.getMethodName.equals(constructorName)) == 1
}

if (isTopLevelConstructor) {
checkEmptyConstructorExists()

val constructor = getClass.getDeclaredConstructor()
constructor.setAccessible(true)
val childInstance = constructor.newInstance()

if (!childInstance.finishedInstantiation) {
throw IllegalDefinitionOfEnumException(s"$enumTypeName defined illegally. Did you forget to call Value when defining a new enum?")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ abstract class LitArg(val num: BigInt, widthArg: Width) extends Arg {
private[chisel3] def width: Width = if (forcedWidth) widthArg else Width(minWidth)
override def fullName(ctx: Component): String = name
// Ensure the node representing this LitArg has a ref to it and a literal binding.
def bindLitArg[T <: Bits](bits: T): T = {
bits.bind(ElementLitBinding(this))
bits.setRef(this)
bits
def bindLitArg[T <: Element](elem: T): T = {
elem.bind(ElementLitBinding(this))
elem.setRef(this)
elem
}

protected def minWidth: Int
Expand Down