diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 383fd9d94661a..19a8a53556a73 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1511,6 +1511,28 @@ class OpenMPIRBuilder { : DepKind(DepKind), DepValueType(DepValueType), DepVal(DepVal) {} }; + /// A struct to pack static and dynamic dependency information for a task. + /// + /// For fixed-count (non-iterator) dependencies, callers populate \p Deps + /// and the builder allocates and fills the kmp_depend_info array internally. + /// For iterator-based dependencies, the caller pre-builds the array and + /// sets \p NumDeps and \p DepArray directly. + struct DependenciesInfo { + SmallVector Deps; // vector of dependencies + Value *NumDeps; // number of kmp_depend_info entries (used by iterator path) + Value *DepArray; // kmp_depend_info array (used by iterator path) + + DependenciesInfo() : Deps(), NumDeps(nullptr), DepArray(nullptr) {} + DependenciesInfo(SmallVector D) + : Deps(std::move(D)), NumDeps(nullptr), DepArray(nullptr) {} + + bool empty() const { return Deps.empty() && DepArray == nullptr; } + }; + + /// Store one kmp_depend_info entry at the given \p Entry pointer. + LLVM_ABI void emitTaskDependency(IRBuilderBase &Builder, Value *Entry, + const DependData &Dep); + /// Return the LLVM struct type matching runtime `kmp_task_affinity_info_t`. /// `{ kmp_intptr_t base_addr; size_t len; flags (bitfield storage as i32) }` LLVM_ABI llvm::StructType *getKmpTaskAffinityInfoTy(); @@ -1579,8 +1601,8 @@ class OpenMPIRBuilder { /// cannot be resumed until execution of the structured /// block that is associated with the generated task is /// completed. - /// \param Dependencies Vector of DependData objects holding information of - /// dependencies as specified by the 'depend' clause. + /// \param Dependencies Dependencies info holding either a vector of + /// DependData objects or a pre-built dependency array. /// \param Affinities AffinityData object holding information of accumulated /// affinities as specified by the 'affinity' clause. /// \param EventHandle If present, signifies the event handle as part of @@ -1591,7 +1613,7 @@ class OpenMPIRBuilder { LLVM_ABI InsertPointOrErrorTy createTask( const LocationDescription &Loc, InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, bool Tied = true, Value *Final = nullptr, - Value *IfCondition = nullptr, SmallVector Dependencies = {}, + Value *IfCondition = nullptr, const DependenciesInfo &Dependencies = {}, AffinityData Affinities = {}, bool Mergeable = false, Value *EventHandle = nullptr, Value *Priority = nullptr); @@ -2888,15 +2910,14 @@ class OpenMPIRBuilder { /// \param DeviceID Identifier for the device via the 'device' clause. /// \param RTLoc Source location identifier /// \param AllocaIP The insertion point to be used for alloca instructions. - /// \param Dependencies Vector of DependData objects holding information of - /// dependencies as specified by the 'depend' clause. + /// \param Dependencies Dependencies info as specified by the 'depend' clause. /// \param HasNoWait True if the target construct had 'nowait' on it, false /// otherwise - LLVM_ABI InsertPointOrErrorTy emitTargetTask( - TargetTaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc, - OpenMPIRBuilder::InsertPointTy AllocaIP, - const SmallVector &Dependencies, - const TargetDataRTArgs &RTArgs, bool HasNoWait); + LLVM_ABI InsertPointOrErrorTy + emitTargetTask(TargetTaskBodyCallbackTy TaskBodyCB, Value *DeviceID, + Value *RTLoc, OpenMPIRBuilder::InsertPointTy AllocaIP, + const DependenciesInfo &Dependencies, + const TargetDataRTArgs &RTArgs, bool HasNoWait); /// Emit the arguments to be passed to the runtime library based on the /// arrays of base pointers, pointers, sizes, map types, and mappers. If @@ -3537,7 +3558,7 @@ class OpenMPIRBuilder { TargetBodyGenCallbackTy BodyGenCB, TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, CustomMapperCallbackTy CustomMapperCB, - const SmallVector &Dependencies, bool HasNowait = false, + const DependenciesInfo &Dependencies = {}, bool HasNowait = false, Value *DynCGroupMem = nullptr, omp::OMPDynGroupprivateFallbackType DynCGroupMemFallback = omp::OMPDynGroupprivateFallbackType::Abort); diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index d06ebbaca9f08..a71e267802656 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1965,6 +1965,29 @@ void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) { emitTaskyieldImpl(Loc); } +void OpenMPIRBuilder::emitTaskDependency(IRBuilderBase &Builder, Value *Entry, + const DependData &Dep) { + // Store the pointer to the variable + Value *Addr = Builder.CreateStructGEP( + DependInfo, Entry, + static_cast(RTLDependInfoFields::BaseAddr)); + Value *DepValPtr = Builder.CreatePtrToInt(Dep.DepVal, SizeTy); + Builder.CreateStore(DepValPtr, Addr); + // Store the size of the variable + Value *Size = Builder.CreateStructGEP( + DependInfo, Entry, static_cast(RTLDependInfoFields::Len)); + Builder.CreateStore( + ConstantInt::get(SizeTy, + M.getDataLayout().getTypeStoreSize(Dep.DepValueType)), + Size); + // Store the dependency kind + Value *Flags = Builder.CreateStructGEP( + DependInfo, Entry, static_cast(RTLDependInfoFields::Flags)); + Builder.CreateStore(ConstantInt::get(Builder.getInt8Ty(), + static_cast(Dep.DepKind)), + Flags); +} + // Processes the dependencies in Dependencies and does the following // - Allocates space on the stack of an array of DependInfo objects // - Populates each DependInfo object with relevant information of @@ -1978,7 +2001,7 @@ static Value *emitTaskDependencies( return nullptr; // Given a vector of DependData objects, in this function we create an - // array on the stack that holds kmp_dep_info objects corresponding + // array on the stack that holds kmp_depend_info objects corresponding // to each dependency. This is then passed to the OpenMP runtime. // For example, if there are 'n' dependencies then the following psedo // code is generated. Assume the first dependence is on a variable 'a' @@ -1995,7 +2018,6 @@ static Value *emitTaskDependencies( IRBuilderBase &Builder = OMPBuilder.Builder; Type *DependInfo = OMPBuilder.DependInfo; - Module &M = OMPBuilder.M; Value *DepArray = nullptr; OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP(); @@ -2010,26 +2032,7 @@ static Value *emitTaskDependencies( for (const auto &[DepIdx, Dep] : enumerate(Dependencies)) { Value *Base = Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, DepIdx); - // Store the pointer to the variable - Value *Addr = Builder.CreateStructGEP( - DependInfo, Base, - static_cast(RTLDependInfoFields::BaseAddr)); - Value *DepValPtr = Builder.CreatePtrToInt(Dep.DepVal, Builder.getInt64Ty()); - Builder.CreateStore(DepValPtr, Addr); - // Store the size of the variable - Value *Size = Builder.CreateStructGEP( - DependInfo, Base, static_cast(RTLDependInfoFields::Len)); - Builder.CreateStore( - Builder.getInt64(M.getDataLayout().getTypeStoreSize(Dep.DepValueType)), - Size); - // Store the dependency kind - Value *Flags = Builder.CreateStructGEP( - DependInfo, Base, - static_cast(RTLDependInfoFields::Flags)); - Builder.CreateStore( - ConstantInt::get(Builder.getInt8Ty(), - static_cast(Dep.DepKind)), - Flags); + OMPBuilder.emitTaskDependency(Builder, Base, Dep); } return DepArray; } @@ -2450,7 +2453,7 @@ llvm::StructType *OpenMPIRBuilder::getKmpTaskAffinityInfoTy() { OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask( const LocationDescription &Loc, InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, bool Tied, Value *Final, Value *IfCondition, - SmallVector Dependencies, AffinityData Affinities, + const DependenciesInfo &Dependencies, AffinityData Affinities, bool Mergeable, Value *EventHandle, Value *Priority) { if (!updateToLocation(Loc)) @@ -2629,7 +2632,15 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask( Builder.CreateStore(Priority, CmplrData); } - Value *DepArray = emitTaskDependencies(*this, Dependencies); + Value *DepArray = nullptr; + Value *NumDeps = nullptr; + if (Dependencies.DepArray) { + DepArray = Dependencies.DepArray; + NumDeps = Dependencies.NumDeps; + } else if (!Dependencies.Deps.empty()) { + DepArray = emitTaskDependencies(*this, Dependencies.Deps); + NumDeps = Builder.getInt32(Dependencies.Deps.size()); + } // In the presence of the `if` clause, the following IR is generated: // ... @@ -2660,12 +2671,12 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask( &ElseTI); Builder.SetInsertPoint(ElseTI); - if (Dependencies.size()) { + if (DepArray) { Function *TaskWaitFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_wait_deps); createRuntimeFunctionCall( TaskWaitFn, - {Ident, ThreadID, Builder.getInt32(Dependencies.size()), DepArray, + {Ident, ThreadID, NumDeps, DepArray, ConstantInt::get(Builder.getInt32Ty(), 0), ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))}); } @@ -2684,13 +2695,13 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask( Builder.SetInsertPoint(ThenTI); } - if (Dependencies.size()) { + if (DepArray) { Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps); createRuntimeFunctionCall( TaskFn, - {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()), - DepArray, ConstantInt::get(Builder.getInt32Ty(), 0), + {Ident, ThreadID, TaskData, NumDeps, DepArray, + ConstantInt::get(Builder.getInt32Ty(), 0), ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))}); } else { @@ -8813,8 +8824,8 @@ static Error emitTargetOutlinedFunction( OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask( TargetTaskBodyCallbackTy TaskBodyCB, Value *DeviceID, Value *RTLoc, OpenMPIRBuilder::InsertPointTy AllocaIP, - const SmallVector &Dependencies, - const TargetDataRTArgs &RTArgs, bool HasNoWait) { + const DependenciesInfo &Dependencies, const TargetDataRTArgs &RTArgs, + bool HasNoWait) { // The following explains the code-gen scenario for the `target` directive. A // similar scneario is followed for other device-related directives (e.g. @@ -9131,7 +9142,15 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask( } } - Value *DepArray = emitTaskDependencies(*this, Dependencies); + Value *DepArray = nullptr; + Value *NumDeps = nullptr; + if (Dependencies.DepArray) { + DepArray = Dependencies.DepArray; + NumDeps = Dependencies.NumDeps; + } else if (!Dependencies.Deps.empty()) { + DepArray = emitTaskDependencies(*this, Dependencies.Deps); + NumDeps = Builder.getInt32(Dependencies.Deps.size()); + } // --------------------------------------------------------------- // V5.2 13.8 target construct @@ -9148,7 +9167,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask( createRuntimeFunctionCall( TaskWaitFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, - /*ndeps=*/Builder.getInt32(Dependencies.size()), + /*ndeps=*/NumDeps, /*dep_list=*/DepArray, /*ndeps_noalias=*/ConstantInt::get(Builder.getInt32Ty(), 0), /*noalias_dep_list=*/ @@ -9171,8 +9190,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask( getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps); createRuntimeFunctionCall( TaskFn, - {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()), - DepArray, ConstantInt::get(Builder.getInt32Ty(), 0), + {Ident, ThreadID, TaskData, NumDeps, DepArray, + ConstantInt::get(Builder.getInt32Ty(), 0), ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))}); } else { // Emit the @__kmpc_omp_task runtime call to spawn the task @@ -9207,19 +9226,19 @@ Error OpenMPIRBuilder::emitOffloadingArraysAndArgs( return Error::success(); } -static void emitTargetCall( - OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, - OpenMPIRBuilder::InsertPointTy AllocaIP, - OpenMPIRBuilder::TargetDataInfo &Info, - const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs, - const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs, - Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID, - SmallVectorImpl &Args, - OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB, - OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB, - const SmallVector &Dependencies, - bool HasNoWait, Value *DynCGroupMem, - OMPDynGroupprivateFallbackType DynCGroupMemFallback) { +static void +emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, + OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::TargetDataInfo &Info, + const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs, + const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs, + Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID, + SmallVectorImpl &Args, + OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB, + OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB, + const OpenMPIRBuilder::DependenciesInfo &Dependencies, + bool HasNoWait, Value *DynCGroupMem, + OMPDynGroupprivateFallbackType DynCGroupMemFallback) { // Generate a function call to the host fallback implementation of the target // region. This is called by the host when no offload entry was generated for // the target region and when the offloading call fails at runtime. @@ -9234,7 +9253,7 @@ static void emitTargetCall( return Builder.saveIP(); }; - bool HasDependencies = Dependencies.size() > 0; + bool HasDependencies = !Dependencies.empty(); bool RequiresOuterTargetTask = HasNoWait || HasDependencies; OpenMPIRBuilder::TargetKernelArgs KArgs; @@ -9412,9 +9431,9 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( SmallVectorImpl &Inputs, GenMapInfoCallbackTy GenMapInfoCB, OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc, OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, - CustomMapperCallbackTy CustomMapperCB, - const SmallVector &Dependencies, bool HasNowait, - Value *DynCGroupMem, OMPDynGroupprivateFallbackType DynCGroupMemFallback) { + CustomMapperCallbackTy CustomMapperCB, const DependenciesInfo &Dependencies, + bool HasNowait, Value *DynCGroupMem, + OMPDynGroupprivateFallbackType DynCGroupMemFallback) { if (!updateToLocation(Loc)) return InsertPointTy(); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 7d25ac9cadedb..4f17b54b07e83 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -7350,7 +7350,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskDepend) { OMPBuilder.createTask( Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), BodyGenCB, - /*Tied=*/false, /*Final*/ nullptr, /*IfCondition*/ nullptr, DDS)); + /*Tied=*/false, /*Final*/ nullptr, /*IfCondition*/ nullptr, + OpenMPIRBuilder::DependenciesInfo{std::move(DDS)})); Builder.restoreIP(AfterIP); OMPBuilder.finalize(); Builder.CreateRetVoid(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 2e15f4de4545d..c425fabeea17f 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -333,11 +333,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { if (!op.getDependVars().empty() || op.getDependKinds()) result = todo("depend"); }; - auto checkDependIteratorModifier = [&todo](auto op, LogicalResult &result) { - if (!op.getDependIterated().empty() || - (op.getDependIteratedKinds() && !op.getDependIteratedKinds()->empty())) - result = todo("depend with iterator modifier"); - }; auto checkHint = [](auto op, LogicalResult &) { if (op.getHint()) op.emitWarning("hint clause discarded"); @@ -410,7 +405,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { }) .Case([&](omp::TaskOp op) { checkAllocate(op, result); - checkDependIteratorModifier(op, result); checkInReduction(op, result); }) .Case([&](omp::TaskgroupOp op) { @@ -445,7 +439,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { .Case([&](omp::TargetOp op) { checkAllocate(op, result); checkBare(op, result); - checkDependIteratorModifier(op, result); checkInReduction(op, result); checkThreadLimit(op, result); }) @@ -2097,33 +2090,35 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, return success(); } -static void -buildDependData(std::optional dependKinds, OperandRange dependVars, - LLVM::ModuleTranslation &moduleTranslation, - SmallVectorImpl &dds) { +static llvm::omp::RTLDependenceKindTy +convertDependKind(mlir::omp::ClauseTaskDepend kind) { + switch (kind) { + case mlir::omp::ClauseTaskDepend::taskdependin: + return llvm::omp::RTLDependenceKindTy::DepIn; + // The OpenMP runtime requires that the codegen for 'depend' clause for + // 'out' dependency kind must be the same as codegen for 'depend' clause + // with 'inout' dependency. + case mlir::omp::ClauseTaskDepend::taskdependout: + case mlir::omp::ClauseTaskDepend::taskdependinout: + return llvm::omp::RTLDependenceKindTy::DepInOut; + case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset: + return llvm::omp::RTLDependenceKindTy::DepMutexInOutSet; + case mlir::omp::ClauseTaskDepend::taskdependinoutset: + return llvm::omp::RTLDependenceKindTy::DepInOutSet; + } + llvm_unreachable("unhandled depend kind"); +} + +static void buildDependDataLocator( + std::optional dependKinds, OperandRange dependVars, + LLVM::ModuleTranslation &moduleTranslation, + SmallVectorImpl &dds) { if (dependVars.empty()) return; for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) { - llvm::omp::RTLDependenceKindTy type; - switch ( - cast(std::get<1>(dep)).getValue()) { - case mlir::omp::ClauseTaskDepend::taskdependin: - type = llvm::omp::RTLDependenceKindTy::DepIn; - break; - // The OpenMP runtime requires that the codegen for 'depend' clause for - // 'out' dependency kind must be the same as codegen for 'depend' clause - // with 'inout' dependency. - case mlir::omp::ClauseTaskDepend::taskdependout: - case mlir::omp::ClauseTaskDepend::taskdependinout: - type = llvm::omp::RTLDependenceKindTy::DepInOut; - break; - case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset: - type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet; - break; - case mlir::omp::ClauseTaskDepend::taskdependinoutset: - type = llvm::omp::RTLDependenceKindTy::DepInOutSet; - break; - }; + auto kind = + cast(std::get<1>(dep)).getValue(); + llvm::omp::RTLDependenceKindTy type = convertDependKind(kind); llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep)); llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal); dds.emplace_back(dd); @@ -2616,6 +2611,92 @@ buildAffinityData(mlir::omp::TaskOp &taskOp, llvm::IRBuilderBase &builder, return mlir::success(); } +// Allocates a single kmp_dep_info array sized to hold both locator +// (non-iterated) and iterated entries, fills the locator entries first, then +// runs an iterator loop for each iterator modifier object. +static mlir::LogicalResult +buildDependData(OperandRange dependVars, std::optional dependKinds, + OperandRange dependIterated, + std::optional dependIteratedKinds, + llvm::IRBuilderBase &builder, + mlir::LLVM::ModuleTranslation &moduleTranslation, + llvm::OpenMPIRBuilder::DependenciesInfo &taskDeps) { + if (dependIterated.empty()) { + buildDependDataLocator(dependKinds, dependVars, moduleTranslation, + taskDeps.Deps); + return mlir::success(); + } + + llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder(); + llvm::Type *dependInfoTy = ompBuilder.DependInfo; + unsigned numLocator = dependVars.size(); + + // Compute total count: locator deps + sum of iterator trip counts. + llvm::Value *totalCount = + llvm::ConstantInt::get(builder.getInt64Ty(), numLocator); + + llvm::SmallVector iterInfos; + for (auto iter : dependIterated) { + auto itersOp = iter.getDefiningOp(); + assert(itersOp && "depend_iterated value must be defined by omp.iterator"); + iterInfos.emplace_back(itersOp, moduleTranslation, builder); + totalCount = + builder.CreateAdd(totalCount, iterInfos.back().getTotalTrips()); + } + + // Heap-allocate the kmp_depend_info array so we don't risk + // dynamic-sized alloca outside the entry block (e.g. inside loops). + llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(dependInfoTy); + llvm::Value *depArray = + builder.CreateMalloc(ompBuilder.SizeTy, dependInfoTy, allocSize, + totalCount, /*MallocF=*/nullptr, ".dep.arr.addr"); + + // Fill non-iterated entries at indices [0, numLocator). + if (numLocator > 0) { + SmallVector dds; + buildDependDataLocator(dependKinds, dependVars, moduleTranslation, dds); + for (auto [i, dd] : llvm::enumerate(dds)) { + llvm::Value *idx = llvm::ConstantInt::get(builder.getInt64Ty(), i); + llvm::Value *entry = + builder.CreateInBoundsGEP(dependInfoTy, depArray, idx); + ompBuilder.emitTaskDependency(builder, entry, dd); + } + } + + // Fill iterated entries starting at index numLocator. + llvm::Value *offset = + llvm::ConstantInt::get(builder.getInt64Ty(), numLocator); + for (auto [i, iterInfo] : llvm::enumerate(iterInfos)) { + auto kindAttr = cast( + dependIteratedKinds->getValue()[i]); + llvm::omp::RTLDependenceKindTy rtlKind = + convertDependKind(kindAttr.getValue()); + + auto itersOp = dependIterated[i].getDefiningOp(); + if (failed(fillIteratorLoop( + itersOp, builder, moduleTranslation, iterInfo, "dep_iterator", + [&](llvm::Value *linearIV, mlir::omp::YieldOp yield) { + llvm::Value *addr = + moduleTranslation.lookupValue(yield.getResults()[0]); + llvm::Value *idx = builder.CreateAdd(offset, linearIV); + llvm::Value *entry = + builder.CreateInBoundsGEP(dependInfoTy, depArray, idx); + ompBuilder.emitTaskDependency( + builder, entry, + llvm::OpenMPIRBuilder::DependData{rtlKind, addr->getType(), + addr}); + }))) + return mlir::failure(); + + // Advance offset by the trip count of this iterator. + offset = builder.CreateAdd(offset, iterInfo.getTotalTrips()); + } + + taskDeps.DepArray = depArray; + taskDeps.NumDeps = builder.CreateTrunc(totalCount, builder.getInt32Ty()); + return mlir::success(); +} + /// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder. static LogicalResult convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, @@ -2828,16 +2909,19 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp, llvm::omp::Directive::OMPD_taskgroup); - SmallVector dds; - buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(), - moduleTranslation, dds); + llvm::OpenMPIRBuilder::DependenciesInfo dependencies; + if (failed(buildDependData(taskOp.getDependVars(), taskOp.getDependKinds(), + taskOp.getDependIterated(), + taskOp.getDependIteratedKinds(), builder, + moduleTranslation, dependencies))) + return failure(); llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = moduleTranslation.getOpenMPBuilder()->createTask( ompLoc, allocaIP, bodyCB, !taskOp.getUntied(), moduleTranslation.lookupValue(taskOp.getFinal()), - moduleTranslation.lookupValue(taskOp.getIfExpr()), dds, ad, + moduleTranslation.lookupValue(taskOp.getIfExpr()), dependencies, ad, taskOp.getMergeable(), moduleTranslation.lookupValue(taskOp.getEventHandle()), moduleTranslation.lookupValue(taskOp.getPriority())); @@ -2849,6 +2933,10 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder, popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get()); builder.restoreIP(*afterIP); + + if (dependencies.DepArray) + builder.CreateFree(dependencies.DepArray); + return success(); } @@ -7139,12 +7227,16 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, kernelInput.push_back(mapData.OriginalValue[i]); } - SmallVector dds; - buildDependData(targetOp.getDependKinds(), targetOp.getDependVars(), - moduleTranslation, dds); - llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); + + llvm::OpenMPIRBuilder::DependenciesInfo dds; + if (failed(buildDependData( + targetOp.getDependVars(), targetOp.getDependKinds(), + targetOp.getDependIterated(), targetOp.getDependIteratedKinds(), + builder, moduleTranslation, dds))) + return failure(); + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); llvm::OpenMPIRBuilder::TargetDataInfo info( @@ -7175,6 +7267,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, builder.restoreIP(*afterIP); + if (dds.DepArray) + builder.CreateFree(dds.DepArray); + // Remap access operations to declare target reference pointers for the // device, essentially generating extra loadop's as necessary if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice()) diff --git a/mlir/test/Target/LLVMIR/openmp-iterator.mlir b/mlir/test/Target/LLVMIR/openmp-iterator.mlir index faadfbdc7202f..50afb68c4ce99 100644 --- a/mlir/test/Target/LLVMIR/openmp-iterator.mlir +++ b/mlir/test/Target/LLVMIR/openmp-iterator.mlir @@ -1,4 +1,12 @@ -// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s +// RUN: split-file %s %t +// RUN: mlir-translate --mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=CHECK +// RUN: mlir-translate --mlir-to-llvmir %t/target.mlir | FileCheck %s --check-prefix=TARGET + +//--- host.mlir + +// -------------------------------------------------------------------- +// Affinity clause +// -------------------------------------------------------------------- llvm.func @task_affinity_iterator_1d(%arr: !llvm.ptr {llvm.nocapture}) { %c1 = llvm.mlir.constant(1 : i64) : i64 @@ -293,3 +301,215 @@ llvm.func @task_affinity_iterator_negative_step(%arr: !llvm.ptr {llvm.nocapture} // CHECK: [[ENTRY:%.*]] = getelementptr inbounds { i64, i64, i32 }, ptr [[AFFLIST]], i64 %omp_iterator.iv // CHECK: [[LENPTR:%.*]] = getelementptr inbounds nuw { i64, i64, i32 }, ptr [[ENTRY]], i32 0, i32 1 // CHECK: store i64 [[PHYSIV]], ptr [[LENPTR]] + +// -------------------------------------------------------------------- +// Depend clause +// -------------------------------------------------------------------- + +llvm.func @omp_task_depend_iterator_simple(%addr : !llvm.ptr) { + %c1 = llvm.mlir.constant(1 : i64) : i64 + %c10 = llvm.mlir.constant(10 : i64) : i64 + %step = llvm.mlir.constant(1 : i64) : i64 + + %it = omp.iterator(%iv: i64) = (%c1 to %c10 step %step) { + omp.yield(%addr : !llvm.ptr) + } -> !omp.iterated + + omp.task depend(taskdependin -> %it : !omp.iterated) { + omp.terminator + } + llvm.return +} + +// CHECK-LABEL: define void @omp_task_depend_iterator_simple +// CHECK-SAME: (ptr %[[ADDR:[0-9]+]]) +// CHECK: %[[DEP_ARR:.*]] = tail call ptr @malloc(i64 %mallocsize) +// +// Iterator loop: preheader -> header -> cond -> body -> inc -> header... +// CHECK: omp_dep_iterator.header: +// CHECK: %[[IV:.*]] = phi i64 [ 0, %omp_dep_iterator.preheader ], [ %[[NEXT:.*]], %omp_dep_iterator.inc ] +// CHECK: omp_dep_iterator.cond: +// CHECK: %[[CMP:.*]] = icmp ult i64 %[[IV]], 10 +// CHECK: br i1 %[[CMP]], label %omp_dep_iterator.body, label %omp_dep_iterator.exit +// +// Body: store kmp_dep_info at depArray[0 + linearIV] +// CHECK: omp_dep_iterator.body: +// CHECK: %[[IDX:.*]] = add i64 0, %[[IV]] +// CHECK: %[[ENTRY:.*]] = getelementptr inbounds %struct.kmp_dep_info, ptr %[[DEP_ARR]], i64 %[[IDX]] +// CHECK: %[[BASE_GEP:.*]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[ENTRY]], i32 0, i32 0 +// CHECK: %[[PTRINT:.*]] = ptrtoint ptr %[[ADDR]] to i64 +// CHECK: store i64 %[[PTRINT]], ptr %[[BASE_GEP]] +// CHECK: %[[LEN_GEP:.*]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[ENTRY]], i32 0, i32 1 +// CHECK: store i64 8, ptr %[[LEN_GEP]] +// CHECK: %[[FLAGS_GEP:.*]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[ENTRY]], i32 0, i32 2 +// depKind = 1 (DepIn) +// CHECK: store i8 1, ptr %[[FLAGS_GEP]] +// +// CHECK: omp_dep_iterator.inc: +// CHECK: %[[NEXT]] = add nuw i64 %[[IV]], 1 +// +// Task creation with deps, then free +// CHECK: call i32 @__kmpc_omp_task_with_deps(ptr @{{.*}}, i32 %{{.*}}, ptr %{{.*}}, i32 10, ptr %[[DEP_ARR]], i32 0, ptr null) +// CHECK: tail call void @free(ptr %[[DEP_ARR]]) + +llvm.func @omp_task_depend_iterator_mixed(%addr : !llvm.ptr, %plain : !llvm.ptr) { + %c1 = llvm.mlir.constant(1 : i64) : i64 + %c10 = llvm.mlir.constant(10 : i64) : i64 + %step = llvm.mlir.constant(1 : i64) : i64 + + %it = omp.iterator(%iv: i64) = (%c1 to %c10 step %step) { + omp.yield(%addr : !llvm.ptr) + } -> !omp.iterated + + omp.task depend(taskdependout -> %plain : !llvm.ptr, taskdependin -> %it : !omp.iterated) { + omp.terminator + } + llvm.return +} + +// CHECK-LABEL: define void @omp_task_depend_iterator_mixed +// CHECK-SAME: (ptr %[[ADDR2:[0-9]+]], ptr %[[PLAIN:[0-9]+]]) +// CHECK: %[[DEP_ARR2:.*]] = tail call ptr @malloc(i64 %mallocsize) +// +// Plain entry at index 0 +// CHECK: %[[PLAIN_ENTRY:.*]] = getelementptr inbounds %struct.kmp_dep_info, ptr %[[DEP_ARR2]], i64 0 +// CHECK: getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[PLAIN_ENTRY]], i32 0, i32 0 +// CHECK: %[[PLAIN_PTRINT:.*]] = ptrtoint ptr %[[PLAIN]] to i64 +// CHECK: store i64 %[[PLAIN_PTRINT]], ptr +// depKind = 3 (DepInOut/out) +// CHECK: store i8 3, ptr +// +// Iterator loop for iterated entry starting at offset 1 +// CHECK: omp_dep_iterator.body: +// startIdx(1) + linearIV +// CHECK: add i64 1, %omp_dep_iterator.iv +// CHECK: getelementptr inbounds %struct.kmp_dep_info, ptr %[[DEP_ARR2]] +// depKind = 1 (DepIn) +// CHECK: store i8 1, ptr +// +// CHECK: call i32 @__kmpc_omp_task_with_deps(ptr @{{.*}}, i32 %{{.*}}, ptr %{{.*}}, i32 11, ptr %[[DEP_ARR2]], i32 0, ptr null) +// CHECK: tail call void @free(ptr %[[DEP_ARR2]]) + +// Dynamic bounds: iterator bounds are function arguments, so the trip count +// and dep-array size are computed at runtime. The alloca must be placed +// after the trip-count computation (not hoisted to the entry block) +// to avoid "instruction does not dominate all uses" errors. +llvm.func @omp_task_depend_iterator_dynamic(%addr : !llvm.ptr, + %lb : i64, %ub : i64, %step : i64) { + %it = omp.iterator(%iv: i64) = (%lb to %ub step %step) { + omp.yield(%addr : !llvm.ptr) + } -> !omp.iterated + + omp.task depend(taskdependin -> %it : !omp.iterated) { + omp.terminator + } + llvm.return +} + +// CHECK-LABEL: define void @omp_task_depend_iterator_dynamic +// +// Tripcount computation from dynamic bounds +// CHECK: %[[DIFF:.*]] = sub i64 %{{.*}}, %{{.*}} +// CHECK: %[[DIV:.*]] = sdiv i64 %[[DIFF]], %{{.*}} +// CHECK: %[[TRIPS:.*]] = add i64 %[[DIV]], 1 +// CHECK: %[[SCALED:.*]] = mul i64 1, %[[TRIPS]] +// Dynamic total = 0 + scaled trip count +// CHECK: %[[TOTAL:.*]] = add i64 0, %[[SCALED]] +// +// Malloc with dynamic size +// CHECK: %[[DEP_ARR:.*]] = tail call ptr @malloc(i64 %mallocsize) +// CHECK: omp_dep_iterator.body: +// CHECK: getelementptr inbounds %struct.kmp_dep_info, ptr %[[DEP_ARR]] +// NumDeps is truncated to i32 for the runtime call +// CHECK: %[[NDEPS:.*]] = trunc i64 %[[TOTAL]] to i32 +// CHECK: call i32 @__kmpc_omp_task_with_deps(ptr @{{.*}}, i32 %{{.*}}, ptr %{{.*}}, i32 %[[NDEPS]], ptr %[[DEP_ARR]], i32 0, ptr null) +// CHECK: tail call void @free(ptr %[[DEP_ARR]]) + +// Dynamic bounds with mixed plain + iterated depends. +llvm.func @omp_task_depend_iterator_dynamic_mixed(%addr : !llvm.ptr, + %plain : !llvm.ptr, %lb : i64, %ub : i64, %step : i64) { + %it = omp.iterator(%iv: i64) = (%lb to %ub step %step) { + omp.yield(%addr : !llvm.ptr) + } -> !omp.iterated + + omp.task depend(taskdependout -> %plain : !llvm.ptr, taskdependin -> %it : !omp.iterated) { + omp.terminator + } + llvm.return +} + +// CHECK-LABEL: define void @omp_task_depend_iterator_dynamic_mixed +// CHECK: %[[TRIPS2:.*]] = mul i64 1, %{{.*}} +// total = 1 (plain) + dynamic trip count +// CHECK: %[[TOTAL2:.*]] = add i64 1, %[[TRIPS2]] +// CHECK: %[[DEP_ARR2:.*]] = tail call ptr @malloc(i64 %mallocsize) +// Plain entry at index 0 +// CHECK: %[[PLAIN_ENTRY:.*]] = getelementptr inbounds %struct.kmp_dep_info, ptr %[[DEP_ARR2]], i64 0 +// CHECK: store i8 3, ptr +// Iterator loop +// CHECK: omp_dep_iterator.body: +// CHECK: add i64 1, %omp_dep_iterator.iv +// CHECK: %[[NDEPS2:.*]] = trunc i64 %[[TOTAL2]] to i32 +// CHECK: call i32 @__kmpc_omp_task_with_deps(ptr @{{.*}}, i32 %{{.*}}, ptr %{{.*}}, i32 %[[NDEPS2]], ptr %[[DEP_ARR2]], i32 0, ptr null) +// CHECK: tail call void @free(ptr %[[DEP_ARR2]]) + +//--- target.mlir + +// -------------------------------------------------------------------- +// Depend clause on target construct +// -------------------------------------------------------------------- + +// Target construct with iterator-based depend clause. +// The iterator(i=1:10) should allocate a kmp_dep_info[10] array, fill it via +// a dep_iterator loop, then emit __kmpc_omp_wait_deps with ndeps=10. +module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} { + llvm.func @omp_target_depend_iterator(%addr: !llvm.ptr) { + %c1 = llvm.mlir.constant(1 : i64) : i64 + %c10 = llvm.mlir.constant(10 : i64) : i64 + %step = llvm.mlir.constant(1 : i64) : i64 + + %it = omp.iterator(%iv: i64) = (%c1 to %c10 step %step) { + omp.yield(%addr : !llvm.ptr) + } -> !omp.iterated + + %map = omp.map.info var_ptr(%addr : !llvm.ptr, i32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = "data"} + omp.target depend(taskdependin -> %it : !omp.iterated) map_entries(%map -> %arg0 : !llvm.ptr) { + omp.terminator + } + llvm.return + } +} + +// TARGET-LABEL: define void @omp_target_depend_iterator +// TARGET-SAME: (ptr %[[ADDR:[0-9]+]]) +// TARGET-DAG: %[[DEP_ARR:.*]] = tail call ptr @malloc(i64 %mallocsize) +// +// Iterator loop: preheader -> header -> cond -> body -> inc -> header... +// TARGET: omp_dep_iterator.header: +// TARGET: %[[IV:.*]] = phi i64 [ 0, %omp_dep_iterator.preheader ], [ %[[NEXT:.*]], %omp_dep_iterator.inc ] +// TARGET: omp_dep_iterator.cond: +// TARGET: %[[CMP:.*]] = icmp ult i64 %[[IV]], 10 +// TARGET: br i1 %[[CMP]], label %omp_dep_iterator.body, label %omp_dep_iterator.exit +// +// Body: store kmp_dep_info at depArray[0 + linearIV] +// TARGET: omp_dep_iterator.body: +// TARGET: %[[IDX:.*]] = add i64 0, %[[IV]] +// TARGET: %[[ENTRY:.*]] = getelementptr inbounds %struct.kmp_dep_info, ptr %[[DEP_ARR]], i64 %[[IDX]] +// TARGET: %[[BASE_GEP:.*]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[ENTRY]], i32 0, i32 0 +// TARGET: %[[PTRINT:.*]] = ptrtoint ptr %[[ADDR]] to i64 +// TARGET: store i64 %[[PTRINT]], ptr %[[BASE_GEP]] +// TARGET: %[[LEN_GEP:.*]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[ENTRY]], i32 0, i32 1 +// TARGET: store i64 8, ptr %[[LEN_GEP]] +// TARGET: %[[FLAGS_GEP:.*]] = getelementptr inbounds nuw %struct.kmp_dep_info, ptr %[[ENTRY]], i32 0, i32 2 +// depKind = 1 (DepIn) +// TARGET: store i8 1, ptr %[[FLAGS_GEP]] +// +// TARGET: omp_dep_iterator.inc: +// TARGET: %[[NEXT]] = add nuw i64 %[[IV]], 1 +// +// Target task: wait_deps with ndeps=10, then begin_if0/proxy/complete_if0, then free +// TARGET: call void @__kmpc_omp_wait_deps(ptr @{{.*}}, i32 %{{.*}}, i32 10, ptr %[[DEP_ARR]], i32 0, ptr null) +// TARGET: call void @__kmpc_omp_task_begin_if0 +// TARGET: call void @.omp_target_task_proxy_func +// TARGET: call void @__kmpc_omp_task_complete_if0 +// TARGET: tail call void @free(ptr %[[DEP_ARR]]) diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index e0872226531e6..af5da3dc8c3a4 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -190,36 +190,6 @@ llvm.func @target_in_reduction(%x : !llvm.ptr) { // ----- -llvm.func @task_depend_iterator_modifier(%lb : i64, %ub : i64, %step : i64, - %addr : !llvm.ptr) { - %it = omp.iterator(%iv: i64) = (%lb to %ub step %step) { - omp.yield(%addr : !llvm.ptr) - } -> !omp.iterated - // expected-error@below {{not yet implemented: Unhandled clause depend with iterator modifier in omp.task operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.task}} - omp.task depend(taskdependin -> %it : !omp.iterated) { - omp.terminator - } - llvm.return -} - -// ----- - -llvm.func @target_depend_iterator_modifier(%lb : i64, %ub : i64, %step : i64, - %addr : !llvm.ptr) { - %it = omp.iterator(%iv: i64) = (%lb to %ub step %step) { - omp.yield(%addr : !llvm.ptr) - } -> !omp.iterated - // expected-error@below {{not yet implemented: Unhandled clause depend with iterator modifier in omp.target operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.target}} - omp.target depend(taskdependin -> %it : !omp.iterated) { - omp.terminator - } - llvm.return -} - -// ----- - llvm.func @target_enter_data_depend(%x: !llvm.ptr) { // expected-error@below {{not yet implemented: Unhandled clause depend in omp.target_enter_data operation}} // expected-error@below {{LLVM Translation failed for operation: omp.target_enter_data}}