Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2848,31 +2848,75 @@ struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface {
// DialectInlinerInterface
//===----------------------------------------------------------------------===//

// Check whether the given alloca is an input to a lifetime intrinsic,
// optionally passing through one or more casts on the way.
static bool hasLifetimeMarkers(LLVM::AllocaOp allocaOp) {
SmallVector<Operation *, 2> stack(allocaOp->getUsers().begin(),
allocaOp->getUsers().end());
while (!stack.empty()) {
Operation *op = stack.pop_back_val();
if (isa<LLVM::LifetimeStartOp, LLVM::LifetimeEndOp>(op))
return true;
if (isa<LLVM::BitcastOp>(op))
stack.append(op->getUsers().begin(), op->getUsers().end());
}
return false;
}

/// Move all alloca operations with a constant size in the former entry block of
/// the newly inlined callee into the entry block of the caller.
/// the newly inlined callee into the entry block of the caller, and insert
/// lifetime intrinsics that limit their scope to the inlined blocks.
static void moveConstantAllocasToEntryBlock(
iterator_range<Region::iterator> inlinedBlocks) {
Block *calleeEntryBlock = &(*inlinedBlocks.begin());
Block *callerEntryBlock = &(*calleeEntryBlock->getParent()->begin());
if (calleeEntryBlock == callerEntryBlock)
// Nothing to do.
return;
SmallVector<std::pair<LLVM::AllocaOp, IntegerAttr>> allocasToMove;
SmallVector<std::tuple<LLVM::AllocaOp, IntegerAttr, bool>> allocasToMove;
bool shouldInsertLifetimes = false;
// Conservatively only move alloca operations that are part of the entry block
// and do not inspect nested regions, since they may execute conditionally or
// have other unknown semantics.
for (auto allocaOp : calleeEntryBlock->getOps<LLVM::AllocaOp>()) {
IntegerAttr arraySize;
if (matchPattern(allocaOp.getArraySize(), m_Constant(&arraySize)))
allocasToMove.emplace_back(allocaOp, arraySize);
if (!matchPattern(allocaOp.getArraySize(), m_Constant(&arraySize)))
continue;
bool shouldInsertLifetime =
arraySize.getValue() != 0 && !hasLifetimeMarkers(allocaOp);
shouldInsertLifetimes |= shouldInsertLifetime;
allocasToMove.emplace_back(allocaOp, arraySize, shouldInsertLifetime);
}
OpBuilder builder(callerEntryBlock, callerEntryBlock->begin());
for (auto &[allocaOp, arraySize] : allocasToMove) {
for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
auto newConstant = builder.create<LLVM::ConstantOp>(
allocaOp->getLoc(), allocaOp.getArraySize().getType(), arraySize);
// Insert a lifetime start intrinsic where the alloca was before moving it.
if (shouldInsertLifetime) {
OpBuilder::InsertionGuard insertionGuard(builder);
builder.setInsertionPoint(allocaOp);
builder.create<LLVM::LifetimeStartOp>(
allocaOp.getLoc(), arraySize.getValue().getLimitedValue(),
allocaOp.getResult());
}
allocaOp->moveAfter(newConstant);
allocaOp.getArraySizeMutable().assign(newConstant.getResult());
}
if (!shouldInsertLifetimes)
return;
// Insert a lifetime end intrinsic before each return in the callee function.
for (Block &block : inlinedBlocks) {
if (!block.back().hasTrait<OpTrait::ReturnLike>())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there should be block->getTerminator() which is a bit more idiomatic here. Also can we have anything else than a LLVM_ReturnOp? If not I would prefer:

if (isa<LLVM::ReturnOp>(block->getTerminator()))

Copy link
Collaborator Author

@definelicht definelicht Jan 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought getTerminator() would crash, because I saw this in some other scenarios (due to terminators not having been handled yet at this point), but it seems to work.

Regarding LLVM::ReturnOp, that seems to be making it less robust for no reason? Why change it so it breaks if there's a different flavor of return here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I am fine with hasTrait. It is not very commonly used outside of the core infra though...

continue;
builder.setInsertionPoint(&block.back());
for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) {
if (!shouldInsertLifetime)
continue;
builder.create<LLVM::LifetimeEndOp>(
allocaOp.getLoc(), arraySize.getValue().getLimitedValue(),
allocaOp.getResult());
}
}
}

namespace {
Expand Down Expand Up @@ -2912,7 +2956,8 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
return false;
return true;
})
.Case<LLVM::CallOp, LLVM::AllocaOp>([](auto) { return true; })
.Case<LLVM::CallOp, LLVM::AllocaOp, LLVM::LifetimeStartOp,
LLVM::LifetimeEndOp>([](auto) { return true; })
.Default([](auto) { return false; });
}

