162 changes: 69 additions & 93 deletions clang/test/OpenMP/irbuilder_nested_parallel_for.c

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ class OpenMPIRBuilder {
struct OutlineInfo {
using PostOutlineCBTy = std::function<void(Function &)>;
PostOutlineCBTy PostOutlineCB;
BasicBlock *EntryBB, *ExitBB;
BasicBlock *EntryBB, *ExitBB, *OuterAllocaBB;
SmallVector<Value *, 2> ExcludeArgsFromAggregate;

/// Collect all blocks in between EntryBB and ExitBB in both the given
Expand Down
13 changes: 11 additions & 2 deletions llvm/include/llvm/Transforms/Utils/CodeExtractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ class CodeExtractorAnalysisCache {
BranchProbabilityInfo *BPI;
AssumptionCache *AC;

// A block outside of the extraction set where any intermediate
// allocations will be placed inside. If this is null, allocations
// will be placed in the entry block of the function.
BasicBlock *AllocationBlock;

// If true, varargs functions can be extracted.
bool AllowVarArgs;

Expand Down Expand Up @@ -120,11 +125,15 @@ class CodeExtractorAnalysisCache {
/// code is extracted, including vastart. If AllowAlloca is true, then
/// extraction of blocks containing alloca instructions would be possible,
/// however code extractor won't validate whether extraction is legal.
/// Any new allocations will be placed in the AllocationBlock, unless
/// it is null, in which case it will be placed in the entry block of
/// the function from which the code is being extracted.
CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
BranchProbabilityInfo *BPI = nullptr,
AssumptionCache *AC = nullptr,
bool AllowVarArgs = false, bool AllowAlloca = false,
AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
bool AllowAlloca = false,
BasicBlock *AllocationBlock = nullptr,
std::string Suffix = "");

/// Create a code extractor for a loop body.
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
/* AssumptionCache */ nullptr,
/* AllowVarArgs */ true,
/* AllowAlloca */ true,
/* AllocaBlock*/ OI.OuterAllocaBB,
/* Suffix */ ".omp_par");

LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n");
Expand Down Expand Up @@ -878,6 +879,7 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
FiniCB(PreFiniIP);

OI.OuterAllocaBB = OuterAllocaBlock;
OI.EntryBB = PRegEntryBB;
OI.ExitBB = PRegExitBB;

Expand All @@ -901,6 +903,7 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
/* AssumptionCache */ nullptr,
/* AllowVarArgs */ true,
/* AllowAlloca */ true,
/* AllocationBlock */ OuterAllocaBlock,
/* Suffix */ ".omp_par");

// Find inputs to, outputs from the code region.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/IPO/HotColdSplitting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ Function *HotColdSplitting::extractColdRegion(
// TODO: Pass BFI and BPI to update profile information.
CodeExtractor CE(Region, &DT, /* AggregateArgs */ false, /* BFI */ nullptr,
/* BPI */ nullptr, AC, /* AllowVarArgs */ false,
/* AllowAlloca */ false,
/* AllowAlloca */ false, /* AllocaBlock */ nullptr,
/* Suffix */ "cold." + std::to_string(Count));

// Perform a simple cost/benefit analysis to decide whether or not to permit
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Transforms/IPO/IROutliner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2679,7 +2679,7 @@ unsigned IROutliner::doOutline(Module &M) {
OS->Candidate->getBasicBlocks(BlocksInRegion, BE);
OS->CE = new (ExtractorAllocator.Allocate())
CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false,
false, "outlined");
false, nullptr, "outlined");
findAddInputsOutputs(M, *OS, NotSame);
if (!OS->IgnoreRegion)
OutlinedRegions.push_back(OS);
Expand Down Expand Up @@ -2790,7 +2790,7 @@ unsigned IROutliner::doOutline(Module &M) {
OS->Candidate->getBasicBlocks(BlocksInRegion, BE);
OS->CE = new (ExtractorAllocator.Allocate())
CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false,
false, "outlined");
false, nullptr, "outlined");
bool FunctionOutlined = extractSection(*OS);
if (FunctionOutlined) {
unsigned StartIdx = OS->Candidate->getStartIdx();
Expand Down
14 changes: 8 additions & 6 deletions llvm/lib/Transforms/Utils/CodeExtractor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,10 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
bool AggregateArgs, BlockFrequencyInfo *BFI,
BranchProbabilityInfo *BPI, AssumptionCache *AC,
bool AllowVarArgs, bool AllowAlloca,
std::string Suffix)
BasicBlock *AllocationBlock, std::string Suffix)
: DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
BPI(BPI), AC(AC), AllowVarArgs(AllowVarArgs),
BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
AllowVarArgs(AllowVarArgs),
Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
Suffix(Suffix) {}

