Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RemoveAccesses, delete CSESubAccesses #2157

Merged
merged 2 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 35 additions & 28 deletions src/main/scala/firrtl/passes/RemoveAccesses.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import firrtl.Mappers._
import firrtl.Utils._
import firrtl.WrappedExpression._
import firrtl.options.Dependency
import firrtl.transforms.CSESubAccesses

import scala.collection.mutable

Expand All @@ -22,8 +21,7 @@ object RemoveAccesses extends Pass {
Dependency(PullMuxes),
Dependency(ZeroLengthVecs),
Dependency(ReplaceAccesses),
Dependency(ExpandConnects),
Dependency[CSESubAccesses]
Dependency(ExpandConnects)
) ++ firrtl.stage.Forms.Deduped

override def invalidates(a: Transform): Boolean = a match {
Expand Down Expand Up @@ -122,26 +120,26 @@ object RemoveAccesses extends Pass {
/** Replaces a subaccess in a given source expression
*/
val stmts = mutable.ArrayBuffer[Statement]()
def removeSource(e: Expression): Expression = e match {
case (_: WSubAccess | _: WSubField | _: WSubIndex | _: WRef) if hasAccess(e) =>
val rs = getLocations(e)
rs.find(x => x.guard != one) match {
case None => throwInternalError(s"removeSource: shouldn't be here - $e")
case Some(_) =>
val (wire, temp) = create_temp(e)
val temps = create_exps(temp)
def getTemp(i: Int) = temps(i % temps.size)
stmts += wire
rs.zipWithIndex.foreach {
case (x, i) if i < temps.size =>
stmts += IsInvalid(get_info(s), getTemp(i))
stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
case (x, i) =>
stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
}
temp
}
case _ => e
// Only called on RefLikes that definitely have a SubAccess
// Must accept Expression because that's the output type of fixIndices
def removeSource(e: Expression): Expression = {
val rs = getLocations(e)
rs.find(x => x.guard != one) match {
case None => throwInternalError(s"removeSource: shouldn't be here - $e")
case Some(_) =>
val (wire, temp) = create_temp(e)
val temps = create_exps(temp)
def getTemp(i: Int) = temps(i % temps.size)
stmts += wire
rs.zipWithIndex.foreach {
case (x, i) if i < temps.size =>
stmts += IsInvalid(get_info(s), getTemp(i))
stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
case (x, i) =>
stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
}
temp
}
Comment on lines +123 to +142
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this diff is just removing the match since we only ever call this on RefLikeExpressions, the match is redundant.

}

/** Replaces a subaccess in a given sink expression
Expand All @@ -162,14 +160,23 @@ object RemoveAccesses extends Pass {
case _ => loc
}

/** Recurse until find SubAccess and call fixSource on its index
* @note this only accepts [[RefLikeExpression]]s but we can't enforce it because map
* requires Expression => Expression
*/
def fixIndices(e: Expression): Expression = e match {
case e: SubAccess => e.copy(index = fixSource(e.index))
case other => other.map(fixIndices)
}

/** Recursively walks a source expression and fixes all subaccesses
* If we see a sub-access, replace it.
* Otherwise, map to children.
*
* If we see a RefLikeExpression that contains a SubAccess, we recursively remove
* subaccesses from the indices of any SubAccesses, then process modified RefLikeExpression
*/
def fixSource(e: Expression): Expression = e match {
case w: WSubAccess => removeSource(WSubAccess(w.expr, fixSource(w.index), w.tpe, w.flow))
//case w: WSubIndex => removeSource(w)
//case w: WSubField => removeSource(w)
case ref: RefLikeExpression =>
if (hasAccess(ref)) removeSource(fixIndices(ref)) else ref
case x => x.map(fixSource)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The diff above here is the real meat of the PR, by recursing until you see the SubAccess for RHS subaccesses ("Sources"), you end up always expanding the entire aggregate. This changes it to match on the outermost RefLikeExpression it can to ensure that we respect any SubFields or SubIndexes when doing the expansion. This is the fix.

}

Expand Down
1 change: 0 additions & 1 deletion src/main/scala/firrtl/stage/Forms.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ object Forms {
val MidForm: Seq[TransformDependency] = HighForm ++
Seq(
Dependency(passes.PullMuxes),
Dependency[firrtl.transforms.CSESubAccesses],
Dependency(passes.ReplaceAccesses),
Dependency(passes.ExpandConnects),
Dependency(passes.RemoveAccesses),
Expand Down
168 changes: 0 additions & 168 deletions src/main/scala/firrtl/transforms/CSESubAccesses.scala

This file was deleted.

3 changes: 1 addition & 2 deletions src/test/scala/firrtlTests/LowerTypesSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,7 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec {
| out <= in0[in1[in2[0]]][in1[in2[1]]]
|""".stripMargin
val expected = Seq(
"node _in0_in1_in1 = _in0_in1_in1_in2_1",
"out <= _in0_in1_in1"
"out <= _in0_in1_in1_in2_1"
)

executeTest(input, expected)
Expand Down
1 change: 0 additions & 1 deletion src/test/scala/firrtlTests/LoweringCompilersSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers {
it should "replicate the old order" in {
val tm = new TransformManager(Forms.MidForm, Forms.Deduped)
val patches = Seq(
Add(2, Seq(Dependency[firrtl.transforms.CSESubAccesses])),
Add(4, Seq(Dependency(firrtl.passes.ResolveFlows))),
Add(5, Seq(Dependency(firrtl.passes.ResolveKinds))),
// Uniquify is now part of [[firrtl.passes.LowerTypes]]
Expand Down
10 changes: 5 additions & 5 deletions src/test/scala/firrtlTests/UnitTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,14 @@ class UnitTests extends FirrtlFlatSpec {
//TODO(azidar): I realize this is brittle, but unfortunately there
// isn't a better way to test this pass
val check = Seq(
"""wire _table_1 : { a : UInt<8>}""",
"""_table_1.a is invalid""",
"""wire _table_1_a : UInt<8>""",
"""_table_1_a is invalid""",
"""when UInt<1>("h1") :""",
"""_table_1.a <= table[1].a""",
"""_table_1_a <= table[1].a""",
"""wire _otherTable_table_1_a_a : UInt<8>""",
"""when eq(UInt<1>("h0"), _table_1.a) :""",
"""when eq(UInt<1>("h0"), _table_1_a) :""",
"""otherTable[0].a <= _otherTable_table_1_a_a""",
"""when eq(UInt<1>("h1"), _table_1.a) :""",
"""when eq(UInt<1>("h1"), _table_1_a) :""",
"""otherTable[1].a <= _otherTable_table_1_a_a""",
"""_otherTable_table_1_a_a <= UInt<1>("h0")"""
)
Expand Down
Loading