From fb85ab46fa1eef855294681638b8c1865402eec1 Mon Sep 17 00:00:00 2001 From: Anchu Rajendran Date: Thu, 30 Oct 2025 17:01:59 -0500 Subject: [PATCH] MLIR][LLVMIR] Adding scan lowering to llvm on the mlir side --- flang/lib/Lower/OpenMP/OpenMP.cpp | 38 +- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 378 ++++++++++++++---- .../Target/LLVMIR/openmp-reduction-scan.mlir | 130 ++++++ mlir/test/Target/LLVMIR/openmp-todo.mlir | 69 +++- 4 files changed, 531 insertions(+), 84 deletions(-) create mode 100644 mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index f86ee01355104..c54b3210cbebc 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -2326,12 +2326,40 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable, static mlir::omp::ScanOp genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable, - semantics::SemanticsContext &semaCtx, mlir::Location loc, - const ConstructQueue &queue, ConstructQueue::const_iterator item) { + semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, + mlir::Location loc, const ConstructQueue &queue, + ConstructQueue::const_iterator item) { mlir::omp::ScanOperands clauseOps; genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps); - return mlir::omp::ScanOp::create(converter.getFirOpBuilder(), - converter.getCurrentLocation(), clauseOps); + mlir::omp::ScanOp scanOp = mlir::omp::ScanOp::create( + converter.getFirOpBuilder(), converter.getCurrentLocation(), clauseOps); + // All loop indices should be loaded after the scan construct as otherwise, + // it would result in using the index variable across scan directive. + // (`Intra-iteration dependences from a statement in the structured + // block sequence that precede a scan directive to a statement in the + // structured block sequence that follows a scan directive must not exist, + // except for dependences for the list items specified in an inclusive or + // exclusive clause.`). + // TODO: Nested loops are not handled. + mlir::omp::LoopNestOp loopNestOp = + scanOp->getParentOfType(); + assert(loopNestOp.getNumLoops() == 1 && + "Scan directive inside nested do loops is not handled yet."); + mlir::Region ®ion = loopNestOp->getRegion(0); + mlir::Value indexVal = fir::getBase(region.getArgument(0)); + lower::pft::Evaluation *doConstructEval = eval.parentConstruct; + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + lower::pft::Evaluation *doLoop = &doConstructEval->getFirstNestedEvaluation(); + auto *doStmt = doLoop->getIf(); + assert(doStmt && "Expected do loop to be in the nested evaluation"); + const auto &loopControl = + std::get>(doStmt->t); + const parser::LoopControl::Bounds *bounds = + std::get_if(&loopControl->u); + mlir::Operation *storeOp = + setLoopVar(converter, loc, indexVal, bounds->name.thing.symbol); + firOpBuilder.setInsertionPointAfter(storeOp); + return scanOp; } static mlir::omp::SectionsOp @@ -3416,7 +3444,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter, loc, queue, item); break; case llvm::omp::Directive::OMPD_scan: - newOp = genScanOp(converter, symTable, semaCtx, loc, queue, item); + newOp = genScanOp(converter, symTable, semaCtx, eval, loc, queue, item); break; case llvm::omp::Directive::OMPD_section: llvm_unreachable("genOMPDispatch: OMPD_section"); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 1e2099d6cc1b2..5bdc7cb0eaca2 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -37,6 +37,7 @@ #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/ModuleUtils.h" +#include #include #include #include @@ -77,6 +78,22 @@ class OpenMPAllocaStackFrame llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint; }; +/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the +/// insertion points for allocas of parent of the current parallel region. The +/// insertion point is used to allocate variables to be shared by the threads +/// executing the parallel region. Lowering of scan reduction requires declaring +/// shared pointers to the temporary buffer to perform scan reduction. +class OpenMPParallelAllocaStackFrame + : public StateStackFrameBase { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPParallelAllocaStackFrame) + + explicit OpenMPParallelAllocaStackFrame( + llvm::OpenMPIRBuilder::InsertPointTy allocaIP) + : allocaInsertPoint(allocaIP) {} + llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint; +}; + /// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the /// collapsed canonical loop information corresponding to an \c omp.loop_nest /// operation. @@ -84,7 +101,13 @@ class OpenMPLoopInfoStackFrame : public StateStackFrameBase { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame) - llvm::CanonicalLoopInfo *loopInfo = nullptr; + /// For constructs like scan, one LoopInfo frame can contain multiple + /// Canonical Loops as a single openmpLoopNestOp will be split into input + /// loop and scan loop. + SmallVector loopInfos; + llvm::ScanInfo *scanInfo; + llvm::DenseMap *reductionVarToType = + new llvm::DenseMap(); }; /// Custom error class to signal translation errors that don't need reporting, @@ -323,6 +346,10 @@ static LogicalResult checkImplementationStatus(Operation &op) { if (op.getDistScheduleChunkSize()) result = todo("dist_schedule with chunk_size"); }; + auto checkExclusive = [&todo](auto op, LogicalResult &result) { + if (!op.getExclusiveVars().empty()) + result = todo("exclusive"); + }; auto checkHint = [](auto op, LogicalResult &) { if (op.getHint()) op.emitWarning("hint clause discarded"); @@ -371,9 +398,14 @@ static LogicalResult checkImplementationStatus(Operation &op) { if (!op.getReductionVars().empty() || op.getReductionByref() || op.getReductionSyms()) result = todo("reduction"); - if (op.getReductionMod() && - op.getReductionMod().value() != omp::ReductionModifier::defaultmod) - result = todo("reduction with modifier"); + if (op.getReductionMod()) { + if (isa(op)) { + if (op.getReductionMod().value() == omp::ReductionModifier::task) + result = todo("reduction with task modifier"); + } else { + result = todo("reduction with modifier"); + } + } }; auto checkTaskReduction = [&todo](auto op, LogicalResult &result) { if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() || @@ -397,6 +429,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkOrder(op, result); }) .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); }) + .Case([&](omp::ScanOp op) { checkExclusive(op, result); }) .Case([&](omp::SectionsOp op) { checkAllocate(op, result); checkPrivate(op, result); @@ -531,15 +564,59 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder, /// Find the loop information structure for the loop nest being translated. It /// will return a `null` value unless called from the translation function for /// a loop wrapper operation after successfully translating its body. -static llvm::CanonicalLoopInfo * -findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) { - llvm::CanonicalLoopInfo *loopInfo = nullptr; +static SmallVector +findCurrentLoopInfos(LLVM::ModuleTranslation &moduleTranslation) { + SmallVector loopInfos; + moduleTranslation.stackWalk( + [&](OpenMPLoopInfoStackFrame &frame) { + loopInfos = frame.loopInfos; + return WalkResult::interrupt(); + }); + return loopInfos; +} + +// LoopFrame stores the scaninfo which is used for scan reduction. +// Upon encountering an `inscan` reduction modifier, `scanInfoInitialize` +// initializes the ScanInfo and is used when scan directive is encountered +// in the body of the loop nest. +static llvm::ScanInfo * +findScanInfo(LLVM::ModuleTranslation &moduleTranslation) { + llvm::ScanInfo *scanInfo; + moduleTranslation.stackWalk( + [&](OpenMPLoopInfoStackFrame &frame) { + scanInfo = frame.scanInfo; + return WalkResult::interrupt(); + }); + return scanInfo; +} + +// The types of reduction vars are used for lowering scan directive which +// appears in the body of the loop. The types are stored in loop frame when +// reduction clause is encountered and is used when scan directive is +// encountered. +static llvm::DenseMap * +findReductionVarTypes(LLVM::ModuleTranslation &moduleTranslation) { + llvm::DenseMap *reductionVarToType = nullptr; moduleTranslation.stackWalk( [&](OpenMPLoopInfoStackFrame &frame) { - loopInfo = frame.loopInfo; + reductionVarToType = frame.reductionVarToType; return WalkResult::interrupt(); }); - return loopInfo; + return reductionVarToType; +} + +// Scan reduction requires a shared buffer to be allocated to perform reduction. +// ParallelAllocaStackFrame holds the allocaIP where shared allocation can be +// done. +static llvm::OpenMPIRBuilder::InsertPointTy +findParallelAllocaIP(LLVM::ModuleTranslation &moduleTranslation) { + llvm::OpenMPIRBuilder::InsertPointTy parallelAllocaIP; + moduleTranslation.stackWalk( + [&](OpenMPParallelAllocaStackFrame &frame) { + parallelAllocaIP = frame.allocaInsertPoint; + return WalkResult::interrupt(); + }); + return parallelAllocaIP; } /// Converts the given region that appears within an OpenMP dialect operation to @@ -1254,11 +1331,17 @@ initReductionVars(OP op, ArrayRef reductionArgs, for (auto [data, addr] : deferredStores) builder.CreateStore(data, addr); + llvm::DenseMap *reductionVarToType = + findReductionVarTypes(moduleTranslation); // Before the loop, store the initial values of reductions into reduction // variables. Although this could be done after allocas, we don't want to mess // up with the alloca insertion point. for (unsigned i = 0; i < op.getNumReductionVars(); ++i) { SmallVector phis; + llvm::Type *reductionType = + moduleTranslation.convertType(reductionDecls[i].getType()); + if (reductionVarToType != nullptr) + (*reductionVarToType)[privateReductionVariables[i]] = reductionType; // map block argument to initializer region mapInitializationArgs(op, moduleTranslation, reductionDecls, @@ -1330,15 +1413,20 @@ static void collectReductionInfo( // Collect the reduction information. reductionInfos.reserve(numReductions); + llvm::DenseMap *reductionVarToType = + findReductionVarTypes(moduleTranslation); for (unsigned i = 0; i < numReductions; ++i) { llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr; if (owningAtomicReductionGens[i]) atomicGen = owningAtomicReductionGens[i]; llvm::Value *variable = moduleTranslation.lookupValue(loop.getReductionVars()[i]); + llvm::Type *reductionType = + moduleTranslation.convertType(reductionDecls[i].getType()); + if (reductionVarToType != nullptr) + (*reductionVarToType)[privateReductionVariables[i]] = reductionType; reductionInfos.push_back( - {moduleTranslation.convertType(reductionDecls[i].getType()), variable, - privateReductionVariables[i], + {reductionType, variable, privateReductionVariables[i], /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar, owningReductionGens[i], /*ReductionGenClang=*/nullptr, atomicGen}); @@ -2543,6 +2631,9 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, std::optional scheduleMod = wsloopOp.getScheduleMod(); bool isSimd = wsloopOp.getScheduleSimd(); bool loopNeedsBarrier = !wsloopOp.getNowait(); + bool isInScanRegion = + wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() == + mlir::omp::ReductionModifier::inscan); // The only legal way for the direct parent to be omp.distribute is that this // represents 'distribute parallel do'. Otherwise, this is a regular @@ -2574,20 +2665,76 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, if (failed(handleError(regionBlock, opInst))) return failure(); - llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation); + SmallVector loopInfos = + findCurrentLoopInfos(moduleTranslation); + + const auto &&wsloopCodeGen = [&](llvm::CanonicalLoopInfo *loopInfo, + bool noLoopMode, bool inputScanLoop) { + // Emit Initialization and Update IR for linear variables + if (!isInScanRegion && !wsloopOp.getLinearVars().empty()) { + llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP = + linearClauseProcessor.initLinearVar(builder, moduleTranslation, + loopInfo->getPreheader()); + if (failed(handleError(afterBarrierIP, *loopOp))) + return failure(); + builder.restoreIP(*afterBarrierIP); + linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(), + loopInfo->getIndVar()); + linearClauseProcessor.outlineLinearFinalizationBB(builder, + loopInfo->getExit()); + } + builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP = + ompBuilder->applyWorkshareLoop( + ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier, + convertToScheduleKind(schedule), chunk, isSimd, + scheduleMod == omp::ScheduleModifier::monotonic, + scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered, + workshareLoopType, noLoopMode); + + if (failed(handleError(wsloopIP, opInst))) + return failure(); - // Emit Initialization and Update IR for linear variables - if (!wsloopOp.getLinearVars().empty()) { - llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP = - linearClauseProcessor.initLinearVar(builder, moduleTranslation, - loopInfo->getPreheader()); - if (failed(handleError(afterBarrierIP, *loopOp))) + // Emit finalization and in-place rewrites for linear vars. + if (!isInScanRegion && !wsloopOp.getLinearVars().empty()) { + llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP(); + if (loopInfo->getLastIter()) + return failure(); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP = + linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation, + loopInfo->getLastIter()); + if (failed(handleError(afterBarrierIP, *loopOp))) + return failure(); + for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++) + linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region", + index); + builder.restoreIP(oldIP); + } + if (!inputScanLoop || !isInScanRegion) + popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get()); + + return llvm::success(); + }; + + if (isInScanRegion) { + auto inputLoopFinishIp = loopInfos.front()->getAfterIP(); + builder.restoreIP(inputLoopFinishIp); + SmallVector owningReductionGens; + SmallVector owningAtomicReductionGens; + SmallVector reductionInfos; + collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls, + owningReductionGens, owningAtomicReductionGens, + privateReductionVariables, reductionInfos); + llvm::BasicBlock *cont = splitBB(builder, false, "omp.scan.loop.cont"); + llvm::ScanInfo *scanInfo = findScanInfo(moduleTranslation); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy redIP = + ompBuilder->emitScanReduction(builder.saveIP(), reductionInfos, + scanInfo); + if (failed(handleError(redIP, opInst))) return failure(); - builder.restoreIP(*afterBarrierIP); - linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(), - loopInfo->getIndVar()); - linearClauseProcessor.outlineLinearFinalizationBB(builder, - loopInfo->getExit()); + + builder.restoreIP(*redIP); + builder.CreateBr(cont); } builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); @@ -2612,42 +2759,37 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, } } - llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP = - ompBuilder->applyWorkshareLoop( - ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier, - convertToScheduleKind(schedule), chunk, isSimd, - scheduleMod == omp::ScheduleModifier::monotonic, - scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered, - workshareLoopType, noLoopMode); - - if (failed(handleError(wsloopIP, opInst))) - return failure(); - - // Emit finalization and in-place rewrites for linear vars. - if (!wsloopOp.getLinearVars().empty()) { - llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP(); - assert(loopInfo->getLastIter() && - "`lastiter` in CanonicalLoopInfo is nullptr"); - llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP = - linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation, - loopInfo->getLastIter()); - if (failed(handleError(afterBarrierIP, *loopOp))) + // For Scan loops input loop need not pop cancellation CB and hence, it is set + // false for the first loop + bool inputScanLoop = isInScanRegion; + for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) { + // TODO: Linear clause support needs to be enabled for scan reduction. + if (isInScanRegion) + assert(wsloopOp.getLinearVars().empty() && + "Linear clause support is not enabled with scan reduction"); + if (failed(wsloopCodeGen(loopInfo, noLoopMode, inputScanLoop))) return failure(); - for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++) - linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region", - index); - builder.restoreIP(oldIP); + inputScanLoop = false; } - // Set the correct branch target for task cancellation - popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get()); - - // Process the reductions if required. - if (failed(createReductionsAndCleanup( - wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls, - privateReductionVariables, isByRef, wsloopOp.getNowait(), - /*isTeamsReduction=*/false))) - return failure(); + if (isInScanRegion) { + SmallVector reductionRegions; + llvm::transform(reductionDecls, std::back_inserter(reductionRegions), + [](omp::DeclareReductionOp reductionDecl) { + return &reductionDecl.getCleanupRegion(); + }); + if (failed(inlineOmpRegionCleanup( + reductionRegions, privateReductionVariables, moduleTranslation, + builder, "omp.reduction.cleanup"))) + return failure(); + } else { + // Process the reductions if required. + if (failed(createReductionsAndCleanup( + wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls, + privateReductionVariables, isByRef, wsloopOp.getNowait(), + /*isTeamsReduction=*/false))) + return failure(); + } return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(), privateVarsInfo.llvmVars, @@ -2815,6 +2957,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + LLVM::ModuleTranslation::SaveStack frame( + moduleTranslation, allocaIP); llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB, @@ -2935,12 +3079,15 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, return failure(); builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); - llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation); - ompBuilder->applySimd(loopInfo, alignedVars, - simdOp.getIfExpr() - ? moduleTranslation.lookupValue(simdOp.getIfExpr()) - : nullptr, - order, simdlen, safelen); + SmallVector loopInfos = + findCurrentLoopInfos(moduleTranslation); + for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) { + ompBuilder->applySimd( + loopInfo, alignedVars, + simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr()) + : nullptr, + order, simdlen, safelen); + } // We now need to reduce the per-simd-lane reduction variable into the // original variable. This works a bit differently to other reductions (e.g. @@ -2991,6 +3138,40 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, privateVarsInfo.privatizers); } +static LogicalResult +convertOmpScan(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + if (failed(checkImplementationStatus(opInst))) + return failure(); + auto scanOp = cast(opInst); + bool isInclusive = scanOp.hasInclusiveVars(); + SmallVector llvmScanVars; + SmallVector llvmScanVarsType; + mlir::OperandRange mlirScanVars = scanOp.getInclusiveVars(); + if (!isInclusive) + mlirScanVars = scanOp.getExclusiveVars(); + + llvm::DenseMap *reductionVarToType = + findReductionVarTypes(moduleTranslation); + for (auto val : mlirScanVars) { + llvm::Value *llvmVal = moduleTranslation.lookupValue(val); + llvmScanVars.push_back(llvmVal); + llvmScanVarsType.push_back((*reductionVarToType)[llvmVal]); + } + llvm::OpenMPIRBuilder::InsertPointTy allocaIP = + findParallelAllocaIP(moduleTranslation); + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + llvm::ScanInfo *scanInfo = findScanInfo(moduleTranslation); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = + moduleTranslation.getOpenMPBuilder()->createScan( + ompLoc, allocaIP, llvmScanVars, llvmScanVarsType, isInclusive, + scanInfo); + if (failed(handleError(afterIP, opInst))) + return failure(); + builder.restoreIP(*afterIP); + return success(); +} + /// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder. static LogicalResult convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, @@ -3052,14 +3233,50 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, computeIP = loopInfos.front()->getPreheaderIP(); } + bool isInScanRegion = false; + if (auto wsloopOp = loopOp->getParentOfType()) + isInScanRegion = + wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() == + mlir::omp::ReductionModifier::inscan); + if (isInScanRegion) { + // TODO: Handle nesting if Scan loop is nested in a loop + assert(loopOp.getNumLoops() == 1 && + "Scan directive inside nested do loops is not handled yet."); + llvm::Expected res = ompBuilder->scanInfoInitialize(); + if (failed(handleError(res, *loopOp))) + return failure(); + llvm::ScanInfo *scanInfo = res.get(); + moduleTranslation.stackWalk( + [&](OpenMPLoopInfoStackFrame &frame) { + frame.scanInfo = scanInfo; + return WalkResult::interrupt(); + }); + llvm::Expected> loopResults = + ompBuilder->createCanonicalScanLoops( + loc, bodyGen, lowerBound, upperBound, step, + /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP, "loop", + scanInfo); + + if (failed(handleError(loopResults, *loopOp))) + return failure(); + llvm::CanonicalLoopInfo *inputLoop = loopResults.get().front(); + llvm::CanonicalLoopInfo *scanLoop = loopResults.get().back(); + moduleTranslation.stackWalk( + [&](OpenMPLoopInfoStackFrame &frame) { + frame.loopInfos.push_back(inputLoop); + frame.loopInfos.push_back(scanLoop); + return WalkResult::interrupt(); + }); + builder.restoreIP(scanLoop->getAfterIP()); + // TODO: tiling and collapse are not yet implemented for scan reduction + return success(); + } llvm::Expected loopResult = ompBuilder->createCanonicalLoop( loc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP); - if (failed(handleError(loopResult, *loopOp))) return failure(); - loopInfos.push_back(*loopResult); } @@ -3102,7 +3319,7 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder, assert(newTopLoopInfo && "New top loop information is missing"); moduleTranslation.stackWalk( [&](OpenMPLoopInfoStackFrame &frame) { - frame.loopInfo = newTopLoopInfo; + frame.loopInfos.push_back(newTopLoopInfo); return WalkResult::interrupt(); }); @@ -4965,18 +5182,20 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, bool loopNeedsBarrier = false; llvm::Value *chunk = nullptr; - llvm::CanonicalLoopInfo *loopInfo = - findCurrentLoopInfo(moduleTranslation); - llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP = - ompBuilder->applyWorkshareLoop( - ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier, - convertToScheduleKind(schedule), chunk, isSimd, - scheduleMod == omp::ScheduleModifier::monotonic, - scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered, - workshareLoopType); - - if (!wsloopIP) - return wsloopIP.takeError(); + SmallVector loopInfos = + findCurrentLoopInfos(moduleTranslation); + for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) { + llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP = + ompBuilder->applyWorkshareLoop( + ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier, + convertToScheduleKind(schedule), chunk, isSimd, + scheduleMod == omp::ScheduleModifier::monotonic, + scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered, + workshareLoopType); + + if (!wsloopIP) + return wsloopIP.takeError(); + } } if (failed(cleanupPrivateVars(builder, moduleTranslation, @@ -6167,6 +6386,11 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, .Case([&](omp::WsloopOp) { return convertOmpWsloop(*op, builder, moduleTranslation); }) + .Case([&](omp::ScanOp) { + if (failed(checkImplementationStatus(*op))) + return failure(); + return convertOmpScan(*op, builder, moduleTranslation); + }) .Case([&](omp::SimdOp) { return convertOmpSimd(*op, builder, moduleTranslation); }) diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir new file mode 100644 index 0000000000000..ed04a069b998f --- /dev/null +++ b/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir @@ -0,0 +1,130 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s +omp.declare_reduction @add_reduction_i32 : i32 init { +^bb0(%arg0: i32): + %0 = llvm.mlir.constant(0 : i32) : i32 + omp.yield(%0 : i32) +} combiner { +^bb0(%arg0: i32, %arg1: i32): + %0 = llvm.add %arg0, %arg1 : i32 + omp.yield(%0 : i32) +} +// CHECK-LABEL: @scan_reduction +llvm.func @scan_reduction() { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.alloca %0 x i32 {bindc_name = "z"} : (i64) -> !llvm.ptr + %2 = llvm.mlir.constant(1 : i64) : i64 + %3 = llvm.alloca %2 x i32 {bindc_name = "y"} : (i64) -> !llvm.ptr + %4 = llvm.mlir.constant(1 : i64) : i64 + %5 = llvm.alloca %4 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr + %6 = llvm.mlir.constant(1 : i64) : i64 + %7 = llvm.alloca %6 x i32 {bindc_name = "k"} : (i64) -> !llvm.ptr + %8 = llvm.mlir.constant(0 : index) : i64 + %9 = llvm.mlir.constant(1 : index) : i64 + %10 = llvm.mlir.constant(100 : i32) : i32 + %11 = llvm.mlir.constant(1 : i32) : i32 + %12 = llvm.mlir.constant(0 : i32) : i32 + %13 = llvm.mlir.constant(100 : index) : i64 + %14 = llvm.mlir.addressof @_QFEa : !llvm.ptr + %15 = llvm.mlir.addressof @_QFEb : !llvm.ptr + omp.parallel { + %37 = llvm.mlir.constant(1 : i64) : i64 + %38 = llvm.alloca %37 x i32 {bindc_name = "k", pinned} : (i64) -> !llvm.ptr + %39 = llvm.mlir.constant(1 : i64) : i64 + omp.wsloop reduction(mod: inscan, @add_reduction_i32 %5 -> %arg0 : !llvm.ptr) { + omp.loop_nest (%arg1) : i32 = (%11) to (%10) inclusive step (%11) { + llvm.store %arg1, %38 : i32, !llvm.ptr + %40 = llvm.load %arg0 : !llvm.ptr -> i32 + %41 = llvm.load %38 : !llvm.ptr -> i32 + %42 = llvm.sext %41 : i32 to i64 + %50 = llvm.getelementptr %14[%42] : (!llvm.ptr, i64) -> !llvm.ptr, i32 + %51 = llvm.load %50 : !llvm.ptr -> i32 + %52 = llvm.add %40, %51 : i32 + llvm.store %52, %arg0 : i32, !llvm.ptr + omp.scan inclusive(%arg0 : !llvm.ptr) + llvm.store %arg1, %38 : i32, !llvm.ptr + %53 = llvm.load %arg0 : !llvm.ptr -> i32 + %54 = llvm.load %38 : !llvm.ptr -> i32 + %55 = llvm.sext %54 : i32 to i64 + %63 = llvm.getelementptr %15[%55] : (!llvm.ptr, i64) -> !llvm.ptr, i32 + llvm.store %53, %63 : i32, !llvm.ptr + omp.yield + } + } + omp.terminator + } + llvm.return +} +llvm.mlir.global internal @_QFEa() {addr_space = 0 : i32} : !llvm.array<100 x i32> { + %0 = llvm.mlir.zero : !llvm.array<100 x i32> + llvm.return %0 : !llvm.array<100 x i32> +} +llvm.mlir.global internal @_QFEb() {addr_space = 0 : i32} : !llvm.array<100 x i32> { + %0 = llvm.mlir.zero : !llvm.array<100 x i32> + llvm.return %0 : !llvm.array<100 x i32> +} +llvm.mlir.global internal constant @_QFECn() {addr_space = 0 : i32} : i32 { + %0 = llvm.mlir.constant(100 : i32) : i32 + llvm.return %0 : i32 +} +//CHECK: %vla = alloca ptr, align 8 +//CHECK: omp_parallel +//CHECK: store ptr %vla, ptr %gep_vla, align 8 +//CHECK: @__kmpc_fork_call +//CHECK: void @scan_reduction..omp_par +//CHECK: %[[BUFF_PTR:.+]] = load ptr, ptr %gep_vla +//CHECK: @__kmpc_masked +//CHECK: @__kmpc_barrier +//CHECK: @__kmpc_masked +//CHECK: @__kmpc_barrier +//CHECK: omp.scan.loop.cont: +//CHECK: @__kmpc_masked +//CHECK: @__kmpc_barrier +//CHECK: %[[FREE_VAR:.+]] = load ptr, ptr %[[BUFF_PTR]], align 8 +//CHECK: %[[ARRLAST:.+]] = getelementptr inbounds i32, ptr %[[FREE_VAR]], i32 100 +//CHECK: %[[RES:.+]] = load i32, ptr %[[ARRLAST]], align 4 +//CHECK: store i32 %[[RES]], ptr %loadgep{{.*}}, align 4 +//CHECK: tail call void @free(ptr %[[FREE_VAR]]) +//CHECK: @__kmpc_end_masked +//CHECK: omp.inscan.dispatch{{.*}}: ; preds = %omp_loop.body{{.*}} +//CHECK: %[[BUFFVAR:.+]] = load ptr, ptr %[[BUFF_PTR]], align 8 +//CHECK: %[[arrayOffset1:.+]] = getelementptr inbounds i32, ptr %[[BUFFVAR]], i32 %{{.*}} +//CHECK: %[[BUFFVAL1:.+]] = load i32, ptr %[[arrayOffset1]], align 4 +//CHECK: store i32 %[[BUFFVAL1]], ptr %{{.*}}, align 4 +//CHECK: %[[LOG:.+]] = call double @llvm.log2.f64(double 1.000000e+02) #0 +//CHECK: %[[CEIL:.+]] = call double @llvm.ceil.f64(double %[[LOG]]) #0 +//CHECK: %[[UB:.+]] = fptoui double %[[CEIL]] to i32 +//CHECK: br label %omp.outer.log.scan.body +//CHECK: omp.outer.log.scan.body: +//CHECK: %[[K:.+]] = phi i32 [ 0, %{{.*}} ], [ %[[NEXTK:.+]], %omp.inner.log.scan.exit ] +//CHECK: %[[I:.+]] = phi i32 [ 1, %{{.*}} ], [ %[[NEXTI:.+]], %omp.inner.log.scan.exit ] +//CHECK: %[[CMP1:.+]] = icmp uge i32 99, %[[I]] +//CHECK: br i1 %[[CMP1]], label %omp.inner.log.scan.body, label %omp.inner.log.scan.exit +//CHECK: omp.inner.log.scan.exit: ; preds = %omp.inner.log.scan.body, %omp.outer.log.scan.body +//CHECK: %[[NEXTK]] = add nuw i32 %[[K]], 1 +//CHECK: %[[NEXTI]] = shl nuw i32 %[[I]], 1 +//CHECK: %[[CMP2:.+]] = icmp ne i32 %[[NEXTK]], %[[UB]] +//CHECK: br i1 %[[CMP2]], label %omp.outer.log.scan.body, label %omp.outer.log.scan.exit +//CHECK: omp.outer.log.scan.exit: ; preds = %omp.inner.log.scan.exit +//CHECK: @__kmpc_end_masked +//CHECK: omp.inner.log.scan.body: ; preds = %omp.inner.log.scan.body, %omp.outer.log.scan.body +//CHECK: %[[CNT:.+]] = phi i32 [ 99, %omp.outer.log.scan.body ], [ %[[CNTNXT:.+]], %omp.inner.log.scan.body ] +//CHECK: %[[BUFF:.+]] = load ptr, ptr %[[BUFF_PTR]] +//CHECK: %[[IND1:.+]] = add i32 %[[CNT]], 1 +//CHECK: %[[IND1PTR:.+]] = getelementptr inbounds i32, ptr %[[BUFF]], i32 %[[IND1]] +//CHECK: %[[IND2:.+]] = sub nuw i32 %[[IND1]], %[[I]] +//CHECK: %[[IND2PTR:.+]] = getelementptr inbounds i32, ptr %[[BUFF]], i32 %[[IND2]] +//CHECK: %[[IND1VAL:.+]] = load i32, ptr %[[IND1PTR]], align 4 +//CHECK: %[[IND2VAL:.+]] = load i32, ptr %[[IND2PTR]], align 4 +//CHECK: %[[REDVAL:.+]] = add i32 %[[IND1VAL]], %[[IND2VAL]] +//CHECK: store i32 %[[REDVAL]], ptr %[[IND1PTR]], align 4 +//CHECK: %[[CNTNXT]] = sub nuw i32 %[[CNT]], 1 +//CHECK: %[[CMP3:.+]] = icmp uge i32 %[[CNTNXT]], %[[I]] +//CHECK: br i1 %[[CMP3]], label %omp.inner.log.scan.body, label %omp.inner.log.scan.exit +//CHECK: omp.inscan.dispatch: ; preds = %omp_loop.body +//CHECK: br i1 true, label %omp.before.scan.bb, label %omp.after.scan.bb +//CHECK: omp.loop_nest.region: ; preds = %omp.before.scan.bb +//CHECK: %[[BUFFER:.+]] = load ptr, ptr %loadgep_vla, align 8 +//CHECK: %[[ARRAYOFFSET2:.+]] = getelementptr inbounds i32, ptr %[[BUFFER]], i32 %{{.*}} +//CHECK-NEXT: %[[REDPRIVVAL:.+]] = load i32, ptr %{{.*}}, align 4 +//CHECK: store i32 %[[REDPRIVVAL]], ptr %[[ARRAYOFFSET2]], align 4 +//CHECK: br label %omp.scan.loop.exit diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index 2fa4470bb8300..074ea634b133e 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -129,6 +129,68 @@ llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { // ----- +omp.declare_reduction @add_f32 : f32 +init { +^bb0(%arg: f32): + %0 = llvm.mlir.constant(0.0 : f32) : f32 + omp.yield (%0 : f32) +} +combiner { +^bb1(%arg0: f32, %arg1: f32): + %1 = llvm.fadd %arg0, %arg1 : f32 + omp.yield (%1 : f32) +} +atomic { +^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr): + %2 = llvm.load %arg3 : !llvm.ptr -> f32 + llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32 + omp.yield +} +llvm.func @task_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { + // expected-error@below {{LLVM Translation failed for operation: omp.wsloop}} + // expected-error@below {{not yet implemented: Unhandled clause reduction with task modifier in omp.wsloop operation}} + omp.wsloop reduction(mod:task, @add_f32 %x -> %prv : !llvm.ptr) { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.yield + } + } + llvm.return +} + + +// ----- + +omp.declare_reduction @add_f32 : f32 +init { +^bb0(%arg: f32): + %0 = llvm.mlir.constant(0.0 : f32) : f32 + omp.yield (%0 : f32) +} +combiner { +^bb1(%arg0: f32, %arg1: f32): + %1 = llvm.fadd %arg0, %arg1 : f32 + omp.yield (%1 : f32) +} +atomic { +^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr): + %2 = llvm.load %arg3 : !llvm.ptr -> f32 + llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32 + omp.yield +} +llvm.func @simd_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { + // expected-error@below {{not yet implemented: Unhandled clause reduction with modifier in omp.simd operation}} + // expected-error@below {{LLVM Translation failed for operation: omp.simd}} + omp.simd reduction(mod:inscan, @add_f32 %x -> %prv : !llvm.ptr) { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + omp.scan inclusive(%prv : !llvm.ptr) + omp.yield + } + } + llvm.return +} + +// ----- + omp.declare_reduction @add_f32 : f32 init { ^bb0(%arg: f32): @@ -147,17 +209,20 @@ atomic { omp.yield } llvm.func @scan_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { - // expected-error@below {{not yet implemented: Unhandled clause reduction with modifier in omp.wsloop operation}} // expected-error@below {{LLVM Translation failed for operation: omp.wsloop}} omp.wsloop reduction(mod:inscan, @add_f32 %x -> %prv : !llvm.ptr) { + // expected-error@below {{LLVM Translation failed for operation: omp.loop_nest}} omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { - omp.scan inclusive(%prv : !llvm.ptr) + // expected-error@below {{not yet implemented: Unhandled clause exclusive in omp.scan operation}} + // expected-error@below {{LLVM Translation failed for operation: omp.scan}} + omp.scan exclusive(%prv : !llvm.ptr) omp.yield } } llvm.return } + // ----- llvm.func @single_allocate(%x : !llvm.ptr) {