diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 24b2cf1dc422bc..30940fcc891f9f 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -355,11 +355,11 @@ def ConvertPDLToPDLInterp : Pass<"convert-pdl-to-pdl-interp", "ModuleOp"> { // SCFToOpenMP //===----------------------------------------------------------------------===// -def ConvertSCFToOpenMP : FunctionPass<"convert-scf-to-openmp"> { +def ConvertSCFToOpenMP : Pass<"convert-scf-to-openmp", "ModuleOp"> { let summary = "Convert SCF parallel loop to OpenMP parallel + workshare " "constructs."; let constructor = "mlir::createConvertSCFToOpenMPPass()"; - let dependentDialects = ["omp::OpenMPDialect"]; + let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect"]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h b/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h index 349c4e1efc8331..4000bc1df46b2e 100644 --- a/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h +++ b/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h @@ -12,11 +12,11 @@ #include namespace mlir { -class FuncOp; +class ModuleOp; template class OperationPass; -std::unique_ptr> createConvertSCFToOpenMPPass(); +std::unique_ptr> createConvertSCFToOpenMPPass(); } // namespace mlir diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 08396a345c6e66..05d406d09fc677 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -317,12 +317,12 @@ def TargetOp : OpenMP_Op<"target",[AttrSizedOperandSegments]> { The optional $device parameter specifies the device number for the target region. The optional $thread_limit specifies the limit on the number of threads - + The optional $nowait elliminates the implicit barrier so the parent task can make progress even if the target task is not yet completed. - + TODO: private, map, is_device_ptr, firstprivate, depend, defaultmap, in_reduction - + }]; let arguments = (ins Optional:$if_expr, diff --git a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt index 1ef4b74da6dd72..1a75a3549f9795 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt @@ -11,7 +11,9 @@ add_mlir_conversion_library(MLIRSCFToOpenMP Core LINK_LIBS PUBLIC + MLIRLLVMIR MLIROpenMP MLIRSCF + MLIRStandard MLIRTransforms ) diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index a7d4a99c9d5b59..9c6fc6fed91394 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -13,26 +13,311 @@ #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" #include "../PassDetail.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; +/// Matches a block containing a "simple" reduction. The expected shape of the +/// block is as follows. +/// +/// ^bb(%arg0, %arg1): +/// %0 = OpTy(%arg0, %arg1) +/// scf.reduce.return %0 +template +static bool matchSimpleReduction(Block &block) { + if (block.empty() || llvm::hasSingleElement(block) || + std::next(block.begin(), 2) != block.end()) + return false; + return isa(block.front()) && + isa(block.back()) && + block.front().getOperands() == block.getArguments() && + block.back().getOperand(0) == block.front().getResult(0); +} + +/// Matches a block containing a select-based min/max reduction. The types of +/// select and compare operations are provided as template arguments. The +/// comparison predicates suitable for min and max are provided as function +/// arguments. If a reduction is matched, `ifMin` will be set if the reduction +/// compute the minimum and unset if it computes the maximum, otherwise it +/// remains unmodified. The expected shape of the block is as follows. +/// +/// ^bb(%arg0, %arg1): +/// %0 = CompareOpTy(, %arg0, %arg1) +/// %1 = SelectOpTy(%0, %arg0, %arg1) // %arg0, %arg1 may be swapped here. +/// scf.reduce.return %1 +template < + typename CompareOpTy, typename SelectOpTy, + typename Predicate = decltype(std::declval().predicate())> +static bool +matchSelectReduction(Block &block, ArrayRef lessThanPredicates, + ArrayRef greaterThanPredicates, bool &isMin) { + static_assert(llvm::is_one_of::value, + "only std and llvm select ops are supported"); + + // Expect exactly three operations in the block. + if (block.empty() || llvm::hasSingleElement(block) || + std::next(block.begin(), 2) == block.end() || + std::next(block.begin(), 3) != block.end()) + return false; + + // Check op kinds. + auto compare = dyn_cast(block.front()); + auto select = dyn_cast(block.front().getNextNode()); + auto terminator = dyn_cast(block.back()); + if (!compare || !select || !terminator) + return false; + + // Block arguments must be compared. + if (compare->getOperands() != block.getArguments()) + return false; + + // Detect whether the comparison is less-than or greater-than, otherwise bail. + bool isLess; + if (llvm::find(lessThanPredicates, compare.predicate()) != + lessThanPredicates.end()) { + isLess = true; + } else if (llvm::find(greaterThanPredicates, compare.predicate()) != + greaterThanPredicates.end()) { + isLess = false; + } else { + return false; + } + + if (select.condition() != compare.getResult()) + return false; + + // Detect if the operands are swapped between cmpf and select. Match the + // comparison type with the requested type or with the opposite of the + // requested type if the operands are swapped. Use generic accessors because + // std and LLVM versions of select have different operand names but identical + // positions. + constexpr unsigned kTrueValue = 1; + constexpr unsigned kFalseValue = 2; + bool sameOperands = select.getOperand(kTrueValue) == compare.lhs() && + select.getOperand(kFalseValue) == compare.rhs(); + bool swappedOperands = select.getOperand(kTrueValue) == compare.rhs() && + select.getOperand(kFalseValue) == compare.lhs(); + if (!sameOperands && !swappedOperands) + return false; + + if (select.getResult() != terminator.result()) + return false; + + // The reduction is a min if it uses less-than predicates with same operands + // or greather-than predicates with swapped operands. Similarly for max. + isMin = (isLess && sameOperands) || (!isLess && swappedOperands); + return isMin || (isLess & swappedOperands) || (!isLess && sameOperands); +} + +/// Returns the float semantics for the given float type. +static const llvm::fltSemantics &fltSemanticsForType(FloatType type) { + if (type.isF16()) + return llvm::APFloat::IEEEhalf(); + if (type.isF32()) + return llvm::APFloat::IEEEsingle(); + if (type.isF64()) + return llvm::APFloat::IEEEdouble(); + if (type.isF128()) + return llvm::APFloat::IEEEquad(); + if (type.isBF16()) + return llvm::APFloat::BFloat(); + if (type.isF80()) + return llvm::APFloat::x87DoubleExtended(); + llvm_unreachable("unknown float type"); +} + +/// Returns an attribute with the minimum (if `min` is set) or the maximum value +/// (otherwise) for the given float type. +static Attribute minMaxValueForFloat(Type type, bool min) { + auto fltType = type.cast(); + return FloatAttr::get( + type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min)); +} + +/// Returns an attribute with the signed integer minimum (if `min` is set) or +/// the maximum value (otherwise) for the given integer type, regardless of its +/// signedness semantics (only the width is considered). +static Attribute minMaxValueForSignedInt(Type type, bool min) { + auto intType = type.cast(); + unsigned bitwidth = intType.getWidth(); + return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth) + : llvm::APInt::getSignedMaxValue(bitwidth)); +} + +/// Returns an attribute with the unsigned integer minimum (if `min` is set) or +/// the maximum value (otherwise) for the given integer type, regardless of its +/// signedness semantics (only the width is considered). +static Attribute minMaxValueForUnsignedInt(Type type, bool min) { + auto intType = type.cast(); + unsigned bitwidth = intType.getWidth(); + return IntegerAttr::get(type, min ? llvm::APInt::getNullValue(bitwidth) + : llvm::APInt::getAllOnesValue(bitwidth)); +} + +/// Creates an OpenMP reduction declaration and inserts it into the provided +/// symbol table. The declaration has a constant initializer with the neutral +/// value `initValue`, and the reduction combiner carried over from `reduce`. +static omp::ReductionDeclareOp createDecl(PatternRewriter &builder, + SymbolTable &symbolTable, + scf::ReduceOp reduce, + Attribute initValue) { + OpBuilder::InsertionGuard guard(builder); + auto decl = builder.create( + reduce.getLoc(), "__scf_reduction", reduce.operand().getType()); + symbolTable.insert(decl); + + Type type = reduce.operand().getType(); + builder.createBlock(&decl.initializerRegion(), decl.initializerRegion().end(), + {type}); + builder.setInsertionPointToEnd(&decl.initializerRegion().back()); + Value init = + builder.create(reduce.getLoc(), type, initValue); + builder.create(reduce.getLoc(), init); + + Operation *terminator = &reduce.getRegion().front().back(); + assert(isa(terminator) && + "expected reduce op to be terminated by redure return"); + builder.setInsertionPoint(terminator); + builder.replaceOpWithNewOp(terminator, + terminator->getOperands()); + builder.inlineRegionBefore(reduce.getRegion(), decl.reductionRegion(), + decl.reductionRegion().end()); + return decl; +} + +/// Adds an atomic reduction combiner to the given OpenMP reduction declaration +/// using llvm.atomicrmw of the given kind. +static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder, + LLVM::AtomicBinOp atomicKind, + omp::ReductionDeclareOp decl, + scf::ReduceOp reduce) { + OpBuilder::InsertionGuard guard(builder); + Type type = reduce.operand().getType(); + Type ptrType = LLVM::LLVMPointerType::get(type); + builder.createBlock(&decl.atomicReductionRegion(), + decl.atomicReductionRegion().end(), {ptrType, ptrType}); + Block *atomicBlock = &decl.atomicReductionRegion().back(); + builder.setInsertionPointToEnd(atomicBlock); + Value loaded = builder.create(reduce.getLoc(), + atomicBlock->getArgument(1)); + builder.create(reduce.getLoc(), type, atomicKind, + atomicBlock->getArgument(0), loaded, + LLVM::AtomicOrdering::monotonic); + builder.create(reduce.getLoc(), ArrayRef()); + return decl; +} + +/// Creates an OpenMP reduction declaration that corresponds to the given SCF +/// reduction and returns it. Recognizes common reductions in order to identify +/// the neutral value, necessary for the OpenMP declaration. If the reduction +/// cannot be recognized, returns null. +static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder, + scf::ReduceOp reduce) { + Operation *container = SymbolTable::getNearestSymbolTable(reduce); + SymbolTable symbolTable(container); + + // Insert reduction declarations in the symbol-table ancestor before the + // ancestor of the current insertion point. + Operation *insertionPoint = reduce; + while (insertionPoint->getParentOp() != container) + insertionPoint = insertionPoint->getParentOp(); + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(insertionPoint); + + assert(llvm::hasSingleElement(reduce.getRegion()) && + "expected reduction region to have a single element"); + + // Match simple binary reductions that can be expressed with atomicrmw. + Type type = reduce.operand().getType(); + Block &reduction = reduce.getRegion().front(); + if (matchSimpleReduction(reduction)) { + omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, + builder.getFloatAttr(type, 0.0)); + return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce); + } + if (matchSimpleReduction(reduction)) { + omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, + builder.getIntegerAttr(type, 0)); + return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce); + } + if (matchSimpleReduction(reduction)) { + omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, + builder.getIntegerAttr(type, 0)); + return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce); + } + if (matchSimpleReduction(reduction)) { + omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, + builder.getIntegerAttr(type, 0)); + return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce); + } + if (matchSimpleReduction(reduction)) { + omp::ReductionDeclareOp decl = createDecl( + builder, symbolTable, reduce, + builder.getIntegerAttr( + type, llvm::APInt::getAllOnesValue(type.getIntOrFloatBitWidth()))); + return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce); + } + + // Match simple binary reductions that cannot be expressed with atomicrmw. + // TODO: add atomic region using cmpxchg (which needs atomic load to be + // available as an op). + if (matchSimpleReduction(reduction)) { + return createDecl(builder, symbolTable, reduce, + builder.getFloatAttr(type, 1.0)); + } + + // Match select-based min/max reductions. + bool isMin; + if (matchSelectReduction( + reduction, {CmpFPredicate::OLT, CmpFPredicate::OLE}, + {CmpFPredicate::OGT, CmpFPredicate::OGE}, isMin) || + matchSelectReduction( + reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole}, + {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) { + return createDecl(builder, symbolTable, reduce, + minMaxValueForFloat(type, !isMin)); + } + if (matchSelectReduction( + reduction, {CmpIPredicate::slt, CmpIPredicate::sle}, + {CmpIPredicate::sgt, CmpIPredicate::sge}, isMin) || + matchSelectReduction( + reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle}, + {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) { + omp::ReductionDeclareOp decl = createDecl( + builder, symbolTable, reduce, minMaxValueForSignedInt(type, !isMin)); + return addAtomicRMW(builder, + isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max, + decl, reduce); + } + if (matchSelectReduction( + reduction, {CmpIPredicate::ult, CmpIPredicate::ule}, + {CmpIPredicate::ugt, CmpIPredicate::uge}, isMin) || + matchSelectReduction( + reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule}, + {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) { + omp::ReductionDeclareOp decl = createDecl( + builder, symbolTable, reduce, minMaxValueForUnsignedInt(type, !isMin)); + return addAtomicRMW( + builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax, + decl, reduce); + } + + return nullptr; +} + namespace { -/// Converts SCF parallel operation into an OpenMP workshare loop construct. struct ParallelOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(scf::ParallelOp parallelOp, PatternRewriter &rewriter) const override { - // TODO: add support for reductions when OpenMP loops have them. - if (parallelOp.getNumResults() != 0) - return rewriter.notifyMatchFailure( - parallelOp, - "OpenMP dialect does not yet support loops with reductions"); - // Replace SCF yield with OpenMP yield. { OpBuilder::InsertionGuard guard(rewriter); @@ -43,47 +328,118 @@ struct ParallelOpLowering : public OpRewritePattern { parallelOp.getBody()->getTerminator(), ValueRange()); } - // Replace the loop. - auto omp = rewriter.create(parallelOp.getLoc()); - Block *block = rewriter.createBlock(&omp.getRegion()); - rewriter.setInsertionPointToStart(block); - auto loop = rewriter.create( - parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(), - parallelOp.step()); - rewriter.inlineRegionBefore(parallelOp.region(), loop.region(), - loop.region().begin()); - rewriter.create(parallelOp.getLoc()); - - rewriter.eraseOp(parallelOp); + // Declare reductions. + // TODO: consider checking it here is already a compatible reduction + // declaration and use it instead of redeclaring. + SmallVector reductionDeclSymbols; + for (auto reduce : parallelOp.getOps()) { + omp::ReductionDeclareOp decl = declareReduction(rewriter, reduce); + if (!decl) + return failure(); + reductionDeclSymbols.push_back( + SymbolRefAttr::get(rewriter.getContext(), decl.sym_name())); + } + + // Allocate reduction variables. Make sure the we don't overflow the stack + // with local `alloca`s by saving and restoring the stack pointer. + Location loc = parallelOp.getLoc(); + Value one = rewriter.create( + loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1)); + SmallVector reductionVariables; + reductionVariables.reserve(parallelOp.getNumReductions()); + Value token = rewriter.create( + loc, LLVM::LLVMPointerType::get(rewriter.getIntegerType(8))); + for (Value init : parallelOp.initVals()) { + assert((LLVM::isCompatibleType(init.getType()) || + init.getType().isa()) && + "cannot create a reduction variable if the type is not an LLVM " + "pointer element"); + Value storage = rewriter.create( + loc, LLVM::LLVMPointerType::get(init.getType()), one, 0); + rewriter.create(loc, init, storage); + reductionVariables.push_back(storage); + } + + // Replace the reduction operations contained in this loop. Must be done + // here rather than in a separate pattern to have access to the list of + // reduction variables. + for (auto pair : + llvm::zip(parallelOp.getOps(), reductionVariables)) { + OpBuilder::InsertionGuard guard(rewriter); + scf::ReduceOp reduceOp = std::get<0>(pair); + rewriter.setInsertionPoint(reduceOp); + rewriter.replaceOpWithNewOp( + reduceOp, reduceOp.operand(), std::get<1>(pair)); + } + + // Create the parallel wrapper. + auto ompParallel = rewriter.create(loc); + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.createBlock(&ompParallel.region()); + + // Replace SCF yield with OpenMP yield. + { + OpBuilder::InsertionGuard innerGuard(rewriter); + rewriter.setInsertionPointToEnd(parallelOp.getBody()); + assert(llvm::hasSingleElement(parallelOp.region()) && + "expected scf.parallel to have one block"); + rewriter.replaceOpWithNewOp( + parallelOp.getBody()->getTerminator(), ValueRange()); + } + + // Replace the loop. + auto loop = rewriter.create( + parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(), + parallelOp.step()); + rewriter.create(loc); + + rewriter.inlineRegionBefore(parallelOp.region(), loop.region(), + loop.region().begin()); + if (!reductionVariables.empty()) { + loop.reductionsAttr( + ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols)); + loop.reduction_varsMutable().append(reductionVariables); + } + } + + // Load loop results. + SmallVector results; + results.reserve(reductionVariables.size()); + for (Value variable : reductionVariables) { + Value res = rewriter.create(loc, variable); + results.push_back(res); + } + rewriter.replaceOp(parallelOp, results); + + rewriter.create(loc, token); return success(); } }; /// Applies the conversion patterns in the given function. -static LogicalResult applyPatterns(FuncOp func) { - ConversionTarget target(*func.getContext()); - target.addIllegalOp(); - target.addDynamicallyLegalOp( - [](scf::YieldOp op) { return !isa(op->getParentOp()); }); - target.addLegalDialect(); - - RewritePatternSet patterns(func.getContext()); - patterns.add(func.getContext()); +static LogicalResult applyPatterns(ModuleOp module) { + ConversionTarget target(*module.getContext()); + target.addIllegalOp(); + target.addLegalDialect(); + + RewritePatternSet patterns(module.getContext()); + patterns.add(module.getContext()); FrozenRewritePatternSet frozen(std::move(patterns)); - return applyPartialConversion(func, target, frozen); + return applyPartialConversion(module, target, frozen); } /// A pass converting SCF operations to OpenMP operations. struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase { /// Pass entry point. - void runOnFunction() override { - if (failed(applyPatterns(getFunction()))) + void runOnOperation() override { + if (failed(applyPatterns(getOperation()))) signalPassFailure(); } }; } // end namespace -std::unique_ptr> mlir::createConvertSCFToOpenMPPass() { +std::unique_ptr> mlir::createConvertSCFToOpenMPPass() { return std::make_unique(); } diff --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir new file mode 100644 index 00000000000000..bbc7d61a33a4b4 --- /dev/null +++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir @@ -0,0 +1,194 @@ +// RUN: mlir-opt -convert-scf-to-openmp -split-input-file %s | FileCheck %s + +// CHECK: omp.reduction.declare @[[$REDF:.*]] : f32 + +// CHECK: init +// CHECK: %[[INIT:.*]] = llvm.mlir.constant(0.000000e+00 : f32) +// CHECK: omp.yield(%[[INIT]] : f32) + +// CHECK: combiner +// CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) +// CHECK: %[[RES:.*]] = addf %[[ARG0]], %[[ARG1]] +// CHECK: omp.yield(%[[RES]] : f32) + +// CHECK: atomic +// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr): +// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]] +// CHECK: llvm.atomicrmw fadd %[[ARG0]], %[[RHS]] monotonic + +// CHECK-LABEL: @reduction1 +func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : index, %arg4 : index) { + // CHECK: %[[CST:.*]] = constant 0.0 + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 + // CHECK: llvm.intr.stacksave + // CHECK: %[[BUF:.*]] = llvm.alloca %[[ONE]] x f32 + // CHECK: llvm.store %[[CST]], %[[BUF]] + %step = constant 1 : index + %zero = constant 0.0 : f32 + // CHECK: omp.parallel + // CHECK: omp.wsloop + // CHECK-SAME: reduction(@[[$REDF]] -> %[[BUF]] + scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) + step (%arg4, %step) init (%zero) -> (f32) { + // CHECK: %[[CST_INNER:.*]] = constant 1.0 + %one = constant 1.0 : f32 + // CHECK: omp.reduction %[[CST_INNER]], %[[BUF]] + scf.reduce(%one) : f32 { + ^bb0(%lhs : f32, %rhs: f32): + %res = addf %lhs, %rhs : f32 + scf.reduce.return %res : f32 + } + // CHECK: omp.yield + } + // CHECK: omp.terminator + // CHECK: llvm.load %[[BUF]] + // CHECK: llvm.intr.stackrestore + return +} + +// ----- + +// Only check the declaration here, the rest is same as above. +// CHECK: omp.reduction.declare @{{.*}} : f32 + +// CHECK: init +// CHECK: %[[INIT:.*]] = llvm.mlir.constant(1.000000e+00 : f32) +// CHECK: omp.yield(%[[INIT]] : f32) + +// CHECK: combiner +// CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) +// CHECK: %[[RES:.*]] = mulf %[[ARG0]], %[[ARG1]] +// CHECK: omp.yield(%[[RES]] : f32) + +// CHECK-NOT: atomic + +// CHECK-LABEL: @reduction2 +func @reduction2(%arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : index, %arg4 : index) { + %step = constant 1 : index + %zero = constant 0.0 : f32 + scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) + step (%arg4, %step) init (%zero) -> (f32) { + %one = constant 1.0 : f32 + scf.reduce(%one) : f32 { + ^bb0(%lhs : f32, %rhs: f32): + %res = mulf %lhs, %rhs : f32 + scf.reduce.return %res : f32 + } + } + return +} + +// ----- + +// Only check the declaration here, the rest is same as above. +// CHECK: omp.reduction.declare @{{.*}} : f32 + +// CHECK: init +// CHECK: %[[INIT:.*]] = llvm.mlir.constant(-3.4 +// CHECK: omp.yield(%[[INIT]] : f32) + +// CHECK: combiner +// CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) +// CHECK: %[[CMP:.*]] = cmpf oge, %[[ARG0]], %[[ARG1]] +// CHECK: %[[RES:.*]] = select %[[CMP]], %[[ARG0]], %[[ARG1]] +// CHECK: omp.yield(%[[RES]] : f32) + +// CHECK-NOT: atomic + +// CHECK-LABEL: @reduction3 +func @reduction3(%arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : index, %arg4 : index) { + %step = constant 1 : index + %zero = constant 0.0 : f32 + scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) + step (%arg4, %step) init (%zero) -> (f32) { + %one = constant 1.0 : f32 + scf.reduce(%one) : f32 { + ^bb0(%lhs : f32, %rhs: f32): + %cmp = cmpf oge, %lhs, %rhs : f32 + %res = select %cmp, %lhs, %rhs : f32 + scf.reduce.return %res : f32 + } + } + return +} + +// ----- + +// CHECK: omp.reduction.declare @[[$REDF1:.*]] : f32 + +// CHECK: init +// CHECK: %[[INIT:.*]] = llvm.mlir.constant(-3.4 +// CHECK: omp.yield(%[[INIT]] : f32) + +// CHECK: combiner +// CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) +// CHECK: %[[CMP:.*]] = cmpf oge, %[[ARG0]], %[[ARG1]] +// CHECK: %[[RES:.*]] = select %[[CMP]], %[[ARG0]], %[[ARG1]] +// CHECK: omp.yield(%[[RES]] : f32) + +// CHECK-NOT: atomic + +// CHECK: omp.reduction.declare @[[$REDF2:.*]] : i64 + +// CHECK: init +// CHECK: %[[INIT:.*]] = llvm.mlir.constant +// CHECK: omp.yield(%[[INIT]] : i64) + +// CHECK: combiner +// CHECK: ^{{.*}}(%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64) +// CHECK: %[[CMP:.*]] = cmpi slt, %[[ARG0]], %[[ARG1]] +// CHECK: %[[RES:.*]] = select %[[CMP]], %[[ARG1]], %[[ARG0]] +// CHECK: omp.yield(%[[RES]] : i64) + +// CHECK: atomic +// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr): +// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]] +// CHECK: llvm.atomicrmw max %[[ARG0]], %[[RHS]] monotonic + +// CHECK-LABEL: @reduction4 +func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : index, %arg4 : index) -> (f32, i64) { + %step = constant 1 : index + // CHECK: %[[ZERO:.*]] = constant 0.0 + %zero = constant 0.0 : f32 + // CHECK: %[[IONE:.*]] = constant 1 + %ione = constant 1 : i64 + // CHECK: %[[BUF1:.*]] = llvm.alloca %{{.*}} x f32 + // CHECK: llvm.store %[[ZERO]], %[[BUF1]] + // CHECK: %[[BUF2:.*]] = llvm.alloca %{{.*}} x i64 + // CHECK: llvm.store %[[IONE]], %[[BUF2]] + + // CHECK: omp.parallel + // CHECK: omp.wsloop + // CHECK-SAME: reduction(@[[$REDF1]] -> %[[BUF1]] + // CHECK-SAME: @[[$REDF2]] -> %[[BUF2]] + %res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) + step (%arg4, %step) init (%zero, %ione) -> (f32, i64) { + %one = constant 1.0 : f32 + // CHECK: omp.reduction %{{.*}}, %[[BUF1]] + scf.reduce(%one) : f32 { + ^bb0(%lhs : f32, %rhs: f32): + %cmp = cmpf oge, %lhs, %rhs : f32 + %res = select %cmp, %lhs, %rhs : f32 + scf.reduce.return %res : f32 + } + // CHECK: fptosi + %1 = fptosi %one : f32 to i64 + // CHECK: omp.reduction %{{.*}}, %[[BUF2]] + scf.reduce(%1) : i64 { + ^bb1(%lhs: i64, %rhs: i64): + %cmp = cmpi slt, %lhs, %rhs : i64 + %res = select %cmp, %rhs, %lhs : i64 + scf.reduce.return %res : i64 + } + // CHECK: omp.yield + } + // CHECK: omp.terminator + // CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]] + // CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]] + // CHECK: return %[[RES1]], %[[RES2]] + return %res#0, %res#1 : f32, i64 +} diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir index 44059a27b32955..1507f927b9f007 100644 --- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir @@ -27,7 +27,6 @@ func @nested_loops(%arg0: index, %arg1: index, %arg2: index, scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) { // CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> () "test.payload"(%i, %j) : (index, index) -> () - // CHECK: omp.yield // CHECK: } } // CHECK: omp.yield @@ -38,6 +37,7 @@ func @nested_loops(%arg0: index, %arg1: index, %arg2: index, return } +// CHECK-LABEL: @adjacent_loops func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { // CHECK: omp.parallel { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 225703d0f3a3e3..f486701c1dfcdf 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4316,9 +4316,11 @@ cc_library( deps = [ ":ConversionPassIncGen", ":IR", + ":LLVMDialect", ":OpenMPDialect", ":Pass", ":SCFDialect", + ":StandardOps", ":Support", ":Transforms", ],