diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h index 32dcdd587f3b3..f8812e7955b82 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h @@ -277,6 +277,16 @@ enum class RTLDependenceKindTy { DepOmpAllMem = 0x80, }; +/// A type of worksharing loop construct +enum class WorksharingLoopType { + // Worksharing `for`-loop + ForStaticLoop, + // Worksharing `distrbute`-loop + DistributeStaticLoop, + // Worksharing `distrbute parallel for`-loop + DistributeForStaticLoop +}; + } // end namespace omp } // end namespace llvm diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 334eaf01a59c9..abbef03d02cb1 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -900,6 +900,28 @@ class OpenMPIRBuilder { omp::OpenMPOffloadMappingFlags MemberOfFlag); private: + /// Modifies the canonical loop to be a statically-scheduled workshare loop + /// which is executed on the device + /// + /// This takes a \p CLI representing a canonical loop, such as the one + /// created by \see createCanonicalLoop and emits additional instructions to + /// turn it into a workshare loop. In particular, it calls to an OpenMP + /// runtime function in the preheader to call OpenMP device rtl function + /// which handles worksharing of loop body interations. + /// + /// \param DL Debug location for instructions added for the + /// workshare-loop construct itself. + /// \param CLI A descriptor of the canonical loop to workshare. + /// \param AllocaIP An insertion point for Alloca instructions usable in the + /// preheader of the loop. + /// \param LoopType Information about type of loop worksharing. + /// It corresponds to type of loop workshare OpenMP pragma. + /// + /// \returns Point where to insert code after the workshare construct. + InsertPointTy applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI, + InsertPointTy AllocaIP, + omp::WorksharingLoopType LoopType); + /// Modifies the canonical loop to be a statically-scheduled workshare loop. /// /// This takes a \p LoopInfo representing a canonical loop, such as the one @@ -1012,6 +1034,8 @@ class OpenMPIRBuilder { /// present in the schedule clause. /// \param HasOrderedClause Whether the (parameterless) ordered clause is /// present. + /// \param LoopType Information about type of loop worksharing. + /// It corresponds to type of loop workshare OpenMP pragma. /// /// \returns Point where to insert code after the workshare construct. InsertPointTy applyWorkshareLoop( @@ -1020,7 +1044,9 @@ class OpenMPIRBuilder { llvm::omp::ScheduleKind SchedKind = llvm::omp::OMP_SCHEDULE_Default, Value *ChunkSize = nullptr, bool HasSimdModifier = false, bool HasMonotonicModifier = false, bool HasNonmonotonicModifier = false, - bool HasOrderedClause = false); + bool HasOrderedClause = false, + omp::WorksharingLoopType LoopType = + omp::WorksharingLoopType::ForStaticLoop); /// Tile a loop nest. /// diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 690d6cbaa67b3..38af4e3cf94ba 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -2674,11 +2674,242 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop( return {DispatchAfter, DispatchAfter->getFirstInsertionPt()}; } +// Returns an LLVM function to call for executing an OpenMP static worksharing +// for loop depending on `type`. Only i32 and i64 are supported by the runtime. +// Always interpret integers as unsigned similarly to CanonicalLoopInfo. +static FunctionCallee +getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder, + WorksharingLoopType LoopType) { + unsigned Bitwidth = Ty->getIntegerBitWidth(); + Module &M = OMPBuilder->M; + switch (LoopType) { + case WorksharingLoopType::ForStaticLoop: + if (Bitwidth == 32) + return OMPBuilder->getOrCreateRuntimeFunction( + M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u); + if (Bitwidth == 64) + return OMPBuilder->getOrCreateRuntimeFunction( + M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u); + break; + case WorksharingLoopType::DistributeStaticLoop: + if (Bitwidth == 32) + return OMPBuilder->getOrCreateRuntimeFunction( + M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u); + if (Bitwidth == 64) + return OMPBuilder->getOrCreateRuntimeFunction( + M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u); + break; + case WorksharingLoopType::DistributeForStaticLoop: + if (Bitwidth == 32) + return OMPBuilder->getOrCreateRuntimeFunction( + M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u); + if (Bitwidth == 64) + return OMPBuilder->getOrCreateRuntimeFunction( + M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u); + break; + } + if (Bitwidth != 32 && Bitwidth != 64) { + llvm_unreachable("Unknown OpenMP loop iterator bitwidth"); + } + llvm_unreachable("Unknown type of OpenMP worksharing loop"); +} + +// Inserts a call to proper OpenMP Device RTL function which handles +// loop worksharing. +static void createTargetLoopWorkshareCall( + OpenMPIRBuilder *OMPBuilder, WorksharingLoopType LoopType, + BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg, + Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) { + Type *TripCountTy = TripCount->getType(); + Module &M = OMPBuilder->M; + IRBuilder<> &Builder = OMPBuilder->Builder; + FunctionCallee RTLFn = + getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType); + SmallVector RealArgs; + RealArgs.push_back(Ident); + RealArgs.push_back(Builder.CreateBitCast(&LoopBodyFn, ParallelTaskPtr)); + RealArgs.push_back(LoopBodyArg); + RealArgs.push_back(TripCount); + if (LoopType == WorksharingLoopType::DistributeStaticLoop) { + RealArgs.push_back(ConstantInt::get(TripCountTy, 0)); + Builder.CreateCall(RTLFn, RealArgs); + return; + } + FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction( + M, omp::RuntimeFunction::OMPRTL_omp_get_num_threads); + Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())}); + Value *NumThreads = Builder.CreateCall(RTLNumThreads, {}); + + RealArgs.push_back( + Builder.CreateZExtOrTrunc(NumThreads, TripCountTy, "num.threads.cast")); + RealArgs.push_back(ConstantInt::get(TripCountTy, 0)); + if (LoopType == WorksharingLoopType::DistributeForStaticLoop) { + RealArgs.push_back(ConstantInt::get(TripCountTy, 0)); + } + + Builder.CreateCall(RTLFn, RealArgs); +} + +static void +workshareLoopTargetCallback(OpenMPIRBuilder *OMPIRBuilder, + CanonicalLoopInfo *CLI, Value *Ident, + Function &OutlinedFn, Type *ParallelTaskPtr, + const SmallVector &ToBeDeleted, + WorksharingLoopType LoopType) { + IRBuilder<> &Builder = OMPIRBuilder->Builder; + BasicBlock *Preheader = CLI->getPreheader(); + Value *TripCount = CLI->getTripCount(); + + // After loop body outling, the loop body contains only set up + // of loop body argument structure and the call to the outlined + // loop body function. Firstly, we need to move setup of loop body args + // into loop preheader. + Preheader->splice(std::prev(Preheader->end()), CLI->getBody(), + CLI->getBody()->begin(), std::prev(CLI->getBody()->end())); + + // The next step is to remove the whole loop. We do not it need anymore. + // That's why make an unconditional branch from loop preheader to loop + // exit block + Builder.restoreIP({Preheader, Preheader->end()}); + Preheader->getTerminator()->eraseFromParent(); + Builder.CreateBr(CLI->getExit()); + + // Delete dead loop blocks + OpenMPIRBuilder::OutlineInfo CleanUpInfo; + SmallPtrSet RegionBlockSet; + SmallVector BlocksToBeRemoved; + CleanUpInfo.EntryBB = CLI->getHeader(); + CleanUpInfo.ExitBB = CLI->getExit(); + CleanUpInfo.collectBlocks(RegionBlockSet, BlocksToBeRemoved); + DeleteDeadBlocks(BlocksToBeRemoved); + + // Find the instruction which corresponds to loop body argument structure + // and remove the call to loop body function instruction. + Value *LoopBodyArg; + User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser(); + assert(OutlinedFnUser && + "Expected unique undroppable user of outlined function"); + CallInst *OutlinedFnCallInstruction = dyn_cast(OutlinedFnUser); + assert(OutlinedFnCallInstruction && "Expected outlined function call"); + assert((OutlinedFnCallInstruction->getParent() == Preheader) && + "Expected outlined function call to be located in loop preheader"); + // Check in case no argument structure has been passed. + if (OutlinedFnCallInstruction->arg_size() > 1) + LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(1); + else + LoopBodyArg = Constant::getNullValue(Builder.getPtrTy()); + OutlinedFnCallInstruction->eraseFromParent(); + + createTargetLoopWorkshareCall(OMPIRBuilder, LoopType, Preheader, Ident, + LoopBodyArg, ParallelTaskPtr, TripCount, + OutlinedFn); + + for (auto &ToBeDeletedItem : ToBeDeleted) + ToBeDeletedItem->eraseFromParent(); + CLI->invalidate(); +} + +OpenMPIRBuilder::InsertPointTy +OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI, + InsertPointTy AllocaIP, + WorksharingLoopType LoopType) { + uint32_t SrcLocStrSize; + Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize); + Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize); + + OutlineInfo OI; + OI.OuterAllocaBB = CLI->getPreheader(); + Function *OuterFn = CLI->getPreheader()->getParent(); + + // Instructions which need to be deleted at the end of code generation + SmallVector ToBeDeleted; + + OI.OuterAllocaBB = AllocaIP.getBlock(); + + // Mark the body loop as region which needs to be extracted + OI.EntryBB = CLI->getBody(); + OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(), + "omp.prelatch", true); + + // Prepare loop body for extraction + Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()}); + + // Insert new loop counter variable which will be used only in loop + // body. + AllocaInst *NewLoopCnt = Builder.CreateAlloca(CLI->getIndVarType(), 0, ""); + Instruction *NewLoopCntLoad = + Builder.CreateLoad(CLI->getIndVarType(), NewLoopCnt); + // New loop counter instructions are redundant in the loop preheader when + // code generation for workshare loop is finshed. That's why mark them as + // ready for deletion. + ToBeDeleted.push_back(NewLoopCntLoad); + ToBeDeleted.push_back(NewLoopCnt); + + // Analyse loop body region. Find all input variables which are used inside + // loop body region. + SmallPtrSet ParallelRegionBlockSet; + SmallVector Blocks; + OI.collectBlocks(ParallelRegionBlockSet, Blocks); + SmallVector BlocksT(ParallelRegionBlockSet.begin(), + ParallelRegionBlockSet.end()); + + CodeExtractorAnalysisCache CEAC(*OuterFn); + CodeExtractor Extractor(Blocks, + /* DominatorTree */ nullptr, + /* AggregateArgs */ true, + /* BlockFrequencyInfo */ nullptr, + /* BranchProbabilityInfo */ nullptr, + /* AssumptionCache */ nullptr, + /* AllowVarArgs */ true, + /* AllowAlloca */ true, + /* AllocationBlock */ CLI->getPreheader(), + /* Suffix */ ".omp_wsloop", + /* AggrArgsIn0AddrSpace */ true); + + BasicBlock *CommonExit = nullptr; + SetVector Inputs, Outputs, SinkingCands, HoistingCands; + + // Find allocas outside the loop body region which are used inside loop + // body + Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); + + // We need to model loop body region as the function f(cnt, loop_arg). + // That's why we replace loop induction variable by the new counter + // which will be one of loop body function argument + for (auto Use = CLI->getIndVar()->user_begin(); + Use != CLI->getIndVar()->user_end(); ++Use) { + if (Instruction *Inst = dyn_cast(*Use)) { + if (ParallelRegionBlockSet.count(Inst->getParent())) { + Inst->replaceUsesOfWith(CLI->getIndVar(), NewLoopCntLoad); + } + } + } + // Make sure that loop counter variable is not merged into loop body + // function argument structure and it is passed as separate variable + OI.ExcludeArgsFromAggregate.push_back(NewLoopCntLoad); + + // PostOutline CB is invoked when loop body function is outlined and + // loop body is replaced by call to outlined function. We need to add + // call to OpenMP device rtl inside loop preheader. OpenMP device rtl + // function will handle loop control logic. + // + OI.PostOutlineCB = [=, ToBeDeletedVec = + std::move(ToBeDeleted)](Function &OutlinedFn) { + workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ParallelTaskPtr, + ToBeDeletedVec, LoopType); + }; + addOutlineInfo(std::move(OI)); + return CLI->getAfterIP(); +} + OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop( DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP, - bool NeedsBarrier, llvm::omp::ScheduleKind SchedKind, - llvm::Value *ChunkSize, bool HasSimdModifier, bool HasMonotonicModifier, - bool HasNonmonotonicModifier, bool HasOrderedClause) { + bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize, + bool HasSimdModifier, bool HasMonotonicModifier, + bool HasNonmonotonicModifier, bool HasOrderedClause, + WorksharingLoopType LoopType) { + if (Config.isTargetDevice()) + return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType); OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType( SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier, HasNonmonotonicModifier, HasOrderedClause); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 2876aa49da402..193ada3a4ea32 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -2228,9 +2228,77 @@ TEST_F(OpenMPIRBuilderTest, UnrollLoopHeuristic) { EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.unroll.enable")); } +TEST_F(OpenMPIRBuilderTest, StaticWorkshareLoopTarget) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + std::string oldDLStr = M->getDataLayoutStr(); + M->setDataLayout( + "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:" + "256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:" + "256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8"); + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.Config.IsTargetDevice = true; + OMPBuilder.initialize(); + IRBuilder<> Builder(BB); + OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); + InsertPointTy AllocaIP = Builder.saveIP(); + + Type *LCTy = Type::getInt32Ty(Ctx); + Value *StartVal = ConstantInt::get(LCTy, 10); + Value *StopVal = ConstantInt::get(LCTy, 52); + Value *StepVal = ConstantInt::get(LCTy, 2); + auto LoopBodyGen = [&](InsertPointTy, Value *) {}; + + CanonicalLoopInfo *CLI = OMPBuilder.createCanonicalLoop( + Loc, LoopBodyGen, StartVal, StopVal, StepVal, false, false); + BasicBlock *Preheader = CLI->getPreheader(); + Value *TripCount = CLI->getTripCount(); + + Builder.SetInsertPoint(BB, BB->getFirstInsertionPt()); + + IRBuilder<>::InsertPoint AfterIP = OMPBuilder.applyWorkshareLoop( + DL, CLI, AllocaIP, true, OMP_SCHEDULE_Static, nullptr, false, false, + false, false, WorksharingLoopType::ForStaticLoop); + Builder.restoreIP(AfterIP); + Builder.CreateRetVoid(); + + OMPBuilder.finalize(); + EXPECT_FALSE(verifyModule(*M, &errs())); + + CallInst *WorkshareLoopRuntimeCall = nullptr; + int WorkshareLoopRuntimeCallCnt = 0; + for (auto Inst = Preheader->begin(); Inst != Preheader->end(); ++Inst) { + CallInst *Call = dyn_cast(Inst); + if (!Call) + continue; + if (!Call->getCalledFunction()) + continue; + + if (Call->getCalledFunction()->getName() == "__kmpc_for_static_loop_4u") { + WorkshareLoopRuntimeCall = Call; + WorkshareLoopRuntimeCallCnt++; + } + } + EXPECT_NE(WorkshareLoopRuntimeCall, nullptr); + // Verify that there is only one call to workshare loop function + EXPECT_EQ(WorkshareLoopRuntimeCallCnt, 1); + // Check that pointer to loop body function is passed as second argument + Value *LoopBodyFuncArg = WorkshareLoopRuntimeCall->getArgOperand(1); + EXPECT_EQ(Builder.getPtrTy(), LoopBodyFuncArg->getType()); + Function *ArgFunction = dyn_cast(LoopBodyFuncArg); + EXPECT_NE(ArgFunction, nullptr); + EXPECT_EQ(ArgFunction->arg_size(), 1u); + EXPECT_EQ(ArgFunction->getArg(0)->getType(), TripCount->getType()); + // Check that no variables except for loop counter are used in loop body + EXPECT_EQ(Constant::getNullValue(Builder.getPtrTy()), + WorkshareLoopRuntimeCall->getArgOperand(2)); + // Check loop trip count argument + EXPECT_EQ(TripCount, WorkshareLoopRuntimeCall->getArgOperand(3)); +} + TEST_F(OpenMPIRBuilderTest, StaticWorkShareLoop) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.Config.IsTargetDevice = false; OMPBuilder.initialize(); IRBuilder<> Builder(BB); OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); @@ -2331,6 +2399,7 @@ TEST_P(OpenMPIRBuilderTestWithIVBits, StaticChunkedWorkshareLoop) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.Config.IsTargetDevice = false; BasicBlock *Body; CallInst *Call; @@ -2405,6 +2474,7 @@ INSTANTIATE_TEST_SUITE_P(IVBits, OpenMPIRBuilderTestWithIVBits, TEST_P(OpenMPIRBuilderTestWithParams, DynamicWorkShareLoop) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.Config.IsTargetDevice = false; OMPBuilder.initialize(); IRBuilder<> Builder(BB); OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL}); @@ -2562,6 +2632,7 @@ INSTANTIATE_TEST_SUITE_P( TEST_F(OpenMPIRBuilderTest, DynamicWorkShareLoopOrdered) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.Config.IsTargetDevice = false; OMPBuilder.initialize(); IRBuilder<> Builder(BB); OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});