Skip to content

Commit

Permalink
Implement asTypeOf, refactor internal APIs (#450)
Browse files Browse the repository at this point in the history
  • Loading branch information
ducky64 committed Feb 15, 2017
1 parent 41bee3f commit 375e2b6
Show file tree
Hide file tree
Showing 8 changed files with 289 additions and 137 deletions.
75 changes: 44 additions & 31 deletions chiselFrontend/src/main/scala/chisel3/core/Aggregate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,25 @@ import chisel3.internal.sourceinfo._
* of) other Data objects.
*/
sealed abstract class Aggregate extends Data {
private[core] def cloneTypeWidth(width: Width): this.type = cloneType
private[core] def width: Width = flatten.map(_.width).reduce(_ + _)
/** Returns a Seq of the immediate contents of this Aggregate, in order.
*/
def getElements: Seq[Data]

private[core] def width: Width = getElements.map(_.width).reduce(_ + _)
private[core] def legacyConnect(that: Data)(implicit sourceInfo: SourceInfo): Unit =
pushCommand(BulkConnect(sourceInfo, this.lref, that.lref))

override def do_asUInt(implicit sourceInfo: SourceInfo): UInt = SeqUtils.do_asUInt(this.flatten)
def do_fromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = {
override def do_asUInt(implicit sourceInfo: SourceInfo): UInt = {
SeqUtils.do_asUInt(getElements.map(_.asUInt()))
}
private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo,
compileOptions: CompileOptions): Unit = {
var i = 0
val wire = Wire(this.chiselCloneType)
val bits =
if (that.width.known && that.width.get >= wire.width.get) {
that
} else {
Wire(that.cloneTypeWidth(wire.width), init = that)
}
for (x <- wire.flatten) {
x := x.fromBits(bits(i + x.getWidth-1, i))
val bits = Wire(UInt(this.width), init=that) // handles width padding
for (x <- getElements) {
x.connectFromBits(bits(i + x.getWidth - 1, i))
i += x.getWidth
}
wire.asInstanceOf[this.type]
}
}

Expand Down Expand Up @@ -78,19 +77,16 @@ object Vec {

// Check that types are homogeneous. Width mismatch for Elements is safe.
require(!elts.isEmpty)
def eltsCompatible(a: Data, b: Data) = a match {
case _: Element => a.getClass == b.getClass
case _: Aggregate => Mux.typesCompatible(a, b)
}

val t = elts.head
for (e <- elts.tail)
require(eltsCompatible(t, e), s"can't create Vec of heterogeneous types ${t.getClass} and ${e.getClass}")
val vec = Wire(new Vec(cloneSupertype(elts, "Vec"), elts.length))

val maxWidth = elts.map(_.width).reduce(_ max _)
val vec = Wire(new Vec(t.cloneTypeWidth(maxWidth).chiselCloneType, elts.length))
def doConnect(sink: T, source: T) = {
if (elts.head.flatten.exists(_.dir != Direction.Unspecified)) {
// TODO: this looks bad, and should feel bad. Replace with a better abstraction.
val hasDirectioned = vec.sample_element match {
case t: Aggregate => t.flatten.exists(_.dir != Direction.Unspecified)
case t: Element => t.dir != Direction.Unspecified
}
if (hasDirectioned) {
sink bulkConnect source
} else {
sink connect source
Expand Down Expand Up @@ -163,13 +159,20 @@ object Vec {
*/
sealed class Vec[T <: Data] private (gen: => T, val length: Int)
extends Aggregate with VecLike[T] {
private[core] override def typeEquivalent(that: Data): Boolean = that match {
case that: Vec[T] =>
this.length == that.length &&
(this.sample_element typeEquivalent that.sample_element)
case _ => false
}

// Note: the constructor takes a gen() function instead of a Seq to enforce
// that all elements must be the same and because it makes FIRRTL generation
// simpler.
private val self: Seq[T] = Vector.fill(length)(gen)

/**
* sample_element 'tracks' all changes to the elements of self.
* sample_element 'tracks' all changes to the elements.
* For consistency, sample_element is always used for creating dynamically
* indexed ports and outputing the FIRRTL type.
*
Expand All @@ -181,7 +184,7 @@ sealed class Vec[T <: Data] private (gen: => T, val length: Int)
// This is somewhat weird although I think the best course of action here is
// to deprecate allElements in favor of dispatched functions to Data or
// a pattern matched recursive descent
private[chisel3] final def allElements: Seq[Element] =
private[chisel3] final override def allElements: Seq[Element] =
(sample_element +: self).flatMap(_.allElements)

/** Strong bulk connect, assigning elements in this Vec from elements in a Seq.
Expand Down Expand Up @@ -244,8 +247,8 @@ sealed class Vec[T <: Data] private (gen: => T, val length: Int)
}

private[chisel3] def toType: String = s"${sample_element.toType}[$length]"
private[chisel3] lazy val flatten: IndexedSeq[Bits] =
(0 until length).flatMap(i => this.apply(i).flatten)
override def getElements: Seq[Data] =
(0 until length).map(apply(_))

for ((elt, i) <- self.zipWithIndex)
elt.setRef(this, i)
Expand Down Expand Up @@ -316,14 +319,14 @@ trait VecLike[T <: Data] extends collection.IndexedSeq[T] with HasId {
*/
def indexWhere(p: T => Bool): UInt = macro SourceInfoTransform.pArg

def do_indexWhere(p: T => Bool)(implicit sourceInfo: SourceInfo): UInt =
def do_indexWhere(p: T => Bool)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): UInt =
SeqUtils.priorityMux(indexWhereHelper(p))

/** Outputs the index of the last element for which p outputs true.
*/
def lastIndexWhere(p: T => Bool): UInt = macro SourceInfoTransform.pArg

def do_lastIndexWhere(p: T => Bool)(implicit sourceInfo: SourceInfo): UInt =
def do_lastIndexWhere(p: T => Bool)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): UInt =
SeqUtils.priorityMux(indexWhereHelper(p).reverse)

/** Outputs the index of the element for which p outputs true, assuming that
Expand Down Expand Up @@ -377,7 +380,15 @@ abstract class Record extends Aggregate {
elements.toIndexedSeq.reverse.map(e => eltPort(e._2)).mkString("{", ", ", "}")
}

private[chisel3] lazy val flatten = elements.toIndexedSeq.flatMap(_._2.flatten)
private[core] override def typeEquivalent(that: Data): Boolean = that match {
case that: Record =>
this.getClass == that.getClass &&
this.elements.size == that.elements.size &&
this.elements.forall{case (name, model) =>
that.elements.contains(name) &&
(that.elements(name) typeEquivalent model)}
case _ => false
}

// NOTE: This sets up dependent references, it can be done before closing the Module
private[chisel3] override def _onModuleClose: Unit = { // scalastyle:ignore method.name
Expand All @@ -390,6 +401,8 @@ abstract class Record extends Aggregate {

private[chisel3] final def allElements: Seq[Element] = elements.toIndexedSeq.flatMap(_._2.allElements)

override def getElements: Seq[Data] = elements.toIndexedSeq.map(_._2)

// Helper because Bundle elements are reversed before printing
private[chisel3] def toPrintableHelper(elts: Seq[(String, Data)]): Printable = {
val xs =
Expand Down
114 changes: 39 additions & 75 deletions chiselFrontend/src/main/scala/chisel3/core/Bits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import chisel3.internal._
import chisel3.internal.Builder.{pushCommand, pushOp}
import chisel3.internal.firrtl._
import chisel3.internal.sourceinfo.{SourceInfo, DeprecatedSourceInfo, SourceInfoTransform, SourceInfoWhiteboxTransform,
UIntTransform, MuxTransform}
UIntTransform}
import chisel3.internal.firrtl.PrimOp._
// TODO: remove this once we have CompileOptions threaded through the macro system.
import chisel3.core.ExplicitCompileOptions.NotStrict
Expand Down Expand Up @@ -58,7 +58,8 @@ sealed abstract class Bits(width: Width, override val litArg: Option[LitArg])
// Arguments for: self-checking code (can't do arithmetic on bits)
// Arguments against: generates down to a FIRRTL UInt anyways

private[chisel3] def flatten: IndexedSeq[Bits] = IndexedSeq(this)
// Only used for in a few cases, hopefully to be removed
private[core] def cloneTypeWidth(width: Width): this.type

def cloneType: this.type = cloneTypeWidth(width)

Expand Down Expand Up @@ -399,6 +400,9 @@ abstract trait Num[T <: Data] {
sealed class UInt private[core] (width: Width, lit: Option[ULit] = None)
extends Bits(width, lit) with Num[UInt] {

private[core] override def typeEquivalent(that: Data): Boolean =
that.isInstanceOf[UInt] && this.width == that.width

private[core] override def cloneTypeWidth(w: Width): this.type =
new UInt(w).asInstanceOf[this.type]
private[chisel3] def toType = s"UInt$width"
Expand Down Expand Up @@ -525,13 +529,9 @@ sealed class UInt private[core] (width: Width, lit: Option[ULit] = None)
throwException(s"cannot call $this.asFixedPoint(binaryPoint=$binaryPoint), you must specify a known binaryPoint")
}
}
def do_fromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = {
val res = Wire(this, null).asInstanceOf[this.type]
res := (that match {
case u: UInt => u
case _ => that.asUInt
})
res
private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo,
compileOptions: CompileOptions): Unit = {
this := that.asUInt
}
}

Expand Down Expand Up @@ -567,6 +567,9 @@ object Bits extends UIntFactory
sealed class SInt private[core] (width: Width, lit: Option[SLit] = None)
extends Bits(width, lit) with Num[SInt] {

private[core] override def typeEquivalent(that: Data): Boolean =
this.getClass == that.getClass && this.width == that.width // TODO: should this be true for unspecified widths?

private[core] override def cloneTypeWidth(w: Width): this.type =
new SInt(w).asInstanceOf[this.type]
private[chisel3] def toType = s"SInt$width"
Expand Down Expand Up @@ -669,13 +672,8 @@ sealed class SInt private[core] (width: Width, lit: Option[SLit] = None)
throwException(s"cannot call $this.asFixedPoint(binaryPoint=$binaryPoint), you must specify a known binaryPoint")
}
}
def do_fromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = {
val res = Wire(this, null).asInstanceOf[this.type]
res := (that match {
case s: SInt => s
case _ => that.asSInt
})
res
private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions) {
this := that.asSInt
}
}

Expand Down Expand Up @@ -768,52 +766,6 @@ trait BoolFactory {

object Bool extends BoolFactory

object Mux {
/** Creates a mux, whose output is one of the inputs depending on the
* value of the condition.
*
* @param cond condition determining the input to choose
* @param con the value chosen when `cond` is true
* @param alt the value chosen when `cond` is false
* @example
* {{{
* val muxOut = Mux(data_in === 3.U, 3.U(4.W), 0.U(4.W))
* }}}
*/
def apply[T <: Data](cond: Bool, con: T, alt: T): T = macro MuxTransform.apply[T]

def do_apply[T <: Data](cond: Bool, con: T, alt: T)(implicit sourceInfo: SourceInfo): T =
(con, alt) match {
// Handle Mux(cond, UInt, Bool) carefully so that the concrete type is UInt
case (c: Bool, a: Bool) => doMux(cond, c, a).asInstanceOf[T]
case (c: UInt, a: Bool) => doMux(cond, c, a << 0).asInstanceOf[T]
case (c: Bool, a: UInt) => doMux(cond, c << 0, a).asInstanceOf[T]
case (c: Bits, a: Bits) => doMux(cond, c, a).asInstanceOf[T]
case _ => doAggregateMux(cond, con, alt)
}

private def doMux[T <: Data](cond: Bool, con: T, alt: T)(implicit sourceInfo: SourceInfo): T = {
require(con.getClass == alt.getClass, s"can't Mux between ${con.getClass} and ${alt.getClass}")
Binding.checkSynthesizable(cond, s"'cond' ($cond)")
Binding.checkSynthesizable(con, s"'con' ($con)")
Binding.checkSynthesizable(alt, s"'alt' ($alt)")
val d = alt.cloneTypeWidth(con.width max alt.width)
pushOp(DefPrim(sourceInfo, d, MultiplexOp, cond.ref, con.ref, alt.ref))
}

private[core] def typesCompatible[T <: Data](x: T, y: T): Boolean = {
val sameTypes = x.getClass == y.getClass
val sameElements = x.flatten zip y.flatten forall { case (a, b) => a.getClass == b.getClass && a.width == b.width }
val sameNumElements = x.flatten.size == y.flatten.size
sameTypes && sameElements && sameNumElements
}

private def doAggregateMux[T <: Data](cond: Bool, con: T, alt: T)(implicit sourceInfo: SourceInfo): T = {
require(typesCompatible(con, alt), s"can't Mux between heterogeneous types ${con.getClass} and ${alt.getClass}")
doMux(cond, con, alt)
}
}

//scalastyle:off number.of.methods
/**
* A sealed class representing a fixed point number that has a bit width and a binary point
Expand All @@ -829,6 +781,11 @@ object Mux {
*/
sealed class FixedPoint private (width: Width, val binaryPoint: BinaryPoint, lit: Option[FPLit] = None)
extends Bits(width, lit) with Num[FixedPoint] {
private[core] override def typeEquivalent(that: Data): Boolean = that match {
case that: FixedPoint => this.width == that.width && this.binaryPoint == that.binaryPoint // TODO: should this be true for unspecified widths?
case _ => false
}

private[core] override def cloneTypeWidth(w: Width): this.type =
new FixedPoint(w, binaryPoint).asInstanceOf[this.type]
private[chisel3] def toType = s"Fixed$width$binaryPoint"
Expand Down Expand Up @@ -945,13 +902,13 @@ sealed class FixedPoint private (width: Width, val binaryPoint: BinaryPoint, lit
throwException(s"cannot call $this.asFixedPoint(binaryPoint=$binaryPoint), you must specify a known binaryPoint")
}
}
def do_fromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = {
val res = Wire(this, null).asInstanceOf[this.type]
res := (that match {

private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions) {
// TODO: redefine as just asFixedPoint on that, where FixedPoint.asFixedPoint just works.
this := (that match {
case fp: FixedPoint => fp.asSInt.asFixedPoint(this.binaryPoint)
case _ => that.asFixedPoint(this.binaryPoint)
})
res
}
//TODO(chick): Consider "convert" as an arithmetic conversion to UInt/SInt
}
Expand Down Expand Up @@ -1073,6 +1030,13 @@ object FixedPoint {
final class Analog private (width: Width) extends Element(width) {
require(width.known, "Since Analog is only for use in BlackBoxes, width must be known")

private[chisel3] def toType = s"Analog$width"

private[core] override def typeEquivalent(that: Data): Boolean =
that.isInstanceOf[Analog] && this.width == that.width

def cloneType: this.type = new Analog(width).asInstanceOf[this.type]

// Used to enforce single bulk connect of Analog types, multi-attach is still okay
// Note that this really means 1 bulk connect per Module because a port can
// be connected in the parent module as well
Expand All @@ -1084,15 +1048,15 @@ final class Analog private (width: Width) extends Element(width) {
case (_: UnboundBinding | _: WireBinding | PortBinding(_, None)) => super.binding_=(target)
case _ => throwException("Only Wires and Ports can be of type Analog")
}
private[core] override def cloneTypeWidth(w: Width): this.type =
new Analog(w).asInstanceOf[this.type]
private[chisel3] def toType = s"Analog$width"
def cloneType: this.type = cloneTypeWidth(width)
// What do flatten and fromBits mean?
private[chisel3] def flatten: IndexedSeq[Bits] =
throwException("Chisel Internal Error: Analog cannot be flattened into Bits")
def do_fromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type =
throwException("Analog does not support fromBits")

override def do_asUInt(implicit sourceInfo: SourceInfo): UInt =
throwException("Analog does not support asUInt")

private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo,
compileOptions: CompileOptions): Unit = {
throwException("Analog does not support connectFromBits")
}

final def toPrintable: Printable = PString("Analog")
}
/** Object that provides factory methods for [[Analog]] objects
Expand Down

12 comments on commit 375e2b6

@ucbjrl
Copy link
Contributor

@ucbjrl ucbjrl commented on 375e2b6 Feb 16, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is causing chisel-testers failures in peek/poke which assume that there is an API ( the old flatten method) that pulls apart a structured signal into its basic components.

@ducky64
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you need flatten there is a version in Data (though private, because it's probably really not the best API). Another option might be to refactor things to work more locally (getElements), one layer at a time?

@ucbjrl
Copy link
Contributor

@ucbjrl ucbjrl commented on 375e2b6 Feb 16, 2017 via email

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ducky64
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getElements is essentially the flatten replacement. I'm not really sure what to do with flatten, since it really doesn't belong in Data. Also, if we further expand the amount of types Chisel has (like Analog, which can't flatten), flatten looks more and more like a broken, brittle API.

@ucbjrl
Copy link
Contributor

@ucbjrl ucbjrl commented on 375e2b6 Feb 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks chisel-tutorial.examples.AdderTester. The problem area appears to be:

class FullAdder extends Module {
  val io = IO(new Bundle {
    val a    = Input(UInt(1.W))
    val b    = Input(UInt(1.W))
    val cin  = Input(UInt(1.W))
    val sum  = Output(UInt(1.W))
    val cout = Output(UInt(1.W))
  })
  ...
}

class Adder(val n:Int) extends Module {
  val io = IO(new Bundle {
    val A    = Input(UInt(n.W))
    val B    = Input(UInt(n.W))
    val Cin  = Input(UInt(1.W))
    val Sum  = Output(UInt(n.W))
    val Cout = Output(UInt(1.W))
  })
 ...
 //create a vector of FullAdders
  val FAs   = Vec(Seq.fill(n)(Module(new FullAdder()).io))
  //wire up the ports of the full adders
  for (i <- 0 until n) {
    FAs(i).a := io.A(i)
    FAs(i).b := io.B(i)
    FAs(i).cin := carry(i)
    carry(i+1) := FAs(i).cout
    sum(i) := FAs(i).sum.toBool()
  }

Before this commit, the generated firrtl was:

    FAs[0].cout <= FullAdder.io.cout @[Adder.scala 17:18]
    FAs[0].sum <= FullAdder.io.sum @[Adder.scala 17:18]
    FullAdder.io.cin <= FAs[0].cin @[Adder.scala 17:18]
    FullAdder.io.b <= FAs[0].b @[Adder.scala 17:18]
    FullAdder.io.a <= FAs[0].a @[Adder.scala 17:18]

With this commit, the generated firrtl is:

    FAs[0].cout <= FullAdder.io.cout @[Adder.scala 17:18]
    FAs[0].sum <= FullAdder.io.sum @[Adder.scala 17:18]
    FAs[0].cin <= FullAdder.io.cin @[Adder.scala 17:18]
    FAs[0].b <= FullAdder.io.b @[Adder.scala 17:18]
    FAs[0].a <= FullAdder.io.a @[Adder.scala 17:18]

@ducky64
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generated FIRRTL with this commit looks more sane? The only thing I see is that the directionality of the assign of cin, a, b is flipped, and more consistent with the input Chisel?

@ucbjrl
Copy link
Contributor

@ucbjrl ucbjrl commented on 375e2b6 Feb 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the broken firrtl, all the FA elements are on the same side of the <= operator. In the original chisel, some are sources and some are sinks. I don't think they should all be on the same side of the <= in the generated firrtl.

@ducky64
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure the previous code was sane? For example, there's a bit slice operation in the Chisel (io.A, io.B) that I don't see in the generated FIRRTL. And the new directionality of cin, a, b makes sense (compared to the input Chisel) while cout and sum doesn't match the old FIRRTL either...

@ucbjrl
Copy link
Contributor

@ucbjrl ucbjrl commented on 375e2b6 Feb 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll see if I can spot another culprit ...

@ucbjrl
Copy link
Contributor

@ucbjrl ucbjrl commented on 375e2b6 Feb 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial analysis of the location of the error was incorrect. The failing source line is:

  val FAs   = Vec(Seq.fill(n)(Module(new FullAdder()).io))

and the problem is due to the wiring up of the vector.

    val vec = Wire(new Vec(cloneSupertype(elts, "Vec"), elts.length))

loses the directionality of the elts so the following test in doConnect() returns false for the FAs vector:

    def doConnect(sink: T, source: T) = {
      val hasDirectioned = vec.sample_element match {
        case t: Aggregate => t.flatten.exists(_.dir != Direction.Unspecified)
        case t: Element => t.dir != Direction.Unspecified
      }
      if (hasDirectioned) {
        sink bulkConnect source
      } else {
        sink connect source
      }
    }

and we take the sink connect source branch instead of the sink bulkConnect source which we took in the old code.

@ducky64
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something we should keep an eye on and revisit once the directionality / vec / binding refactors happen, I don't think it's desirable that this doesn't work.

@ucbjrl
Copy link
Contributor

@ucbjrl ucbjrl commented on 375e2b6 Feb 22, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed issue #522.

Please sign in to comment.