171 changes: 69 additions & 102 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1456,62 +1456,25 @@ void ModulePrinter::printAttribute(Attribute attr,
}
}

/// Print the integer element of the given DenseElementsAttr at 'index'.
static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os,
unsigned index, bool isSigned) {
APInt value = *std::next(attr.int_value_begin(), index);
/// Print the integer element of a DenseElementsAttr.
static void printDenseIntElement(const APInt &value, raw_ostream &os,
bool isSigned) {
if (value.getBitWidth() == 1)
os << (value.getBoolValue() ? "true" : "false");
else
value.print(os, isSigned);
}

/// Print the float element of the given DenseElementsAttr at 'index'.
static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
unsigned index, bool isSigned) {
assert(isSigned && "floating point values are always signed");
APFloat value = *std::next(attr.float_value_begin(), index);
printFloatValue(value, os);
}

static void printDenseStringElement(DenseStringElementsAttr attr,
raw_ostream &os, unsigned index) {
os << "\"";
printEscapedString(attr.getRawStringData()[index], os);
os << "\"";
}

void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
bool allowHex) {
if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
printDenseStringElementsAttr(stringAttr);
return;
}

printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
allowHex);
}

void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
bool allowHex) {
auto type = attr.getType();
auto shape = type.getShape();
auto rank = type.getRank();
bool isSigned = !type.getElementType().isUnsignedInteger();

// The function used to print elements of this attribute.
auto printEltFn = type.getElementType().isIntOrIndex()
? printDenseIntElement
: printDenseFloatElement;

static void
printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
function_ref<void(unsigned)> printEltFn) {
// Special case for 0-d and splat tensors.
if (attr.isSplat()) {
printEltFn(attr, os, 0, isSigned);
return;
}
if (isSplat)
return printEltFn(0);

// Special case for degenerate tensors.
auto numElements = type.getNumElements();
int64_t rank = type.getRank();
if (numElements == 0) {
for (int i = 0; i < rank; ++i)
os << '[';
Expand All @@ -1520,14 +1483,6 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
return;
}

// Check to see if we should format this attribute as a hex string.
if (allowHex && shouldPrintElementsAttrWithHex(numElements)) {
ArrayRef<char> rawData = attr.getRawData();
os << '"' << "0x" << llvm::toHex(StringRef(rawData.data(), rawData.size()))
<< "\"";
return;
}

// We use a mixed-radix counter to iterate through the shape. When we bump a
// non-least-significant digit, we emit a close bracket. When we next emit an
// element we re-open all closed brackets.
Expand All @@ -1537,7 +1492,8 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
// The number of brackets that have been opened and not closed.
unsigned openBrackets = 0;

auto bumpCounter = [&]() {
auto shape = type.getShape();
auto bumpCounter = [&] {
// Bump the least significant digit.
++counter[rank - 1];
// Iterate backwards bubbling back the increment.
Expand All @@ -1557,68 +1513,79 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
while (openBrackets++ < rank)
os << '[';
openBrackets = rank;
printEltFn(attr, os, idx, isSigned);
printEltFn(idx);
bumpCounter();
}
while (openBrackets-- > 0)
os << ']';
}

void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
auto type = attr.getType();
auto shape = type.getShape();
auto rank = type.getRank();
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
bool allowHex) {
if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
return printDenseStringElementsAttr(stringAttr);

// Special case for 0-d and splat tensors.
if (attr.isSplat()) {
printDenseStringElement(attr, os, 0);
return;
}
printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
allowHex);
}

// Special case for degenerate tensors.
void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
bool allowHex) {
auto type = attr.getType();
auto elementType = type.getElementType();

// Check to see if we should format this attribute as a hex string.
auto numElements = type.getNumElements();
if (numElements == 0) {
for (int i = 0; i < rank; ++i)
os << '[';
for (int i = 0; i < rank; ++i)
os << ']';
if (!attr.isSplat() && allowHex &&
shouldPrintElementsAttrWithHex(numElements)) {
ArrayRef<char> rawData = attr.getRawData();
os << '"' << "0x" << llvm::toHex(StringRef(rawData.data(), rawData.size()))
<< "\"";
return;
}

// We use a mixed-radix counter to iterate through the shape. When we bump a
// non-least-significant digit, we emit a close bracket. When we next emit an
// element we re-open all closed brackets.
if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
auto printComplexValue = [&](auto complexValues, auto printFn,
raw_ostream &os, auto &&... params) {
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
auto complexValue = *(complexValues.begin() + index);
os << "(";
printFn(complexValue.real(), os, params...);
os << ",";
printFn(complexValue.imag(), os, params...);
os << ")";
});
};

