Skip to content

Commit

Permalink
[flang] Fix the incorrect insertion point for alloca (#65999)
Browse files Browse the repository at this point in the history
While creating a temporary alloca for a box in OpenMp region, the
insertion point should be the OpenMP region block instead of the
function entry block.
  • Loading branch information
erjin committed Sep 13, 2023
1 parent 5a48a82 commit f3fdc96
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 13 deletions.
27 changes: 15 additions & 12 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,25 +333,28 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
return rewriter.create<mlir::LLVM::GEPOp>(loc, ty, base, cv);
}

// Find the LLVMFuncOp in whose entry block the alloca should be inserted.
// The order to find the LLVMFuncOp is as follows:
// 1. The parent operation of the current block if it is a LLVMFuncOp.
// 2. The first ancestor that is a LLVMFuncOp.
mlir::LLVM::LLVMFuncOp
getFuncForAllocaInsert(mlir::ConversionPatternRewriter &rewriter) const {
mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
return mlir::isa<mlir::LLVM::LLVMFuncOp>(parentOp)
? mlir::cast<mlir::LLVM::LLVMFuncOp>(parentOp)
: parentOp->getParentOfType<mlir::LLVM::LLVMFuncOp>();
// Find the Block in which the alloca should be inserted.
// The order to recursively find the proper block:
// 1. An OpenMP Op that will be outlined.
// 2. A LLVMFuncOp
// 3. The first ancestor that is an OpenMP Op or a LLVMFuncOp
static mlir::Block *getBlockForAllocaInsert(mlir::Operation *op) {
if (auto iface =
mlir::dyn_cast<mlir::omp::OutlineableOpenMPOpInterface>(op))
return iface.getAllocaBlock();
if (auto llvmFuncOp = mlir::dyn_cast<mlir::LLVM::LLVMFuncOp>(op))
return &llvmFuncOp.front();
return getBlockForAllocaInsert(op->getParentOp());
}

// Generate an alloca of size 1 and type \p toTy.
mlir::LLVM::AllocaOp
genAllocaWithType(mlir::Location loc, mlir::Type toTy, unsigned alignment,
mlir::ConversionPatternRewriter &rewriter) const {
auto thisPt = rewriter.saveInsertionPoint();
mlir::LLVM::LLVMFuncOp func = getFuncForAllocaInsert(rewriter);
rewriter.setInsertionPointToStart(&func.front());
mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp();
mlir::Block *insertBlock = getBlockForAllocaInsert(parentOp);
rewriter.setInsertionPointToStart(insertBlock);
auto size = genI32Constant(loc, rewriter, 1);
auto al = rewriter.create<mlir::LLVM::AllocaOp>(loc, toTy, size, alignment);
rewriter.restoreInsertionPoint(thisPt);
Expand Down
40 changes: 39 additions & 1 deletion flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
Original file line number Diff line number Diff line change
Expand Up @@ -654,4 +654,42 @@ func.func @_QPs(%arg0: !fir.ref<!fir.complex<4>> {fir.bindc_name = "x"}) {
%0 = fir.alloca !fir.complex<4> {bindc_name = "v", uniq_name = "_QFsEv"}
omp.atomic.read %0 = %arg0 : !fir.ref<!fir.complex<4>>, !fir.complex<4>
return
}
}

// -----
// Test if llvm.alloca is properly inserted in the omp section

//CHECK: %[[CONST:.*]] = llvm.mlir.constant(1 : i64) : i64
//CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[CONST]] x !llvm.struct<(ptr<i32>, i64, i32, i8, i8, i8, i8)> {bindc_name = "iattr", in_type = !fir.box<!fir.ptr<i32>>, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFEiattr"} : (i64) -> !llvm.ptr<struct<(ptr<i32>, i64, i32, i8, i8, i8, i8)>>
//CHECK: omp.parallel {
//CHECK: %[[CONST_1:.*]] = llvm.mlir.constant(1 : i32) : i32
//CHECK: %[[ALLOCA_1:.*]] = llvm.alloca %[[CONST_1:.*]] x !llvm.struct<(ptr<i32>, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<struct<(ptr<i32>, i64, i32, i8, i8, i8, i8)>>
//CHECK: %[[LOAD:.*]] = llvm.load %[[ALLOCA]] : !llvm.ptr<struct<(ptr<i32>, i64, i32, i8, i8, i8, i8)>>
//CHECK: llvm.store %[[LOAD]], %[[ALLOCA_1]] : !llvm.ptr<struct<(ptr<i32>, i64, i32, i8, i8, i8, i8)>>
//CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA_1]][0, 0] : (!llvm.ptr<struct<(ptr<i32>, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr<ptr<i32>>
//CHECK: %[[LOAD_2:.*]] = llvm.load %[[GEP]] : !llvm.ptr<ptr<i32>>
//CHECK: omp.terminator
//CHECK: }

func.func @_QQmain() attributes {fir.bindc_name = "mn"} {
%0 = fir.alloca !fir.box<!fir.ptr<i32>> {bindc_name = "iattr", uniq_name = "_QFEiattr"}
%1 = fir.zero_bits !fir.ptr<i32>
%2 = fir.embox %1 : (!fir.ptr<i32>) -> !fir.box<!fir.ptr<i32>>
fir.store %2 to %0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
%3 = fir.address_of(@_QFEx) : !fir.ref<i32>
%4 = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"}
%5 = fir.embox %3 : (!fir.ref<i32>) -> !fir.box<!fir.ptr<i32>>
fir.store %5 to %0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
omp.parallel {
%6 = fir.load %4 : !fir.ref<i32>
%7 = fir.load %0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
%8 = fir.box_addr %7 : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
fir.store %6 to %8 : !fir.ptr<i32>
omp.terminator
}
return
}
fir.global internal @_QFEx target : i32 {
%0 = fir.zero_bits i32
fir.has_value %0 : i32
}

0 comments on commit f3fdc96

Please sign in to comment.