Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make InferTypes error on enable conditions > 1-bit wide #2182

Merged
merged 1 commit into from
Apr 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions src/main/scala/firrtl/passes/CheckTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ object CheckTypes extends Pass {
class RegReqClk(info: Info, mname: String, name: String)
extends PassException(s"$info: [module $mname] Register $name requires a clock typed signal.")
class EnNotUInt(info: Info, mname: String)
extends PassException(s"$info: [module $mname] Enable must be a UIntType typed signal.")
extends PassException(s"$info: [module $mname] Enable must be a 1-bit UIntType typed signal.")
class PredNotUInt(info: Info, mname: String)
extends PassException(s"$info: [module $mname] Predicate not a UIntType.")
extends PassException(s"$info: [module $mname] Predicate not a 1-bit UIntType.")
class OpNotGround(info: Info, mname: String, op: String)
extends PassException(s"$info: [module $mname] Primop $op cannot operate on non-ground types.")
class OpNotUInt(info: Info, mname: String, op: String, e: String)
Expand All @@ -81,7 +81,7 @@ object CheckTypes extends Pass {
class MuxPassiveTypes(info: Info, mname: String)
extends PassException(s"$info: [module $mname] Must mux between passive types.")
class MuxCondUInt(info: Info, mname: String)
extends PassException(s"$info: [module $mname] A mux condition must be of type UInt.")
extends PassException(s"$info: [module $mname] A mux condition must be of type 1-bit UInt.")
class MuxClock(info: Info, mname: String)
extends PassException(s"$info: [module $mname] Firrtl does not support muxing clocks.")
class ValidIfPassiveTypes(info: Info, mname: String)
Expand Down Expand Up @@ -120,6 +120,15 @@ object CheckTypes extends Pass {
case _ => false
}

private def legalCondType(tpe: Type): Boolean = tpe match {
// If width is known, must be 1
case UIntType(IntWidth(w)) => w == 1
// Unknown width or variable widths (for width inference) are acceptable (checked in later run)
case UIntType(_) => true
// Any other type is not okay
case _ => false
}

private def bulk_equals(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = {
(t1, t2) match {
case (ClockType, ClockType) => flip1 == flip2
Expand Down Expand Up @@ -165,7 +174,8 @@ object CheckTypes extends Pass {
bulk_equals(con.loc.tpe, con.expr.tpe, Default, Default)

//;---------------- Helper Functions --------------
def ut: UIntType = UIntType(UnknownWidth)
private val UIntUnknown = UIntType(UnknownWidth)
def ut: UIntType = UIntUnknown
def st: SIntType = SIntType(UnknownWidth)

def run(c: Circuit): Circuit = {
Expand Down Expand Up @@ -332,9 +342,8 @@ object CheckTypes extends Pass {
errors.append(new MuxSameType(info, mname, e.tval.tpe.serialize, e.fval.tpe.serialize))
if (!passive(e.tpe))
errors.append(new MuxPassiveTypes(info, mname))
e.cond.tpe match {
case _: UIntType =>
case _ => errors.append(new MuxCondUInt(info, mname))
if (!legalCondType(e.cond.tpe)) {
errors.append(new MuxCondUInt(info, mname))
}
case (e: ValidIf) =>
if (!passive(e.tpe))
Expand Down Expand Up @@ -375,7 +384,7 @@ object CheckTypes extends Pass {
if (sx.clock.tpe != ClockType) {
errors.append(new RegReqClk(info, mname, sx.name))
}
case sx: Conditionally if wt(sx.pred.tpe) != wt(ut) =>
case sx: Conditionally if !legalCondType(sx.pred.tpe) =>
errors.append(new PredNotUInt(info, mname))
case sx: DefNode =>
sx.value.tpe match {
Expand All @@ -396,16 +405,16 @@ object CheckTypes extends Pass {
}
case sx: Stop =>
if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname))
if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname))
if (!legalCondType(sx.en.tpe)) errors.append(new EnNotUInt(info, mname))
case sx: Print =>
if (sx.args.exists(x => wt(x.tpe) != wt(ut) && wt(x.tpe) != wt(st)))
errors.append(new PrintfArgNotGround(info, mname))
if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname))
if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname))
if (!legalCondType(sx.en.tpe)) errors.append(new EnNotUInt(info, mname))
case sx: Verification =>
if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname))
if (wt(sx.pred.tpe) != wt(ut)) errors.append(new PredNotUInt(info, mname))
if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname))
if (!legalCondType(sx.pred.tpe)) errors.append(new PredNotUInt(info, mname))
if (!legalCondType(sx.en.tpe)) errors.append(new EnNotUInt(info, mname))
case sx: DefMemory =>
sx.dataType match {
case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
Expand Down
100 changes: 100 additions & 0 deletions src/test/scala/firrtlTests/CheckSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,106 @@ class CheckSpec extends AnyFlatSpec with Matchers {
}
}