// The mixed-radix counter, with radices in 'shape'.
SmallVector<unsigned, 4> counter(rank, 0);
// The number of brackets that have been opened and not closed.
unsigned openBrackets = 0;
Type complexElementType = complexTy.getElementType();
if (complexElementType.isa<IntegerType>())
printComplexValue(attr.getComplexIntValues(), printDenseIntElement, os,
/*isSigned=*/!complexElementType.isUnsignedInteger());
else
printComplexValue(attr.getComplexFloatValues(), printFloatValue, os);
} else if (elementType.isIntOrIndex()) {
bool isSigned = !elementType.isUnsignedInteger();
auto intValues = attr.getIntValues();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
printDenseIntElement(*(intValues.begin() + index), os, isSigned);
});
} else {
assert(elementType.isa<FloatType>() && "unexpected element type");
auto floatValues = attr.getFloatValues();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
printFloatValue(*(floatValues.begin() + index), os);
});
}
}

auto bumpCounter = [&]() {
// Bump the least significant digit.
++counter[rank - 1];
// Iterate backwards bubbling back the increment.
for (unsigned i = rank - 1; i > 0; --i)
if (counter[i] >= shape[i]) {
// Index 'i' is rolled over. Bump (i-1) and close a bracket.
counter[i] = 0;
++counter[i - 1];
--openBrackets;
os << ']';
}
void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
ArrayRef<StringRef> data = attr.getRawStringData();
auto printFn = [&](unsigned index) {
os << "\"";
printEscapedString(data[index], os);
os << "\"";
};

for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
if (idx != 0)
os << ", ";
while (openBrackets++ < rank)
os << '[';
openBrackets = rank;
printDenseStringElement(attr, os, idx);
bumpCounter();
}
while (openBrackets-- > 0)
os << ']';
printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
}

void ModulePrinter::printType(Type type) {
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/IR/AttributeDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ struct TypeAttributeStorage : public AttributeStorage {

/// Return the bit width which DenseElementsAttr should use for this type.
inline size_t getDenseElementBitWidth(Type eltType) {
// Align the width for complex to 8 to make storage and interpretation easier.
if (ComplexType comp = eltType.dyn_cast<ComplexType>())
return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
// FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
// with double semantics.
if (eltType.isBF16())
Expand Down
230 changes: 193 additions & 37 deletions mlir/lib/IR/Attributes.cpp

Large diffs are not rendered by default.

179 changes: 91 additions & 88 deletions mlir/lib/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1956,29 +1956,13 @@ class TensorLiteralParser {
ArrayRef<int64_t> getShape() const { return shape; }

private:
enum class ElementKind { Boolean, Integer, Float, String };

/// Return a string to represent the given element kind.
const char *getElementKindStr(ElementKind kind) {
switch (kind) {
case ElementKind::Boolean:
return "'boolean'";
case ElementKind::Integer:
return "'integer'";
case ElementKind::Float:
return "'float'";
case ElementKind::String:
return "'string'";
}
llvm_unreachable("unknown element kind");
}

/// Build a Dense Integer attribute for the given type.
DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
/// Get the parsed elements for an integer attribute.
ParseResult getIntAttrElements(llvm::SMLoc loc, Type eltTy,
std::vector<APInt> &intValues);

/// Build a Dense Float attribute for the given type.
DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type,
FloatType eltTy);
/// Get the parsed elements for a float attribute.
ParseResult getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
std::vector<APFloat> &floatValues);

/// Build a Dense String attribute for the given type.
DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
Expand Down Expand Up @@ -2011,9 +1995,6 @@ class TensorLiteralParser {
/// Storage used when parsing elements, this is a pair of <is_negated, token>.
std::vector<std::pair<bool, Token>> storage;

/// A flag that indicates the type of elements that have been parsed.
Optional<ElementKind> knownEltKind;

/// Storage used when parsing elements that were stored as hex values.
Optional<Token> hexStorage;
};
Expand Down Expand Up @@ -2041,7 +2022,8 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
Type eltType = type.getElementType();

// Check to see if we parse the literal from a hex string.
if (hexStorage.hasValue() && eltType.isIntOrFloat())
if (hexStorage.hasValue() &&
(eltType.isIntOrFloat() || eltType.isa<ComplexType>()))
return getHexAttr(loc, type);

// Check that the parsed storage size has the same number of elements to the
Expand All @@ -2052,75 +2034,94 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
return nullptr;
}

// If the type is an integer, build a set of APInt values from the storage
// with the correct bitwidth.
if (auto intTy = eltType.dyn_cast<IntegerType>())
return getIntAttr(loc, type, intTy);
if (auto indexTy = eltType.dyn_cast<IndexType>())
return getIntAttr(loc, type, indexTy);
// Handle complex types in the specific element type cases below.
bool isComplex = false;
if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) {
eltType = complexTy.getElementType();
isComplex = true;
}

