Skip to content

Commit

Permalink
[MLIR] Generalize affine fusion to work on Block instead of FuncOp
Browse files Browse the repository at this point in the history
The affine fusion pass can actually work on the top-level of a `Block`
and doesn't require to be called on a `FuncOp`. Remove this restriction
and generalize the pass to work on any `Block`. This allows fusion to be
performed, for example, on multiple blocks of a FuncOp or any
region-holding op like an scf.while, scf.if or even at an inner depth of
an affine.for or affine.if op. This generalization has no effect on
existing functionality. No changes to the fusion logic or its
transformational power were needed.

Update fusion pass to be a generic operation pass (instead of FuncOp
pass) and remove references and assumptions on the parent being a
FuncOp.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D139293
  • Loading branch information
bondhugula committed Dec 14, 2022
1 parent dc44acc commit fe9d0a4
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 97 deletions.
5 changes: 3 additions & 2 deletions mlir/include/mlir/Dialect/Affine/Passes.h
Expand Up @@ -73,10 +73,11 @@ createAffineScalarReplacementPass();
/// bounds into a single loop.
std::unique_ptr<OperationPass<func::FuncOp>> createLoopCoalescingPass();

/// Creates a loop fusion pass which fuses loops according to type of fusion
/// Creates a loop fusion pass which fuses affine loop nests at the top-level of
/// the operation the pass is created on according to the type of fusion
/// specified in `fusionMode`. Buffers of size less than or equal to
/// `localBufSizeThreshold` are promoted to memory space `fastMemorySpace`.
std::unique_ptr<OperationPass<func::FuncOp>>
std::unique_ptr<Pass>
createLoopFusionPass(unsigned fastMemorySpace = 0,
uint64_t localBufSizeThreshold = 0,
bool maximalFusion = false,
Expand Down
30 changes: 16 additions & 14 deletions mlir/include/mlir/Dialect/Affine/Passes.td
Expand Up @@ -43,22 +43,24 @@ def AffineDataCopyGeneration : Pass<"affine-data-copy-generate", "func::FuncOp">
];
}

