Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1596,10 +1596,61 @@ static llvm::Expected<llvm::Value *> initPrivateVar(
return phis[0];
}

/// Beginning with \p startBlock, this function visits all reachable successor
/// blocks. For each such block, static alloca instructions (i.e. non-array
/// allocas) are collected. Then, these collected alloca instructions are moved
/// to the \p allocaIP insertion point.
///
/// This is useful in cases where, for example, more than one allocatable or
/// array are privatized. In such cases, we allocate a number of temporary
/// descriptors to handle the initialization logic. Additonally, for each
/// private value, there is branching logic based on the value of the origianl
/// private variable's allocation state. Therefore, we end up with descriptor
/// alloca instructions preceded by conditional branches which casues runtime
/// issues at least on the GPU.
static void hoistStaticAllocasToAllocaIP(
llvm::BasicBlock *startBlock,
const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP) {
llvm::SmallVector<llvm::BasicBlock *> inlinedBlocks{startBlock};
llvm::SmallPtrSet<llvm::BasicBlock *, 4> seenBlocks;
llvm::SmallVector<llvm::Instruction *> staticAllocas;

while (!inlinedBlocks.empty()) {
llvm::BasicBlock *curBlock = inlinedBlocks.front();
inlinedBlocks.erase(inlinedBlocks.begin());
llvm::Instruction *terminator = curBlock->getTerminator();

for (llvm::Instruction &inst : *curBlock) {
if (auto *allocaInst = mlir::dyn_cast<llvm::AllocaInst>(&inst)) {
if (!allocaInst->isArrayAllocation()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are array allocations special? I would have thought the case worth worrying about is array allocations which have a size determined dynamically - a statically sized array allocation should work okay.

Would multiple of these dynamically sized arrays still crash the GPU code?

Copy link
Member Author

@ergawy ergawy Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For descriptors, the problematic allocations are the allocations of the temporary descriptor structures used to initialize the private storage. All such allocations are static allocations because they are just structs.

The dynamic arrays are allocated only when the original value is allocated, so these allocations has to be maintained in the proper branch since we read the shape from the original value.

The problem when we have many descriptors is that:

  1. We inline the init region of descriptor number 1 which includes temp allocations + the if-else branch for initialization.
  2. We do the same for descriptor number 2.
  3. .....
    Because of that, such temp allocations are emitted between if-else branching.

I think the backend is smart enough to keep the dynamic allocations since they are obviously protected/tucked inside a branch while some of the static allocations after the branching joins again are problamtic since they are supposed to be uncoditional.

#ifdef EXPENSIVE_CHECKS
assert(llvm::count(staticInitAllocas, allocaInst) == 0);
#endif
staticAllocas.push_back(allocaInst);
}
}
}

if (!terminator || !terminator->isTerminator() ||
terminator->getNumSuccessors() == 0)
continue;

for (llvm::BasicBlock *successor : llvm::successors(terminator))
if (!seenBlocks.contains(successor)) {
inlinedBlocks.push_back(successor);
seenBlocks.insert(successor);
}
}

for (llvm::Instruction *staticAlloca : staticAllocas)
staticAlloca->moveBefore(allocaIP.getPoint());
}

static llvm::Error
initPrivateVars(llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
PrivateVarsInfo &privateVarsInfo,
const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
if (privateVarsInfo.blockArgs.empty())
return llvm::Error::success();
Expand All @@ -1624,6 +1675,8 @@ initPrivateVars(llvm::IRBuilderBase &builder,
setInsertPointForPossiblyEmptyBlock(builder);
}

hoistStaticAllocasToAllocaIP(privInitBlock, allocaIP);

return llvm::Error::success();
}

Expand Down Expand Up @@ -2575,7 +2628,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
deferredStores, isByRef)))
return failure();

if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo,
allocaIP),
opInst)
.failed())
return failure();
Expand Down Expand Up @@ -2764,9 +2818,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
assert(afterAllocas.get()->getSinglePredecessor());
builder.restoreIP(codeGenIP);

if (handleError(
initPrivateVars(builder, moduleTranslation, privateVarsInfo),
*opInst)
if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo,
allocaIP),
*opInst)
.failed())
return llvm::make_error<PreviouslyReportedError>();