// If parsing a floating point type.
if (auto floatTy = eltType.dyn_cast<FloatType>())
return getFloatAttr(loc, type, floatTy);
// Handle integer and index types.
if (eltType.isIntOrIndex()) {
std::vector<APInt> intValues;
if (failed(getIntAttrElements(loc, eltType, intValues)))
return nullptr;
if (isComplex) {
// If this is a complex, treat the parsed values as complex values.
auto complexData = llvm::makeArrayRef(
reinterpret_cast<std::complex<APInt> *>(intValues.data()),
intValues.size() / 2);
return DenseElementsAttr::get(type, complexData);
}
return DenseElementsAttr::get(type, intValues);
}
// Handle floating point types.
if (FloatType floatTy = eltType.dyn_cast<FloatType>()) {
std::vector<APFloat> floatValues;
if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
return nullptr;
if (isComplex) {
// If this is a complex, treat the parsed values as complex values.
auto complexData = llvm::makeArrayRef(
reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
floatValues.size() / 2);
return DenseElementsAttr::get(type, complexData);
}
return DenseElementsAttr::get(type, floatValues);
}

// Other types are assumed to be string representations.
return getStringAttr(loc, type, type.getElementType());
}

/// Build a Dense Integer attribute for the given type.
DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
ShapedType type, Type eltTy) {
std::vector<APInt> intElements;
intElements.reserve(storage.size());
auto isUintType = type.getElementType().isUnsignedInteger();
ParseResult
TensorLiteralParser::getIntAttrElements(llvm::SMLoc loc, Type eltTy,
std::vector<APInt> &intValues) {
intValues.reserve(storage.size());
bool isUintType = eltTy.isUnsignedInteger();
for (const auto &signAndToken : storage) {
bool isNegative = signAndToken.first;
const Token &token = signAndToken.second;
auto tokenLoc = token.getLoc();

if (isNegative && isUintType) {
p.emitError(tokenLoc)
<< "expected unsigned integer elements, but parsed negative value";
return nullptr;
return p.emitError(tokenLoc)
<< "expected unsigned integer elements, but parsed negative value";
}

// Check to see if floating point values were parsed.
if (token.is(Token::floatliteral)) {
p.emitError(tokenLoc)
<< "expected integer elements, but parsed floating-point";
return nullptr;
return p.emitError(tokenLoc)
<< "expected integer elements, but parsed floating-point";
}

assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
"unexpected token type");
if (token.isAny(Token::kw_true, Token::kw_false)) {
if (!eltTy.isInteger(1)) {
p.emitError(tokenLoc)
<< "expected i1 type for 'true' or 'false' values";
return nullptr;
return p.emitError(tokenLoc)
<< "expected i1 type for 'true' or 'false' values";
}
APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
intElements.push_back(apInt);
intValues.push_back(apInt);
continue;
}

// Create APInt values for each element with the correct bitwidth.
Optional<APInt> apInt =
buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
if (!apInt)
return (p.emitError(tokenLoc, "integer constant out of range for type"),
nullptr);
intElements.push_back(*apInt);
return p.emitError(tokenLoc, "integer constant out of range for type");
intValues.push_back(*apInt);
}

return DenseElementsAttr::get(type, intElements);
return success();
}

