diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index bbba495e613b2..0d1ef4dc89829 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -984,7 +984,13 @@ def ReconcileUnrealizedCastsPass : Pass<"reconcile-unrealized-casts"> { def SCFToControlFlowPass : Pass<"convert-scf-to-cf"> { let summary = "Convert SCF dialect to ControlFlow dialect, replacing structured" " control flow with a CFG"; - let dependentDialects = ["cf::ControlFlowDialect"]; + let dependentDialects = ["cf::ControlFlowDialect", "LLVM::LLVMDialect"]; + + let options = [Option<"enableVectorizeHints", "enable-vectorize-hints", + "bool", + /*default=*/"false", + "Add vectorization hints when convert SCF parallel " + "loop to ControlFlow dialect">]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h b/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h index 2def01d208f72..ca1185d6bb3b5 100644 --- a/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h +++ b/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h @@ -20,7 +20,8 @@ class RewritePatternSet; /// Collect a set of patterns to convert SCF operations to CFG branch-based /// operations within the ControlFlow dialect. -void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns); +void populateSCFToControlFlowConversionPatterns( + RewritePatternSet &patterns, bool enableVectorizeHints = false); } // namespace mlir diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 114d634629d77..071067adcdb4b 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -38,6 +38,7 @@ namespace { struct SCFToControlFlowPass : public impl::SCFToControlFlowPassBase { + using Base::Base; void runOnOperation() override; }; @@ -212,6 +213,11 @@ struct ExecuteRegionLowering : public OpRewritePattern { struct ParallelLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + bool enableVectorizeHints; + + ParallelLowering(mlir::MLIRContext *ctx, bool enableVectorizeHints) + : OpRewritePattern(ctx), enableVectorizeHints(enableVectorizeHints) {} + LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp, PatternRewriter &rewriter) const override; }; @@ -487,6 +493,13 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp, return failure(); } + auto vecAttr = LLVM::LoopVectorizeAttr::get( + rewriter.getContext(), + /*disable=*/rewriter.getBoolAttr(false), {}, {}, {}, {}, {}, {}); + auto loopAnnotation = LLVM::LoopAnnotationAttr::get( + rewriter.getContext(), {}, /*vectorize=*/vecAttr, {}, {}, {}, {}, {}, {}, + {}, {}, {}, {}, {}, {}, {}); + // For a parallel loop, we essentially need to create an n-dimensional loop // nest. We do this by translating to scf.for ops and have those lowered in // a further rewrite. If a parallel loop contains reductions (and thus returns @@ -517,6 +530,11 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp, rewriter.create(loc, forOp.getResults()); } + if (enableVectorizeHints) + forOp->setAttr(LLVM::BrOp::getLoopAnnotationAttrName(OperationName( + LLVM::BrOp::getOperationName(), getContext())), + loopAnnotation); + rewriter.setInsertionPointToStart(forOp.getBody()); } @@ -706,16 +724,18 @@ LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp, } void mlir::populateSCFToControlFlowConversionPatterns( - RewritePatternSet &patterns) { - patterns.add( + RewritePatternSet &patterns, bool enableVectorizeHints) { + patterns.add( patterns.getContext()); + patterns.add(patterns.getContext(), enableVectorizeHints); patterns.add(patterns.getContext(), /*benefit=*/2); } void SCFToControlFlowPass::runOnOperation() { RewritePatternSet patterns(&getContext()); - populateSCFToControlFlowConversionPatterns(patterns); + populateSCFToControlFlowConversionPatterns(patterns, + enableVectorizeHints.getValue()); // Configure conversion to lower out SCF operations. ConversionTarget target(getContext()); diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir index 9ea0093eff786..8db76d6c7466e 100644 --- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir +++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir @@ -1,4 +1,9 @@ -// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf -split-input-file %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf -split-input-file %s | FileCheck %s --check-prefixes=CHECK,NO-VEC +// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf="enable-vectorize-hints=true" \ +// RUN: -split-input-file %s | FileCheck %s --check-prefixes=CHECK,VEC + +// VEC: #loop_vectorize = #llvm.loop_vectorize +// VEC-NEXT: #[[$VEC_ATTR:.+]] = #llvm.loop_annotation // CHECK-LABEL: func @simple_std_for_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { // CHECK-NEXT: cf.br ^bb1(%{{.*}} : index) @@ -332,7 +337,8 @@ func.func @simple_parallel_reduce_loop(%arg0: index, %arg1: index, // variable and the current partially reduced value. // CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG:.*]]: f32 // CHECK: %[[COMP:.*]] = arith.cmpi slt, %[[ITER]], %[[UB]] - // CHECK: cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]] + // NO-VEC: cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]] + // VEC: cf.cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]] {loop_annotation = #[[$VEC_ATTR]]} // Bodies of scf.reduce operations are folded into the main loop body. The // result of this partial reduction is passed as argument to the condition @@ -366,11 +372,13 @@ func.func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index, // CHECK: %[[INIT2:.*]] = arith.constant 42 // CHECK: cf.br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]] // CHECK: ^[[COND_OUT]](%{{.*}}: index, %[[ITER_ARG1_OUT:.*]]: f32, %[[ITER_ARG2_OUT:.*]]: i64 - // CHECK: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]] + // NO-VEC: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]] + // VEC: cf.cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]] {loop_annotation = #[[$VEC_ATTR]]} // CHECK: ^[[BODY_OUT]]: // CHECK: cf.br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]] // CHECK: ^[[COND_IN]](%{{.*}}: index, %[[ITER_ARG1_IN:.*]]: f32, %[[ITER_ARG2_IN:.*]]: i64 - // CHECK: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]] + // NO-VEC: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]] + // VEC: cf.cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]] {loop_annotation = #[[$VEC_ATTR]]} // CHECK: ^[[BODY_IN]]: // CHECK: %[[REDUCE1:.*]] = arith.addf %[[ITER_ARG1_IN]], %{{.*}} // CHECK: %[[REDUCE2:.*]] = arith.ori %[[ITER_ARG2_IN]], %{{.*}} @@ -551,7 +559,8 @@ func.func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1, // CHECK: cf.br ^[[LOOP_LATCH:.*]](%[[ARG0]] : index) // CHECK: ^[[LOOP_LATCH]](%[[LOOP_IV:.*]]: index): // CHECK: %[[LOOP_COND:.*]] = arith.cmpi slt, %[[LOOP_IV]], %[[ARG1]] : index - // CHECK: cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]] + // NO-VEC: cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]] + // VEC: cf.cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]] {loop_annotation = #[[$VEC_ATTR]]} // CHECK: ^[[LOOP_BODY]]: // CHECK: cf.cond_br %[[ARG3]], ^[[IF1_THEN:.*]], ^[[IF1_CONT:.*]] // CHECK: ^[[IF1_THEN]]: @@ -660,7 +669,8 @@ func.func @index_switch(%i: index, %a: i32, %b: i32, %c: i32) -> i32 { // CHECK: cf.br ^[[bb1:.*]](%[[c0]] : index) // CHECK: ^[[bb1]](%[[arg0:.*]]: index): // CHECK: %[[cmpi:.*]] = arith.cmpi slt, %[[arg0]], %[[num_threads]] -// CHECK: cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]] +// NO-VEC: cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]] +// VEC: cf.cond_br %[[cmpi]], ^[[bb2:.*]], ^[[bb3:.*]] {loop_annotation = #[[$VEC_ATTR]]} // CHECK: ^[[bb2]]: // CHECK: "test.foo"(%[[arg0]]) // CHECK: %[[addi:.*]] = arith.addi %[[arg0]], %[[c1]]