From df343a36d4bd9a38c62d5e691f0f1f0b74185be8 Mon Sep 17 00:00:00 2001 From: Hasan Genc Date: Mon, 17 Sep 2018 16:29:47 -0700 Subject: [PATCH 01/13] Added new strongly-typed enum construct called "StrongEnum". "StrongEnum" will automatically generate annotations that HDL backends can use to mark components as enums Removed "override val width" constructor parameter from "Element" so that classes with variable widths, like the new strong enums, can inherit from it Changed the parameter types of certain functions, such as "switch", "is", and "LitArg.bindLitArg" from "Bits" to "Element", so that they can take the new strong enums as arguments --- .../src/main/scala/chisel3/core/Bits.scala | 8 +- .../src/main/scala/chisel3/core/Clock.scala | 2 +- .../src/main/scala/chisel3/core/Data.scala | 4 +- .../main/scala/chisel3/core/MonoConnect.scala | 6 + .../main/scala/chisel3/core/StrongEnum.scala | 276 ++++++++++++++++++ .../scala/chisel3/internal/firrtl/IR.scala | 8 +- .../chisel3/internal/firrtl/Converter.scala | 1 + .../chisel3/internal/firrtl/Emitter.scala | 1 + src/main/scala/chisel3/package.scala | 4 + src/main/scala/chisel3/util/Conditional.scala | 10 +- 10 files changed, 305 insertions(+), 15 deletions(-) create mode 100644 chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala diff --git a/chiselFrontend/src/main/scala/chisel3/core/Bits.scala b/chiselFrontend/src/main/scala/chisel3/core/Bits.scala index 10b6ec8ed7c..b1b67b9d634 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/Bits.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/Bits.scala @@ -17,7 +17,7 @@ import chisel3.internal.firrtl.PrimOp._ /** Element is a leaf data type: it cannot contain other Data objects. Example * uses are for representing primitive data types, like integers and bits. */ -abstract class Element(private[chisel3] val width: Width) extends Data { +abstract class Element extends Data { private[chisel3] override def bind(target: Binding, parentDirection: SpecifiedDirection) { binding = target val resolvedDirection = SpecifiedDirection.fromParent(parentDirection, specifiedDirection) @@ -61,8 +61,8 @@ private[chisel3] sealed trait ToBoolable extends Element { * bitwise operations. */ //scalastyle:off number.of.methods -sealed abstract class Bits(width: Width) - extends Element(width) with ToBoolable { +sealed abstract class Bits(private[chisel3] val width: Width) + extends Element with ToBoolable { // TODO: perhaps make this concrete? // Arguments for: self-checking code (can't do arithmetic on bits) // Arguments against: generates down to a FIRRTL UInt anyways @@ -1130,7 +1130,7 @@ object FixedPoint { * * @note This API is experimental and subject to change */ -final class Analog private (width: Width) extends Element(width) { +final class Analog private (private[chisel3] val width: Width) extends Element { require(width.known, "Since Analog is only for use in BlackBoxes, width must be known") private[core] override def typeEquivalent(that: Data): Boolean = diff --git a/chiselFrontend/src/main/scala/chisel3/core/Clock.scala b/chiselFrontend/src/main/scala/chisel3/core/Clock.scala index b728075b954..88208d9a6c0 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/Clock.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/Clock.scala @@ -12,7 +12,7 @@ object Clock { } // TODO: Document this. -sealed class Clock extends Element(Width(1)) { +sealed class Clock(private[chisel3] val width: Width = Width(1)) extends Element { def cloneType: this.type = Clock().asInstanceOf[this.type] private[core] def typeEquivalent(that: Data): Boolean = diff --git a/chiselFrontend/src/main/scala/chisel3/core/Data.scala b/chiselFrontend/src/main/scala/chisel3/core/Data.scala index 171a2bff5b1..4f9b894c56d 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/Data.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/Data.scala @@ -508,10 +508,12 @@ object WireInit { /** RHS (source) for Invalidate API. * Causes connection logic to emit a DefInvalid when connected to an output port (or wire). */ -object DontCare extends Element(width = UnknownWidth()) { +object DontCare extends Element { // This object should be initialized before we execute any user code that refers to it, // otherwise this "Chisel" object will end up on the UserModule's id list. + private[chisel3] override val width: Width = UnknownWidth() + bind(DontCareBinding(), SpecifiedDirection.Output) override def cloneType = DontCare diff --git a/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala b/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala index eba248709d9..2383e093bda 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala @@ -79,6 +79,12 @@ object MonoConnect { elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) case (sink_e: Clock, source_e: Clock) => elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) + case (sink_e: EnumType, source_e: UnsafeEnum) => + elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) + case (sink_e: EnumType, source_e: EnumType) if sink_e.getClass.equals(source_e.getClass) => + elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) + case (sink_e: UnsafeEnum, source_e: UInt) => + elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) // Handle Vec case case (sink_v: Vec[Data @unchecked], source_v: Vec[Data @unchecked]) => diff --git a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala new file mode 100644 index 00000000000..92eec10d672 --- /dev/null +++ b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala @@ -0,0 +1,276 @@ +package chisel3.core + +import scala.language.experimental.macros +import scala.reflect.ClassTag +import scala.reflect.runtime.currentMirror +import scala.reflect.runtime.universe.{MethodSymbol, runtimeMirror} +import scala.collection.mutable + +import chisel3.internal.Builder.pushOp +import chisel3.internal.firrtl.PrimOp._ +import chisel3.internal.firrtl._ +import chisel3.internal.sourceinfo._ +import chisel3.internal.{Builder, InstanceId, throwException} +import firrtl.annotations._ + +object EnumExceptions { + case class EnumTypeMismatch(message: String) extends Exception(message) + case class EnumHasNoCompanionObject(message: String) extends Exception(message) + case class NonLiteralEnum(message: String) extends Exception(message) + case class NonIncreasingEnum(message: String) extends Exception(message) + case class IllegalDefinitionOfEnum(message: String) extends Exception(message) + case class IllegalCastToEnum(message: String) extends Exception(message) + case class NoEmptyConstructor(message: String) extends Exception(message) +} + +object EnumAnnotations { + case class EnumComponentAnnotation(target: Named, enumTypeName: String) extends SingleTargetAnnotation[Named] { + def duplicate(n: Named) = this.copy(target = n) + } + + case class EnumComponentChiselAnnotation(target: InstanceId, enumTypeName: String) extends ChiselAnnotation { + def toFirrtl = EnumComponentAnnotation(target.toNamed, enumTypeName) + } + + case class EnumDefAnnotation(enumTypeName: String, definition: Map[String, UInt]) extends NoTargetAnnotation + + case class EnumDefChiselAnnotation(enumTypeName: String, definition: Map[String, UInt]) extends ChiselAnnotation { + override def toFirrtl: Annotation = EnumDefAnnotation(enumTypeName, definition) + } +} + +import EnumExceptions._ +import EnumAnnotations._ + +abstract class EnumType(selfAnnotating: Boolean = true) extends Element { + def cloneType: this.type = getClass.getConstructor().newInstance().asInstanceOf[this.type] + + private[core] override def topBindingOpt: Option[TopBinding] = super.topBindingOpt match { + // Translate Bundle lit bindings to Element lit bindings + case Some(BundleLitBinding(litMap)) => litMap.get(this) match { + case Some(litArg) => Some(ElementLitBinding(litArg)) + case _ => Some(DontCareBinding()) + } + case topBindingOpt => topBindingOpt + } + + private[core] def litArgOption: Option[LitArg] = topBindingOpt match { + case Some(ElementLitBinding(litArg)) => Some(litArg) + case _ => None + } + + override def litOption: Option[BigInt] = litArgOption.map(_.num) + private[core] def litIsForcedWidth: Option[Boolean] = litArgOption.map(_.forcedWidth) + + // provide bits-specific literal handling functionality here + override private[chisel3] def ref: Arg = topBindingOpt match { + case Some(ElementLitBinding(litArg)) => litArg + case Some(BundleLitBinding(litMap)) => litMap.get(this) match { + case Some(litArg) => litArg + case _ => throwException(s"internal error: DontCare should be caught before getting ref") + } + case _ => super.ref + } + + private[core] def compop(sourceInfo: SourceInfo, op: PrimOp, other: EnumType): Bool = { + requireIsHardware(this, "bits operated on") + requireIsHardware(other, "bits operated on") + + checkTypeEquivalency(other) + + pushOp(DefPrim(sourceInfo, Bool(), op, this.ref, other.ref)) + } + + private[core] override def typeEquivalent(that: Data): Boolean = this.getClass == that.getClass + + // This isn't actually used anywhere (and it would throw an exception anyway). But it has to be defined since we + // inherit it from Data. + private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo, + compileOptions: CompileOptions): Unit = { + this := that.asUInt + } + + final def === (that: EnumType): Bool = macro SourceInfoTransform.thatArg + final def =/= (that: EnumType): Bool = macro SourceInfoTransform.thatArg + final def < (that: EnumType): Bool = macro SourceInfoTransform.thatArg + final def <= (that: EnumType): Bool = macro SourceInfoTransform.thatArg + final def > (that: EnumType): Bool = macro SourceInfoTransform.thatArg + final def >= (that: EnumType): Bool = macro SourceInfoTransform.thatArg + + def do_=== (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, EqualOp, that) + def do_=/= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, NotEqualOp, that) + def do_< (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessOp, that) + def do_> (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, GreaterOp, that) + def do_<= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessEqOp, that) + def do_>= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, GreaterEqOp, that) + + override def do_asUInt(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): UInt = + pushOp(DefPrim(sourceInfo, UInt(width), AsUIntOp, ref)) + + val companionModule = currentMirror.reflect(this).symbol.companion.asModule + val companionObject = + try { + currentMirror.reflectModule(companionModule).instance.asInstanceOf[StrongEnum[this.type]] + } catch { + case ex: java.lang.ClassNotFoundException => + throw EnumHasNoCompanionObject(s"$enumTypeName's companion object was not found") + } + + private[chisel3] override def width: Width = companionObject.width + + private[core] def bindToLiteral(bits: UInt): Unit = { + val litNum = bits.litOption.get + val lit = ULit(litNum, width) // We must make sure to use the enum's width, rather than the UInt's width + lit.bindLitArg(this) + } + + override def bind(target: Binding, parentDirection: SpecifiedDirection): Unit = { + super.bind(target, parentDirection) + + // If we try to annotate something that is bound to a literal, we get a FIRRTL annotation exception. + // To workaround that, we only annotate enums that are not bound to literals. + if (selfAnnotating && !litOption.isDefined) + annotate(EnumComponentChiselAnnotation(this, enumTypeName)) + } + + private def enumTypeName: String = getClass.getName + + // TODO: See if there is a way to catch this at compile-time + def checkTypeEquivalency(that: EnumType): Unit = + if (!typeEquivalent(that)) + throw EnumTypeMismatch(s"${this.getClass.getName} and ${that.getClass.getName} are different enum types") + + def toPrintable: Printable = FullName(this) // TODO: Find a better pretty printer +} + +// This is an enum type that can be connected directly to UInts. It is used as a "glue" to cast non-literal UInts +// to enums. +sealed private[chisel3] class UnsafeEnum(override val width: Width) extends EnumType(selfAnnotating = false) { + override def cloneType: this.type = getClass.getConstructor(classOf[Width]).newInstance(width).asInstanceOf[this.type] +} +private object UnsafeEnum extends StrongEnum[UnsafeEnum] { + override def checkEmptyConstructorExists(): Unit = {} +} + +abstract class StrongEnum[T <: EnumType : ClassTag] { + private var id: BigInt = 0 + private[core] var width: Width = 0.W + + private val enum_names = getEnumNames + private val enum_values = mutable.ArrayBuffer.empty[BigInt] + private val enum_instances = mutable.ArrayBuffer.empty[T] + + private def getEnumNames(implicit ct: ClassTag[T]): Seq[String] = { + val mirror = runtimeMirror(this.getClass.getClassLoader) + val reflection = mirror.reflect(this) + + // We use Java reflection to get all the enum fields, and then we use Scala reflection to sort them in declaration + // order. TODO: Use only Scala reflection here + val fields = getClass.getDeclaredFields.filter(_.getType == ct.runtimeClass).map(_.getName) + val getters = mirror.classSymbol(this.getClass).toType.members.sorted.collect { + case m: MethodSymbol if m.isGetter => m.name.toString + } + + getters.filter(fields.contains(_)) + } + + private def bindAllEnums(): Unit = + (enum_instances, enum_values).zipped.foreach((inst, v) => inst.bindToLiteral(v.U(width))) + + private def createAnnotation(): Unit = + annotate(EnumDefChiselAnnotation(enumTypeName, + (enum_names, enum_values.map(_.U(width))).zipped.toMap)) + + private def newEnum()(implicit ct: ClassTag[T]): T = + ct.runtimeClass.newInstance.asInstanceOf[T] + + // TODO: This depends upon undocumented behavior (which, to be fair, is unlikely to change). Use reflection to find + // the companion class's name in a more robust way. + private val enumTypeName = getClass.getName.init + + def getWidth: BigInt = width.get + + def all: List[T] = enum_instances.toList + + def Value: T = { + val result = newEnum() + enum_instances.append(result) + enum_values.append(id) + + width = (1 max id.bitLength).W + id += 1 + + // Check whether we've instantiated all the enums + if (enum_instances.length == enum_names.length && isTopLevelConstructor) { + bindAllEnums() + createAnnotation() + } + + result + } + + def Value(id: UInt): T = { + // TODO: These throw ExceptionInInitializerError which can be confusing to the user. Get rid of the error, and just + // throw an exception + if (!id.litOption.isDefined) + throw NonLiteralEnum(s"$enumTypeName defined with a non-literal type in companion object") + if (id.litValue() <= this.id) + throw NonIncreasingEnum(s"Enums must be strictly increasing: $enumTypeName") + + this.id = id.litValue() + Value + } + + def apply(): T = newEnum() + + def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): T = { + if (n.litOption.isDefined) { + if (!enum_values.contains(n.litValue)) + throwException(s"${n.litValue}.U is not a valid value for $enumTypeName") + + val result = newEnum() + result.bindToLiteral(n) + result + } else { + Builder.warning(s"A non-literal UInt is being cast to $enumTypeName. No automatic bounds checking will be done!") + + val glue = Wire(new UnsafeEnum(width)) + glue := n + val result = Wire(newEnum()) + result := glue + result + } + } + + // StrongEnum basically has a recursive constructor. It instantiates a copy of itself internally, so that it can + // make sure that all EnumType's inside of it were instantiated using the "Value" function. However, in order to + // instantiate its copy, as well as to instantiate new enums, it has to make sure that it has a no-args constructor + // as it won't know what parameters to add otherwise. + + protected def checkEmptyConstructorExists(): Unit = { + try { + implicitly[ClassTag[T]].runtimeClass.getDeclaredConstructor() + getClass.getDeclaredConstructor() + } catch { + case ex: NoSuchMethodException => throw NoEmptyConstructor(s"$enumTypeName does not have a no-args constructor. Did you declare it inside a class?") + } + } + + private val isTopLevelConstructor: Boolean = { + val stack_trace = Thread.currentThread().getStackTrace + val constructorName = "" + + stack_trace.count(se => se.getClassName.equals(getClass.getName) && se.getMethodName.equals(constructorName)) == 1 + } + + if (isTopLevelConstructor) { + checkEmptyConstructorExists() + + val constructor = getClass.getDeclaredConstructor() + constructor.setAccessible(true) + val childInstance = constructor.newInstance() + + if (childInstance.enum_names.length != childInstance.enum_instances.length) + throw IllegalDefinitionOfEnum(s"$enumTypeName defined illegally. Did you forget to call Value when defining a new enum?") + } +} diff --git a/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala b/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala index b6630f7f0f6..ae8b248aaf4 100644 --- a/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala +++ b/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala @@ -68,10 +68,10 @@ abstract class LitArg(val num: BigInt, widthArg: Width) extends Arg { private[chisel3] def width: Width = if (forcedWidth) widthArg else Width(minWidth) override def fullName(ctx: Component): String = name // Ensure the node representing this LitArg has a ref to it and a literal binding. - def bindLitArg[T <: Bits](bits: T): T = { - bits.bind(ElementLitBinding(this)) - bits.setRef(this) - bits + def bindLitArg[T <: Element](elem: T): T = { + elem.bind(ElementLitBinding(this)) + elem.setRef(this) + elem } protected def minWidth: Int diff --git a/src/main/scala/chisel3/internal/firrtl/Converter.scala b/src/main/scala/chisel3/internal/firrtl/Converter.scala index 97504aba19a..9e9616c893e 100644 --- a/src/main/scala/chisel3/internal/firrtl/Converter.scala +++ b/src/main/scala/chisel3/internal/firrtl/Converter.scala @@ -211,6 +211,7 @@ private[chisel3] object Converter { def extractType(data: Data, clearDir: Boolean = false): fir.Type = data match { case _: Clock => fir.ClockType + case d: EnumType => fir.UIntType(convert(d.width)) case d: UInt => fir.UIntType(convert(d.width)) case d: SInt => fir.SIntType(convert(d.width)) case d: FixedPoint => fir.FixedType(convert(d.width), convert(d.binaryPoint)) diff --git a/src/main/scala/chisel3/internal/firrtl/Emitter.scala b/src/main/scala/chisel3/internal/firrtl/Emitter.scala index 26ccc09d62a..3e4d6f2151b 100644 --- a/src/main/scala/chisel3/internal/firrtl/Emitter.scala +++ b/src/main/scala/chisel3/internal/firrtl/Emitter.scala @@ -28,6 +28,7 @@ private class Emitter(circuit: Circuit) { private def emitType(d: Data, clearDir: Boolean = false): String = d match { case d: Clock => "Clock" + case d: EnumType => s"UInt${d.width}" case d: UInt => s"UInt${d.width}" case d: SInt => s"SInt${d.width}" case d: FixedPoint => s"Fixed${d.width}${d.binaryPoint}" diff --git a/src/main/scala/chisel3/package.scala b/src/main/scala/chisel3/package.scala index 7f1ad040852..c433ceb8fae 100644 --- a/src/main/scala/chisel3/package.scala +++ b/src/main/scala/chisel3/package.scala @@ -243,6 +243,10 @@ package object chisel3 { // scalastyle:ignore package.object.name object Bool extends BoolFactory val Mux = chisel3.core.Mux + type EnumType = chisel3.core.EnumType + type StrongEnum[T <: EnumType] = chisel3.core.StrongEnum[T] + val EnumAnnotations = chisel3.core.EnumAnnotations + type BlackBox = chisel3.core.BlackBox type InstanceId = chisel3.internal.InstanceId diff --git a/src/main/scala/chisel3/util/Conditional.scala b/src/main/scala/chisel3/util/Conditional.scala index bf2d4268717..3630f8add27 100644 --- a/src/main/scala/chisel3/util/Conditional.scala +++ b/src/main/scala/chisel3/util/Conditional.scala @@ -24,7 +24,7 @@ object unless { // scalastyle:ignore object.name * user-facing API. * @note DO NOT USE. This API is subject to change without warning. */ -class SwitchContext[T <: Bits](cond: T, whenContext: Option[WhenContext], lits: Set[BigInt]) { +class SwitchContext[T <: Element](cond: T, whenContext: Option[WhenContext], lits: Set[BigInt]) { def is(v: Iterable[T])(block: => Unit): SwitchContext[T] = { if (!v.isEmpty) { val newLits = v.map { w => @@ -60,19 +60,19 @@ object is { // scalastyle:ignore object.name // TODO: Begin deprecation of non-type-parameterized is statements. /** Executes `block` if the switch condition is equal to any of the values in `v`. */ - def apply(v: Iterable[Bits])(block: => Unit) { + def apply(v: Iterable[Element])(block: => Unit) { require(false, "The 'is' keyword may not be used outside of a switch.") } /** Executes `block` if the switch condition is equal to `v`. */ - def apply(v: Bits)(block: => Unit) { + def apply(v: Element)(block: => Unit) { require(false, "The 'is' keyword may not be used outside of a switch.") } /** Executes `block` if the switch condition is equal to any of the values in the argument list. */ - def apply(v: Bits, vr: Bits*)(block: => Unit) { + def apply(v: Element, vr: Element*)(block: => Unit) { require(false, "The 'is' keyword may not be used outside of a switch.") } } @@ -91,7 +91,7 @@ object is { // scalastyle:ignore object.name * }}} */ object switch { // scalastyle:ignore object.name - def apply[T <: Bits](cond: T)(x: => Unit): Unit = macro impl + def apply[T <: Element](cond: T)(x: => Unit): Unit = macro impl def impl(c: Context)(cond: c.Tree)(x: c.Tree): c.Tree = { import c.universe._ val q"..$body" = x val res = body.foldLeft(q"""new SwitchContext($cond, None, Set.empty)""") { From 6feccd9cdadf2438d7417c395388193d0da871fc Mon Sep 17 00:00:00 2001 From: Hasan Genc Date: Mon, 17 Sep 2018 16:32:44 -0700 Subject: [PATCH 02/13] Added tests for the new strong enums --- src/test/scala/chiselTests/StrongEnum.scala | 303 ++++++++++++++++++++ 1 file changed, 303 insertions(+) create mode 100644 src/test/scala/chiselTests/StrongEnum.scala diff --git a/src/test/scala/chiselTests/StrongEnum.scala b/src/test/scala/chiselTests/StrongEnum.scala new file mode 100644 index 00000000000..29b54fc609b --- /dev/null +++ b/src/test/scala/chiselTests/StrongEnum.scala @@ -0,0 +1,303 @@ +// See LICENSE for license details. + +package chiselTests + +import chisel3._ +import chisel3.core.{EnumAnnotations, EnumExceptions} +import chisel3.util._ +import chisel3.testers.BasicTester +import firrtl.annotations.ComponentName +import org.scalatest.{FreeSpec, Matchers} + +class EnumExample extends EnumType +object EnumExample extends StrongEnum[EnumExample] { + val e0, e1, e2 = Value + val e100 = Value(100.U) + val e101 = Value + + val litValues = List(0.U, 1.U, 2.U, 100.U, 101.U) +} + +class OtherEnum extends EnumType +object OtherEnum extends StrongEnum[OtherEnum] { + val otherEnum = Value +} + +class EnumWithoutCompanionObj extends EnumType + +class NonLiteralEnumType extends EnumType +object NonLiteralEnumType extends StrongEnum[NonLiteralEnumType] { + val nonLit = Value(UInt()) +} + +class SimpleConnector(inType: Data, outType: Data) extends Module { + val io = IO(new Bundle { + val in = Input(inType) + val out = Output(outType) + }) + + io.out := io.in +} + +class CastToUInt extends Module { + val io = IO(new Bundle { + val in = Input(EnumExample()) + val out = Output(UInt()) + }) + + io.out := io.in.asUInt() +} + +class CastToEnum extends Module { + val io = IO(new Bundle { + val in = Input(UInt()) + val out = Output(EnumExample()) + }) + + io.out := EnumExample(io.in) +} + +class EnumOps(xType: EnumType, yType: EnumType) extends Module { + val io = IO(new Bundle { + val x = Input(xType) + val y = Input(yType) + + val lt = Output(Bool()) + val le = Output(Bool()) + val gt = Output(Bool()) + val ge = Output(Bool()) + val eq = Output(Bool()) + val ne = Output(Bool()) + }) + + io.lt := io.x < io.y + io.le := io.x <= io.y + io.gt := io.x > io.y + io.ge := io.x >= io.y + io.eq := io.x === io.y + io.ne := io.x =/= io.y +} + +object StrongEnumFSM { + class State extends EnumType + object State extends StrongEnum[State] { + val sNone, sOne1, sTwo1s = Value + + val correct_annotation_map = Map[String, UInt]("sNone" -> 0.U(2.W), "sOne1" -> 1.U(2.W), "sTwo1s" -> 2.U(2.W)) + } +} + +class StrongEnumFSM extends Module { + // This FSM detects two 1's one after the other + val io = IO(new Bundle { + val in = Input(Bool()) + val out = Output(Bool()) + }) + + import StrongEnumFSM.State._ + + val state = RegInit(sNone) + + io.out := (state === sTwo1s) + + switch (state) { + is (sNone) { + when (io.in) { + state := sOne1 + } + } + is (sOne1) { + when (io.in) { + state := sTwo1s + } .otherwise { + state := sNone + } + } + is (sTwo1s) { + when (!io.in) { + state := sNone + } + } + } +} + +class CastToUIntTester extends BasicTester { + for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) { + val mod = Module(new CastToUInt) + mod.io.in := enum + assert(mod.io.out === lit) + } + stop() +} + +class CastToEnumTester extends BasicTester { + for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) { + val mod = Module(new CastToEnum) + mod.io.in := lit + assert(mod.io.out === enum) + } + stop() +} + +class CastToInvalidEnumTester extends BasicTester { + val invalid_value: UInt = EnumExample.litValues.last + 1.U + val mod = Module(new CastToEnum { + io.out := invalid_value + }) +} + +class EnumOpsTester extends BasicTester { + for (x <- EnumExample.all; + y <- EnumExample.all) { + val mod = Module(new EnumOps(EnumExample(), EnumExample())) + mod.io.x := x + mod.io.y := y + + assert(mod.io.lt === (x.asUInt() < y.asUInt())) + assert(mod.io.le === (x.asUInt() <= y.asUInt())) + assert(mod.io.gt === (x.asUInt() > y.asUInt())) + assert(mod.io.ge === (x.asUInt() >= y.asUInt())) + assert(mod.io.eq === (x.asUInt() === y.asUInt())) + assert(mod.io.ne === (x.asUInt() =/= y.asUInt())) + } + stop() +} + +class InvalidEnumOpsTester extends BasicTester { + val mod = Module(new EnumOps(EnumExample(), OtherEnum())) + mod.io.x := EnumExample.e0 + mod.io.y := OtherEnum.otherEnum +} + +class IsLitTester extends BasicTester { + for (e <- EnumExample.all) { + val wire = WireInit(e) + + assert(e.isLit()) + assert(!wire.isLit()) + } + stop() +} + +class StrongEnumFSMTester extends BasicTester { + val dut = Module(new StrongEnumFSM) + + // Inputs and expected results + val inputs: Vec[Bool] = VecInit(false.B, true.B, false.B, true.B, true.B, true.B, false.B, true.B, true.B, false.B) + val expected: Vec[Bool] = VecInit(false.B, false.B, false.B, false.B, false.B, true.B, true.B, false.B, false.B, true.B) + + val cntr = Counter(inputs.length) + val cycle = cntr.value + + dut.io.in := inputs(cycle) + assert(dut.io.out === expected(cycle)) + + when(cntr.inc()) { + stop() + } +} + +class StrongEnumSpec extends ChiselFlatSpec { + behavior of "Strong enum tester" + + it should "fail to instantiate enums without a companion class" in { + an [Exception] should be thrownBy { + elaborate(new SimpleConnector(new EnumWithoutCompanionObj(), new EnumWithoutCompanionObj())) + } + } + + it should "fail to instantiate non-literal enums in a companion object" in { + an [Error] should be thrownBy { + elaborate(new SimpleConnector(new NonLiteralEnumType(), new NonLiteralEnumType())) + } + } + + it should "connect enums of the same type" in { + elaborate(new SimpleConnector(EnumExample(), EnumExample())) + } + + it should "fail to connect a strong enum to a UInt" in { + an [Exception] should be thrownBy { + elaborate(new SimpleConnector(EnumExample(), UInt())) + } + } + + it should "fail to connect enums of different types" in { + an [Exception] should be thrownBy { + elaborate(new SimpleConnector(EnumExample(), OtherEnum())) + } + } + + it should "cast enums to UInts correctly" in { + assertTesterPasses(new CastToUIntTester) + } + + it should "cast UInts to enums correctly" in { + assertTesterPasses(new CastToEnumTester) + } + + it should "catch illegal literal casts to enums" in { + an [Exception] should be thrownBy { + elaborate(new CastToInvalidEnumTester) + } + } + + it should "execute enum comparison operations correctly" in { + assertTesterPasses(new EnumOpsTester) + } + + it should "fail to compare enums of different types" in { + an [Exception] should be thrownBy { + elaborate(new InvalidEnumOpsTester) + } + } + + it should "correctly check whether or not enums are literal" in { + assertTesterPasses(new IsLitTester) + } + + it should "return the correct widths for enums" in { + EnumExample.getWidth == EnumExample.litValues.last.getWidth + } + + "StrongEnum FSM" should "work" in { + assertTesterPasses(new StrongEnumFSMTester) + } +} + +class StrongEnumAnnotationSpec extends FreeSpec with Matchers { + "Test that strong enums annotate themselves appropriately" in { + + Driver.execute(Array("--target-dir", "test_run_dir"), () => new StrongEnumFSM) match { + case ChiselExecutionSuccess(Some(circuit), emitted, _) => + val annos = circuit.annotations.map(_.toFirrtl) + + // Check that the global annotation is correct + annos.exists { + case EnumAnnotations.EnumDefAnnotation(name, map) => + name.endsWith("State") && + map.size == StrongEnumFSM.State.correct_annotation_map.size && + map.forall { + case (k, v) => + val correctValue = StrongEnumFSM.State.correct_annotation_map(k) + + val correctValLit = correctValue.litValue() + val vLitValue = v.litValue() + + correctValue.getWidth == v.getWidth && correctValue.litValue() == v.litValue() + } + case _ => false + } should be(true) + + // Check that the component annotations are correct + annos.exists { + case EnumAnnotations.EnumComponentAnnotation(target, enumName) => + val ComponentName(targetName, _) = target + targetName == "state" && enumName.endsWith("State") + case _ => false + } should be(true) + case _ => + assert(false) + } + } +} From 192d47106f78f848fb03cfcdf78b84f32de74412 Mon Sep 17 00:00:00 2001 From: Hasan Genc Date: Mon, 17 Sep 2018 17:56:05 -0700 Subject: [PATCH 03/13] Changed StrongEnum exception names and made sure in StrongEnum tests that the correct types of exceptions are thrown --- .../main/scala/chisel3/core/StrongEnum.scala | 26 +++++++++---------- src/test/scala/chiselTests/StrongEnum.scala | 15 ++++++----- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala index 92eec10d672..15c4dc70426 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala @@ -14,13 +14,13 @@ import chisel3.internal.{Builder, InstanceId, throwException} import firrtl.annotations._ object EnumExceptions { - case class EnumTypeMismatch(message: String) extends Exception(message) - case class EnumHasNoCompanionObject(message: String) extends Exception(message) - case class NonLiteralEnum(message: String) extends Exception(message) - case class NonIncreasingEnum(message: String) extends Exception(message) - case class IllegalDefinitionOfEnum(message: String) extends Exception(message) - case class IllegalCastToEnum(message: String) extends Exception(message) - case class NoEmptyConstructor(message: String) extends Exception(message) + case class EnumTypeMismatchException(message: String) extends Exception(message) + case class EnumHasNoCompanionObjectException(message: String) extends Exception(message) + case class NonLiteralEnumException(message: String) extends Exception(message) + case class NonIncreasingEnumException(message: String) extends Exception(message) + case class IllegalDefinitionOfEnumException(message: String) extends Exception(message) + case class IllegalCastToEnumException(message: String) extends Exception(message) + case class NoEmptyConstructorException(message: String) extends Exception(message) } object EnumAnnotations { @@ -113,7 +113,7 @@ abstract class EnumType(selfAnnotating: Boolean = true) extends Element { currentMirror.reflectModule(companionModule).instance.asInstanceOf[StrongEnum[this.type]] } catch { case ex: java.lang.ClassNotFoundException => - throw EnumHasNoCompanionObject(s"$enumTypeName's companion object was not found") + throw EnumHasNoCompanionObjectException(s"$enumTypeName's companion object was not found") } private[chisel3] override def width: Width = companionObject.width @@ -138,7 +138,7 @@ abstract class EnumType(selfAnnotating: Boolean = true) extends Element { // TODO: See if there is a way to catch this at compile-time def checkTypeEquivalency(that: EnumType): Unit = if (!typeEquivalent(that)) - throw EnumTypeMismatch(s"${this.getClass.getName} and ${that.getClass.getName} are different enum types") + throw EnumTypeMismatchException(s"${this.getClass.getName} and ${that.getClass.getName} are different enum types") def toPrintable: Printable = FullName(this) // TODO: Find a better pretty printer } @@ -213,9 +213,9 @@ abstract class StrongEnum[T <: EnumType : ClassTag] { // TODO: These throw ExceptionInInitializerError which can be confusing to the user. Get rid of the error, and just // throw an exception if (!id.litOption.isDefined) - throw NonLiteralEnum(s"$enumTypeName defined with a non-literal type in companion object") + throw NonLiteralEnumException(s"$enumTypeName defined with a non-literal type in companion object") if (id.litValue() <= this.id) - throw NonIncreasingEnum(s"Enums must be strictly increasing: $enumTypeName") + throw NonIncreasingEnumException(s"Enums must be strictly increasing: $enumTypeName") this.id = id.litValue() Value @@ -252,7 +252,7 @@ abstract class StrongEnum[T <: EnumType : ClassTag] { implicitly[ClassTag[T]].runtimeClass.getDeclaredConstructor() getClass.getDeclaredConstructor() } catch { - case ex: NoSuchMethodException => throw NoEmptyConstructor(s"$enumTypeName does not have a no-args constructor. Did you declare it inside a class?") + case ex: NoSuchMethodException => throw NoEmptyConstructorException(s"$enumTypeName does not have a no-args constructor. Did you declare it inside a class?") } } @@ -271,6 +271,6 @@ abstract class StrongEnum[T <: EnumType : ClassTag] { val childInstance = constructor.newInstance() if (childInstance.enum_names.length != childInstance.enum_instances.length) - throw IllegalDefinitionOfEnum(s"$enumTypeName defined illegally. Did you forget to call Value when defining a new enum?") + throw IllegalDefinitionOfEnumException(s"$enumTypeName defined illegally. Did you forget to call Value when defining a new enum?") } } diff --git a/src/test/scala/chiselTests/StrongEnum.scala b/src/test/scala/chiselTests/StrongEnum.scala index 29b54fc609b..605a422422a 100644 --- a/src/test/scala/chiselTests/StrongEnum.scala +++ b/src/test/scala/chiselTests/StrongEnum.scala @@ -198,16 +198,19 @@ class StrongEnumFSMTester extends BasicTester { } class StrongEnumSpec extends ChiselFlatSpec { + import chisel3.core.EnumExceptions._ + import chisel3.internal.ChiselException + behavior of "Strong enum tester" it should "fail to instantiate enums without a companion class" in { - an [Exception] should be thrownBy { + an [EnumHasNoCompanionObjectException] should be thrownBy { elaborate(new SimpleConnector(new EnumWithoutCompanionObj(), new EnumWithoutCompanionObj())) } } it should "fail to instantiate non-literal enums in a companion object" in { - an [Error] should be thrownBy { + an [ExceptionInInitializerError] should be thrownBy { elaborate(new SimpleConnector(new NonLiteralEnumType(), new NonLiteralEnumType())) } } @@ -217,13 +220,13 @@ class StrongEnumSpec extends ChiselFlatSpec { } it should "fail to connect a strong enum to a UInt" in { - an [Exception] should be thrownBy { + a [ChiselException] should be thrownBy { elaborate(new SimpleConnector(EnumExample(), UInt())) } } it should "fail to connect enums of different types" in { - an [Exception] should be thrownBy { + an [ChiselException] should be thrownBy { elaborate(new SimpleConnector(EnumExample(), OtherEnum())) } } @@ -237,7 +240,7 @@ class StrongEnumSpec extends ChiselFlatSpec { } it should "catch illegal literal casts to enums" in { - an [Exception] should be thrownBy { + an [ChiselException] should be thrownBy { elaborate(new CastToInvalidEnumTester) } } @@ -247,7 +250,7 @@ class StrongEnumSpec extends ChiselFlatSpec { } it should "fail to compare enums of different types" in { - an [Exception] should be thrownBy { + an [EnumTypeMismatchException] should be thrownBy { elaborate(new InvalidEnumOpsTester) } } From 300fbc97e65dd0a34aa198bbea5e7a52f768cd94 Mon Sep 17 00:00:00 2001 From: Hasan Genc Date: Tue, 18 Sep 2018 23:46:05 -0700 Subject: [PATCH 04/13] Fixed bug where an enum's global annotation would not be set if it was used in multiple circuits Made styling changes to StrongEnum.scala --- .../main/scala/chisel3/core/StrongEnum.scala | 37 ++++++++++++------- src/test/scala/chiselTests/StrongEnum.scala | 26 +++++++++---- 2 files changed, 42 insertions(+), 21 deletions(-) diff --git a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala index 15c4dc70426..b9b0979d809 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala @@ -1,3 +1,5 @@ +// See LICENSE for license details. + package chisel3.core import scala.language.experimental.macros @@ -107,8 +109,8 @@ abstract class EnumType(selfAnnotating: Boolean = true) extends Element { override def do_asUInt(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): UInt = pushOp(DefPrim(sourceInfo, UInt(width), AsUIntOp, ref)) - val companionModule = currentMirror.reflect(this).symbol.companion.asModule - val companionObject = + private val companionModule = currentMirror.reflect(this).symbol.companion.asModule + private val companionObject = try { currentMirror.reflectModule(companionModule).instance.asInstanceOf[StrongEnum[this.type]] } catch { @@ -129,16 +131,26 @@ abstract class EnumType(selfAnnotating: Boolean = true) extends Element { // If we try to annotate something that is bound to a literal, we get a FIRRTL annotation exception. // To workaround that, we only annotate enums that are not bound to literals. - if (selfAnnotating && !litOption.isDefined) - annotate(EnumComponentChiselAnnotation(this, enumTypeName)) + if (selfAnnotating && !litOption.isDefined) { + annotateEnum() + } + } + + private def annotateEnum(): Unit = { + annotate(EnumComponentChiselAnnotation(this, enumTypeName)) + + if (!Builder.annotations.contains(companionObject.globalAnnotation)) { + annotate(companionObject.globalAnnotation) + } } private def enumTypeName: String = getClass.getName // TODO: See if there is a way to catch this at compile-time def checkTypeEquivalency(that: EnumType): Unit = - if (!typeEquivalent(that)) + if (!typeEquivalent(that)) { throw EnumTypeMismatchException(s"${this.getClass.getName} and ${that.getClass.getName} are different enum types") + } def toPrintable: Printable = FullName(this) // TODO: Find a better pretty printer } @@ -162,7 +174,6 @@ abstract class StrongEnum[T <: EnumType : ClassTag] { private def getEnumNames(implicit ct: ClassTag[T]): Seq[String] = { val mirror = runtimeMirror(this.getClass.getClassLoader) - val reflection = mirror.reflect(this) // We use Java reflection to get all the enum fields, and then we use Scala reflection to sort them in declaration // order. TODO: Use only Scala reflection here @@ -177,9 +188,8 @@ abstract class StrongEnum[T <: EnumType : ClassTag] { private def bindAllEnums(): Unit = (enum_instances, enum_values).zipped.foreach((inst, v) => inst.bindToLiteral(v.U(width))) - private def createAnnotation(): Unit = - annotate(EnumDefChiselAnnotation(enumTypeName, - (enum_names, enum_values.map(_.U(width))).zipped.toMap)) + private[core] def globalAnnotation: EnumDefChiselAnnotation = + EnumDefChiselAnnotation(enumTypeName, (enum_names, enum_values.map(_.U(width))).zipped.toMap) private def newEnum()(implicit ct: ClassTag[T]): T = ct.runtimeClass.newInstance.asInstanceOf[T] @@ -200,10 +210,9 @@ abstract class StrongEnum[T <: EnumType : ClassTag] { width = (1 max id.bitLength).W id += 1 - // Check whether we've instantiated all the enums + // Instantiate all the enums when Value is called for the last time if (enum_instances.length == enum_names.length && isTopLevelConstructor) { bindAllEnums() - createAnnotation() } result @@ -225,8 +234,9 @@ abstract class StrongEnum[T <: EnumType : ClassTag] { def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): T = { if (n.litOption.isDefined) { - if (!enum_values.contains(n.litValue)) + if (!enum_values.contains(n.litValue)) { throwException(s"${n.litValue}.U is not a valid value for $enumTypeName") + } val result = newEnum() result.bindToLiteral(n) @@ -270,7 +280,8 @@ abstract class StrongEnum[T <: EnumType : ClassTag] { constructor.setAccessible(true) val childInstance = constructor.newInstance() - if (childInstance.enum_names.length != childInstance.enum_instances.length) + if (childInstance.enum_names.length != childInstance.enum_instances.length) { throw IllegalDefinitionOfEnumException(s"$enumTypeName defined illegally. Did you forget to call Value when defining a new enum?") + } } } diff --git a/src/test/scala/chiselTests/StrongEnum.scala b/src/test/scala/chiselTests/StrongEnum.scala index 605a422422a..cca0202120e 100644 --- a/src/test/scala/chiselTests/StrongEnum.scala +++ b/src/test/scala/chiselTests/StrongEnum.scala @@ -88,17 +88,20 @@ object StrongEnumFSM { } class StrongEnumFSM extends Module { + import StrongEnumFSM.State + import StrongEnumFSM.State._ + // This FSM detects two 1's one after the other val io = IO(new Bundle { val in = Input(Bool()) val out = Output(Bool()) + val state = Output(State()) }) - import StrongEnumFSM.State._ - val state = RegInit(sNone) io.out := (state === sTwo1s) + io.state := state switch (state) { is (sNone) { @@ -269,15 +272,20 @@ class StrongEnumSpec extends ChiselFlatSpec { } class StrongEnumAnnotationSpec extends FreeSpec with Matchers { + import EnumAnnotations._ + "Test that strong enums annotate themselves appropriately" in { Driver.execute(Array("--target-dir", "test_run_dir"), () => new StrongEnumFSM) match { case ChiselExecutionSuccess(Some(circuit), emitted, _) => val annos = circuit.annotations.map(_.toFirrtl) + val enumDefAnnos = annos.collect { case a: EnumDefAnnotation => a } + val enumCompAnnos = annos.collect { case a: EnumComponentAnnotation => a } + // Check that the global annotation is correct - annos.exists { - case EnumAnnotations.EnumDefAnnotation(name, map) => + enumDefAnnos.exists { + case EnumDefAnnotation(name, map) => name.endsWith("State") && map.size == StrongEnumFSM.State.correct_annotation_map.size && map.forall { @@ -293,12 +301,14 @@ class StrongEnumAnnotationSpec extends FreeSpec with Matchers { } should be(true) // Check that the component annotations are correct - annos.exists { - case EnumAnnotations.EnumComponentAnnotation(target, enumName) => + enumCompAnnos.count { + case EnumComponentAnnotation(target, enumName) => val ComponentName(targetName, _) = target - targetName == "state" && enumName.endsWith("State") + (targetName == "state" && enumName.endsWith("State")) || + (targetName == "io.state" && enumName.endsWith("State")) case _ => false - } should be(true) + } should be(2) + case _ => assert(false) } From 863409fa7e0f64c7da4e7af831ccd15a49e5fd8a Mon Sep 17 00:00:00 2001 From: Hasan Genc Date: Wed, 19 Sep 2018 16:10:25 -0700 Subject: [PATCH 05/13] Reverted accidental changes to the AnnotatingDiamond test --- .../chiselTests/AnnotatingDiamondSpec.scala | 62 +++---------------- 1 file changed, 9 insertions(+), 53 deletions(-) diff --git a/src/test/scala/chiselTests/AnnotatingDiamondSpec.scala b/src/test/scala/chiselTests/AnnotatingDiamondSpec.scala index 4b8d82daa4f..e88d475e158 100644 --- a/src/test/scala/chiselTests/AnnotatingDiamondSpec.scala +++ b/src/test/scala/chiselTests/AnnotatingDiamondSpec.scala @@ -3,13 +3,16 @@ package chiselTests import chisel3._ -import chisel3.core.{EnumAnnotations, FixedPoint} -import chisel3.experimental.{ChiselAnnotation, RunFirrtlTransform, annotate} +import chisel3.experimental.{annotate, ChiselAnnotation, RunFirrtlTransform} import chisel3.internal.InstanceId -import chisel3.internal.firrtl.BinaryPoint import chisel3.testers.BasicTester import firrtl.{CircuitState, LowForm, Transform} -import firrtl.annotations.{Annotation, ModuleName, Named, SingleTargetAnnotation} +import firrtl.annotations.{ + Annotation, + SingleTargetAnnotation, + ModuleName, + Named +} import org.scalatest._ /** These annotations and the IdentityTransform class serve as an example of how to write a @@ -95,39 +98,14 @@ class ModB(widthB: Int) extends Module { modC.io.in := io.in io.out := modC.io.out - val reg = RegInit(MyEnum.enum) - identify(io.in, s"modB.io.in annotated from inside modB") } -class MyEnum extends EnumType -object MyEnum extends StrongEnum[MyEnum] { - val enum, e2 = Value - val e3 = Value(5.U) - val e4, e10, e11 = Value -} - -/*class OtherEnum extends EnumType -object OtherEnum extends StrongEnum[OtherEnum] { - val err = Value -}*/ - class TopOfDiamond extends Module { val io = IO(new Bundle { val in = Input(UInt(32.W)) - // val out = Output(UInt(32.W)) - val out = Output(MyEnum()) + val out = Output(UInt(32.W)) }) - - val wire = WireInit(MyEnum.enum) - println(MyEnum.e11.litValue) - - val uiWire = WireInit(1.U) - - wire := MyEnum(uiWire)// OtherEnum.err - - val m = Mux(wire === MyEnum.enum, wire, MyEnum.e2) - val x = Reg(UInt(32.W)) val y = Reg(UInt(32.W)) @@ -139,8 +117,7 @@ class TopOfDiamond extends Module { modB.io.in := x y := modA.io.out + modB.io.out - // io.out := y - io.out := m//.asUInt() + io.out := y identify(this, s"TopOfDiamond\nWith\nSome new lines") @@ -179,27 +156,6 @@ class AnnotatingDiamondSpec extends FreeSpec with Matchers { case IdentityAnnotation(ModuleName("ModC_1", _), "ModC(32)") => true case _ => false } should be (1) - - println(s"Enum defs:") - annos.foreach { - case EnumAnnotations.EnumDefAnnotation(name, m) => - print(s"\t$name: ") - for ((k,v) <- m) { - print(s"($k -> ${v.litValue()}), ") - } - println() - case _ => - } - - println(s"Enum comps:") - annos.foreach { - case EnumAnnotations.EnumComponentAnnotation(target, eName) => - println(s"\t$target: $eName") - case _ => - } - - println("\n\n-----------\n\n") - println(emitted) case _ => assert(false) } From 8999553f4062c98a2a8f1baceb9a993fe0653ad8 Mon Sep 17 00:00:00 2001 From: Hasan Genc Date: Wed, 19 Sep 2018 16:25:08 -0700 Subject: [PATCH 06/13] Changed the API for casting non-literal UInts to enums Added an isValid function that checks whether or not enums have valid values Calling getWidth on an enum's companion object now returns a BigInt instead of an Int --- .../main/scala/chisel3/core/StrongEnum.scala | 53 ++++-- src/test/scala/chiselTests/StrongEnum.scala | 164 +++++++++++++----- 2 files changed, 159 insertions(+), 58 deletions(-) diff --git a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala index b9b0979d809..153141a9b74 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala @@ -45,7 +45,7 @@ import EnumExceptions._ import EnumAnnotations._ abstract class EnumType(selfAnnotating: Boolean = true) extends Element { - def cloneType: this.type = getClass.getConstructor().newInstance().asInstanceOf[this.type] + override def cloneType: this.type = getClass.getConstructor().newInstance().asInstanceOf[this.type] private[core] override def topBindingOpt: Option[TopBinding] = super.topBindingOpt match { // Translate Bundle lit bindings to Element lit bindings @@ -120,6 +120,22 @@ abstract class EnumType(selfAnnotating: Boolean = true) extends Element { private[chisel3] override def width: Width = companionObject.width + def isValid: Bool = { + if (!companionObject.finishedInstantiation) + throwException(s"Not all enums values have been defined yet") + + if (litOption.isDefined) { + true.B + } else { + def mux_builder(enums: Seq[this.type]): Bool = enums match { + case Nil => false.B + case e :: es => Mux(this === e, true.B, mux_builder(es)) + } + + mux_builder(companionObject.all) + } + } + private[core] def bindToLiteral(bits: UInt): Unit = { val litNum = bits.litOption.get val lit = ULit(litNum, width) // We must make sure to use the enum's width, rather than the UInt's width @@ -191,6 +207,9 @@ abstract class StrongEnum[T <: EnumType : ClassTag] { private[core] def globalAnnotation: EnumDefChiselAnnotation = EnumDefChiselAnnotation(enumTypeName, (enum_names, enum_values.map(_.U(width))).zipped.toMap) + private[core] def finishedInstantiation: Boolean = + enum_names.length == enum_instances.length + private def newEnum()(implicit ct: ClassTag[T]): T = ct.runtimeClass.newInstance.asInstanceOf[T] @@ -198,7 +217,7 @@ abstract class StrongEnum[T <: EnumType : ClassTag] { // the companion class's name in a more robust way. private val enumTypeName = getClass.getName.init - def getWidth: BigInt = width.get + def getWidth: Int = width.get def all: List[T] = enum_instances.toList @@ -233,16 +252,28 @@ abstract class StrongEnum[T <: EnumType : ClassTag] { def apply(): T = newEnum() def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): T = { - if (n.litOption.isDefined) { - if (!enum_values.contains(n.litValue)) { - throwException(s"${n.litValue}.U is not a valid value for $enumTypeName") - } + if (!n.litOption.isDefined) { + throwException(s"Illegal cast from non-literal UInt to $enumTypeName. Use castFromNonLit instead") + } else if (!enum_values.contains(n.litValue)) { + throwException(s"${n.litValue}.U is not a valid value for $enumTypeName") + } - val result = newEnum() - result.bindToLiteral(n) - result + val result = newEnum() + result.bindToLiteral(n) + result + } + + def castFromNonLit(n: UInt): T = { + if (!n.isWidthKnown) { + throwException(s"Non-literal UInts being cast to $enumTypeName must have a defined width") + } else if (n.getWidth > this.getWidth) { + throwException(s"The UInt being cast to $enumTypeName is wider than $enumTypeName's width ($getWidth)") + } + + if (n.litOption.isDefined) { + apply(n) } else { - Builder.warning(s"A non-literal UInt is being cast to $enumTypeName. No automatic bounds checking will be done!") + Builder.warning(s"A non-literal UInt is being cast to $enumTypeName. You can check that the value is legal by calling isValid") val glue = Wire(new UnsafeEnum(width)) glue := n @@ -280,7 +311,7 @@ abstract class StrongEnum[T <: EnumType : ClassTag] { constructor.setAccessible(true) val childInstance = constructor.newInstance() - if (childInstance.enum_names.length != childInstance.enum_instances.length) { + if (!childInstance.finishedInstantiation) { throw IllegalDefinitionOfEnumException(s"$enumTypeName defined illegally. Did you forget to call Value when defining a new enum?") } } diff --git a/src/test/scala/chiselTests/StrongEnum.scala b/src/test/scala/chiselTests/StrongEnum.scala index cca0202120e..427cbf4d210 100644 --- a/src/test/scala/chiselTests/StrongEnum.scala +++ b/src/test/scala/chiselTests/StrongEnum.scala @@ -4,6 +4,7 @@ package chiselTests import chisel3._ import chisel3.core.{EnumAnnotations, EnumExceptions} +import chisel3.internal.firrtl.UnknownWidth import chisel3.util._ import chisel3.testers.BasicTester import firrtl.annotations.ComponentName @@ -48,13 +49,36 @@ class CastToUInt extends Module { io.out := io.in.asUInt() } -class CastToEnum extends Module { +class CastFromLit(in: UInt) extends Module { val io = IO(new Bundle { - val in = Input(UInt()) val out = Output(EnumExample()) + val valid = Output(Bool()) }) - io.out := EnumExample(io.in) + io.out := EnumExample(in) + io.valid := io.out.isValid +} + +class CastFromNonLit extends Module { + val io = IO(new Bundle { + val in = Input(UInt(EnumExample.getWidth.W)) + val out = Output(EnumExample()) + val valid = Output(Bool()) + }) + + io.out := EnumExample.castFromNonLit(io.in) + io.valid := io.out.isValid +} + +class CastFromNonLitWidth(w: Option[Int] = None) extends Module { + val width = if (w.isDefined) w.get.W else UnknownWidth() + + override val io = IO(new Bundle { + val in = Input(UInt(width)) + val out = Output(EnumExample()) + }) + + io.out := EnumExample.castFromNonLit(io.in) } class EnumOps(xType: EnumType, yType: EnumType) extends Module { @@ -133,20 +157,41 @@ class CastToUIntTester extends BasicTester { stop() } -class CastToEnumTester extends BasicTester { +class CastFromLitTester extends BasicTester { for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) { - val mod = Module(new CastToEnum) + val mod = Module(new CastFromLit(lit)) + assert(mod.io.out === enum) + assert(mod.io.valid === true.B) + } + stop() +} + +class CastFromNonLitTester extends BasicTester { + for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) { + val mod = Module(new CastFromNonLit) mod.io.in := lit assert(mod.io.out === enum) + assert(mod.io.valid === true.B) + } + + import scala.util.Random + val invalid_values = (for(i <- 0 until 200) yield Random.nextInt((1 << EnumExample.getWidth)-1)). + filter(!EnumExample.litValues.map(_.litValue).contains(_)). + map(_.U) + + for (invalid_val <- invalid_values) { + val mod = Module(new CastFromNonLit) + mod.io.in := invalid_val + + assert(mod.io.valid === false.B) } + stop() } class CastToInvalidEnumTester extends BasicTester { val invalid_value: UInt = EnumExample.litValues.last + 1.U - val mod = Module(new CastToEnum { - io.out := invalid_value - }) + Module(new CastFromLit(invalid_value)) } class EnumOpsTester extends BasicTester { @@ -238,16 +283,35 @@ class StrongEnumSpec extends ChiselFlatSpec { assertTesterPasses(new CastToUIntTester) } - it should "cast UInts to enums correctly" in { - assertTesterPasses(new CastToEnumTester) + it should "cast literal UInts to enums correctly" in { + assertTesterPasses(new CastFromLitTester) } - it should "catch illegal literal casts to enums" in { - an [ChiselException] should be thrownBy { + it should "cast non-literal UInts to enums correctly and detect illegal casts" in { + assertTesterPasses(new CastFromNonLitTester) + } + + it should "prevent illegal literal casts to enums" in { + a [ChiselException] should be thrownBy { elaborate(new CastToInvalidEnumTester) } } + it should "only allow non-literal casts to enums if the width is smaller than or equal to the enum width" in { + for (w <- 0 to EnumExample.getWidth) + elaborate(new CastFromNonLitWidth(Some(w))) + + a [ChiselException] should be thrownBy { + elaborate(new CastFromNonLitWidth) + } + + for (w <- (EnumExample.getWidth+1) to (EnumExample.getWidth+100)) { + a [ChiselException] should be thrownBy { + elaborate(new CastFromNonLitWidth(Some(w))) + } + } + } + it should "execute enum comparison operations correctly" in { assertTesterPasses(new EnumOpsTester) } @@ -276,41 +340,47 @@ class StrongEnumAnnotationSpec extends FreeSpec with Matchers { "Test that strong enums annotate themselves appropriately" in { - Driver.execute(Array("--target-dir", "test_run_dir"), () => new StrongEnumFSM) match { - case ChiselExecutionSuccess(Some(circuit), emitted, _) => - val annos = circuit.annotations.map(_.toFirrtl) - - val enumDefAnnos = annos.collect { case a: EnumDefAnnotation => a } - val enumCompAnnos = annos.collect { case a: EnumComponentAnnotation => a } - - // Check that the global annotation is correct - enumDefAnnos.exists { - case EnumDefAnnotation(name, map) => - name.endsWith("State") && - map.size == StrongEnumFSM.State.correct_annotation_map.size && - map.forall { - case (k, v) => - val correctValue = StrongEnumFSM.State.correct_annotation_map(k) - - val correctValLit = correctValue.litValue() - val vLitValue = v.litValue() - - correctValue.getWidth == v.getWidth && correctValue.litValue() == v.litValue() - } - case _ => false - } should be(true) - - // Check that the component annotations are correct - enumCompAnnos.count { - case EnumComponentAnnotation(target, enumName) => - val ComponentName(targetName, _) = target - (targetName == "state" && enumName.endsWith("State")) || - (targetName == "io.state" && enumName.endsWith("State")) - case _ => false - } should be(2) - - case _ => - assert(false) + def test() = { + Driver.execute(Array("--target-dir", "test_run_dir"), () => new StrongEnumFSM) match { + case ChiselExecutionSuccess(Some(circuit), emitted, _) => + val annos = circuit.annotations.map(_.toFirrtl) + + val enumDefAnnos = annos.collect { case a: EnumDefAnnotation => a } + val enumCompAnnos = annos.collect { case a: EnumComponentAnnotation => a } + + // Check that the global annotation is correct + enumDefAnnos.exists { + case EnumDefAnnotation(name, map) => + name.endsWith("State") && + map.size == StrongEnumFSM.State.correct_annotation_map.size && + map.forall { + case (k, v) => + val correctValue = StrongEnumFSM.State.correct_annotation_map(k) + + val correctValLit = correctValue.litValue() + val vLitValue = v.litValue() + + correctValue.getWidth == v.getWidth && correctValue.litValue() == v.litValue() + } + case _ => false + } should be(true) + + // Check that the component annotations are correct + enumCompAnnos.count { + case EnumComponentAnnotation(target, enumName) => + val ComponentName(targetName, _) = target + (targetName == "state" && enumName.endsWith("State")) || + (targetName == "io.state" && enumName.endsWith("State")) + case _ => false + } should be(2) + + case _ => + assert(false) + } } + + // We run this test twice, to test for an older bug where only the first circuit would be annotated + test() + test() } } From 6c0793bf82f5740d39b72e65e62dc72a203b2976 Mon Sep 17 00:00:00 2001 From: Hasan Genc Date: Wed, 19 Sep 2018 16:36:17 -0700 Subject: [PATCH 07/13] Casting a literal to an enum using the StrongEnum.castFromNonLit(n) function is now simply a wrapper for StrongEnum.apply(n) --- .../src/main/scala/chisel3/core/StrongEnum.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala index 153141a9b74..438eb253ab1 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala @@ -264,14 +264,12 @@ abstract class StrongEnum[T <: EnumType : ClassTag] { } def castFromNonLit(n: UInt): T = { - if (!n.isWidthKnown) { + if (n.litOption.isDefined) { + apply(n) + } else if (!n.isWidthKnown) { throwException(s"Non-literal UInts being cast to $enumTypeName must have a defined width") } else if (n.getWidth > this.getWidth) { throwException(s"The UInt being cast to $enumTypeName is wider than $enumTypeName's width ($getWidth)") - } - - if (n.litOption.isDefined) { - apply(n) } else { Builder.warning(s"A non-literal UInt is being cast to $enumTypeName. You can check that the value is legal by calling isValid") From f67ffb69d5e13bbbf29df345186a1936d3ce6ce2 Mon Sep 17 00:00:00 2001 From: Hasan Genc Date: Wed, 19 Sep 2018 16:49:06 -0700 Subject: [PATCH 08/13] Fixed compilation bug --- chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala index 438eb253ab1..0802c9bceb7 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala @@ -120,14 +120,14 @@ abstract class EnumType(selfAnnotating: Boolean = true) extends Element { private[chisel3] override def width: Width = companionObject.width - def isValid: Bool = { + def isValid(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = { if (!companionObject.finishedInstantiation) throwException(s"Not all enums values have been defined yet") if (litOption.isDefined) { true.B } else { - def mux_builder(enums: Seq[this.type]): Bool = enums match { + def mux_builder(enums: List[this.type]): Bool = enums match { case Nil => false.B case e :: es => Mux(this === e, true.B, mux_builder(es)) } @@ -263,7 +263,7 @@ abstract class StrongEnum[T <: EnumType : ClassTag] { result } - def castFromNonLit(n: UInt): T = { + def castFromNonLit(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): T = { if (n.litOption.isDefined) { apply(n) } else if (!n.isWidthKnown) { From 50e3ffda5655f37e12dcc12a8b4d675e790ef20b Mon Sep 17 00:00:00 2001 From: Hasan Genc Date: Fri, 21 Sep 2018 16:55:48 -0700 Subject: [PATCH 09/13] * Added "next" method to EnumType * Renamed "castFromNonLit" to "fromBits" --- .../main/scala/chisel3/core/StrongEnum.scala | 45 +++++++++---- src/test/scala/chiselTests/StrongEnum.scala | 63 +++++++++++++++++-- 2 files changed, 92 insertions(+), 16 deletions(-) diff --git a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala index 0802c9bceb7..1c831a5ea21 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala @@ -116,23 +116,44 @@ abstract class EnumType(selfAnnotating: Boolean = true) extends Element { } catch { case ex: java.lang.ClassNotFoundException => throw EnumHasNoCompanionObjectException(s"$enumTypeName's companion object was not found") + case default => throw default } private[chisel3] override def width: Width = companionObject.width def isValid(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = { - if (!companionObject.finishedInstantiation) + if (!companionObject.finishedInstantiation) { throwException(s"Not all enums values have been defined yet") + } if (litOption.isDefined) { true.B } else { - def mux_builder(enums: List[this.type]): Bool = enums match { + def muxBuilder(enums: List[this.type]): Bool = enums match { case Nil => false.B - case e :: es => Mux(this === e, true.B, mux_builder(es)) + case e :: es => Mux(this === e, true.B, muxBuilder(es)) + } + + muxBuilder(companionObject.all) + } + } + + def next(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = { + if (!companionObject.finishedInstantiation) { + throwException(s"Not all enums values have been defined yet") + } + + if (litOption.isDefined) { + val index = companionObject.all.indexOf(this) + if (index < companionObject.all.length-1) companionObject.all(index+1) + else companionObject.all.head + } else { + def muxBuilder(enums: List[this.type], first_enum: this.type): this.type = enums match { + case e :: Nil => first_enum + case e :: e_next :: es => Mux(this === e, e_next, muxBuilder(e_next :: es, first_enum)) } - mux_builder(companionObject.all) + muxBuilder(companionObject.all, companionObject.all.head) } } @@ -253,17 +274,19 @@ abstract class StrongEnum[T <: EnumType : ClassTag] { def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): T = { if (!n.litOption.isDefined) { - throwException(s"Illegal cast from non-literal UInt to $enumTypeName. Use castFromNonLit instead") - } else if (!enum_values.contains(n.litValue)) { - throwException(s"${n.litValue}.U is not a valid value for $enumTypeName") + throwException(s"Illegal cast from non-literal UInt to $enumTypeName. Use fromBits instead") } - val result = newEnum() - result.bindToLiteral(n) - result + val result = enum_instances.find(_.litValue == n.litValue) + + if (result.isEmpty) { + throwException(s"${n.litValue}.U is not a valid value for $enumTypeName") + } else { + result.get + } } - def castFromNonLit(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): T = { + def fromBits(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): T = { if (n.litOption.isDefined) { apply(n) } else if (!n.isWidthKnown) { diff --git a/src/test/scala/chiselTests/StrongEnum.scala b/src/test/scala/chiselTests/StrongEnum.scala index 427cbf4d210..0cd5b333e2a 100644 --- a/src/test/scala/chiselTests/StrongEnum.scala +++ b/src/test/scala/chiselTests/StrongEnum.scala @@ -31,6 +31,20 @@ object NonLiteralEnumType extends StrongEnum[NonLiteralEnumType] { val nonLit = Value(UInt()) } +class EnumWithEarlyIsValid extends EnumType +object EnumWithEarlyIsValid extends StrongEnum[EnumWithEarlyIsValid] { + val s1 = Value + val isV = s1.isValid + val s2 = Value +} + +class EnumWithEarlyNext extends EnumType +object EnumWithEarlyNext extends StrongEnum[EnumWithEarlyNext] { + val s1 = Value + val n = s1.next + val s2 = Value +} + class SimpleConnector(inType: Data, outType: Data) extends Module { val io = IO(new Bundle { val in = Input(inType) @@ -66,7 +80,7 @@ class CastFromNonLit extends Module { val valid = Output(Bool()) }) - io.out := EnumExample.castFromNonLit(io.in) + io.out := EnumExample.fromBits(io.in) io.valid := io.out.isValid } @@ -78,7 +92,7 @@ class CastFromNonLitWidth(w: Option[Int] = None) extends Module { val out = Output(EnumExample()) }) - io.out := EnumExample.castFromNonLit(io.in) + io.out := EnumExample.fromBits(io.in) } class EnumOps(xType: EnumType, yType: EnumType) extends Module { @@ -174,8 +188,7 @@ class CastFromNonLitTester extends BasicTester { assert(mod.io.valid === true.B) } - import scala.util.Random - val invalid_values = (for(i <- 0 until 200) yield Random.nextInt((1 << EnumExample.getWidth)-1)). + val invalid_values = (1 until (1 << EnumExample.getWidth)). filter(!EnumExample.litValues.map(_.litValue).contains(_)). map(_.U) @@ -227,18 +240,42 @@ class IsLitTester extends BasicTester { stop() } +class NextTester extends BasicTester { + for ((e,n) <- EnumExample.all.zip(EnumExample.litValues.tail :+ EnumExample.litValues.head)) { + assert(e.next.litValue == n.litValue) + val w = WireInit(e) + assert(w.next === EnumExample(n)) + } + stop() +} + +class WidthTester extends BasicTester { + assert(EnumExample.getWidth == EnumExample.litValues.last.getWidth) + assert(EnumExample.all.forall(_.getWidth == EnumExample.litValues.last.getWidth)) + assert(EnumExample.all.forall{e => + val w = WireInit(e) + w.getWidth == EnumExample.litValues.last.getWidth + }) + stop() +} + class StrongEnumFSMTester extends BasicTester { + import StrongEnumFSM.State + import StrongEnumFSM.State._ + val dut = Module(new StrongEnumFSM) // Inputs and expected results val inputs: Vec[Bool] = VecInit(false.B, true.B, false.B, true.B, true.B, true.B, false.B, true.B, true.B, false.B) val expected: Vec[Bool] = VecInit(false.B, false.B, false.B, false.B, false.B, true.B, true.B, false.B, false.B, true.B) + val expected_state: Vec[State] = VecInit(sNone, sNone, sOne1, sNone, sOne1, sTwo1s, sTwo1s, sNone, sOne1, sTwo1s) val cntr = Counter(inputs.length) val cycle = cntr.value dut.io.in := inputs(cycle) assert(dut.io.out === expected(cycle)) + assert(dut.io.state === expected_state(cycle)) when(cntr.inc()) { stop() @@ -263,6 +300,18 @@ class StrongEnumSpec extends ChiselFlatSpec { } } + it should "fail to call isValid early" in { + an [ExceptionInInitializerError] should be thrownBy { + elaborate(new SimpleConnector(EnumWithEarlyIsValid(), EnumWithEarlyIsValid())) + } + } + + it should "fail to call next early" in { + an [ExceptionInInitializerError] should be thrownBy { + elaborate(new SimpleConnector(EnumWithEarlyNext(), EnumWithEarlyNext())) + } + } + it should "connect enums of the same type" in { elaborate(new SimpleConnector(EnumExample(), EnumExample())) } @@ -326,8 +375,12 @@ class StrongEnumSpec extends ChiselFlatSpec { assertTesterPasses(new IsLitTester) } + it should "return the correct next values for enums" in { + assertTesterPasses(new NextTester) + } + it should "return the correct widths for enums" in { - EnumExample.getWidth == EnumExample.litValues.last.getWidth + assertTesterPasses(new WidthTester) } "StrongEnum FSM" should "work" in { From cd903d6f2517d45489fd61f2582a412fedc00ecd Mon Sep 17 00:00:00 2001 From: Hasan Genc Date: Fri, 21 Sep 2018 17:10:04 -0700 Subject: [PATCH 10/13] The FSM example in the test/scala/cookbook now uses StrongEnums --- src/test/scala/cookbook/FSM.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/test/scala/cookbook/FSM.scala b/src/test/scala/cookbook/FSM.scala index 22cf8059e8f..688cfecb6bc 100644 --- a/src/test/scala/cookbook/FSM.scala +++ b/src/test/scala/cookbook/FSM.scala @@ -7,16 +7,24 @@ import chisel3.util._ /* ### How do I create a finite state machine? * - * Use Chisel Enum to construct the states and switch & is to construct the FSM + * Use Chisel StrongEnum to construct the states and switch & is to construct the FSM * control logic */ + +object DetectTwoOnes { + class State extends EnumType + object State extends StrongEnum[State] { + val sNone, sOne1, sTwo1s = Value + } +} + class DetectTwoOnes extends Module { val io = IO(new Bundle { val in = Input(Bool()) val out = Output(Bool()) }) - val sNone :: sOne1 :: sTwo1s :: Nil = Enum(3) + import DetectTwoOnes.State._ val state = RegInit(sNone) io.out := (state === sTwo1s) From 6537628643db210f790c4d99acab247ef04bee1e Mon Sep 17 00:00:00 2001 From: Hasan Genc Date: Mon, 24 Sep 2018 16:44:55 -0700 Subject: [PATCH 11/13] * Changed strong enum API, so that users no longer have to declare both a class and a companion object for each strong enum * Strong enums do not have to be static any longer --- .../main/scala/chisel3/core/MonoConnect.scala | 2 +- .../main/scala/chisel3/core/StrongEnum.scala | 228 +++++++----------- .../chisel3/internal/firrtl/Converter.scala | 2 +- .../chisel3/internal/firrtl/Emitter.scala | 4 +- src/main/scala/chisel3/package.scala | 3 +- src/test/scala/chiselTests/StrongEnum.scala | 94 +++----- src/test/scala/cookbook/FSM.scala | 13 +- 7 files changed, 138 insertions(+), 208 deletions(-) diff --git a/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala b/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala index 2383e093bda..c9420ba70de 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala @@ -81,7 +81,7 @@ object MonoConnect { elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) case (sink_e: EnumType, source_e: UnsafeEnum) => elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) - case (sink_e: EnumType, source_e: EnumType) if sink_e.getClass.equals(source_e.getClass) => + case (sink_e: EnumType, source_e: EnumType) if sink_e.typeEquivalent(source_e) => elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) case (sink_e: UnsafeEnum, source_e: UInt) => elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) diff --git a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala index 1c831a5ea21..8b9d6779ef2 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala @@ -3,9 +3,7 @@ package chisel3.core import scala.language.experimental.macros -import scala.reflect.ClassTag -import scala.reflect.runtime.currentMirror -import scala.reflect.runtime.universe.{MethodSymbol, runtimeMirror} +import scala.reflect.macros.blackbox.Context import scala.collection.mutable import chisel3.internal.Builder.pushOp @@ -15,15 +13,6 @@ import chisel3.internal.sourceinfo._ import chisel3.internal.{Builder, InstanceId, throwException} import firrtl.annotations._ -object EnumExceptions { - case class EnumTypeMismatchException(message: String) extends Exception(message) - case class EnumHasNoCompanionObjectException(message: String) extends Exception(message) - case class NonLiteralEnumException(message: String) extends Exception(message) - case class NonIncreasingEnumException(message: String) extends Exception(message) - case class IllegalDefinitionOfEnumException(message: String) extends Exception(message) - case class IllegalCastToEnumException(message: String) extends Exception(message) - case class NoEmptyConstructorException(message: String) extends Exception(message) -} object EnumAnnotations { case class EnumComponentAnnotation(target: Named, enumTypeName: String) extends SingleTargetAnnotation[Named] { @@ -34,18 +23,17 @@ object EnumAnnotations { def toFirrtl = EnumComponentAnnotation(target.toNamed, enumTypeName) } - case class EnumDefAnnotation(enumTypeName: String, definition: Map[String, UInt]) extends NoTargetAnnotation + case class EnumDefAnnotation(enumTypeName: String, definition: Map[String, BigInt]) extends NoTargetAnnotation - case class EnumDefChiselAnnotation(enumTypeName: String, definition: Map[String, UInt]) extends ChiselAnnotation { + case class EnumDefChiselAnnotation(enumTypeName: String, definition: Map[String, BigInt]) extends ChiselAnnotation { override def toFirrtl: Annotation = EnumDefAnnotation(enumTypeName, definition) } } - -import EnumExceptions._ import EnumAnnotations._ -abstract class EnumType(selfAnnotating: Boolean = true) extends Element { - override def cloneType: this.type = getClass.getConstructor().newInstance().asInstanceOf[this.type] + +abstract class EnumType(private val factory: EnumFactory, selfAnnotating: Boolean = true) extends Element { + override def cloneType: this.type = factory().asInstanceOf[this.type] private[core] override def topBindingOpt: Option[TopBinding] = super.topBindingOpt match { // Translate Bundle lit bindings to Element lit bindings @@ -78,12 +66,16 @@ abstract class EnumType(selfAnnotating: Boolean = true) extends Element { requireIsHardware(this, "bits operated on") requireIsHardware(other, "bits operated on") - checkTypeEquivalency(other) + if(!this.typeEquivalent(other)) + throwException(s"Enum types are not equivalent: ${this.enumTypeName}, ${other.enumTypeName}") pushOp(DefPrim(sourceInfo, Bool(), op, this.ref, other.ref)) } - private[core] override def typeEquivalent(that: Data): Boolean = this.getClass == that.getClass + private[core] override def typeEquivalent(that: Data): Boolean = { + this.getClass == that.getClass && + this.factory == that.asInstanceOf[EnumType].factory + } // This isn't actually used anywhere (and it would throw an exception anyway). But it has to be defined since we // inherit it from Data. @@ -109,57 +101,41 @@ abstract class EnumType(selfAnnotating: Boolean = true) extends Element { override def do_asUInt(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): UInt = pushOp(DefPrim(sourceInfo, UInt(width), AsUIntOp, ref)) - private val companionModule = currentMirror.reflect(this).symbol.companion.asModule - private val companionObject = - try { - currentMirror.reflectModule(companionModule).instance.asInstanceOf[StrongEnum[this.type]] - } catch { - case ex: java.lang.ClassNotFoundException => - throw EnumHasNoCompanionObjectException(s"$enumTypeName's companion object was not found") - case default => throw default - } - - private[chisel3] override def width: Width = companionObject.width + protected[chisel3] override def width: Width = factory.width def isValid(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = { - if (!companionObject.finishedInstantiation) { - throwException(s"Not all enums values have been defined yet") - } - if (litOption.isDefined) { true.B } else { - def muxBuilder(enums: List[this.type]): Bool = enums match { + def muxBuilder(enums: List[EnumType]): Bool = enums match { case Nil => false.B case e :: es => Mux(this === e, true.B, muxBuilder(es)) } - muxBuilder(companionObject.all) + muxBuilder(factory.all.toList) } } def next(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = { - if (!companionObject.finishedInstantiation) { - throwException(s"Not all enums values have been defined yet") - } - if (litOption.isDefined) { - val index = companionObject.all.indexOf(this) - if (index < companionObject.all.length-1) companionObject.all(index+1) - else companionObject.all.head + val index = factory.all.indexOf(this) + + if (index < factory.all.length-1) + factory.all(index+1).asInstanceOf[this.type] + else + factory.all.head.asInstanceOf[this.type] } else { - def muxBuilder(enums: List[this.type], first_enum: this.type): this.type = enums match { + def muxBuilder(enums: List[EnumType], first_enum: EnumType): EnumType = enums match { case e :: Nil => first_enum case e :: e_next :: es => Mux(this === e, e_next, muxBuilder(e_next :: es, first_enum)) } - muxBuilder(companionObject.all, companionObject.all.head) + muxBuilder(factory.all.toList, factory.all.head).asInstanceOf[this.type] } } - private[core] def bindToLiteral(bits: UInt): Unit = { - val litNum = bits.litOption.get - val lit = ULit(litNum, width) // We must make sure to use the enum's width, rather than the UInt's width + private[core] def bindToLiteral(num: BigInt, w: Width): Unit = { + val lit = ULit(num, w) lit.bindLitArg(this) } @@ -176,104 +152,70 @@ abstract class EnumType(selfAnnotating: Boolean = true) extends Element { private def annotateEnum(): Unit = { annotate(EnumComponentChiselAnnotation(this, enumTypeName)) - if (!Builder.annotations.contains(companionObject.globalAnnotation)) { - annotate(companionObject.globalAnnotation) + if (!Builder.annotations.contains(factory.globalAnnotation)) { + annotate(factory.globalAnnotation) } } - private def enumTypeName: String = getClass.getName - - // TODO: See if there is a way to catch this at compile-time - def checkTypeEquivalency(that: EnumType): Unit = - if (!typeEquivalent(that)) { - throw EnumTypeMismatchException(s"${this.getClass.getName} and ${that.getClass.getName} are different enum types") - } + protected def enumTypeName: String = factory.enumTypeName def toPrintable: Printable = FullName(this) // TODO: Find a better pretty printer } -// This is an enum type that can be connected directly to UInts. It is used as a "glue" to cast non-literal UInts -// to enums. -sealed private[chisel3] class UnsafeEnum(override val width: Width) extends EnumType(selfAnnotating = false) { - override def cloneType: this.type = getClass.getConstructor(classOf[Width]).newInstance(width).asInstanceOf[this.type] -} -private object UnsafeEnum extends StrongEnum[UnsafeEnum] { - override def checkEmptyConstructorExists(): Unit = {} -} - -abstract class StrongEnum[T <: EnumType : ClassTag] { - private var id: BigInt = 0 - private[core] var width: Width = 0.W - - private val enum_names = getEnumNames - private val enum_values = mutable.ArrayBuffer.empty[BigInt] - private val enum_instances = mutable.ArrayBuffer.empty[T] - - private def getEnumNames(implicit ct: ClassTag[T]): Seq[String] = { - val mirror = runtimeMirror(this.getClass.getClassLoader) - // We use Java reflection to get all the enum fields, and then we use Scala reflection to sort them in declaration - // order. TODO: Use only Scala reflection here - val fields = getClass.getDeclaredFields.filter(_.getType == ct.runtimeClass).map(_.getName) - val getters = mirror.classSymbol(this.getClass).toType.members.sorted.collect { - case m: MethodSymbol if m.isGetter => m.name.toString - } +abstract class EnumFactory { + class E extends EnumType(this) - getters.filter(fields.contains(_)) - } + var id: BigInt = 0 + var width: Width = 0.W - private def bindAllEnums(): Unit = - (enum_instances, enum_values).zipped.foreach((inst, v) => inst.bindToLiteral(v.U(width))) + val enum_names = mutable.ArrayBuffer.empty[String] + val enum_values = mutable.ArrayBuffer.empty[BigInt] + val enum_instances = mutable.ArrayBuffer.empty[E] private[core] def globalAnnotation: EnumDefChiselAnnotation = - EnumDefChiselAnnotation(enumTypeName, (enum_names, enum_values.map(_.U(width))).zipped.toMap) - - private[core] def finishedInstantiation: Boolean = - enum_names.length == enum_instances.length + EnumDefChiselAnnotation(enumTypeName, (enum_names, enum_values).zipped.toMap) - private def newEnum()(implicit ct: ClassTag[T]): T = - ct.runtimeClass.newInstance.asInstanceOf[T] - - // TODO: This depends upon undocumented behavior (which, to be fair, is unlikely to change). Use reflection to find - // the companion class's name in a more robust way. - private val enumTypeName = getClass.getName.init + private[core] val enumTypeName = getClass.getName.init def getWidth: Int = width.get - def all: List[T] = enum_instances.toList + def all: Seq[E] = enum_instances.toSeq + + def Value: E = macro EnumMacros.ValImpl + def Value(id: UInt): E = macro EnumMacros.ValCustomImpl - def Value: T = { - val result = newEnum() + def do_Value(names: Seq[String]): E = { + val result = new E + enum_names ++= names.filter(!enum_names.contains(_)) enum_instances.append(result) enum_values.append(id) + // We have to use UnknownWidth here, because we don't actually know what the final width will be + result.bindToLiteral(id, UnknownWidth()) + width = (1 max id.bitLength).W id += 1 - // Instantiate all the enums when Value is called for the last time - if (enum_instances.length == enum_names.length && isTopLevelConstructor) { - bindAllEnums() - } - result } - def Value(id: UInt): T = { + def do_Value(names: Seq[String], id: UInt): E = { // TODO: These throw ExceptionInInitializerError which can be confusing to the user. Get rid of the error, and just // throw an exception if (!id.litOption.isDefined) - throw NonLiteralEnumException(s"$enumTypeName defined with a non-literal type in companion object") - if (id.litValue() <= this.id) - throw NonIncreasingEnumException(s"Enums must be strictly increasing: $enumTypeName") + throwException(s"$enumTypeName defined with a non-literal type") + if (id.litValue() < this.id) + throwException(s"Enums must be strictly increasing: $enumTypeName") this.id = id.litValue() - Value + do_Value(names) } - def apply(): T = newEnum() + def apply(): E = new E - def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): T = { - if (!n.litOption.isDefined) { + def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): E = { + if (n.litOption.isEmpty) { throwException(s"Illegal cast from non-literal UInt to $enumTypeName. Use fromBits instead") } @@ -286,7 +228,7 @@ abstract class StrongEnum[T <: EnumType : ClassTag] { } } - def fromBits(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): T = { + def fromBits(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): E = { if (n.litOption.isDefined) { apply(n) } else if (!n.isWidthKnown) { @@ -298,42 +240,48 @@ abstract class StrongEnum[T <: EnumType : ClassTag] { val glue = Wire(new UnsafeEnum(width)) glue := n - val result = Wire(newEnum()) + val result = Wire(new E) result := glue result } } +} - // StrongEnum basically has a recursive constructor. It instantiates a copy of itself internally, so that it can - // make sure that all EnumType's inside of it were instantiated using the "Value" function. However, in order to - // instantiate its copy, as well as to instantiate new enums, it has to make sure that it has a no-args constructor - // as it won't know what parameters to add otherwise. - - protected def checkEmptyConstructorExists(): Unit = { - try { - implicitly[ClassTag[T]].runtimeClass.getDeclaredConstructor() - getClass.getDeclaredConstructor() - } catch { - case ex: NoSuchMethodException => throw NoEmptyConstructorException(s"$enumTypeName does not have a no-args constructor. Did you declare it inside a class?") - } - } - - private val isTopLevelConstructor: Boolean = { - val stack_trace = Thread.currentThread().getStackTrace - val constructorName = "" - stack_trace.count(se => se.getClassName.equals(getClass.getName) && se.getMethodName.equals(constructorName)) == 1 +object EnumMacros { + def ValImpl(c: Context) : c.Tree = { + import c.universe._ + val names = getNames(c) + q"""this.do_Value(Seq(..$names))""" } - if (isTopLevelConstructor) { - checkEmptyConstructorExists() + def ValCustomImpl(c: Context)(id: c.Expr[UInt]) = { + import c.universe._ + val names = getNames(c) + q"""this.do_Value(Seq(..$names), $id)""" + } - val constructor = getClass.getDeclaredConstructor() - constructor.setAccessible(true) - val childInstance = constructor.newInstance() + // Much thanks to Travis Brown for this solution: + // stackoverflow.com/questions/18450203/retrieve-the-name-of-the-value-a-scala-macro-invocation-will-be-assigned-to + def getNames(c: Context): Seq[String] = { + import c.universe._ - if (!childInstance.finishedInstantiation) { - throw IllegalDefinitionOfEnumException(s"$enumTypeName defined illegally. Did you forget to call Value when defining a new enum?") + val names = c.enclosingClass.collect { + case ValDef(_, name, _, rhs) + if rhs.pos == c.macroApplication.pos => name.decoded } + + if (names.isEmpty) + c.abort(c.enclosingPosition, "Value cannot be called without assigning to an enum") + + names } } + + +// This is an enum type that can be connected directly to UInts. It is used as a "glue" to cast non-literal UInts +// to enums. +private[chisel3] class UnsafeEnum(override val width: Width) extends EnumType(UnsafeEnum, selfAnnotating = false) { + override def cloneType: this.type = new UnsafeEnum(width).asInstanceOf[this.type] +} +private object UnsafeEnum extends EnumFactory diff --git a/src/main/scala/chisel3/internal/firrtl/Converter.scala b/src/main/scala/chisel3/internal/firrtl/Converter.scala index 9e9616c893e..181bdfe8b72 100644 --- a/src/main/scala/chisel3/internal/firrtl/Converter.scala +++ b/src/main/scala/chisel3/internal/firrtl/Converter.scala @@ -2,7 +2,7 @@ package chisel3.internal.firrtl import chisel3._ -import chisel3.core.SpecifiedDirection +import chisel3.core.{SpecifiedDirection, EnumType} import chisel3.experimental._ import chisel3.internal.sourceinfo.{NoSourceInfo, SourceLine, SourceInfo} import firrtl.{ir => fir} diff --git a/src/main/scala/chisel3/internal/firrtl/Emitter.scala b/src/main/scala/chisel3/internal/firrtl/Emitter.scala index 3e4d6f2151b..ac4bf8e7d0f 100644 --- a/src/main/scala/chisel3/internal/firrtl/Emitter.scala +++ b/src/main/scala/chisel3/internal/firrtl/Emitter.scala @@ -2,7 +2,7 @@ package chisel3.internal.firrtl import chisel3._ -import chisel3.core.SpecifiedDirection +import chisel3.core.{SpecifiedDirection, EnumType} import chisel3.experimental._ import chisel3.internal.sourceinfo.{NoSourceInfo, SourceLine} @@ -28,7 +28,7 @@ private class Emitter(circuit: Circuit) { private def emitType(d: Data, clearDir: Boolean = false): String = d match { case d: Clock => "Clock" - case d: EnumType => s"UInt${d.width}" + case d: chisel3.core.EnumType => s"UInt${d.width}" case d: UInt => s"UInt${d.width}" case d: SInt => s"SInt${d.width}" case d: FixedPoint => s"Fixed${d.width}${d.binaryPoint}" diff --git a/src/main/scala/chisel3/package.scala b/src/main/scala/chisel3/package.scala index fd4f4136046..f4e62d3de72 100644 --- a/src/main/scala/chisel3/package.scala +++ b/src/main/scala/chisel3/package.scala @@ -243,8 +243,7 @@ package object chisel3 { // scalastyle:ignore package.object.name object Bool extends BoolFactory val Mux = chisel3.core.Mux - type EnumType = chisel3.core.EnumType - type StrongEnum[T <: EnumType] = chisel3.core.StrongEnum[T] + type ChiselEnum = chisel3.core.EnumFactory val EnumAnnotations = chisel3.core.EnumAnnotations type BlackBox = chisel3.core.BlackBox diff --git a/src/test/scala/chiselTests/StrongEnum.scala b/src/test/scala/chiselTests/StrongEnum.scala index 0cd5b333e2a..d9a17b84bdf 100644 --- a/src/test/scala/chiselTests/StrongEnum.scala +++ b/src/test/scala/chiselTests/StrongEnum.scala @@ -3,46 +3,33 @@ package chiselTests import chisel3._ -import chisel3.core.{EnumAnnotations, EnumExceptions} +import chisel3.core.EnumAnnotations import chisel3.internal.firrtl.UnknownWidth import chisel3.util._ import chisel3.testers.BasicTester import firrtl.annotations.ComponentName import org.scalatest.{FreeSpec, Matchers} -class EnumExample extends EnumType -object EnumExample extends StrongEnum[EnumExample] { +object EnumExample extends ChiselEnum { val e0, e1, e2 = Value + val e100 = Value(100.U) - val e101 = Value + val e101 = Value(101.U) val litValues = List(0.U, 1.U, 2.U, 100.U, 101.U) } -class OtherEnum extends EnumType -object OtherEnum extends StrongEnum[OtherEnum] { +object OtherEnum extends ChiselEnum { val otherEnum = Value } -class EnumWithoutCompanionObj extends EnumType - -class NonLiteralEnumType extends EnumType -object NonLiteralEnumType extends StrongEnum[NonLiteralEnumType] { +object NonLiteralEnumType extends ChiselEnum { val nonLit = Value(UInt()) } -class EnumWithEarlyIsValid extends EnumType -object EnumWithEarlyIsValid extends StrongEnum[EnumWithEarlyIsValid] { - val s1 = Value - val isV = s1.isValid - val s2 = Value -} - -class EnumWithEarlyNext extends EnumType -object EnumWithEarlyNext extends StrongEnum[EnumWithEarlyNext] { - val s1 = Value - val n = s1.next - val s2 = Value +object NonIncreasingEnum extends ChiselEnum { + val x = Value(2.U) + val y = Value(2.U) } class SimpleConnector(inType: Data, outType: Data) extends Module { @@ -95,10 +82,10 @@ class CastFromNonLitWidth(w: Option[Int] = None) extends Module { io.out := EnumExample.fromBits(io.in) } -class EnumOps(xType: EnumType, yType: EnumType) extends Module { +class EnumOps(val xType: ChiselEnum, val yType: ChiselEnum) extends Module { val io = IO(new Bundle { - val x = Input(xType) - val y = Input(yType) + val x = Input(xType()) + val y = Input(yType()) val lt = Output(Bool()) val le = Output(Bool()) @@ -117,11 +104,10 @@ class EnumOps(xType: EnumType, yType: EnumType) extends Module { } object StrongEnumFSM { - class State extends EnumType - object State extends StrongEnum[State] { + object State extends ChiselEnum { val sNone, sOne1, sTwo1s = Value - val correct_annotation_map = Map[String, UInt]("sNone" -> 0.U(2.W), "sOne1" -> 1.U(2.W), "sTwo1s" -> 2.U(2.W)) + val correct_annotation_map = Map[String, BigInt]("sNone" -> 0, "sOne1" -> 1, "sTwo1s" -> 2) } } @@ -210,7 +196,7 @@ class CastToInvalidEnumTester extends BasicTester { class EnumOpsTester extends BasicTester { for (x <- EnumExample.all; y <- EnumExample.all) { - val mod = Module(new EnumOps(EnumExample(), EnumExample())) + val mod = Module(new EnumOps(EnumExample, EnumExample)) mod.io.x := x mod.io.y := y @@ -225,7 +211,7 @@ class EnumOpsTester extends BasicTester { } class InvalidEnumOpsTester extends BasicTester { - val mod = Module(new EnumOps(EnumExample(), OtherEnum())) + val mod = Module(new EnumOps(EnumExample, OtherEnum)) mod.io.x := EnumExample.e0 mod.io.y := OtherEnum.otherEnum } @@ -268,7 +254,7 @@ class StrongEnumFSMTester extends BasicTester { // Inputs and expected results val inputs: Vec[Bool] = VecInit(false.B, true.B, false.B, true.B, true.B, true.B, false.B, true.B, true.B, false.B) val expected: Vec[Bool] = VecInit(false.B, false.B, false.B, false.B, false.B, true.B, true.B, false.B, false.B, true.B) - val expected_state: Vec[State] = VecInit(sNone, sNone, sOne1, sNone, sOne1, sTwo1s, sTwo1s, sNone, sOne1, sTwo1s) + val expected_state = VecInit(sNone, sNone, sOne1, sNone, sOne1, sTwo1s, sTwo1s, sNone, sOne1, sTwo1s) val cntr = Counter(inputs.length) val cycle = cntr.value @@ -283,32 +269,19 @@ class StrongEnumFSMTester extends BasicTester { } class StrongEnumSpec extends ChiselFlatSpec { - import chisel3.core.EnumExceptions._ import chisel3.internal.ChiselException behavior of "Strong enum tester" - it should "fail to instantiate enums without a companion class" in { - an [EnumHasNoCompanionObjectException] should be thrownBy { - elaborate(new SimpleConnector(new EnumWithoutCompanionObj(), new EnumWithoutCompanionObj())) - } - } - - it should "fail to instantiate non-literal enums in a companion object" in { - an [ExceptionInInitializerError] should be thrownBy { - elaborate(new SimpleConnector(new NonLiteralEnumType(), new NonLiteralEnumType())) - } - } - - it should "fail to call isValid early" in { + it should "fail to instantiate non-literal enums with the Value function" in { an [ExceptionInInitializerError] should be thrownBy { - elaborate(new SimpleConnector(EnumWithEarlyIsValid(), EnumWithEarlyIsValid())) + elaborate(new SimpleConnector(NonLiteralEnumType(), NonLiteralEnumType())) } } - it should "fail to call next early" in { + it should "fail to instantiate non-increasing enums with the Value function" in { an [ExceptionInInitializerError] should be thrownBy { - elaborate(new SimpleConnector(EnumWithEarlyNext(), EnumWithEarlyNext())) + elaborate(new SimpleConnector(NonIncreasingEnum(), NonIncreasingEnum())) } } @@ -366,7 +339,7 @@ class StrongEnumSpec extends ChiselFlatSpec { } it should "fail to compare enums of different types" in { - an [EnumTypeMismatchException] should be thrownBy { + a [ChiselException] should be thrownBy { elaborate(new InvalidEnumOpsTester) } } @@ -383,6 +356,13 @@ class StrongEnumSpec extends ChiselFlatSpec { assertTesterPasses(new WidthTester) } + it should "maintain Scala-level type-safety" in { + def foo(e: EnumExample.E) = {} + + "foo(EnumExample.e1); foo(EnumExample.e1.next)" should compile + "foo(OtherEnum.otherEnum)" shouldNot compile + } + "StrongEnum FSM" should "work" in { assertTesterPasses(new StrongEnumFSMTester) } @@ -401,6 +381,16 @@ class StrongEnumAnnotationSpec extends FreeSpec with Matchers { val enumDefAnnos = annos.collect { case a: EnumDefAnnotation => a } val enumCompAnnos = annos.collect { case a: EnumComponentAnnotation => a } + // Print the annotations out onto the screen + println("Enum definitions:") + enumDefAnnos.foreach { + case EnumDefAnnotation(enumTypeName, definition) => println(s"\t$enumTypeName: $definition") + } + println("Enum components:") + enumCompAnnos.foreach{ + case EnumComponentAnnotation(target, enumTypeName) => println(s"\t$target => $enumTypeName") + } + // Check that the global annotation is correct enumDefAnnos.exists { case EnumDefAnnotation(name, map) => @@ -409,11 +399,7 @@ class StrongEnumAnnotationSpec extends FreeSpec with Matchers { map.forall { case (k, v) => val correctValue = StrongEnumFSM.State.correct_annotation_map(k) - - val correctValLit = correctValue.litValue() - val vLitValue = v.litValue() - - correctValue.getWidth == v.getWidth && correctValue.litValue() == v.litValue() + correctValue == v } case _ => false } should be(true) diff --git a/src/test/scala/cookbook/FSM.scala b/src/test/scala/cookbook/FSM.scala index 688cfecb6bc..e71494f1814 100644 --- a/src/test/scala/cookbook/FSM.scala +++ b/src/test/scala/cookbook/FSM.scala @@ -11,20 +11,17 @@ import chisel3.util._ * control logic */ -object DetectTwoOnes { - class State extends EnumType - object State extends StrongEnum[State] { - val sNone, sOne1, sTwo1s = Value - } -} - class DetectTwoOnes extends Module { val io = IO(new Bundle { val in = Input(Bool()) val out = Output(Bool()) }) - import DetectTwoOnes.State._ + object State extends ChiselEnum { + val sNone, sOne1, sTwo1s = Value + } + import State._ + val state = RegInit(sNone) io.out := (state === sTwo1s) From 323c20cca8aa3870249591609552db2a8cc8e45e Mon Sep 17 00:00:00 2001 From: Hasan Genc Date: Thu, 27 Sep 2018 20:32:36 -0700 Subject: [PATCH 12/13] * Added scope protections to ChiselEnum.Value so that users cannot call it outside of a ChiselEnum definition * Renamed ChiselEnum.Value type to ChiselEnum.Type so that we can give it a companion object just like UInt and Bool do --- .../main/scala/chisel3/core/StrongEnum.scala | 29 ++++++++++--------- src/test/scala/chiselTests/StrongEnum.scala | 9 ++++-- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala index eb2d7d932e8..3ef8ba1d914 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala @@ -164,14 +164,17 @@ abstract class EnumType(private val factory: EnumFactory, selfAnnotating: Boolea abstract class EnumFactory { - class Value extends EnumType(this) + class Type extends EnumType(this) + object Type { + def apply(): Type = EnumFactory.this.apply() + } var id: BigInt = 0 var width: Width = 0.W val enum_names = mutable.ArrayBuffer.empty[String] val enum_values = mutable.ArrayBuffer.empty[BigInt] - val enum_instances = mutable.ArrayBuffer.empty[Value] + val enum_instances = mutable.ArrayBuffer.empty[Type] private[core] def globalAnnotation: EnumDefChiselAnnotation = EnumDefChiselAnnotation(enumTypeName, (enum_names, enum_values).zipped.toMap) @@ -180,13 +183,13 @@ abstract class EnumFactory { def getWidth: Int = width.get - def all: Seq[Value] = enum_instances.toSeq + def all: Seq[Type] = enum_instances.toSeq - def Value: Value = macro EnumMacros.ValImpl - def Value(id: UInt): Value = macro EnumMacros.ValCustomImpl + protected def Value: Type = macro EnumMacros.ValImpl + protected def Value(id: UInt): Type = macro EnumMacros.ValCustomImpl - def do_Value(names: Seq[String]): Value = { - val result = new Value + protected def do_Value(names: Seq[String]): Type = { + val result = new Type enum_names ++= names.filter(!enum_names.contains(_)) enum_instances.append(result) enum_values.append(id) @@ -200,7 +203,7 @@ abstract class EnumFactory { result } - def do_Value(names: Seq[String], id: UInt): Value = { + protected def do_Value(names: Seq[String], id: UInt): Type = { // TODO: These throw ExceptionInInitializerError which can be confusing to the user. Get rid of the error, and just // throw an exception if (!id.litOption.isDefined) @@ -212,9 +215,9 @@ abstract class EnumFactory { do_Value(names) } - def apply(): Value = new Value + def apply(): Type = new Type - def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): Value = { + def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): Type = { if (n.litOption.isEmpty) { throwException(s"Illegal cast from non-literal UInt to $enumTypeName. Use fromBits instead") } @@ -228,7 +231,7 @@ abstract class EnumFactory { } } - def fromBits(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): Value = { + def fromBits(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): Type = { if (n.litOption.isDefined) { apply(n) } else if (!n.isWidthKnown) { @@ -240,7 +243,7 @@ abstract class EnumFactory { val glue = Wire(new UnsafeEnum(width)) glue := n - val result = Wire(new Value) + val result = Wire(new Type) result := glue result } @@ -248,7 +251,7 @@ abstract class EnumFactory { } -object EnumMacros { +private[core] object EnumMacros { def ValImpl(c: Context) : c.Tree = { import c.universe._ val names = getNames(c) diff --git a/src/test/scala/chiselTests/StrongEnum.scala b/src/test/scala/chiselTests/StrongEnum.scala index 49a6c342499..6b9dd3af0b8 100644 --- a/src/test/scala/chiselTests/StrongEnum.scala +++ b/src/test/scala/chiselTests/StrongEnum.scala @@ -287,6 +287,7 @@ class StrongEnumSpec extends ChiselFlatSpec { it should "connect enums of the same type" in { elaborate(new SimpleConnector(EnumExample(), EnumExample())) + elaborate(new SimpleConnector(EnumExample(), EnumExample.Type())) } it should "fail to connect a strong enum to a UInt" in { @@ -296,9 +297,13 @@ class StrongEnumSpec extends ChiselFlatSpec { } it should "fail to connect enums of different types" in { - an [ChiselException] should be thrownBy { + a [ChiselException] should be thrownBy { elaborate(new SimpleConnector(EnumExample(), OtherEnum())) } + + a [ChiselException] should be thrownBy { + elaborate(new SimpleConnector(EnumExample.Type(), OtherEnum.Type())) + } } it should "cast enums to UInts correctly" in { @@ -357,7 +362,7 @@ class StrongEnumSpec extends ChiselFlatSpec { } it should "maintain Scala-level type-safety" in { - def foo(e: EnumExample.Value) = {} + def foo(e: EnumExample.Type) = {} "foo(EnumExample.e1); foo(EnumExample.e1.next)" should compile "foo(OtherEnum.otherEnum)" shouldNot compile From 8e1344d1b82e1c43a7bfeffea4168ee252d4879f Mon Sep 17 00:00:00 2001 From: Hasan Genc Date: Fri, 12 Oct 2018 00:10:46 -0700 Subject: [PATCH 13/13] * Moved strong enums into experimental package * Non-literal UInts can now be cast to enums with apply() rather than fromBits() * Reduced code-duplication by moving some functions from EnumType and Bits to Element --- .../src/main/scala/chisel3/core/Bits.scala | 60 +++++------ .../main/scala/chisel3/core/StrongEnum.scala | 100 +++++------------- src/main/scala/chisel3/package.scala | 6 +- src/test/scala/chiselTests/StrongEnum.scala | 10 +- src/test/scala/cookbook/FSM.scala | 20 ++-- 5 files changed, 77 insertions(+), 119 deletions(-) diff --git a/chiselFrontend/src/main/scala/chisel3/core/Bits.scala b/chiselFrontend/src/main/scala/chisel3/core/Bits.scala index 40b2bc77991..9356a91cc6d 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/Bits.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/Bits.scala @@ -20,6 +20,10 @@ import chisel3.internal.firrtl.PrimOp._ * @define coll element */ abstract class Element extends Data { + private[chisel3] final def allElements: Seq[Element] = Seq(this) + def widthKnown: Boolean = width.known + def name: String = getRef.name + private[chisel3] override def bind(target: Binding, parentDirection: SpecifiedDirection) { binding = target val resolvedDirection = SpecifiedDirection.fromParent(parentDirection, specifiedDirection) @@ -30,9 +34,32 @@ abstract class Element extends Data { } } - private[chisel3] final def allElements: Seq[Element] = Seq(this) - def widthKnown: Boolean = width.known - def name: String = getRef.name + private[core] override def topBindingOpt: Option[TopBinding] = super.topBindingOpt match { + // Translate Bundle lit bindings to Element lit bindings + case Some(BundleLitBinding(litMap)) => litMap.get(this) match { + case Some(litArg) => Some(ElementLitBinding(litArg)) + case _ => Some(DontCareBinding()) + } + case topBindingOpt => topBindingOpt + } + + private[core] def litArgOption: Option[LitArg] = topBindingOpt match { + case Some(ElementLitBinding(litArg)) => Some(litArg) + case _ => None + } + + override def litOption: Option[BigInt] = litArgOption.map(_.num) + private[core] def litIsForcedWidth: Option[Boolean] = litArgOption.map(_.forcedWidth) + + // provide bits-specific literal handling functionality here + override private[chisel3] def ref: Arg = topBindingOpt match { + case Some(ElementLitBinding(litArg)) => litArg + case Some(BundleLitBinding(litMap)) => litMap.get(this) match { + case Some(litArg) => litArg + case _ => throwException(s"internal error: DontCare should be caught before getting ref") + } + case _ => super.ref + } private[core] def legacyConnect(that: Data)(implicit sourceInfo: SourceInfo): Unit = { // If the source is a DontCare, generate a DefInvalid for the sink, @@ -79,33 +106,6 @@ sealed abstract class Bits(private[chisel3] val width: Width) extends Element wi def cloneType: this.type = cloneTypeWidth(width) - private[core] override def topBindingOpt: Option[TopBinding] = super.topBindingOpt match { - // Translate Bundle lit bindings to Element lit bindings - case Some(BundleLitBinding(litMap)) => litMap.get(this) match { - case Some(litArg) => Some(ElementLitBinding(litArg)) - case _ => Some(DontCareBinding()) - } - case topBindingOpt => topBindingOpt - } - - private[core] def litArgOption: Option[LitArg] = topBindingOpt match { - case Some(ElementLitBinding(litArg)) => Some(litArg) - case _ => None - } - - override def litOption: Option[BigInt] = litArgOption.map(_.num) - private[core] def litIsForcedWidth: Option[Boolean] = litArgOption.map(_.forcedWidth) - - // provide bits-specific literal handling functionality here - override private[chisel3] def ref: Arg = topBindingOpt match { - case Some(ElementLitBinding(litArg)) => litArg - case Some(BundleLitBinding(litMap)) => litMap.get(this) match { - case Some(litArg) => litArg - case _ => throwException(s"internal error: DontCare should be caught before getting ref") - } - case _ => super.ref - } - /** Tail operator * * @param n the number of bits to remove diff --git a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala index 3ef8ba1d914..a9f513872e9 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala @@ -35,33 +35,6 @@ import EnumAnnotations._ abstract class EnumType(private val factory: EnumFactory, selfAnnotating: Boolean = true) extends Element { override def cloneType: this.type = factory().asInstanceOf[this.type] - private[core] override def topBindingOpt: Option[TopBinding] = super.topBindingOpt match { - // Translate Bundle lit bindings to Element lit bindings - case Some(BundleLitBinding(litMap)) => litMap.get(this) match { - case Some(litArg) => Some(ElementLitBinding(litArg)) - case _ => Some(DontCareBinding()) - } - case topBindingOpt => topBindingOpt - } - - private[core] def litArgOption: Option[LitArg] = topBindingOpt match { - case Some(ElementLitBinding(litArg)) => Some(litArg) - case _ => None - } - - override def litOption: Option[BigInt] = litArgOption.map(_.num) - private[core] def litIsForcedWidth: Option[Boolean] = litArgOption.map(_.forcedWidth) - - // provide bits-specific literal handling functionality here - override private[chisel3] def ref: Arg = topBindingOpt match { - case Some(ElementLitBinding(litArg)) => litArg - case Some(BundleLitBinding(litMap)) => litMap.get(this) match { - case Some(litArg) => litArg - case _ => throwException(s"internal error: DontCare should be caught before getting ref") - } - case _ => super.ref - } - private[core] def compop(sourceInfo: SourceInfo, op: PrimOp, other: EnumType): Bool = { requireIsHardware(this, "bits operated on") requireIsHardware(other, "bits operated on") @@ -80,9 +53,7 @@ abstract class EnumType(private val factory: EnumFactory, selfAnnotating: Boolea // This isn't actually used anywhere (and it would throw an exception anyway). But it has to be defined since we // inherit it from Data. private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo, - compileOptions: CompileOptions): Unit = { - this := that.asUInt - } + compileOptions: CompileOptions): Unit = ??? final def === (that: EnumType): Bool = macro SourceInfoTransform.thatArg final def =/= (that: EnumType): Bool = macro SourceInfoTransform.thatArg @@ -107,12 +78,7 @@ abstract class EnumType(private val factory: EnumFactory, selfAnnotating: Boolea if (litOption.isDefined) { true.B } else { - def muxBuilder(enums: List[EnumType]): Bool = enums match { - case Nil => false.B - case e :: es => Mux(this === e, true.B, muxBuilder(es)) - } - - muxBuilder(factory.all.toList) + factory.all.map(this === _).reduce(_ || _) } } @@ -125,12 +91,9 @@ abstract class EnumType(private val factory: EnumFactory, selfAnnotating: Boolea else factory.all.head.asInstanceOf[this.type] } else { - def muxBuilder(enums: List[EnumType], first_enum: EnumType): EnumType = enums match { - case e :: Nil => first_enum - case e :: e_next :: es => Mux(this === e, e_next, muxBuilder(e_next :: es, first_enum)) - } - - muxBuilder(factory.all.toList, factory.all.head).asInstanceOf[this.type] + val enums_with_nexts = factory.all zip (factory.all.tail :+ factory.all.head) + val next_enum = SeqUtils.priorityMux(enums_with_nexts.map { case (e,n) => (this === e, n) } ) + next_enum.asInstanceOf[this.type] } } @@ -144,7 +107,7 @@ abstract class EnumType(private val factory: EnumFactory, selfAnnotating: Boolea // If we try to annotate something that is bound to a literal, we get a FIRRTL annotation exception. // To workaround that, we only annotate enums that are not bound to literals. - if (selfAnnotating && !litOption.isDefined) { + if (selfAnnotating && litOption.isEmpty) { annotateEnum() } } @@ -169,34 +132,37 @@ abstract class EnumFactory { def apply(): Type = EnumFactory.this.apply() } - var id: BigInt = 0 - var width: Width = 0.W + private var id: BigInt = 0 + private[core] var width: Width = 0.W - val enum_names = mutable.ArrayBuffer.empty[String] - val enum_values = mutable.ArrayBuffer.empty[BigInt] - val enum_instances = mutable.ArrayBuffer.empty[Type] + private case class EnumRecord(inst: Type, name: String) + private val enum_records = mutable.ArrayBuffer.empty[EnumRecord] - private[core] def globalAnnotation: EnumDefChiselAnnotation = - EnumDefChiselAnnotation(enumTypeName, (enum_names, enum_values).zipped.toMap) + private def enumNames = enum_records.map(_.name).toSeq + private def enumValues = enum_records.map(_.inst.litValue()).toSeq + private def enumInstances = enum_records.map(_.inst).toSeq private[core] val enumTypeName = getClass.getName.init + private[core] def globalAnnotation: EnumDefChiselAnnotation = + EnumDefChiselAnnotation(enumTypeName, (enumNames, enumValues).zipped.toMap) + def getWidth: Int = width.get - def all: Seq[Type] = enum_instances.toSeq + def all: Seq[Type] = enumInstances protected def Value: Type = macro EnumMacros.ValImpl protected def Value(id: UInt): Type = macro EnumMacros.ValCustomImpl protected def do_Value(names: Seq[String]): Type = { val result = new Type - enum_names ++= names.filter(!enum_names.contains(_)) - enum_instances.append(result) - enum_values.append(id) // We have to use UnknownWidth here, because we don't actually know what the final width will be result.bindToLiteral(id, UnknownWidth()) + val result_name = names.find(!enumNames.contains(_)).get + enum_records.append(EnumRecord(result, result_name)) + width = (1 max id.bitLength).W id += 1 @@ -206,7 +172,7 @@ abstract class EnumFactory { protected def do_Value(names: Seq[String], id: UInt): Type = { // TODO: These throw ExceptionInInitializerError which can be confusing to the user. Get rid of the error, and just // throw an exception - if (!id.litOption.isDefined) + if (id.litOption.isEmpty) throwException(s"$enumTypeName defined with a non-literal type") if (id.litValue() < this.id) throwException(s"Enums must be strictly increasing: $enumTypeName") @@ -218,28 +184,20 @@ abstract class EnumFactory { def apply(): Type = new Type def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): Type = { - if (n.litOption.isEmpty) { - throwException(s"Illegal cast from non-literal UInt to $enumTypeName. Use fromBits instead") - } - - val result = enum_instances.find(_.litValue == n.litValue) - - if (result.isEmpty) { - throwException(s"${n.litValue}.U is not a valid value for $enumTypeName") - } else { - result.get - } - } - - def fromBits(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): Type = { if (n.litOption.isDefined) { - apply(n) + val result = enumInstances.find(_.litValue == n.litValue) + + if (result.isEmpty) { + throwException(s"${n.litValue}.U is not a valid value for $enumTypeName") + } else { + result.get + } } else if (!n.isWidthKnown) { throwException(s"Non-literal UInts being cast to $enumTypeName must have a defined width") } else if (n.getWidth > this.getWidth) { throwException(s"The UInt being cast to $enumTypeName is wider than $enumTypeName's width ($getWidth)") } else { - Builder.warning(s"A non-literal UInt is being cast to $enumTypeName. You can check that the value is legal by calling isValid") + Builder.warning(s"A non-literal UInt is being cast to $enumTypeName. You can check that its value is legal by calling isValid") val glue = Wire(new UnsafeEnum(width)) glue := n diff --git a/src/main/scala/chisel3/package.scala b/src/main/scala/chisel3/package.scala index f4e62d3de72..e79a11867dc 100644 --- a/src/main/scala/chisel3/package.scala +++ b/src/main/scala/chisel3/package.scala @@ -243,9 +243,6 @@ package object chisel3 { // scalastyle:ignore package.object.name object Bool extends BoolFactory val Mux = chisel3.core.Mux - type ChiselEnum = chisel3.core.EnumFactory - val EnumAnnotations = chisel3.core.EnumAnnotations - type BlackBox = chisel3.core.BlackBox type InstanceId = chisel3.internal.InstanceId @@ -423,6 +420,9 @@ package object chisel3 { // scalastyle:ignore package.object.name val Analog = chisel3.core.Analog val attach = chisel3.core.attach + type ChiselEnum = chisel3.core.EnumFactory + val EnumAnnotations = chisel3.core.EnumAnnotations + val withClockAndReset = chisel3.core.withClockAndReset val withClock = chisel3.core.withClock val withReset = chisel3.core.withReset diff --git a/src/test/scala/chiselTests/StrongEnum.scala b/src/test/scala/chiselTests/StrongEnum.scala index 6b9dd3af0b8..982866244ec 100644 --- a/src/test/scala/chiselTests/StrongEnum.scala +++ b/src/test/scala/chiselTests/StrongEnum.scala @@ -3,11 +3,10 @@ package chiselTests import chisel3._ -import chisel3.core.EnumAnnotations +import chisel3.experimental.ChiselEnum import chisel3.internal.firrtl.UnknownWidth import chisel3.util._ import chisel3.testers.BasicTester -import firrtl.annotations.ComponentName import org.scalatest.{FreeSpec, Matchers} object EnumExample extends ChiselEnum { @@ -67,7 +66,7 @@ class CastFromNonLit extends Module { val valid = Output(Bool()) }) - io.out := EnumExample.fromBits(io.in) + io.out := EnumExample(io.in) io.valid := io.out.isValid } @@ -79,7 +78,7 @@ class CastFromNonLitWidth(w: Option[Int] = None) extends Module { val out = Output(EnumExample()) }) - io.out := EnumExample.fromBits(io.in) + io.out := EnumExample(io.in) } class EnumOps(val xType: ChiselEnum, val yType: ChiselEnum) extends Module { @@ -374,7 +373,8 @@ class StrongEnumSpec extends ChiselFlatSpec { } class StrongEnumAnnotationSpec extends FreeSpec with Matchers { - import EnumAnnotations._ + import chisel3.experimental.EnumAnnotations._ + import firrtl.annotations.ComponentName "Test that strong enums annotate themselves appropriately" in { diff --git a/src/test/scala/cookbook/FSM.scala b/src/test/scala/cookbook/FSM.scala index e71494f1814..170d110ff97 100644 --- a/src/test/scala/cookbook/FSM.scala +++ b/src/test/scala/cookbook/FSM.scala @@ -4,6 +4,7 @@ package cookbook import chisel3._ import chisel3.util._ +import chisel3.experimental.ChiselEnum /* ### How do I create a finite state machine? * @@ -20,28 +21,27 @@ class DetectTwoOnes extends Module { object State extends ChiselEnum { val sNone, sOne1, sTwo1s = Value } - import State._ - val state = RegInit(sNone) + val state = RegInit(State.sNone) - io.out := (state === sTwo1s) + io.out := (state === State.sTwo1s) switch (state) { - is (sNone) { + is (State.sNone) { when (io.in) { - state := sOne1 + state := State.sOne1 } } - is (sOne1) { + is (State.sOne1) { when (io.in) { - state := sTwo1s + state := State.sTwo1s } .otherwise { - state := sNone + state := State.sNone } } - is (sTwo1s) { + is (State.sTwo1s) { when (!io.in) { - state := sNone + state := State.sNone } } }