Expand All @@ -257,7 +258,7 @@ CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
BranchProbabilityInfo *BPI, AssumptionCache *AC,
std::string Suffix)
: DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
BPI(BPI), AC(AC), AllowVarArgs(false),
BPI(BPI), AC(AC), AllocationBlock(nullptr), AllowVarArgs(false),
Blocks(buildExtractionBlockSet(L.getBlocks(), &DT,
/* AllowVarArgs */ false,
/* AllowAlloca */ false)),
Expand Down Expand Up @@ -1189,9 +1190,10 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,

// Allocate a struct at the beginning of this function
StructArgTy = StructType::get(newFunction->getContext(), ArgTypes);
Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
"structArg",
&codeReplacer->getParent()->front().front());
Struct = new AllocaInst(
StructArgTy, DL.getAllocaAddrSpace(), nullptr, "structArg",
AllocationBlock ? &*AllocationBlock->getFirstInsertionPt()
: &codeReplacer->getParent()->front().front());
params.push_back(Struct);

// Store aggregated inputs in the struct.
Expand Down
41 changes: 41 additions & 0 deletions mlir/test/Target/LLVMIR/openmp-nested.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s

module {
llvm.func @printf(!llvm.ptr<i8>, ...) -> i32
llvm.mlir.global internal constant @str0("WG size of kernel = %d X %d\0A\00")

llvm.func @main(%arg0: i32, %arg1: !llvm.ptr<ptr<i8>>) -> i32 {
omp.parallel {
%0 = llvm.mlir.constant(1 : index) : i64
%1 = llvm.mlir.constant(10 : index) : i64
%2 = llvm.mlir.constant(0 : index) : i64
%4 = llvm.mlir.constant(0 : i32) : i32
%12 = llvm.alloca %0 x i64 : (i64) -> !llvm.ptr<i64>
omp.wsloop (%arg2) : i64 = (%2) to (%1) step (%0) {
omp.parallel {
omp.wsloop (%arg3) : i64 = (%2) to (%0) step (%0) {
llvm.store %2, %12 : !llvm.ptr<i64>
omp.yield
}
omp.terminator
}
%19 = llvm.load %12 : !llvm.ptr<i64>
%20 = llvm.trunc %19 : i64 to i32
%5 = llvm.mlir.addressof @str0 : !llvm.ptr<array<29 x i8>>
%6 = llvm.getelementptr %5[%4, %4] : (!llvm.ptr<array<29 x i8>>, i32, i32) -> !llvm.ptr<i8>
%21 = llvm.call @printf(%6, %20, %20) : (!llvm.ptr<i8>, i32, i32) -> i32
omp.yield
}
omp.terminator
}
%a4 = llvm.mlir.constant(0 : i32) : i32
llvm.return %a4 : i32
}

}

// CHECK: call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* @1, i32 0, void (i32*, i32*, ...)* bitcast (void (i32*, i32*)* @[[inner1:.+]] to void (i32*, i32*, ...)*))

// CHECK: define internal void @[[inner1]]
// CHECK: %[[structArg:.+]] = alloca { i64* }
// CHECK: call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* @3, i32 1, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, { i64* }*)* @[[inner2:.+]] to void (i32*, i32*, ...)*), { i64* }* %[[structArg]])