Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 25 additions & 50 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2850,15 +2850,37 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable);
if (!callOp || !funcOp)
return false;
return isLegalToInlineCallAttributes(callOp) &&
isLegalToInlineFuncAttributes(funcOp);
// TODO: Handle argument and result attributes;
if (funcOp.getArgAttrs() || funcOp.getResAttrs())
return false;
// TODO: Handle exceptions.
if (funcOp.getPersonality())
return false;
if (funcOp.getPassthrough()) {
// TODO: Used attributes should not be passthrough.
DenseSet<StringAttr> disallowed(
{StringAttr::get(funcOp->getContext(), "noduplicate"),
StringAttr::get(funcOp->getContext(), "noinline"),
StringAttr::get(funcOp->getContext(), "optnone"),
StringAttr::get(funcOp->getContext(), "presplitcoroutine"),
StringAttr::get(funcOp->getContext(), "returns_twice"),
StringAttr::get(funcOp->getContext(), "strictfp")});
if (llvm::any_of(*funcOp.getPassthrough(), [&](Attribute attr) {
auto stringAttr = dyn_cast<StringAttr>(attr);
if (!stringAttr)
return false;
return disallowed.contains(stringAttr);
}))
return false;
}
return true;
}

bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
return true;
}

/// Conservative allowlist-based inlining of operations supported so far.
/// Conservative allowlist of operations supported so far.
bool isLegalToInline(Operation *op, Region *, bool, IRMapping &) const final {
if (isPure(op))
return true;
Expand Down Expand Up @@ -2919,53 +2941,6 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
// which newly inlined block was previously the entry block of the callee.
moveConstantAllocasToEntryBlock(inlinedBlocks);
}

private:
/// Returns true if all attributes of `callOp` are handled during inlining.
[[nodiscard]] static bool isLegalToInlineCallAttributes(LLVM::CallOp callOp) {
return all_of(callOp.getAttributeNames(), [&](StringRef attrName) {
return llvm::StringSwitch<bool>(attrName)
// TODO: Propagate and update branch weights.
.Case("branch_weights", !callOp.getBranchWeights())
.Case("callee", true)
.Case("fastmathFlags", true)
.Default(false);
});
}

/// Returns true if all attributes of `funcOp` are handled during inlining.
[[nodiscard]] static bool
isLegalToInlineFuncAttributes(LLVM::LLVMFuncOp funcOp) {
return all_of(funcOp.getAttributeNames(), [&](StringRef attrName) {
return llvm::StringSwitch<bool>(attrName)
.Case("CConv", true)
.Case("arg_attrs", ([&]() {
if (!funcOp.getArgAttrs())
return true;
return llvm::all_of(funcOp.getArgAttrs().value(),
[&](Attribute) {
// TODO: Handle argument attributes.
return false;
});
})())
.Case("dso_local", true)
.Case("function_entry_count", true)
.Case("function_type", true)
// TODO: Once the garbage collector attribute is supported on
// LLVM::CallOp, make sure that the garbage collector matches.
.Case("garbageCollector", !funcOp.getGarbageCollector())
.Case("linkage", true)
.Case("memory", true)
.Case("passthrough", !funcOp.getPassthrough())
// Exception handling is not yet supported, so bail out if the
// personality is set.
.Case("personality", !funcOp.getPersonality())
// TODO: Handle result attributes.
.Case("res_attrs", !funcOp.getResAttrs())
.Case("sym_name", true)
.Default(false);
});
}
};
} // end anonymous namespace

Expand Down
60 changes: 28 additions & 32 deletions mlir/test/Dialect/LLVMIR/inlining.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ llvm.func internal fastcc @callee() -> (i32) attributes { function_entry_count =
// CHECK-NEXT: llvm.return %[[CST]]
llvm.func @caller() -> (i32) {
// Include all call attributes that don't prevent inlining.
%0 = llvm.call @callee() { fastmathFlags = #llvm.fastmath<nnan, ninf> } : () -> (i32)
%0 = llvm.call @callee() { fastmathFlags = #llvm.fastmath<nnan, ninf>, branch_weights = dense<42> : vector<1xi32> } : () -> (i32)
llvm.return %0 : i32
}

Expand Down Expand Up @@ -147,32 +147,42 @@ llvm.func @caller() -> (i32) {

// -----

llvm.func @callee() -> (i32) attributes { passthrough = ["foo"] } {
%0 = llvm.mlir.constant(42 : i32) : i32
llvm.return %0 : i32
llvm.func @callee() attributes { passthrough = ["foo", "bar"] } {
llvm.return
}

// CHECK-LABEL: llvm.func @caller
// CHECK-NEXT: llvm.call @callee
// CHECK-NEXT: return
llvm.func @caller() -> (i32) {
%0 = llvm.call @callee() : () -> (i32)
llvm.return %0 : i32
// CHECK-NEXT: llvm.return
llvm.func @caller() {
llvm.call @callee() : () -> ()
llvm.return
}

// -----

llvm.func @callee() -> (i32) attributes { garbageCollector = "foo" } {
%0 = llvm.mlir.constant(42 : i32) : i32
llvm.return %0 : i32
}
llvm.func @callee_noinline() attributes { passthrough = ["noinline"] }
llvm.func @callee_optnone() attributes { passthrough = ["optnone"] }
llvm.func @callee_noduplicate() attributes { passthrough = ["noduplicate"] }
llvm.func @callee_presplitcoroutine() attributes { passthrough = ["presplitcoroutine"] }
llvm.func @callee_returns_twice() attributes { passthrough = ["returns_twice"] }
llvm.func @callee_strictfp() attributes { passthrough = ["strictfp"] }

// CHECK-LABEL: llvm.func @caller
// CHECK-NEXT: llvm.call @callee
// CHECK-NEXT: return
llvm.func @caller() -> (i32) {
%0 = llvm.call @callee() : () -> (i32)
llvm.return %0 : i32
// CHECK-NEXT: llvm.call @callee_noinline
// CHECK-NEXT: llvm.call @callee_optnone
// CHECK-NEXT: llvm.call @callee_noduplicate
// CHECK-NEXT: llvm.call @callee_presplitcoroutine
// CHECK-NEXT: llvm.call @callee_returns_twice
// CHECK-NEXT: llvm.call @callee_strictfp
// CHECK-NEXT: llvm.return
llvm.func @caller() {
llvm.call @callee_noinline() : () -> ()
llvm.call @callee_optnone() : () -> ()
llvm.call @callee_noduplicate() : () -> ()
llvm.call @callee_presplitcoroutine() : () -> ()
llvm.call @callee_returns_twice() : () -> ()
llvm.call @callee_strictfp() : () -> ()
llvm.return
}

// -----
Expand All @@ -191,20 +201,6 @@ llvm.func @caller(%ptr : !llvm.ptr) -> (!llvm.ptr) {

// -----

llvm.func @callee() {
llvm.return
}

// CHECK-LABEL: llvm.func @caller
// CHECK-NEXT: llvm.call @callee
// CHECK-NEXT: llvm.return
llvm.func @caller() {
llvm.call @callee() { branch_weights = dense<42> : vector<1xi32> } : () -> ()
llvm.return
}

// -----

llvm.func @static_alloca() -> f32 {
%0 = llvm.mlir.constant(4 : i32) : i32
%1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
Expand Down