Skip to content

Commit 68830c7

Browse files
authored
[MLIR][Affine] Check dependences during MDG init (#156422)
Check affine dependences precisely during MDG init before adding edges. We were conservatively only checking for memref-level conflicts. Leads to more/better fusion as a result. Fixes: #156421
1 parent 7731ecf commit 68830c7

File tree

3 files changed

+135
-8
lines changed

3 files changed

+135
-8
lines changed

mlir/include/mlir/Dialect/Affine/Analysis/Utils.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,17 @@ struct MemRefDependenceGraph {
153153

154154
MemRefDependenceGraph(Block &block) : block(block) {}
155155

156-
// Initializes the dependence graph based on operations in `block'.
157-
// Returns true on success, false otherwise.
158-
bool init();
156+
// Initializes the data dependence graph by iterating over the operations of
157+
// the MDG's `block`. A `Node` is created for every top-level op except for
158+
// side-effect-free operations with zero results and no regions. Assigns each
159+
// node in the graph a node id based on the order in block. Fails if certain
160+
// kinds of operations, for which `Node` creation isn't supported, are
161+
// encountered (unknown region holding ops). If `fullAffineDependences` is
162+
// set, affine memory dependence analysis is performed before concluding that
163+
// conflicting affine memory accesses lead to a dependence check; otherwise, a
164+
// pair of conflicting affine memory accesses (where one of them is a store
165+
// and they are to the same memref) always leads to an edge (conservatively).
166+
bool init(bool fullAffineDependences = true);
159167

160168
// Returns the graph node for 'id'.
161169
const Node *getNode(unsigned id) const;

mlir/lib/Dialect/Affine/Analysis/Utils.cpp

Lines changed: 101 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1919
#include "mlir/Dialect/Affine/IR/AffineOps.h"
2020
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
21+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2122
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2223
#include "mlir/IR/IntegerSet.h"
2324
#include "llvm/ADT/SetVector.h"
@@ -241,7 +242,98 @@ addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
241242
return &node;
242243
}
243244

244-
bool MemRefDependenceGraph::init() {
245+
/// Returns the memref being read/written by a memref/affine load/store op.
246+
static Value getMemRef(Operation *memOp) {
247+
if (auto memrefLoad = dyn_cast<memref::LoadOp>(memOp))
248+
return memrefLoad.getMemRef();
249+
if (auto affineLoad = dyn_cast<AffineReadOpInterface>(memOp))
250+
return affineLoad.getMemRef();
251+
if (auto memrefStore = dyn_cast<memref::StoreOp>(memOp))
252+
return memrefStore.getMemRef();
253+
if (auto affineStore = dyn_cast<AffineWriteOpInterface>(memOp))
254+
return affineStore.getMemRef();
255+
llvm_unreachable("unexpected op");
256+
}
257+
258+
/// Returns true if there may be a dependence on `memref` from srcNode's
259+
/// memory ops to dstNode's memory ops, while using the affine memory
260+
/// dependence analysis checks. The method assumes that there is at least one
261+
/// memory op in srcNode's loads and stores on `memref`, and similarly for
262+
/// `dstNode`. `srcNode.op` and `destNode.op` are expected to be nested in the
263+
/// same block and so the dependences are tested at the depth of that block.
264+
static bool mayDependence(const Node &srcNode, const Node &dstNode,
265+
Value memref) {
266+
assert(srcNode.op->getBlock() == dstNode.op->getBlock());
267+
if (!isa<AffineForOp>(srcNode.op) || !isa<AffineForOp>(dstNode.op))
268+
return true;
269+
270+
// Conservatively handle dependences involving non-affine load/stores. Return
271+
// true if there exists a conflicting read/write access involving such.
272+
273+
// Check whether there is a dependence from a source read/write op to a
274+
// destination read/write one; all expected to be memref/affine load/store.
275+
auto hasNonAffineDep = [&](ArrayRef<Operation *> srcMemOps,
276+
ArrayRef<Operation *> dstMemOps) {
277+
return llvm::any_of(srcMemOps, [&](Operation *srcOp) {
278+
Value srcMemref = getMemRef(srcOp);
279+
if (srcMemref != memref)
280+
return false;
281+
return llvm::find_if(dstMemOps, [&](Operation *dstOp) {
282+
return srcMemref == getMemRef(dstOp);
283+
}) != dstMemOps.end();
284+
});
285+
};
286+
287+
SmallVector<Operation *> dstOps;
288+
// Between non-affine src stores and dst load/store.
289+
llvm::append_range(dstOps, llvm::concat<Operation *const>(
290+
dstNode.loads, dstNode.stores,
291+
dstNode.memrefLoads, dstNode.memrefStores));
292+
if (hasNonAffineDep(srcNode.memrefStores, dstOps))
293+
return true;
294+
// Between non-affine loads and dst stores.
295+
dstOps.clear();
296+
llvm::append_range(dstOps, llvm::concat<Operation *const>(
297+
dstNode.stores, dstNode.memrefStores));
298+
if (hasNonAffineDep(srcNode.memrefLoads, dstOps))
299+
return true;
300+
// Between affine stores and memref load/stores.
301+
dstOps.clear();
302+
llvm::append_range(dstOps, llvm::concat<Operation *const>(
303+
dstNode.memrefLoads, dstNode.memrefStores));
304+
if (hasNonAffineDep(srcNode.stores, dstOps))
305+
return true;
306+
// Between affine loads and memref stores.
307+
dstOps.clear();
308+
llvm::append_range(dstOps, dstNode.memrefStores);
309+
if (hasNonAffineDep(srcNode.loads, dstOps))
310+
return true;
311+
312+
// Affine load/store pairs. We don't need to check for locally allocated
313+
// memrefs since the dependence analysis here is between mem ops from
314+
// srcNode's for op to dstNode's for op at the depth at which those
315+
// `affine.for` ops are nested, i.e., dependences at depth `d + 1` where
316+
// `d` is the number of common surrounding loops.
317+
for (auto *srcMemOp :
318+
llvm::concat<Operation *const>(srcNode.stores, srcNode.loads)) {
319+
MemRefAccess srcAcc(srcMemOp);
320+
if (srcAcc.memref != memref)
321+
continue;
322+
for (auto *destMemOp :
323+
llvm::concat<Operation *const>(dstNode.stores, dstNode.loads)) {
324+
MemRefAccess destAcc(destMemOp);
325+
if (destAcc.memref != memref)
326+
continue;
327+
// Check for a top-level dependence between srcNode and destNode's ops.
328+
if (!noDependence(checkMemrefAccessDependence(
329+
srcAcc, destAcc, getNestingDepth(srcNode.op) + 1)))
330+
return true;
331+
}
332+
}
333+
return false;
334+
}
335+
336+
bool MemRefDependenceGraph::init(bool fullAffineDependences) {
245337
LDBG() << "--- Initializing MDG ---";
246338
// Map from a memref to the set of ids of the nodes that have ops accessing
247339
// the memref.
@@ -344,8 +436,12 @@ bool MemRefDependenceGraph::init() {
344436
Node *dstNode = getNode(dstId);
345437
bool dstHasStoreOrFree =
346438
dstNode->hasStore(srcMemRef) || dstNode->hasFree(srcMemRef);
347-
if (srcHasStoreOrFree || dstHasStoreOrFree)
348-
addEdge(srcId, dstId, srcMemRef);
439+
if ((srcHasStoreOrFree || dstHasStoreOrFree)) {
440+
// Check precise affine deps if asked for; otherwise, conservative.
441+
if (!fullAffineDependences ||
442+
mayDependence(*srcNode, *dstNode, srcMemRef))
443+
addEdge(srcId, dstId, srcMemRef);
444+
}
349445
}
350446
}
351447
}
@@ -562,13 +658,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
562658
}
563659

