diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 920d6f86a355e..80a0b58b6ab63 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -293,7 +293,9 @@ void createHLFIRToFIRPassPipeline(mlir::PassManager &pm, pm, hlfir::createInlineElementals); if (optLevel.isOptimizingForSpeed()) { addCanonicalizerPassWithoutRegionSimplification(pm); - pm.addPass(mlir::createCSEPass()); + mlir::CSEPassOptions options; + options.hoistPureOps = false; + pm.addPass(mlir::createCSEPass(options)); // Run SimplifyHLFIRIntrinsics pass late after CSE, // and allow introducing operations with new side effects. addNestedPassToAllTopLevelOperations(pm, [&]() { diff --git a/mlir/include/mlir/Transforms/CSE.h b/mlir/include/mlir/Transforms/CSE.h index 4a87d585e0eb9..b930b78cb641f 100644 --- a/mlir/include/mlir/Transforms/CSE.h +++ b/mlir/include/mlir/Transforms/CSE.h @@ -32,7 +32,8 @@ void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed = nullptr, int64_t *numCSE = nullptr, - int64_t *numDCE = nullptr); + int64_t *numDCE = nullptr, + bool hoistPureOps = true); /// Eliminate common subexpressions within the given region. /// @@ -41,7 +42,8 @@ void eliminateCommonSubExpressions(RewriterBase &rewriter, /// DCE counts are needed. void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Region ®ion, - bool *changed = nullptr); + bool *changed = nullptr, + bool hoistPureOps = true); } // namespace mlir diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 74ac370ea950b..e7781c8c9e1bb 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -94,6 +94,10 @@ def CSEPass : Pass<"cse"> { operations. See [Common subexpression elimination](https://en.wikipedia.org/wiki/Common_subexpression_elimination) for more general details on this optimization. }]; + let options = [ + Option<"hoistPureOps", "hoist-pure-ops", "bool", /*default=*/"true", + "Allow hoisting of pure operations out of regions">, + ]; let statistics = [ Statistic<"numCSE", "num-cse'd", "Number of operations CSE'd">, Statistic<"numDCE", "num-dce'd", "Number of operations DCE'd"> diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index f7afa03e2f02b..4ed90d49acea7 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/CSE.h" +#include "mlir/Transforms/Passes.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/PatternMatch.h" @@ -23,11 +24,13 @@ namespace mlir { #include "mlir/Transforms/Passes.h.inc" } // namespace mlir +#define DEBUG_TYPE "cse" using namespace mlir; namespace { /// CSE pass. struct CSE : public impl::CSEPassBase { + using impl::CSEPassBase::CSEPassBase; void runOnOperation() override; }; } // namespace @@ -41,7 +44,7 @@ void CSE::runOnOperation() { int64_t cseCount = 0; int64_t dceCount = 0; eliminateCommonSubExpressions(rewriter, domInfo, getOperation(), &changed, - &cseCount, &dceCount); + &cseCount, &dceCount, hoistPureOps); numCSE = cseCount; numDCE = dceCount; diff --git a/mlir/lib/Transforms/Utils/CSE.cpp b/mlir/lib/Transforms/Utils/CSE.cpp index 90444e6201891..934908bda7883 100644 --- a/mlir/lib/Transforms/Utils/CSE.cpp +++ b/mlir/lib/Transforms/Utils/CSE.cpp @@ -19,6 +19,7 @@ #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/Support/Allocator.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/RecyclingAllocator.h" #include @@ -52,8 +53,9 @@ namespace { /// Simple common sub-expression elimination. class CSEDriver { public: - CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo) - : rewriter(rewriter), domInfo(domInfo) {} + CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo, + bool hoistPureOps = true) + : rewriter(rewriter), domInfo(domInfo), hoistPureOps(hoistPureOps) {} /// Simplify all operations within the given op. void simplify(Operation *op, bool *changed = nullptr); @@ -97,10 +99,13 @@ class CSEDriver { /// Attempt to eliminate a redundant operation. Returns success if the /// operation was marked for removal, failure otherwise. - LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op, + LogicalResult simplifyOperation(ScopedMapTy &knownValues, + ScopedMapTy &knownPureOps, Operation *op, bool hasSSADominance); - void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance); - void simplifyRegion(ScopedMapTy &knownValues, Region ®ion); + void simplifyBlock(ScopedMapTy &knownValues, ScopedMapTy &knownPureOps, + Block *bb, bool hasSSADominance); + void simplifyRegion(ScopedMapTy &knownValues, ScopedMapTy &knownPureOps, + Region ®ion); /// Erase all operations queued for deletion by the simplification routines. void eraseDeadOps(bool *changed); @@ -112,20 +117,137 @@ class CSEDriver { /// between the two operations. bool hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp); + LogicalResult hoistPureOp(Operation *existing, Operation *op); + /// A rewriter for modifying the IR. RewriterBase &rewriter; /// Operations marked as dead and to be erased. std::vector opsToErase; + DominanceInfo *domInfo = nullptr; MemEffectsCache memEffectsCache; // Various statistics. int64_t numCSE = 0; int64_t numDCE = 0; + + bool hoistPureOps = true; + + // The map uses region op as the key and a list of operations as the + // value. This list describes the dependencies of the region op, as the + // operations within the region op consume results from the value in + // the map. Therefore, it is necessary to consider these dependent operations + // when hoisting a region op. + DenseMap> hoistBlockingDeps; + + // The keys of this map are exiting ops, and the values are lists of + // operations. Each entry describes an exiting op that has been hoisted along + // with its associated operations. When an exiting op and its equivalent + // operations are hoisted for the first time, they are added to this map. + // During subsequent hoistings of the same exiting op(hoisting exiting op + // multiple times), the operations stored in the map's value will also be + // hoisted together with it. + DenseMap> hoistOpsSet; }; } // namespace +/// Returns true if the path between block 'a' and block 'b' in the region +/// hierarchy crosses an operation with the 'IsIsolatedFromAbove' trait. +static bool isBlockCrossIsIsolatedFromAbove(DominanceInfo *dominate, Block *a, + Block *b) { + if (a == b) + return false; + if (a->getParent() == b->getParent()) + return false; + if (dominate->dominates(b, a)) + std::swap(b, a); + while (b && b->getParentOp()) { + Operation *parentOp = b->getParentOp(); + if (parentOp->mightHaveTrait()) + return true; + b = parentOp->getBlock(); + if (b == a) + return false; + } + return false; +} + +/// Hoist the pure ops to the location of the Nearest Common Dominator. +LogicalResult CSEDriver::hoistPureOp(Operation *existing, Operation *op) { + Block *ancestorBlock = + domInfo->findNearestCommonDominator(existing->getBlock(), op->getBlock()); + if (!ancestorBlock) { + LDBG() << "hoist " << OpWithFlags(existing, OpPrintingFlags().skipRegions()) + << " and " << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " failed"; + return failure(); + } + + if (isBlockCrossIsIsolatedFromAbove(domInfo, ancestorBlock, + existing->getBlock()) || + isBlockCrossIsIsolatedFromAbove(domInfo, ancestorBlock, op->getBlock())) + return failure(); + + if (existing->getParentOp() != ancestorBlock->getParentOp() && + !existing->use_empty()) { + LDBG() << "add " + << OpWithFlags(existing->getParentOp(), + OpPrintingFlags().skipRegions()) + << " dependents " + << OpWithFlags(existing, OpPrintingFlags().skipRegions()); + hoistBlockingDeps[existing->getParentOp()].push_back(existing); + } + + // Find the insertion point based on dominance relationships. When hoisting a + // region op, we must consider not only its operands but also the dominance + // relationships of the operations within the region when determining the + // insertion point + Operation *insertPoint = nullptr; + SmallVector dependentOperands(existing->getOperands()); + if (hoistBlockingDeps.contains(existing) && + !hoistBlockingDeps[existing].empty()) { + for (Operation *dependentOp : hoistBlockingDeps[existing]) + dependentOperands.append(dependentOp->getResults().begin(), + dependentOp->getResults().end()); + } + + for (Value operand : dependentOperands) { + if (domInfo->properlyDominates(operand, &ancestorBlock->front())) + continue; + if (!insertPoint) { + insertPoint = operand.getDefiningOp(); + } else { + insertPoint = domInfo->dominates(insertPoint, operand.getDefiningOp()) + ? operand.getDefiningOp() + : insertPoint; + } + } + + // We hoist both `op` and `existing` here because if they are identical + // regionOps and we only hoist existing, the two would no longer be congruent. + // This would lead to a missed optimization opportunity in subsequent CSE + // passes. The test @cse_multiple_regions's `%r2` tests it. + if (!insertPoint) { + rewriter.moveOpBefore(existing, ancestorBlock, ancestorBlock->begin()); + rewriter.moveOpAfter(op, existing); + } else { + rewriter.moveOpAfter(existing, insertPoint); + rewriter.moveOpAfter(op, existing); + } + + // When hoisting an exiting op multiple times, we must also hoist the + // operations that were previously hoisted alongside it. + if (hoistOpsSet.contains(existing) && !hoistOpsSet[existing].empty()) + for (Operation *op : hoistOpsSet[existing]) + rewriter.moveOpAfter(op, existing); + hoistOpsSet[existing].push_back(op); + LDBG() << "hoist " << OpWithFlags(existing, OpPrintingFlags().skipRegions()) + << " and " << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " success"; + return success(); +} + void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, Operation *existing, bool hasSSADominance) { @@ -136,6 +258,21 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, // If the region has SSA dominance, then we are guaranteed to have not // visited any use of the current operation. // Replace all uses, but do not remove the operation yet. + if (!domInfo->properlyDominates(existing, op)) { + if (!hoistPureOps || failed(hoistPureOp(existing, op))) + return; + } else { + // Hoist `op` even though `existing` already dominates it, because + // hoisting op may create further CSE optimization opportunities for + // subsequent region operations. The test @cse_multiple_regions's `%r3` + // tests it. + rewriter.moveOpAfter(op, existing); + } + LDBG() << "replace " << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " with " + << OpWithFlags(existing, OpPrintingFlags().skipRegions()); + LDBG() << "add " << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " to opsToErase"; rewriter.replaceAllOpUsesWith(op, existing->getResults()); opsToErase.push_back(op); } else { @@ -150,11 +287,19 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, if (all_of(v.getUses(), wasVisited)) rewriteListener->notifyOperationReplaced(op, existing); + if (!domInfo->properlyDominates(existing, op)) { + if (!hoistPureOps || failed(hoistPureOp(existing, op))) + return; + } + // Replace all uses, but do not remove the operation yet. This does not + // notify the listener because the original op is not erased. + LDBG() << "replace " << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " with " + << OpWithFlags(existing, OpPrintingFlags().skipRegions()); // Replace all uses, but do not remove the operation yet. This does not // notify the listener because the original op is not erased. rewriter.replaceUsesWithIf(op->getResults(), existing->getResults(), wasVisited); - // There may be some remaining uses of the operation. if (op->use_empty()) opsToErase.push_back(op); @@ -247,8 +392,11 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp, /// Attempt to eliminate a redundant operation. LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues, + ScopedMapTy &knownPureOps, Operation *op, bool hasSSADominance) { + LDBG() << "visit operation: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); // Don't simplify terminator operations. if (op->hasTrait()) return failure(); @@ -279,6 +427,8 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues, return success(); } } + LDBG() << "insert op: " << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " to map"; knownValues.insert(op, op); return failure(); } @@ -289,13 +439,29 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues, return success(); } - // Otherwise, we add this operation to the known values map. - knownValues.insert(op, op); + if (auto *existing = knownPureOps.lookup(op)) { + replaceUsesAndDelete(knownPureOps, op, existing, hasSSADominance); + return success(); + } + + if (mlir::isPure(op)) { + LDBG() << "insert op: " << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " to pureMap"; + knownPureOps.insert(op, op); + } else { + // Otherwise, we add this operation to the known values map. + LDBG() << "insert op: " << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " to map"; + knownValues.insert(op, op); + } return failure(); } -void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb, +void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, + ScopedMapTy &knownPureOps, Block *bb, bool hasSSADominance) { + LDBG() << "visit block #" << bb->computeBlockNumber() << " of " + << OpWithFlags(bb->getParentOp(), OpPrintingFlags().skipRegions()); for (auto &op : llvm::make_early_inc_range(*bb)) { // If the operation is already trivially dead just add it to the erase list. // This also avoids calling `simplifyRegion` on dead region ops @@ -313,34 +479,42 @@ void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb, // implicit captures in explicit capture only regions. if (op.mightHaveTrait()) { ScopedMapTy nestedKnownValues; + ScopedMapTy nestedKnownPureOps; + ScopedMapTy::ScopeTy scope(nestedKnownValues); + ScopedMapTy::ScopeTy pureScope(nestedKnownPureOps); for (auto ®ion : op.getRegions()) - simplifyRegion(nestedKnownValues, region); + simplifyRegion(nestedKnownValues, nestedKnownPureOps, region); } else { // Otherwise, process nested regions normally. for (auto ®ion : op.getRegions()) - simplifyRegion(knownValues, region); + simplifyRegion(knownValues, knownPureOps, region); } } // If the operation is simplified, we don't process any held regions. - if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance))) + if (succeeded( + simplifyOperation(knownValues, knownPureOps, &op, hasSSADominance))) continue; } // Clear the MemoryEffects cache since its usage is by block only. memEffectsCache.clear(); } -void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) { +void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, + ScopedMapTy &knownPureOps, Region ®ion) { // If the region is empty there is nothing to do. if (region.empty()) return; + LDBG() << "visit region #" << region.getRegionNumber() << " of " + << OpWithFlags(region.getParentOp(), OpPrintingFlags().skipRegions()); + bool hasSSADominance = domInfo->hasSSADominance(®ion); // If the region only contains one block, then simplify it directly. if (region.hasOneBlock()) { ScopedMapTy::ScopeTy scope(knownValues); - simplifyBlock(knownValues, ®ion.front(), hasSSADominance); + simplifyBlock(knownValues, knownPureOps, ®ion.front(), hasSSADominance); return; } @@ -368,7 +542,7 @@ void CSEDriver::simplifyRegion(ScopedMapTy &knownValues, Region ®ion) { // Check to see if we need to process this node. if (!currentNode->processed) { currentNode->processed = true; - simplifyBlock(knownValues, currentNode->node->getBlock(), + simplifyBlock(knownValues, knownPureOps, currentNode->node->getBlock(), hasSSADominance); } @@ -402,24 +576,38 @@ void CSEDriver::eraseDeadOps(bool *changed) { } void CSEDriver::simplify(Operation *op, bool *changed) { - // Simplify all regions. - ScopedMapTy knownValues; - for (auto ®ion : op->getRegions()) - simplifyRegion(knownValues, region); + /// Simplify all regions. Added a new scope using curly braces to release the + /// knownPureOps scope before deleting the operation. + { + /// The entry point for CSE simplification. A top-level scope is added for + /// 'knownPureOps' to track pure operations across the entire operation's + /// regions, enabling potential hoisting opportunities. Since only pure + /// operations are candidates for hoisting, 'knownValues' does not require + /// a corresponding top-level scope here. + ScopedMapTy knownValues; + ScopedMapTy knownPureOps; + ScopedMapTy::ScopeTy scope(knownPureOps); + for (auto ®ion : op->getRegions()) + simplifyRegion(knownValues, knownPureOps, region); + } eraseDeadOps(changed); } void CSEDriver::simplify(Region ®ion, bool *changed) { - ScopedMapTy knownValues; - simplifyRegion(knownValues, region); + { + ScopedMapTy knownValues; + ScopedMapTy knownPureOps; + ScopedMapTy::ScopeTy scope(knownPureOps); + simplifyRegion(knownValues, knownPureOps, region); + } eraseDeadOps(changed); } void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed, int64_t *numCSE, - int64_t *numDCE) { - CSEDriver driver(rewriter, &domInfo); + int64_t *numDCE, bool hoistPureOps) { + CSEDriver driver(rewriter, &domInfo, hoistPureOps); driver.simplify(op, changed); if (numCSE) *numCSE = driver.getNumCSE(); @@ -429,7 +617,7 @@ void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter, void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Region ®ion, - bool *changed) { - CSEDriver driver(rewriter, &domInfo); + bool *changed, bool hoistPureOps) { + CSEDriver driver(rewriter, &domInfo, hoistPureOps); driver.simplify(region, changed); } diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir index 2a183cb4d056a..7fd66035a4140 100644 --- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir +++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir @@ -82,6 +82,8 @@ func.func @use_too_many_tiles() { // AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_H]], %[[SVL_H]]) // AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref +// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] +// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] // // AFTER-LLVM-LOWERING-NOT: scf.for @@ -104,8 +106,6 @@ func.func @use_too_many_tiles() { // AFTER-LLVM-LOWERING: scf.for // AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] { -// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] -// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] // AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]] // AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}> // AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}> @@ -122,8 +122,6 @@ func.func @use_too_many_tiles() { // AFTER-LLVM-LOWERING: scf.for // AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] { -// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] -// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] // AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]] // AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}> // AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}> @@ -156,6 +154,8 @@ func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: m // AFTER-LLVM-LOWERING-DAG: %[[SVL_S:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index // AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_S]], %[[SVL_S]]) // AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref +// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] +// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] // /// 1. Swap %useAllTiles and %tile - note that this will only swap one 32-bit @@ -163,8 +163,6 @@ func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: m // AFTER-LLVM-LOWERING: scf.for // AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] { -// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] -// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] // AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]] // Read ZA tile slice -> vector // AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}> @@ -182,8 +180,6 @@ func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: m // AFTER-LLVM-LOWERING: scf.for // AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] { -// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]] -// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1] // AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]] /// Read ZA tile slice -> vector // AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}> diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir index 3929f5be3b4ef..8dc6364fddb2e 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -199,27 +199,20 @@ func.func @eleminate_multiple_ops(%t: tensor {bufferization.buffer_layout { %cst1 = arith.constant 0.0: f32 %cst2 = arith.constant 1.0: f32 - - // CHECK: %[[r:.*]] = scf.if %{{.*}} -> (memref + // CHECK: %[[T_SUBVIEW_1:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] + // CHECK: scf.if %{{.*}} %if = scf.if %c -> tensor { - // CHECK: %[[T_SUBVIEW_1:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] %a1 = tensor.empty(%sz) : tensor // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_1]] : memref) -> tensor - // CHECK: scf.yield %[[T_SUBVIEW_1]] scf.yield %f1 : tensor } else { - // CHECK: %[[T_SUBVIEW_2:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] %a2 = tensor.empty(%sz) : tensor - // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW_2]] : memref) -> tensor - // CHECK: scf.yield %[[T_SUBVIEW_2]] scf.yield %f2 : tensor } - - // Self-copy could canonicalize away later. - // CHECK: %[[T_SUBVIEW_3:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] - // CHECK: memref.copy %[[r]], %[[T_SUBVIEW_3]] + // CHECK: return %[[FUNC_ARG]] %r1 = tensor.insert_slice %if into %t[42][%sz][1]: tensor into tensor return %r1: tensor } diff --git a/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir b/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir index 6cab25b50460d..8aa59b2bf0f5f 100644 --- a/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir +++ b/mlir/test/Dialect/Linalg/matmul-shared-memory-padding.mlir @@ -5,14 +5,13 @@ // CHECK-NOT: memref.copy // CHECK: linalg.fill // CHECK: scf.for +// CHECK: vector.constant_mask [16, 4] : vector<128x4xi1> // CHECK: memref.alloc() : memref<128x16xf32, 3> // CHECK: scf.forall -// CHECK: vector.constant_mask [16, 4] : vector<128x4xi1> // CHECK: vector.transfer_read // CHECK: vector.transfer_write // CHECK: memref.alloc() : memref<16x128xf32, 3> // CHECK: scf.forall -// CHECK: vector.constant_mask [16, 4] : vector<128x4xi1> // CHECK: vector.transfer_read // CHECK: vector.transfer_write // CHECK: memref.alloc() : memref<128x128xf32, 3> diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir index 51bf4a23406d4..9b6716f09df37 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -315,7 +315,7 @@ module attributes {transform.with_named_sequence} { // Test dynamic padding using `use_prescribed_tensor_shapes` -// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 7) * 7)> +// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 7) * 7)> // CHECK: @use_prescribed_tensor_shapes // CHECK: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<12x?xf32> func.func @use_prescribed_tensor_shapes(%arg0: tensor, diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir index c26ba56347299..f6aa91ad2b10f 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir @@ -3,118 +3,117 @@ #DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> // CHECK-LABEL: func.func @fill_zero_after_alloc( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, -// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr) -> !llvm.ptr { -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[ZERO:.*]] = llvm.mlir.zero : !llvm.ptr -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32 -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant false -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant true -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 100 : index -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 300 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 262144 : i64 -// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi64> -// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref -// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64> -// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi64> -// CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<2xindex> -// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref -// CHECK: memref.store %[[VAL_9]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex> -// CHECK: memref.store %[[VAL_10]], %[[VAL_14]]{{\[}}%[[VAL_6]]] : memref<2xindex> -// CHECK: %[[VAL_16:.*]] = memref.alloca() : memref<2xindex> -// CHECK: %[[VAL_17:.*]] = memref.cast %[[VAL_16]] : memref<2xindex> to memref -// CHECK: memref.store %[[VAL_5]], %[[VAL_16]]{{\[}}%[[VAL_5]]] : memref<2xindex> -// CHECK: memref.store %[[VAL_6]], %[[VAL_16]]{{\[}}%[[VAL_6]]] : memref<2xindex> -// CHECK: %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[ZERO]]) : (memref, memref, memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr -// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<300xf64> -// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<300xf64> to memref -// CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<300xi1> -// CHECK: %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<300xi1> to memref -// CHECK: %[[VAL_24:.*]] = memref.alloc() : memref<300xindex> -// CHECK: %[[VAL_25:.*]] = memref.cast %[[VAL_24]] : memref<300xindex> to memref -// CHECK: linalg.fill ins(%[[VAL_2]] : f64) outs(%[[VAL_20]] : memref<300xf64>) -// CHECK: linalg.fill ins(%[[VAL_7]] : i1) outs(%[[VAL_22]] : memref<300xi1>) -// CHECK-DAG: %[[VAL_26:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref -// CHECK-DAG: %[[VAL_27:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref -// CHECK-DAG: %[[VAL_28:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref -// CHECK-DAG: %[[VAL_29:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref -// CHECK-DAG: %[[VAL_30:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref -// CHECK-DAG: %[[VAL_31:.*]] = call @sparsePositions0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref -// CHECK-DAG: %[[VAL_32:.*]] = call @sparseCoordinates0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref -// CHECK-DAG: %[[VAL_33:.*]] = call @sparsePositions0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref -// CHECK-DAG: %[[VAL_34:.*]] = call @sparseCoordinates0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref -// CHECK-DAG: %[[VAL_35:.*]] = call @sparseValuesF64(%[[VAL_1]]) : (!llvm.ptr) -> memref -// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_5]]] : memref -// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_6]]] : memref -// CHECK: scf.for %[[VAL_38:.*]] = %[[VAL_36]] to %[[VAL_37]] step %[[VAL_6]] { -// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_38]]] : memref -// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_38]]] : memref -// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_38]], %[[VAL_6]] : index -// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_41]]] : memref -// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_5]]] : memref -// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_6]]] : memref -// CHECK: %[[VAL_45:.*]]:3 = scf.while (%[[VAL_46:.*]] = %[[VAL_40]], %[[VAL_47:.*]] = %[[VAL_43]], %[[VAL_48:.*]] = %[[VAL_5]]) : (index, index, index) -> (index, index, index) { -// CHECK: %[[VAL_49:.*]] = arith.cmpi ult, %[[VAL_46]], %[[VAL_42]] : index -// CHECK: %[[VAL_50:.*]] = arith.cmpi ult, %[[VAL_47]], %[[VAL_44]] : index -// CHECK: %[[VAL_51:.*]] = arith.andi %[[VAL_49]], %[[VAL_50]] : i1 -// CHECK: scf.condition(%[[VAL_51]]) %[[VAL_46]], %[[VAL_47]], %[[VAL_48]] : index, index, index +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr) -> !llvm.ptr { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[MLIR_0:.*]] = llvm.mlir.zero : !llvm.ptr +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 0 : i32 +// CHECK: %[[CONSTANT_3:.*]] = arith.constant true +// CHECK: %[[CONSTANT_4:.*]] = arith.constant false +// CHECK: %[[CONSTANT_5:.*]] = arith.constant 1 : index +// CHECK: %[[CONSTANT_6:.*]] = arith.constant 0 : index +// CHECK: %[[CONSTANT_7:.*]] = arith.constant 100 : index +// CHECK: %[[CONSTANT_8:.*]] = arith.constant 300 : index +// CHECK: %[[CONSTANT_9:.*]] = arith.constant 262144 : i64 +// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() : memref<2xi64> +// CHECK: %[[CAST_0:.*]] = memref.cast %[[ALLOCA_0]] : memref<2xi64> to memref +// CHECK: memref.store %[[CONSTANT_9]], %[[ALLOCA_0]]{{\[}}%[[CONSTANT_6]]] : memref<2xi64> +// CHECK: memref.store %[[CONSTANT_9]], %[[ALLOCA_0]]{{\[}}%[[CONSTANT_5]]] : memref<2xi64> +// CHECK: %[[ALLOCA_1:.*]] = memref.alloca() : memref<2xindex> +// CHECK: %[[CAST_1:.*]] = memref.cast %[[ALLOCA_1]] : memref<2xindex> to memref +// CHECK: memref.store %[[CONSTANT_7]], %[[ALLOCA_1]]{{\[}}%[[CONSTANT_6]]] : memref<2xindex> +// CHECK: memref.store %[[CONSTANT_8]], %[[ALLOCA_1]]{{\[}}%[[CONSTANT_5]]] : memref<2xindex> +// CHECK: %[[ALLOCA_2:.*]] = memref.alloca() : memref<2xindex> +// CHECK: %[[CAST_2:.*]] = memref.cast %[[ALLOCA_2]] : memref<2xindex> to memref +// CHECK: memref.store %[[CONSTANT_6]], %[[ALLOCA_2]]{{\[}}%[[CONSTANT_6]]] : memref<2xindex> +// CHECK: memref.store %[[CONSTANT_5]], %[[ALLOCA_2]]{{\[}}%[[CONSTANT_5]]] : memref<2xindex> +// CHECK: %[[VAL_0:.*]] = call @newSparseTensor(%[[CAST_1]], %[[CAST_1]], %[[CAST_0]], %[[CAST_2]], %[[CAST_2]], %[[CONSTANT_2]], %[[CONSTANT_2]], %[[CONSTANT_1]], %[[CONSTANT_2]], %[[MLIR_0]]) : (memref, memref, memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr +// CHECK: %[[ALLOC_0:.*]] = memref.alloc() : memref<300xf64> +// CHECK: %[[CAST_3:.*]] = memref.cast %[[ALLOC_0]] : memref<300xf64> to memref +// CHECK: %[[ALLOC_1:.*]] = memref.alloc() : memref<300xi1> +// CHECK: %[[CAST_4:.*]] = memref.cast %[[ALLOC_1]] : memref<300xi1> to memref +// CHECK: %[[ALLOC_2:.*]] = memref.alloc() : memref<300xindex> +// CHECK: %[[CAST_5:.*]] = memref.cast %[[ALLOC_2]] : memref<300xindex> to memref +// CHECK: linalg.fill ins(%[[CONSTANT_0]] : f64) outs(%[[ALLOC_0]] : memref<300xf64>) +// CHECK: linalg.fill ins(%[[CONSTANT_4]] : i1) outs(%[[ALLOC_1]] : memref<300xi1>) +// CHECK: %[[VAL_1:.*]] = call @sparseValuesF64(%[[ARG0]]) : (!llvm.ptr) -> memref +// CHECK: %[[VAL_2:.*]] = call @sparseValuesF64(%[[ARG1]]) : (!llvm.ptr) -> memref +// CHECK: %[[VAL_3:.*]] = call @sparsePositions0(%[[ARG0]], %[[CONSTANT_6]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[VAL_4:.*]] = call @sparseCoordinates0(%[[ARG0]], %[[CONSTANT_6]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[VAL_5:.*]] = call @sparsePositions0(%[[ARG0]], %[[CONSTANT_5]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[VAL_6:.*]] = call @sparseCoordinates0(%[[ARG0]], %[[CONSTANT_5]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[VAL_7:.*]] = call @sparsePositions0(%[[ARG1]], %[[CONSTANT_6]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[VAL_8:.*]] = call @sparseCoordinates0(%[[ARG1]], %[[CONSTANT_6]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[VAL_9:.*]] = call @sparsePositions0(%[[ARG1]], %[[CONSTANT_5]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[VAL_10:.*]] = call @sparseCoordinates0(%[[ARG1]], %[[CONSTANT_5]]) : (!llvm.ptr, index) -> memref +// CHECK: %[[LOAD_0:.*]] = memref.load %[[VAL_3]]{{\[}}%[[CONSTANT_6]]] : memref +// CHECK: %[[LOAD_1:.*]] = memref.load %[[VAL_3]]{{\[}}%[[CONSTANT_5]]] : memref +// CHECK: scf.for %[[VAL_11:.*]] = %[[LOAD_0]] to %[[LOAD_1]] step %[[CONSTANT_5]] { +// CHECK: %[[LOAD_2:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[LOAD_3:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[VAL_11]], %[[CONSTANT_5]] : index +// CHECK: %[[LOAD_4:.*]] = memref.load %[[VAL_5]]{{\[}}%[[ADDI_0]]] : memref +// CHECK: %[[LOAD_5:.*]] = memref.load %[[VAL_7]]{{\[}}%[[CONSTANT_6]]] : memref +// CHECK: %[[LOAD_6:.*]] = memref.load %[[VAL_7]]{{\[}}%[[CONSTANT_5]]] : memref +// CHECK: %[[WHILE_0:.*]]:3 = scf.while (%[[VAL_12:.*]] = %[[LOAD_3]], %[[VAL_13:.*]] = %[[LOAD_5]], %[[VAL_14:.*]] = %[[CONSTANT_6]]) : (index, index, index) -> (index, index, index) { +// CHECK: %[[CMPI_0:.*]] = arith.cmpi ult, %[[VAL_12]], %[[LOAD_4]] : index +// CHECK: %[[CMPI_1:.*]] = arith.cmpi ult, %[[VAL_13]], %[[LOAD_6]] : index +// CHECK: %[[ANDI_0:.*]] = arith.andi %[[CMPI_0]], %[[CMPI_1]] : i1 +// CHECK: scf.condition(%[[ANDI_0]]) %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] : index, index, index // CHECK: } do { -// CHECK: ^bb0(%[[VAL_52:.*]]: index, %[[VAL_53:.*]]: index, %[[VAL_54:.*]]: index): -// CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_29]]{{\[}}%[[VAL_52]]] : memref -// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_53]]] : memref -// CHECK: %[[VAL_57:.*]] = arith.cmpi ult, %[[VAL_56]], %[[VAL_55]] : index -// CHECK: %[[VAL_58:.*]] = arith.select %[[VAL_57]], %[[VAL_56]], %[[VAL_55]] : index -// CHECK: %[[VAL_59:.*]] = arith.cmpi eq, %[[VAL_55]], %[[VAL_58]] : index -// CHECK: %[[VAL_60:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_58]] : index -// CHECK: %[[VAL_61:.*]] = arith.andi %[[VAL_59]], %[[VAL_60]] : i1 -// CHECK: %[[VAL_62:.*]] = scf.if %[[VAL_61]] -> (index) { -// CHECK: %[[VAL_63:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_52]]] : memref -// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_53]]] : memref -// CHECK: %[[VAL_65:.*]] = arith.addi %[[VAL_53]], %[[VAL_6]] : index -// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_65]]] : memref -// CHECK: %[[VAL_67:.*]] = scf.for %[[VAL_68:.*]] = %[[VAL_64]] to %[[VAL_66]] step %[[VAL_6]] iter_args(%[[VAL_69:.*]] = %[[VAL_54]]) -> (index) { -// CHECK: %[[VAL_70:.*]] = memref.load %[[VAL_34]]{{\[}}%[[VAL_68]]] : memref -// CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_70]]] : memref<300xf64> -// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_68]]] : memref -// CHECK: %[[VAL_73:.*]] = arith.mulf %[[VAL_63]], %[[VAL_72]] : f64 -// CHECK: %[[VAL_74:.*]] = arith.addf %[[VAL_71]], %[[VAL_73]] : f64 -// CHECK: %[[VAL_75:.*]] = memref.load %[[VAL_22]]{{\[}}%[[VAL_70]]] : memref<300xi1> -// CHECK: %[[VAL_76:.*]] = arith.cmpi eq, %[[VAL_75]], %[[VAL_7]] : i1 -// CHECK: %[[VAL_77:.*]] = scf.if %[[VAL_76]] -> (index) { -// CHECK: memref.store %[[VAL_8]], %[[VAL_22]]{{\[}}%[[VAL_70]]] : memref<300xi1> -// CHECK: memref.store %[[VAL_70]], %[[VAL_24]]{{\[}}%[[VAL_69]]] : memref<300xindex> -// CHECK: %[[VAL_78:.*]] = arith.addi %[[VAL_69]], %[[VAL_6]] : index -// CHECK: scf.yield %[[VAL_78]] : index +// CHECK: ^bb0(%[[VAL_15:.*]]: index, %[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index): +// CHECK: %[[ADDI_1:.*]] = arith.addi %[[VAL_16]], %[[CONSTANT_5]] : index +// CHECK: %[[LOAD_7:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[LOAD_8:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[CMPI_2:.*]] = arith.cmpi ult, %[[LOAD_8]], %[[LOAD_7]] : index +// CHECK: %[[SELECT_0:.*]] = arith.select %[[CMPI_2]], %[[LOAD_8]], %[[LOAD_7]] : index +// CHECK: %[[CMPI_3:.*]] = arith.cmpi eq, %[[LOAD_7]], %[[SELECT_0]] : index +// CHECK: %[[CMPI_4:.*]] = arith.cmpi eq, %[[LOAD_8]], %[[SELECT_0]] : index +// CHECK: %[[ANDI_1:.*]] = arith.andi %[[CMPI_3]], %[[CMPI_4]] : i1 +// CHECK: %[[IF_0:.*]] = scf.if %[[ANDI_1]] -> (index) { +// CHECK: %[[LOAD_9:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[LOAD_10:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[LOAD_11:.*]] = memref.load %[[VAL_9]]{{\[}}%[[ADDI_1]]] : memref +// CHECK: %[[FOR_0:.*]] = scf.for %[[VAL_18:.*]] = %[[LOAD_10]] to %[[LOAD_11]] step %[[CONSTANT_5]] iter_args(%[[VAL_19:.*]] = %[[VAL_17]]) -> (index) { +// CHECK: %[[LOAD_12:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref +// CHECK: %[[LOAD_13:.*]] = memref.load %[[ALLOC_0]]{{\[}}%[[LOAD_12]]] : memref<300xf64> +// CHECK: %[[LOAD_14:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_18]]] : memref +// CHECK: %[[MULF_0:.*]] = arith.mulf %[[LOAD_9]], %[[LOAD_14]] : f64 +// CHECK: %[[ADDF_0:.*]] = arith.addf %[[LOAD_13]], %[[MULF_0]] : f64 +// CHECK: %[[LOAD_15:.*]] = memref.load %[[ALLOC_1]]{{\[}}%[[LOAD_12]]] : memref<300xi1> +// CHECK: %[[CMPI_5:.*]] = arith.cmpi eq, %[[LOAD_15]], %[[CONSTANT_4]] : i1 +// CHECK: %[[IF_1:.*]] = scf.if %[[CMPI_5]] -> (index) { +// CHECK: memref.store %[[CONSTANT_3]], %[[ALLOC_1]]{{\[}}%[[LOAD_12]]] : memref<300xi1> +// CHECK: memref.store %[[LOAD_12]], %[[ALLOC_2]]{{\[}}%[[VAL_19]]] : memref<300xindex> +// CHECK: %[[ADDI_2:.*]] = arith.addi %[[VAL_19]], %[[CONSTANT_5]] : index +// CHECK: scf.yield %[[ADDI_2]] : index // CHECK: } else { -// CHECK: scf.yield %[[VAL_69]] : index +// CHECK: scf.yield %[[VAL_19]] : index // CHECK: } -// CHECK: memref.store %[[VAL_74]], %[[VAL_20]]{{\[}}%[[VAL_70]]] : memref<300xf64> -// CHECK: scf.yield %[[VAL_77]] : index -// CHECK: } -// CHECK: scf.yield %[[VAL_67]] : index +// CHECK: memref.store %[[ADDF_0]], %[[ALLOC_0]]{{\[}}%[[LOAD_12]]] : memref<300xf64> +// CHECK: scf.yield %[[IF_1]] : index +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: scf.yield %[[FOR_0]] : index // CHECK: } else { -// CHECK: scf.yield %[[VAL_54]] : index +// CHECK: scf.yield %[[VAL_17]] : index // CHECK: } -// CHECK: %[[VAL_79:.*]] = arith.addi %[[VAL_52]], %[[VAL_6]] : index -// CHECK: %[[VAL_80:.*]] = arith.select %[[VAL_59]], %[[VAL_79]], %[[VAL_52]] : index -// CHECK: %[[VAL_81:.*]] = arith.addi %[[VAL_53]], %[[VAL_6]] : index -// CHECK: %[[VAL_82:.*]] = arith.select %[[VAL_60]], %[[VAL_81]], %[[VAL_53]] : index -// CHECK: scf.yield %[[VAL_80]], %[[VAL_82]], %[[VAL_62]] : index, index, index +// CHECK: %[[ADDI_3:.*]] = arith.addi %[[VAL_15]], %[[CONSTANT_5]] : index +// CHECK: %[[SELECT_1:.*]] = arith.select %[[CMPI_3]], %[[ADDI_3]], %[[VAL_15]] : index +// CHECK: %[[SELECT_2:.*]] = arith.select %[[CMPI_4]], %[[ADDI_1]], %[[VAL_16]] : index +// CHECK: scf.yield %[[SELECT_1]], %[[SELECT_2]], %[[IF_0]] : index, index, index // CHECK: } -// CHECK: %[[VAL_83:.*]] = memref.alloca() : memref<2xindex> -// CHECK: %[[VAL_84:.*]] = memref.cast %[[VAL_83]] : memref<2xindex> to memref -// CHECK: memref.store %[[VAL_39]], %[[VAL_83]]{{\[}}%[[VAL_5]]] : memref<2xindex> -// CHECK: func.call @expInsertF64(%[[VAL_19]], %[[VAL_84]], %[[VAL_21]], %[[VAL_23]], %[[VAL_25]], %[[VAL_85:.*]]#2) : (!llvm.ptr, memref, memref, memref, memref, index) -> () -// CHECK: } -// CHECK: memref.dealloc %[[VAL_20]] : memref<300xf64> -// CHECK: memref.dealloc %[[VAL_22]] : memref<300xi1> -// CHECK: memref.dealloc %[[VAL_24]] : memref<300xindex> -// CHECK: call @endLexInsert(%[[VAL_19]]) : (!llvm.ptr) -> () -// CHECK: return %[[VAL_19]] : !llvm.ptr -// CHECK: } +// CHECK: %[[ALLOCA_3:.*]] = memref.alloca() : memref<2xindex> +// CHECK: %[[CAST_6:.*]] = memref.cast %[[ALLOCA_3]] : memref<2xindex> to memref +// CHECK: memref.store %[[LOAD_2]], %[[ALLOCA_3]]{{\[}}%[[CONSTANT_6]]] : memref<2xindex> +// CHECK: func.call @expInsertF64(%[[VAL_0]], %[[CAST_6]], %[[CAST_3]], %[[CAST_4]], %[[CAST_5]], %[[VAL_20:.*]]#2) : (!llvm.ptr, memref, memref, memref, memref, index) -> () +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: memref.dealloc %[[ALLOC_0]] : memref<300xf64> +// CHECK: memref.dealloc %[[ALLOC_1]] : memref<300xi1> +// CHECK: memref.dealloc %[[ALLOC_2]] : memref<300xindex> +// CHECK: call @endLexInsert(%[[VAL_0]]) : (!llvm.ptr) -> () +// CHECK: return %[[VAL_0]] : !llvm.ptr +// CHECK: } func.func @fill_zero_after_alloc(%arg0: tensor<100x200xf64, #DCSR>, %arg1: tensor<200x300xf64, #DCSR>) -> tensor<100x300xf64, #DCSR> { %0 = tensor.empty() : tensor<100x300xf64, #DCSR> diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir index f6f7f396adab5..c4d86b6b6931f 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir @@ -79,76 +79,72 @@ func.func @sqsum(%arg0: tensor) -> tensor { // ITER: } // CHECK-LABEL: func.func @add( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xi32, #sparse{{.*}}>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<10xi32, #sparse{{.*}}>) -> tensor<10xi32> { -// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_5:.*]] = arith.constant dense<0> : tensor<10xi32> -// CHECK: %[[VAL_6:.*]] = bufferization.to_buffer %[[VAL_5]] : tensor<10xi32> to memref<10xi32> -// CHECK: linalg.fill ins(%[[VAL_4]] : i32) outs(%[[VAL_6]] : memref<10xi32>) -// CHECK: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref -// CHECK: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref -// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]]] : memref -// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_2]]] : memref -// CHECK: %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref -// CHECK: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref -// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_3]]] : memref -// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_2]]] : memref -// CHECK: %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_9]], %[[VAL_17:.*]] = %[[VAL_13]]) : (index, index) -> (index, index) { -// CHECK: %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_10]] : index -// CHECK: %[[VAL_19:.*]] = arith.cmpi ult, %[[VAL_17]], %[[VAL_14]] : index -// CHECK: %[[VAL_20:.*]] = arith.andi %[[VAL_18]], %[[VAL_19]] : i1 -// CHECK: scf.condition(%[[VAL_20]]) %[[VAL_16]], %[[VAL_17]] : index, index +// CHECK-SAME: %[[ARG0:.*]]: tensor<10xi32, {{.*}}>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<10xi32, {{.*}}>) -> tensor<10xi32> { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : index +// CHECK: %[[CONSTANT_2:.*]] = arith.constant dense<0> : tensor<10xi32> +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32 +// CHECK: %[[VALUES_0:.*]] = sparse_tensor.values %[[ARG1]] : tensor<10xi32, {{.*}}> to memref +// CHECK: %[[VALUES_1:.*]] = sparse_tensor.values %[[ARG0]] : tensor<10xi32, {{.*}}> to memref +// CHECK: %[[TO_BUFFER_0:.*]] = bufferization.to_buffer %[[CONSTANT_2]] : tensor<10xi32> to memref<10xi32> +// CHECK: linalg.fill ins(%[[CONSTANT_3]] : i32) outs(%[[TO_BUFFER_0]] : memref<10xi32>) +// CHECK: %[[POSITIONS_0:.*]] = sparse_tensor.positions %[[ARG0]] {level = 0 : index} : tensor<10xi32, {{.*}}> to memref +// CHECK: %[[COORDINATES_0:.*]] = sparse_tensor.coordinates %[[ARG0]] {level = 0 : index} : tensor<10xi32, {{.*}}> to memref +// CHECK: %[[LOAD_0:.*]] = memref.load %[[POSITIONS_0]]{{\[}}%[[CONSTANT_1]]] : memref +// CHECK: %[[LOAD_1:.*]] = memref.load %[[POSITIONS_0]]{{\[}}%[[CONSTANT_0]]] : memref +// CHECK: %[[POSITIONS_1:.*]] = sparse_tensor.positions %[[ARG1]] {level = 0 : index} : tensor<10xi32, {{.*}}> to memref +// CHECK: %[[COORDINATES_1:.*]] = sparse_tensor.coordinates %[[ARG1]] {level = 0 : index} : tensor<10xi32, {{.*}}> to memref +// CHECK: %[[LOAD_2:.*]] = memref.load %[[POSITIONS_1]]{{\[}}%[[CONSTANT_1]]] : memref +// CHECK: %[[LOAD_3:.*]] = memref.load %[[POSITIONS_1]]{{\[}}%[[CONSTANT_0]]] : memref +// CHECK: %[[WHILE_0:.*]]:2 = scf.while (%[[VAL_0:.*]] = %[[LOAD_0]], %[[VAL_1:.*]] = %[[LOAD_2]]) : (index, index) -> (index, index) { +// CHECK: %[[CMPI_0:.*]] = arith.cmpi ult, %[[VAL_0]], %[[LOAD_1]] : index +// CHECK: %[[CMPI_1:.*]] = arith.cmpi ult, %[[VAL_1]], %[[LOAD_3]] : index +// CHECK: %[[ANDI_0:.*]] = arith.andi %[[CMPI_0]], %[[CMPI_1]] : i1 +// CHECK: scf.condition(%[[ANDI_0]]) %[[VAL_0]], %[[VAL_1]] : index, index // CHECK: } do { -// CHECK: ^bb0(%[[VAL_21:.*]]: index, %[[VAL_22:.*]]: index): -// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_21]]] : memref -// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]]] : memref -// CHECK: %[[VAL_25:.*]] = arith.cmpi ult, %[[VAL_24]], %[[VAL_23]] : index -// CHECK: %[[VAL_26:.*]] = arith.select %[[VAL_25]], %[[VAL_24]], %[[VAL_23]] : index -// CHECK: %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_26]] : index -// CHECK: %[[VAL_28:.*]] = arith.cmpi eq, %[[VAL_24]], %[[VAL_26]] : index -// CHECK: %[[VAL_29:.*]] = arith.andi %[[VAL_27]], %[[VAL_28]] : i1 -// CHECK: scf.if %[[VAL_29]] { -// CHECK: %[[VAL_30:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref -// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_21]]] : memref -// CHECK: %[[VAL_32:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref -// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_22]]] : memref -// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_31]], %[[VAL_33]] : i32 -// CHECK: memref.store %[[VAL_34]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32> +// CHECK: ^bb0(%[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index): +// CHECK: %[[LOAD_4:.*]] = memref.load %[[COORDINATES_0]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[LOAD_5:.*]] = memref.load %[[COORDINATES_1]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[CMPI_2:.*]] = arith.cmpi ult, %[[LOAD_5]], %[[LOAD_4]] : index +// CHECK: %[[SELECT_0:.*]] = arith.select %[[CMPI_2]], %[[LOAD_5]], %[[LOAD_4]] : index +// CHECK: %[[CMPI_3:.*]] = arith.cmpi eq, %[[LOAD_4]], %[[SELECT_0]] : index +// CHECK: %[[CMPI_4:.*]] = arith.cmpi eq, %[[LOAD_5]], %[[SELECT_0]] : index +// CHECK: %[[ANDI_1:.*]] = arith.andi %[[CMPI_3]], %[[CMPI_4]] : i1 +// CHECK: scf.if %[[ANDI_1]] { +// CHECK: %[[LOAD_6:.*]] = memref.load %[[VALUES_1]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[LOAD_7:.*]] = memref.load %[[VALUES_0]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[LOAD_6]], %[[LOAD_7]] : i32 +// CHECK: memref.store %[[ADDI_0]], %[[TO_BUFFER_0]]{{\[}}%[[SELECT_0]]] : memref<10xi32> // CHECK: } else { -// CHECK: scf.if %[[VAL_27]] { -// CHECK: %[[VAL_35:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref -// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_21]]] : memref -// CHECK: memref.store %[[VAL_36]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32> +// CHECK: scf.if %[[CMPI_3]] { +// CHECK: %[[LOAD_8:.*]] = memref.load %[[VALUES_1]]{{\[}}%[[VAL_2]]] : memref +// CHECK: memref.store %[[LOAD_8]], %[[TO_BUFFER_0]]{{\[}}%[[SELECT_0]]] : memref<10xi32> // CHECK: } else { -// CHECK: scf.if %[[VAL_28]] { -// CHECK: %[[VAL_37:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref -// CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_37]]{{\[}}%[[VAL_22]]] : memref -// CHECK: memref.store %[[VAL_38]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32> +// CHECK: scf.if %[[CMPI_4]] { +// CHECK: %[[LOAD_9:.*]] = memref.load %[[VALUES_0]]{{\[}}%[[VAL_3]]] : memref +// CHECK: memref.store %[[LOAD_9]], %[[TO_BUFFER_0]]{{\[}}%[[SELECT_0]]] : memref<10xi32> // CHECK: } // CHECK: } // CHECK: } -// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_21]], %[[VAL_2]] : index -// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_27]], %[[VAL_39]], %[[VAL_21]] : index -// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_22]], %[[VAL_2]] : index -// CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_28]], %[[VAL_41]], %[[VAL_22]] : index -// CHECK: scf.yield %[[VAL_40]], %[[VAL_42]] : index, index +// CHECK: %[[ADDI_1:.*]] = arith.addi %[[VAL_2]], %[[CONSTANT_0]] : index +// CHECK: %[[SELECT_1:.*]] = arith.select %[[CMPI_3]], %[[ADDI_1]], %[[VAL_2]] : index +// CHECK: %[[ADDI_2:.*]] = arith.addi %[[VAL_3]], %[[CONSTANT_0]] : index +// CHECK: %[[SELECT_2:.*]] = arith.select %[[CMPI_4]], %[[ADDI_2]], %[[VAL_3]] : index +// CHECK: scf.yield %[[SELECT_1]], %[[SELECT_2]] : index, index // CHECK: } -// CHECK: %[[VAL_43:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref -// CHECK: scf.for %[[VAL_44:.*]] = %[[VAL_45:.*]]#0 to %[[VAL_10]] step %[[VAL_2]] { -// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_44]]] : memref -// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_43]]{{\[}}%[[VAL_44]]] : memref -// CHECK: memref.store %[[VAL_47]], %[[VAL_6]]{{\[}}%[[VAL_46]]] : memref<10xi32> +// CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_5:.*]]#0 to %[[LOAD_1]] step %[[CONSTANT_0]] { +// CHECK: %[[LOAD_10:.*]] = memref.load %[[COORDINATES_0]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[LOAD_11:.*]] = memref.load %[[VALUES_1]]{{\[}}%[[VAL_4]]] : memref +// CHECK: memref.store %[[LOAD_11]], %[[TO_BUFFER_0]]{{\[}}%[[LOAD_10]]] : memref<10xi32> // CHECK: } -// CHECK: %[[VAL_48:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref -// CHECK: scf.for %[[VAL_49:.*]] = %[[VAL_50:.*]]#1 to %[[VAL_14]] step %[[VAL_2]] { -// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_49]]] : memref -// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_48]]{{\[}}%[[VAL_49]]] : memref -// CHECK: memref.store %[[VAL_52]], %[[VAL_6]]{{\[}}%[[VAL_51]]] : memref<10xi32> +// CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_7:.*]]#1 to %[[LOAD_3]] step %[[CONSTANT_0]] { +// CHECK: %[[LOAD_12:.*]] = memref.load %[[COORDINATES_1]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[LOAD_13:.*]] = memref.load %[[VALUES_0]]{{\[}}%[[VAL_6]]] : memref +// CHECK: memref.store %[[LOAD_13]], %[[TO_BUFFER_0]]{{\[}}%[[LOAD_12]]] : memref<10xi32> // CHECK: } -// CHECK: %[[VAL_53:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<10xi32> -// CHECK: return %[[VAL_53]] : tensor<10xi32> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[TO_BUFFER_0]] : memref<10xi32> to tensor<10xi32> +// CHECK: return %[[TO_TENSOR_0]] : tensor<10xi32> // CHECK: } func.func @add(%arg0: tensor<10xi32, #VEC>, %arg1: tensor<10xi32, #VEC>) -> tensor<10xi32> { %cst = arith.constant dense<0> : tensor<10xi32> diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir index e9587edef4678..165d0835b5824 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir @@ -58,56 +58,55 @@ func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8 return %r : tensor<8xi64> } -// CHECK-LABEL: func.func @sparse_index_1d_disj( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<8xi64, #sparse{{[0-9]*}}>) -> tensor<8xi64> { -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i64 -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant true -// CHECK-DAG: %[[VAL_7:.*]] = tensor.empty() : tensor<8xi64> -// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref -// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref -// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse{{[0-9]*}}> to memref -// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_buffer %[[VAL_7]] : tensor<8xi64> to memref<8xi64> -// CHECK-DAG: linalg.fill ins(%[[VAL_3]] : i64) outs(%[[VAL_11]] : memref<8xi64>) -// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref -// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref -// CHECK: %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) { -// CHECK: %[[VAL_17:.*]] = arith.cmpi ult, %[[VAL_15]], %[[VAL_13]] : index -// CHECK: scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index +// CHECK-LABEL: func.func @sparse_index_1d_disj( +// CHECK-SAME: %[[ARG0:.*]]: tensor<8xi64, #sparse{{[0-9]*}}>) -> tensor<8xi64> { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> +// CHECK: %[[CONSTANT_1:.*]] = arith.constant true +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 1 : index +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : index +// CHECK: %[[CONSTANT_4:.*]] = arith.constant 8 : index +// CHECK: %[[CONSTANT_5:.*]] = arith.constant 0 : i64 +// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<8xi64> +// CHECK: %[[VALUES_0:.*]] = sparse_tensor.values %[[ARG0]] : tensor<8xi64, #sparse{{[0-9]*}}> to memref +// CHECK: %[[TO_BUFFER_0:.*]] = bufferization.to_buffer %[[EMPTY_0]] : tensor<8xi64> to memref<8xi64> +// CHECK: linalg.fill ins(%[[CONSTANT_5]] : i64) outs(%[[TO_BUFFER_0]] : memref<8xi64>) +// CHECK: %[[POSITIONS_0:.*]] = sparse_tensor.positions %[[ARG0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref +// CHECK: %[[COORDINATES_0:.*]] = sparse_tensor.coordinates %[[ARG0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref +// CHECK: %[[LOAD_0:.*]] = memref.load %[[POSITIONS_0]]{{\[}}%[[CONSTANT_3]]] : memref +// CHECK: %[[LOAD_1:.*]] = memref.load %[[POSITIONS_0]]{{\[}}%[[CONSTANT_2]]] : memref +// CHECK: %[[WHILE_0:.*]]:2 = scf.while (%[[VAL_0:.*]] = %[[LOAD_0]], %[[VAL_1:.*]] = %[[CONSTANT_3]]) : (index, index) -> (index, index) { +// CHECK: %[[CMPI_0:.*]] = arith.cmpi ult, %[[VAL_0]], %[[LOAD_1]] : index +// CHECK: scf.condition(%[[CMPI_0]]) %[[VAL_0]], %[[VAL_1]] : index, index // CHECK: } do { -// CHECK: ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index): -// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref -// CHECK: %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_19]] : index -// CHECK: scf.if %[[VAL_21]] { -// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref -// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_19]] : index to i64 -// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_22]], %[[VAL_23]] : i64 -// CHECK: memref.store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<8xi64> +// CHECK: ^bb0(%[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index): +// CHECK: %[[INDEX_CAST_0:.*]] = arith.index_cast %[[VAL_3]] : index to i64 +// CHECK: %[[LOAD_2:.*]] = memref.load %[[COORDINATES_0]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[CMPI_1:.*]] = arith.cmpi eq, %[[LOAD_2]], %[[VAL_3]] : index +// CHECK: scf.if %[[CMPI_1]] { +// CHECK: %[[LOAD_3:.*]] = memref.load %[[VALUES_0]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[LOAD_3]], %[[INDEX_CAST_0]] : i64 +// CHECK: memref.store %[[ADDI_0]], %[[TO_BUFFER_0]]{{\[}}%[[VAL_3]]] : memref<8xi64> // CHECK: } else { -// CHECK: scf.if %[[VAL_6]] { -// CHECK: %[[VAL_25:.*]] = arith.index_cast %[[VAL_19]] : index to i64 -// CHECK: memref.store %[[VAL_25]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<8xi64> +// CHECK: scf.if %[[CONSTANT_1]] { +// CHECK: memref.store %[[INDEX_CAST_0]], %[[TO_BUFFER_0]]{{\[}}%[[VAL_3]]] : memref<8xi64> // CHECK: } else { // CHECK: } // CHECK: } -// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_18]], %[[VAL_5]] : index -// CHECK: %[[VAL_27:.*]] = arith.select %[[VAL_21]], %[[VAL_26]], %[[VAL_18]] : index -// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index -// CHECK: scf.yield %[[VAL_27]], %[[VAL_28]] : index, index +// CHECK: %[[ADDI_1:.*]] = arith.addi %[[VAL_2]], %[[CONSTANT_2]] : index +// CHECK: %[[SELECT_0:.*]] = arith.select %[[CMPI_1]], %[[ADDI_1]], %[[VAL_2]] : index +// CHECK: %[[ADDI_2:.*]] = arith.addi %[[VAL_3]], %[[CONSTANT_2]] : index +// CHECK: scf.yield %[[SELECT_0]], %[[ADDI_2]] : index, index // CHECK: } attributes {"Emitted from" = "linalg.generic"} -// CHECK: scf.for %[[VAL_29:.*]] = %[[VAL_30:.*]]#1 to %[[VAL_1]] step %[[VAL_1]] { -// CHECK: %[[VAL_31:.*]] = affine.min #map(%[[VAL_1]], %[[VAL_29]]){{\[}}%[[VAL_1]]] -// CHECK: %[[VAL_32:.*]] = vector.create_mask %[[VAL_31]] : vector<8xi1> -// CHECK: %[[VAL_33:.*]] = vector.broadcast %[[VAL_29]] : index to vector<8xindex> -// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_33]], %[[VAL_2]] : vector<8xindex> -// CHECK: %[[VAL_35:.*]] = arith.index_cast %[[VAL_34]] : vector<8xindex> to vector<8xi64> -// CHECK: vector.maskedstore %[[VAL_11]]{{\[}}%[[VAL_29]]], %[[VAL_32]], %[[VAL_35]] : memref<8xi64>, vector<8xi1>, vector<8xi64> +// CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_5:.*]]#1 to %[[CONSTANT_4]] step %[[CONSTANT_4]] { +// CHECK: %[[MIN_0:.*]] = affine.min #{{.*}}(%[[CONSTANT_4]], %[[VAL_4]]){{\[}}%[[CONSTANT_4]]] +// CHECK: %[[CREATE_MASK_0:.*]] = vector.create_mask %[[MIN_0]] : vector<8xi1> +// CHECK: %[[BROADCAST_0:.*]] = vector.broadcast %[[VAL_4]] : index to vector<8xindex> +// CHECK: %[[ADDI_3:.*]] = arith.addi %[[BROADCAST_0]], %[[CONSTANT_0]] : vector<8xindex> +// CHECK: %[[INDEX_CAST_1:.*]] = arith.index_cast %[[ADDI_3]] : vector<8xindex> to vector<8xi64> +// CHECK: vector.maskedstore %[[TO_BUFFER_0]]{{\[}}%[[VAL_4]]], %[[CREATE_MASK_0]], %[[INDEX_CAST_1]] : memref<8xi64>, vector<8xi1>, vector<8xi64> // CHECK: } {"Emitted from" = "linalg.generic"} -// CHECK: %[[VAL_36:.*]] = bufferization.to_tensor %[[VAL_11]] : memref<8xi64> -// CHECK: return %[[VAL_36]] : tensor<8xi64> +// CHECK: %[[TO_TENSOR_0:.*]] = bufferization.to_tensor %[[TO_BUFFER_0]] : memref<8xi64> to tensor<8xi64> +// CHECK: return %[[TO_TENSOR_0]] : tensor<8xi64> // CHECK: } func.func @sparse_index_1d_disj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8xi64> { %init = tensor.empty() : tensor<8xi64> diff --git a/mlir/test/Pass/ir-printing.mlir b/mlir/test/Pass/ir-printing.mlir index 360b347043722..0da7011d33207 100644 --- a/mlir/test/Pass/ir-printing.mlir +++ b/mlir/test/Pass/ir-printing.mlir @@ -16,50 +16,50 @@ func.func @bar() { return } -// BEFORE: // -----// IR Dump Before{{.*}}CSEPass: cse //----- // +// BEFORE: // -----// IR Dump Before{{.*}}CSEPass: cse{{.*}} //----- // // BEFORE: func @foo() -// BEFORE: // -----// IR Dump Before{{.*}}CSEPass: cse //----- // +// BEFORE: // -----// IR Dump Before{{.*}}CSEPass: cse{{.*}} //----- // // BEFORE: func @bar() // BEFORE-NOT: // -----// IR Dump Before{{.*}}CanonicalizerPass: canonicalize //----- // // BEFORE-NOT: // -----// IR Dump After -// BEFORE_ALL: // -----// IR Dump Before{{.*}}CSEPass: cse //----- // +// BEFORE_ALL: // -----// IR Dump Before{{.*}}CSEPass: cse{{.*}} //----- // // BEFORE_ALL: func @foo() // BEFORE_ALL: // -----// IR Dump Before{{.*}}CanonicalizerPass: canonicalize{{.*}} //----- // // BEFORE_ALL: func @foo() -// BEFORE_ALL: // -----// IR Dump Before{{.*}}CSEPass: cse //----- // +// BEFORE_ALL: // -----// IR Dump Before{{.*}}CSEPass: cse{{.*}} //----- // // BEFORE_ALL: func @bar() // BEFORE_ALL: // -----// IR Dump Before{{.*}}CanonicalizerPass: canonicalize{{.*}} //----- // // BEFORE_ALL: func @bar() // BEFORE_ALL-NOT: // -----// IR Dump After // AFTER-NOT: // -----// IR Dump Before -// AFTER: // -----// IR Dump After{{.*}}CSEPass: cse //----- // +// AFTER: // -----// IR Dump After{{.*}}CSEPass: cse{{.*}} //----- // // AFTER: func @foo() -// AFTER: // -----// IR Dump After{{.*}}CSEPass: cse //----- // +// AFTER: // -----// IR Dump After{{.*}}CSEPass: cse{{.*}} //----- // // AFTER: func @bar() // AFTER-NOT: // -----// IR Dump After{{.*}}CanonicalizerPass: canonicalize{{.*}} //----- // // AFTER_ALL-NOT: // -----// IR Dump Before -// AFTER_ALL: // -----// IR Dump After{{.*}}CSEPass: cse //----- // +// AFTER_ALL: // -----// IR Dump After{{.*}}CSEPass: cse{{.*}} //----- // // AFTER_ALL: func @foo() // AFTER_ALL: // -----// IR Dump After{{.*}}CanonicalizerPass: canonicalize{{.*}} //----- // // AFTER_ALL: func @foo() -// AFTER_ALL: // -----// IR Dump After{{.*}}CSEPass: cse //----- // +// AFTER_ALL: // -----// IR Dump After{{.*}}CSEPass: cse{{.*}} //----- // // AFTER_ALL: func @bar() // AFTER_ALL: // -----// IR Dump After{{.*}}CanonicalizerPass: canonicalize{{.*}} //----- // // AFTER_ALL: func @bar() -// BEFORE_MODULE: // -----// IR Dump Before{{.*}}CSEPass: cse ('func.func' operation: @foo) //----- // +// BEFORE_MODULE: // -----// IR Dump Before{{.*}}CSEPass: cse{{.*}} ('func.func' operation: @foo) //----- // // BEFORE_MODULE: func @foo() // BEFORE_MODULE: func @bar() -// BEFORE_MODULE: // -----// IR Dump Before{{.*}}CSEPass: cse ('func.func' operation: @bar) //----- // +// BEFORE_MODULE: // -----// IR Dump Before{{.*}}CSEPass: cse{{.*}} ('func.func' operation: @bar) //----- // // BEFORE_MODULE: func @foo() // BEFORE_MODULE: func @bar() -// AFTER_ALL_CHANGE: // -----// IR Dump After{{.*}}CSEPass: cse //----- // +// AFTER_ALL_CHANGE: // -----// IR Dump After{{.*}}CSEPass: cse{{.*}} //----- // // AFTER_ALL_CHANGE: func @foo() -// AFTER_ALL_CHANGE-NOT: // -----// IR Dump After{{.*}}CSEPass: cse //----- // +// AFTER_ALL_CHANGE-NOT: // -----// IR Dump After{{.*}}CSEPass: cse{{.*}} //----- // // We expect that only 'foo' changed during CSE, and the second run of CSE did // nothing. diff --git a/mlir/test/Pass/run-reproducer.mlir b/mlir/test/Pass/run-reproducer.mlir index 5b3b21090d169..b086a7b529364 100644 --- a/mlir/test/Pass/run-reproducer.mlir +++ b/mlir/test/Pass/run-reproducer.mlir @@ -16,19 +16,19 @@ func.func @bar() { verify_each: true, // CHECK: builtin.module( // CHECK-NEXT: func.func( - // CHECK-NEXT: cse, + // CHECK-NEXT: cse{hoist-pure-ops=true}, // CHECK-NEXT: canonicalize{cse-between-iterations=false max-iterations=1 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=false} // CHECK-NEXT: ) // CHECK-NEXT: ) - pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=normal top-down=false}))", + pipeline: "builtin.module(func.func(cse{hoist-pure-ops=1},canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=normal top-down=false}))", disable_threading: true } } #-} -// BEFORE: // -----// IR Dump Before{{.*}}CSEPass: cse //----- // +// BEFORE: // -----// IR Dump Before{{.*}}CSEPass: cse{{.*}} //----- // // BEFORE: func @foo() -// BEFORE: // -----// IR Dump Before{{.*}}CSEPass: cse //----- // +// BEFORE: // -----// IR Dump Before{{.*}}CSEPass: cse{{.*}} //----- // // BEFORE: func @bar() // BEFORE-NOT: // -----// IR Dump Before{{.*}}CanonicalizerPass: canonicalize //----- // // BEFORE-NOT: // -----// IR Dump After diff --git a/mlir/test/Transforms/composite-pass.mlir b/mlir/test/Transforms/composite-pass.mlir index 03c540d72185b..61e18fc7c4cab 100644 --- a/mlir/test/Transforms/composite-pass.mlir +++ b/mlir/test/Transforms/composite-pass.mlir @@ -4,7 +4,7 @@ // Ensure the composite pass correctly prints its options. // PIPELINE: builtin.module( // PIPELINE-NEXT: composite-fixed-point-pass{max-iterations=10 name=TestCompositePass -// PIPELINE-SAME: pipeline=canonicalize{cse-between-iterations=false max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},cse} +// PIPELINE-SAME: pipeline=canonicalize{cse-between-iterations=false max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true},cse{hoist-pure-ops=true}} // CHECK-LABEL: running `TestCompositePass` // CHECK: running `CanonicalizerPass` diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir index 4b2907287d89e..55ab37eec53a4 100644 --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -163,16 +163,14 @@ func.func @down_propagate() -> i32 { /// Check that operation definitions are NOT propagated up the dominance tree. // CHECK-LABEL: @up_propagate_for func.func @up_propagate_for() -> i32 { + // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32 // CHECK: affine.for {{.*}} = 0 to 4 { - affine.for %i = 0 to 4 { - // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32 + affine.for %i = 0 to 4 { // CHECK-NEXT: "foo"(%[[VAR_c1_i32_0]]) : (i32) -> () %0 = arith.constant 1 : i32 "foo"(%0) : (i32) -> () } - - // CHECK: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32 - // CHECK-NEXT: return %[[VAR_c1_i32]] : i32 + // CHECK: return %[[VAR_c1_i32_0]] : i32 %1 = arith.constant 1 : i32 return %1 : i32 } @@ -181,7 +179,8 @@ func.func @up_propagate_for() -> i32 { // CHECK-LABEL: func @up_propagate func.func @up_propagate() -> i32 { - // CHECK-NEXT: %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32 + // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32 + // CHECK-NEXT: %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32 %0 = arith.constant 0 : i32 // CHECK-NEXT: %[[VAR_true:[0-9a-zA-Z_]+]] = arith.constant true @@ -191,17 +190,15 @@ func.func @up_propagate() -> i32 { cf.cond_br %cond, ^bb1, ^bb2(%0 : i32) ^bb1: // CHECK: ^bb1: - // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32 %1 = arith.constant 1 : i32 // CHECK-NEXT: cf.br ^bb2(%[[VAR_c1_i32]] : i32) cf.br ^bb2(%1 : i32) ^bb2(%arg : i32): // CHECK: ^bb2 - // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32 %2 = arith.constant 1 : i32 - // CHECK-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = arith.addi %{{.*}}, %[[VAR_c1_i32_0]] : i32 + // CHECK-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = arith.addi %{{.*}}, %[[VAR_c1_i32]] : i32 %add = arith.addi %arg, %2 : i32 // CHECK-NEXT: return %[[VAR_1]] : i32 @@ -216,6 +213,7 @@ func.func @up_propagate() -> i32 { func.func @up_propagate_region() -> i32 { // CHECK-NEXT: {{.*}} "foo.region" %0 = "foo.region"() ({ + // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32 // CHECK-NEXT: %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32 // CHECK-NEXT: %[[VAR_true:[0-9a-zA-Z_]+]] = arith.constant true // CHECK-NEXT: cf.cond_br @@ -225,15 +223,13 @@ func.func @up_propagate_region() -> i32 { cf.cond_br %true, ^bb1, ^bb2(%1 : i32) ^bb1: // CHECK: ^bb1: - // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32 // CHECK-NEXT: cf.br %c1_i32 = arith.constant 1 : i32 cf.br ^bb2(%c1_i32 : i32) ^bb2(%arg : i32): // CHECK: ^bb2(%[[VAR_1:.*]]: i32): - // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32 - // CHECK-NEXT: %[[VAR_2:[0-9a-zA-Z_]+]] = arith.addi %[[VAR_1]], %[[VAR_c1_i32_0]] : i32 + // CHECK-NEXT: %[[VAR_2:[0-9a-zA-Z_]+]] = arith.addi %[[VAR_1]], %[[VAR_c1_i32]] : i32 // CHECK-NEXT: "foo.yield"(%[[VAR_2]]) : (i32) -> () %c1_i32_0 = arith.constant 1 : i32 @@ -484,7 +480,7 @@ func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor // ----- -func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) { +func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>, tensor<5xf32>) { %r1 = scf.if %c -> (tensor<5xf32>) { %0 = tensor.empty() : tensor<5xf32> scf.yield %0 : tensor<5xf32> @@ -497,17 +493,76 @@ func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, te } else { scf.yield %t : tensor<5xf32> } - return %r1, %r2 : tensor<5xf32>, tensor<5xf32> + %r3 = scf.if %c -> (tensor<5xf32>) { + %0 = tensor.empty() : tensor<5xf32> + scf.yield %0 : tensor<5xf32> + } else { + scf.yield %t : tensor<5xf32> + } + return %r1, %r2, %r3 : tensor<5xf32>, tensor<5xf32>, tensor<5xf32> } // CHECK-LABEL: func @cse_multiple_regions -// CHECK: %[[if:.*]] = scf.if {{.*}} { -// CHECK: tensor.empty +// CHECK: tensor.empty +// CHECK: %[[if:.*]] = scf.if {{.*}} // CHECK: scf.yield // CHECK: } else { // CHECK: scf.yield // CHECK: } // CHECK-NOT: scf.if -// CHECK: return %[[if]], %[[if]] +// CHECK: return %[[if]], %[[if]], %[[if]] + +// ----- + +func.func @cse_multiple_regions(%c: i1, %t: i32) -> (i32, i32) { + %init = "test.producer"() : () -> i32 + %r1 = scf.if %c -> (i32) { + %r11 = scf.if %c -> (i32) { + %0 = arith.addi %init, %init : i32 + %1 = arith.muli %0, %0 : i32 + scf.yield %1 : i32 + } else { + %0 = arith.addi %init, %init : i32 + %1 = arith.muli %0, %0 : i32 + scf.yield %1 : i32 + } + scf.yield %r11 : i32 + } else { + scf.yield %t : i32 + } + %r2 = scf.if %c -> (i32) { + %r11 = scf.if %c -> (i32) { + %0 = arith.addi %init, %init : i32 + %1 = arith.muli %0, %0 : i32 + scf.yield %1 : i32 + } else { + %0 = arith.addi %init, %init : i32 + %1 = arith.muli %0, %0 : i32 + scf.yield %1 : i32 + } + scf.yield %r11 : i32 + } else { + scf.yield %t : i32 + } + return %r1, %r2 : i32, i32 +} +// CHECK-LABEL: func @cse_multiple_regions +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i32 +// CHECK: %[[VAL_0:.*]] = "test.producer"() : () -> i32 +// CHECK: %[[ADDI_0:.*]] = arith.addi %[[VAL_0]], %[[VAL_0]] : i32 +// CHECK: %[[MULI_0:.*]] = arith.muli %[[ADDI_0]], %[[ADDI_0]] : i32 +// CHECK: %[[IF_0:.*]] = scf.if %[[ARG0]] -> (i32) { +// CHECK: scf.yield %[[MULI_0]] : i32 +// CHECK: } else { +// CHECK: scf.yield %[[MULI_0]] : i32 +// CHECK: } +// CHECK: %[[IF_1:.*]] = scf.if %[[ARG0]] -> (i32) { +// CHECK: cf.yield %[[IF_0]] : i32 +// CHECK: } else { +// CHECK: scf.yield %[[ARG1]] : i32 +// CHECK: } +// CHECK: return %[[IF_1]], %[[IF_1]] : i32, i32 +// CHECK: } // ----- @@ -683,3 +738,58 @@ func.func @cse_pointer_write_does_not_block_non_addressable_read() -> i32 { %2 = arith.addi %0, %1 : i32 return %2 : i32 } + +// ----- + +func.func @cse_hoist_blocked_by_isolated_region() -> (i32, i32) { + %0 = "test.always_speculatable_op"() : () -> i32 + %1 = "test.isolated_one_region_op"() ({ + %1 = "test.always_speculatable_op"() : () -> i32 + %2 = "test.always_speculatable_op"() : () -> i32 + %3 = arith.addi %1, %2 : i32 + "test.region_yield"(%3) : (i32) -> () + }) : () -> (i32) + return %0, %1 : i32, i32 +} +// CHECK-LABEL: func @cse_hoist_blocked_by_isolated_region +// CHECK: %[[PURE_0:.*]] = "test.always_speculatable_op"() +// CHECK: %[[ISOLATED_ONE_REGION_OP:.*]] = test.isolated_one_region_op { +// CHECK: %[[PURE_1:.*]] = "test.always_speculatable_op"() +// CHECK: %[[ADDI:.*]] = arith.addi %[[PURE_1]], %[[PURE_1]] +// CHECK: test.region_yield %[[ADDI]] +// CHECK: } : -> i32 +// CHECK: return %[[PURE_0]], %[[ISOLATED_ONE_REGION_OP]] +// CHECK: } + +// ----- + +func.func @cse_no_hoist_opportunity_with_nested_isolated_regions() -> (i32) { + %1 = "test.always_speculatable_op"() : () -> i32 + test.isolated_regions { + %2 = "test.always_speculatable_op"() : () -> i32 + test.region_yield %2 : i32 + }, { + %2 = "test.always_speculatable_op"() : () -> i32 + test.isolated_regions { + %3 = "test.always_speculatable_op"() : () -> i32 + test.region_yield %3 : i32 + } + test.region_yield %2 : i32 + } + return %1 : i32 +} +// CHECK-LABEL: func @cse_no_hoist_opportunity_with_nested_isolated_regions +// CHECK: %[[PURE_0:.*]] = "test.always_speculatable_op"() +// CHECK: test.isolated_regions { +// CHECK: %[[PURE_1:.*]] = "test.always_speculatable_op"() +// CHECK: test.region_yield %[[PURE_1]] +// CHECK: }, { +// CHECK: %[[PURE_2:.*]] = "test.always_speculatable_op"() +// CHECK: test.isolated_regions { +// CHECK: %[[PURE_3:.*]] = "test.always_speculatable_op"() +// CHECK: test.region_yield %[[PURE_3]] +// CHECK: } +// CHECK: test.region_yield %[[PURE_2]] +// CHECK: } +// CHECK: return %[[PURE_0]] +// CHECK: } diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index 39b57d321033a..5427547037e29 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -115,10 +115,10 @@ def testAdd(): pm = PassManager("any", Context()) # CHECK: pm: 'any()' log(f"pm: '{pm}'") - # CHECK: pm: 'any(cse)' + # CHECK: pm: 'any(cse{hoist-pure-ops=true})' pm.add("cse") log(f"pm: '{pm}'") - # CHECK: pm: 'any(cse,cse)' + # CHECK: pm: 'any(cse{hoist-pure-ops=true},cse{hoist-pure-ops=true})' pm.add("cse") log(f"pm: '{pm}'")