Skip to content

Commit

Permalink
[LowerToHW] Implement MuxCell intrinsics lowering (#5458)
Browse files Browse the repository at this point in the history
This commit deprecates addMuxPragma flag in LowerToHW and implements lowering of MuxCell intrinsics.

MuxCell intrinsics are lowered into a wire declaration, an assignment with infer_mux_override pragma and a value with map_to_mux pragma. Operands are probed to hw.wire with inner symbols to prevent optimizations from destructing AST structures.
  • Loading branch information
uenoku committed Jul 3, 2023
1 parent 4ffcdd2 commit b903f48
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 45 deletions.
3 changes: 1 addition & 2 deletions include/circt/Conversion/FIRRTLToHW.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ namespace circt {

std::unique_ptr<mlir::Pass> createLowerFIRRTLToHWPass(
bool enableAnnotationWarning = false, bool emitChiselAssertsAsSVA = false,
bool addMuxPragmas = false, bool disableMemRandomization = false,
bool disableRegRandomization = false);
bool disableMemRandomization = false, bool disableRegRandomization = false);

} // namespace circt

Expand Down
4 changes: 1 addition & 3 deletions include/circt/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,7 @@ def LowerFIRRTLToHW : Pass<"lower-firrtl-to-hw", "mlir::ModuleOp"> {
"bool", "false",
"Emit warnings on unprocessed annotations during lower-to-hw pass">,
Option<"emitChiselAssertsAsSVA", "emit-chisel-asserts-as-sva",
"bool", "false","Convert all Chisel asserts to SVA">,
Option<"addMuxPragmas", "add-mux-pragmas", "bool", "false",
"Annotate mux pragmas to multibit mux and subacess results">
"bool", "false","Convert all Chisel asserts to SVA">
];
}

Expand Down
3 changes: 2 additions & 1 deletion include/circt/Firtool/Firtool.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ struct FirtoolOptions {
llvm::cl::init(false), llvm::cl::cat(category)};

llvm::cl::opt<bool> addMuxPragmas{
"add-mux-pragmas", llvm::cl::desc("Annotate mux pragmas"),
"add-mux-pragmas",
llvm::cl::desc("Annotate mux pragmas for memory array access"),
llvm::cl::init(false), llvm::cl::cat(category)};