/// Build a Dense Float attribute for the given type.
DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
ShapedType type,
FloatType eltTy) {
std::vector<APFloat> floatValues;
ParseResult
TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
std::vector<APFloat> &floatValues) {
floatValues.reserve(storage.size());
for (const auto &signAndToken : storage) {
bool isNegative = signAndToken.first;
Expand All @@ -2129,34 +2130,31 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
// Handle hexadecimal float literals.
if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
if (isNegative) {
p.emitError(token.getLoc())
<< "hexadecimal float literal should not have a leading minus";
return nullptr;
return p.emitError(token.getLoc())
<< "hexadecimal float literal should not have a leading minus";
}
auto val = token.getUInt64IntegerValue();
if (!val.hasValue()) {
p.emitError("hexadecimal float constant out of range for attribute");
return nullptr;
return p.emitError(
"hexadecimal float constant out of range for attribute");
}
Optional<APFloat> apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val);
if (!apVal)
return nullptr;
return failure();
floatValues.push_back(*apVal);
continue;
}

// Check to see if any decimal integers or booleans were parsed.
if (!token.is(Token::floatliteral)) {
p.emitError() << "expected floating-point elements, but parsed integer";
return nullptr;
}
if (!token.is(Token::floatliteral))
return p.emitError()
<< "expected floating-point elements, but parsed integer";

// Build the float values from tokens.
auto val = token.getFloatingPointValue();
if (!val.hasValue()) {
p.emitError("floating point value too large for attribute");
return nullptr;
}
if (!val.hasValue())
return p.emitError("floating point value too large for attribute");

// Treat BF16 as double because it is not supported in LLVM's APFloat.
APFloat apVal(isNegative ? -*val : *val);
if (!eltTy.isBF16() && !eltTy.isF64()) {
Expand All @@ -2166,8 +2164,7 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
}
floatValues.push_back(apVal);
}

return DenseElementsAttr::get(type, floatValues);
return success();
}

/// Build a Dense String attribute for the given type.
Expand Down Expand Up @@ -2196,31 +2193,26 @@ DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc,
DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
ShapedType type) {
Type elementType = type.getElementType();
if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) {
p.emitError(loc) << "expected floating-point or integer element type, got "
<< elementType;
if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
p.emitError(loc)
<< "expected floating-point, integer, or complex element type, got "
<< elementType;
return nullptr;
}

std::string data;
if (parseElementAttrHexValues(p, hexStorage.getValue(), data))
return nullptr;

// Check that the size of the hex data corresponds to the size of the type, or
// a splat of the type.
// TODO: bf16 is currently stored as a double, this should be removed when
// APFloat properly supports it.
int64_t elementWidth =
elementType.isBF16() ? 64 : elementType.getIntOrFloatBitWidth();
if (static_cast<int64_t>(data.size() * CHAR_BIT) !=
(type.getNumElements() * elementWidth)) {
ArrayRef<char> rawData(data.data(), data.size());
bool detectedSplat = false;
if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
p.emitError(loc) << "elements hex data size is invalid for provided type: "
<< type;
return nullptr;
}

return DenseElementsAttr::getFromRawBuffer(
type, ArrayRef<char>(data.data(), data.size()), /*isSplatBuffer=*/false);
return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat);
}