Expand Down Expand Up @@ -2967,7 +3021,8 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
deferredStores, isByRef)))
return failure();

if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo,
allocaIP),
opInst)
.failed())
return failure();
Expand Down Expand Up @@ -5288,8 +5343,9 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
if (handleError(afterAllocas, opInst).failed())
return llvm::make_error<PreviouslyReportedError>();

if (handleError(initPrivateVars(builder, moduleTranslation, privVarsInfo),
opInst)
if (handleError(
initPrivateVars(builder, moduleTranslation, privVarsInfo, allocaIP),
opInst)
.failed())
return llvm::make_error<PreviouslyReportedError>();

Expand Down Expand Up @@ -6090,7 +6146,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,

builder.restoreIP(codeGenIP);
if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo,
&mappedPrivateVars),
allocaIP, &mappedPrivateVars),
*targetOp)
.failed())
return llvm::make_error<PreviouslyReportedError>();
Expand Down
79 changes: 79 additions & 0 deletions mlir/test/Target/LLVMIR/openmp-private-allloca-hoisting.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Tests that static alloca's in `omp.private ... init` regions are hoisted to
// the parent construct's alloca IP.
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s

llvm.func @foo1()
llvm.func @foo2()
llvm.func @foo3()
llvm.func @foo4()

omp.private {type = private} @multi_block.privatizer : f32 init {
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
%0 = llvm.mlir.constant(1 : i32) : i32
%alloca1 = llvm.alloca %0 x !llvm.struct<(i64)> {alignment = 8 : i64} : (i32) -> !llvm.ptr

%1 = llvm.load %arg0 : !llvm.ptr -> f32

%c1 = llvm.mlir.constant(1 : i32) : i32
%c2 = llvm.mlir.constant(2 : i32) : i32
%cond1 = llvm.icmp "eq" %c1, %c2 : i32
llvm.cond_br %cond1, ^bb1, ^bb2

^bb1:
llvm.call @foo1() : () -> ()
llvm.br ^bb3

^bb2:
llvm.call @foo2() : () -> ()
llvm.br ^bb3

^bb3:
llvm.store %1, %arg1 : f32, !llvm.ptr

omp.yield(%arg1 : !llvm.ptr)
}

omp.private {type = private} @multi_block.privatizer2 : f32 init {
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
%0 = llvm.mlir.constant(1 : i32) : i32
%alloca1 = llvm.alloca %0 x !llvm.struct<(ptr)> {alignment = 8 : i64} : (i32) -> !llvm.ptr

%1 = llvm.load %arg0 : !llvm.ptr -> f32

%c1 = llvm.mlir.constant(1 : i32) : i32
%c2 = llvm.mlir.constant(2 : i32) : i32
%cond1 = llvm.icmp "eq" %c1, %c2 : i32
llvm.cond_br %cond1, ^bb1, ^bb2

^bb1:
llvm.call @foo3() : () -> ()
llvm.br ^bb3

^bb2:
llvm.call @foo4() : () -> ()
llvm.br ^bb3

^bb3:
llvm.store %1, %arg1 : f32, !llvm.ptr

omp.yield(%arg1 : !llvm.ptr)
}

llvm.func @parallel_op_private_multi_block(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
omp.parallel private(@multi_block.privatizer %arg0 -> %arg2,
@multi_block.privatizer2 %arg1 -> %arg3 : !llvm.ptr, !llvm.ptr) {
%0 = llvm.load %arg2 : !llvm.ptr -> f32
%1 = llvm.load %arg3 : !llvm.ptr -> f32
omp.terminator
}
llvm.return
}

// CHECK: define internal void @parallel_op_private_multi_block..omp_par({{.*}}) {{.*}} {
// CHECK: omp.par.entry:
// Varify that both allocas were hoisted to the parallel region's entry block.
// CHECK: %{{.*}} = alloca { i64 }, align 8
// CHECK-NEXT: %{{.*}} = alloca { ptr }, align 8
// CHECK-NEXT: br label %omp.region.after_alloca
// CHECK: omp.region.after_alloca:
// CHECK: }