From f8572ba6532359e8a0f1bc34f3eb8241a29129ab Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Thu, 22 Jun 2017 17:51:39 -0700 Subject: [PATCH] Add support for wires in ConstProp This requires a quick second pass to back propagate constant wires but the QoR win is substantial. We also only need to count back propagations in determining whether to run ConstProp again which shaves off an iteration in the common case. --- src/main/scala/firrtl/passes/ConstProp.scala | 28 ++++++++++++++++--- .../scala/firrtlTests/AnnotationTests.scala | 7 +++-- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala index 8a2d6ec628..f2aa1a0301 100644 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ b/src/main/scala/firrtl/passes/ConstProp.scala @@ -234,21 +234,38 @@ object ConstProp extends Pass { case _ => r } + // Two pass process + // 1. Propagate constants in expressions and forward propagate references + // 2. Propagate references again for backwards reference (Wires) + // TODO Replacing all wires with nodes makes the second pass unnecessary @tailrec private def constPropModule(m: Module): Module = { var nPropagated = 0L val nodeMap = collection.mutable.HashMap[String, Expression]() + def backPropExpr(expr: Expression): Expression = { + val old = expr map backPropExpr + val propagated = old match { + case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => + constPropNodeRef(ref, nodeMap(rname)) + case x => x + } + if (old ne propagated) { + nPropagated += 1 + } + propagated + } + def backPropStmt(stmt: Statement): Statement = stmt map backPropStmt map backPropExpr + def constPropExpression(e: Expression): Expression = { val old = e map constPropExpression val propagated = old match { case p: DoPrim => constPropPrim(p) case m: Mux => constPropMux(m) - case r: WRef if nodeMap contains r.name => constPropNodeRef(r, nodeMap(r.name)) + case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => + constPropNodeRef(ref, nodeMap(rname)) case x => x } - if (old ne propagated) - nPropagated += 1 propagated } @@ -256,12 +273,15 @@ object ConstProp extends Pass { val stmtx = s map constPropStmt map constPropExpression stmtx match { case x: DefNode => nodeMap(x.name) = x.value + case Connect(_, WRef(wname, wtpe, WireKind, _), expr) => + val exprx = constPropExpression(pad(expr, wtpe)) + nodeMap(wname) = exprx case _ => } stmtx } - val res = Module(m.info, m.name, m.ports, constPropStmt(m.body)) + val res = Module(m.info, m.name, m.ports, backPropStmt(constPropStmt(m.body))) if (nPropagated > 0) constPropModule(res) else res } diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index e3dd3dbd48..3e93081e4f 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -272,7 +272,7 @@ class AnnotationTests extends AnnotationSpec with Matchers { anno("n.a"), anno("n.b[0]"), anno("n.b[1]"), anno("r.a"), anno("r.b[0]"), anno("r.b[1]"), anno("write.a"), anno("write.b[0]"), anno("write.b[1]"), - dontTouch("Top.r") + dontTouch("Top.r"), dontTouch("Top.w") ) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations @@ -326,7 +326,8 @@ class AnnotationTests extends AnnotationSpec with Matchers { | out <= n | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk |""".stripMargin - val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r")) + val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r"), + dontTouch("Top.w")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations resultAnno should contain (anno("in_a")) @@ -362,7 +363,7 @@ class AnnotationTests extends AnnotationSpec with Matchers { | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk |""".stripMargin val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b"), - dontTouch("Top.r")) + dontTouch("Top.r"), dontTouch("Top.w")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations resultAnno should contain (anno("in_b_0"))