From 2205b3cc0996b34365705785db76000418f8c427 Mon Sep 17 00:00:00 2001 From: Prithayan Barua Date: Wed, 13 Mar 2024 10:18:36 -0400 Subject: [PATCH] [InferRW] Remove dependence of WMODE on EN --- .../FIRRTL/Transforms/InferReadWrite.cpp | 66 +++++++++++++++++++ test/Dialect/FIRRTL/inferRW.mlir | 20 ++++++ 2 files changed, 86 insertions(+) diff --git a/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp b/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp index 39087e5f3ad6..14824b74c10a 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp @@ -45,6 +45,7 @@ struct InferReadWritePass : public InferReadWriteBase { for (MemOp memOp : llvm::make_early_inc_range( getOperation().getBodyBlock()->getOps())) { inferUnmasked(memOp, opsToErase); + simplifyWmode(memOp); size_t nReads, nWrites, nRWs, nDbgs; memOp.getNumPorts(nReads, nWrites, nRWs, nDbgs); // Run the analysis only for Seq memories (latency=1) and a single read @@ -209,6 +210,7 @@ struct InferReadWritePass : public InferReadWriteBase { opsToErase.push_back(sf); } } + simplifyWmode(rwMem); // All uses for all results of mem removed, now erase the memOp. opsToErase.push_back(memOp); } @@ -312,6 +314,70 @@ struct InferReadWritePass : public InferReadWriteBase { return {}; } + // Remove redundant dependence of wmode on the enable signal. wmode can assume + // the enable signal be true. + void simplifyWmode(MemOp &memOp) { + + // Iterate over all results, and find the enable and wmode fields of the RW + // port. + for (const auto &portIt : llvm::enumerate(memOp.getResults())) { + auto portKind = memOp.getPortKind(portIt.index()); + if (portKind != MemOp::PortKind::ReadWrite) + continue; + Value enableDriver, wmodeDriver; + Value portVal = portIt.value(); + // Iterate over all users of the rw port. + for (Operation *u : portVal.getUsers()) + if (auto sf = dyn_cast(u)) { + // Get the field name. + auto fName = + sf.getInput().getType().base().getElementName(sf.getFieldIndex()); + // Record the enable and wmode fields. + if (fName.contains("en")) + enableDriver = getConnectSrc(sf.getResult()); + if (fName.contains("wmode")) + wmodeDriver = getConnectSrc(sf.getResult()); + } + + if (enableDriver && wmodeDriver) { + ImplicitLocOpBuilder builder(memOp.getLoc(), memOp); + auto constOne = builder.create( + UIntType::get(builder.getContext(), 1), APInt(1, 1)); + setEnable(enableDriver, wmodeDriver, constOne); + } + } + } + + // Replace any occurence of enable on the expression tree of wmode with a + // constant one. + void setEnable(Value enableDriver, Value wmodeDriver, Value constOne) { + auto getDriverOp = [&](Value dst) -> Operation * { + // Look through one level of wire to get the driver op. + auto *defOp = dst.getDefiningOp(); + if (defOp) { + if (isa(defOp)) + dst = getConnectSrc(dst); + if (dst) + defOp = dst.getDefiningOp(); + } + return defOp; + }; + SmallVector stack; + stack.push_back(wmodeDriver); + while (!stack.empty()) { + auto driver = stack.pop_back_val(); + auto *defOp = getDriverOp(driver); + if (!defOp) + continue; + for (auto operand : llvm::enumerate(defOp->getOperands())) { + if (operand.value() == enableDriver) + defOp->setOperand(operand.index(), constOne); + else + stack.push_back(operand.value()); + } + } + } + void inferUnmasked(MemOp &memOp, SmallVector &opsToErase) { bool isMasked = true; diff --git a/test/Dialect/FIRRTL/inferRW.mlir b/test/Dialect/FIRRTL/inferRW.mlir index 018df6ab68d4..1c3e959f74db 100644 --- a/test/Dialect/FIRRTL/inferRW.mlir +++ b/test/Dialect/FIRRTL/inferRW.mlir @@ -282,5 +282,25 @@ firrtl.circuit "TLRAM" { firrtl.connect %auto_0, %11 : !firrtl.uint<8>, !firrtl.uint<8> } + + // CHECK-LABEL: firrtl.module @SimplifyWMODE + firrtl.module @SimplifyWMODE(in %rwPort_enable: !firrtl.uint<1>, in %rwPort_isWrite: !firrtl.uint<1>) attributes {} { + %mem_rwPort_readData_rw = firrtl.mem Undefined {depth = 64 : i64, name = "t", portNames = ["rw"], prefix = "", readLatency = 1 : i32, writeLatency = 1 : i32} : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<5>> + %mem_rwPort_readData_rw_wmode = firrtl.wire : !firrtl.uint<1> + %0 = firrtl.subfield %mem_rwPort_readData_rw[addr] : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<5>> + %1 = firrtl.subfield %mem_rwPort_readData_rw[en] : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<5>> + firrtl.strictconnect %1, %rwPort_enable : !firrtl.uint<1> + %2 = firrtl.subfield %mem_rwPort_readData_rw[clk] : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<5>> + %3 = firrtl.subfield %mem_rwPort_readData_rw[rdata] : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<5>> + %6 = firrtl.subfield %mem_rwPort_readData_rw[wmode] : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<5>> + firrtl.strictconnect %6, %mem_rwPort_readData_rw_wmode : !firrtl.uint<1> + %7 = firrtl.subfield %mem_rwPort_readData_rw[wdata] : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<5>> + %9 = firrtl.subfield %mem_rwPort_readData_rw[wmask] : !firrtl.bundle, en: uint<1>, clk: clock, rdata flip: uint<10>, wmode: uint<1>, wdata: uint<10>, wmask: uint<5>> + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + %18 = firrtl.mux(%rwPort_enable, %rwPort_isWrite, %c0_ui1) : (!firrtl.uint<1>, !firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1> + // CHECK: %[[c1_ui1:.+]] = firrtl.constant 1 : !firrtl.uint<1> + // CHECK: %[[v7:.+]] = firrtl.mux(%[[c1_ui1]], %rwPort_isWrite, %c0_ui1) + firrtl.strictconnect %mem_rwPort_readData_rw_wmode, %18 : !firrtl.uint<1> + } }