diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index f19f9d5a3083c..61ba8f7b991c8 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -2073,9 +2073,9 @@ def LLVM_ConstantOp Unlike LLVM IR, MLIR does not have first-class constant values. Therefore, all constants must be created as SSA values before being used in other operations. `llvm.mlir.constant` creates such values for scalars, vectors, - strings, and structs. It has a mandatory `value` attribute whose type - depends on the type of the constant value. The type of the constant value - must correspond to the attribute type converted to LLVM IR type. + strings, structs, and array of structs. It has a mandatory `value` attribute + whose type depends on the type of the constant value. The type of the constant + value must correspond to the attribute type converted to LLVM IR type. When creating constant scalars, the `value` attribute must be either an integer attribute or a floating point attribute. The type of the attribute @@ -2097,6 +2097,11 @@ def LLVM_ConstantOp must correspond to the type of the corresponding attribute element converted to LLVM IR. + When creating an array of structs, the `value` attribute must be an array + attribute, itself containing zero, or undef, or array attributes for each + potential nested array type, and the elements of the leaf array attributes + for must match the struct element types or be zero or undef attributes. + Examples: ```mlir diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index c757f3ceb90e3..d8abf6fd41301 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -3142,6 +3142,74 @@ static bool hasScalableVectorType(Type t) { return false; } +/// Verifies the constant array represented by `arrayAttr` matches the provided +/// `arrayType`. +static LogicalResult verifyStructArrayConstant(LLVM::ConstantOp op, + LLVM::LLVMArrayType arrayType, + ArrayAttr arrayAttr, int dim) { + if (arrayType.getNumElements() != arrayAttr.size()) + return op.emitOpError() + << "array attribute size does not match array type size in " + "dimension " + << dim << ": " << arrayAttr.size() << " vs. " + << arrayType.getNumElements(); + + llvm::DenseSet elementsVerified; + + // Recursively verify sub-dimensions for multidimensional arrays. + if (auto subArrayType = + dyn_cast(arrayType.getElementType())) { + for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr)) + if (elementsVerified.insert(elementAttr).second) { + if (isa(elementAttr)) + continue; + auto subArrayAttr = dyn_cast(elementAttr); + if (!subArrayAttr) + return op.emitOpError() + << "nested attribute for sub-array in dimension " << dim + << " at index " << idx + << " must be a zero, or undef, or array attribute"; + if (failed(verifyStructArrayConstant(op, subArrayType, subArrayAttr, + dim + 1))) + return failure(); + } + return success(); + } + + // Forbid usages of ArrayAttr for simple array types that should use + // DenseElementsAttr instead. Note that there would be a use case for such + // array types when one element value is obtained via a ptr-to-int conversion + // from a symbol and cannot be represented in a DenseElementsAttr, but no MLIR + // user needs this so far, and it seems better to avoid people misusing the + // ArrayAttr for simple types. + auto structType = dyn_cast(arrayType.getElementType()); + if (!structType) + return op.emitOpError() << "for array with an array attribute must have a " + "struct element type"; + + // Shallow verification that leaf attributes are appropriate as struct initial + // value. + size_t numStructElements = structType.getBody().size(); + for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr)) { + if (elementsVerified.insert(elementAttr).second) { + if (isa(elementAttr)) + continue; + auto subArrayAttr = dyn_cast(elementAttr); + if (!subArrayAttr) + return op.emitOpError() + << "nested attribute for struct element at index " << idx + << " must be a zero, or undef, or array attribute"; + if (subArrayAttr.size() != numStructElements) + return op.emitOpError() + << "nested array attribute size for struct element at index " + << idx << " must match struct size: " << subArrayAttr.size() + << " vs. " << numStructElements; + } + } + + return success(); +} + LogicalResult LLVM::ConstantOp::verify() { if (StringAttr sAttr = llvm::dyn_cast(getValue())) { auto arrayType = llvm::dyn_cast(getType()); @@ -3208,7 +3276,7 @@ LogicalResult LLVM::ConstantOp::verify() { if (isa(getType()) && !getType().isInteger(floatWidth)) { return emitOpError() << "expected integer type of width " << floatWidth; } - } else if (isa(getValue())) { + } else if (isa(getValue())) { if (hasScalableVectorType(getType())) { // The exact number of elements of a scalable vector is unknown, so we // allow only splat attributes. @@ -3221,15 +3289,20 @@ LogicalResult LLVM::ConstantOp::verify() { if (!isa(getType())) return emitOpError() << "expected vector or array type"; // The number of elements of the attribute and the type must match. - int64_t attrNumElements; - if (auto elementsAttr = dyn_cast(getValue())) - attrNumElements = elementsAttr.getNumElements(); - else - attrNumElements = cast(getValue()).size(); - if (getNumElements(getType()) != attrNumElements) - return emitOpError() - << "type and attribute have a different number of elements: " - << getNumElements(getType()) << " vs. " << attrNumElements; + if (auto elementsAttr = dyn_cast(getValue())) { + int64_t attrNumElements = elementsAttr.getNumElements(); + if (getNumElements(getType()) != attrNumElements) + return emitOpError() + << "type and attribute have a different number of elements: " + << getNumElements(getType()) << " vs. " << attrNumElements; + } + } else if (auto arrayAttr = dyn_cast(getValue())) { + auto arrayType = dyn_cast(getType()); + if (!arrayType) + return emitOpError() << "expected array type"; + // When the attribute is an ArrayAttr, check that its nesting matches the + // corresponding ArrayType or VectorType nesting. + return verifyStructArrayConstant(*this, arrayType, arrayAttr, /*dim=*/0); } else { return emitOpError() << "only supports integer, float, string or elements attributes"; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 1168b9f339904..229682dff7a24 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -553,8 +553,10 @@ static llvm::Constant *convertDenseResourceElementsAttr( llvm::Constant *mlir::LLVM::detail::getLLVMConstant( llvm::Type *llvmType, Attribute attr, Location loc, const ModuleTranslation &moduleTranslation) { - if (!attr) + if (!attr || isa(attr)) return llvm::UndefValue::get(llvmType); + if (isa(attr)) + return llvm::Constant::getNullValue(llvmType); if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) { auto arrayAttr = dyn_cast(attr); if (!arrayAttr) { @@ -713,6 +715,33 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( ArrayRef{stringAttr.getValue().data(), stringAttr.getValue().size()}); } + + // Handle arrays of structs that cannot be represented as DenseElementsAttr + // in MLIR. + if (auto arrayAttr = dyn_cast(attr)) { + if (auto *arrayTy = dyn_cast(llvmType)) { + llvm::Type *elementType = arrayTy->getElementType(); + Attribute previousElementAttr; + llvm::Constant *elementCst = nullptr; + SmallVector constants; + constants.reserve(arrayTy->getNumElements()); + for (Attribute elementAttr : arrayAttr) { + // Arrays with a single value or with repeating values are quite common. + // Short-circuit the translation when the element value is the same as + // the previous one. + if (!previousElementAttr || previousElementAttr != elementAttr) { + previousElementAttr = elementAttr; + elementCst = + getLLVMConstant(elementType, elementAttr, loc, moduleTranslation); + if (!elementCst) + return nullptr; + } + constants.push_back(elementCst); + } + return llvm::ConstantArray::get(arrayTy, constants); + } + } + emitError(loc, "unsupported constant value"); return nullptr; } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index f9ea066a63624..f5adf4b3bf33d 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1850,3 +1850,35 @@ llvm.func @gep_inbounds_flag_usage(%ptr: !llvm.ptr, %idx: i64) { llvm.getelementptr inbounds_flag %ptr[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)> llvm.return } + +// ----- + +llvm.mlir.global @bad_struct_array_init_size() : !llvm.array<2x!llvm.struct<(i32, f32)>> { + // expected-error@below {{'llvm.mlir.constant' op array attribute size does not match array type size in dimension 0: 1 vs. 2}} + %0 = llvm.mlir.constant([[42 : i32, 1.000000e+00 : f32]]) : !llvm.array<2x!llvm.struct<(i32, f32)>> + llvm.return %0 : !llvm.array<2x!llvm.struct<(i32, f32)>> +} + +// ----- + +llvm.mlir.global @bad_struct_array_init_nesting() : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>> { + // expected-error@below {{'llvm.mlir.constant' op nested attribute for sub-array in dimension 1 at index 0 must be a zero, or undef, or array attribute}} + %0 = llvm.mlir.constant([[1 : i32]]) : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>> + llvm.return %0 : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>> +} + +// ----- + +llvm.mlir.global @bad_struct_array_init_elements() : !llvm.array<1x!llvm.struct<(i32, f32)>> { + // expected-error@below {{'llvm.mlir.constant' op nested array attribute size for struct element at index 0 must match struct size: 1 vs. 2}} + %0 = llvm.mlir.constant([[1 : i32]]) : !llvm.array<1x!llvm.struct<(i32, f32)>> + llvm.return %0 : !llvm.array<1x!llvm.struct<(i32, f32)>> +} + +// ---- + +llvm.mlir.global internal constant @bad_array_attr_simple_type() : !llvm.array<2 x f64> { + // expected-error@below {{'llvm.mlir.constant' op for array with an array attribute must have a struct element type}} + %0 = llvm.mlir.constant([2.5, 7.4]) : !llvm.array<2 x f64> + llvm.return %0 : !llvm.array<2 x f64> +} diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir index 90c0f5ac55cb1..24a7b42557278 100644 --- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir @@ -79,11 +79,6 @@ llvm.func @incompatible_integer_type_for_float_attr() -> i32 { // ----- -// expected-error @below{{unsupported constant value}} -llvm.mlir.global internal constant @test([2.5, 7.4]) : !llvm.array<2 x f64> - -// ----- - // expected-error @below{{LLVM attribute 'readonly' does not expect a value}} llvm.func @passthrough_unexpected_value() attributes {passthrough = [["readonly", "42"]]} diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 4ef68fa83a70d..ccc2bec223113 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -3000,3 +3000,29 @@ llvm.func internal @i(%arg0: i32) attributes {dso_local} { llvm.call @testfn3(%arg0) : (i32 {llvm.alignstack = 8 : i64}) -> () llvm.return } + +// ----- + +// CHECK: @test_array_attr_2 = global [2 x { i32, float }] [{ i32, float } { i32 42, float 1.000000e+00 }, { i32, float } { i32 42, float 1.000000e+00 }] +llvm.mlir.global @test_array_attr_2() : !llvm.array<2 x !llvm.struct<(i32, f32)>> { + %0 = llvm.mlir.constant([[42 : i32, 1.000000e+00 : f32],[42 : i32, 1.000000e+00 : f32]]) : !llvm.array<2 x !llvm.struct<(i32, f32)>> + llvm.return %0 : !llvm.array<2 x !llvm.struct<(i32, f32)>> +} + +// CHECK: @test_array_attr_3 = global [2 x [3 x { i32, float }]{{.*}}[3 x { i32, float }] [{ i32, float } { i32 1, float 1.000000e+00 }, { i32, float } { i32 2, float 1.000000e+00 }, { i32, float } { i32 3, float 1.000000e+00 }], [3 x { i32, float }] [{ i32, float } { i32 4, float 1.000000e+00 }, { i32, float } { i32 5, float 1.000000e+00 }, { i32, float } { i32 6, float 1.000000e+00 } +llvm.mlir.global @test_array_attr_3() : !llvm.array<2 x !llvm.array<3 x !llvm.struct<(i32, f32)>>> { + %0 = llvm.mlir.constant([[[1 : i32, 1.000000e+00 : f32], [2 : i32, 1.000000e+00 : f32], [3 : i32, 1.000000e+00 : f32]], [[4 : i32, 1.000000e+00 : f32], [5 : i32, 1.000000e+00 : f32], [6 : i32, 1.000000e+00 : f32]]]) : !llvm.array<2 x !llvm.array<3 x !llvm.struct<(i32, f32)>>> + llvm.return %0 : !llvm.array<2 x !llvm.array<3 x !llvm.struct<(i32, f32)>>> +} + +// CHECK: @test_array_attr_struct_with_ptr = internal constant [2 x { ptr }] [{ ptr } zeroinitializer, { ptr } undef] +llvm.mlir.global internal constant @test_array_attr_struct_with_ptr() : !llvm.array<2 x struct<(ptr)>> { + %0 = llvm.mlir.constant([[#llvm.zero], [#llvm.undef]]) : !llvm.array<2 x struct<(ptr)>> + llvm.return %0 : !llvm.array<2 x struct<(ptr)>> +} + +// CHECK: @test_array_attr_struct_with_struct = internal constant [3 x { i32, float }] [{ i32, float } zeroinitializer, { i32, float } { i32 2, float 1.000000e+00 }, { i32, float } undef] +llvm.mlir.global internal constant @test_array_attr_struct_with_struct() : !llvm.array<3 x struct<(i32, f32)>> { + %0 = llvm.mlir.constant([#llvm.zero, [2 : i32, 1.0 : f32], #llvm.undef]) : !llvm.array<3 x struct<(i32, f32)>> + llvm.return %0 : !llvm.array<3 x struct<(i32, f32)>> +}