diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h index a66284562b765..228a6b5718269 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms.h @@ -30,6 +30,7 @@ class RewritePatternSet; class Operation; class Value; class ValueRange; +class PatternRewriter; namespace scf { @@ -140,7 +141,21 @@ struct PipeliningOption { using AnnotationlFnType = std::function; AnnotationlFnType annotateFn = nullptr; - // TODO: add option to decide if the prologue/epilogue should be peeled. + + /// Control whether the epilogue should be peeled out of the loop or + /// operations should be predicated to skip the early stages in the last loop + /// iterations. If the epilogue is predicated; the user needs to provide a + /// lambda to generate the predicated version of operations. + bool peelEpilogue = true; + + // Lamdba to predicate operations when the prologue or epilogue are not + // peeled. This takes the original operation, an i1 predicate value and the + // pattern rewriter. + using PredicateOpFn = + std::function; + PredicateOpFn predicateFn = nullptr; + + // TODO: add option to decide if the prologue should be peeled. }; /// Populate patterns for SCF software pipelining transformation. diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp index fa16e90f7530c..659d248b3b7ed 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -41,6 +41,8 @@ struct LoopPipelinerInternal { int64_t lb; int64_t step; PipeliningOption::AnnotationlFnType annotateFn = nullptr; + bool peelEpilogue; + PipeliningOption::PredicateOpFn predicateFn = nullptr; // When peeling the kernel we generate several version of each value for // different stage of the prologue. This map tracks the mapping between @@ -91,6 +93,10 @@ bool LoopPipelinerInternal::initializeLoopInfo( ub = upperBoundCst.value(); lb = lowerBoundCst.value(); step = stepCst.value(); + peelEpilogue = options.peelEpilogue; + predicateFn = options.predicateFn; + if (!peelEpilogue && predicateFn == nullptr) + return false; int64_t numIteration = ceilDiv(ub - lb, step); std::vector> schedule; options.getScheduleFn(forOp, schedule); @@ -226,10 +232,13 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop( } } - // Create the new kernel loop. Since we need to peel `numStages - 1` - // iteration we change the upper bound to remove those iterations. - Value newUb = rewriter.create(forOp.getLoc(), - ub - maxStage * step); + // Create the new kernel loop. When we peel the epilgue we need to peel + // `numStages - 1` iterations. Then we adjust the upper bound to remove those + // iterations. + Value newUb = forOp.getUpperBound(); + if (peelEpilogue) + newUb = rewriter.create(forOp.getLoc(), + ub - maxStage * step); auto newForOp = rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, forOp.getStep(), newLoopArg); @@ -252,6 +261,18 @@ void LoopPipelinerInternal::createKernel( for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) { mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); } + SmallVector predicates(maxStage + 1, nullptr); + if (!peelEpilogue) { + // Create a predicate for each stage except the last stage. + for (unsigned i = 0; i < maxStage; i++) { + Value c = rewriter.create( + newForOp.getLoc(), ub - (maxStage - i) * step); + Value pred = rewriter.create( + newForOp.getLoc(), arith::CmpIPredicate::slt, + newForOp.getInductionVar(), c); + predicates[i] = pred; + } + } for (Operation *op : opOrder) { int64_t useStage = stages[op]; auto *newOp = rewriter.clone(*op, mapping); @@ -300,6 +321,13 @@ void LoopPipelinerInternal::createKernel( newOp->setOperand(operand.getOperandNumber(), newForOp.getRegionIterArgs()[remap->second]); } + if (predicates[useStage]) { + newOp = predicateFn(newOp, predicates[useStage], rewriter); + // Remap the results to the new predicated one. + for (auto values : llvm::zip(op->getResults(), newOp->getResults())) + mapping.map(std::get<0>(values), std::get<1>(values)); + } + rewriter.setInsertionPointAfter(newOp); if (annotateFn) annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0); } @@ -455,10 +483,13 @@ struct ForLoopPipelining : public OpRewritePattern { // operands. pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, rewriter); - // 4. Emit the epilogue after the new forOp. - rewriter.setInsertionPointAfter(newForOp); - llvm::SmallVector returnValues = pipeliner.emitEpilogue(rewriter); - + llvm::SmallVector returnValues = + newForOp.getResults().take_front(forOp->getNumResults()); + if (options.peelEpilogue) { + // 4. Emit the epilogue after the new forOp. + rewriter.setInsertionPointAfter(newForOp); + returnValues = pipeliner.emitEpilogue(rewriter); + } // 5. Erase the original loop and replace the uses with the epilogue output. if (forOp->getNumResults() > 0) rewriter.replaceOp(forOp, returnValues); diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir index d7e8827c46a3d..0246231f5b743 100644 --- a/mlir/test/Dialect/SCF/loop-pipelining.mlir +++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s -test-scf-pipelining -split-input-file | FileCheck %s // RUN: mlir-opt %s -test-scf-pipelining=annotate -split-input-file | FileCheck %s --check-prefix ANNOTATE +// RUN: mlir-opt %s -test-scf-pipelining=no-epilogue-peeling -split-input-file | FileCheck %s --check-prefix NOEPILOGUE // CHECK-LABEL: simple_pipeline( // CHECK-SAME: %[[A:.*]]: memref, %[[R:.*]]: memref) { @@ -114,6 +115,44 @@ func.func @simple_pipeline_step(%A: memref, %result: memref) { // ANNOTATE: arith.addf {{.*}} {__test_pipelining_iteration = 0 : i32, __test_pipelining_part = "epilogue"} // ANNOTATE: memref.store {{.*}} {__test_pipelining_iteration = 1 : i32, __test_pipelining_part = "epilogue"} +// NOEPILOGUE-LABEL: three_stage( +// NOEPILOGUE-SAME: %[[A:.*]]: memref, %[[R:.*]]: memref) { +// NOEPILOGUE-DAG: %[[C0:.*]] = arith.constant 0 : index +// NOEPILOGUE-DAG: %[[C1:.*]] = arith.constant 1 : index +// NOEPILOGUE-DAG: %[[C2:.*]] = arith.constant 2 : index +// NOEPILOGUE-DAG: %[[C3:.*]] = arith.constant 3 : index +// NOEPILOGUE-DAG: %[[C4:.*]] = arith.constant 4 : index +// NOEPILOGUE-DAG: %[[CF:.*]] = arith.constant 0.000000e+00 : f32 +// Prologue: +// NOEPILOGUE: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref +// NOEPILOGUE-NEXT: %[[ADD0:.*]] = arith.addf %[[L0]], %{{.*}} : f32 +// NOEPILOGUE-NEXT: %[[L1:.*]] = memref.load %[[A]][%[[C1]]] : memref +// Kernel: +// NOEPILOGUE-NEXT: %[[LR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] +// NOEPILOGUE-SAME: step %[[C1]] iter_args(%[[ADDARG:.*]] = %[[ADD0]], +// NOEPILOGUE-SAME: %[[LARG:.*]] = %[[L1]]) -> (f32, f32) { +// NOEPILOGUE-DAG: %[[S0:.*]] = arith.cmpi slt, %[[IV]], %[[C2]] : index +// NOEPILOGUE-DAG: %[[S1:.*]] = arith.cmpi slt, %[[IV]], %[[C3]] : index +// NOEPILOGUE-NEXT: memref.store %[[ADDARG]], %[[R]][%[[IV]]] : memref +// NOEPILOGUE-NEXT: %[[ADD1:.*]] = scf.if %[[S1]] -> (f32) { +// NOEPILOGUE-NEXT: %[[PADD:.*]] = arith.addf %[[LARG]], %{{.*}} : f32 +// NOEPILOGUE-NEXT: scf.yield %[[PADD]] : f32 +// NOEPILOGUE-NEXT: } else { +// NOEPILOGUE-NEXT: scf.yield %[[CF]] : f32 +// NOEPILOGUE-NEXT: } +// NOEPILOGUE-NEXT: %[[IV2:.*]] = arith.addi %[[IV]], %[[C2]] : index +// NOEPILOGUE-NEXT: %[[L3:.*]] = scf.if %[[S0]] -> (f32) { +// NOEPILOGUE-NEXT: %[[PL:.*]] = memref.load %[[A]][%[[IV2]]] : memref +// NOEPILOGUE-NEXT: scf.yield %[[PL]] : f32 +// NOEPILOGUE-NEXT: } else { +// NOEPILOGUE-NEXT: scf.yield %[[CF]] : f32 +// NOEPILOGUE-NEXT: } +// NOEPILOGUE-NEXT: scf.yield %[[ADD1]], %[[L3]] : f32, f32 +// NOEPILOGUE-NEXT: } +// No epilogue should be generated. +// NOEPILOGUE-NOT: memref.store +// NOEPILOGUE: return + func.func @three_stage(%A: memref, %result: memref) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp index b52b4f524b808..4f6bdaff81aa2 100644 --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -123,6 +123,11 @@ struct TestSCFPipeliningPass llvm::cl::desc("Annote operations during loop pipelining transformation"), llvm::cl::init(false)}; + Option noEpiloguePeeling{ + *this, "no-epilogue-peeling", + llvm::cl::desc("Use predicates instead of peeling the epilogue."), + llvm::cl::init(false)}; + static void getSchedule(scf::ForOp forOp, std::vector> &schedule) { @@ -141,6 +146,29 @@ struct TestSCFPipeliningPass }); } + /// Helper to generate "predicated" version of `op`. For simplicity we just + /// wrap the operation in a scf.ifOp operation. + static Operation *predicateOp(Operation *op, Value pred, + PatternRewriter &rewriter) { + Location loc = op->getLoc(); + auto ifOp = + rewriter.create(loc, op->getResultTypes(), pred, true); + // True branch. + op->moveBefore(&ifOp.getThenRegion().front(), + ifOp.getThenRegion().front().end()); + rewriter.setInsertionPointAfter(op); + rewriter.create(loc, op->getResults()); + // False branch. + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + SmallVector zeros; + for (Type type : op->getResultTypes()) { + zeros.push_back( + rewriter.create(loc, rewriter.getZeroAttr(type))); + } + rewriter.create(loc, zeros); + return ifOp.getOperation(); + } + static void annotate(Operation *op, mlir::scf::PipeliningOption::PipelinerPart part, unsigned iteration) { @@ -170,6 +198,10 @@ struct TestSCFPipeliningPass options.getScheduleFn = getSchedule; if (annotatePipeline) options.annotateFn = annotate; + if (noEpiloguePeeling) { + options.peelEpilogue = false; + options.predicateFn = predicateOp; + } scf::populateSCFLoopPipeliningPatterns(patterns, options); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); getOperation().walk([](Operation *op) {