Skip to content

Commit 18b9390

Browse files
committed
[MLIR][Affine] Check dependences during MDG init
Check affine dependences precisely during MDG init before adding edges. We were conservatively only checking for memref-level conflicts. Leads to more/better fusion as a result. Fixes: #156421
1 parent e09fe9b commit 18b9390

File tree

3 files changed

+132
-8
lines changed

3 files changed

+132
-8
lines changed

mlir/include/mlir/Dialect/Affine/Analysis/Utils.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,17 @@ struct MemRefDependenceGraph {
153153

154154
MemRefDependenceGraph(Block &block) : block(block) {}
155155

156-
// Initializes the dependence graph based on operations in `block'.
157-
// Returns true on success, false otherwise.
158-
bool init();
156+
// Initializes the data dependence graph by iterating over the operations of
157+
// the MDG's `block`. A `Node` is created for every top-level op except for
158+
// side-effect-free operations with zero results and no regions. Assigns each
159+
// node in the graph a node id based on the order in block. Fails if certain
160+
// kinds of operations, for which `Node` creation isn't supported, are
161+
// encountered (unknown region holding ops). If `fullAffineDependences` is
162+
// set, affine memory dependence analysis is performed before concluding that
163+
// conflicting affine memory accesses lead to a dependence check; otherwise, a
164+
// pair of conflicting affine memory accesses (where one of them is a store
165+
// and they are to the same memref) always leads to an edge (conservatively).
166+
bool init(bool fullAffineDependences = true);
159167

160168
// Returns the graph node for 'id'.
161169
const Node *getNode(unsigned id) const;

mlir/lib/Dialect/Affine/Analysis/Utils.cpp

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1919
#include "mlir/Dialect/Affine/IR/AffineOps.h"
2020
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
21+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2122
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2223
#include "mlir/IR/IntegerSet.h"
2324
#include "llvm/ADT/SetVector.h"
@@ -241,7 +242,95 @@ addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
241242
return &node;
242243
}
243244

244-
bool MemRefDependenceGraph::init() {
245+
/// Returns the memref being read/written by a memref/affine load/store op.
246+
static Value getMemRef(Operation *memOp) {
247+
if (auto memrefLoad = dyn_cast<memref::LoadOp>(memOp))
248+
return memrefLoad.getMemRef();
249+
if (auto affineLoad = dyn_cast<AffineReadOpInterface>(memOp))
250+
return affineLoad.getMemRef();
251+
if (auto memrefStore = dyn_cast<memref::StoreOp>(memOp))
252+
return memrefStore.getMemRef();
253+
if (auto affineStore = dyn_cast<AffineWriteOpInterface>(memOp))
254+
return affineStore.getMemRef();
255+
llvm_unreachable("unexpected op");
256+
}
257+
258+
/// Returns true if there may be a dependence on `memref` from srcNode's
259+
/// memory ops to dstNode's memory ops, while using the affine memory
260+
/// dependence analysis checks. The method assumes that there is at least one
261+
/// memory op in srcNode's loads and stores on `memref`, and similarly for
262+
/// `dstNode`. `srcNode.op` and `destNode.op` are expected to be nested in the
263+
/// same block and so the dependences are tested at the depth of that block.
264+
static bool mayDependence(const Node &srcNode, const Node &dstNode,
265+
Value memref) {
266+
assert(srcNode.op->getBlock() == dstNode.op->getBlock());
267+
if (!isa<AffineForOp>(srcNode.op) || !isa<AffineForOp>(dstNode.op))
268+
return true;
269+
270+
// Conservatively handle dependences involving non-affine load/stores. Return
271+
// true if there exists a conflicting read/write access involving such.
272+
auto hasNonAffineDep = [&](ArrayRef<Operation *> srcOps,
273+
ArrayRef<Operation *> dstOps) {
274+
return llvm::any_of(srcOps, [&](Operation *srcOp) {
275+
Value srcMemref = getMemRef(srcOp);
276+
if (srcMemref != memref)
277+
return false;
278+
return llvm::find_if(dstOps, [&](Operation *dstOp) {
279+
return srcMemref == getMemRef(dstOp);
280+
}) != dstOps.end();
281+
});
282+
};
283+
284+
SmallVector<Operation *> dstOps;
285+
// Between non-affine src stores and dst load/store.
286+
llvm::append_range(dstOps, llvm::concat<Operation *const>(
287+
dstNode.loads, dstNode.stores,
288+
dstNode.memrefLoads, dstNode.memrefStores));
289+
if (hasNonAffineDep(srcNode.memrefStores, dstOps))
290+
return true;
291+
// Between non-affine loads and dst stores.
292+
dstOps.clear();
293+
llvm::append_range(dstOps, llvm::concat<Operation *const>(
294+
dstNode.stores, dstNode.memrefStores));
295+
if (hasNonAffineDep(srcNode.memrefLoads, dstOps))
296+
return true;
297+
// Between affine stores and memref load/stores.
298+
dstOps.clear();
299+
llvm::append_range(dstOps, llvm::concat<Operation *const>(
300+
dstNode.memrefLoads, dstNode.memrefStores));
301+
if (hasNonAffineDep(srcNode.stores, dstOps))
302+
return true;
303+
// Between affine loads and memref stores.
304+
dstOps.clear();
305+
llvm::append_range(dstOps, dstNode.memrefStores);
306+
if (hasNonAffineDep(srcNode.loads, dstOps))
307+
return true;
308+
309+
// Affine load/store pairs. We don't need to check for locally allocated
310+
// memrefs since the dependence analysis here is between mem ops from
311+
// srcNode's for op to dstNode's for op at the depth at which those
312+
// `affine.for` ops are nested, i.e., dependences at depth `d + 1` where
313+
// `d` is the number of common surrounding loops.
314+
for (auto *srcMemOp :
315+
llvm::concat<Operation *const>(srcNode.stores, srcNode.loads)) {
316+
MemRefAccess srcAcc(srcMemOp);
317+
if (srcAcc.memref != memref)
318+
continue;
319+
for (auto *destMemOp :
320+
llvm::concat<Operation *const>(dstNode.stores, dstNode.loads)) {
321+
MemRefAccess destAcc(destMemOp);
322+
if (destAcc.memref != memref)
323+
continue;
324+
// Check for a top-level dependence between srcNode and destNode's ops.
325+
if (!noDependence(checkMemrefAccessDependence(
326+
srcAcc, destAcc, getNestingDepth(srcNode.op) + 1)))
327+
return true;
328+
}
329+
}
330+
return false;
331+
}
332+
333+
bool MemRefDependenceGraph::init(bool fullAffineDependences) {
245334
LDBG() << "--- Initializing MDG ---";
246335
// Map from a memref to the set of ids of the nodes that have ops accessing
247336
// the memref.
@@ -344,8 +433,12 @@ bool MemRefDependenceGraph::init() {
344433
Node *dstNode = getNode(dstId);
345434
bool dstHasStoreOrFree =
346435
dstNode->hasStore(srcMemRef) || dstNode->hasFree(srcMemRef);
347-
if (srcHasStoreOrFree || dstHasStoreOrFree)
348-
addEdge(srcId, dstId, srcMemRef);
436+
if ((srcHasStoreOrFree || dstHasStoreOrFree)) {
437+
// Check precise affine deps if asked for; otherwise, conservative.
438+
if (!fullAffineDependences ||
439+
mayDependence(*srcNode, *dstNode, srcMemRef))
440+
addEdge(srcId, dstId, srcMemRef);
441+
}
349442
}
350443
}
351444
}
@@ -562,13 +655,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
562655
}
563656