def AffineLoopFusion : Pass<"affine-loop-fusion", "func::FuncOp"> {
def AffineLoopFusion : Pass<"affine-loop-fusion"> {
let summary = "Fuse affine loop nests";
let description = [{
This pass performs fusion of loop nests using a slicing-based approach. It
combines two fusion strategies: producer-consumer fusion and sibling fusion.
Producer-consumer fusion is aimed at fusing pairs of loops where the first
one writes to a memref that the second reads. Sibling fusion targets pairs
of loops that share no dependences between them but that load from the same
memref. The fused loop nests, when possible, are rewritten to access
significantly smaller local buffers instead of the original memref's, and
the latter are often either completely optimized away or contracted. This
transformation leads to enhanced locality and lower memory footprint through
the elimination or contraction of temporaries/intermediate memref's. These
benefits are sometimes achieved at the expense of redundant computation
through a cost model that evaluates available choices such as the depth at
which a source slice should be materialized in the designation slice.
This pass performs fusion of loop nests using a slicing-based approach. The
transformation works on an MLIR `Block` granularity and applies to all
blocks of the pass is run on. It combines two fusion strategies:
producer-consumer fusion and sibling fusion. Producer-consumer fusion is
aimed at fusing pairs of loops where the first one writes to a memref that
the second reads. Sibling fusion targets pairs of loops that share no
dependences between them but that load from the same memref. The fused loop
nests, when possible, are rewritten to access significantly smaller local
buffers instead of the original memref's, and the latter are often either
completely optimized away or contracted. This transformation leads to
enhanced locality and lower memory footprint through the elimination or
contraction of temporaries/intermediate memref's. These benefits are
sometimes achieved at the expense of redundant computation through a cost
model that evaluates available choices such as the depth at which a source
slice should be materialized in the designation slice.

Example 1: Producer-consumer fusion.
Input:
Expand Down
152 changes: 84 additions & 68 deletions mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
//
// This file implements loop fusion.
// This file implements affine fusion.
//
//===----------------------------------------------------------------------===//

Expand All @@ -19,7 +19,6 @@
#include "mlir/Dialect/Affine/LoopFusionUtils.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
Expand Down Expand Up @@ -65,12 +64,13 @@ struct LoopFusion : public impl::AffineLoopFusionBase<LoopFusion> {
this->affineFusionMode = affineFusionMode;
}

void runOnBlock(Block *block);
void runOnOperation() override;
};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
std::unique_ptr<Pass>
mlir::createLoopFusionPass(unsigned fastMemorySpace,
uint64_t localBufSizeThreshold, bool maximalFusion,
enum FusionMode affineFusionMode) {
Expand Down Expand Up @@ -104,7 +104,7 @@ struct LoopNestStateCollector {
};

// MemRefDependenceGraph is a graph data structure where graph nodes are
// top-level operations in a FuncOp which contain load/store ops, and edges
// top-level operations in a `Block` which contain load/store ops, and edges
// are memref dependences between the nodes.
// TODO: Add a more flexible dependence graph representation.
// TODO: Add a depth parameter to dependence graph construction.
Expand Down Expand Up @@ -207,11 +207,11 @@ struct MemRefDependenceGraph {
// The next unique identifier to use for newly created graph nodes.
unsigned nextNodeId = 0;

MemRefDependenceGraph() = default;
MemRefDependenceGraph(Block &block) : block(block) {}

// Initializes the dependence graph based on operations in 'f'.
// Returns true on success, false otherwise.
bool init(func::FuncOp f);
bool init(Block *block);

// Returns the graph node for 'id'.
Node *getNode(unsigned id) {
Expand Down Expand Up @@ -258,7 +258,7 @@ struct MemRefDependenceGraph {
}

// Returns true if node 'id' writes to any memref which escapes (or is an
// argument to) the function/block. Returns false otherwise.
// argument to) the block. Returns false otherwise.
bool writesToLiveInOrEscapingMemrefs(unsigned id) {
Node *node = getNode(id);
for (auto *storeOpInst : node->stores) {
Expand All @@ -267,7 +267,8 @@ struct MemRefDependenceGraph {
// Return true if 'memref' is a block argument.
if (!op)
return true;
// Return true if any use of 'memref' escapes the function.
// Return true if any use of 'memref' does not deference it in an affine
// way.
for (auto *user : memref.getUsers())
if (!isa<AffineMapAccessInterface>(*user))
return true;
Expand Down Expand Up @@ -597,6 +598,9 @@ struct MemRefDependenceGraph {
}
}
void dump() const { print(llvm::errs()); }

/// The block for which this graph is created to perform fusion.
Block &block;
};

/// Returns true if node 'srcId' can be removed after fusing it with node
Expand Down Expand Up @@ -710,21 +714,22 @@ gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
producerConsumerMemrefs);
}

/// A memref escapes the function if either:
/// A memref escapes in the context of the fusion pass if either:
/// 1. it (or its alias) is a block argument, or
/// 2. created by an op not known to guarantee alias freedom,
/// 3. it (or its alias) is used by a non-affine op (e.g., call op, memref
/// load/store ops, alias creating ops, unknown ops, etc.); such ops
/// do not deference the memref in an affine way.
static bool isEscapingMemref(Value memref) {
/// 3. it (or its alias) are used by ops other than affine dereferencing ops
/// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops,
/// terminator ops, etc.); such ops do not deference the memref in an affine
/// way.
static bool isEscapingMemref(Value memref, Block *block) {
Operation *defOp = memref.getDefiningOp();
// Check if 'memref' is a block argument.
if (!defOp)
return true;

// Check if this is defined to be an alias of another memref.
if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp))
if (isEscapingMemref(viewOp.getViewSource()))
if (isEscapingMemref(viewOp.getViewSource(), block))
return true;

// Any op besides allocating ops wouldn't guarantee alias freedom
Expand All @@ -733,44 +738,44 @@ static bool isEscapingMemref(Value memref) {

// Check if 'memref' is used by a non-deferencing op (including unknown ones)
// (e.g., call ops, alias creating ops, etc.).
for (Operation *user : memref.getUsers())
for (Operation *user : memref.getUsers()) {
// Ignore users outside of `block`.
if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() != block)
continue;
if (!isa<AffineMapAccessInterface>(*user))
return true;
}
return false;
}

/// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
/// that escape the function or are accessed by non-affine ops.
/// that escape the block or are accessed in a non-affine way.
void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
DenseSet<Value> &escapingMemRefs) {
auto *node = mdg->getNode(id);
for (Operation *storeOp : node->stores) {
auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
if (escapingMemRefs.count(memref))
continue;
if (isEscapingMemref(memref))
if (isEscapingMemref(memref, &mdg->block))
escapingMemRefs.insert(memref);
}
}

} // namespace

// Initializes the data dependence graph by walking operations in 'f'.
// Initializes the data dependence graph by walking operations in `block`.
// Assigns each node in the graph a node id based on program order in 'f'.
// TODO: Add support for taking a Block arg to construct the
// dependence graph at a different depth.
bool MemRefDependenceGraph::init(func::FuncOp f) {
bool MemRefDependenceGraph::init(Block *block) {
LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
// Map from a memref to the set of ids of the nodes that have ops accessing
// the memref.
DenseMap<Value, SetVector<unsigned>> memrefAccesses;

// TODO: support multi-block functions.
if (!llvm::hasSingleElement(f))
return false;

DenseMap<Operation *, unsigned> forToNodeMap;
for (auto &op : f.front()) {
for (Operation &op : *block) {
if (auto forOp = dyn_cast<AffineForOp>(op)) {
// Create graph node 'id' to represent top-level 'forOp' and record
// all loads and store accesses it contains.
Expand Down Expand Up @@ -845,14 +850,18 @@ bool MemRefDependenceGraph::init(func::FuncOp f) {
// Stores don't define SSA values, skip them.
if (!node.stores.empty())
continue;
auto *opInst = node.op;
for (auto value : opInst->getResults()) {
for (auto *user : value.getUsers()) {
Operation *opInst = node.op;
for (Value value : opInst->getResults()) {
for (Operation *user : value.getUsers()) {
// Ignore users outside of the block.
if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() !=
block)
continue;
SmallVector<AffineForOp, 4> loops;
getLoopIVs(*user, &loops);
if (loops.empty())
continue;
assert(forToNodeMap.count(loops[0]) > 0);
assert(forToNodeMap.count(loops[0]) > 0 && "missing mapping");
unsigned userLoopNestId = forToNodeMap[loops[0]];
addEdge(node.id, userLoopNestId, value);
}
Expand Down Expand Up @@ -918,7 +927,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
// Create builder to insert alloc op just before 'forOp'.
OpBuilder b(forInst);
// Builder to create constants at the top level.
OpBuilder top(forInst->getParentOfType<func::FuncOp>().getBody());
OpBuilder top(forInst->getParentRegion());
// Create new memref type based on slice bounds.
auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
Expand Down Expand Up @@ -979,7 +988,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
// a constant shape.
// TODO: Create/move alloc ops for private memrefs closer to their
// consumer loop nests to reduce their live range. Currently they are added
// at the beginning of the function, because loop nests can be reordered
// at the beginning of the block, because loop nests can be reordered
// during the fusion pass.
Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType);

Expand Down Expand Up @@ -1508,8 +1517,8 @@ struct GreedyFusion {
}))
continue;

// Gather memrefs in 'srcNode' that are written and escape to the
// function (e.g., memref function arguments, returned memrefs,
// Gather memrefs in 'srcNode' that are written and escape out of the
// block (e.g., memref block arguments, returned memrefs,
// memrefs passed to function calls, etc.).
DenseSet<Value> srcEscapingMemRefs;
gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
Expand Down Expand Up @@ -1829,7 +1838,7 @@ struct GreedyFusion {
}
}

// Searches function argument uses and the graph from 'dstNode' looking for a
// Searches block argument uses and the graph from 'dstNode' looking for a
// fusion candidate sibling node which shares no dependences with 'dstNode'
// but which loads from the same memref. Returns true and sets
// 'idAndMemrefToFuse' on success. Returns false otherwise.
Expand Down Expand Up @@ -1874,36 +1883,37 @@ struct GreedyFusion {
return true;
};

// Search for siblings which load the same memref function argument.
auto fn = dstNode->op->getParentOfType<func::FuncOp>();
for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
for (auto *user : fn.getArgument(i).getUsers()) {
if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) {
// Gather loops surrounding 'use'.
SmallVector<AffineForOp, 4> loops;
getLoopIVs(*user, &loops);
// Skip 'use' if it is not within a loop nest.
if (loops.empty())
continue;
Node *sibNode = mdg->getForOpNode(loops[0]);
assert(sibNode != nullptr);
// Skip 'use' if it not a sibling to 'dstNode'.
if (sibNode->id == dstNode->id)
continue;
// Skip 'use' if it has been visited.
if (visitedSibNodeIds->count(sibNode->id) > 0)
continue;
// Skip 'use' if it does not load from the same memref as 'dstNode'.
auto memref = loadOp.getMemRef();
if (dstNode->getLoadOpCount(memref) == 0)
continue;
// Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
if (canFuseWithSibNode(sibNode, memref)) {
visitedSibNodeIds->insert(sibNode->id);
idAndMemrefToFuse->first = sibNode->id;
idAndMemrefToFuse->second = memref;
return true;
}
// Search for siblings which load the same memref block argument.
Block *block = dstNode->op->getBlock();
for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) {
for (Operation *user : block->getArgument(i).getUsers()) {
auto loadOp = dyn_cast<AffineReadOpInterface>(user);
if (!loadOp)
continue;
// Gather loops surrounding 'use'.
SmallVector<AffineForOp, 4> loops;
getLoopIVs(*user, &loops);
// Skip 'use' if it is not within a loop nest.
if (loops.empty())
continue;
Node *sibNode = mdg->getForOpNode(loops[0]);
assert(sibNode != nullptr);
// Skip 'use' if it not a sibling to 'dstNode'.
if (sibNode->id == dstNode->id)
continue;
// Skip 'use' if it has been visited.
if (visitedSibNodeIds->count(sibNode->id) > 0)
continue;
// Skip 'use' if it does not load from the same memref as 'dstNode'.
auto memref = loadOp.getMemRef();
if (dstNode->getLoadOpCount(memref) == 0)
continue;
// Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'.
if (canFuseWithSibNode(sibNode, memref)) {
visitedSibNodeIds->insert(sibNode->id);
idAndMemrefToFuse->first = sibNode->id;
idAndMemrefToFuse->second = memref;
return true;
}
}
}
Expand Down Expand Up @@ -1968,8 +1978,7 @@ struct GreedyFusion {
mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
dstLoopCollector.storeOpInsts);
// Remove old sibling loop nest if it no longer has outgoing dependence
// edges, and it does not write to a memref which escapes the
// function.
// edges, and it does not write to a memref which escapes the block.
if (mdg->getOutEdgeCount(sibNode->id) == 0) {
Operation *op = sibNode->op;
mdg->removeNode(sibNode->id);
Expand All @@ -1996,9 +2005,10 @@ struct GreedyFusion {

} // namespace

void LoopFusion::runOnOperation() {
MemRefDependenceGraph g;
if (!g.init(getOperation()))
/// Run fusion on `block`.
void LoopFusion::runOnBlock(Block *block) {
MemRefDependenceGraph g(*block);
if (!g.init(block))
return;

Optional<unsigned> fastMemorySpaceOpt;
Expand All @@ -2015,3 +2025,9 @@ void LoopFusion::runOnOperation() {
else
fusion.runGreedyFusion();
}

void LoopFusion::runOnOperation() {
for (Region &region : getOperation()->getRegions())
for (Block &block : region.getBlocks())
runOnBlock(&block);
}

0 comments on commit fe9d0a4

Please sign in to comment.