From ae36117ffda3dea173c79337f8ea2865ffa4ad6a Mon Sep 17 00:00:00 2001 From: Johannes de Fine Licht Date: Wed, 15 Feb 2023 15:50:00 +0100 Subject: [PATCH] [MLIR][LLVM] Disallow inlining for selected function attributes. This loosens the requirement of no passthrough function attribute being present to checking for specific attributes that prevent inlining. Since these attributes are no longer strictly passthrough, they should eventually be upgraded to some form of addressable attributes. --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 75 ++++++++-------------- mlir/test/Dialect/LLVMIR/inlining.mlir | 60 ++++++++--------- 2 files changed, 53 insertions(+), 82 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index d5069214c8f0f..16c61f3969d63 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2850,15 +2850,37 @@ struct LLVMInlinerInterface : public DialectInlinerInterface { auto funcOp = dyn_cast(callable); if (!callOp || !funcOp) return false; - return isLegalToInlineCallAttributes(callOp) && - isLegalToInlineFuncAttributes(funcOp); + // TODO: Handle argument and result attributes; + if (funcOp.getArgAttrs() || funcOp.getResAttrs()) + return false; + // TODO: Handle exceptions. + if (funcOp.getPersonality()) + return false; + if (funcOp.getPassthrough()) { + // TODO: Used attributes should not be passthrough. + DenseSet disallowed( + {StringAttr::get(funcOp->getContext(), "noduplicate"), + StringAttr::get(funcOp->getContext(), "noinline"), + StringAttr::get(funcOp->getContext(), "optnone"), + StringAttr::get(funcOp->getContext(), "presplitcoroutine"), + StringAttr::get(funcOp->getContext(), "returns_twice"), + StringAttr::get(funcOp->getContext(), "strictfp")}); + if (llvm::any_of(*funcOp.getPassthrough(), [&](Attribute attr) { + auto stringAttr = dyn_cast(attr); + if (!stringAttr) + return false; + return disallowed.contains(stringAttr); + })) + return false; + } + return true; } bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { return true; } - /// Conservative allowlist-based inlining of operations supported so far. + /// Conservative allowlist of operations supported so far. bool isLegalToInline(Operation *op, Region *, bool, IRMapping &) const final { if (isPure(op)) return true; @@ -2919,53 +2941,6 @@ struct LLVMInlinerInterface : public DialectInlinerInterface { // which newly inlined block was previously the entry block of the callee. moveConstantAllocasToEntryBlock(inlinedBlocks); } - -private: - /// Returns true if all attributes of `callOp` are handled during inlining. - [[nodiscard]] static bool isLegalToInlineCallAttributes(LLVM::CallOp callOp) { - return all_of(callOp.getAttributeNames(), [&](StringRef attrName) { - return llvm::StringSwitch(attrName) - // TODO: Propagate and update branch weights. - .Case("branch_weights", !callOp.getBranchWeights()) - .Case("callee", true) - .Case("fastmathFlags", true) - .Default(false); - }); - } - - /// Returns true if all attributes of `funcOp` are handled during inlining. - [[nodiscard]] static bool - isLegalToInlineFuncAttributes(LLVM::LLVMFuncOp funcOp) { - return all_of(funcOp.getAttributeNames(), [&](StringRef attrName) { - return llvm::StringSwitch(attrName) - .Case("CConv", true) - .Case("arg_attrs", ([&]() { - if (!funcOp.getArgAttrs()) - return true; - return llvm::all_of(funcOp.getArgAttrs().value(), - [&](Attribute) { - // TODO: Handle argument attributes. - return false; - }); - })()) - .Case("dso_local", true) - .Case("function_entry_count", true) - .Case("function_type", true) - // TODO: Once the garbage collector attribute is supported on - // LLVM::CallOp, make sure that the garbage collector matches. - .Case("garbageCollector", !funcOp.getGarbageCollector()) - .Case("linkage", true) - .Case("memory", true) - .Case("passthrough", !funcOp.getPassthrough()) - // Exception handling is not yet supported, so bail out if the - // personality is set. - .Case("personality", !funcOp.getPersonality()) - // TODO: Handle result attributes. - .Case("res_attrs", !funcOp.getResAttrs()) - .Case("sym_name", true) - .Default(false); - }); - } }; } // end anonymous namespace diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir index 65632434b1349..ab28f4236af97 100644 --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -83,7 +83,7 @@ llvm.func internal fastcc @callee() -> (i32) attributes { function_entry_count = // CHECK-NEXT: llvm.return %[[CST]] llvm.func @caller() -> (i32) { // Include all call attributes that don't prevent inlining. - %0 = llvm.call @callee() { fastmathFlags = #llvm.fastmath } : () -> (i32) + %0 = llvm.call @callee() { fastmathFlags = #llvm.fastmath, branch_weights = dense<42> : vector<1xi32> } : () -> (i32) llvm.return %0 : i32 } @@ -147,32 +147,42 @@ llvm.func @caller() -> (i32) { // ----- -llvm.func @callee() -> (i32) attributes { passthrough = ["foo"] } { - %0 = llvm.mlir.constant(42 : i32) : i32 - llvm.return %0 : i32 +llvm.func @callee() attributes { passthrough = ["foo", "bar"] } { + llvm.return } // CHECK-LABEL: llvm.func @caller -// CHECK-NEXT: llvm.call @callee -// CHECK-NEXT: return -llvm.func @caller() -> (i32) { - %0 = llvm.call @callee() : () -> (i32) - llvm.return %0 : i32 +// CHECK-NEXT: llvm.return +llvm.func @caller() { + llvm.call @callee() : () -> () + llvm.return } // ----- -llvm.func @callee() -> (i32) attributes { garbageCollector = "foo" } { - %0 = llvm.mlir.constant(42 : i32) : i32 - llvm.return %0 : i32 -} +llvm.func @callee_noinline() attributes { passthrough = ["noinline"] } +llvm.func @callee_optnone() attributes { passthrough = ["optnone"] } +llvm.func @callee_noduplicate() attributes { passthrough = ["noduplicate"] } +llvm.func @callee_presplitcoroutine() attributes { passthrough = ["presplitcoroutine"] } +llvm.func @callee_returns_twice() attributes { passthrough = ["returns_twice"] } +llvm.func @callee_strictfp() attributes { passthrough = ["strictfp"] } // CHECK-LABEL: llvm.func @caller -// CHECK-NEXT: llvm.call @callee -// CHECK-NEXT: return -llvm.func @caller() -> (i32) { - %0 = llvm.call @callee() : () -> (i32) - llvm.return %0 : i32 +// CHECK-NEXT: llvm.call @callee_noinline +// CHECK-NEXT: llvm.call @callee_optnone +// CHECK-NEXT: llvm.call @callee_noduplicate +// CHECK-NEXT: llvm.call @callee_presplitcoroutine +// CHECK-NEXT: llvm.call @callee_returns_twice +// CHECK-NEXT: llvm.call @callee_strictfp +// CHECK-NEXT: llvm.return +llvm.func @caller() { + llvm.call @callee_noinline() : () -> () + llvm.call @callee_optnone() : () -> () + llvm.call @callee_noduplicate() : () -> () + llvm.call @callee_presplitcoroutine() : () -> () + llvm.call @callee_returns_twice() : () -> () + llvm.call @callee_strictfp() : () -> () + llvm.return } // ----- @@ -191,20 +201,6 @@ llvm.func @caller(%ptr : !llvm.ptr) -> (!llvm.ptr) { // ----- -llvm.func @callee() { - llvm.return -} - -// CHECK-LABEL: llvm.func @caller -// CHECK-NEXT: llvm.call @callee -// CHECK-NEXT: llvm.return -llvm.func @caller() { - llvm.call @callee() { branch_weights = dense<42> : vector<1xi32> } : () -> () - llvm.return -} - -// ----- - llvm.func @static_alloca() -> f32 { %0 = llvm.mlir.constant(4 : i32) : i32 %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr