292 changes: 164 additions & 128 deletions clang/test/OpenMP/cancel_codegen.cpp

Large diffs are not rendered by default.

30 changes: 28 additions & 2 deletions clang/test/OpenMP/irbuilder_nested_openmp_parallel_empty.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ void nested_parallel_0(void) {

// ALL-LABEL: @_Z17nested_parallel_1Pfid(
// ALL-NEXT: entry:
// ALL-NEXT: [[STRUCTARG14:%.*]] = alloca { { i32*, double*, float** }*, i32*, double*, float** }, align 8
// ALL-NEXT: [[STRUCTARG:%.*]] = alloca { i32*, double*, float** }, align 8
// ALL-NEXT: [[R_ADDR:%.*]] = alloca float*, align 8
// ALL-NEXT: [[A_ADDR:%.*]] = alloca i32, align 4
// ALL-NEXT: [[B_ADDR:%.*]] = alloca double, align 8
Expand All @@ -42,7 +44,15 @@ void nested_parallel_0(void) {
// ALL-NEXT: [[OMP_GLOBAL_THREAD_NUM:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GLOB1]])
// ALL-NEXT: br label [[OMP_PARALLEL:%.*]]
// ALL: omp_parallel:
// ALL-NEXT: call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* @[[GLOB1]], i32 3, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i32*, double*, float**)* @_Z17nested_parallel_1Pfid..omp_par.2 to void (i32*, i32*, ...)*), i32* [[A_ADDR]], double* [[B_ADDR]], float** [[R_ADDR]])
// ALL-NEXT: [[GEP_STRUCTARG:%.*]] = getelementptr { { i32*, double*, float** }*, i32*, double*, float** }, { { i32*, double*, float** }*, i32*, double*, float** }* [[STRUCTARG14]], i32 0, i32 0
// ALL-NEXT: store { i32*, double*, float** }* [[STRUCTARG]], { i32*, double*, float** }** [[GEP_STRUCTARG]], align 8
// ALL-NEXT: [[GEP_A_ADDR15:%.*]] = getelementptr { { i32*, double*, float** }*, i32*, double*, float** }, { { i32*, double*, float** }*, i32*, double*, float** }* [[STRUCTARG14]], i32 0, i32 1
// ALL-NEXT: store i32* [[A_ADDR]], i32** [[GEP_A_ADDR15]], align 8
// ALL-NEXT: [[GEP_B_ADDR16:%.*]] = getelementptr { { i32*, double*, float** }*, i32*, double*, float** }, { { i32*, double*, float** }*, i32*, double*, float** }* [[STRUCTARG14]], i32 0, i32 2
// ALL-NEXT: store double* [[B_ADDR]], double** [[GEP_B_ADDR16]], align 8
// ALL-NEXT: [[GEP_R_ADDR17:%.*]] = getelementptr { { i32*, double*, float** }*, i32*, double*, float** }, { { i32*, double*, float** }*, i32*, double*, float** }* [[STRUCTARG14]], i32 0, i32 3
// ALL-NEXT: store float** [[R_ADDR]], float*** [[GEP_R_ADDR17]], align 8
// ALL-NEXT: call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* @[[GLOB1]], i32 1, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, { { i32*, double*, float** }*, i32*, double*, float** }*)* @_Z17nested_parallel_1Pfid..omp_par.2 to void (i32*, i32*, ...)*), { { i32*, double*, float** }*, i32*, double*, float** }* [[STRUCTARG14]])
// ALL-NEXT: br label [[OMP_PAR_OUTLINED_EXIT13:%.*]]
// ALL: omp.par.outlined.exit13:
// ALL-NEXT: br label [[OMP_PAR_EXIT_SPLIT:%.*]]
Expand All @@ -61,6 +71,10 @@ void nested_parallel_1(float *r, int a, double b) {

// ALL-LABEL: @_Z17nested_parallel_2Pfid(
// ALL-NEXT: entry:
// ALL-NEXT: [[STRUCTARG68:%.*]] = alloca { i32*, double*, float**, { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }*, { i32*, double*, float** }*, { i32*, double*, float** }* }, align 8
// ALL-NEXT: [[STRUCTARG64:%.*]] = alloca { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }, align 8
// ALL-NEXT: [[STRUCTARG59:%.*]] = alloca { i32*, double*, float** }, align 8
// ALL-NEXT: [[STRUCTARG:%.*]] = alloca { i32*, double*, float** }, align 8
// ALL-NEXT: [[R_ADDR:%.*]] = alloca float*, align 8
// ALL-NEXT: [[A_ADDR:%.*]] = alloca i32, align 4
// ALL-NEXT: [[B_ADDR:%.*]] = alloca double, align 8
Expand All @@ -70,7 +84,19 @@ void nested_parallel_1(float *r, int a, double b) {
// ALL-NEXT: [[OMP_GLOBAL_THREAD_NUM:%.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GLOB1]])
// ALL-NEXT: br label [[OMP_PARALLEL:%.*]]
// ALL: omp_parallel:
// ALL-NEXT: call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* @[[GLOB1]], i32 3, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i32*, double*, float**)* @_Z17nested_parallel_2Pfid..omp_par.5 to void (i32*, i32*, ...)*), i32* [[A_ADDR]], double* [[B_ADDR]], float** [[R_ADDR]])
// ALL-NEXT: [[GEP_A_ADDR:%.*]] = getelementptr { i32*, double*, float**, { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }*, { i32*, double*, float** }*, { i32*, double*, float** }* }, { i32*, double*, float**, { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }*, { i32*, double*, float** }*, { i32*, double*, float** }* }* [[STRUCTARG68]], i32 0, i32 0
// ALL-NEXT: store i32* [[A_ADDR]], i32** [[GEP_A_ADDR]], align 8
// ALL-NEXT: [[GEP_B_ADDR:%.*]] = getelementptr { i32*, double*, float**, { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }*, { i32*, double*, float** }*, { i32*, double*, float** }* }, { i32*, double*, float**, { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }*, { i32*, double*, float** }*, { i32*, double*, float** }* }* [[STRUCTARG68]], i32 0, i32 1
// ALL-NEXT: store double* [[B_ADDR]], double** [[GEP_B_ADDR]], align 8
// ALL-NEXT: [[GEP_R_ADDR:%.*]] = getelementptr { i32*, double*, float**, { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }*, { i32*, double*, float** }*, { i32*, double*, float** }* }, { i32*, double*, float**, { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }*, { i32*, double*, float** }*, { i32*, double*, float** }* }* [[STRUCTARG68]], i32 0, i32 2
// ALL-NEXT: store float** [[R_ADDR]], float*** [[GEP_R_ADDR]], align 8
// ALL-NEXT: [[GEP_STRUCTARG64:%.*]] = getelementptr { i32*, double*, float**, { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }*, { i32*, double*, float** }*, { i32*, double*, float** }* }, { i32*, double*, float**, { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }*, { i32*, double*, float** }*, { i32*, double*, float** }* }* [[STRUCTARG68]], i32 0, i32 3
// ALL-NEXT: store { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }* [[STRUCTARG64]], { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }** [[GEP_STRUCTARG64]], align 8
// ALL-NEXT: [[GEP_STRUCTARG69:%.*]] = getelementptr { i32*, double*, float**, { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }*, { i32*, double*, float** }*, { i32*, double*, float** }* }, { i32*, double*, float**, { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }*, { i32*, double*, float** }*, { i32*, double*, float** }* }* [[STRUCTARG68]], i32 0, i32 4
// ALL-NEXT: store { i32*, double*, float** }* [[STRUCTARG]], { i32*, double*, float** }** [[GEP_STRUCTARG69]], align 8
// ALL-NEXT: [[GEP_STRUCTARG5970:%.*]] = getelementptr { i32*, double*, float**, { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }*, { i32*, double*, float** }*, { i32*, double*, float** }* }, { i32*, double*, float**, { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }*, { i32*, double*, float** }*, { i32*, double*, float** }* }* [[STRUCTARG68]], i32 0, i32 5
// ALL-NEXT: store { i32*, double*, float** }* [[STRUCTARG59]], { i32*, double*, float** }** [[GEP_STRUCTARG5970]], align 8
// ALL-NEXT: call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* @[[GLOB1]], i32 1, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, { i32*, double*, float**, { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }*, { i32*, double*, float** }*, { i32*, double*, float** }* }*)* @_Z17nested_parallel_2Pfid..omp_par.5 to void (i32*, i32*, ...)*), { i32*, double*, float**, { i32*, double*, float**, { i32*, double*, float** }*, { i32*, double*, float** }* }*, { i32*, double*, float** }*, { i32*, double*, float** }* }* [[STRUCTARG68]])
// ALL-NEXT: br label [[OMP_PAR_OUTLINED_EXIT55:%.*]]
// ALL: omp.par.outlined.exit55:
// ALL-NEXT: br label [[OMP_PAR_EXIT_SPLIT:%.*]]
Expand Down
168 changes: 110 additions & 58 deletions clang/test/OpenMP/irbuilder_nested_parallel_for.c

Large diffs are not rendered by default.

605 changes: 316 additions & 289 deletions clang/test/OpenMP/parallel_codegen.cpp

Large diffs are not rendered by default.

6 changes: 2 additions & 4 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@ class OpenMPIRBuilder {
/// Finalize the underlying module, e.g., by outlining regions.
/// \param Fn The function to be finalized. If not used,
/// all functions are finalized.
/// \param AllowExtractorSinking Flag to include sinking instructions,
/// emitted by CodeExtractor, in the
/// outlined region. Default is false.
void finalize(Function *Fn = nullptr, bool AllowExtractorSinking = false);
void finalize(Function *Fn = nullptr);

/// Add attributes known for \p FnID to \p Fn.
void addAttributes(omp::RuntimeFunction FnID, Function &Fn);
Expand Down Expand Up @@ -772,6 +769,7 @@ class OpenMPIRBuilder {
using PostOutlineCBTy = std::function<void(Function &)>;
PostOutlineCBTy PostOutlineCB;
BasicBlock *EntryBB, *ExitBB;
SmallVector<Value *, 2> ExcludeArgsFromAggregate;

/// Collect all blocks in between EntryBB and ExitBB in both the given
/// vector and set.
Expand Down
8 changes: 7 additions & 1 deletion llvm/include/llvm/Transforms/Utils/CodeExtractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class CodeExtractorAnalysisCache {
///
/// Based on the blocks used when constructing the code extractor,
/// determine whether it is eligible for extraction.
///
///
/// Checks that varargs handling (with vastart and vaend) is only done in
/// the outlined blocks.
bool isEligible() const;
Expand Down Expand Up @@ -214,6 +214,10 @@ class CodeExtractorAnalysisCache {
/// original block will be added to the outline region.
BasicBlock *findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock);

/// Exclude a value from aggregate argument passing when extracting a code
/// region, passing it instead as a scalar.
void excludeArgFromAggregate(Value *Arg);

private:
struct LifetimeMarkerInfo {
bool SinkLifeStart = false;
Expand All @@ -222,6 +226,8 @@ class CodeExtractorAnalysisCache {
Instruction *LifeEnd = nullptr;
};

ValueSet ExcludeArgsFromAggregate;

LifetimeMarkerInfo
getLifetimeMarkers(const CodeExtractorAnalysisCache &CEAC,
Instruction *Addr, BasicBlock *ExitBlock) const;
Expand Down
47 changes: 26 additions & 21 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {

void OpenMPIRBuilder::initialize() { initializeTypes(M); }

void OpenMPIRBuilder::finalize(Function *Fn, bool AllowExtractorSinking) {
void OpenMPIRBuilder::finalize(Function *Fn) {
SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
SmallVector<BasicBlock *, 32> Blocks;
SmallVector<OutlineInfo, 16> DeferredOutlines;
Expand All @@ -193,7 +193,7 @@ void OpenMPIRBuilder::finalize(Function *Fn, bool AllowExtractorSinking) {
Function *OuterFn = OI.getFunction();
CodeExtractorAnalysisCache CEAC(*OuterFn);
CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
/* AggregateArgs */ false,
/* AggregateArgs */ true,
/* BlockFrequencyInfo */ nullptr,
/* BranchProbabilityInfo */ nullptr,
/* AssumptionCache */ nullptr,
Expand All @@ -207,6 +207,9 @@ void OpenMPIRBuilder::finalize(Function *Fn, bool AllowExtractorSinking) {
assert(Extractor.isEligible() &&
"Expected OpenMP outlining to be possible!");

for (auto *V : OI.ExcludeArgsFromAggregate)
Extractor.excludeArgFromAggregate(V);

Function *OutlinedFn = Extractor.extractCodeRegion(CEAC);

LLVM_DEBUG(dbgs() << "After outlining: " << *OuterFn << "\n");
Expand All @@ -225,25 +228,25 @@ void OpenMPIRBuilder::finalize(Function *Fn, bool AllowExtractorSinking) {
BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
if (AllowExtractorSinking) {
// Move instructions from the to-be-deleted ArtificialEntry to the entry
// basic block of the parallel region. CodeExtractor may have sunk
// allocas/bitcasts for values that are solely used in the outlined
// region and do not escape.
assert(!ArtificialEntry.empty() &&
"Expected instructions to sink in the outlined region");
for (BasicBlock::iterator It = ArtificialEntry.begin(),
End = ArtificialEntry.end();
It != End;) {
Instruction &I = *It;
It++;

if (I.isTerminator())
continue;

I.moveBefore(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
}
// Move instructions from the to-be-deleted ArtificialEntry to the entry
// basic block of the parallel region. CodeExtractor generates
// instructions to unwrap the aggregate argument and may sink
// allocas/bitcasts for values that are solely used in the outlined region
// and do not escape.
assert(!ArtificialEntry.empty() &&
"Expected instructions to add in the outlined region entry");
for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
End = ArtificialEntry.rend();
It != End;) {
Instruction &I = *It;
It++;

if (I.isTerminator())
continue;

I.moveBefore(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
}

OI.EntryBB->moveBefore(&ArtificialEntry);
ArtificialEntry.eraseFromParent();
}
Expand Down Expand Up @@ -811,8 +814,10 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);

auto PrivHelper = [&](Value &V) {
if (&V == TIDAddr || &V == ZeroAddr)
if (&V == TIDAddr || &V == ZeroAddr) {
OI.ExcludeArgsFromAggregate.push_back(&V);
return;
}

SetVector<Use *> Uses;
for (Use &U : V.uses())
Expand Down
3 changes: 1 addition & 2 deletions llvm/lib/Transforms/IPO/OpenMPOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1074,8 +1074,7 @@ struct OpenMPOpt {
BranchInst::Create(AfterBB, AfterIP.getBlock());

// Perform the actual outlining.
OMPInfoCache.OMPBuilder.finalize(OriginalFn,
/* AllowExtractorSinking */ true);
OMPInfoCache.OMPBuilder.finalize(OriginalFn);

Function *OutlinedFn = MergableCIs.front()->getCaller();

Expand Down
179 changes: 110 additions & 69 deletions llvm/lib/Transforms/Utils/CodeExtractor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,39 +829,54 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
default: RetTy = Type::getInt16Ty(header->getContext()); break;
}

std::vector<Type *> paramTy;
std::vector<Type *> ParamTy;
std::vector<Type *> AggParamTy;
ValueSet StructValues;

// Add the types of the input values to the function's argument list
for (Value *value : inputs) {
LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
paramTy.push_back(value->getType());
if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) {
AggParamTy.push_back(value->getType());
StructValues.insert(value);
} else
ParamTy.push_back(value->getType());
}

// Add the types of the output values to the function's argument list.
for (Value *output : outputs) {
LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n");
if (AggregateArgs)
paramTy.push_back(output->getType());
else
paramTy.push_back(PointerType::getUnqual(output->getType()));
if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
AggParamTy.push_back(output->getType());
StructValues.insert(output);
} else
ParamTy.push_back(PointerType::getUnqual(output->getType()));
}

assert(
(ParamTy.size() + AggParamTy.size()) ==
(inputs.size() + outputs.size()) &&
"Number of scalar and aggregate params does not match inputs, outputs");
assert(StructValues.empty() ||
AggregateArgs && "Expeced StructValues only with AggregateArgs set");

// Concatenate scalar and aggregate params in ParamTy.
size_t NumScalarParams = ParamTy.size();
StructType *StructTy = nullptr;
if (AggregateArgs && !AggParamTy.empty()) {
StructTy = StructType::get(M->getContext(), AggParamTy);
ParamTy.push_back(PointerType::getUnqual(StructTy));
}

LLVM_DEBUG({
dbgs() << "Function type: " << *RetTy << " f(";
for (Type *i : paramTy)
for (Type *i : ParamTy)
dbgs() << *i << ", ";
dbgs() << ")\n";
});

StructType *StructTy = nullptr;
if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
StructTy = StructType::get(M->getContext(), paramTy);
paramTy.clear();
paramTy.push_back(PointerType::getUnqual(StructTy));
}
FunctionType *funcType =
FunctionType::get(RetTy, paramTy,
AllowVarArgs && oldFunction->isVarArg());
FunctionType *funcType = FunctionType::get(
RetTy, ParamTy, AllowVarArgs && oldFunction->isVarArg());

std::string SuffixToUse =
Suffix.empty()
Expand Down Expand Up @@ -981,24 +996,27 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
}
newFunction->getBasicBlockList().push_back(newRootNode);

// Create an iterator to name all of the arguments we inserted.
Function::arg_iterator AI = newFunction->arg_begin();
// Create scalar and aggregate iterators to name all of the arguments we
// inserted.
Function::arg_iterator ScalarAI = newFunction->arg_begin();
Function::arg_iterator AggAI = std::next(ScalarAI, NumScalarParams);

// Rewrite all users of the inputs in the extracted region to use the
// arguments (or appropriate addressing into struct) instead.
for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) {
Value *RewriteVal;
if (AggregateArgs) {
if (AggregateArgs && StructValues.contains(inputs[i])) {
Value *Idx[2];
Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i);
Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx);
Instruction *TI = newFunction->begin()->getTerminator();
GetElementPtrInst *GEP = GetElementPtrInst::Create(
StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI);
RewriteVal = new LoadInst(StructTy->getElementType(i), GEP,
StructTy, &*AggAI, Idx, "gep_" + inputs[i]->getName(), TI);
RewriteVal = new LoadInst(StructTy->getElementType(aggIdx), GEP,
"loadgep_" + inputs[i]->getName(), TI);
++aggIdx;
} else
RewriteVal = &*AI++;
RewriteVal = &*ScalarAI++;

std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end());
for (User *use : Users)
Expand All @@ -1008,12 +1026,14 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
}

// Set names for input and output arguments.
if (!AggregateArgs) {
AI = newFunction->arg_begin();
for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI)
AI->setName(inputs[i]->getName());
for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI)
AI->setName(outputs[i]->getName()+".out");
if (NumScalarParams) {
ScalarAI = newFunction->arg_begin();
for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++ScalarAI)
if (!StructValues.contains(inputs[i]))
ScalarAI->setName(inputs[i]->getName());
for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++ScalarAI)
if (!StructValues.contains(outputs[i]))
ScalarAI->setName(outputs[i]->getName() + ".out");
}

// Rewrite branches to basic blocks outside of the loop to new dummy blocks
Expand Down Expand Up @@ -1121,44 +1141,48 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
ValueSet &outputs) {
// Emit a call to the new function, passing in: *pointer to struct (if
// aggregating parameters), or plan inputs and allocated memory for outputs
std::vector<Value *> params, StructValues, ReloadOutputs, Reloads;
std::vector<Value *> params, ReloadOutputs, Reloads;
ValueSet StructValues;

Module *M = newFunction->getParent();
LLVMContext &Context = M->getContext();
const DataLayout &DL = M->getDataLayout();
CallInst *call = nullptr;

// Add inputs as params, or to be filled into the struct
unsigned ArgNo = 0;
unsigned ScalarInputArgNo = 0;
SmallVector<unsigned, 1> SwiftErrorArgs;
for (Value *input : inputs) {
if (AggregateArgs)
StructValues.push_back(input);
if (AggregateArgs && !ExcludeArgsFromAggregate.contains(input))
StructValues.insert(input);
else {
params.push_back(input);
if (input->isSwiftError())
SwiftErrorArgs.push_back(ArgNo);
SwiftErrorArgs.push_back(ScalarInputArgNo);
}
++ArgNo;
++ScalarInputArgNo;
}

// Create allocas for the outputs
unsigned ScalarOutputArgNo = 0;
for (Value *output : outputs) {
if (AggregateArgs) {
StructValues.push_back(output);
if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
StructValues.insert(output);
} else {
AllocaInst *alloca =
new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
nullptr, output->getName() + ".loc",
&codeReplacer->getParent()->front().front());
ReloadOutputs.push_back(alloca);
params.push_back(alloca);
++ScalarOutputArgNo;
}
}

StructType *StructArgTy = nullptr;
AllocaInst *Struct = nullptr;
if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
unsigned NumAggregatedInputs = 0;
if (AggregateArgs && !StructValues.empty()) {
std::vector<Type *> ArgTypes;
for (Value *V : StructValues)
ArgTypes.push_back(V->getType());
Expand All @@ -1170,14 +1194,18 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
&codeReplacer->getParent()->front().front());
params.push_back(Struct);

for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
Value *Idx[2];
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
GetElementPtrInst *GEP = GetElementPtrInst::Create(
StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
codeReplacer->getInstList().push_back(GEP);
new StoreInst(StructValues[i], GEP, codeReplacer);
// Store aggregated inputs in the struct.
for (unsigned i = 0, e = StructValues.size(); i != e; ++i) {
if (inputs.contains(StructValues[i])) {
Value *Idx[2];
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
GetElementPtrInst *GEP = GetElementPtrInst::Create(
StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
codeReplacer->getInstList().push_back(GEP);
new StoreInst(StructValues[i], GEP, codeReplacer);
NumAggregatedInputs++;
}
}
}

Expand All @@ -1200,24 +1228,24 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError);
}

Function::arg_iterator OutputArgBegin = newFunction->arg_begin();
unsigned FirstOut = inputs.size();
if (!AggregateArgs)
std::advance(OutputArgBegin, inputs.size());

// Reload the outputs passed in by reference.
for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
// Reload the outputs passed in by reference, use the struct if output is in
// the aggregate or reload from the scalar argument.
for (unsigned i = 0, e = outputs.size(), scalarIdx = 0,
aggIdx = NumAggregatedInputs;
i != e; ++i) {
Value *Output = nullptr;
if (AggregateArgs) {
if (AggregateArgs && StructValues.contains(outputs[i])) {
Value *Idx[2];
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
GetElementPtrInst *GEP = GetElementPtrInst::Create(
StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
codeReplacer->getInstList().push_back(GEP);
Output = GEP;
++aggIdx;
} else {
Output = ReloadOutputs[i];
Output = ReloadOutputs[scalarIdx];
++scalarIdx;
}
LoadInst *load = new LoadInst(outputs[i]->getType(), Output,
outputs[i]->getName() + ".reload",
Expand Down Expand Up @@ -1299,8 +1327,13 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
// Store the arguments right after the definition of output value.
// This should be proceeded after creating exit stubs to be ensure that invoke
// result restore will be placed in the outlined function.
Function::arg_iterator OAI = OutputArgBegin;
for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
Function::arg_iterator ScalarOutputArgBegin = newFunction->arg_begin();
std::advance(ScalarOutputArgBegin, ScalarInputArgNo);
Function::arg_iterator AggOutputArgBegin = newFunction->arg_begin();
std::advance(AggOutputArgBegin, ScalarInputArgNo + ScalarOutputArgNo);

for (unsigned i = 0, e = outputs.size(), aggIdx = NumAggregatedInputs; i != e;
++i) {
auto *OutI = dyn_cast<Instruction>(outputs[i]);
if (!OutI)
continue;
Expand All @@ -1320,23 +1353,27 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
assert((InsertBefore->getFunction() == newFunction ||
Blocks.count(InsertBefore->getParent())) &&
"InsertPt should be in new function");
assert(OAI != newFunction->arg_end() &&
"Number of output arguments should match "
"the amount of defined values");
if (AggregateArgs) {
if (AggregateArgs && StructValues.contains(outputs[i])) {
assert(AggOutputArgBegin != newFunction->arg_end() &&
"Number of aggregate output arguments should match "
"the number of defined values");
Value *Idx[2];
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
GetElementPtrInst *GEP = GetElementPtrInst::Create(
StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(),
StructArgTy, &*AggOutputArgBegin, Idx, "gep_" + outputs[i]->getName(),
InsertBefore);
new StoreInst(outputs[i], GEP, InsertBefore);
++aggIdx;
// Since there should be only one struct argument aggregating
// all the output values, we shouldn't increment OAI, which always
// points to the struct argument, in this case.
// all the output values, we shouldn't increment AggOutputArgBegin, which
// always points to the struct argument, in this case.
} else {
new StoreInst(outputs[i], &*OAI, InsertBefore);
++OAI;
assert(ScalarOutputArgBegin != newFunction->arg_end() &&
"Number of scalar output arguments should match "
"the number of defined values");
new StoreInst(outputs[i], &*ScalarOutputArgBegin, InsertBefore);
++ScalarOutputArgBegin;
}
}

Expand Down Expand Up @@ -1835,3 +1872,7 @@ bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc,
}
return false;
}

void CodeExtractor::excludeArgFromAggregate(Value *Arg) {
ExcludeArgsFromAggregate.insert(Arg);
}
682 changes: 440 additions & 242 deletions llvm/test/Transforms/OpenMP/parallel_region_merging.ll

Large diffs are not rendered by default.

94 changes: 72 additions & 22 deletions llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,64 @@ template <typename InstTy> static Value *findStoredValue(Value *AllocaValue) {
return Store->getValueOperand();
}

// Returns the value stored in the aggregate argument of an outlined function,
// or nullptr if it is not found.
static Value *findStoredValueInAggregateAt(LLVMContext &Ctx, Value *Aggregate,
unsigned Idx) {
GetElementPtrInst *GEPAtIdx = nullptr;
// Find GEP instruction at that index.
for (User *Usr : Aggregate->users()) {
GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Usr);
if (!GEP)
continue;

if (GEP->getOperand(2) != ConstantInt::get(Type::getInt32Ty(Ctx), Idx))
continue;

EXPECT_EQ(GEPAtIdx, nullptr);
GEPAtIdx = GEP;
}

EXPECT_NE(GEPAtIdx, nullptr);
EXPECT_EQ(GEPAtIdx->getNumUses(), 1U);

// Find the value stored to the aggregate.
StoreInst *StoreToAgg = dyn_cast<StoreInst>(*GEPAtIdx->user_begin());
Value *StoredAggValue = StoreToAgg->getValueOperand();

Value *StoredValue = nullptr;

// Find the value stored to the value stored in the aggregate.
for (User *Usr : StoredAggValue->users()) {
StoreInst *Store = dyn_cast<StoreInst>(Usr);
if (!Store)
continue;

if (Store->getPointerOperand() != StoredAggValue)
continue;

EXPECT_EQ(StoredValue, nullptr);
StoredValue = Store->getValueOperand();
}

return StoredValue;
}

// Returns the aggregate that the value is originating from.
static Value *findAggregateFromValue(Value *V) {
// Expects a load instruction that loads from the aggregate.
LoadInst *Load = dyn_cast<LoadInst>(V);
EXPECT_NE(Load, nullptr);
// Find the GEP instruction used in the load instruction.
GetElementPtrInst *GEP =
dyn_cast<GetElementPtrInst>(Load->getPointerOperand());
EXPECT_NE(GEP, nullptr);
// Find the aggregate used in the GEP instruction.
Value *Aggregate = GEP->getPointerOperand();

return Aggregate;
}

TEST_F(OpenMPIRBuilderTest, CreateBarrier) {
OpenMPIRBuilder OMPBuilder(*M);
OMPBuilder.initialize();
Expand Down Expand Up @@ -581,8 +639,9 @@ TEST_F(OpenMPIRBuilderTest, ParallelSimple) {
EXPECT_EQ(ForkCI->getArgOperand(1),
ConstantInt::get(Type::getInt32Ty(Ctx), 1U));
EXPECT_EQ(ForkCI->getArgOperand(2), Usr);
EXPECT_EQ(findStoredValue<AllocaInst>(ForkCI->getArgOperand(3)),
F->arg_begin());
Value *StoredValue =
findStoredValueInAggregateAt(Ctx, ForkCI->getArgOperand(3), 0);
EXPECT_EQ(StoredValue, F->arg_begin());
}

TEST_F(OpenMPIRBuilderTest, ParallelNested) {
Expand Down Expand Up @@ -906,15 +965,16 @@ TEST_F(OpenMPIRBuilderTest, ParallelIfCond) {
EXPECT_TRUE(isa<GlobalVariable>(ForkCI->getArgOperand(0)));
EXPECT_EQ(ForkCI->getArgOperand(1),
ConstantInt::get(Type::getInt32Ty(Ctx), 1));
Value *StoredForkArg = findStoredValue<AllocaInst>(ForkCI->getArgOperand(3));
Value *StoredForkArg =
findStoredValueInAggregateAt(Ctx, ForkCI->getArgOperand(3), 0);
EXPECT_EQ(StoredForkArg, F->arg_begin());

EXPECT_EQ(DirectCI->getCalledFunction(), OutlinedFn);
EXPECT_EQ(DirectCI->arg_size(), 3U);
EXPECT_TRUE(isa<AllocaInst>(DirectCI->getArgOperand(0)));
EXPECT_TRUE(isa<AllocaInst>(DirectCI->getArgOperand(1)));
Value *StoredDirectArg =
findStoredValue<AllocaInst>(DirectCI->getArgOperand(2));
findStoredValueInAggregateAt(Ctx, DirectCI->getArgOperand(2), 0);
EXPECT_EQ(StoredDirectArg, F->arg_begin());
}

Expand Down Expand Up @@ -1045,6 +1105,8 @@ TEST_F(OpenMPIRBuilderTest, ParallelForwardAsPointers) {
Type *I32PtrTy = Type::getInt32PtrTy(M->getContext());
Type *StructTy = StructType::get(I32Ty, I32PtrTy);
Type *StructPtrTy = StructTy->getPointerTo();
StructType *ArgStructTy =
StructType::get(I32PtrTy, StructPtrTy, I32PtrTy, StructPtrTy);
Type *VoidTy = Type::getVoidTy(M->getContext());
FunctionCallee RetI32Func = M->getOrInsertFunction("ret_i32", I32Ty);
FunctionCallee TakeI32Func =
Expand Down Expand Up @@ -1096,21 +1158,7 @@ TEST_F(OpenMPIRBuilderTest, ParallelForwardAsPointers) {

Type *Arg2Type = OutlinedFn->getArg(2)->getType();
EXPECT_TRUE(Arg2Type->isPointerTy());
EXPECT_TRUE(cast<PointerType>(Arg2Type)->isOpaqueOrPointeeTypeMatches(I32Ty));

// Arguments that need to be passed through pointers and reloaded will get
// used earlier in the functions and therefore will appear first in the
// argument list after outlining.
Type *Arg3Type = OutlinedFn->getArg(3)->getType();
EXPECT_TRUE(Arg3Type->isPointerTy());
EXPECT_TRUE(
cast<PointerType>(Arg3Type)->isOpaqueOrPointeeTypeMatches(StructTy));

Type *Arg4Type = OutlinedFn->getArg(4)->getType();
EXPECT_EQ(Arg4Type, I32PtrTy);

Type *Arg5Type = OutlinedFn->getArg(5)->getType();
EXPECT_EQ(Arg5Type, StructPtrTy);
EXPECT_EQ(Arg2Type->getPointerElementType(), ArgStructTy);
}

TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) {
Expand Down Expand Up @@ -3031,7 +3079,7 @@ static bool isValueReducedToFuncArg(Value *V, BasicBlock *BB) {
return false;

return Store->getPointerOperand() == GlobalLoad->getPointerOperand() &&
isa<Argument>(GlobalLoad->getPointerOperand());
isa<Argument>(findAggregateFromValue(GlobalLoad->getPointerOperand()));
}

/// Finds among users of Ptr a pair of GEP instructions with indices [0, 0] and
Expand Down Expand Up @@ -3328,9 +3376,11 @@ TEST_F(OpenMPIRBuilderTest, CreateReductions) {
auto *SecondAtomic =
findSingleUserInBlock<AtomicRMWInst>(SecondLoad, AtomicBB);
ASSERT_NE(FirstAtomic, nullptr);
EXPECT_TRUE(isa<Argument>(FirstAtomic->getPointerOperand()));
Value *AtomicStorePointer = FirstAtomic->getPointerOperand();
EXPECT_TRUE(isa<Argument>(findAggregateFromValue(AtomicStorePointer)));
ASSERT_NE(SecondAtomic, nullptr);
EXPECT_TRUE(isa<Argument>(SecondAtomic->getPointerOperand()));
AtomicStorePointer = SecondAtomic->getPointerOperand();
EXPECT_TRUE(isa<Argument>(findAggregateFromValue(AtomicStorePointer)));

// Check that the separate reduction function also performs (non-atomic)
// reductions after extracting reduction variables from its arguments.
Expand Down
54 changes: 52 additions & 2 deletions llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ TEST(CodeExtractor, ExitBlockOrderingPhis) {
EXPECT_TRUE(NextReturn);
ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
EXPECT_TRUE(CINext->getLimitedValue() == 0u);

EXPECT_FALSE(verifyFunction(*Outlined));
EXPECT_FALSE(verifyFunction(*Func));
}
Expand Down Expand Up @@ -245,7 +245,7 @@ TEST(CodeExtractor, ExitBlockOrdering) {
EXPECT_TRUE(NextReturn);
ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
EXPECT_TRUE(CINext->getLimitedValue() == 0u);

EXPECT_FALSE(verifyFunction(*Outlined));
EXPECT_FALSE(verifyFunction(*Func));
}
Expand Down Expand Up @@ -504,4 +504,54 @@ TEST(CodeExtractor, RemoveBitcastUsesFromOuterLifetimeMarkers) {
EXPECT_FALSE(verifyFunction(*Outlined));
EXPECT_FALSE(verifyFunction(*Func));
}

TEST(CodeExtractor, PartialAggregateArgs) {
LLVMContext Ctx;
SMDiagnostic Err;
std::unique_ptr<Module> M(parseAssemblyString(R"ir(
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"
declare void @use(i32)
define void @foo(i32 %a, i32 %b, i32 %c) {
entry:
br label %extract
extract:
call void @use(i32 %a)
call void @use(i32 %b)
call void @use(i32 %c)
br label %exit
exit:
ret void
}
)ir",
Err, Ctx));

Function *Func = M->getFunction("foo");
SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};

// Create the CodeExtractor with arguments aggregation enabled.
CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
/* AggregateArgs */ true);
EXPECT_TRUE(CE.isEligible());

CodeExtractorAnalysisCache CEAC(*Func);
SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
BasicBlock *CommonExit = nullptr;
CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
// Exclude the first input from the argument aggregate.
CE.excludeArgFromAggregate(Inputs[0]);

Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
EXPECT_TRUE(Outlined);
// Expect 2 arguments in the outlined function: the excluded input and the
// struct aggregate for the remaining inputs.
EXPECT_EQ(Outlined->arg_size(), 2U);
EXPECT_FALSE(verifyFunction(*Outlined));
EXPECT_FALSE(verifyFunction(*Func));
}
} // end anonymous namespace