Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2848,6 +2848,33 @@ struct LLVMOpAsmDialectInterface : public OpAsmDialectInterface {
// DialectInlinerInterface
//===----------------------------------------------------------------------===//

/// Move all alloca operations with a constant size in the former entry block of
/// the newly inlined callee into the entry block of the caller.
static void moveConstantAllocasToEntryBlock(
iterator_range<Region::iterator> inlinedBlocks) {
Block *calleeEntryBlock = &(*inlinedBlocks.begin());
Block *callerEntryBlock = &(*calleeEntryBlock->getParent()->begin());
if (calleeEntryBlock == callerEntryBlock)
// Nothing to do.
return;
SmallVector<std::pair<LLVM::AllocaOp, IntegerAttr>> allocasToMove;
// Conservatively only move alloca operations that are part of the entry block
// and do not inspect nested regions, since they may execute conditionally or
// have other unknown semantics.
for (auto allocaOp : calleeEntryBlock->getOps<LLVM::AllocaOp>()) {
IntegerAttr arraySize;
if (matchPattern(allocaOp.getArraySize(), m_Constant(&arraySize)))
allocasToMove.emplace_back(allocaOp, arraySize);
}
OpBuilder builder(callerEntryBlock, callerEntryBlock->begin());
for (auto &[allocaOp, arraySize] : allocasToMove) {
auto newConstant = builder.create<LLVM::ConstantOp>(
allocaOp->getLoc(), allocaOp.getArraySize().getType(), arraySize);
allocaOp->moveAfter(newConstant);
allocaOp.getArraySizeMutable().assign(newConstant.getResult());
}
}

namespace {
struct LLVMInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
Expand Down Expand Up @@ -2885,7 +2912,7 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
return false;
return true;
})
.Case<LLVM::CallOp>([](auto) { return true; })
.Case<LLVM::CallOp, LLVM::AllocaOp>([](auto) { return true; })
.Default([](auto) { return false; });
}

Expand Down Expand Up @@ -2918,6 +2945,17 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
dst.replaceAllUsesWith(src);
}

void processInlinedCallBlocks(
Operation *call,
iterator_range<Region::iterator> inlinedBlocks) const override {
// Alloca operations with a constant size that were in the entry block of
// the callee should be moved to the entry block of the caller, as this will
// fold into prologue/epilogue code during code generation.
// This is not implemented as a standalone pattern because we need to know
// which newly inlined block was previously the entry block of the callee.
moveConstantAllocasToEntryBlock(inlinedBlocks);
}

private:
/// Returns true if all attributes of `callOp` are handled during inlining.
[[nodiscard]] static bool isLegalToInlineCallAttributes(LLVM::CallOp callOp) {
Expand Down
92 changes: 82 additions & 10 deletions mlir/test/Dialect/LLVMIR/inlining.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@ func.func @test_inline(%ptr : !llvm.ptr) -> i32 {

// -----

func.func @inner_func_not_inlinable() -> !llvm.ptr<f64> {
%0 = llvm.mlir.constant(0 : i32) : i32
%1 = llvm.alloca %0 x f64 : (i32) -> !llvm.ptr<f64>
return %1 : !llvm.ptr<f64>
func.func @inner_func_not_inlinable() -> i32 {
%0 = llvm.inline_asm has_side_effects "foo", "bar" : () -> i32
return %0 : i32
}

// CHECK-LABEL: func.func @test_not_inline() -> !llvm.ptr<f64> {
// CHECK-NEXT: %[[RES:.*]] = call @inner_func_not_inlinable() : () -> !llvm.ptr<f64>
// CHECK-NEXT: return %[[RES]] : !llvm.ptr<f64>
func.func @test_not_inline() -> !llvm.ptr<f64> {
%0 = call @inner_func_not_inlinable() : () -> !llvm.ptr<f64>
return %0 : !llvm.ptr<f64>
// CHECK-LABEL: func.func @test_not_inline() -> i32 {
// CHECK-NEXT: %[[RES:.*]] = call @inner_func_not_inlinable() : () -> i32
// CHECK-NEXT: return %[[RES]] : i32
func.func @test_not_inline() -> i32 {
%0 = call @inner_func_not_inlinable() : () -> i32
return %0 : i32
}

// -----
Expand Down Expand Up @@ -203,3 +202,76 @@ llvm.func @caller() {
llvm.call @callee() { branch_weights = dense<42> : vector<1xi32> } : () -> ()
llvm.return
}

// -----

llvm.func @static_alloca() -> f32 {
%0 = llvm.mlir.constant(4 : i32) : i32
%1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
%2 = llvm.load %1 : !llvm.ptr -> f32
llvm.return %2 : f32
}

llvm.func @dynamic_alloca(%size : i32) -> f32 {
%0 = llvm.add %size, %size : i32
%1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
%2 = llvm.load %1 : !llvm.ptr -> f32
llvm.return %2 : f32
}

// CHECK-LABEL: llvm.func @test_inline
llvm.func @test_inline(%cond : i1, %size : i32) -> f32 {
// Check that the static alloca was moved to the entry block after inlining
// with its size defined by a constant.
// CHECK-NOT: ^{{.+}}:
// CHECK-NEXT: llvm.mlir.constant
// CHECK-NEXT: llvm.alloca
// CHECK: llvm.cond_br
llvm.cond_br %cond, ^bb1, ^bb2
// CHECK: ^{{.+}}:
^bb1:
// CHECK-NOT: llvm.call @static_alloca
%0 = llvm.call @static_alloca() : () -> f32
// CHECK: llvm.br
llvm.br ^bb3(%0: f32)
// CHECK: ^{{.+}}:
^bb2:
// Check that the dynamic alloca was inlined, but that it was not moved to the
// entry block.
// CHECK: llvm.add
// CHECK-NEXT: llvm.alloca
// CHECK-NOT: llvm.call @dynamic_alloca
%1 = llvm.call @dynamic_alloca(%size) : (i32) -> f32
// CHECK: llvm.br
llvm.br ^bb3(%1: f32)
// CHECK: ^{{.+}}:
^bb3(%arg : f32):
llvm.return %arg : f32
}

// -----

llvm.func @static_alloca_not_in_entry(%cond : i1) -> f32 {
llvm.cond_br %cond, ^bb1, ^bb2
^bb1:
%0 = llvm.mlir.constant(4 : i32) : i32
%1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
llvm.br ^bb3(%1: !llvm.ptr)
^bb2:
%2 = llvm.mlir.constant(8 : i32) : i32
%3 = llvm.alloca %2 x f32 : (i32) -> !llvm.ptr
llvm.br ^bb3(%3: !llvm.ptr)
^bb3(%ptr : !llvm.ptr):
%4 = llvm.load %ptr : !llvm.ptr -> f32
llvm.return %4 : f32
}

// CHECK-LABEL: llvm.func @test_inline
llvm.func @test_inline(%cond : i1) -> f32 {
// Make sure the alloca was not moved to the entry block.
// CHECK-NOT: llvm.alloca
// CHECK: llvm.cond_br
// CHECK: llvm.alloca
%0 = llvm.call @static_alloca_not_in_entry(%cond) : (i1) -> f32
llvm.return %0 : f32
}