564657
// Build set of insts in range (srcId, dstId) which depend on 'srcId'.
565-
SmallPtrSet<Operation *, 2> srcDepInsts;
658+
llvm::SmallPtrSet<Operation *, 2> srcDepInsts;
566659
for (auto &outEdge : outEdges.lookup(srcId))
567660
if (outEdge.id != dstId)
568661
srcDepInsts.insert(getNode(outEdge.id)->op);
569662

570663
// Build set of insts in range (srcId, dstId) on which 'dstId' depends.
571-
SmallPtrSet<Operation *, 2> dstDepInsts;
664+
llvm::SmallPtrSet<Operation *, 2> dstDepInsts;
572665
for (auto &inEdge : inEdges.lookup(dstId))
573666
if (inEdge.id != srcId)
574667
dstDepInsts.insert(getNode(inEdge.id)->op);
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' | FileCheck %s
2+
3+
// Test cases specifically for sibling fusion. Note that sibling fusion test
4+
// cases also exist in loop-fusion*.mlir.
5+
6+
// CHECK-LABEL: func @disjoint_stores
7+
func.func @disjoint_stores(%0: memref<8xf32>) {
8+
%alloc_1 = memref.alloc() : memref<16xf32>
9+
// The affine stores below are to different parts of the memrefs. Sibling
10+
// fusion helps improve reuse and is valid.
11+
affine.for %arg2 = 0 to 8 {
12+
%2 = affine.load %0[%arg2] : memref<8xf32>
13+
affine.store %2, %alloc_1[%arg2] : memref<16xf32>
14+
}
15+
affine.for %arg2 = 0 to 8 {
16+
%2 = affine.load %0[%arg2] : memref<8xf32>
17+
%3 = arith.negf %2 : f32
18+
affine.store %3, %alloc_1[%arg2 + 8] : memref<16xf32>
19+
}
20+
// CHECK: affine.for
21+
// CHECK-NOT: affine.for
22+
return
23+
}

0 commit comments

Comments
 (0)