Skip to content

Commit 11058ac

Browse files
committed
[MLIR][Affine] Check dependences during MDG init
Check affine dependences precisely during MDG init before adding edges. We were conservatively only checking for memref-level conflicts. Leads to more/better fusion as a result. Fixes: #156421
1 parent e09fe9b commit 11058ac

File tree

3 files changed

+91
-8
lines changed

3 files changed

+91
-8
lines changed

mlir/include/mlir/Dialect/Affine/Analysis/Utils.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,17 @@ struct MemRefDependenceGraph {
153153

154154
MemRefDependenceGraph(Block &block) : block(block) {}
155155

156-
// Initializes the dependence graph based on operations in `block'.
157-
// Returns true on success, false otherwise.
158-
bool init();
156+
// Initializes the data dependence graph by walking operations in the MDG's
157+
// `block`. A `Node` is created for every top-level op except for
158+
// side-effect-free operations with zero results and no regions. Assigns each
159+
// node in the graph a node id based on the order in block. Fails if certain
160+
// kinds of operations, for which `Node` creation isn't supported, are
161+
// encountered (unknown region holding ops). If `fullAffineDependences` is
162+
// set, affine memory dependence analysis is performed before concluding that
163+
// conflicting affine memory accesses lead to a dependence check; otherwise, a
164+
// pair of conflicting affine memory accesses (where one of them is a store
165+
// and they are to the same memref) always leads to an edge (conservatively).
166+
bool init(bool fullAffineDependences = true);
159167

160168
// Returns the graph node for 'id'.
161169
const Node *getNode(unsigned id) const;

mlir/lib/Dialect/Affine/Analysis/Utils.cpp

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,55 @@ addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
241241
return &node;
242242
}
243243

244-
bool MemRefDependenceGraph::init() {
244+
/// Returns true if there may be a dependence on `memref` from srcNode's
245+
/// memory ops to dstNode's memory ops, while using the affine memory
246+
/// dependence analysis checks. The method assumes that there is at least one
247+
/// memory op in srcNode's loads and stores on `memref`, and similarly for
248+
/// `dstNode`. `srcNode.op` and `destNode.op` are expected to be nested in the
249+
/// same block and so the dependences are tested at the depth of that block.
250+
static bool mayDependence(const Node &srcNode, const Node &dstNode,
251+
Value memref) {
252+
assert(srcNode.op->getBlock() == dstNode.op->getBlock());
253+
if (!isa<AffineForOp>(srcNode.op) || !isa<AffineForOp>(dstNode.op))
254+
return true;
255+
256+
// Non-affine stores, can't check. Conservatively, return true.
257+
if (!srcNode.memrefStores.empty())
258+
return true;
259+
if (!dstNode.memrefStores.empty())
260+
return true;
261+
262+
// Non-affine loads with a store in the other.
263+
if (!srcNode.memrefLoads.empty() && !dstNode.stores.empty())
264+
return true;
265+
if (!dstNode.memrefLoads.empty() && !srcNode.stores.empty())
266+
return true;
267+
268+
// Affine load/store pairs. We don't need to check for locally allocated
269+
// memrefs since the dependence analysis here is between mem ops from
270+
// srcNode's for op to dstNode's for op at the depth at which those
271+
// `affine.for` ops are nested, i.e., dependences at depth `d + 1` where
272+
// `d` is the number of common surrounding loops.
273+
for (auto *srcMemOp :
274+
llvm::concat<Operation *const>(srcNode.stores, srcNode.loads)) {
275+
MemRefAccess srcAcc(srcMemOp);
276+
if (srcAcc.memref != memref)
277+
continue;
278+
for (auto *destMemOp :
279+
llvm::concat<Operation *const>(dstNode.stores, dstNode.loads)) {
280+
MemRefAccess destAcc(destMemOp);
281+
if (destAcc.memref != memref)
282+
continue;
283+
// Check for a top-level dependence between srcNode and destNode's ops.
284+
if (!noDependence(checkMemrefAccessDependence(
285+
srcAcc, destAcc, getNestingDepth(srcNode.op) + 1)))
286+
return true;
287+
}
288+
}
289+
return false;
290+
}
291+
292+
bool MemRefDependenceGraph::init(bool fullAffineDependences) {
245293
LDBG() << "--- Initializing MDG ---";
246294
// Map from a memref to the set of ids of the nodes that have ops accessing
247295
// the memref.
@@ -344,8 +392,12 @@ bool MemRefDependenceGraph::init() {
344392
Node *dstNode = getNode(dstId);
345393
bool dstHasStoreOrFree =
346394
dstNode->hasStore(srcMemRef) || dstNode->hasFree(srcMemRef);
347-
if (srcHasStoreOrFree || dstHasStoreOrFree)
348-
addEdge(srcId, dstId, srcMemRef);
395+
if ((srcHasStoreOrFree || dstHasStoreOrFree)) {
396+
// Check precise affine deps if asked for; otherwise, conservative.
397+
if (!fullAffineDependences ||
398+
mayDependence(*srcNode, *dstNode, srcMemRef))
399+
addEdge(srcId, dstId, srcMemRef);
400+
}
349401
}
350402
}
351403
}
@@ -562,13 +614,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
562614
}
563615

564616
// Build set of insts in range (srcId, dstId) which depend on 'srcId'.
565-
SmallPtrSet<Operation *, 2> srcDepInsts;
617+
llvm::SmallPtrSet<Operation *, 2> srcDepInsts;
566618
for (auto &outEdge : outEdges.lookup(srcId))
567619
if (outEdge.id != dstId)
568620
srcDepInsts.insert(getNode(outEdge.id)->op);
569621

570622
// Build set of insts in range (srcId, dstId) on which 'dstId' depends.
571-
SmallPtrSet<Operation *, 2> dstDepInsts;
623+
llvm::SmallPtrSet<Operation *, 2> dstDepInsts;
572624
for (auto &inEdge : inEdges.lookup(dstId))
573625
if (inEdge.id != srcId)
574626
dstDepInsts.insert(getNode(inEdge.id)->op);
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' | FileCheck %s
2+
3+
// Test cases specifically for sibling fusion. Note that sibling fusion test
4+
// cases also exist in loop-fusion*.mlir.
5+
6+
// CHECK-LABEL: func @disjoint_stores
7+
func.func @disjoint_stores(%0: memref<8xf32>) {
8+
%alloc_1 = memref.alloc() : memref<16xf32>
9+
// The affine stores below are to different parts of the memrefs. Sibling
10+
// fusion helps improve reuse and is valid.
11+
affine.for %arg2 = 0 to 8 {
12+
%2 = affine.load %0[%arg2] : memref<8xf32>
13+
affine.store %2, %alloc_1[%arg2] : memref<16xf32>
14+
}
15+
affine.for %arg2 = 0 to 8 {
16+
%2 = affine.load %0[%arg2] : memref<8xf32>
17+
%3 = arith.negf %2 : f32
18+
affine.store %3, %alloc_1[%arg2 + 8] : memref<16xf32>
19+
}
20+
// CHECK: affine.for
21+
// CHECK-NOT: affine.for
22+
return
23+
}

0 commit comments

Comments
 (0)