diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 93602a286c9c4..2c167d23a7d9a 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -63,6 +63,12 @@ mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes, MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMFunctionTypeGetName(void); +/// Returns `true` if the type is an LLVM dialect function type. +MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMFunctionType(MlirType type); + +/// Returns the TypeID of an LLVM function type. +MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMFunctionTypeGetTypeID(void); + /// Returns the number of input types. MLIR_CAPI_EXPORTED intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type); @@ -70,6 +76,9 @@ MLIR_CAPI_EXPORTED intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type); MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetInput(MlirType type, intptr_t pos); +/// Returns `true` if the function type is variadic. +MLIR_CAPI_EXPORTED bool mlirLLVMFunctionTypeIsVarArg(MlirType type); + /// Returns the return type of the function type. MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type); @@ -190,6 +199,7 @@ enum MlirLLVMCConv { MlirLLVMCConvAMDGPU_Gfx = 100, MlirLLVMCConvM68k_INTR = 101, }; + typedef enum MlirLLVMCConv MlirLLVMCConv; /// Creates a LLVM CConv attribute. @@ -205,6 +215,7 @@ enum MlirLLVMComdat { MlirLLVMComdatNoDeduplicate = 3, MlirLLVMComdatSameSize = 4, }; + typedef enum MlirLLVMComdat MlirLLVMComdat; /// Creates a LLVM Comdat attribute. @@ -226,6 +237,7 @@ enum MlirLLVMLinkage { MlirLLVMLinkageExternWeak = 9, MlirLLVMLinkageCommon = 10, }; + typedef enum MlirLLVMLinkage MlirLLVMLinkage; /// Creates a LLVM Linkage attribute. @@ -274,6 +286,7 @@ enum MlirLLVMTypeEncoding { MlirLLVMTypeEncodingLoUser = 0x80, MlirLLVMTypeEncodingHiUser = 0xff, }; + typedef enum MlirLLVMTypeEncoding MlirLLVMTypeEncoding; /// Creates a LLVM DIBasicType attribute. @@ -337,6 +350,7 @@ enum MlirLLVMDIEmissionKind { MlirLLVMDIEmissionKindLineTablesOnly = 2, MlirLLVMDIEmissionKindDebugDirectivesOnly = 3, }; + typedef enum MlirLLVMDIEmissionKind MlirLLVMDIEmissionKind; enum MlirLLVMDINameTableKind { @@ -345,6 +359,7 @@ enum MlirLLVMDINameTableKind { MlirLLVMDINameTableKindNone = 2, MlirLLVMDINameTableKindApple = 3, }; + typedef enum MlirLLVMDINameTableKind MlirLLVMDINameTableKind; /// Creates a LLVM DICompileUnit attribute. @@ -456,6 +471,69 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMDIImportedEntityAttrGetName(void); MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIModuleAttrGetScope(MlirAttribute diModule); +//===----------------------------------------------------------------------===// +// Metadata Attributes +//===----------------------------------------------------------------------===// + +/// Creates an LLVM MDStringAttr. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMMDStringAttrGet(MlirContext ctx, + MlirStringRef value); + +/// Returns `true` if the attribute is an LLVM MDStringAttr. +MLIR_CAPI_EXPORTED bool mlirLLVMAttrIsAMDStringAttr(MlirAttribute attr); + +/// Returns the TypeID of MDStringAttr. +MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMMDStringAttrGetTypeID(void); + +/// Returns the string value of an LLVM MDStringAttr. +MLIR_CAPI_EXPORTED MlirStringRef +mlirLLVMMDStringAttrGetValue(MlirAttribute attr); + +/// Creates an LLVM MDConstantAttr wrapping an attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMMDConstantAttrGet(MlirContext ctx, MlirAttribute valueAttr); + +/// Returns `true` if the attribute is an LLVM MDConstantAttr. +MLIR_CAPI_EXPORTED bool mlirLLVMAttrIsAMDConstantAttr(MlirAttribute attr); + +/// Returns the TypeID of MDConstantAttr. +MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMMDConstantAttrGetTypeID(void); + +/// Returns the attribute value of an LLVM MDConstantAttr. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMMDConstantAttrGetValue(MlirAttribute attr); + +/// Creates an LLVM MDFuncAttr referencing a function symbol. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMMDFuncAttrGet(MlirContext ctx, + MlirAttribute name); + +/// Returns `true` if the attribute is an LLVM MDFuncAttr. +MLIR_CAPI_EXPORTED bool mlirLLVMAttrIsAMDFuncAttr(MlirAttribute attr); + +/// Returns the TypeID of MDFuncAttr. +MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMMDFuncAttrGetTypeID(void); + +/// Returns the symbol name of an LLVM MDFuncAttr. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMMDFuncAttrGetName(MlirAttribute attr); + +/// Creates an LLVM MDNodeAttr. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMMDNodeAttrGet( + MlirContext ctx, intptr_t nOperands, MlirAttribute const *operands); + +/// Returns `true` if the attribute is an LLVM MDNodeAttr. +MLIR_CAPI_EXPORTED bool mlirLLVMAttrIsAMDNodeAttr(MlirAttribute attr); + +/// Returns the TypeID of MDNodeAttr. +MLIR_CAPI_EXPORTED MlirTypeID mlirLLVMMDNodeAttrGetTypeID(void); + +/// Returns the number of operands in an LLVM MDNodeAttr. +MLIR_CAPI_EXPORTED intptr_t +mlirLLVMMDNodeAttrGetNumOperands(MlirAttribute attr); + +/// Returns the operand at the given index of an LLVM MDNodeAttr. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMMDNodeAttrGetOperand(MlirAttribute attr, intptr_t index); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td index 36acf244865eb..14a4f888bd51d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -1686,4 +1686,75 @@ def UWTableKindAttr : LLVM_Attr<"UWTableKind", "uwtableKind"> { let assemblyFormat = "`<` $uwtableKind `>`"; } +//===----------------------------------------------------------------------===// +// Metadata Attributes +//===----------------------------------------------------------------------===// +// +// These attributes model LLVM IR metadata nodes (llvm::Metadata and its +// subclasses). They can be nested to form arbitrary metadata trees and are +// translated to their LLVM IR counterparts during MLIR-to-LLVM-IR conversion. + +def LLVM_MDStringAttr : LLVM_Attr<"MDString", "md_string"> { + let summary = "LLVM metadata string"; + let description = [{ + Wraps a string as an LLVM metadata node, corresponding to + `llvm::MDString` in LLVM IR. + + Example: + ```mlir + #llvm.md_string<"foo.buffer"> + ``` + }]; + let parameters = (ins "StringAttr":$value); + let assemblyFormat = "`<` $value `>`"; +} + +def LLVM_MDConstantAttr : LLVM_Attr<"MDConstant", "md_const"> { + let summary = "LLVM constant-as-metadata"; + let description = [{ + Wraps an attribute as an LLVM metadata node, corresponding to + `llvm::ConstantAsMetadata` wrapping a `llvm::Constant*` in LLVM IR. + Currently, only integers/IntegerAttrs supported. + + Example: + ```mlir + #llvm.md_const<42 : i32> + ``` + }]; + let parameters = (ins "Attribute":$value); + let assemblyFormat = "`<` $value `>`"; +} + +def LLVM_MDFuncAttr : LLVM_Attr<"MDFunc", "md_func"> { + let summary = "LLVM function-as-metadata"; + let description = [{ + References a function (or global) symbol as LLVM metadata, corresponding + to `llvm::ValueAsMetadata::get(function)` in LLVM IR. + + Example: + ```mlir + #llvm.md_func<@my_kernel> + ``` + }]; + let parameters = (ins "FlatSymbolRefAttr":$name); + let assemblyFormat = "`<` $name `>`"; +} + +def LLVM_MDNodeAttr : LLVM_Attr<"MDNode", "md_node"> { + let summary = "LLVM metadata node"; + let description = [{ + Represents an LLVM metadata node. The operands + can be any combination of metadata attributes: `#llvm.md_string`, + `#llvm.md_const`, `#llvm.md_func`, or nested `#llvm.md_node`. + + Example: + ```mlir + #llvm.md_node<#llvm.md_const<0 : i32>, #llvm.md_string<"foo.buffer">> + #llvm.md_node<> + ``` + }]; + let parameters = (ins OptionalArrayRefParameter<"Attribute">:$operands); + let assemblyFormat = "`<` (`>`) : ($operands^ `>`)?"; +} + #endif // LLVMIR_ATTRDEFS diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index e781d4c876315..75c47f087f78e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -2576,4 +2576,53 @@ def LLVM_ModuleFlagsOp let hasVerifier = 1; } +//===--------------------------------------------------------------------===// +// NamedMetadataOp +//===--------------------------------------------------------------------===// + +def LLVM_NamedMetadataOp + : LLVM_Op<"named_metadata"> { + let summary = "Module-level named metadata"; + let description = [{ + Represents an LLVM named metadata node (`llvm::NamedMDNode`). Named + metadata nodes are module-level metadata that associate a name string + with a list of metadata nodes. Each operand must be an `#llvm.md_node`. + + Note: cyclic metadata graphs are not supported. Because metadata attributes + are represented as MLIR attributes (which form a tree), there is no way to + express a metadata node that directly or transitively references itself. + LLVM IR permits such cycles (e.g. `!0 = !{!0}`), but they cannot be + represented here and will not round-trip through this op. + + Example: + ```mlir + llvm.named_metadata "foo.version" [ + #llvm.md_node<#llvm.md_const<2 : i32>, + #llvm.md_const<9 : i32>, + #llvm.md_const<0 : i32> + > + ] + llvm.named_metadata "foo.kernel" [ + #llvm.md_node< + #llvm.md_func<@my_kernel>, + #llvm.md_node<>, + #llvm.md_node< + #llvm.md_node<#llvm.md_const<0 : i32>, + #llvm.md_string<"foo.buffer"> + > + > + > + ] + ``` + }]; + let arguments = (ins StrAttr:$metadata_name, ArrayAttr:$nodes); + let assemblyFormat = [{ + $metadata_name $nodes attr-dict + }]; + + let llvmBuilder = [{ + convertNamedMetadataOp($metadata_name, $nodes, builder, moduleTranslation); + }]; +} + #endif // LLVMIR_OPS diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp index 5c79f515c49eb..7e4f24b556613 100644 --- a/mlir/lib/Bindings/Python/DialectLLVM.cpp +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include +#include #include "mlir-c/Dialect/LLVM.h" #include "mlir-c/IR.h" @@ -222,10 +223,175 @@ struct PointerType : PyConcreteType { } }; +//===--------------------------------------------------------------------===// +// FunctionType +//===--------------------------------------------------------------------===// + +struct FunctionType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsALLVMFunctionType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirLLVMFunctionTypeGetTypeID; + static constexpr const char *pyClassName = "FunctionType"; + static inline const MlirStringRef name = mlirLLVMFunctionTypeGetName(); + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &resultType, const std::vector &argumentTypes, + bool isVarArg) { + std::vector argTypes(argumentTypes.size()); + std::copy(argumentTypes.begin(), argumentTypes.end(), + argTypes.begin()); + return FunctionType( + resultType.getContext(), + mlirLLVMFunctionTypeGet(resultType, argTypes.size(), + argTypes.data(), isVarArg)); + }, + "result_type"_a, "argument_types"_a, nb::kw_only(), + "is_var_arg"_a = false); + c.def_prop_ro("return_type", [](const FunctionType &type) { + return mlirLLVMFunctionTypeGetReturnType(type); + }); + c.def_prop_ro("num_inputs", [](const FunctionType &type) { + return mlirLLVMFunctionTypeGetNumInputs(type); + }); + c.def_prop_ro("inputs", [](const FunctionType &type) { + nb::list inputs; + for (intptr_t i = 0, e = mlirLLVMFunctionTypeGetNumInputs(type); i < e; + ++i) { + inputs.append(mlirLLVMFunctionTypeGetInput(type, i)); + } + return inputs; + }); + c.def_prop_ro("is_var_arg", [](const FunctionType &type) { + return mlirLLVMFunctionTypeIsVarArg(type); + }); + } +}; + +//===--------------------------------------------------------------------===// +// Metadata Attributes +//===--------------------------------------------------------------------===// + +struct MDStringAttr : PyConcreteAttribute { + static constexpr IsAFunctionTy isaFunction = mlirLLVMAttrIsAMDStringAttr; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirLLVMMDStringAttrGetTypeID; + static constexpr const char *pyClassName = "MDStringAttr"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const std::string &value, DefaultingPyMlirContext context) { + return MDStringAttr( + context->getRef(), + mlirLLVMMDStringAttrGet( + context.get()->get(), + mlirStringRefCreate(value.data(), value.size()))); + }, + "value"_a, nb::kw_only(), "context"_a = nb::none()); + c.def_prop_ro("value", [](const MDStringAttr &self) { + MlirStringRef ref = mlirLLVMMDStringAttrGetValue(self); + return nb::str(ref.data, ref.length); + }); + } +}; + +struct MDConstantAttr : PyConcreteAttribute { + static constexpr IsAFunctionTy isaFunction = mlirLLVMAttrIsAMDConstantAttr; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirLLVMMDConstantAttrGetTypeID; + static constexpr const char *pyClassName = "MDConstantAttr"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyAttribute &valueAttr, DefaultingPyMlirContext context) { + return MDConstantAttr( + context->getRef(), + mlirLLVMMDConstantAttrGet(context.get()->get(), valueAttr)); + }, + "value"_a, nb::kw_only(), "context"_a = nb::none()); + c.def_prop_ro("value", [](const MDConstantAttr &self) { + return mlirLLVMMDConstantAttrGetValue(self); + }); + } +}; + +struct MDFuncAttr : PyConcreteAttribute { + static constexpr IsAFunctionTy isaFunction = mlirLLVMAttrIsAMDFuncAttr; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirLLVMMDFuncAttrGetTypeID; + static constexpr const char *pyClassName = "MDFuncAttr"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const std::string &name, DefaultingPyMlirContext context) { + MlirAttribute symRef = mlirFlatSymbolRefAttrGet( + context.get()->get(), + mlirStringRefCreate(name.data(), name.size())); + return MDFuncAttr( + context->getRef(), + mlirLLVMMDFuncAttrGet(context.get()->get(), symRef)); + }, + "name"_a, nb::kw_only(), "context"_a = nb::none()); + c.def_prop_ro("name", [](const MDFuncAttr &self) { + MlirAttribute symRef = mlirLLVMMDFuncAttrGetName(self); + MlirStringRef ref = mlirFlatSymbolRefAttrGetValue(symRef); + return nb::str(ref.data, ref.length); + }); + } +}; + +struct MDNodeAttr : PyConcreteAttribute { + static constexpr IsAFunctionTy isaFunction = mlirLLVMAttrIsAMDNodeAttr; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirLLVMMDNodeAttrGetTypeID; + static constexpr const char *pyClassName = "MDNodeAttr"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const std::vector &operands, + DefaultingPyMlirContext context) { + std::vector operands_(operands.size()); + std::copy(operands.begin(), operands.end(), operands_.begin()); + return MDNodeAttr(context->getRef(), + mlirLLVMMDNodeAttrGet(context.get()->get(), + operands_.size(), + operands_.data())); + }, + "operands"_a, nb::kw_only(), "context"_a = nb::none()); + c.def_prop_ro("num_operands", [](const MDNodeAttr &self) { + return mlirLLVMMDNodeAttrGetNumOperands(self); + }); + c.def("__getitem__", [](const MDNodeAttr &self, intptr_t index) { + intptr_t n = mlirLLVMMDNodeAttrGetNumOperands(self); + if (index < 0 || index >= n) + throw nb::index_error("MDNodeAttr operand index out of range"); + return mlirLLVMMDNodeAttrGetOperand(self, index); + }); + c.def("__len__", [](const MDNodeAttr &self) { + return mlirLLVMMDNodeAttrGetNumOperands(self); + }); + } +}; + static void populateDialectLLVMSubmodule(nanobind::module_ &m) { StructType::bind(m); ArrayType::bind(m); PointerType::bind(m); + FunctionType::bind(m); + MDStringAttr::bind(m); + MDConstantAttr::bind(m); + MDFuncAttr::bind(m); + MDNodeAttr::bind(m); m.def( "translate_module_to_llvmir", diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index eb30f169e4289..154ebbc51e961 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -85,6 +85,14 @@ MlirStringRef mlirLLVMFunctionTypeGetName(void) { return wrap(LLVMFunctionType::name); } +bool mlirTypeIsALLVMFunctionType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirLLVMFunctionTypeGetTypeID(void) { + return wrap(LLVM::LLVMFunctionType::getTypeID()); +} + intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type) { return llvm::cast(unwrap(type)).getNumParams(); } @@ -99,6 +107,10 @@ MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type) { return wrap(llvm::cast(unwrap(type)).getReturnType()); } +bool mlirLLVMFunctionTypeIsVarArg(MlirType type) { + return llvm::cast(unwrap(type)).isVarArg(); +} + bool mlirTypeIsALLVMStructType(MlirType type) { return isa(unwrap(type)); } @@ -523,3 +535,82 @@ MlirAttribute mlirLLVMDIAnnotationAttrGet(MlirContext ctx, MlirAttribute name, MlirStringRef mlirLLVMDIAnnotationAttrGetName(void) { return wrap(DIAnnotationAttr::name); } + +//===----------------------------------------------------------------------===// +// Metadata Attributes +//===----------------------------------------------------------------------===// + +MlirAttribute mlirLLVMMDStringAttrGet(MlirContext ctx, MlirStringRef value) { + return wrap(MDStringAttr::get(unwrap(ctx), + StringAttr::get(unwrap(ctx), unwrap(value)))); +} + +bool mlirLLVMAttrIsAMDStringAttr(MlirAttribute attr) { + return isa(unwrap(attr)); +} + +MlirTypeID mlirLLVMMDStringAttrGetTypeID(void) { + return wrap(MDStringAttr::getTypeID()); +} + +MlirStringRef mlirLLVMMDStringAttrGetValue(MlirAttribute attr) { + return wrap(cast(unwrap(attr)).getValue().getValue()); +} + +MlirAttribute mlirLLVMMDConstantAttrGet(MlirContext ctx, + MlirAttribute valueAttr) { + return wrap(MDConstantAttr::get(unwrap(ctx), unwrap(valueAttr))); +} + +bool mlirLLVMAttrIsAMDConstantAttr(MlirAttribute attr) { + return isa(unwrap(attr)); +} + +MlirTypeID mlirLLVMMDConstantAttrGetTypeID(void) { + return wrap(MDConstantAttr::getTypeID()); +} + +MlirAttribute mlirLLVMMDConstantAttrGetValue(MlirAttribute attr) { + return wrap((Attribute)cast(unwrap(attr)).getValue()); +} + +MlirAttribute mlirLLVMMDFuncAttrGet(MlirContext ctx, MlirAttribute name) { + return wrap( + MDFuncAttr::get(unwrap(ctx), cast(unwrap(name)))); +} + +bool mlirLLVMAttrIsAMDFuncAttr(MlirAttribute attr) { + return isa(unwrap(attr)); +} + +MlirTypeID mlirLLVMMDFuncAttrGetTypeID(void) { + return wrap(MDFuncAttr::getTypeID()); +} + +MlirAttribute mlirLLVMMDFuncAttrGetName(MlirAttribute attr) { + return wrap((Attribute)cast(unwrap(attr)).getName()); +} + +MlirAttribute mlirLLVMMDNodeAttrGet(MlirContext ctx, intptr_t nOperands, + MlirAttribute const *operands) { + SmallVector attrStorage; + attrStorage.reserve(nOperands); + return wrap(MDNodeAttr::get(unwrap(ctx), + unwrapList(nOperands, operands, attrStorage))); +} + +bool mlirLLVMAttrIsAMDNodeAttr(MlirAttribute attr) { + return isa(unwrap(attr)); +} + +MlirTypeID mlirLLVMMDNodeAttrGetTypeID(void) { + return wrap(MDNodeAttr::getTypeID()); +} + +intptr_t mlirLLVMMDNodeAttrGetNumOperands(MlirAttribute attr) { + return cast(unwrap(attr)).getOperands().size(); +} + +MlirAttribute mlirLLVMMDNodeAttrGetOperand(MlirAttribute attr, intptr_t index) { + return wrap(cast(unwrap(attr)).getOperands()[index]); +} diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 21f7954fd338a..e60a682b42ea1 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Operation.h" #include "mlir/Interfaces/CallInterfaces.h" @@ -215,6 +216,54 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, return success(); } +/// Recursively converts an MLIR metadata attribute to an LLVM metadata node. +static llvm::Metadata * +convertMetadataAttr(Attribute attr, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + return llvm::TypeSwitch(attr) + .Case([&](auto a) -> llvm::Metadata * { + return llvm::MDString::get(builder.getContext(), + a.getValue().getValue()); + }) + .Case([&](auto a) -> llvm::Metadata * { + IntegerAttr intAttr = llvm::dyn_cast(a.getValue()); + if (!intAttr) + return nullptr; + return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get( + llvm::Type::getIntNTy(builder.getContext(), + intAttr.getType().getIntOrFloatBitWidth()), + intAttr.getValue())); + }) + .Case([&](auto a) -> llvm::Metadata * { + if (llvm::Function *fn = + moduleTranslation.lookupFunction(a.getName().getValue())) + return llvm::ValueAsMetadata::get(fn); + return nullptr; + }) + .Case([&](auto a) -> llvm::Metadata * { + SmallVector operands; + for (Attribute op : a.getOperands()) + operands.push_back( + convertMetadataAttr(op, builder, moduleTranslation)); + return llvm::MDNode::get(builder.getContext(), operands); + }) + .Default([](auto) -> llvm::Metadata * { return nullptr; }); +} + +static void convertNamedMetadataOp(StringRef metadataName, ArrayAttr nodes, + llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); + llvm::NamedMDNode *namedMD = + llvmModule->getOrInsertNamedMetadata(metadataName); + for (Attribute nodeAttr : nodes) { + llvm::Metadata *md = + convertMetadataAttr(nodeAttr, builder, moduleTranslation); + if (auto *mdNode = llvm::dyn_cast_or_null(md)) + namedMD->addOperand(mdNode); + } +} + static void convertLinkerOptionsOp(ArrayAttr options, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { diff --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py index 1fd7e64251e61..23ed23997c3ba 100644 --- a/mlir/python/mlir/dialects/llvm.py +++ b/mlir/python/mlir/dialects/llvm.py @@ -6,7 +6,7 @@ from ._llvm_ops_gen import _Dialect from ._llvm_enum_gen import * from .._mlir_libs._mlirDialectsLLVM import * -from ..ir import Value +from ..ir import Value, IntegerType, IntegerAttr from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results @@ -14,3 +14,16 @@ def mlir_constant(value, *, loc=None, ip=None) -> Value: return _get_op_result_or_op_results( ConstantOp(res=value.type, value=value, loc=loc, ip=ip) ) + + +def md_const(val, *, width=32, context=None): + if not isinstance(val, int): + raise NotImplementedError( + f"{val=} not supported; only integers currently supported." + ) + i_type = IntegerType.get_signless(width, context=context) + return MDConstantAttr.get(IntegerAttr.get(i_type, val), context=context) + + +def md_str(s, *, context=None): + return MDStringAttr.get(s, context=context) diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index 437d06841f7e4..500988ea894de 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -1117,3 +1117,39 @@ llvm.func @escapedtypename() { %1 = llvm.alloca %0 x !llvm.struct<"bucket::Iterator", (ptr, i64, i64)> {alignment = 8 : i64} : (i32) -> !llvm.ptr llvm.return } + +// Metadata attributes and llvm.named_metadata op. + +llvm.func @md_kernel() { + llvm.return +} + +// CHECK: llvm.named_metadata "foo.version" [#llvm.md_node<#llvm.md_const<1 : i32>, #llvm.md_const<0 : i32>, #llvm.md_const<0 : i32>>] +llvm.named_metadata "foo.version" [ + #llvm.md_node< + #llvm.md_const<1 : i32>, + #llvm.md_const<0 : i32>, + #llvm.md_const<0 : i32> + > +] + +// CHECK: llvm.named_metadata "foo.language" [#llvm.md_node<#llvm.md_string<"Bar">, #llvm.md_const<1 : i32>, #llvm.md_const<2 : i32>>] +llvm.named_metadata "foo.language" [ + #llvm.md_node< + #llvm.md_string<"Bar">, + #llvm.md_const<1 : i32>, + #llvm.md_const<2 : i32> + > +] + +// CHECK: llvm.named_metadata "foo.kernel" [#llvm.md_node<#llvm.md_func<@md_kernel>, #llvm.md_node<>, #llvm.md_node<#llvm.md_const<0 : i32>, #llvm.md_string<"foo.buffer">>>] +llvm.named_metadata "foo.kernel" [ + #llvm.md_node< + #llvm.md_func<@md_kernel>, + #llvm.md_node<>, + #llvm.md_node< + #llvm.md_const<0 : i32>, + #llvm.md_string<"foo.buffer"> + > + > +] diff --git a/mlir/test/Target/LLVMIR/llvmir-named-metadata.mlir b/mlir/test/Target/LLVMIR/llvmir-named-metadata.mlir new file mode 100644 index 0000000000000..493616430c822 --- /dev/null +++ b/mlir/test/Target/LLVMIR/llvmir-named-metadata.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// Tests LLVM named metadata translation with deeply nested metadata trees. + +// CHECK: !foo.version = !{![[VERSION:[0-9]+]]} +// CHECK: !foo.language_version = !{![[LANG:[0-9]+]]} +// CHECK: !foo.kernel = !{![[KERNEL:[0-9]+]]} + +llvm.func @my_kernel() { + llvm.return +} + +llvm.named_metadata "foo.version" [ + #llvm.md_node<#llvm.md_const<1 : i32>, + #llvm.md_const<0 : i32>, + #llvm.md_const<0 : i32>> +] +// CHECK-DAG: ![[VERSION]] = !{i32 1, i32 0, i32 0} + +llvm.named_metadata "foo.language_version" [ + #llvm.md_node<#llvm.md_string<"Bar">, + #llvm.md_const<1 : i32>, + #llvm.md_const<2 : i32>, + #llvm.md_const<3 : i32>> +] +// CHECK-DAG: ![[LANG]] = !{!"Bar", i32 1, i32 2, i32 3} + +#buf0 = #llvm.md_node< + #llvm.md_const<0 : i32>, #llvm.md_string<"foo.buffer">, + #llvm.md_string<"foo.idx">, #llvm.md_const<0 : i32>, + #llvm.md_const<1 : i32>, #llvm.md_string<"foo.read">, + #llvm.md_string<"foo.address_space">, #llvm.md_const<1 : i32>, + #llvm.md_string<"foo.size">, #llvm.md_const<4 : i32>, + #llvm.md_string<"foo.align_size">, #llvm.md_const<4 : i32>> +// CHECK-DAG: ![[A0:[0-9]+]] = !{i32 0, !"foo.buffer", !"foo.idx", i32 0, i32 1, !"foo.read", !"foo.address_space", i32 1, !"foo.size", i32 4, !"foo.align_size", i32 4} + +llvm.named_metadata "foo.kernel" [ + #llvm.md_node< + #llvm.md_func<@my_kernel>, + #llvm.md_node<>, + #llvm.md_node<#buf0>> +] +// CHECK-DAG: ![[KERNEL]] = !{ptr @my_kernel, ![[EMPTY:[0-9]+]], ![[ARGS:[0-9]+]]} +// CHECK-DAG: ![[EMPTY]] = !{} +// CHECK-DAG: ![[ARGS]] = !{![[A0]]} diff --git a/mlir/test/python/dialects/llvm.py b/mlir/test/python/dialects/llvm.py index 7a2b8e1809c47..1ed77cf3b84b3 100644 --- a/mlir/test/python/dialects/llvm.py +++ b/mlir/test/python/dialects/llvm.py @@ -215,3 +215,140 @@ def testTranslateToLLVMIR(): # CHECK: ret i64 %3 # CHECK: } print(llvm.translate_module_to_llvmir(module.operation)) + + +# CHECK-LABEL: testMetadataAttrs +@constructAndPrintInModule +def testMetadataAttrs(): + # MDStringAttr + md_str = llvm.MDStringAttr.get("foo.buffer") + # CHECK: #llvm.md_string<"foo.buffer"> + print(md_str) + assert md_str.value == "foo.buffer" + + # MDConstantAttr + i32 = IntegerType.get_signless(32) + md_const = llvm.MDConstantAttr.get(IntegerAttr.get(i32, 42)) + # CHECK: #llvm.md_const<42 : i32> + print(md_const) + + # MDFuncAttr + md_func = llvm.MDFuncAttr.get("my_kernel") + # CHECK: #llvm.md_func<@my_kernel> + print(md_func) + assert md_func.name == "my_kernel" + + # MDNodeAttr - empty + md_empty = llvm.MDNodeAttr.get([]) + # CHECK: #llvm.md_node<> + print(md_empty) + assert len(md_empty) == 0 + + # MDNodeAttr - with operands + md_node = llvm.MDNodeAttr.get([md_const, md_str]) + # CHECK: #llvm.md_node<#llvm.md_const<42 : i32>, #llvm.md_string<"foo.buffer">> + print(md_node) + assert len(md_node) == 2 + + # MDNodeAttr - __getitem__ + # CHECK: #llvm.md_const<42 : i32> + print(md_node[0]) + # CHECK: #llvm.md_string<"foo.buffer"> + print(md_node[1]) + assert str(md_node[0]) == str(md_const) + assert str(md_node[1]) == str(md_str) + + # MDNodeAttr - nested + md_nested = llvm.MDNodeAttr.get([md_node, md_empty]) + # CHECK: #llvm.md_node<#llvm.md_node<#llvm.md_const<42 : i32>, #llvm.md_string<"foo.buffer">>, #llvm.md_node<>> + print(md_nested) + assert len(md_nested) == 2 + + +# CHECK-LABEL: testNamedMetadata +@constructAndPrintInModule +def testNamedMetadata(): + void = Type.parse("!llvm.void") + func_ty = llvm.FunctionType.get(void, []) + + llvm.LLVMFuncOp("my_kernel", TypeAttr.get(func_ty)) + # CHECK-LABEL: llvm.func @my_kernel() + + llvm.NamedMetadataOp( + metadata_name="foo.version", + nodes=ArrayAttr.get( + [ + llvm.MDNodeAttr.get( + [llvm.md_const(1), llvm.md_const(0), llvm.md_const(0)] + ) + ] + ), + ) + # CHECK: llvm.named_metadata "foo.version" [#llvm.md_node<#llvm.md_const<1 : i32>, #llvm.md_const<0 : i32>, #llvm.md_const<0 : i32>>] + + llvm.NamedMetadataOp( + metadata_name="foo.language_version", + nodes=ArrayAttr.get( + [ + llvm.MDNodeAttr.get( + [ + llvm.md_str("Bar"), + llvm.md_const(1), + llvm.md_const(2), + llvm.md_const(3), + ] + ) + ] + ), + ) + # CHECK: llvm.named_metadata "foo.language_version" [#llvm.md_node<#llvm.md_string<"Bar">, #llvm.md_const<1 : i32>, #llvm.md_const<2 : i32>, #llvm.md_const<3 : i32>>] + + buf0 = llvm.MDNodeAttr.get( + [ + llvm.md_const(0), + llvm.md_str("foo.buffer"), + llvm.md_str("foo.idx"), + llvm.md_const(0), + llvm.md_const(1), + llvm.md_str("foo.read"), + llvm.md_str("foo.address_space"), + llvm.md_const(1), + llvm.md_str("foo.size"), + llvm.md_const(4), + llvm.md_str("foo.align_size"), + llvm.md_const(4), + ] + ) + + llvm.NamedMetadataOp( + metadata_name="foo.kernel", + nodes=ArrayAttr.get( + [ + llvm.MDNodeAttr.get( + [ + llvm.MDFuncAttr.get("my_kernel"), + llvm.MDNodeAttr.get([]), + buf0, + ] + ) + ] + ), + ) + # CHECK: llvm.named_metadata "foo.kernel" [ + # CHECK-SAME: #llvm.md_node< + # CHECK-SAME: #llvm.md_func<@my_kernel>, + # CHECK-SAME: #llvm.md_node<>, + # CHECK-SAME: #llvm.md_node< + # CHECK-SAME: #llvm.md_const<0 : i32>, + # CHECK-SAME: #llvm.md_string<"foo.buffer">, + # CHECK-SAME: #llvm.md_string<"foo.idx">, + # CHECK-SAME: #llvm.md_const<0 : i32>, + # CHECK-SAME: #llvm.md_const<1 : i32>, + # CHECK-SAME: #llvm.md_string<"foo.read">, + # CHECK-SAME: #llvm.md_string<"foo.address_space">, + # CHECK-SAME: #llvm.md_const<1 : i32>, + # CHECK-SAME: #llvm.md_string<"foo.size">, + # CHECK-SAME: #llvm.md_const<4 : i32>, + # CHECK-SAME: #llvm.md_string<"foo.align_size">, + # CHECK-SAME: #llvm.md_const<4 : i32>> + # CHECK-SAME: >]