Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LowerFirReg] Reimplement the mux reachability analysis #6709

Merged
merged 18 commits into from
Apr 9, 2024
Merged
102 changes: 73 additions & 29 deletions lib/Conversion/SeqToSV/FirRegLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Debug.h"
#include <cassert>

using namespace circt;
using namespace hw;
Expand All @@ -20,34 +21,69 @@ using llvm::MapVector;

#define DEBUG_TYPE "lower-seq-firreg"

// Reimplemented from SliceAnalysis to use a worklist rather than recursion and
// non-insert ordered set.
static void
getForwardSliceSimple(Operation *root,
llvm::DenseSet<Operation *> &forwardSlice,
llvm::function_ref<bool(Operation *)> filter = nullptr) {
SmallVector<Operation *> worklist({root});
std::function<bool(const Operation *op)> OpUserInfo::opAllowsReachability =
[](const Operation *op) -> bool {
return (isa<comb::MuxOp, ArrayGetOp, ArrayCreateOp>(op));
};

bool ReachableMuxes::isMuxReachableFrom(seq::FirRegOp regOp,
comb::MuxOp muxOp) {
return llvm::any_of(regOp.getResult().getUsers(), [&](Operation *user) {
if (!OpUserInfo::opAllowsReachability(user))
return false;
buildReachabilityFrom(user);
return reachableMuxes[user].contains(muxOp);
});
}

while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
void ReachableMuxes::buildReachabilityFrom(Operation *startNode) {
// This is a backward dataflow analysis.
// First build a graph rooted at the `startNode`. Every user of an operation
// that does not block the reachability is a child node. Then, the ops that
// are reachable from a node is computed as the union of the Reachability of
// all its child nodes.
// The dataflow can be expressed as, for all child in the Children(node)
// Reachability(node) = node + Union{Reachability(child)}
if (visited.contains(startNode))
return;

if (!op)
continue;
// The stack to record enough information for an iterative post-order
// traversal.
llvm::SmallVector<OpUserInfo> stk;

if (filter && !filter(op))
continue;
stk.emplace_back(startNode);

while (!stk.empty()) {
auto &info = stk.back();
Operation *currentNode = info.op;

// Node is being visited for the first time.
if (info.getAndSetUnvisited())
visited.insert(currentNode);
prithayan marked this conversation as resolved.
Show resolved Hide resolved

if (info.userIter != info.userEnd) {
Operation *child = *info.userIter;
++info.userIter;
if (!visited.contains(child))
stk.emplace_back(child);

for (Region &region : op->getRegions())
for (Block &block : region)
for (Operation &blockOp : block)
if (forwardSlice.insert(&blockOp).second)
worklist.push_back(&blockOp);
for (Value result : op->getResults())
for (Operation *userOp : result.getUsers())
if (forwardSlice.insert(userOp).second)
worklist.push_back(userOp);

forwardSlice.insert(op);
} else { // All children of the node have been visited
// Any op is reachable from itself.
reachableMuxes[currentNode].insert(currentNode);

for (auto *childOp : llvm::make_filter_range(
info.op->getUsers(), OpUserInfo::opAllowsReachability)) {
reachableMuxes[currentNode].insert(childOp);
// Propagate the reachability backwards from m to currentNode.
auto iter = reachableMuxes.find(childOp);
assert(iter != reachableMuxes.end());

// Add all the mux that was reachable from childOp, to currentNode.
reachableMuxes[currentNode].insert(iter->getSecond().begin(),
iter->getSecond().end());
}
stk.pop_back();
}
}
}

Expand All @@ -70,6 +106,17 @@ void FirRegLowering::addToIfBlock(OpBuilder &builder, Value cond,
}
}

