diff --git a/chiselFrontend/src/main/scala/chisel3/core/Bits.scala b/chiselFrontend/src/main/scala/chisel3/core/Bits.scala index e9458446a43..9356a91cc6d 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/Bits.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/Bits.scala @@ -19,7 +19,11 @@ import chisel3.internal.firrtl.PrimOp._ * * @define coll element */ -abstract class Element(private[chisel3] val width: Width) extends Data { +abstract class Element extends Data { + private[chisel3] final def allElements: Seq[Element] = Seq(this) + def widthKnown: Boolean = width.known + def name: String = getRef.name + private[chisel3] override def bind(target: Binding, parentDirection: SpecifiedDirection) { binding = target val resolvedDirection = SpecifiedDirection.fromParent(parentDirection, specifiedDirection) @@ -30,9 +34,32 @@ abstract class Element(private[chisel3] val width: Width) extends Data { } } - private[chisel3] final def allElements: Seq[Element] = Seq(this) - def widthKnown: Boolean = width.known - def name: String = getRef.name + private[core] override def topBindingOpt: Option[TopBinding] = super.topBindingOpt match { + // Translate Bundle lit bindings to Element lit bindings + case Some(BundleLitBinding(litMap)) => litMap.get(this) match { + case Some(litArg) => Some(ElementLitBinding(litArg)) + case _ => Some(DontCareBinding()) + } + case topBindingOpt => topBindingOpt + } + + private[core] def litArgOption: Option[LitArg] = topBindingOpt match { + case Some(ElementLitBinding(litArg)) => Some(litArg) + case _ => None + } + + override def litOption: Option[BigInt] = litArgOption.map(_.num) + private[core] def litIsForcedWidth: Option[Boolean] = litArgOption.map(_.forcedWidth) + + // provide bits-specific literal handling functionality here + override private[chisel3] def ref: Arg = topBindingOpt match { + case Some(ElementLitBinding(litArg)) => litArg + case Some(BundleLitBinding(litMap)) => litMap.get(this) match { + case Some(litArg) => litArg + case _ => throwException(s"internal error: DontCare should be caught before getting ref") + } + case _ => super.ref + } private[core] def legacyConnect(that: Data)(implicit sourceInfo: SourceInfo): Unit = { // If the source is a DontCare, generate a DefInvalid for the sink, @@ -69,7 +96,7 @@ private[chisel3] sealed trait ToBoolable extends Element { * @define sumWidth @note The width of the returned $coll is `width of this` + `width of that`. * @define unchangedWidth @note The width of the returned $coll is unchanged, i.e., the `width of this`. */ -sealed abstract class Bits(width: Width) extends Element(width) with ToBoolable { //scalastyle:off number.of.methods +sealed abstract class Bits(private[chisel3] val width: Width) extends Element with ToBoolable { //scalastyle:off number.of.methods // TODO: perhaps make this concrete? // Arguments for: self-checking code (can't do arithmetic on bits) // Arguments against: generates down to a FIRRTL UInt anyways @@ -79,33 +106,6 @@ sealed abstract class Bits(width: Width) extends Element(width) with ToBoolable def cloneType: this.type = cloneTypeWidth(width) - private[core] override def topBindingOpt: Option[TopBinding] = super.topBindingOpt match { - // Translate Bundle lit bindings to Element lit bindings - case Some(BundleLitBinding(litMap)) => litMap.get(this) match { - case Some(litArg) => Some(ElementLitBinding(litArg)) - case _ => Some(DontCareBinding()) - } - case topBindingOpt => topBindingOpt - } - - private[core] def litArgOption: Option[LitArg] = topBindingOpt match { - case Some(ElementLitBinding(litArg)) => Some(litArg) - case _ => None - } - - override def litOption: Option[BigInt] = litArgOption.map(_.num) - private[core] def litIsForcedWidth: Option[Boolean] = litArgOption.map(_.forcedWidth) - - // provide bits-specific literal handling functionality here - override private[chisel3] def ref: Arg = topBindingOpt match { - case Some(ElementLitBinding(litArg)) => litArg - case Some(BundleLitBinding(litMap)) => litMap.get(this) match { - case Some(litArg) => litArg - case _ => throwException(s"internal error: DontCare should be caught before getting ref") - } - case _ => super.ref - } - /** Tail operator * * @param n the number of bits to remove @@ -1693,7 +1693,7 @@ object FixedPoint { * * @note This API is experimental and subject to change */ -final class Analog private (width: Width) extends Element(width) { +final class Analog private (private[chisel3] val width: Width) extends Element { require(width.known, "Since Analog is only for use in BlackBoxes, width must be known") private[core] override def typeEquivalent(that: Data): Boolean = diff --git a/chiselFrontend/src/main/scala/chisel3/core/Clock.scala b/chiselFrontend/src/main/scala/chisel3/core/Clock.scala index b728075b954..88208d9a6c0 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/Clock.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/Clock.scala @@ -12,7 +12,7 @@ object Clock { } // TODO: Document this. -sealed class Clock extends Element(Width(1)) { +sealed class Clock(private[chisel3] val width: Width = Width(1)) extends Element { def cloneType: this.type = Clock().asInstanceOf[this.type] private[core] def typeEquivalent(that: Data): Boolean = diff --git a/chiselFrontend/src/main/scala/chisel3/core/Data.scala b/chiselFrontend/src/main/scala/chisel3/core/Data.scala index 869e22fb111..f292d3c673b 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/Data.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/Data.scala @@ -533,10 +533,12 @@ object WireInit { /** RHS (source) for Invalidate API. * Causes connection logic to emit a DefInvalid when connected to an output port (or wire). */ -object DontCare extends Element(width = UnknownWidth()) { +object DontCare extends Element { // This object should be initialized before we execute any user code that refers to it, // otherwise this "Chisel" object will end up on the UserModule's id list. + private[chisel3] override val width: Width = UnknownWidth() + bind(DontCareBinding(), SpecifiedDirection.Output) override def cloneType = DontCare diff --git a/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala b/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala index eba248709d9..c9420ba70de 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala @@ -79,6 +79,12 @@ object MonoConnect { elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) case (sink_e: Clock, source_e: Clock) => elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) + case (sink_e: EnumType, source_e: UnsafeEnum) => + elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) + case (sink_e: EnumType, source_e: EnumType) if sink_e.typeEquivalent(source_e) => + elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) + case (sink_e: UnsafeEnum, source_e: UInt) => + elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) // Handle Vec case case (sink_v: Vec[Data @unchecked], source_v: Vec[Data @unchecked]) => diff --git a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala new file mode 100644 index 00000000000..a9f513872e9 --- /dev/null +++ b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala @@ -0,0 +1,248 @@ +// See LICENSE for license details. + +package chisel3.core + +import scala.language.experimental.macros +import scala.reflect.macros.blackbox.Context +import scala.collection.mutable + +import chisel3.internal.Builder.pushOp +import chisel3.internal.firrtl.PrimOp._ +import chisel3.internal.firrtl._ +import chisel3.internal.sourceinfo._ +import chisel3.internal.{Builder, InstanceId, throwException} +import firrtl.annotations._ + + +object EnumAnnotations { + case class EnumComponentAnnotation(target: Named, enumTypeName: String) extends SingleTargetAnnotation[Named] { + def duplicate(n: Named) = this.copy(target = n) + } + + case class EnumComponentChiselAnnotation(target: InstanceId, enumTypeName: String) extends ChiselAnnotation { + def toFirrtl = EnumComponentAnnotation(target.toNamed, enumTypeName) + } + + case class EnumDefAnnotation(enumTypeName: String, definition: Map[String, BigInt]) extends NoTargetAnnotation + + case class EnumDefChiselAnnotation(enumTypeName: String, definition: Map[String, BigInt]) extends ChiselAnnotation { + override def toFirrtl: Annotation = EnumDefAnnotation(enumTypeName, definition) + } +} +import EnumAnnotations._ + + +abstract class EnumType(private val factory: EnumFactory, selfAnnotating: Boolean = true) extends Element { + override def cloneType: this.type = factory().asInstanceOf[this.type] + + private[core] def compop(sourceInfo: SourceInfo, op: PrimOp, other: EnumType): Bool = { + requireIsHardware(this, "bits operated on") + requireIsHardware(other, "bits operated on") + + if(!this.typeEquivalent(other)) + throwException(s"Enum types are not equivalent: ${this.enumTypeName}, ${other.enumTypeName}") + + pushOp(DefPrim(sourceInfo, Bool(), op, this.ref, other.ref)) + } + + private[core] override def typeEquivalent(that: Data): Boolean = { + this.getClass == that.getClass && + this.factory == that.asInstanceOf[EnumType].factory + } + + // This isn't actually used anywhere (and it would throw an exception anyway). But it has to be defined since we + // inherit it from Data. + private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo, + compileOptions: CompileOptions): Unit = ??? + + final def === (that: EnumType): Bool = macro SourceInfoTransform.thatArg + final def =/= (that: EnumType): Bool = macro SourceInfoTransform.thatArg + final def < (that: EnumType): Bool = macro SourceInfoTransform.thatArg + final def <= (that: EnumType): Bool = macro SourceInfoTransform.thatArg + final def > (that: EnumType): Bool = macro SourceInfoTransform.thatArg + final def >= (that: EnumType): Bool = macro SourceInfoTransform.thatArg + + def do_=== (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, EqualOp, that) + def do_=/= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, NotEqualOp, that) + def do_< (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessOp, that) + def do_> (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, GreaterOp, that) + def do_<= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessEqOp, that) + def do_>= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, GreaterEqOp, that) + + override def do_asUInt(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): UInt = + pushOp(DefPrim(sourceInfo, UInt(width), AsUIntOp, ref)) + + protected[chisel3] override def width: Width = factory.width + + def isValid(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = { + if (litOption.isDefined) { + true.B + } else { + factory.all.map(this === _).reduce(_ || _) + } + } + + def next(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = { + if (litOption.isDefined) { + val index = factory.all.indexOf(this) + + if (index < factory.all.length-1) + factory.all(index+1).asInstanceOf[this.type] + else + factory.all.head.asInstanceOf[this.type] + } else { + val enums_with_nexts = factory.all zip (factory.all.tail :+ factory.all.head) + val next_enum = SeqUtils.priorityMux(enums_with_nexts.map { case (e,n) => (this === e, n) } ) + next_enum.asInstanceOf[this.type] + } + } + + private[core] def bindToLiteral(num: BigInt, w: Width): Unit = { + val lit = ULit(num, w) + lit.bindLitArg(this) + } + + override def bind(target: Binding, parentDirection: SpecifiedDirection): Unit = { + super.bind(target, parentDirection) + + // If we try to annotate something that is bound to a literal, we get a FIRRTL annotation exception. + // To workaround that, we only annotate enums that are not bound to literals. + if (selfAnnotating && litOption.isEmpty) { + annotateEnum() + } + } + + private def annotateEnum(): Unit = { + annotate(EnumComponentChiselAnnotation(this, enumTypeName)) + + if (!Builder.annotations.contains(factory.globalAnnotation)) { + annotate(factory.globalAnnotation) + } + } + + protected def enumTypeName: String = factory.enumTypeName + + def toPrintable: Printable = FullName(this) // TODO: Find a better pretty printer +} + + +abstract class EnumFactory { + class Type extends EnumType(this) + object Type { + def apply(): Type = EnumFactory.this.apply() + } + + private var id: BigInt = 0 + private[core] var width: Width = 0.W + + private case class EnumRecord(inst: Type, name: String) + private val enum_records = mutable.ArrayBuffer.empty[EnumRecord] + + private def enumNames = enum_records.map(_.name).toSeq + private def enumValues = enum_records.map(_.inst.litValue()).toSeq + private def enumInstances = enum_records.map(_.inst).toSeq + + private[core] val enumTypeName = getClass.getName.init + + private[core] def globalAnnotation: EnumDefChiselAnnotation = + EnumDefChiselAnnotation(enumTypeName, (enumNames, enumValues).zipped.toMap) + + def getWidth: Int = width.get + + def all: Seq[Type] = enumInstances + + protected def Value: Type = macro EnumMacros.ValImpl + protected def Value(id: UInt): Type = macro EnumMacros.ValCustomImpl + + protected def do_Value(names: Seq[String]): Type = { + val result = new Type + + // We have to use UnknownWidth here, because we don't actually know what the final width will be + result.bindToLiteral(id, UnknownWidth()) + + val result_name = names.find(!enumNames.contains(_)).get + enum_records.append(EnumRecord(result, result_name)) + + width = (1 max id.bitLength).W + id += 1 + + result + } + + protected def do_Value(names: Seq[String], id: UInt): Type = { + // TODO: These throw ExceptionInInitializerError which can be confusing to the user. Get rid of the error, and just + // throw an exception + if (id.litOption.isEmpty) + throwException(s"$enumTypeName defined with a non-literal type") + if (id.litValue() < this.id) + throwException(s"Enums must be strictly increasing: $enumTypeName") + + this.id = id.litValue() + do_Value(names) + } + + def apply(): Type = new Type + + def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): Type = { + if (n.litOption.isDefined) { + val result = enumInstances.find(_.litValue == n.litValue) + + if (result.isEmpty) { + throwException(s"${n.litValue}.U is not a valid value for $enumTypeName") + } else { + result.get + } + } else if (!n.isWidthKnown) { + throwException(s"Non-literal UInts being cast to $enumTypeName must have a defined width") + } else if (n.getWidth > this.getWidth) { + throwException(s"The UInt being cast to $enumTypeName is wider than $enumTypeName's width ($getWidth)") + } else { + Builder.warning(s"A non-literal UInt is being cast to $enumTypeName. You can check that its value is legal by calling isValid") + + val glue = Wire(new UnsafeEnum(width)) + glue := n + val result = Wire(new Type) + result := glue + result + } + } +} + + +private[core] object EnumMacros { + def ValImpl(c: Context) : c.Tree = { + import c.universe._ + val names = getNames(c) + q"""this.do_Value(Seq(..$names))""" + } + + def ValCustomImpl(c: Context)(id: c.Expr[UInt]) = { + import c.universe._ + val names = getNames(c) + q"""this.do_Value(Seq(..$names), $id)""" + } + + // Much thanks to Travis Brown for this solution: + // stackoverflow.com/questions/18450203/retrieve-the-name-of-the-value-a-scala-macro-invocation-will-be-assigned-to + def getNames(c: Context): Seq[String] = { + import c.universe._ + + val names = c.enclosingClass.collect { + case ValDef(_, name, _, rhs) + if rhs.pos == c.macroApplication.pos => name.decoded + } + + if (names.isEmpty) + c.abort(c.enclosingPosition, "Value cannot be called without assigning to an enum") + + names + } +} + + +// This is an enum type that can be connected directly to UInts. It is used as a "glue" to cast non-literal UInts +// to enums. +private[chisel3] class UnsafeEnum(override val width: Width) extends EnumType(UnsafeEnum, selfAnnotating = false) { + override def cloneType: this.type = new UnsafeEnum(width).asInstanceOf[this.type] +} +private object UnsafeEnum extends EnumFactory diff --git a/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala b/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala index b6630f7f0f6..ae8b248aaf4 100644 --- a/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala +++ b/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala @@ -68,10 +68,10 @@ abstract class LitArg(val num: BigInt, widthArg: Width) extends Arg { private[chisel3] def width: Width = if (forcedWidth) widthArg else Width(minWidth) override def fullName(ctx: Component): String = name // Ensure the node representing this LitArg has a ref to it and a literal binding. - def bindLitArg[T <: Bits](bits: T): T = { - bits.bind(ElementLitBinding(this)) - bits.setRef(this) - bits + def bindLitArg[T <: Element](elem: T): T = { + elem.bind(ElementLitBinding(this)) + elem.setRef(this) + elem } protected def minWidth: Int diff --git a/src/main/scala/chisel3/internal/firrtl/Converter.scala b/src/main/scala/chisel3/internal/firrtl/Converter.scala index 97504aba19a..181bdfe8b72 100644 --- a/src/main/scala/chisel3/internal/firrtl/Converter.scala +++ b/src/main/scala/chisel3/internal/firrtl/Converter.scala @@ -2,7 +2,7 @@ package chisel3.internal.firrtl import chisel3._ -import chisel3.core.SpecifiedDirection +import chisel3.core.{SpecifiedDirection, EnumType} import chisel3.experimental._ import chisel3.internal.sourceinfo.{NoSourceInfo, SourceLine, SourceInfo} import firrtl.{ir => fir} @@ -211,6 +211,7 @@ private[chisel3] object Converter { def extractType(data: Data, clearDir: Boolean = false): fir.Type = data match { case _: Clock => fir.ClockType + case d: EnumType => fir.UIntType(convert(d.width)) case d: UInt => fir.UIntType(convert(d.width)) case d: SInt => fir.SIntType(convert(d.width)) case d: FixedPoint => fir.FixedType(convert(d.width), convert(d.binaryPoint)) diff --git a/src/main/scala/chisel3/internal/firrtl/Emitter.scala b/src/main/scala/chisel3/internal/firrtl/Emitter.scala index 26ccc09d62a..ac4bf8e7d0f 100644 --- a/src/main/scala/chisel3/internal/firrtl/Emitter.scala +++ b/src/main/scala/chisel3/internal/firrtl/Emitter.scala @@ -2,7 +2,7 @@ package chisel3.internal.firrtl import chisel3._ -import chisel3.core.SpecifiedDirection +import chisel3.core.{SpecifiedDirection, EnumType} import chisel3.experimental._ import chisel3.internal.sourceinfo.{NoSourceInfo, SourceLine} @@ -28,6 +28,7 @@ private class Emitter(circuit: Circuit) { private def emitType(d: Data, clearDir: Boolean = false): String = d match { case d: Clock => "Clock" + case d: chisel3.core.EnumType => s"UInt${d.width}" case d: UInt => s"UInt${d.width}" case d: SInt => s"SInt${d.width}" case d: FixedPoint => s"Fixed${d.width}${d.binaryPoint}" diff --git a/src/main/scala/chisel3/package.scala b/src/main/scala/chisel3/package.scala index b7c39bad9e7..e79a11867dc 100644 --- a/src/main/scala/chisel3/package.scala +++ b/src/main/scala/chisel3/package.scala @@ -420,6 +420,9 @@ package object chisel3 { // scalastyle:ignore package.object.name val Analog = chisel3.core.Analog val attach = chisel3.core.attach + type ChiselEnum = chisel3.core.EnumFactory + val EnumAnnotations = chisel3.core.EnumAnnotations + val withClockAndReset = chisel3.core.withClockAndReset val withClock = chisel3.core.withClock val withReset = chisel3.core.withReset diff --git a/src/main/scala/chisel3/util/Conditional.scala b/src/main/scala/chisel3/util/Conditional.scala index bf2d4268717..3630f8add27 100644 --- a/src/main/scala/chisel3/util/Conditional.scala +++ b/src/main/scala/chisel3/util/Conditional.scala @@ -24,7 +24,7 @@ object unless { // scalastyle:ignore object.name * user-facing API. * @note DO NOT USE. This API is subject to change without warning. */ -class SwitchContext[T <: Bits](cond: T, whenContext: Option[WhenContext], lits: Set[BigInt]) { +class SwitchContext[T <: Element](cond: T, whenContext: Option[WhenContext], lits: Set[BigInt]) { def is(v: Iterable[T])(block: => Unit): SwitchContext[T] = { if (!v.isEmpty) { val newLits = v.map { w => @@ -60,19 +60,19 @@ object is { // scalastyle:ignore object.name // TODO: Begin deprecation of non-type-parameterized is statements. /** Executes `block` if the switch condition is equal to any of the values in `v`. */ - def apply(v: Iterable[Bits])(block: => Unit) { + def apply(v: Iterable[Element])(block: => Unit) { require(false, "The 'is' keyword may not be used outside of a switch.") } /** Executes `block` if the switch condition is equal to `v`. */ - def apply(v: Bits)(block: => Unit) { + def apply(v: Element)(block: => Unit) { require(false, "The 'is' keyword may not be used outside of a switch.") } /** Executes `block` if the switch condition is equal to any of the values in the argument list. */ - def apply(v: Bits, vr: Bits*)(block: => Unit) { + def apply(v: Element, vr: Element*)(block: => Unit) { require(false, "The 'is' keyword may not be used outside of a switch.") } } @@ -91,7 +91,7 @@ object is { // scalastyle:ignore object.name * }}} */ object switch { // scalastyle:ignore object.name - def apply[T <: Bits](cond: T)(x: => Unit): Unit = macro impl + def apply[T <: Element](cond: T)(x: => Unit): Unit = macro impl def impl(c: Context)(cond: c.Tree)(x: c.Tree): c.Tree = { import c.universe._ val q"..$body" = x val res = body.foldLeft(q"""new SwitchContext($cond, None, Set.empty)""") { diff --git a/src/test/scala/chiselTests/StrongEnum.scala b/src/test/scala/chiselTests/StrongEnum.scala new file mode 100644 index 00000000000..982866244ec --- /dev/null +++ b/src/test/scala/chiselTests/StrongEnum.scala @@ -0,0 +1,430 @@ +// See LICENSE for license details. + +package chiselTests + +import chisel3._ +import chisel3.experimental.ChiselEnum +import chisel3.internal.firrtl.UnknownWidth +import chisel3.util._ +import chisel3.testers.BasicTester +import org.scalatest.{FreeSpec, Matchers} + +object EnumExample extends ChiselEnum { + val e0, e1, e2 = Value + + val e100 = Value(100.U) + val e101 = Value(101.U) + + val litValues = List(0.U, 1.U, 2.U, 100.U, 101.U) +} + +object OtherEnum extends ChiselEnum { + val otherEnum = Value +} + +object NonLiteralEnumType extends ChiselEnum { + val nonLit = Value(UInt()) +} + +object NonIncreasingEnum extends ChiselEnum { + val x = Value(2.U) + val y = Value(2.U) +} + +class SimpleConnector(inType: Data, outType: Data) extends Module { + val io = IO(new Bundle { + val in = Input(inType) + val out = Output(outType) + }) + + io.out := io.in +} + +class CastToUInt extends Module { + val io = IO(new Bundle { + val in = Input(EnumExample()) + val out = Output(UInt()) + }) + + io.out := io.in.asUInt() +} + +class CastFromLit(in: UInt) extends Module { + val io = IO(new Bundle { + val out = Output(EnumExample()) + val valid = Output(Bool()) + }) + + io.out := EnumExample(in) + io.valid := io.out.isValid +} + +class CastFromNonLit extends Module { + val io = IO(new Bundle { + val in = Input(UInt(EnumExample.getWidth.W)) + val out = Output(EnumExample()) + val valid = Output(Bool()) + }) + + io.out := EnumExample(io.in) + io.valid := io.out.isValid +} + +class CastFromNonLitWidth(w: Option[Int] = None) extends Module { + val width = if (w.isDefined) w.get.W else UnknownWidth() + + override val io = IO(new Bundle { + val in = Input(UInt(width)) + val out = Output(EnumExample()) + }) + + io.out := EnumExample(io.in) +} + +class EnumOps(val xType: ChiselEnum, val yType: ChiselEnum) extends Module { + val io = IO(new Bundle { + val x = Input(xType()) + val y = Input(yType()) + + val lt = Output(Bool()) + val le = Output(Bool()) + val gt = Output(Bool()) + val ge = Output(Bool()) + val eq = Output(Bool()) + val ne = Output(Bool()) + }) + + io.lt := io.x < io.y + io.le := io.x <= io.y + io.gt := io.x > io.y + io.ge := io.x >= io.y + io.eq := io.x === io.y + io.ne := io.x =/= io.y +} + +object StrongEnumFSM { + object State extends ChiselEnum { + val sNone, sOne1, sTwo1s = Value + + val correct_annotation_map = Map[String, BigInt]("sNone" -> 0, "sOne1" -> 1, "sTwo1s" -> 2) + } +} + +class StrongEnumFSM extends Module { + import StrongEnumFSM.State + import StrongEnumFSM.State._ + + // This FSM detects two 1's one after the other + val io = IO(new Bundle { + val in = Input(Bool()) + val out = Output(Bool()) + val state = Output(State()) + }) + + val state = RegInit(sNone) + + io.out := (state === sTwo1s) + io.state := state + + switch (state) { + is (sNone) { + when (io.in) { + state := sOne1 + } + } + is (sOne1) { + when (io.in) { + state := sTwo1s + } .otherwise { + state := sNone + } + } + is (sTwo1s) { + when (!io.in) { + state := sNone + } + } + } +} + +class CastToUIntTester extends BasicTester { + for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) { + val mod = Module(new CastToUInt) + mod.io.in := enum + assert(mod.io.out === lit) + } + stop() +} + +class CastFromLitTester extends BasicTester { + for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) { + val mod = Module(new CastFromLit(lit)) + assert(mod.io.out === enum) + assert(mod.io.valid === true.B) + } + stop() +} + +class CastFromNonLitTester extends BasicTester { + for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) { + val mod = Module(new CastFromNonLit) + mod.io.in := lit + assert(mod.io.out === enum) + assert(mod.io.valid === true.B) + } + + val invalid_values = (1 until (1 << EnumExample.getWidth)). + filter(!EnumExample.litValues.map(_.litValue).contains(_)). + map(_.U) + + for (invalid_val <- invalid_values) { + val mod = Module(new CastFromNonLit) + mod.io.in := invalid_val + + assert(mod.io.valid === false.B) + } + + stop() +} + +class CastToInvalidEnumTester extends BasicTester { + val invalid_value: UInt = EnumExample.litValues.last + 1.U + Module(new CastFromLit(invalid_value)) +} + +class EnumOpsTester extends BasicTester { + for (x <- EnumExample.all; + y <- EnumExample.all) { + val mod = Module(new EnumOps(EnumExample, EnumExample)) + mod.io.x := x + mod.io.y := y + + assert(mod.io.lt === (x.asUInt() < y.asUInt())) + assert(mod.io.le === (x.asUInt() <= y.asUInt())) + assert(mod.io.gt === (x.asUInt() > y.asUInt())) + assert(mod.io.ge === (x.asUInt() >= y.asUInt())) + assert(mod.io.eq === (x.asUInt() === y.asUInt())) + assert(mod.io.ne === (x.asUInt() =/= y.asUInt())) + } + stop() +} + +class InvalidEnumOpsTester extends BasicTester { + val mod = Module(new EnumOps(EnumExample, OtherEnum)) + mod.io.x := EnumExample.e0 + mod.io.y := OtherEnum.otherEnum +} + +class IsLitTester extends BasicTester { + for (e <- EnumExample.all) { + val wire = WireInit(e) + + assert(e.isLit()) + assert(!wire.isLit()) + } + stop() +} + +class NextTester extends BasicTester { + for ((e,n) <- EnumExample.all.zip(EnumExample.litValues.tail :+ EnumExample.litValues.head)) { + assert(e.next.litValue == n.litValue) + val w = WireInit(e) + assert(w.next === EnumExample(n)) + } + stop() +} + +class WidthTester extends BasicTester { + assert(EnumExample.getWidth == EnumExample.litValues.last.getWidth) + assert(EnumExample.all.forall(_.getWidth == EnumExample.litValues.last.getWidth)) + assert(EnumExample.all.forall{e => + val w = WireInit(e) + w.getWidth == EnumExample.litValues.last.getWidth + }) + stop() +} + +class StrongEnumFSMTester extends BasicTester { + import StrongEnumFSM.State + import StrongEnumFSM.State._ + + val dut = Module(new StrongEnumFSM) + + // Inputs and expected results + val inputs: Vec[Bool] = VecInit(false.B, true.B, false.B, true.B, true.B, true.B, false.B, true.B, true.B, false.B) + val expected: Vec[Bool] = VecInit(false.B, false.B, false.B, false.B, false.B, true.B, true.B, false.B, false.B, true.B) + val expected_state = VecInit(sNone, sNone, sOne1, sNone, sOne1, sTwo1s, sTwo1s, sNone, sOne1, sTwo1s) + + val cntr = Counter(inputs.length) + val cycle = cntr.value + + dut.io.in := inputs(cycle) + assert(dut.io.out === expected(cycle)) + assert(dut.io.state === expected_state(cycle)) + + when(cntr.inc()) { + stop() + } +} + +class StrongEnumSpec extends ChiselFlatSpec { + import chisel3.internal.ChiselException + + behavior of "Strong enum tester" + + it should "fail to instantiate non-literal enums with the Value function" in { + an [ExceptionInInitializerError] should be thrownBy { + elaborate(new SimpleConnector(NonLiteralEnumType(), NonLiteralEnumType())) + } + } + + it should "fail to instantiate non-increasing enums with the Value function" in { + an [ExceptionInInitializerError] should be thrownBy { + elaborate(new SimpleConnector(NonIncreasingEnum(), NonIncreasingEnum())) + } + } + + it should "connect enums of the same type" in { + elaborate(new SimpleConnector(EnumExample(), EnumExample())) + elaborate(new SimpleConnector(EnumExample(), EnumExample.Type())) + } + + it should "fail to connect a strong enum to a UInt" in { + a [ChiselException] should be thrownBy { + elaborate(new SimpleConnector(EnumExample(), UInt())) + } + } + + it should "fail to connect enums of different types" in { + a [ChiselException] should be thrownBy { + elaborate(new SimpleConnector(EnumExample(), OtherEnum())) + } + + a [ChiselException] should be thrownBy { + elaborate(new SimpleConnector(EnumExample.Type(), OtherEnum.Type())) + } + } + + it should "cast enums to UInts correctly" in { + assertTesterPasses(new CastToUIntTester) + } + + it should "cast literal UInts to enums correctly" in { + assertTesterPasses(new CastFromLitTester) + } + + it should "cast non-literal UInts to enums correctly and detect illegal casts" in { + assertTesterPasses(new CastFromNonLitTester) + } + + it should "prevent illegal literal casts to enums" in { + a [ChiselException] should be thrownBy { + elaborate(new CastToInvalidEnumTester) + } + } + + it should "only allow non-literal casts to enums if the width is smaller than or equal to the enum width" in { + for (w <- 0 to EnumExample.getWidth) + elaborate(new CastFromNonLitWidth(Some(w))) + + a [ChiselException] should be thrownBy { + elaborate(new CastFromNonLitWidth) + } + + for (w <- (EnumExample.getWidth+1) to (EnumExample.getWidth+100)) { + a [ChiselException] should be thrownBy { + elaborate(new CastFromNonLitWidth(Some(w))) + } + } + } + + it should "execute enum comparison operations correctly" in { + assertTesterPasses(new EnumOpsTester) + } + + it should "fail to compare enums of different types" in { + a [ChiselException] should be thrownBy { + elaborate(new InvalidEnumOpsTester) + } + } + + it should "correctly check whether or not enums are literal" in { + assertTesterPasses(new IsLitTester) + } + + it should "return the correct next values for enums" in { + assertTesterPasses(new NextTester) + } + + it should "return the correct widths for enums" in { + assertTesterPasses(new WidthTester) + } + + it should "maintain Scala-level type-safety" in { + def foo(e: EnumExample.Type) = {} + + "foo(EnumExample.e1); foo(EnumExample.e1.next)" should compile + "foo(OtherEnum.otherEnum)" shouldNot compile + } + + "StrongEnum FSM" should "work" in { + assertTesterPasses(new StrongEnumFSMTester) + } +} + +class StrongEnumAnnotationSpec extends FreeSpec with Matchers { + import chisel3.experimental.EnumAnnotations._ + import firrtl.annotations.ComponentName + + "Test that strong enums annotate themselves appropriately" in { + + def test() = { + Driver.execute(Array("--target-dir", "test_run_dir"), () => new StrongEnumFSM) match { + case ChiselExecutionSuccess(Some(circuit), emitted, _) => + val annos = circuit.annotations.map(_.toFirrtl) + + val enumDefAnnos = annos.collect { case a: EnumDefAnnotation => a } + val enumCompAnnos = annos.collect { case a: EnumComponentAnnotation => a } + + // Print the annotations out onto the screen + println("Enum definitions:") + enumDefAnnos.foreach { + case EnumDefAnnotation(enumTypeName, definition) => println(s"\t$enumTypeName: $definition") + } + println("Enum components:") + enumCompAnnos.foreach{ + case EnumComponentAnnotation(target, enumTypeName) => println(s"\t$target => $enumTypeName") + } + + // Check that the global annotation is correct + enumDefAnnos.exists { + case EnumDefAnnotation(name, map) => + name.endsWith("State") && + map.size == StrongEnumFSM.State.correct_annotation_map.size && + map.forall { + case (k, v) => + val correctValue = StrongEnumFSM.State.correct_annotation_map(k) + correctValue == v + } + case _ => false + } should be(true) + + // Check that the component annotations are correct + enumCompAnnos.count { + case EnumComponentAnnotation(target, enumName) => + val ComponentName(targetName, _) = target + (targetName == "state" && enumName.endsWith("State")) || + (targetName == "io.state" && enumName.endsWith("State")) + case _ => false + } should be(2) + + case _ => + assert(false) + } + } + + // We run this test twice, to test for an older bug where only the first circuit would be annotated + test() + test() + } +} diff --git a/src/test/scala/cookbook/FSM.scala b/src/test/scala/cookbook/FSM.scala index 22cf8059e8f..170d110ff97 100644 --- a/src/test/scala/cookbook/FSM.scala +++ b/src/test/scala/cookbook/FSM.scala @@ -4,39 +4,44 @@ package cookbook import chisel3._ import chisel3.util._ +import chisel3.experimental.ChiselEnum /* ### How do I create a finite state machine? * - * Use Chisel Enum to construct the states and switch & is to construct the FSM + * Use Chisel StrongEnum to construct the states and switch & is to construct the FSM * control logic */ + class DetectTwoOnes extends Module { val io = IO(new Bundle { val in = Input(Bool()) val out = Output(Bool()) }) - val sNone :: sOne1 :: sTwo1s :: Nil = Enum(3) - val state = RegInit(sNone) + object State extends ChiselEnum { + val sNone, sOne1, sTwo1s = Value + } + + val state = RegInit(State.sNone) - io.out := (state === sTwo1s) + io.out := (state === State.sTwo1s) switch (state) { - is (sNone) { + is (State.sNone) { when (io.in) { - state := sOne1 + state := State.sOne1 } } - is (sOne1) { + is (State.sOne1) { when (io.in) { - state := sTwo1s + state := State.sTwo1s } .otherwise { - state := sNone + state := State.sNone } } - is (sTwo1s) { + is (State.sTwo1s) { when (!io.in) { - state := sNone + state := State.sNone } } }