diff --git a/lib/Conversion/SeqToSV/FirRegLowering.cpp b/lib/Conversion/SeqToSV/FirRegLowering.cpp index 39765d5f2e5..9c500416efc 100644 --- a/lib/Conversion/SeqToSV/FirRegLowering.cpp +++ b/lib/Conversion/SeqToSV/FirRegLowering.cpp @@ -12,6 +12,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Support/Debug.h" +#include using namespace circt; using namespace hw; @@ -20,34 +21,69 @@ using llvm::MapVector; #define DEBUG_TYPE "lower-seq-firreg" -// Reimplemented from SliceAnalysis to use a worklist rather than recursion and -// non-insert ordered set. -static void -getForwardSliceSimple(Operation *root, - llvm::DenseSet &forwardSlice, - llvm::function_ref filter = nullptr) { - SmallVector worklist({root}); +std::function OpUserInfo::opAllowsReachability = + [](const Operation *op) -> bool { + return (isa(op)); +}; + +bool ReachableMuxes::isMuxReachableFrom(seq::FirRegOp regOp, + comb::MuxOp muxOp) { + return llvm::any_of(regOp.getResult().getUsers(), [&](Operation *user) { + if (!OpUserInfo::opAllowsReachability(user)) + return false; + buildReachabilityFrom(user); + return reachableMuxes[user].contains(muxOp); + }); +} - while (!worklist.empty()) { - Operation *op = worklist.pop_back_val(); +void ReachableMuxes::buildReachabilityFrom(Operation *startNode) { + // This is a backward dataflow analysis. + // First build a graph rooted at the `startNode`. Every user of an operation + // that does not block the reachability is a child node. Then, the ops that + // are reachable from a node is computed as the union of the Reachability of + // all its child nodes. + // The dataflow can be expressed as, for all child in the Children(node) + // Reachability(node) = node + Union{Reachability(child)} + if (visited.contains(startNode)) + return; - if (!op) - continue; + // The stack to record enough information for an iterative post-order + // traversal. + llvm::SmallVector stk; - if (filter && !filter(op)) - continue; + stk.emplace_back(startNode); + + while (!stk.empty()) { + auto &info = stk.back(); + Operation *currentNode = info.op; + + // Node is being visited for the first time. + if (info.getAndSetUnvisited()) + visited.insert(currentNode); + + if (info.userIter != info.userEnd) { + Operation *child = *info.userIter; + ++info.userIter; + if (!visited.contains(child)) + stk.emplace_back(child); - for (Region ®ion : op->getRegions()) - for (Block &block : region) - for (Operation &blockOp : block) - if (forwardSlice.insert(&blockOp).second) - worklist.push_back(&blockOp); - for (Value result : op->getResults()) - for (Operation *userOp : result.getUsers()) - if (forwardSlice.insert(userOp).second) - worklist.push_back(userOp); - - forwardSlice.insert(op); + } else { // All children of the node have been visited + // Any op is reachable from itself. + reachableMuxes[currentNode].insert(currentNode); + + for (auto *childOp : llvm::make_filter_range( + info.op->getUsers(), OpUserInfo::opAllowsReachability)) { + reachableMuxes[currentNode].insert(childOp); + // Propagate the reachability backwards from m to currentNode. + auto iter = reachableMuxes.find(childOp); + assert(iter != reachableMuxes.end()); + + // Add all the mux that was reachable from childOp, to currentNode. + reachableMuxes[currentNode].insert(iter->getSecond().begin(), + iter->getSecond().end()); + } + stk.pop_back(); + } } } @@ -70,6 +106,17 @@ void FirRegLowering::addToIfBlock(OpBuilder &builder, Value cond, } } +FirRegLowering::FirRegLowering(TypeConverter &typeConverter, + hw::HWModuleOp module, + bool disableRegRandomization, + bool emitSeparateAlwaysBlocks) + : typeConverter(typeConverter), module(module), + disableRegRandomization(disableRegRandomization), + emitSeparateAlwaysBlocks(emitSeparateAlwaysBlocks) { + + reachableMuxes = std::make_unique(module); +} + void FirRegLowering::lower() { // Find all registers to lower in the module. auto regs = module.getOps(); @@ -358,10 +405,6 @@ void FirRegLowering::createTree(OpBuilder &builder, Value reg, Value term, // want to create if/else structure for logic unrelated to the register's // enable. auto firReg = term.getDefiningOp(); - DenseSet regMuxFanout; - getForwardSliceSimple(firReg, regMuxFanout, [&](Operation *op) { - return op == firReg || !isa(op); - }); SmallVector> worklist; auto addToWorklist = [&](Value reg, Value term, Value next) { @@ -389,7 +432,8 @@ void FirRegLowering::createTree(OpBuilder &builder, Value reg, Value term, // If this is a two-state mux within the fanout from the register, we use // if/else structure for proper enable inference. auto mux = next.getDefiningOp(); - if (mux && mux.getTwoState() && regMuxFanout.contains(mux)) { + if (mux && mux.getTwoState() && + reachableMuxes->isMuxReachableFrom(firReg, mux)) { addToIfBlock( builder, mux.getCond(), [&]() { addToWorklist(reg, term, mux.getTrueValue()); }, diff --git a/lib/Conversion/SeqToSV/FirRegLowering.h b/lib/Conversion/SeqToSV/FirRegLowering.h index a65c00f2b11..a78c406e53b 100644 --- a/lib/Conversion/SeqToSV/FirRegLowering.h +++ b/lib/Conversion/SeqToSV/FirRegLowering.h @@ -10,26 +10,82 @@ #ifndef CONVERSION_SEQTOSV_FIRREGLOWERING_H #define CONVERSION_SEQTOSV_FIRREGLOWERING_H +#include "circt/Dialect/Comb/CombOps.h" #include "circt/Dialect/HW/HWOps.h" #include "circt/Dialect/SV/SVOps.h" #include "circt/Dialect/Seq/SeqOps.h" #include "circt/Support/LLVM.h" #include "circt/Support/Namespace.h" #include "circt/Support/SymCache.h" +#include "mlir/IR/Visitors.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include +#include +#include namespace circt { + +using namespace hw; +// This class computes the set of muxes that are reachable from an op. +// The heuristic propagates the reachability only through the 3 ops, mux, +// array_create and array_get. All other ops block the reachability. +// This analysis is built lazily on every query. +// The query: is a mux is reachable from a reg, results in a DFS traversal +// of the IR rooted at the register. This traversal is completed and the +// result is cached in a Map, for faster retrieval on any future query of any +// op in this subgraph. +class ReachableMuxes { +public: + ReachableMuxes(HWModuleOp m) : module(m) {} + + bool isMuxReachableFrom(seq::FirRegOp regOp, comb::MuxOp muxOp); + +private: + void buildReachabilityFrom(Operation *startNode); + HWModuleOp module; + llvm::DenseMap> reachableMuxes; + llvm::SmallPtrSet visited; +}; + +// The op and its users information that needs to be tracked on the stack +// for an iterative DFS traversal. +struct OpUserInfo { + Operation *op; + using ValidUsersIterator = + llvm::filter_iterator>; + + ValidUsersIterator userIter, userEnd; + static std::function opAllowsReachability; + + OpUserInfo(Operation *op) + : op(op), userIter(op->getUsers().begin(), op->getUsers().end(), + opAllowsReachability), + userEnd(op->getUsers().end(), op->getUsers().end(), + opAllowsReachability) {} + + bool getAndSetUnvisited() { + if (unvisited) { + unvisited = false; + return true; + } + return false; + } + +private: + bool unvisited = true; +}; + /// Lower FirRegOp to `sv.reg` and `sv.always`. class FirRegLowering { public: FirRegLowering(TypeConverter &typeConverter, hw::HWModuleOp module, bool disableRegRandomization = false, - bool emitSeparateAlwaysBlocks = false) - : typeConverter(typeConverter), module(module), - disableRegRandomization(disableRegRandomization), - emitSeparateAlwaysBlocks(emitSeparateAlwaysBlocks){}; + bool emitSeparateAlwaysBlocks = false); void lower(); - bool needsRegRandomization() const { return needsRandom; } unsigned numSubaccessRestored = 0; @@ -87,6 +143,7 @@ class FirRegLowering { llvm::SmallDenseMap constantCache; llvm::SmallDenseMap, Value> arrayIndexCache; + std::unique_ptr reachableMuxes; TypeConverter &typeConverter; hw::HWModuleOp module; diff --git a/test/Dialect/Seq/firreg.mlir b/test/Dialect/Seq/firreg.mlir index 4557ff2507a..16cfeb5f58b 100644 --- a/test/Dialect/Seq/firreg.mlir +++ b/test/Dialect/Seq/firreg.mlir @@ -290,16 +290,12 @@ hw.module private @InitReg1(in %clock: !seq.clock, in %reset: i1, in %io_d: i32, // COMMON-NEXT: %5 = comb.add %3, %4 : i33 // COMMON-NEXT: %6 = comb.extract %5 from 1 : (i33) -> i32 // COMMON-NEXT: %7 = comb.mux bin %io_en, %io_d, %6 : i32 - // COMMON-NEXT: sv.always posedge %clock, posedge %reset { + // COMMON-NEXT: sv.always posedge %clock, posedge %reset { // COMMON-NEXT: sv.if %reset { // COMMON-NEXT: sv.passign %reg, %c0_i32 : i32 // COMMON-NEXT: sv.passign %reg3, %c1_i32 : i32 // COMMON-NEXT: } else { - // COMMON-NEXT: sv.if %io_en { - // COMMON-NEXT: sv.passign %reg, %io_d : i32 - // COMMON-NEXT: } else { - // COMMON-NEXT: sv.passign %reg, %6 : i32 - // COMMON-NEXT: } + // COMMON-NEXT: sv.passign %reg, %7 : i32 // COMMON-NEXT: sv.passign %reg3, %2 : i32 // COMMON-NEXT: } // COMMON-NEXT: } @@ -915,3 +911,20 @@ hw.module @RegMuxInlining3(in %clock: !seq.clock, in %c: i1, out out: i8) { %0 = comb.mux bin %c, %r2, %r3 : i8 hw.output %r1 : i8 } + + // CHECK-LABEL: hw.module @SharedMux + hw.module @SharedMux(in %clock: !seq.clock, in %cond : i1, out o: i2){ + %mux = comb.mux bin %cond, %r1, %r2 : i2 + %r1 = seq.firreg %mux clock %clock : i2 + %r2 = seq.firreg %mux clock %clock : i2 + hw.output %r2: i2 + //CHECK: %r1 = sv.reg : !hw.inout + //CHECK: %[[V1:.+]] = sv.read_inout %r1 : !hw.inout + //CHECK: %r2 = sv.reg : !hw.inout + //CHECK: %[[V2:.+]] = sv.read_inout %r2 : !hw.inout + //CHECK: sv.always posedge %clock { + //CHECK: sv.if %cond { + //CHECK: sv.passign %r2, %[[V1]] : i2 + //CHECK: } else { + //CHECK: sv.passign %r1, %[[V2]] : i2 +}