diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h index 0dd8de4f70039..df4145db90a61 100644 --- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h @@ -153,9 +153,17 @@ struct MemRefDependenceGraph { MemRefDependenceGraph(Block &block) : block(block) {} - // Initializes the dependence graph based on operations in `block'. - // Returns true on success, false otherwise. - bool init(); + // Initializes the data dependence graph by iterating over the operations of + // the MDG's `block`. A `Node` is created for every top-level op except for + // side-effect-free operations with zero results and no regions. Assigns each + // node in the graph a node id based on the order in block. Fails if certain + // kinds of operations, for which `Node` creation isn't supported, are + // encountered (unknown region holding ops). If `fullAffineDependences` is + // set, affine memory dependence analysis is performed before concluding that + // conflicting affine memory accesses lead to a dependence check; otherwise, a + // pair of conflicting affine memory accesses (where one of them is a store + // and they are to the same memref) always leads to an edge (conservatively). + bool init(bool fullAffineDependences = true); // Returns the graph node for 'id'. const Node *getNode(unsigned id) const; diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp index 99ea20bf13b49..f38493bc9a96e 100644 --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/IntegerSet.h" #include "llvm/ADT/SetVector.h" @@ -241,7 +242,98 @@ addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg, return &node; } -bool MemRefDependenceGraph::init() { +/// Returns the memref being read/written by a memref/affine load/store op. +static Value getMemRef(Operation *memOp) { + if (auto memrefLoad = dyn_cast(memOp)) + return memrefLoad.getMemRef(); + if (auto affineLoad = dyn_cast(memOp)) + return affineLoad.getMemRef(); + if (auto memrefStore = dyn_cast(memOp)) + return memrefStore.getMemRef(); + if (auto affineStore = dyn_cast(memOp)) + return affineStore.getMemRef(); + llvm_unreachable("unexpected op"); +} + +/// Returns true if there may be a dependence on `memref` from srcNode's +/// memory ops to dstNode's memory ops, while using the affine memory +/// dependence analysis checks. The method assumes that there is at least one +/// memory op in srcNode's loads and stores on `memref`, and similarly for +/// `dstNode`. `srcNode.op` and `destNode.op` are expected to be nested in the +/// same block and so the dependences are tested at the depth of that block. +static bool mayDependence(const Node &srcNode, const Node &dstNode, + Value memref) { + assert(srcNode.op->getBlock() == dstNode.op->getBlock()); + if (!isa(srcNode.op) || !isa(dstNode.op)) + return true; + + // Conservatively handle dependences involving non-affine load/stores. Return + // true if there exists a conflicting read/write access involving such. + + // Check whether there is a dependence from a source read/write op to a + // destination read/write one; all expected to be memref/affine load/store. + auto hasNonAffineDep = [&](ArrayRef srcMemOps, + ArrayRef dstMemOps) { + return llvm::any_of(srcMemOps, [&](Operation *srcOp) { + Value srcMemref = getMemRef(srcOp); + if (srcMemref != memref) + return false; + return llvm::find_if(dstMemOps, [&](Operation *dstOp) { + return srcMemref == getMemRef(dstOp); + }) != dstMemOps.end(); + }); + }; + + SmallVector dstOps; + // Between non-affine src stores and dst load/store. + llvm::append_range(dstOps, llvm::concat( + dstNode.loads, dstNode.stores, + dstNode.memrefLoads, dstNode.memrefStores)); + if (hasNonAffineDep(srcNode.memrefStores, dstOps)) + return true; + // Between non-affine loads and dst stores. + dstOps.clear(); + llvm::append_range(dstOps, llvm::concat( + dstNode.stores, dstNode.memrefStores)); + if (hasNonAffineDep(srcNode.memrefLoads, dstOps)) + return true; + // Between affine stores and memref load/stores. + dstOps.clear(); + llvm::append_range(dstOps, llvm::concat( + dstNode.memrefLoads, dstNode.memrefStores)); + if (hasNonAffineDep(srcNode.stores, dstOps)) + return true; + // Between affine loads and memref stores. + dstOps.clear(); + llvm::append_range(dstOps, dstNode.memrefStores); + if (hasNonAffineDep(srcNode.loads, dstOps)) + return true; + + // Affine load/store pairs. We don't need to check for locally allocated + // memrefs since the dependence analysis here is between mem ops from + // srcNode's for op to dstNode's for op at the depth at which those + // `affine.for` ops are nested, i.e., dependences at depth `d + 1` where + // `d` is the number of common surrounding loops. + for (auto *srcMemOp : + llvm::concat(srcNode.stores, srcNode.loads)) { + MemRefAccess srcAcc(srcMemOp); + if (srcAcc.memref != memref) + continue; + for (auto *destMemOp : + llvm::concat(dstNode.stores, dstNode.loads)) { + MemRefAccess destAcc(destMemOp); + if (destAcc.memref != memref) + continue; + // Check for a top-level dependence between srcNode and destNode's ops. + if (!noDependence(checkMemrefAccessDependence( + srcAcc, destAcc, getNestingDepth(srcNode.op) + 1))) + return true; + } + } + return false; +} + +bool MemRefDependenceGraph::init(bool fullAffineDependences) { LDBG() << "--- Initializing MDG ---"; // Map from a memref to the set of ids of the nodes that have ops accessing // the memref. @@ -344,8 +436,12 @@ bool MemRefDependenceGraph::init() { Node *dstNode = getNode(dstId); bool dstHasStoreOrFree = dstNode->hasStore(srcMemRef) || dstNode->hasFree(srcMemRef); - if (srcHasStoreOrFree || dstHasStoreOrFree) - addEdge(srcId, dstId, srcMemRef); + if ((srcHasStoreOrFree || dstHasStoreOrFree)) { + // Check precise affine deps if asked for; otherwise, conservative. + if (!fullAffineDependences || + mayDependence(*srcNode, *dstNode, srcMemRef)) + addEdge(srcId, dstId, srcMemRef); + } } } } @@ -562,13 +658,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId, } // Build set of insts in range (srcId, dstId) which depend on 'srcId'. - SmallPtrSet srcDepInsts; + llvm::SmallPtrSet srcDepInsts; for (auto &outEdge : outEdges.lookup(srcId)) if (outEdge.id != dstId) srcDepInsts.insert(getNode(outEdge.id)->op); // Build set of insts in range (srcId, dstId) on which 'dstId' depends. - SmallPtrSet dstDepInsts; + llvm::SmallPtrSet dstDepInsts; for (auto &inEdge : inEdges.lookup(dstId)) if (inEdge.id != srcId) dstDepInsts.insert(getNode(inEdge.id)->op); diff --git a/mlir/test/Dialect/Affine/loop-fusion-sibling.mlir b/mlir/test/Dialect/Affine/loop-fusion-sibling.mlir new file mode 100644 index 0000000000000..937c855b86b50 --- /dev/null +++ b/mlir/test/Dialect/Affine/loop-fusion-sibling.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' | FileCheck %s + +// Test cases specifically for sibling fusion. Note that sibling fusion test +// cases also exist in loop-fusion*.mlir. + +// CHECK-LABEL: func @disjoint_stores +func.func @disjoint_stores(%0: memref<8xf32>) { + %alloc_1 = memref.alloc() : memref<16xf32> + // The affine stores below are to different parts of the memrefs. Sibling + // fusion helps improve reuse and is valid. + affine.for %arg2 = 0 to 8 { + %2 = affine.load %0[%arg2] : memref<8xf32> + affine.store %2, %alloc_1[%arg2] : memref<16xf32> + } + affine.for %arg2 = 0 to 8 { + %2 = affine.load %0[%arg2] : memref<8xf32> + %3 = arith.negf %2 : f32 + affine.store %3, %alloc_1[%arg2 + 8] : memref<16xf32> + } + // CHECK: affine.for + // CHECK-NOT: affine.for + return +}