llvm::cl::opt<bool> emitChiselAssertsAsSVA{
Expand Down
114 changes: 78 additions & 36 deletions lib/Conversion/FIRRTLToHW/LowerToHW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,11 @@ struct CircuitLoweringState {
std::atomic<bool> used_RANDOMIZE_GARBAGE_ASSIGN{false};

CircuitLoweringState(CircuitOp circuitOp, bool enableAnnotationWarning,
bool emitChiselAssertsAsSVA, bool addMuxPragmas,
bool emitChiselAssertsAsSVA,
InstanceGraph *instanceGraph, NLATable *nlaTable)
: circuitOp(circuitOp), instanceGraph(instanceGraph),
enableAnnotationWarning(enableAnnotationWarning),
emitChiselAssertsAsSVA(emitChiselAssertsAsSVA),
addMuxPragmas(addMuxPragmas), nlaTable(nlaTable) {
emitChiselAssertsAsSVA(emitChiselAssertsAsSVA), nlaTable(nlaTable) {
auto *context = circuitOp.getContext();

// Get the testbench output directory.
Expand Down Expand Up @@ -327,7 +326,6 @@ struct CircuitLoweringState {
std::mutex annotationPrintingMtx;

const bool emitChiselAssertsAsSVA;
const bool addMuxPragmas;

// Records any sv::BindOps that are found during the course of execution.
// This is unsafe to access directly and should only be used through addBind.
Expand Down Expand Up @@ -423,7 +421,6 @@ struct FIRRTLModuleLowering : public LowerFIRRTLToHWBase<FIRRTLModuleLowering> {
void setDisableRegRandomization() { disableRegRandomization = true; }
void setEnableAnnotationWarning() { enableAnnotationWarning = true; }
void setEmitChiselAssertAsSVA() { emitChiselAssertsAsSVA = true; }
void setAddMuxPragmas() { addMuxPragmas = true; }

private:
void lowerFileHeader(CircuitOp op, CircuitLoweringState &loweringState);
Expand Down Expand Up @@ -456,15 +453,12 @@ struct FIRRTLModuleLowering : public LowerFIRRTLToHWBase<FIRRTLModuleLowering> {
/// This is the pass constructor.
std::unique_ptr<mlir::Pass> circt::createLowerFIRRTLToHWPass(
bool enableAnnotationWarning, bool emitChiselAssertsAsSVA,
bool addMuxPragmas, bool disableMemRandomization,
bool disableRegRandomization) {
bool disableMemRandomization, bool disableRegRandomization) {
auto pass = std::make_unique<FIRRTLModuleLowering>();
if (enableAnnotationWarning)
pass->setEnableAnnotationWarning();
if (emitChiselAssertsAsSVA)
pass->setEmitChiselAssertAsSVA();
if (addMuxPragmas)
pass->setAddMuxPragmas();
if (disableMemRandomization)
pass->setDisableMemRandomization();
if (disableRegRandomization)
Expand Down Expand Up @@ -495,7 +489,7 @@ void FIRRTLModuleLowering::runOnOperation() {
// Keep track of the mapping from old to new modules. The result may be null
// if lowering failed.
CircuitLoweringState state(
circuit, enableAnnotationWarning, emitChiselAssertsAsSVA, addMuxPragmas,
circuit, enableAnnotationWarning, emitChiselAssertsAsSVA,
&getAnalysis<InstanceGraph>(), &getAnalysis<NLATable>());

SmallVector<FModuleOp, 32> modulesToProcess;
Expand Down Expand Up @@ -1547,6 +1541,7 @@ struct FIRRTLLowering : public FIRRTLVisitor<FIRRTLLowering, LogicalResult> {
FIRRTLBaseType destType,
bool allowTruncate);
Value createArrayIndexing(Value array, Value index);
Value createValueWithMuxAnnotation(Operation *op, bool isMux2);

// Create a temporary wire at the current insertion point, and try to
// eliminate it later as part of lowering post processing.
Expand Down Expand Up @@ -1710,6 +1705,8 @@ struct FIRRTLLowering : public FIRRTLVisitor<FIRRTLLowering, LogicalResult> {
}
LogicalResult visitExpr(TailPrimOp op);
LogicalResult visitExpr(MuxPrimOp op);
LogicalResult visitExpr(Mux2CellIntrinsicOp op);
LogicalResult visitExpr(Mux4CellIntrinsicOp op);
LogicalResult visitExpr(MultibitMuxOp op);
LogicalResult visitExpr(VerbatimExprOp op);

Expand Down Expand Up @@ -3918,14 +3915,81 @@ LogicalResult FIRRTLLowering::visitExpr(MuxPrimOp op) {
true);
}

// Construct array indexing annotated with vendor pragmas to get
// better synthesis results. Specifically we annotate pragmas in the following
// form.
LogicalResult FIRRTLLowering::visitExpr(Mux2CellIntrinsicOp op) {
auto cond = getLoweredValue(op.getSel());
auto ifTrue = getLoweredAndExtendedValue(op.getHigh(), op.getType());
auto ifFalse = getLoweredAndExtendedValue(op.getLow(), op.getType());
if (!cond || !ifTrue || !ifFalse)
return failure();

auto val = builder.create<comb::MuxOp>(ifTrue.getType(), cond, ifTrue,
ifFalse, true);
return setLowering(op, createValueWithMuxAnnotation(val, true));
}

LogicalResult FIRRTLLowering::visitExpr(Mux4CellIntrinsicOp op) {
auto sel = getLoweredValue(op.getSel());
auto v3 = getLoweredAndExtendedValue(op.getV3(), op.getType());
auto v2 = getLoweredAndExtendedValue(op.getV2(), op.getType());
auto v1 = getLoweredAndExtendedValue(op.getV1(), op.getType());
auto v0 = getLoweredAndExtendedValue(op.getV0(), op.getType());
if (!sel || !v3 || !v2 || !v1 || !v0)
return failure();
Value array[] = {v3, v2, v1, v0};
auto create = builder.create<hw::ArrayCreateOp>(array);
auto val = builder.create<hw::ArrayGetOp>(create, sel);
return setLowering(op, createValueWithMuxAnnotation(val, false));
}

// Construct a value with vendor specific pragmas to utilize MUX cells.
// Specifically we annotate pragmas in the following form.
//
// For an array indexing:
// ```
// wire GEN;
// /* synopsys infer_mux_override */
// assign GEN = array[index] /* cadence map_to_mux */;
// ```
//
// For a mux:
// ```
// wire GEN;
// /* synopsys infer_mux_override */
// assign GEN = sel ? /* cadence map_to_mux */ high : low;
// ```
Value FIRRTLLowering::createValueWithMuxAnnotation(Operation *op, bool isMux2) {
assert(op->getNumResults() == 1 && "only expect a single result");
auto val = op->getResult(0);
auto valWire = builder.create<sv::WireOp>(val.getType());
// Use SV attributes to annotate pragmas.
circt::sv::setSVAttributes(
op, sv::SVAttributeAttr::get(builder.getContext(), "cadence map_to_mux",
/*emitAsComment=*/true));

// For operands, create temporary wires with optimization blockers(inner
// symbols) so that the AST structure will never be destoyed in the later
// pipeline.
{
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(op);
StringRef namehint = isMux2 ? "mux2cell_in" : "mux4cell_in";
for (auto [idx, operand] : llvm::enumerate(op->getOperands())) {
auto sym = moduleNamespace.newName(Twine("__") + theModule.getName() +
Twine("__MUX__PRAGMA"));
auto wire =
builder.create<hw::WireOp>(operand, namehint + Twine(idx), sym);
op->setOperand(idx, wire);
}
}

auto assignOp = builder.create<sv::AssignOp>(valWire, val);
sv::setSVAttributes(assignOp,
sv::SVAttributeAttr::get(builder.getContext(),
"synopsys infer_mux_override",
/*emitAsComment=*/true));
return builder.create<sv::ReadInOutOp>(valWire);
}

Value FIRRTLLowering::createArrayIndexing(Value array, Value index) {

auto size = hw::type_cast<hw::ArrayType>(array.getType()).getSize();
Expand All @@ -3942,29 +4006,7 @@ Value FIRRTLLowering::createArrayIndexing(Value array, Value index) {
array = builder.create<hw::ArrayConcatOp>(temp2);
}

Value inBoundsRead;
// If `addMuxPragmas` is enabled, add mux pragmas to array reads.
// Don't annotate mux pragmas if the array size is 1 since it causes a
// complication failure.
if (!circuitState.addMuxPragmas || size <= 1) {
inBoundsRead = builder.create<hw::ArrayGetOp>(array, index);
} else {
auto arrayGet = builder.create<hw::ArrayGetOp>(array, index);
auto valWire = builder.create<sv::WireOp>(
hw::type_cast<hw::ArrayType>(array.getType()).getElementType());
// Use SV attributes to annotate pragmas.
circt::sv::setSVAttributes(
arrayGet,
sv::SVAttributeAttr::get(builder.getContext(), "cadence map_to_mux",
/*emitAsComment=*/true));

auto assignOp = builder.create<sv::AssignOp>(valWire, arrayGet);
sv::setSVAttributes(assignOp,
sv::SVAttributeAttr::get(builder.getContext(),
"synopsys infer_mux_override",
/*emitAsComment=*/true));
inBoundsRead = builder.create<sv::ReadInOutOp>(valWire);
}
Value inBoundsRead = builder.create<hw::ArrayGetOp>(array, index);

return inBoundsRead;
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Firtool/Firtool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ LogicalResult firtool::populateLowFIRRTLToHW(mlir::PassManager &pm,

pm.addPass(createLowerFIRRTLToHWPass(
opt.enableAnnotationWarning.getValue(),
opt.emitChiselAssertsAsSVA.getValue(), opt.addMuxPragmas.getValue(),
opt.emitChiselAssertsAsSVA.getValue(),
!opt.isRandomEnabled(FirtoolOptions::RandomKind::Mem),
!opt.isRandomEnabled(FirtoolOptions::RandomKind::Reg)));

Expand Down
22 changes: 22 additions & 0 deletions test/Conversion/FIRRTLToHW/lower-to-hw.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1649,4 +1649,26 @@ firrtl.circuit "Simple" attributes {annotations = [{class =

// CHECK-NEXT: hw.output %[[OR]], %[[AND]], %[[XOR]] : !hw.array<2xi1>, !hw.array<2xi1>, !hw.array<2xi1>
}
// CHECK-LABEL: @MuxIntrinsics
firrtl.module @MuxIntrinsics(in %sel1: !firrtl.uint<1>, in %sel2: !firrtl.uint<2>, in %v3: !firrtl.uint<32>, in %v2: !firrtl.uint<32>, in %v1: !firrtl.uint<32>, in %v0: !firrtl.uint<32>, out %out1: !firrtl.uint<32>, out %out2: !firrtl.uint<32>) attributes {convention = #firrtl<convention scalarized>} {
%0 = firrtl.int.mux2cell(%sel1, %v1, %v0) : (!firrtl.uint<1>, !firrtl.uint<32>, !firrtl.uint<32>) -> !firrtl.uint<32>
firrtl.strictconnect %out1, %0 : !firrtl.uint<32>
// CHECK-NEXT: %mux2cell_in0 = hw.wire %sel1 sym @__MuxIntrinsics__MUX__PRAGMA : i1
// CHECK-NEXT: %mux2cell_in1 = hw.wire %v1 sym @__MuxIntrinsics__MUX__PRAGMA_0 : i32
// CHECK-NEXT: %mux2cell_in2 = hw.wire %v0 sym @__MuxIntrinsics__MUX__PRAGMA_1 : i32
// CHECK-NEXT: %0 = comb.mux bin %mux2cell_in0, %mux2cell_in1, %mux2cell_in2 {sv.attributes = [#sv.attribute<"cadence map_to_mux", emitAsComment>]} : i32
// CHECK-NEXT: %1 = sv.wire : !hw.inout<i32>
// CHECK-NEXT: sv.assign %1, %0 {sv.attributes = [#sv.attribute<"synopsys infer_mux_override", emitAsComment>]} : i32
// CHECK-NEXT: %2 = sv.read_inout %1 : !hw.inout<i32>

%1 = firrtl.int.mux4cell(%sel2, %v3, %v2, %v1, %v0) : (!firrtl.uint<2>, !firrtl.uint<32>, !firrtl.uint<32>, !firrtl.uint<32>, !firrtl.uint<32>) -> !firrtl.uint<32>
firrtl.strictconnect %out2, %1 : !firrtl.uint<32>
// CHECK: %mux4cell_in0 = hw.wire %3 sym @__MuxIntrinsics__MUX__PRAGMA_2 : !hw.array<4xi32>
// CHECK-NEXT: %mux4cell_in1 = hw.wire %sel2 sym @__MuxIntrinsics__MUX__PRAGMA_3 : i2
// CHECK-NEXT: %4 = hw.array_get %mux4cell_in0[%mux4cell_in1] {sv.attributes = [#sv.attribute<"cadence map_to_mux", emitAsComment>]} : !hw.array<4xi32>, i2
// CHECK-NEXT: %5 = sv.wire : !hw.inout<i32>
// CHECK-NEXT: sv.assign %5, %4 {sv.attributes = [#sv.attribute<"synopsys infer_mux_override", emitAsComment>]} : i32
// CHECK-NEXT: %6 = sv.read_inout %5 : !hw.inout<i32>
// CHECK-NEXT: hw.output %2, %6 : i32, i32
}
}
49 changes: 47 additions & 2 deletions test/firtool/firtool.fir
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ circuit test_mod : %[[{"class": "circt.testNT", "data": "a"}]]
output out_implicitTrunc: UInt<1>
output out_prettifyExample: UInt<1>
output out_multibitMux: UInt<1>
output out_mux2cell: UInt<1>[2]

inst cat of Cat
cat.a <= b
Expand Down Expand Up @@ -84,6 +85,11 @@ circuit test_mod : %[[{"class": "circt.testNT", "data": "a"}]]
inst unusedPortsMod of UnusedPortsMod
unusedPortsMod.in <= a

inst mux2cell of Mux2Cell
mux2cell.cond <= a
mux2cell.low <= a
out_mux2cell <= mux2cell.out

; These outputs exist to work around the aggressive removal of unused module
; ports.
;
Expand All @@ -94,7 +100,7 @@ circuit test_mod : %[[{"class": "circt.testNT", "data": "a"}]]
out_multibitMux <= multibitMux.b


; MLIR-LABEL: firrtl.module @test_mod(in %clock: !firrtl.clock, in %a: !firrtl.uint<1>, in %b: !firrtl.uint<2>, out %c: !firrtl.uint<1>, in %vec_0: !firrtl.uint<1>, in %vec_1: !firrtl.uint<1>, in %vec_2: !firrtl.uint<1>, out %out_implicitTrunc: !firrtl.uint<1>, out %out_prettifyExample: !firrtl.uint<1>, out %out_multibitMux: !firrtl.uint<1>) {{.*}}{
; MLIR-LABEL: firrtl.module @test_mod(in %clock: !firrtl.clock, in %a: !firrtl.uint<1>, in %b: !firrtl.uint<2>, out %c: !firrtl.uint<1>, in %vec_0: !firrtl.uint<1>, in %vec_1: !firrtl.uint<1>, in %vec_2: !firrtl.uint<1>, out %out_implicitTrunc: !firrtl.uint<1>, out %out_prettifyExample: !firrtl.uint<1>, out %out_multibitMux: !firrtl.uint<1>, out %out_mux2cell_0: !firrtl.uint<1>, out %out_mux2cell_1: !firrtl.uint<1>) {{.*}}{
; MLIR-NEXT: %cat_a, %cat_b, %cat_c, %cat_d = firrtl.instance cat @Cat(in a: !firrtl.uint<2>, in b: !firrtl.uint<2>, in c: !firrtl.uint<2>, out d: !firrtl.uint<6>)
; MLIR-NEXT: firrtl.strictconnect %cat_a, %b : !firrtl.uint<2>
; MLIR-NEXT: firrtl.strictconnect %cat_b, %b : !firrtl.uint<2>
Expand Down Expand Up @@ -136,7 +142,9 @@ circuit test_mod : %[[{"class": "circt.testNT", "data": "a"}]]
; VERILOG-NEXT: output c,
; VERILOG-NEXT: out_implicitTrunc,
; VERILOG-NEXT: out_prettifyExample,
; VERILOG-NEXT: out_multibitMux
; VERILOG-NEXT: out_multibitMux,
; VERILOG-NEXT: out_mux2cell_0,
; VERILOG-NEXT: out_mux2cell_1
; VERILOG-NEXT: );
; VERILOG-EMPTY:
; VERILOG-NEXT: wire [9:0] _prettifyExample_out1;
Expand Down Expand Up @@ -277,3 +285,40 @@ circuit test_mod : %[[{"class": "circt.testNT", "data": "a"}]]
input in : UInt<1>
output out : UInt<1>
out is invalid

intmodule Mux2:
input sel: UInt<1>
input high: UInt
input low: UInt
output out: UInt
intrinsic = circt_mux2cell

extmodule Val:
output v: UInt<1>
; VERILOG-LABEL: Mux2Cell
; Make sure that mux annotations are emitted properly.
; VERILOG: /* synopsys infer_mux_override */
; VERILOG-NEXT: assign [[OUT1:.+]] = mux2cell_in0 ? /* cadence map_to_mux */ mux2cell_in1 : mux2cell_in2;
; VERILOG: /* synopsys infer_mux_override */
; VERILOG-NEXT: assign [[OUT2:.+]] =
; VERILOG-NEXT: mux2cell_in0_0 ? /* cadence map_to_mux */ mux2cell_in1_0 : mux2cell_in2_0;
; VERILOG: assign out_0 = [[OUT1]];
; VERILOG: assign out_1 = [[OUT2]];

module Mux2Cell:
input cond: UInt<1>
input low: UInt<1>
output out: UInt<1>[2]
wire w: UInt<1>
inst mux2 of Mux2
inst ext of Val
w <= ext.v
mux2.sel <= xor(cond, UInt<1>(1))
mux2.high <= w
mux2.low <= low
out[0] <= mux2.out
inst mux2_1 of Mux2
mux2_1.sel <= xor(cond, UInt<1>(1))
mux2_1.high <= ext.v
mux2_1.low <= low
out[1] <= mux2_1.out

0 comments on commit b903f48

Please sign in to comment.