diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h index bb23cf4a9a3cb..27b34ef023db7 100644 --- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h +++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h @@ -114,6 +114,10 @@ class CodeExtractorAnalysisCache { // label, if non-empty, otherwise "extracted". std::string Suffix; + // If true, the outlined function has aggregate argument in zero address + // space. + bool ArgsInZeroAddressSpace; + public: /// Create a code extractor for a sequence of blocks. /// @@ -128,13 +132,16 @@ class CodeExtractorAnalysisCache { /// Any new allocations will be placed in the AllocationBlock, unless /// it is null, in which case it will be placed in the entry block of /// the function from which the code is being extracted. + /// If ArgsInZeroAddressSpace param is set to true, then the aggregate + /// param pointer of the outlined function is declared in zero address + /// space. CodeExtractor(ArrayRef BBs, DominatorTree *DT = nullptr, bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr, BranchProbabilityInfo *BPI = nullptr, AssumptionCache *AC = nullptr, bool AllowVarArgs = false, bool AllowAlloca = false, BasicBlock *AllocationBlock = nullptr, - std::string Suffix = ""); + std::string Suffix = "", bool ArgsInZeroAddressSpace = false); /// Create a code extractor for a loop body. /// diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index ae7ed296c45ea..b251a85cf85f9 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -245,12 +245,13 @@ CodeExtractor::CodeExtractor(ArrayRef BBs, DominatorTree *DT, bool AggregateArgs, BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI, AssumptionCache *AC, bool AllowVarArgs, bool AllowAlloca, - BasicBlock *AllocationBlock, std::string Suffix) + BasicBlock *AllocationBlock, std::string Suffix, + bool ArgsInZeroAddressSpace) : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), BPI(BPI), AC(AC), AllocationBlock(AllocationBlock), AllowVarArgs(AllowVarArgs), Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)), - Suffix(Suffix) {} + Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {} CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, BlockFrequencyInfo *BFI, @@ -866,7 +867,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, StructType *StructTy = nullptr; if (AggregateArgs && !AggParamTy.empty()) { StructTy = StructType::get(M->getContext(), AggParamTy); - ParamTy.push_back(PointerType::get(StructTy, DL.getAllocaAddrSpace())); + ParamTy.push_back(PointerType::get( + StructTy, ArgsInZeroAddressSpace ? 0 : DL.getAllocaAddrSpace())); } LLVM_DEBUG({ @@ -1187,8 +1189,15 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, StructArgTy, DL.getAllocaAddrSpace(), nullptr, "structArg", AllocationBlock ? &*AllocationBlock->getFirstInsertionPt() : &codeReplacer->getParent()->front().front()); - params.push_back(Struct); + if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) { + auto *StructSpaceCast = new AddrSpaceCastInst( + Struct, PointerType ::get(Context, 0), "structArg.ascast"); + StructSpaceCast->insertAfter(Struct); + params.push_back(StructSpaceCast); + } else { + params.push_back(Struct); + } // Store aggregated inputs in the struct. for (unsigned i = 0, e = StructValues.size(); i != e; ++i) { if (inputs.contains(StructValues[i])) { diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp index c142729e2c6f4..528d332393326 100644 --- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp +++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp @@ -555,4 +555,64 @@ TEST(CodeExtractor, PartialAggregateArgs) { EXPECT_FALSE(verifyFunction(*Outlined)); EXPECT_FALSE(verifyFunction(*Func)); } + +TEST(CodeExtractor, OpenMPAggregateArgs) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"ir( + target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8" + target triple = "amdgcn-amd-amdhsa" + + define void @foo(ptr %0) { + %2= alloca ptr, align 8, addrspace(5) + %3 = addrspacecast ptr addrspace(5) %2 to ptr + store ptr %0, ptr %3, align 8 + %4 = load ptr, ptr %3, align 8 + br label %entry + + entry: + br label %extract + + extract: + store i64 10, ptr %4, align 4 + br label %exit + + exit: + ret void + } + )ir", + Err, Ctx)); + Function *Func = M->getFunction("foo"); + SmallVector Blocks{getBlockByName(Func, "extract")}; + + // Create the CodeExtractor with arguments aggregation enabled. + // Outlined function argument should be declared in 0 address space + // even if the default alloca address space is 5. + CodeExtractor CE(Blocks, /* DominatorTree */ nullptr, + /* AggregateArgs */ true, /* BlockFrequencyInfo */ nullptr, + /* BranchProbabilityInfo */ nullptr, + /* AssumptionCache */ nullptr, + /* AllowVarArgs */ true, + /* AllowAlloca */ true, + /* AllocaBlock*/ &Func->getEntryBlock(), + /* Suffix */ ".outlined", + /* ArgsInZeroAddressSpace */ true); + + EXPECT_TRUE(CE.isEligible()); + + CodeExtractorAnalysisCache CEAC(*Func); + SetVector Inputs, Outputs, SinkingCands, HoistingCands; + BasicBlock *CommonExit = nullptr; + CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); + CE.findInputsOutputs(Inputs, Outputs, SinkingCands); + + Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs); + EXPECT_TRUE(Outlined); + EXPECT_EQ(Outlined->arg_size(), 1U); + // Check address space of outlined argument is ptr in address space 0 + EXPECT_EQ(Outlined->getArg(0)->getType(), + PointerType::get(M->getContext(), 0)); + EXPECT_FALSE(verifyFunction(*Outlined)); + EXPECT_FALSE(verifyFunction(*Func)); +} } // end anonymous namespace