diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp index 3ef4703fb41d6..a4847288459b4 100644 --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -780,13 +780,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { for (auto fn : mod.getOps()) { if (targetCPUAttr) - fn->setAttr("target_cpu", targetCPUAttr); + fn->setAttr("llvm.target_cpu", targetCPUAttr); if (tuneCPUAttr) - fn->setAttr("tune_cpu", tuneCPUAttr); + fn->setAttr("llvm.tune_cpu", tuneCPUAttr); if (targetFeaturesAttr) - fn->setAttr("target_features", targetFeaturesAttr); + fn->setAttr("llvm.target_features", targetFeaturesAttr); convertSignature(fn); } diff --git a/flang/lib/Optimizer/Transforms/FunctionAttr.cpp b/flang/lib/Optimizer/Transforms/FunctionAttr.cpp index 4655ed6ed0d40..857491fc90cc3 100644 --- a/flang/lib/Optimizer/Transforms/FunctionAttr.cpp +++ b/flang/lib/Optimizer/Transforms/FunctionAttr.cpp @@ -15,6 +15,8 @@ #include "flang/Optimizer/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "llvm/ADT/Twine.h" +#include namespace fir { #define GEN_PASS_DEF_FUNCTIONATTR @@ -25,6 +27,14 @@ namespace fir { namespace { +/// Names of LLVM dialect function properties on `func.func` must use the +/// `llvm.` prefix so convert-func-to-llvm can recognize them and lower them +/// into `llvm.func` properties (bare ODS names are ignored as legacy spellings) +static mlir::StringAttr getLlvmFuncPropertyAttrName(mlir::MLIRContext *ctx, + mlir::StringAttr baseName) { + return mlir::StringAttr::get(ctx, llvm::Twine("llvm.") + baseName.getValue()); +} + class FunctionAttrPass : public fir::impl::FunctionAttrBase { public: FunctionAttrPass(const fir::FunctionAttrOptions &options) : Base{options} {} @@ -73,31 +83,45 @@ void FunctionAttrPass::runOnOperation() { } mlir::MLIRContext *context = &getContext(); - if (framePointerKind != mlir::LLVM::framePointerKind::FramePointerKind::None) - func->setAttr("frame_pointer", mlir::LLVM::FramePointerKindAttr::get( - context, framePointerKind)); - auto llvmFuncOpName = mlir::OperationName(mlir::LLVM::LLVMFuncOp::getOperationName(), context); + + if (framePointerKind != mlir::LLVM::framePointerKind::FramePointerKind::None) + func->setAttr( + getLlvmFuncPropertyAttrName( + context, + mlir::LLVM::LLVMFuncOp::getFramePointerAttrName(llvmFuncOpName)), + mlir::LLVM::FramePointerKindAttr::get(context, framePointerKind)); + if (!instrumentFunctionEntry.empty()) - func->setAttr(mlir::LLVM::LLVMFuncOp::getInstrumentFunctionEntryAttrName( - llvmFuncOpName), - mlir::StringAttr::get(context, instrumentFunctionEntry)); + func->setAttr( + getLlvmFuncPropertyAttrName( + context, mlir::LLVM::LLVMFuncOp::getInstrumentFunctionEntryAttrName( + llvmFuncOpName)), + mlir::StringAttr::get(context, instrumentFunctionEntry)); if (!instrumentFunctionExit.empty()) - func->setAttr(mlir::LLVM::LLVMFuncOp::getInstrumentFunctionExitAttrName( - llvmFuncOpName), - mlir::StringAttr::get(context, instrumentFunctionExit)); + func->setAttr( + getLlvmFuncPropertyAttrName( + context, mlir::LLVM::LLVMFuncOp::getInstrumentFunctionExitAttrName( + llvmFuncOpName)), + mlir::StringAttr::get(context, instrumentFunctionExit)); if (noSignedZerosFPMath) func->setAttr( - mlir::LLVM::LLVMFuncOp::getNoSignedZerosFpMathAttrName(llvmFuncOpName), + getLlvmFuncPropertyAttrName( + context, mlir::LLVM::LLVMFuncOp::getNoSignedZerosFpMathAttrName( + llvmFuncOpName)), mlir::BoolAttr::get(context, true)); if (!reciprocals.empty()) func->setAttr( - mlir::LLVM::LLVMFuncOp::getReciprocalEstimatesAttrName(llvmFuncOpName), + getLlvmFuncPropertyAttrName( + context, mlir::LLVM::LLVMFuncOp::getReciprocalEstimatesAttrName( + llvmFuncOpName)), mlir::StringAttr::get(context, reciprocals)); if (!preferVectorWidth.empty()) func->setAttr( - mlir::LLVM::LLVMFuncOp::getPreferVectorWidthAttrName(llvmFuncOpName), + getLlvmFuncPropertyAttrName( + context, mlir::LLVM::LLVMFuncOp::getPreferVectorWidthAttrName( + llvmFuncOpName)), mlir::StringAttr::get(context, preferVectorWidth)); LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n"); diff --git a/flang/lib/Optimizer/Transforms/VScaleAttr.cpp b/flang/lib/Optimizer/Transforms/VScaleAttr.cpp index d0e83effbbc45..740fc9e85cb85 100644 --- a/flang/lib/Optimizer/Transforms/VScaleAttr.cpp +++ b/flang/lib/Optimizer/Transforms/VScaleAttr.cpp @@ -26,18 +26,20 @@ #include "flang/Optimizer/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/Twine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include -#include +#include namespace fir { #define GEN_PASS_DEF_VSCALEATTR @@ -48,6 +50,13 @@ namespace fir { namespace { +/// See FunctionAttr.cpp: `llvm.func` properties on `func.func` need the `llvm.` +/// prefix for convert-func-to-llvm. +static mlir::StringAttr getLlvmFuncPropertyAttrName(mlir::MLIRContext *ctx, + mlir::StringAttr baseName) { + return mlir::StringAttr::get(ctx, llvm::Twine("llvm.") + baseName.getValue()); +} + class VScaleAttrPass : public fir::impl::VScaleAttrBase { public: VScaleAttrPass(const fir::VScaleAttrOptions &options) { @@ -80,11 +89,15 @@ void VScaleAttrPass::runOnOperation() { return signalPassFailure(); } - auto context = &getContext(); + mlir::MLIRContext *context = &getContext(); + auto llvmFuncOpName = + mlir::OperationName(mlir::LLVM::LLVMFuncOp::getOperationName(), context); auto intTy = mlir::IntegerType::get(context, 32); - func->setAttr("vscale_range", + func->setAttr(getLlvmFuncPropertyAttrName( + context, mlir::LLVM::LLVMFuncOp::getVscaleRangeAttrName( + llvmFuncOpName)), mlir::LLVM::VScaleRangeAttr::get( context, mlir::IntegerAttr::get(intTy, vscaleMin), mlir::IntegerAttr::get(intTy, vscaleMax))); diff --git a/flang/test/Transforms/vscale-attr.fir b/flang/test/Transforms/vscale-attr.fir index 52b6bbba7c9c9..146b1af7c489f 100644 --- a/flang/test/Transforms/vscale-attr.fir +++ b/flang/test/Transforms/vscale-attr.fir @@ -9,11 +9,11 @@ // RUN: not fir-opt --vscale-attr="vscale-min=16 vscale-max=8" %s 2>&1 | FileCheck %s --check-prefix=VSCALE-MIN-GREATER -// CHECK-DEFAULT: attributes {vscale_range = #llvm.vscale_range} -// CHECK-MIN: attributes {vscale_range = #llvm.vscale_range} -// CHECK-MAX: attributes {vscale_range = #llvm.vscale_range} -// CHECK-BOTH: attributes {vscale_range = #llvm.vscale_range} -// CHECK-EQUAL: attributes {vscale_range = #llvm.vscale_range} +// CHECK-DEFAULT: attributes {llvm.vscale_range = #llvm.vscale_range} +// CHECK-MIN: attributes {llvm.vscale_range = #llvm.vscale_range} +// CHECK-MAX: attributes {llvm.vscale_range = #llvm.vscale_range} +// CHECK-BOTH: attributes {llvm.vscale_range = #llvm.vscale_range} +// CHECK-EQUAL: attributes {llvm.vscale_range = #llvm.vscale_range} // VSCALE-MIN-0: VScaleAttr: vscaleMin has to be a power-of-two greater than 0 // VSCALE-MIN-NO-PO2: VScaleAttr: vscaleMin has to be a power-of-two greater than 0 diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 644d27938177b..c09faa86528eb 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -502,6 +502,41 @@ def ConvertFuncToLLVMPass : Pass<"convert-func-to-llvm", "ModuleOp"> { 1 value is returned, packed into an LLVM IR struct type. Function calls and returns are updated accordingly. Block argument types are updated to use LLVM IR types. + + #### Function discardable attributes + + Discardable attributes on `func.func` are lowered as follows. + + - **LLVM `llvm.func` properties.** Each inherent attribute defined on + `llvm.func` (ODS properties such as `target_cpu`, `linkage`, + `vscale_range`, `passthrough`, and so on) must be attached to `func.func` + using the `llvm.` prefix (for example `llvm.target_cpu`, + `llvm.vscale_range`). The pass strips that prefix, validates the attribute + value the same way as for `llvm.func`, and fills the corresponding fields + on the generated `llvm.func`. Values that fail validation make conversion + fail. + + - **Unprefixed legacy names.** A discardable attribute whose name equals + the bare ODS property name (without `llvm.`) is **not** forwarded: it is + dropped. Only the explicit `llvm.*` spelling is lowered into `llvm.func` + properties so that front ends cannot accidentally rely on ambiguous + short names. + + - **Opaque pass-through.** Any other discardable attribute is copied onto + the `llvm.func` unchanged, so arbitrary metadata can survive the + conversion. That includes names that start with `llvm.` but are **not** + inherent `llvm.func` properties (for example dialect-specific markers): they + are not interpreted as properties and are forwarded as discardable + attributes on the result. + + - **`func.varargs`.** This attribute is interpreted when converting the + function type (variadic LLVM signature). It is not an LLVM IR dialect + property and is handled separately from the `llvm.*` property mapping + above. + + - **`llvm.readnone`.** If present, the pass also sets `memory_effects` on + the `llvm.func` to read-none semantics, in addition to any other attribute + handling. }]; let dependentDialects = ["LLVM::LLVMDialect"]; let options = [ diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 75c47f087f78e..ccb933230b65a 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -2066,7 +2066,12 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [ CArg<"SymbolRefAttr", "{}">:$comdat, CArg<"ArrayRef", "{}">:$attrs, CArg<"ArrayRef", "{}">:$argAttrs, - CArg<"std::optional", "{}">:$functionEntryCount)> + CArg<"std::optional", "{}">:$functionEntryCount)>, + OpBuilder<(ins "const Properties &":$properties, + CArg<"ArrayRef", "{}">:$discardableAttributes), [{ + $_state.addRegion(); + $_state.getOrAddProperties() = properties; + $_state.addAttributes(discardableAttributes);}]> ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 6546e74514c74..88abc4400c9b7 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -35,6 +35,7 @@ #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Type.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/FormatVariadic.h" #include @@ -47,6 +48,7 @@ namespace mlir { using namespace mlir; #define PASS_NAME "convert-func-to-llvm" +#define DEBUG_TYPE PASS_NAME static constexpr StringRef varargsAttrName = "func.varargs"; static constexpr StringRef linkageAttrName = "llvm.linkage"; @@ -59,19 +61,81 @@ static bool shouldUseBarePtrCallConv(Operation *op, typeConverter->getOptions().useBarePtrCallConv; } +static bool isDiscardableAttr(StringRef name) { + return name == linkageAttrName || name == varargsAttrName || + name == LLVM::LLVMDialect::getReadnoneAttrName(); +} + /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. static void filterFuncAttributes(FunctionOpInterface func, SmallVectorImpl &result) { for (const NamedAttribute &attr : func->getDiscardableAttrs()) { - if (attr.getName() == linkageAttrName || - attr.getName() == varargsAttrName || - attr.getName() == LLVM::LLVMDialect::getReadnoneAttrName()) + if (isDiscardableAttr(attr.getName().strref())) continue; result.push_back(attr); } } +/// Add custom lowered funcOp to llvm.func attributes here. +struct LoweredFuncAttrs { + LLVM::LLVMFuncOp::Properties properties; + NamedAttrList discardableAttrs; +}; + +/// Lower discardable function attributes on `func.func` to attributes expected +/// by `llvm.func`. +static FailureOr +lowerFuncAttributes(FunctionOpInterface func) { + MLIRContext *ctx = func->getContext(); + LoweredFuncAttrs lowered; + + llvm::SmallDenseSet odsAttrNames( + LLVM::LLVMFuncOp::getAttributeNames().begin(), + LLVM::LLVMFuncOp::getAttributeNames().end()); + + NamedAttrList inherentAttrs; + + for (const NamedAttribute &attr : func->getDiscardableAttrs()) { + StringRef attrName = attr.getName().strref(); + + if (odsAttrNames.contains(attrName)) { + LDBG() << "LLVM specific attributes: " << attrName + << "should use llvm.* prefix, discarding it"; + continue; + } + + StringRef inherent = attrName; + if (inherent.consume_front("llvm.") && odsAttrNames.contains(inherent)) + inherentAttrs.set(inherent, attr.getValue()); // collect inherent attrs + else + lowered.discardableAttrs.push_back(attr); + } + + // Convert collected inherent attrs into typed properties. + if (!inherentAttrs.empty()) { + DictionaryAttr dict = inherentAttrs.getDictionary(ctx); + auto emitError = [&] { + return func.emitOpError("invalid llvm.func property"); + }; + if (failed(LLVM::LLVMFuncOp::setPropertiesFromAttr(lowered.properties, dict, + emitError))) { + return failure(); + } + } + return lowered; +} + +static void buildLLVMFuncProperties(PatternRewriter &rewriter, + FunctionOpInterface srcFunc, + Type llvmFuncType, + LLVM::LLVMFuncOp::Properties &props) { + MLIRContext *ctx = rewriter.getContext(); + props.sym_name = rewriter.getStringAttr(srcFunc.getName()); + props.function_type = TypeAttr::get(llvmFuncType); + props.setCConv(LLVM::CConvAttr::get(ctx, LLVM::CConv::C)); +} + /// Propagate argument/results attributes. static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType, FunctionOpInterface funcOp, @@ -288,6 +352,7 @@ static void restoreByValRefArgumentType( } } +/// TODO: Refactor this function to be more modular and easier to understand. FailureOr mlir::convertFuncOpToLLVMFuncOp( FunctionOpInterface funcOp, ConversionPatternRewriter &rewriter, const LLVMTypeConverter &converter, SymbolTableCollection *symbolTables) { @@ -320,35 +385,10 @@ FailureOr mlir::convertFuncOpToLLVMFuncOp( return funcOp.emitError("C interface for variadic functions is not " "supported yet."); - // Create an LLVM function, use external linkage by default until MLIR - // functions have linkage. - LLVM::Linkage linkage = LLVM::Linkage::External; - if (funcOp->hasAttr(linkageAttrName)) { - auto attr = - dyn_cast(funcOp->getAttr(linkageAttrName)); - if (!attr) { - funcOp->emitError() << "Contains " << linkageAttrName - << " attribute not of type LLVM::LinkageAttr"; - return rewriter.notifyMatchFailure( - funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr"); - } - linkage = attr.getLinkage(); - } - - // Check for invalid attributes. - StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName(); - if (funcOp->hasAttr(readnoneAttrName)) { - auto attr = funcOp->getAttrOfType(readnoneAttrName); - if (!attr) { - funcOp->emitError() << "Contains " << readnoneAttrName - << " attribute not of type UnitAttr"; - return rewriter.notifyMatchFailure( - funcOp, "Contains readnone attribute not of type UnitAttr"); - } - } - - SmallVector attributes; - filterFuncAttributes(funcOp, attributes); + FailureOr loweredAttrs = lowerFuncAttributes(funcOp); + if (failed(loweredAttrs)) + return rewriter.notifyMatchFailure(funcOp, + "failed to lower func attributes"); Operation *symbolTableOp = funcOp->getParentWithTrait(); @@ -356,11 +396,10 @@ FailureOr mlir::convertFuncOpToLLVMFuncOp( SymbolTable &symbolTable = symbolTables->getSymbolTable(symbolTableOp); symbolTable.remove(funcOp); } - - auto newFuncOp = LLVM::LLVMFuncOp::create( - rewriter, funcOp.getLoc(), funcOp.getName(), llvmType, linkage, - /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, - attributes); + buildLLVMFuncProperties(rewriter, funcOp, llvmType, loweredAttrs->properties); + auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, funcOp.getLoc(), + loweredAttrs->properties, + loweredAttrs->discardableAttrs); if (symbolTables && symbolTableOp) { auto ip = rewriter.getInsertionPoint(); @@ -372,6 +411,7 @@ FailureOr mlir::convertFuncOpToLLVMFuncOp( .setVisibility(funcOp.getVisibility()); // Create a memory effect attribute corresponding to readnone. + StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName(); if (funcOp->hasAttr(readnoneAttrName)) { auto memoryAttr = LLVM::MemoryEffectsAttr::get( rewriter.getContext(), {/*other=*/LLVM::ModRefInfo::NoModRef, diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp index 112d69ce87f7f..0c5bcfe631c6c 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -242,8 +242,8 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) { // The switch-resumed API based coroutine should be marked with // presplitcoroutine attribute to mark the function as a coroutine. - func->setAttr("passthrough", builder.getArrayAttr( - StringAttr::get(ctx, "presplitcoroutine"))); + func->setAttr("llvm.passthrough", builder.getArrayAttr(StringAttr::get( + ctx, "presplitcoroutine"))); CoroMachinery machinery; machinery.func = func; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index d7e844b98dc92..ce51884368b69 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -3035,6 +3035,15 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, if (functionEntryCount) result.addAttribute(getFunctionEntryCountAttrName(result.name), builder.getI64IntegerAttr(functionEntryCount.value())); +#ifndef NDEBUG + std::optional duplicate = result.attributes.findDuplicate(); + if (duplicate.has_value()) { + llvm::report_fatal_error( + Twine("LLVMFuncOp propagated an attribute that is meant " + "to be constructed by the builder: ") + + duplicate->getName().str()); + } +#endif if (argAttrs.empty()) return; diff --git a/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir b/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir index 4cb31f8f92661..a648c415bb031 100644 --- a/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir +++ b/mlir/test/Conversion/FuncToLLVM/convert-funcs.mlir @@ -63,12 +63,18 @@ func.func @variadic_func(%arg0: i32) attributes { "func.varargs" = true } { // CHECK-LABEL: llvm.func @target_cpu() // CHECK-SAME: target_cpu = "gfx90a" -func.func private @target_cpu() attributes { "target_cpu" = "gfx90a" } +func.func private @target_cpu() attributes { "llvm.target_cpu" = "gfx90a" } // CHECK-LABEL: llvm.func @target_features() // CHECK-SAME: target_features = #llvm.target_features<["+sme", "+sve"]> func.func private @target_features() attributes { - "target_features" = #llvm.target_features<["+sme", "+sve"]> + "llvm.target_features" = #llvm.target_features<["+sme", "+sve"]> +} + +// CHECK-LABEL: llvm.func @passthrough_attr() +// CHECK-SAME: passthrough = ["presplitcoroutine"] +func.func private @passthrough_attr() attributes { + "llvm.passthrough" = ["presplitcoroutine"] } // ----- @@ -88,7 +94,8 @@ func.func @caller_private_callee(%arg1: f32) -> i32 { // ----- -func.func private @badllvmlinkage(i32) attributes { "llvm.linkage" = 3 : i64 } // expected-error {{Contains llvm.linkage attribute not of type LLVM::LinkageAttr}} +// expected-error@+1{{'func.func' op invalid llvm.func propertyInvalid attribute `linkage` in property conversion: 3 : i64}} +func.func private @badllvmlinkage(i32) attributes { "llvm.linkage" = 3 : i64 } // ----- @@ -103,3 +110,17 @@ func.func @variadic_func(%arg0: i32) attributes { "func.varargs" = true, "llvm.e func.func @empty_res_attrs() attributes {res_attrs = []} { return } + +// ----- + +// Internal `llvm.linkage` must lower correctly +// CHECK-LABEL: llvm.func internal @host_next_to_gpu_module +// CHECK: gpu.module @gpu_mod +func.func @host_next_to_gpu_module() attributes { llvm.linkage = #llvm.linkage } { + return +} +gpu.module @gpu_mod { + gpu.func @gpu_kernel() kernel { + gpu.return + } +}