behavior.of("Check Types")

def runCheckTypes(input: String) = {
val passes = List(InferTypes, CheckTypes)
val wrapped = "circuit test:\n module test:\n " + input.replaceAll("\n", "\n ")
passes.foldLeft(Parser.parse(wrapped)) { case (c, p) => p.run(c) }
}

it should "disallow mux enable conditions that are not 1-bit UInts (or unknown width)" in {
def mk(tpe: String) =
s"""|input en : $tpe
|input foo : UInt<8>
|input bar : UInt<8>
|node x = mux(en, foo, bar)""".stripMargin
a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) }
a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) }
a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) }
a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) }
a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) }
runCheckTypes(mk("UInt"))
runCheckTypes(mk("UInt<1>"))
}

it should "disallow when predicates that are not 1-bit UInts (or unknown width)" in {
def mk(tpe: String) =
s"""|input en : $tpe
|input foo : UInt<8>
|input bar : UInt<8>
|output out : UInt<8>
|when en :
| out <= foo
|else:
| out <= bar""".stripMargin
a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) }
a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) }
a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) }
a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) }
a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) }
runCheckTypes(mk("UInt"))
runCheckTypes(mk("UInt<1>"))
}

it should "disallow print enables that are not 1-bit UInts (or unknown width)" in {
def mk(tpe: String) =
s"""|input en : $tpe
|input clock : Clock
|printf(clock, en, "Hello World!\\n")""".stripMargin
a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) }
a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) }
a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) }
a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) }
a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) }
runCheckTypes(mk("UInt"))
runCheckTypes(mk("UInt<1>"))
}

it should "disallow stop enables that are not 1-bit UInts (or unknown width)" in {
def mk(tpe: String) =
s"""|input en : $tpe
|input clock : Clock
|stop(clock, en, 0)""".stripMargin
a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) }
a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) }
a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) }
a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) }
a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) }
runCheckTypes(mk("UInt"))
runCheckTypes(mk("UInt<1>"))
}

it should "disallow verif node predicates that are not 1-bit UInts (or unknown width)" in {
def mk(tpe: String) =
s"""|input en : $tpe
|input cond : UInt<1>
|input clock : Clock
|assert(clock, en, cond, "Howdy!")""".stripMargin
a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) }
a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) }
a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) }
a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) }
a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) }
runCheckTypes(mk("UInt"))
runCheckTypes(mk("UInt<1>"))
}

it should "disallow verif node enables that are not 1-bit UInts (or unknown width)" in {
def mk(tpe: String) =
s"""|input en : UInt<1>
|input cond : $tpe
|input clock : Clock
|assert(clock, en, cond, "Howdy!")""".stripMargin
a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) }
a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) }
a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) }
a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) }
a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) }
runCheckTypes(mk("UInt"))
runCheckTypes(mk("UInt<1>"))
}

"Instance loops a -> b -> a" should "be detected" in {
val input =
"""
Expand Down
4 changes: 2 additions & 2 deletions src/test/scala/firrtlTests/LowerTypesSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,10 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec {
| input a : { b : UInt<1>, flip c : { d : UInt<2>, e : UInt<3>}[2], c_1_e : UInt<4>}[2]
| output a_0_b : UInt<1>
| input a__0_c_ : { d : UInt<2>, e : UInt<3>}[2]
| a_0_b <= mux(a[UInt(0)].c_1_e, or(a[or(a[0].b, a[1].b)].b, xorr(a[0].c_1_e)), orr(cat(a__0_c_[0].e, a[1].c_1_e)))
| a_0_b <= mux(bits(a[UInt(0)].c_1_e, 0, 0), or(a[or(a[0].b, a[1].b)].b, xorr(a[0].c_1_e)), orr(cat(a__0_c_[0].e, a[1].c_1_e)))
""".stripMargin
val expected = Seq(
"a_0_b <= mux(a___0_c_1_e, or(_a_or_b, xorr(a___0_c_1_e)), orr(cat(a__0_c__0_e, a___1_c_1_e)))"
"a_0_b <= mux(bits(a___0_c_1_e, 0, 0), or(_a_or_b, xorr(a___0_c_1_e)), orr(cat(a__0_c__0_e, a___1_c_1_e)))"
)

executeTest(input, expected)
Expand Down
4 changes: 2 additions & 2 deletions src/test/scala/firrtlTests/ReplSeqMemTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ circuit CustomMemory :
circuit CustomMemory :
module CustomMemory :
input clock : Clock
output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<8>[2] }
output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<1>[2] }

smem mem : UInt<8>[2][1024]
read mport r = mem[io.raddr], clock
Expand Down Expand Up @@ -452,7 +452,7 @@ circuit CustomMemory :
circuit CustomMemory :
module CustomMemory :
input clock : Clock
output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<8>[2] }
output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<1>[2] }

io.out is invalid

Expand Down