564660
// Build set of insts in range (srcId, dstId) which depend on 'srcId'.
565-
SmallPtrSet<Operation *, 2> srcDepInsts;
661+
llvm::SmallPtrSet<Operation *, 2> srcDepInsts;
566662
for (auto &outEdge : outEdges.lookup(srcId))
567663
if (outEdge.id != dstId)
568664
srcDepInsts.insert(getNode(outEdge.id)->op);
569665

570666
// Build set of insts in range (srcId, dstId) on which 'dstId' depends.
571-
SmallPtrSet<Operation *, 2> dstDepInsts;
667+
llvm::SmallPtrSet<Operation *, 2> dstDepInsts;
572668
for (auto &inEdge : inEdges.lookup(dstId))
573669
if (inEdge.id != srcId)
574670
dstDepInsts.insert(getNode(inEdge.id)->op);
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' | FileCheck %s
2+
3+
// Test cases specifically for sibling fusion. Note that sibling fusion test
4+
// cases also exist in loop-fusion*.mlir.
5+
6+
// CHECK-LABEL: func @disjoint_stores
7+
func.func @disjoint_stores(%0: memref<8xf32>) {
8+
%alloc_1 = memref.alloc() : memref<16xf32>
9+
// The affine stores below are to different parts of the memrefs. Sibling
10+
// fusion helps improve reuse and is valid.
11+
affine.for %arg2 = 0 to 8 {
12+
%2 = affine.load %0[%arg2] : memref<8xf32>
13+
affine.store %2, %alloc_1[%arg2] : memref<16xf32>
14+
}
15+
affine.for %arg2 = 0 to 8 {
16+
%2 = affine.load %0[%arg2] : memref<8xf32>
17+
%3 = arith.negf %2 : f32
18+
affine.store %3, %alloc_1[%arg2 + 8] : memref<16xf32>
19+
}
20+
// CHECK: affine.for
21+
// CHECK-NOT: affine.for
22+
return
23+
}

0 commit comments

Comments
 (0)