diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index 1750c0c9a3cfd3..f2b50669250335 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -351,8 +351,7 @@ SetVector getTopologicallySortedBlocks(Region ®ion); /// report it to `loc` and return nullptr. llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc, - const ModuleTranslation &moduleTranslation, - bool isTopLevel = true); + const ModuleTranslation &moduleTranslation); /// Creates a call to an LLVM IR intrinsic function with the given arguments. llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder, diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 127e7e15ccab93..b0e231ded6d54b 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -214,7 +214,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, (type.isa() || hasVectorElementType)) { llvm::Constant *splatValue = LLVM::detail::getLLVMConstant( innermostLLVMType, denseElementsAttr.getSplatValue(), loc, - moduleTranslation, /*isTopLevel=*/false); + moduleTranslation); llvm::Constant *splatVector = llvm::ConstantDataVector::getSplat(0, splatValue); SmallVector constants(numAggregates, splatVector); @@ -272,22 +272,22 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, /// report it to `loc` and return nullptr. llvm::Constant *mlir::LLVM::detail::getLLVMConstant( llvm::Type *llvmType, Attribute attr, Location loc, - const ModuleTranslation &moduleTranslation, bool isTopLevel) { + const ModuleTranslation &moduleTranslation) { if (!attr) return llvm::UndefValue::get(llvmType); if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) { - if (!isTopLevel) { - emitError(loc, "nested struct types are not supported in constants"); + auto arrayAttr = attr.dyn_cast(); + if (!arrayAttr || arrayAttr.size() != 2) { + emitError(loc, "expected struct type to be a complex number"); return nullptr; } - auto arrayAttr = attr.cast(); llvm::Type *elementType = structType->getElementType(0); - llvm::Constant *real = getLLVMConstant(elementType, arrayAttr[0], loc, - moduleTranslation, false); + llvm::Constant *real = + getLLVMConstant(elementType, arrayAttr[0], loc, moduleTranslation); if (!real) return nullptr; - llvm::Constant *imag = getLLVMConstant(elementType, arrayAttr[1], loc, - moduleTranslation, false); + llvm::Constant *imag = + getLLVMConstant(elementType, arrayAttr[1], loc, moduleTranslation); if (!imag) return nullptr; return llvm::ConstantStruct::get(structType, {real, imag}); @@ -336,7 +336,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( elementType, elementTypeSequential ? splatAttr : splatAttr.getSplatValue(), - loc, moduleTranslation, false); + loc, moduleTranslation); if (!child) return nullptr; if (llvmType->isVectorTy()) @@ -367,7 +367,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( llvm::Type *innermostType = getInnermostElementType(llvmType); for (auto n : elementsAttr.getValues()) { constants.push_back( - getLLVMConstant(innermostType, n, loc, moduleTranslation, false)); + getLLVMConstant(innermostType, n, loc, moduleTranslation)); if (!constants.back()) return nullptr; } diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir index fdbbf9e8fcc98a..ba23c8700c48dd 100644 --- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir @@ -41,14 +41,22 @@ llvm.func @invalid_align(%arg0 : f32 {llvm.align = 4}) -> f32 { // ----- -llvm.func @no_nested_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> { - // expected-error @+1 {{nested struct types are not supported in constants}} +llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> { + // expected-error @+1 {{expected struct type to be a complex number}} %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> } // ----- +llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> { + // expected-error @+1 {{expected struct type to be a complex number}} + %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> + llvm.return %0 : !llvm.array<2 x array<2 x array<2 x struct<(i32, i32, i32)>>>> +} + +// ----- + llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> { // expected-error @+1 {{FloatAttr does not match expected type of the constant}} %0 = llvm.mlir.constant([1.0 : f32, 1.0 : f32]) : !llvm.struct<(f64, f64)> diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index cd14641944c300..b4a2dbcf02d8a7 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1122,6 +1122,18 @@ llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> { llvm.return %1 : !llvm.struct<(i32, i32)> } +llvm.func @complexintconstantsplat() -> !llvm.array<2 x !llvm.struct<(i32, i32)>> { + %1 = llvm.mlir.constant(dense<(0, 1)> : tensor>) : !llvm.array<2 x !llvm.struct<(i32, i32)>> + // CHECK: ret [2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 0, i32 1 }] + llvm.return %1 : !llvm.array<2 x !llvm.struct<(i32, i32)>> +} + +llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> { + %1 = llvm.mlir.constant(dense<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>> + // CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]] + llvm.return %1 : !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> +} + llvm.func @noreach() { // CHECK: unreachable llvm.unreachable