diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td index d910f2e02d090..896859d5ee375 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td @@ -11,7 +11,7 @@ include "mlir/Pass/PassBase.td" -def SCFBufferize : Pass<"scf-bufferize", "func::FuncOp"> { +def SCFBufferize : Pass<"scf-bufferize"> { let summary = "Bufferize the scf dialect."; let constructor = "mlir::createSCFBufferizePass()"; let dependentDialects = ["bufferization::BufferizationDialect", @@ -21,14 +21,14 @@ def SCFBufferize : Pass<"scf-bufferize", "func::FuncOp"> { // Note: Making these canonicalization patterns would require a dependency // of the SCF dialect on the Affine/Tensor/MemRef dialects or vice versa. def SCFForLoopCanonicalization - : Pass<"scf-for-loop-canonicalization", "func::FuncOp"> { + : Pass<"scf-for-loop-canonicalization"> { let summary = "Canonicalize operations within scf.for loop bodies"; let constructor = "mlir::createSCFForLoopCanonicalizationPass()"; let dependentDialects = ["AffineDialect", "tensor::TensorDialect", "memref::MemRefDialect"]; } -def SCFForLoopPeeling : Pass<"scf-for-loop-peeling", "func::FuncOp"> { +def SCFForLoopPeeling : Pass<"scf-for-loop-peeling"> { let summary = "Peel `for` loops at their upper bounds."; let constructor = "mlir::createForLoopPeelingPass()"; let options = [ @@ -40,7 +40,7 @@ def SCFForLoopPeeling : Pass<"scf-for-loop-peeling", "func::FuncOp"> { let dependentDialects = ["AffineDialect"]; } -def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization", "func::FuncOp"> { +def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> { let summary = "Specialize `for` loops for vectorization"; let constructor = "mlir::createForLoopSpecializationPass()"; } @@ -64,12 +64,12 @@ def SCFParallelLoopCollapsing : Pass<"scf-parallel-loop-collapsing"> { } def SCFParallelLoopSpecialization - : Pass<"scf-parallel-loop-specialization", "func::FuncOp"> { + : Pass<"scf-parallel-loop-specialization"> { let summary = "Specialize parallel loops for vectorization"; let constructor = "mlir::createParallelLoopSpecializationPass()"; } -def SCFParallelLoopTiling : Pass<"scf-parallel-loop-tiling", "func::FuncOp"> { +def SCFParallelLoopTiling : Pass<"scf-parallel-loop-tiling"> { let summary = "Tile parallel loops"; let constructor = "mlir::createParallelLoopTilingPass()"; let options = [ @@ -88,7 +88,7 @@ def SCFForLoopRangeFolding : Pass<"scf-for-loop-range-folding"> { let constructor = "mlir::createForLoopRangeFoldingPass()"; } -def SCFForToWhileLoop : Pass<"scf-for-to-while", "func::FuncOp"> { +def SCFForToWhileLoop : Pass<"scf-for-to-while"> { let summary = "Convert SCF for loops to SCF while loops"; let constructor = "mlir::createForToWhileLoopPass()"; let description = [{ diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp index 17cc6f3773390..14eb075d8c897 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -100,11 +100,11 @@ struct ForLoopLoweringPattern : public OpRewritePattern { struct ForToWhileLoop : public SCFForToWhileLoopBase { void runOnOperation() override { - func::FuncOp funcOp = getOperation(); - MLIRContext *ctx = funcOp.getContext(); + auto *parentOp = getOperation(); + MLIRContext *ctx = parentOp->getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns)); } }; } // namespace diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp index eda6bc6e1cf8b..18d43d72e210b 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -195,11 +195,11 @@ struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern { struct SCFForLoopCanonicalization : public SCFForLoopCanonicalizationBase { void runOnOperation() override { - func::FuncOp funcOp = getOperation(); - MLIRContext *ctx = funcOp.getContext(); + auto *parentOp = getOperation(); + MLIRContext *ctx = parentOp->getContext(); RewritePatternSet patterns(ctx); scf::populateSCFForLoopCanonicalizationPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) + if (failed(applyPatternsAndFoldGreedily(parentOp, std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp index 1195f7f8a5672..aa0056b683907 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -237,7 +237,7 @@ namespace { struct ParallelLoopSpecialization : public SCFParallelLoopSpecializationBase { void runOnOperation() override { - getOperation().walk( + getOperation()->walk( [](ParallelOp op) { specializeParallelLoopForUnrolling(op); }); } }; @@ -245,20 +245,20 @@ struct ParallelLoopSpecialization struct ForLoopSpecialization : public SCFForLoopSpecializationBase { void runOnOperation() override { - getOperation().walk([](ForOp op) { specializeForLoopForUnrolling(op); }); + getOperation()->walk([](ForOp op) { specializeForLoopForUnrolling(op); }); } }; struct ForLoopPeeling : public SCFForLoopPeelingBase { void runOnOperation() override { - func::FuncOp funcOp = getOperation(); - MLIRContext *ctx = funcOp.getContext(); + auto *parentOp = getOperation(); + MLIRContext *ctx = parentOp->getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx, skipPartial); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns)); // Drop the markers. - funcOp.walk([](Operation *op) { + parentOp->walk([](Operation *op) { op->removeAttr(kPeeledLoopLabel); op->removeAttr(kPartialIterationLabel); }); diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp index f20764647d4b4..c39d3afae25ce 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp @@ -195,8 +195,9 @@ struct ParallelLoopTiling } void runOnOperation() override { + auto *parentOp = getOperation(); SmallVector innermostPloops; - getInnermostParallelLoops(getOperation().getOperation(), innermostPloops); + getInnermostParallelLoops(parentOp, innermostPloops); for (ParallelOp ploop : innermostPloops) { // FIXME: Add reduction support. if (ploop.getNumReductions() == 0)