diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index e51050b5a594a..0ec7e570395e8 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -333,16 +333,18 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern { return rewriter.create(loc, ty, base, cv); } - // Find the LLVMFuncOp in whose entry block the alloca should be inserted. - // The order to find the LLVMFuncOp is as follows: - // 1. The parent operation of the current block if it is a LLVMFuncOp. - // 2. The first ancestor that is a LLVMFuncOp. - mlir::LLVM::LLVMFuncOp - getFuncForAllocaInsert(mlir::ConversionPatternRewriter &rewriter) const { - mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp(); - return mlir::isa(parentOp) - ? mlir::cast(parentOp) - : parentOp->getParentOfType(); + // Find the Block in which the alloca should be inserted. + // The order to recursively find the proper block: + // 1. An OpenMP Op that will be outlined. + // 2. A LLVMFuncOp + // 3. The first ancestor that is an OpenMP Op or a LLVMFuncOp + static mlir::Block *getBlockForAllocaInsert(mlir::Operation *op) { + if (auto iface = + mlir::dyn_cast(op)) + return iface.getAllocaBlock(); + if (auto llvmFuncOp = mlir::dyn_cast(op)) + return &llvmFuncOp.front(); + return getBlockForAllocaInsert(op->getParentOp()); } // Generate an alloca of size 1 and type \p toTy. @@ -350,8 +352,9 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern { genAllocaWithType(mlir::Location loc, mlir::Type toTy, unsigned alignment, mlir::ConversionPatternRewriter &rewriter) const { auto thisPt = rewriter.saveInsertionPoint(); - mlir::LLVM::LLVMFuncOp func = getFuncForAllocaInsert(rewriter); - rewriter.setInsertionPointToStart(&func.front()); + mlir::Operation *parentOp = rewriter.getInsertionBlock()->getParentOp(); + mlir::Block *insertBlock = getBlockForAllocaInsert(parentOp); + rewriter.setInsertionPointToStart(insertBlock); auto size = genI32Constant(loc, rewriter, 1); auto al = rewriter.create(loc, toTy, size, alignment); rewriter.restoreInsertionPoint(thisPt); diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir index 06fc1d0edbe2e..8134fb8792d76 100644 --- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir +++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir @@ -654,4 +654,42 @@ func.func @_QPs(%arg0: !fir.ref> {fir.bindc_name = "x"}) { %0 = fir.alloca !fir.complex<4> {bindc_name = "v", uniq_name = "_QFsEv"} omp.atomic.read %0 = %arg0 : !fir.ref>, !fir.complex<4> return -} +} + +// ----- +// Test if llvm.alloca is properly inserted in the omp section + +//CHECK: %[[CONST:.*]] = llvm.mlir.constant(1 : i64) : i64 +//CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[CONST]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {bindc_name = "iattr", in_type = !fir.box>, operandSegmentSizes = array, uniq_name = "_QFEiattr"} : (i64) -> !llvm.ptr, i64, i32, i8, i8, i8, i8)>> +//CHECK: omp.parallel { +//CHECK: %[[CONST_1:.*]] = llvm.mlir.constant(1 : i32) : i32 +//CHECK: %[[ALLOCA_1:.*]] = llvm.alloca %[[CONST_1:.*]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr, i64, i32, i8, i8, i8, i8)>> +//CHECK: %[[LOAD:.*]] = llvm.load %[[ALLOCA]] : !llvm.ptr, i64, i32, i8, i8, i8, i8)>> +//CHECK: llvm.store %[[LOAD]], %[[ALLOCA_1]] : !llvm.ptr, i64, i32, i8, i8, i8, i8)>> +//CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA_1]][0, 0] : (!llvm.ptr, i64, i32, i8, i8, i8, i8)>>) -> !llvm.ptr> +//CHECK: %[[LOAD_2:.*]] = llvm.load %[[GEP]] : !llvm.ptr> +//CHECK: omp.terminator +//CHECK: } + +func.func @_QQmain() attributes {fir.bindc_name = "mn"} { + %0 = fir.alloca !fir.box> {bindc_name = "iattr", uniq_name = "_QFEiattr"} + %1 = fir.zero_bits !fir.ptr + %2 = fir.embox %1 : (!fir.ptr) -> !fir.box> + fir.store %2 to %0 : !fir.ref>> + %3 = fir.address_of(@_QFEx) : !fir.ref + %4 = fir.alloca i32 {bindc_name = "y", uniq_name = "_QFEy"} + %5 = fir.embox %3 : (!fir.ref) -> !fir.box> + fir.store %5 to %0 : !fir.ref>> + omp.parallel { + %6 = fir.load %4 : !fir.ref + %7 = fir.load %0 : !fir.ref>> + %8 = fir.box_addr %7 : (!fir.box>) -> !fir.ptr + fir.store %6 to %8 : !fir.ptr + omp.terminator + } + return +} +fir.global internal @_QFEx target : i32 { + %0 = fir.zero_bits i32 + fir.has_value %0 : i32 +}