Expand Down
78 changes: 78 additions & 0 deletions mlir/test/Dialect/LLVMIR/inlining.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,9 @@ llvm.func @test_inline(%cond : i1, %size : i32) -> f32 {
// CHECK: ^{{.+}}:
^bb1:
// CHECK-NOT: llvm.call @static_alloca
// CHECK: llvm.intr.lifetime.start
%0 = llvm.call @static_alloca() : () -> f32
// CHECK: llvm.intr.lifetime.end
// CHECK: llvm.br
llvm.br ^bb3(%0: f32)
// CHECK: ^{{.+}}:
Expand Down Expand Up @@ -275,3 +277,79 @@ llvm.func @test_inline(%cond : i1) -> f32 {
%0 = llvm.call @static_alloca_not_in_entry(%cond) : (i1) -> f32
llvm.return %0 : f32
}

// -----

llvm.func @static_alloca(%cond: i1) -> f32 {
%0 = llvm.mlir.constant(4 : i32) : i32
%1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
llvm.cond_br %cond, ^bb1, ^bb2
^bb1:
%2 = llvm.load %1 : !llvm.ptr -> f32
llvm.return %2 : f32
^bb2:
%3 = llvm.mlir.constant(3.14192 : f32) : f32
llvm.return %3 : f32
}

// CHECK-LABEL: llvm.func @test_inline
llvm.func @test_inline(%cond0 : i1, %cond1 : i1, %funcArg : f32) -> f32 {
// CHECK-NOT: llvm.cond_br
// CHECK: %[[PTR:.+]] = llvm.alloca
// CHECK: llvm.cond_br %{{.+}}, ^[[BB1:.+]], ^{{.+}}
llvm.cond_br %cond0, ^bb1, ^bb2
// CHECK: ^[[BB1]]
^bb1:
// Make sure the lifetime begin intrinsic has been inserted where the call
// used to be, even though the alloca has been moved to the entry block.
// CHECK-NEXT: llvm.intr.lifetime.start 4, %[[PTR]]
%0 = llvm.call @static_alloca(%cond1) : (i1) -> f32
// CHECK: llvm.cond_br %{{.+}}, ^[[BB2:.+]], ^[[BB3:.+]]
llvm.br ^bb3(%0: f32)
// Make sure the lifetime end intrinsic has been inserted at both former
// return sites of the callee.
// CHECK: ^[[BB2]]:
// CHECK-NEXT: llvm.load
// CHECK-NEXT: llvm.intr.lifetime.end 4, %[[PTR]]
// CHECK: ^[[BB3]]:
// CHECK-NEXT: llvm.intr.lifetime.end 4, %[[PTR]]
^bb2:
llvm.br ^bb3(%funcArg: f32)
^bb3(%blockArg: f32):
llvm.return %blockArg : f32
}

// -----

llvm.func @alloca_with_lifetime(%cond: i1) -> f32 {
%0 = llvm.mlir.constant(4 : i32) : i32
%1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
llvm.intr.lifetime.start 4, %1 : !llvm.ptr
%2 = llvm.load %1 : !llvm.ptr -> f32
llvm.intr.lifetime.end 4, %1 : !llvm.ptr
%3 = llvm.fadd %2, %2 : f32
llvm.return %3 : f32
}

// CHECK-LABEL: llvm.func @test_inline
llvm.func @test_inline(%cond0 : i1, %cond1 : i1, %funcArg : f32) -> f32 {
// CHECK-NOT: llvm.cond_br
// CHECK: %[[PTR:.+]] = llvm.alloca
// CHECK: llvm.cond_br %{{.+}}, ^[[BB1:.+]], ^{{.+}}
llvm.cond_br %cond0, ^bb1, ^bb2
// CHECK: ^[[BB1]]
^bb1:
// Make sure the original lifetime intrinsic has been preserved, rather than
// inserting a new one with a larger scope.
// CHECK: llvm.intr.lifetime.start 4, %[[PTR]]
// CHECK-NEXT: llvm.load %[[PTR]]
// CHECK-NEXT: llvm.intr.lifetime.end 4, %[[PTR]]
// CHECK: llvm.fadd
// CHECK-NOT: llvm.intr.lifetime.end
%0 = llvm.call @alloca_with_lifetime(%cond1) : (i1) -> f32
llvm.br ^bb3(%0: f32)
^bb2:
llvm.br ^bb3(%funcArg: f32)
^bb3(%blockArg: f32):
llvm.return %blockArg : f32
}