diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp index 0e43480e82926..b61138ad4678b 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp @@ -603,10 +603,16 @@ static Value handleByValArgumentInit(OpBuilder &builder, Location loc, // Allocate the new value on the stack. Value allocaOp; { - // Since this is a static alloca, we can put it directly in the entry block, - // so they can be absorbed into the prologue/epilogue at code generation. + // Walk up from the call site to find the innermost AutomaticAllocationScope + // (e.g. an llvm.func or scf.forall). Placing the alloca at the entry block + // of that scope keeps it inside parallel regions rather than hoisting it + // out, while still landing at the function entry block for the common + // non-parallel case. OpBuilder::InsertionGuard insertionGuard(builder); - Block *entryBlock = &(*argument.getParentRegion()->begin()); + Operation *scope = builder.getInsertionBlock()->getParentOp(); + if (!scope->mightHaveTrait()) + scope = scope->getParentWithTrait(); + Block *entryBlock = &scope->getRegion(0).front(); builder.setInsertionPointToStart(entryBlock); Value one = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(), builder.getI64IntegerAttr(1)); diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir index 70ce7ca20986b..cc3600af431ea 100644 --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -570,6 +570,32 @@ llvm.func @test_byval_global() { // ----- +// Check that inlining does not hoist byval allocas out of automatic allocation +// scopes, such as parallel forall regions. Each parallel iteration must have +// its own private copy of the byval argument. + +llvm.func @byval_in_parallel(%ptr : !llvm.ptr { llvm.byval = f32 }) { + llvm.return +} + +// CHECK-LABEL: llvm.func @test_byval_in_parallel_region +// CHECK-SAME: %[[PTR:[a-zA-Z0-9_]+]]: !llvm.ptr +llvm.func @test_byval_in_parallel_region(%ptr : !llvm.ptr) { + %c0 = arith.constant 0 : index + // Verify the alloca is not hoisted out of the allocation scope. + // CHECK-NOT: llvm.alloca + // CHECK: test.alloca_scope_region + test.alloca_scope_region { + // CHECK: %[[ALLOCA:.+]] = llvm.alloca %{{.+}} x f32 + // CHECK: "llvm.intr.memcpy"(%[[ALLOCA]], %[[PTR]] + llvm.call @byval_in_parallel(%ptr) : (!llvm.ptr) -> () + test.region_yield %c0 : index + } + llvm.return +} + +// ----- + llvm.func @ignored_attrs(%ptr : !llvm.ptr { llvm.inreg, llvm.nocapture, llvm.nofree, llvm.preallocated = i32, llvm.returned, llvm.alignstack = 32 : i64, llvm.writeonly, llvm.noundef, llvm.nonnull }, %x : i32 { llvm.zeroext }) -> (!llvm.ptr { llvm.noundef, llvm.inreg, llvm.nonnull }) { llvm.return %ptr : !llvm.ptr }