Expand Up
@@ -1956,29 +1956,13 @@ class TensorLiteralParser {
ArrayRef<int64_t > getShape () const { return shape; }
private:
enum class ElementKind { Boolean , Integer, Float, String };
// / Return a string to represent the given element kind.
const char *getElementKindStr (ElementKind kind) {
switch (kind) {
case ElementKind::Boolean :
return " 'boolean'" ;
case ElementKind::Integer:
return " 'integer'" ;
case ElementKind::Float:
return " 'float'" ;
case ElementKind::String:
return " 'string'" ;
}
llvm_unreachable (" unknown element kind" );
}
// / Build a Dense Integer attribute for the given type.
DenseElementsAttr getIntAttr (llvm::SMLoc loc, ShapedType type, Type eltTy);
// / Get the parsed elements for an integer attribute.
ParseResult getIntAttrElements (llvm::SMLoc loc, Type eltTy,
std::vector<APInt> &intValues);
// / Build a Dense Float attribute for the given type .
DenseElementsAttr getFloatAttr (llvm::SMLoc loc, ShapedType type ,
FloatType eltTy );
// / Get the parsed elements for a float attribute .
ParseResult getFloatAttrElements (llvm::SMLoc loc, FloatType eltTy ,
std::vector<APFloat> &floatValues );
// / Build a Dense String attribute for the given type.
DenseElementsAttr getStringAttr (llvm::SMLoc loc, ShapedType type, Type eltTy);
Expand Down
Expand Up
@@ -2011,9 +1995,6 @@ class TensorLiteralParser {
// / Storage used when parsing elements, this is a pair of <is_negated, token>.
std::vector<std::pair<bool , Token>> storage;
// / A flag that indicates the type of elements that have been parsed.
Optional<ElementKind> knownEltKind;
// / Storage used when parsing elements that were stored as hex values.
Optional<Token> hexStorage;
};
Expand Down
Expand Up
@@ -2041,7 +2022,8 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
Type eltType = type.getElementType ();
// Check to see if we parse the literal from a hex string.
if (hexStorage.hasValue () && eltType.isIntOrFloat ())
if (hexStorage.hasValue () &&
(eltType.isIntOrFloat () || eltType.isa <ComplexType>()))
return getHexAttr (loc, type);
// Check that the parsed storage size has the same number of elements to the
Expand All
@@ -2052,75 +2034,94 @@ DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
return nullptr ;
}
// If the type is an integer, build a set of APInt values from the storage
// with the correct bitwidth.
if (auto intTy = eltType.dyn_cast <IntegerType >())
return getIntAttr (loc, type, intTy );
if ( auto indexTy = eltType. dyn_cast <IndexType>())
return getIntAttr (loc, type, indexTy);
// Handle complex types in the specific element type cases below.
bool isComplex = false ;
if (ComplexType complexTy = eltType.dyn_cast <ComplexType >()) {
eltType = complexTy. getElementType ( );
isComplex = true ;
}
// If parsing a floating point type.
if (auto floatTy = eltType.dyn_cast <FloatType>())
return getFloatAttr (loc, type, floatTy);
// Handle integer and index types.
if (eltType.isIntOrIndex ()) {
std::vector<APInt> intValues;
if (failed (getIntAttrElements (loc, eltType, intValues)))
return nullptr ;
if (isComplex) {
// If this is a complex, treat the parsed values as complex values.
auto complexData = llvm::makeArrayRef (
reinterpret_cast <std::complex<APInt> *>(intValues.data ()),
intValues.size () / 2 );
return DenseElementsAttr::get (type, complexData);
}
return DenseElementsAttr::get (type, intValues);
}
// Handle floating point types.
if (FloatType floatTy = eltType.dyn_cast <FloatType>()) {
std::vector<APFloat> floatValues;
if (failed (getFloatAttrElements (loc, floatTy, floatValues)))
return nullptr ;
if (isComplex) {
// If this is a complex, treat the parsed values as complex values.
auto complexData = llvm::makeArrayRef (
reinterpret_cast <std::complex<APFloat> *>(floatValues.data ()),
floatValues.size () / 2 );
return DenseElementsAttr::get (type, complexData);
}
return DenseElementsAttr::get (type, floatValues);
}
// Other types are assumed to be string representations.
return getStringAttr (loc, type, type.getElementType ());
}
// / Build a Dense Integer attribute for the given type.
DenseElementsAttr TensorLiteralParser::getIntAttr (llvm::SMLoc loc,
ShapedType type , Type eltTy) {
std::vector<APInt> intElements;
intElements .reserve (storage.size ());
auto isUintType = type. getElementType () .isUnsignedInteger ();
ParseResult
TensorLiteralParser::getIntAttrElements (llvm::SMLoc loc , Type eltTy,
std::vector<APInt> &intValues) {
intValues .reserve (storage.size ());
bool isUintType = eltTy .isUnsignedInteger ();
for (const auto &signAndToken : storage) {
bool isNegative = signAndToken.first ;
const Token &token = signAndToken.second ;
auto tokenLoc = token.getLoc ();
if (isNegative && isUintType) {
p.emitError (tokenLoc)
<< " expected unsigned integer elements, but parsed negative value" ;
return nullptr ;
return p.emitError (tokenLoc)
<< " expected unsigned integer elements, but parsed negative value" ;
}
// Check to see if floating point values were parsed.
if (token.is (Token::floatliteral)) {
p.emitError (tokenLoc)
<< " expected integer elements, but parsed floating-point" ;
return nullptr ;
return p.emitError (tokenLoc)
<< " expected integer elements, but parsed floating-point" ;
}
assert (token.isAny (Token::integer, Token::kw_true, Token::kw_false) &&
" unexpected token type" );
if (token.isAny (Token::kw_true, Token::kw_false)) {
if (!eltTy.isInteger (1 )) {
p.emitError (tokenLoc)
<< " expected i1 type for 'true' or 'false' values" ;
return nullptr ;
return p.emitError (tokenLoc)
<< " expected i1 type for 'true' or 'false' values" ;
}
APInt apInt (1 , token.is (Token::kw_true), /* isSigned=*/ false );
intElements .push_back (apInt);
intValues .push_back (apInt);
continue ;
}
// Create APInt values for each element with the correct bitwidth.
Optional<APInt> apInt =
buildAttributeAPInt (eltTy, isNegative, token.getSpelling ());
if (!apInt)
return (p.emitError (tokenLoc, " integer constant out of range for type" ),
nullptr );
intElements.push_back (*apInt);
return p.emitError (tokenLoc, " integer constant out of range for type" );
intValues.push_back (*apInt);
}
return DenseElementsAttr::get (type, intElements);
return success ();
}
// / Build a Dense Float attribute for the given type.
DenseElementsAttr TensorLiteralParser::getFloatAttr (llvm::SMLoc loc,
ShapedType type,
FloatType eltTy) {
std::vector<APFloat> floatValues;
ParseResult
TensorLiteralParser::getFloatAttrElements (llvm::SMLoc loc, FloatType eltTy,
std::vector<APFloat> &floatValues) {
floatValues.reserve (storage.size ());
for (const auto &signAndToken : storage) {
bool isNegative = signAndToken.first ;
Expand All
@@ -2129,34 +2130,31 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
// Handle hexadecimal float literals.
if (token.is (Token::integer) && token.getSpelling ().startswith (" 0x" )) {
if (isNegative) {
p.emitError (token.getLoc ())
<< " hexadecimal float literal should not have a leading minus" ;
return nullptr ;
return p.emitError (token.getLoc ())
<< " hexadecimal float literal should not have a leading minus" ;
}
auto val = token.getUInt64IntegerValue ();
if (!val.hasValue ()) {
p.emitError (" hexadecimal float constant out of range for attribute " );
return nullptr ;
return p.emitError (
" hexadecimal float constant out of range for attribute " ) ;
}
Optional<APFloat> apVal = buildHexadecimalFloatLiteral (&p, eltTy, *val);
if (!apVal)
return nullptr ;
return failure () ;
floatValues.push_back (*apVal);
continue ;
}
// Check to see if any decimal integers or booleans were parsed.
if (!token.is (Token::floatliteral)) {
p.emitError () << " expected floating-point elements, but parsed integer" ;
return nullptr ;
}
if (!token.is (Token::floatliteral))
return p.emitError ()
<< " expected floating-point elements, but parsed integer" ;
// Build the float values from tokens.
auto val = token.getFloatingPointValue ();
if (!val.hasValue ()) {
p.emitError (" floating point value too large for attribute" );
return nullptr ;
}
if (!val.hasValue ())
return p.emitError (" floating point value too large for attribute" );
// Treat BF16 as double because it is not supported in LLVM's APFloat.
APFloat apVal (isNegative ? -*val : *val);
if (!eltTy.isBF16 () && !eltTy.isF64 ()) {
Expand All
@@ -2166,8 +2164,7 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
}
floatValues.push_back (apVal);
}
return DenseElementsAttr::get (type, floatValues);
return success ();
}
// / Build a Dense String attribute for the given type.
Expand Down
Expand Up
@@ -2196,31 +2193,26 @@ DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc,
DenseElementsAttr TensorLiteralParser::getHexAttr (llvm::SMLoc loc,
ShapedType type) {
Type elementType = type.getElementType ();
if (!elementType.isa <FloatType>() && !elementType.isa <IntegerType>()) {
p.emitError (loc) << " expected floating-point or integer element type, got "
<< elementType;
if (!elementType.isIntOrIndexOrFloat () && !elementType.isa <ComplexType>()) {
p.emitError (loc)
<< " expected floating-point, integer, or complex element type, got "
<< elementType;
return nullptr ;
}
std::string data;
if (parseElementAttrHexValues (p, hexStorage.getValue (), data))
return nullptr ;
// Check that the size of the hex data corresponds to the size of the type, or
// a splat of the type.
// TODO: bf16 is currently stored as a double, this should be removed when
// APFloat properly supports it.
int64_t elementWidth =
elementType.isBF16 () ? 64 : elementType.getIntOrFloatBitWidth ();
if (static_cast <int64_t >(data.size () * CHAR_BIT) !=
(type.getNumElements () * elementWidth)) {
ArrayRef<char > rawData (data.data (), data.size ());
bool detectedSplat = false ;
if (!DenseElementsAttr::isValidRawBuffer (type, rawData, detectedSplat)) {
p.emitError (loc) << " elements hex data size is invalid for provided type: "
<< type;
return nullptr ;
}
return DenseElementsAttr::getFromRawBuffer (
type, ArrayRef<char >(data.data (), data.size ()), /* isSplatBuffer=*/ false );
return DenseElementsAttr::getFromRawBuffer (type, rawData, detectedSplat);
}
ParseResult TensorLiteralParser::parseElement () {
Expand All
@@ -2247,6 +2239,17 @@ ParseResult TensorLiteralParser::parseElement() {
storage.emplace_back (/* isNegative=*/ false , p.getToken ());
p.consumeToken ();
break ;
// Parse a complex element of the form '(' element ',' element ')'.
case Token::l_paren:
p.consumeToken (Token::l_paren);
if (parseElement () ||
p.parseToken (Token::comma, " expected ',' between complex elements" ) ||
parseElement () ||
p.parseToken (Token::r_paren, " expected ')' after complex elements" ))
return failure ();
break ;
default :
return p.emitError (" expected element literal of primitive type" );
}
Expand Down