diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index 344ffe977caf5c..8d411d5964c5d7 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -62,12 +62,6 @@ def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> { let summary = "Lower the operations from the linalg dialect into affine " "loops"; let constructor = "mlir::createConvertLinalgToAffineLoopsPass()"; - let options = [ - ListOption<"interchangeVector", "interchange-vector", "unsigned", - "Permute the loops in the nest following the given " - "interchange vector", - "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> - ]; let dependentDialects = [ "AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"]; } @@ -75,12 +69,6 @@ def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> { def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> { let summary = "Lower the operations from the linalg dialect into loops"; let constructor = "mlir::createConvertLinalgToLoopsPass()"; - let options = [ - ListOption<"interchangeVector", "interchange-vector", "unsigned", - "Permute the loops in the nest following the given " - "interchange vector", - "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> - ]; let dependentDialects = [ "linalg::LinalgDialect", "scf::SCFDialect", @@ -103,12 +91,6 @@ def LinalgLowerToParallelLoops let summary = "Lower the operations from the linalg dialect into parallel " "loops"; let constructor = "mlir::createConvertLinalgToParallelLoopsPass()"; - let options = [ - ListOption<"interchangeVector", "interchange-vector", "unsigned", - "Permute the loops in the nest following the given " - "interchange vector", - "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> - ]; let dependentDialects = [ "AffineDialect", "linalg::LinalgDialect", diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 251a2f8e6d0347..2338198b5f2e7d 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -338,28 +338,16 @@ LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op, /// Emits a loop nest of `LoopTy` with the proper body for `op`. template -Optional -linalgLowerOpToLoops(OpBuilder &builder, Operation *op, - ArrayRef interchangeVector = {}); - -/// Emits a loop nest of `scf.for` with the proper body for `op`. The generated -/// loop nest will follow the `interchangeVector`-permutated iterator order. If -/// `interchangeVector` is empty, then no permutation happens. -LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op, - ArrayRef interchangeVector = {}); - -/// Emits a loop nest of `scf.parallel` with the proper body for `op`. The -/// generated loop nest will follow the `interchangeVector`-permutated -// iterator order. If `interchangeVector` is empty, then no permutation happens. -LogicalResult -linalgOpToParallelLoops(OpBuilder &builder, Operation *op, - ArrayRef interchangeVector = {}); +Optional linalgLowerOpToLoops(OpBuilder &builder, Operation *op); + +/// Emits a loop nest of `scf.for` with the proper body for `op`. +LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op); -/// Emits a loop nest of `affine.for` with the proper body for `op`. The -/// generated loop nest will follow the `interchangeVector`-permutated -// iterator order. If `interchangeVector` is empty, then no permutation happens. -LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op, - ArrayRef interchangeVector = {}); +/// Emits a loop nest of `scf.parallel` with the proper body for `op`. +LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op); + +/// Emits a loop nest of `affine.for` with the proper body for `op`. +LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op); //===----------------------------------------------------------------------===// // Preconditions that ensure the corresponding transformation succeeds and can @@ -808,10 +796,9 @@ struct LinalgLoweringPattern : public RewritePattern { LinalgLoweringPattern( MLIRContext *context, LinalgLoweringType loweringType, LinalgTransformationFilter filter = LinalgTransformationFilter(), - ArrayRef interchangeVector = {}, PatternBenefit benefit = 1) + PatternBenefit benefit = 1) : RewritePattern(OpTy::getOperationName(), benefit, context), - filter(filter), loweringType(loweringType), - interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} + filter(filter), loweringType(loweringType) {} // TODO: Move implementation to .cpp once named ops are auto-generated. LogicalResult matchAndRewrite(Operation *op, @@ -827,15 +814,15 @@ struct LinalgLoweringPattern : public RewritePattern { // TODO: Move lowering to library calls here. return failure(); case LinalgLoweringType::Loops: - if (failed(linalgOpToLoops(rewriter, op, interchangeVector))) + if (failed(linalgOpToLoops(rewriter, op))) return failure(); break; case LinalgLoweringType::AffineLoops: - if (failed(linalgOpToAffineLoops(rewriter, op, interchangeVector))) + if (failed(linalgOpToAffineLoops(rewriter, op))) return failure(); break; case LinalgLoweringType::ParallelLoops: - if (failed(linalgOpToParallelLoops(rewriter, op, interchangeVector))) + if (failed(linalgOpToParallelLoops(rewriter, op))) return failure(); break; } @@ -850,8 +837,6 @@ struct LinalgLoweringPattern : public RewritePattern { /// Controls whether the pattern lowers to library calls, scf.for, affine.for /// or scf.parallel. LinalgLoweringType loweringType; - /// Permutated loop order in the generated loop nest. - SmallVector interchangeVector; }; /// Linalg generalization patterns diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index c85f4a9abd3836..f19493c3cca9ca 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -457,9 +457,8 @@ static void emitScalarImplementation(ArrayRef allIvs, } template -static Optional -linalgOpToLoopsImpl(Operation *op, OpBuilder &builder, - ArrayRef interchangeVector) { +static Optional linalgOpToLoopsImpl(Operation *op, + OpBuilder &builder) { using IndexedValueTy = typename GenerateLoopNest::IndexedValueTy; ScopedContext scope(builder, op->getLoc()); @@ -472,13 +471,6 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder, auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc()); auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); - if (!interchangeVector.empty()) { - assert(interchangeVector.size() == loopRanges.size()); - assert(interchangeVector.size() == iteratorTypes.size()); - applyPermutationToVector(loopRanges, interchangeVector); - applyPermutationToVector(iteratorTypes, interchangeVector); - } - SmallVector allIvs; GenerateLoopNest::doit( loopRanges, /*iterInitArgs=*/{}, iteratorTypes, @@ -511,11 +503,10 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder, } /// Replace the index operations in the body of the loop nest by the matching -/// induction variables. If available use the interchange vector to map the -/// interchanged induction variables to the dimension of the index operation. -static void replaceIndexOpsByInductionVariables( - LinalgOp linalgOp, PatternRewriter &rewriter, ArrayRef loopOps, - ArrayRef interchangeVector) { +/// induction variables. +static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp, + PatternRewriter &rewriter, + ArrayRef loopOps) { // Extract the induction variables of the loop nest from outer to inner. SmallVector allIvs; for (Operation *loopOp : loopOps) { @@ -538,16 +529,8 @@ static void replaceIndexOpsByInductionVariables( if (!loopOps.empty()) { LoopLikeOpInterface loopOp = loopOps.back(); for (IndexOp indexOp : - llvm::make_early_inc_range(loopOp.getLoopBody().getOps())) { - // Search the indexing dimension in the interchange vector if available. - assert(interchangeVector.empty() || - interchangeVector.size() == linalgOp.getNumLoops()); - const auto *it = llvm::find(interchangeVector, indexOp.dim()); - uint64_t dim = it != interchangeVector.end() - ? std::distance(interchangeVector.begin(), it) - : indexOp.dim(); - rewriter.replaceOp(indexOp, allIvs[dim]); - } + llvm::make_early_inc_range(loopOp.getLoopBody().getOps())) + rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]); } } @@ -555,39 +538,31 @@ namespace { template class LinalgRewritePattern : public RewritePattern { public: - LinalgRewritePattern(MLIRContext *context, - ArrayRef interchangeVector) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), - interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} + LinalgRewritePattern(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { auto linalgOp = dyn_cast(op); if (!isa(op)) return failure(); - Optional loopOps = - linalgOpToLoopsImpl(op, rewriter, interchangeVector); + Optional loopOps = linalgOpToLoopsImpl(op, rewriter); if (!loopOps.hasValue()) return failure(); - replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue(), - interchangeVector); + replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue()); rewriter.eraseOp(op); return success(); } - -private: - SmallVector interchangeVector; }; struct FoldAffineOp; } // namespace template -static void lowerLinalgToLoopsImpl(FuncOp funcOp, - ArrayRef interchangeVector) { +static void lowerLinalgToLoopsImpl(FuncOp funcOp) { MLIRContext *context = funcOp.getContext(); RewritePatternSet patterns(context); - patterns.add>(context, interchangeVector); + patterns.add>(context); memref::DimOp::getCanonicalizationPatterns(patterns, context); AffineApplyOp::getCanonicalizationPatterns(patterns, context); patterns.add(context); @@ -639,7 +614,7 @@ struct LowerToAffineLoops registry.insert(); } void runOnFunction() override { - lowerLinalgToLoopsImpl(getFunction(), interchangeVector); + lowerLinalgToLoopsImpl(getFunction()); } }; @@ -648,14 +623,14 @@ struct LowerToLoops : public LinalgLowerToLoopsBase { registry.insert(); } void runOnFunction() override { - lowerLinalgToLoopsImpl(getFunction(), interchangeVector); + lowerLinalgToLoopsImpl(getFunction()); } }; struct LowerToParallelLoops : public LinalgLowerToParallelLoopsBase { void runOnFunction() override { - lowerLinalgToLoopsImpl(getFunction(), interchangeVector); + lowerLinalgToLoopsImpl(getFunction()); } }; } // namespace @@ -676,43 +651,38 @@ mlir::createConvertLinalgToAffineLoopsPass() { /// Emits a loop nest with the proper body for `op`. template -Optional -mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, Operation *op, - ArrayRef interchangeVector) { - return linalgOpToLoopsImpl(op, builder, interchangeVector); +Optional mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, + Operation *op) { + return linalgOpToLoopsImpl(op, builder); } -template Optional mlir::linalg::linalgLowerOpToLoops( - OpBuilder &builder, Operation *op, ArrayRef interchangeVector); -template Optional mlir::linalg::linalgLowerOpToLoops( - OpBuilder &builder, Operation *op, ArrayRef interchangeVector); template Optional -mlir::linalg::linalgLowerOpToLoops( - OpBuilder &builder, Operation *op, ArrayRef interchangeVector); +mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, + Operation *op); +template Optional +mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, + Operation *op); +template Optional +mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, + Operation *op); /// Emits a loop nest of `affine.for` with the proper body for `op`. -LogicalResult -mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, Operation *op, - ArrayRef interchangeVector) { - Optional loops = - linalgLowerOpToLoops(builder, op, interchangeVector); +LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, + Operation *op) { + Optional loops = linalgLowerOpToLoops(builder, op); return loops ? success() : failure(); } /// Emits a loop nest of `scf.for` with the proper body for `op`. -LogicalResult -mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op, - ArrayRef interchangeVector) { - Optional loops = - linalgLowerOpToLoops(builder, op, interchangeVector); +LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) { + Optional loops = linalgLowerOpToLoops(builder, op); return loops ? success() : failure(); } /// Emits a loop nest of `scf.parallel` with the proper body for `op`. -LogicalResult -mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, Operation *op, - ArrayRef interchangeVector) { +LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, + Operation *op) { Optional loops = - linalgLowerOpToLoops(builder, op, interchangeVector); + linalgLowerOpToLoops(builder, op); return loops ? success() : failure(); } diff --git a/mlir/test/Dialect/Linalg/loop-order.mlir b/mlir/test/Dialect/Linalg/loop-order.mlir deleted file mode 100644 index c572967e6d1014..00000000000000 --- a/mlir/test/Dialect/Linalg/loop-order.mlir +++ /dev/null @@ -1,72 +0,0 @@ -// RUN: mlir-opt %s -convert-linalg-to-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=LOOP %s -// RUN: mlir-opt %s -convert-linalg-to-parallel-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=PARALLEL %s -// RUN: mlir-opt %s -convert-linalg-to-affine-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=AFFINE %s - -func @copy(%input: memref<1x2x3x4x5xf32>, %output: memref<1x2x3x4x5xf32>) { - linalg.copy(%input, %output): memref<1x2x3x4x5xf32>, memref<1x2x3x4x5xf32> - return -} - -// LOOP: scf.for %{{.*}} = %c0 to %c5 step %c1 -// LOOP: scf.for %{{.*}} = %c0 to %c1 step %c1 -// LOOP: scf.for %{{.*}} = %c0 to %c4 step %c1 -// LOOP: scf.for %{{.*}} = %c0 to %c2 step %c1 -// LOOP: scf.for %{{.*}} = %c0 to %c3 step %c1 - -// PARALLEL: scf.parallel -// PARALLEL-SAME: to (%c5, %c1, %c4, %c2, %c3) - -// AFFINE: affine.for %{{.*}} = 0 to 5 -// AFFINE: affine.for %{{.*}} = 0 to 1 -// AFFINE: affine.for %{{.*}} = 0 to 4 -// AFFINE: affine.for %{{.*}} = 0 to 2 -// AFFINE: affine.for %{{.*}} = 0 to 3 - -// ----- - -#map = affine_map<(i, j, k, l, m) -> (i, j, k, l, m)> -func @generic(%output: memref<1x2x3x4x5xindex>) { - linalg.generic {indexing_maps = [#map], - iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} - outs(%output : memref<1x2x3x4x5xindex>) { - ^bb0(%arg0 : index): - %i = linalg.index 0 : index - %j = linalg.index 1 : index - %k = linalg.index 2 : index - %l = linalg.index 3 : index - %m = linalg.index 4 : index - %0 = addi %i, %j : index - %1 = addi %0, %k : index - %2 = addi %1, %l : index - %3 = addi %2, %m : index - linalg.yield %3: index - } - return -} - -// LOOP: scf.for %[[m:.*]] = %c0 to %c5 step %c1 -// LOOP: scf.for %[[i:.*]] = %c0 to %c1 step %c1 -// LOOP: scf.for %[[l:.*]] = %c0 to %c4 step %c1 -// LOOP: scf.for %[[j:.*]] = %c0 to %c2 step %c1 -// LOOP: scf.for %[[k:.*]] = %c0 to %c3 step %c1 -// LOOP: %{{.*}} = addi %[[i]], %[[j]] : index -// LOOP: %{{.*}} = addi %{{.*}}, %[[k]] : index -// LOOP: %{{.*}} = addi %{{.*}}, %[[l]] : index -// LOOP: %{{.*}} = addi %{{.*}}, %[[m]] : index - -// PARALLEL: scf.parallel (%[[m:.*]], %[[i:.*]], %[[l:.*]], %[[j:.*]], %[[k:.*]]) = -// PARALLEL-SAME: to (%c5, %c1, %c4, %c2, %c3) -// PARALLEL: %{{.*}} = addi %[[i]], %[[j]] : index -// PARALLEL: %{{.*}} = addi %{{.*}}, %[[k]] : index -// PARALLEL: %{{.*}} = addi %{{.*}}, %[[l]] : index -// PARALLEL: %{{.*}} = addi %{{.*}}, %[[m]] : index - -// AFFINE: affine.for %[[m:.*]] = 0 to 5 -// AFFINE: affine.for %[[i:.*]] = 0 to 1 -// AFFINE: affine.for %[[l:.*]] = 0 to 4 -// AFFINE: affine.for %[[j:.*]] = 0 to 2 -// AFFINE: affine.for %[[k:.*]] = 0 to 3 -// AFFINE: %{{.*}} = addi %[[i]], %[[j]] : index -// AFFINE: %{{.*}} = addi %{{.*}}, %[[k]] : index -// AFFINE: %{{.*}} = addi %{{.*}}, %[[l]] : index -// AFFINE: %{{.*}} = addi %{{.*}}, %[[m]] : index