FirRegLowering::FirRegLowering(TypeConverter &typeConverter,
hw::HWModuleOp module,
bool disableRegRandomization,
bool emitSeparateAlwaysBlocks)
: typeConverter(typeConverter), module(module),
disableRegRandomization(disableRegRandomization),
emitSeparateAlwaysBlocks(emitSeparateAlwaysBlocks) {

reachableMuxes = std::make_unique<ReachableMuxes>(module);
prithayan marked this conversation as resolved.
Show resolved Hide resolved
}

void FirRegLowering::lower() {
// Find all registers to lower in the module.
auto regs = module.getOps<seq::FirRegOp>();
Expand Down Expand Up @@ -358,10 +405,6 @@ void FirRegLowering::createTree(OpBuilder &builder, Value reg, Value term,
// want to create if/else structure for logic unrelated to the register's
// enable.
auto firReg = term.getDefiningOp<seq::FirRegOp>();
DenseSet<Operation *> regMuxFanout;
getForwardSliceSimple(firReg, regMuxFanout, [&](Operation *op) {
return op == firReg || !isa<sv::RegOp, seq::FirRegOp, hw::InstanceOp>(op);
});

SmallVector<std::tuple<Block *, Value, Value, Value>> worklist;
auto addToWorklist = [&](Value reg, Value term, Value next) {
Expand Down Expand Up @@ -389,7 +432,8 @@ void FirRegLowering::createTree(OpBuilder &builder, Value reg, Value term,
// If this is a two-state mux within the fanout from the register, we use
// if/else structure for proper enable inference.
auto mux = next.getDefiningOp<comb::MuxOp>();
if (mux && mux.getTwoState() && regMuxFanout.contains(mux)) {
if (mux && mux.getTwoState() &&
reachableMuxes->isMuxReachableFrom(firReg, mux)) {
addToIfBlock(
builder, mux.getCond(),
[&]() { addToWorklist(reg, term, mux.getTrueValue()); },
Expand Down
67 changes: 62 additions & 5 deletions lib/Conversion/SeqToSV/FirRegLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,82 @@
#ifndef CONVERSION_SEQTOSV_FIRREGLOWERING_H
#define CONVERSION_SEQTOSV_FIRREGLOWERING_H

#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/SV/SVOps.h"
#include "circt/Dialect/Seq/SeqOps.h"
#include "circt/Support/LLVM.h"
#include "circt/Support/Namespace.h"
#include "circt/Support/SymCache.h"
#include "mlir/IR/Visitors.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include <mlir/IR/ValueRange.h>
#include <stack>
#include <unordered_set>

namespace circt {

using namespace hw;
// This class computes the set of muxes that are reachable from an op.
// The heuristic propagates the reachability only through the 3 ops, mux,
// array_create and array_get. All other ops block the reachability.
// This analysis is built lazily on every query.
// The query: is a mux is reachable from a reg, results in a DFS traversal
// of the IR rooted at the register. This traversal is completed and the
// result is cached in a Map, for faster retrieval on any future query of any
// op in this subgraph.
class ReachableMuxes {
prithayan marked this conversation as resolved.
Show resolved Hide resolved
public:
ReachableMuxes(HWModuleOp m) : module(m) {}

bool isMuxReachableFrom(seq::FirRegOp regOp, comb::MuxOp muxOp);

private:
void buildReachabilityFrom(Operation *startNode);
HWModuleOp module;
llvm::DenseMap<Operation *, llvm::SmallDenseSet<Operation *>> reachableMuxes;
llvm::SmallPtrSet<Operation *, 16> visited;
};

// The op and its users information that needs to be tracked on the stack
// for an iterative DFS traversal.
struct OpUserInfo {
Operation *op;
using ValidUsersIterator =
llvm::filter_iterator<Operation::user_iterator,
std::function<bool(const Operation *)>>;

ValidUsersIterator userIter, userEnd;
static std::function<bool(const Operation *op)> opAllowsReachability;

OpUserInfo(Operation *op)
: op(op), userIter(op->getUsers().begin(), op->getUsers().end(),
opAllowsReachability),
userEnd(op->getUsers().end(), op->getUsers().end(),
opAllowsReachability) {}

bool getAndSetUnvisited() {
if (unvisited) {
unvisited = false;
return true;
}
return false;
}

private:
bool unvisited = true;
};

/// Lower FirRegOp to `sv.reg` and `sv.always`.
class FirRegLowering {
public:
FirRegLowering(TypeConverter &typeConverter, hw::HWModuleOp module,
bool disableRegRandomization = false,
bool emitSeparateAlwaysBlocks = false)
: typeConverter(typeConverter), module(module),
disableRegRandomization(disableRegRandomization),
emitSeparateAlwaysBlocks(emitSeparateAlwaysBlocks){};
bool emitSeparateAlwaysBlocks = false);

void lower();

bool needsRegRandomization() const { return needsRandom; }

unsigned numSubaccessRestored = 0;
Expand Down Expand Up @@ -87,6 +143,7 @@ class FirRegLowering {

llvm::SmallDenseMap<APInt, hw::ConstantOp> constantCache;
llvm::SmallDenseMap<std::pair<Value, unsigned>, Value> arrayIndexCache;
std::unique_ptr<ReachableMuxes> reachableMuxes;

TypeConverter &typeConverter;
hw::HWModuleOp module;
Expand Down
25 changes: 19 additions & 6 deletions test/Dialect/Seq/firreg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -290,16 +290,12 @@ hw.module private @InitReg1(in %clock: !seq.clock, in %reset: i1, in %io_d: i32,
// COMMON-NEXT: %5 = comb.add %3, %4 : i33
// COMMON-NEXT: %6 = comb.extract %5 from 1 : (i33) -> i32
// COMMON-NEXT: %7 = comb.mux bin %io_en, %io_d, %6 : i32
// COMMON-NEXT: sv.always posedge %clock, posedge %reset {
// COMMON-NEXT: sv.always posedge %clock, posedge %reset {
// COMMON-NEXT: sv.if %reset {
// COMMON-NEXT: sv.passign %reg, %c0_i32 : i32
// COMMON-NEXT: sv.passign %reg3, %c1_i32 : i32
// COMMON-NEXT: } else {
// COMMON-NEXT: sv.if %io_en {
// COMMON-NEXT: sv.passign %reg, %io_d : i32
// COMMON-NEXT: } else {
// COMMON-NEXT: sv.passign %reg, %6 : i32
// COMMON-NEXT: }
// COMMON-NEXT: sv.passign %reg, %7 : i32
// COMMON-NEXT: sv.passign %reg3, %2 : i32
// COMMON-NEXT: }
// COMMON-NEXT: }
Expand Down Expand Up @@ -915,3 +911,20 @@ hw.module @RegMuxInlining3(in %clock: !seq.clock, in %c: i1, out out: i8) {
%0 = comb.mux bin %c, %r2, %r3 : i8
hw.output %r1 : i8
}

// CHECK-LABEL: hw.module @SharedMux
hw.module @SharedMux(in %clock: !seq.clock, in %cond : i1, out o: i2){
%mux = comb.mux bin %cond, %r1, %r2 : i2
%r1 = seq.firreg %mux clock %clock : i2
%r2 = seq.firreg %mux clock %clock : i2
hw.output %r2: i2
//CHECK: %r1 = sv.reg : !hw.inout<i2>
//CHECK: %[[V1:.+]] = sv.read_inout %r1 : !hw.inout<i2>
//CHECK: %r2 = sv.reg : !hw.inout<i2>
//CHECK: %[[V2:.+]] = sv.read_inout %r2 : !hw.inout<i2>
//CHECK: sv.always posedge %clock {
//CHECK: sv.if %cond {
//CHECK: sv.passign %r2, %[[V1]] : i2
//CHECK: } else {
//CHECK: sv.passign %r1, %[[V2]] : i2
}