ParseResult TensorLiteralParser::parseElement() {
Expand All @@ -2247,6 +2239,17 @@ ParseResult TensorLiteralParser::parseElement() {
storage.emplace_back(/*isNegative=*/ false, p.getToken());
p.consumeToken();
break;

// Parse a complex element of the form '(' element ',' element ')'.
case Token::l_paren:
p.consumeToken(Token::l_paren);
if (parseElement() ||
p.parseToken(Token::comma, "expected ',' between complex elements") ||
parseElement() ||
p.parseToken(Token::r_paren, "expected ')' after complex elements"))
return failure();
break;

default:
return p.emitError("expected element literal of primitive type");
}
Expand Down
3 changes: 3 additions & 0 deletions mlir/test/IR/dense-elements-hex.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
// CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xf64>
"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xf64>} : () -> ()

// CHECK: dense<(1.000000e+01,5.000000e+00)> : tensor<2xcomplex<f64>>
"foo.op"() {dense.attr = dense<"0x0000000000002440000000000000144000000000000024400000000000001440"> : tensor<2xcomplex<f64>>} : () -> ()

// CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xbf16>
"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xbf16>} : () -> ()

Expand Down
16 changes: 16 additions & 0 deletions mlir/test/IR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,22 @@ func @elementsattr_toolarge2() -> () {
// -----
"foo"(){bar = dense<[()]> : tensor<complex<i64>>} : () -> () // expected-error {{expected element literal of primitive type}}
// -----
"foo"(){bar = dense<[(10)]> : tensor<complex<i64>>} : () -> () // expected-error {{expected ',' between complex elements}}
// -----
"foo"(){bar = dense<[(10,)]> : tensor<complex<i64>>} : () -> () // expected-error {{expected element literal of primitive type}}
// -----
"foo"(){bar = dense<[(10,10]> : tensor<complex<i64>>} : () -> () // expected-error {{expected ')' after complex elements}}
// -----
func @elementsattr_malformed_opaque() -> () {
^bb0:
"foo"(){bar = opaque<10, "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{expected dialect namespace}}
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/IR/parser.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,15 @@ func @densetensorattr() -> () {
"index"(){bar = dense<1> : tensor<index>} : () -> ()
// CHECK: "index"() {bar = dense<[1, 2]> : tensor<2xindex>} : () -> ()
"index"(){bar = dense<[1, 2]> : tensor<2xindex>} : () -> ()

// CHECK: dense<(1,1)> : tensor<complex<i64>>
"complex_attr"(){bar = dense<(1,1)> : tensor<complex<i64>>} : () -> ()
// CHECK: dense<[(1,1), (2,2)]> : tensor<2xcomplex<i64>>
"complex_attr"(){bar = dense<[(1,1), (2,2)]> : tensor<2xcomplex<i64>>} : () -> ()
// CHECK: dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>
"complex_attr"(){bar = dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>} : () -> ()
// CHECK: dense<[(1.000000e+00,0.000000e+00), (2.000000e+00,2.000000e+00)]> : tensor<2xcomplex<f32>>
"complex_attr"(){bar = dense<[(1.000000e+00,0.000000e+00), (2.000000e+00,2.000000e+00)]> : tensor<2xcomplex<f32>>} : () -> ()
return
}

Expand Down
31 changes: 31 additions & 0 deletions mlir/unittests/IR/AttributeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
auto detectedSplat =
DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
EXPECT_EQ(detectedSplat, splat);

for (auto newValue : detectedSplat.template getValues<EltTy>())
EXPECT_TRUE(newValue == splatElt);
}

namespace {
Expand Down Expand Up @@ -162,4 +165,32 @@ TEST(DenseSplatTest, StringAttrSplat) {
testSplat(stringType, stringAttr);
}

TEST(DenseComplexTest, ComplexFloatSplat) {
MLIRContext context;
ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
std::complex<float> value(10.0, 15.0);
testSplat(complexType, value);
}

TEST(DenseComplexTest, ComplexIntSplat) {
MLIRContext context;
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
std::complex<int64_t> value(10, 15);
testSplat(complexType, value);
}

TEST(DenseComplexTest, ComplexAPFloatSplat) {
MLIRContext context;
ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
testSplat(complexType, value);
}

TEST(DenseComplexTest, ComplexAPIntSplat) {
MLIRContext context;
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
testSplat(complexType, value);
}

} // end namespace