diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index ead6c0341cd69d..a95ed18ca4e7b0 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3109,6 +3109,7 @@ def SPV_OC_OpTypeBool : I32EnumAttrCase<"OpTypeBool", 20>; def SPV_OC_OpTypeInt : I32EnumAttrCase<"OpTypeInt", 21>; def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>; def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>; +def SPV_OC_OpTypeMatrix : I32EnumAttrCase<"OpTypeMatrix", 24>; def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>; def SPV_OC_OpTypeRuntimeArray : I32EnumAttrCase<"OpTypeRuntimeArray", 29>; def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>; @@ -3250,15 +3251,15 @@ def SPV_OpcodeAttr : SPV_OC_OpLine, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, - SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray, - SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer, - SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, - SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, - SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, - SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, - SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, - SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, - SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct, + SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeMatrix, + SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, + SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, + SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite, + SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, + SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, + SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, + SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, + SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h index 71eba72e5e84d0..b7180399a837fe 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -13,6 +13,8 @@ #ifndef MLIR_DIALECT_SPIRV_SPIRVTYPES_H_ #define MLIR_DIALECT_SPIRV_SPIRVTYPES_H_ +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/TypeSupport.h" #include "mlir/IR/Types.h" @@ -56,9 +58,11 @@ namespace detail { struct ArrayTypeStorage; struct CooperativeMatrixTypeStorage; struct ImageTypeStorage; +struct MatrixTypeStorage; struct PointerTypeStorage; struct RuntimeArrayTypeStorage; struct StructTypeStorage; + } // namespace detail namespace TypeKind { @@ -66,6 +70,7 @@ enum Kind { Array = Type::FIRST_SPIRV_TYPE, CooperativeMatrix, Image, + Matrix, Pointer, RuntimeArray, Struct, @@ -366,6 +371,36 @@ class CooperativeMatrixNVType Optional storage = llvm::None); }; +// SPIR-V matrix type +class MatrixType : public Type::TypeBase { +public: + using Base::Base; + + static bool kindof(unsigned kind) { return kind == TypeKind::Matrix; } + + static MatrixType get(Type columnType, uint32_t columnCount); + + static MatrixType getChecked(Type columnType, uint32_t columnCount, + Location location); + + static LogicalResult verifyConstructionInvariants(Location loc, + Type columnType, + uint32_t columnCount); + + /// Returns true if the matrix elements are vectors of float elements + static bool isValidColumnType(Type columnType); + + Type getElementType() const; + + unsigned getNumElements() const; + + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage = llvm::None); + void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage = llvm::None); +}; + } // end namespace spirv } // end namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp index 8c4d0ebe99a73c..455064f58ce699 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -116,8 +116,8 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface { SPIRVDialect::SPIRVDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { - addTypes(); + addTypes(); addAttributes(); @@ -197,6 +197,42 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, return type; } +static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + Type type; + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + if (parser.parseType(type)) + return Type(); + + if (auto t = type.dyn_cast()) { + if (t.getRank() != 1) { + parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; + return Type(); + } + if (t.getNumElements() > 4 || t.getNumElements() < 2) { + parser.emitError(typeLoc, + "matrix columns size has to be less than or equal " + "to 4 and greater than or equal 2, but found ") + << t.getNumElements(); + return Type(); + } + + if (!t.getElementType().isa()) { + parser.emitError(typeLoc, "matrix columns' elements must be of " + "Float type, got ") + << t.getElementType(); + return Type(); + } + } else { + parser.emitError(typeLoc, "matrix must be composed using vector " + "type, got ") + << type; + return Type(); + } + + return type; +} + /// Parses an optional `, stride = N` assembly segment. If no parsing failure /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if /// missing. @@ -279,7 +315,7 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, return Type(); if (dims.size() != 2) { - parser.emitError(countLoc, "expected rows and columns size."); + parser.emitError(countLoc, "expected rows and columns size"); return Type(); } @@ -350,6 +386,40 @@ static Type parseRuntimeArrayType(SPIRVDialect const &dialect, return RuntimeArrayType::get(elementType, stride); } +// matrix-type ::= `!spv.matrix` `<` integer-literal `x` element-type `>` +static Type parseMatrixType(SPIRVDialect const &dialect, + DialectAsmParser &parser) { + if (parser.parseLess()) + return Type(); + + SmallVector countDims; + llvm::SMLoc countLoc = parser.getCurrentLocation(); + if (parser.parseDimensionList(countDims, /*allowDynamic=*/false)) + return Type(); + if (countDims.size() != 1) { + parser.emitError(countLoc, "expected single unsigned " + "integer for number of columns"); + return Type(); + } + + int64_t columnCount = countDims[0]; + // According to the specification, Matrices can have 2, 3, or 4 columns + if (columnCount < 2 || columnCount > 4) { + parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 " + "columns"); + return Type(); + } + + Type columnType = parseAndVerifyMatrixType(dialect, parser); + if (!columnType) + return Type(); + + if (parser.parseGreater()) + return Type(); + + return MatrixType::get(columnType, columnCount); +} + // Specialize this function to parse each of the parameters that define an // ImageType. By default it assumes this is an enum type. template @@ -567,7 +637,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const { return parseRuntimeArrayType(*this, parser); if (keyword == "struct") return parseStructType(*this, parser); - + if (keyword == "matrix") + return parseMatrixType(*this, parser); parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword; return Type(); } @@ -635,6 +706,11 @@ static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) { os << ">"; } +static void print(MatrixType type, DialectAsmPrinter &os) { + os << "matrix<" << type.getNumElements() << " x " << type.getElementType(); + os << ">"; +} + void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { switch (type.getKind()) { case TypeKind::Array: @@ -655,6 +731,9 @@ void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { case TypeKind::Struct: print(type.cast(), os); return; + case TypeKind::Matrix: + print(type.cast(), os); + return; default: llvm_unreachable("unhandled SPIR-V type"); } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp index 49b39ec7843530..4ba17f3a1240fc 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -159,6 +159,7 @@ bool CompositeType::classof(Type type) { switch (type.getKind()) { case TypeKind::Array: case TypeKind::CooperativeMatrix: + case TypeKind::Matrix: case TypeKind::RuntimeArray: case TypeKind::Struct: return true; @@ -180,6 +181,8 @@ Type CompositeType::getElementType(unsigned index) const { return cast().getElementType(); case spirv::TypeKind::CooperativeMatrix: return cast().getElementType(); + case spirv::TypeKind::Matrix: + return cast().getElementType(); case spirv::TypeKind::RuntimeArray: return cast().getElementType(); case spirv::TypeKind::Struct: @@ -198,6 +201,8 @@ unsigned CompositeType::getNumElements() const { case spirv::TypeKind::CooperativeMatrix: llvm_unreachable( "invalid to query number of elements of spirv::CooperativeMatrix type"); + case spirv::TypeKind::Matrix: + return cast().getNumElements(); case spirv::TypeKind::RuntimeArray: llvm_unreachable( "invalid to query number of elements of spirv::RuntimeArray type"); @@ -230,6 +235,9 @@ void CompositeType::getExtensions( case spirv::TypeKind::CooperativeMatrix: cast().getExtensions(extensions, storage); break; + case spirv::TypeKind::Matrix: + cast().getExtensions(extensions, storage); + break; case spirv::TypeKind::RuntimeArray: cast().getExtensions(extensions, storage); break; @@ -255,6 +263,9 @@ void CompositeType::getCapabilities( case spirv::TypeKind::CooperativeMatrix: cast().getCapabilities(capabilities, storage); break; + case spirv::TypeKind::Matrix: + cast().getCapabilities(capabilities, storage); + break; case spirv::TypeKind::RuntimeArray: cast().getCapabilities(capabilities, storage); break; @@ -823,10 +834,12 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, scalarType.getExtensions(extensions, storage); } else if (auto compositeType = dyn_cast()) { compositeType.getExtensions(extensions, storage); - } else if (auto ptrType = dyn_cast()) { - ptrType.getExtensions(extensions, storage); } else if (auto imageType = dyn_cast()) { imageType.getExtensions(extensions, storage); + } else if (auto matrixType = dyn_cast()) { + matrixType.getExtensions(extensions, storage); + } else if (auto ptrType = dyn_cast()) { + ptrType.getExtensions(extensions, storage); } else { llvm_unreachable("invalid SPIR-V Type to getExtensions"); } @@ -839,10 +852,12 @@ void SPIRVType::getCapabilities( scalarType.getCapabilities(capabilities, storage); } else if (auto compositeType = dyn_cast()) { compositeType.getCapabilities(capabilities, storage); - } else if (auto ptrType = dyn_cast()) { - ptrType.getCapabilities(capabilities, storage); } else if (auto imageType = dyn_cast()) { imageType.getCapabilities(capabilities, storage); + } else if (auto matrixType = dyn_cast()) { + matrixType.getCapabilities(capabilities, storage); + } else if (auto ptrType = dyn_cast()) { + ptrType.getCapabilities(capabilities, storage); } else { llvm_unreachable("invalid SPIR-V Type to getCapabilities"); } @@ -1000,3 +1015,89 @@ void StructType::getCapabilities( for (Type elementType : getElementTypes()) elementType.cast().getCapabilities(capabilities, storage); } + +//===----------------------------------------------------------------------===// +// MatrixType +//===----------------------------------------------------------------------===// + +struct spirv::detail::MatrixTypeStorage : public TypeStorage { + MatrixTypeStorage(Type columnType, uint32_t columnCount) + : TypeStorage(), columnType(columnType), columnCount(columnCount) {} + + using KeyTy = std::tuple; + + static MatrixTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + + // Initialize the memory using placement new. + return new (allocator.allocate()) + MatrixTypeStorage(std::get<0>(key), std::get<1>(key)); + } + + bool operator==(const KeyTy &key) const { + return key == KeyTy(columnType, columnCount); + } + + Type columnType; + const uint32_t columnCount; +}; + +MatrixType MatrixType::get(Type columnType, uint32_t columnCount) { + return Base::get(columnType.getContext(), TypeKind::Matrix, columnType, + columnCount); +} + +MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount, + Location location) { + return Base::getChecked(location, TypeKind::Matrix, columnType, columnCount); +} + +LogicalResult MatrixType::verifyConstructionInvariants(Location loc, + Type columnType, + uint32_t columnCount) { + if (columnCount < 2 || columnCount > 4) + return emitError(loc, "matrix can have 2, 3, or 4 columns only"); + + if (!isValidColumnType(columnType)) + return emitError(loc, "matrix columns must be vectors of floats"); + + /// The underlying vectors (columns) must be of size 2, 3, or 4 + ArrayRef columnShape = columnType.cast().getShape(); + if (columnShape.size() != 1) + return emitError(loc, "matrix columns must be 1D vectors"); + + if (columnShape[0] < 2 || columnShape[0] > 4) + return emitError(loc, "matrix columns must be of size 2, 3, or 4"); + + return success(); +} + +/// Returns true if the matrix elements are vectors of float elements +bool MatrixType::isValidColumnType(Type columnType) { + if (auto vectorType = columnType.dyn_cast()) { + if (vectorType.getElementType().isa()) + return true; + } + return false; +} + +Type MatrixType::getElementType() const { return getImpl()->columnType; } + +unsigned MatrixType::getNumElements() const { return getImpl()->columnCount; } + +void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, + Optional storage) { + getElementType().cast().getExtensions(extensions, storage); +} + +void MatrixType::getCapabilities( + SPIRVType::CapabilityArrayRefVector &capabilities, + Optional storage) { + { + static const Capability caps[] = {Capability::Matrix}; + ArrayRef ref(caps, llvm::array_lengthof(caps)); + capabilities.push_back(ref); + } + // Add any capabilities associated with the underlying vectors (i.e., columns) + getElementType().cast().getCapabilities(capabilities, storage); +} diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 87f233580b75ac..750dddfa6dc4f3 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -225,6 +225,8 @@ class Deserializer { LogicalResult processStructType(ArrayRef operands); + LogicalResult processMatrixType(ArrayRef operands); + //===--------------------------------------------------------------------===// // Constant //===--------------------------------------------------------------------===// @@ -1170,6 +1172,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode, return processRuntimeArrayType(operands); case spirv::Opcode::OpTypeStruct: return processStructType(operands); + case spirv::Opcode::OpTypeMatrix: + return processMatrixType(operands); default: return emitError(unknownLoc, "unhandled type instruction"); } @@ -1333,6 +1337,25 @@ LogicalResult Deserializer::processStructType(ArrayRef operands) { return success(); } +LogicalResult Deserializer::processMatrixType(ArrayRef operands) { + if (operands.size() != 3) { + // Three operands are needed: result_id, column_type, and column_count + return emitError(unknownLoc, "OpTypeMatrix must have 3 operands" + " (result_id, column_type, and column_count)"); + } + // Matrix columns must be of vector type + Type elementTy = getType(operands[1]); + if (!elementTy) { + return emitError(unknownLoc, + "OpTypeMatrix references undefined column type.") + << operands[1]; + } + + uint32_t colsCount = operands[2]; + typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount); + return success(); +} + //===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// @@ -2238,6 +2261,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, case spirv::Opcode::OpTypeInt: case spirv::Opcode::OpTypeFloat: case spirv::Opcode::OpTypeVector: + case spirv::Opcode::OpTypeMatrix: case spirv::Opcode::OpTypeArray: case spirv::Opcode::OpTypeFunction: case spirv::Opcode::OpTypeRuntimeArray: diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 8ea0c4f4711bfe..0b1c970589b122 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -1111,6 +1111,17 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID, return success(); } + if (auto matrixType = type.dyn_cast()) { + uint32_t elementTypeID = 0; + if (failed(processType(loc, matrixType.getElementType(), elementTypeID))) { + return failure(); + } + typeEnum = spirv::Opcode::OpTypeMatrix; + operands.push_back(elementTypeID); + operands.push_back(matrixType.getNumElements()); + return success(); + } + // TODO(ravishankarm) : Handle other types. return emitError(loc, "unhandled type in serialization: ") << type; } diff --git a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir new file mode 100644 index 00000000000000..b27702bf50d8b1 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s + +spv.module Logical GLSL450 requires #spv.vce { + spv.func @matrix_type(%arg0 : !spv.ptr>, StorageBuffer>, %arg1 : i32) "None" { + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr>, StorageBuffer> + %2 = spv.AccessChain %arg0[%arg1] : !spv.ptr>, StorageBuffer> + spv.Return + } +} + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { + // CHECK: spv.globalVariable {{@.*}} : !spv.ptr>, StorageBuffer> + spv.globalVariable @var0 : !spv.ptr>, StorageBuffer> + + // CHECK: spv.globalVariable {{@.*}} : !spv.ptr>, StorageBuffer> + spv.globalVariable @var1 : !spv.ptr>, StorageBuffer> + + // CHECK: spv.globalVariable {{@.*}} : !spv.ptr>, StorageBuffer> + spv.globalVariable @var2 : !spv.ptr>, StorageBuffer> +} diff --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir index 697177b0b98e74..1d1a1868ea3c5f 100644 --- a/mlir/test/Dialect/SPIRV/types.mlir +++ b/mlir/test/Dialect/SPIRV/types.mlir @@ -347,3 +347,87 @@ func @missing_scope(!spv.coopmatrix<8x16xi32>) -> () // expected-error @+1 {{expected rows and columns size}} func @missing_count(!spv.coopmatrix<8xi32, Subgroup>) -> () +// ----- + +//===----------------------------------------------------------------------===// +// Matrix +//===----------------------------------------------------------------------===// +// CHECK: func @matrix_type(!spv.matrix<2 x vector<2xf16>>) +func @matrix_type(!spv.matrix<2 x vector<2xf16>>) -> () + +// ----- + +// CHECK: func @matrix_type(!spv.matrix<3 x vector<3xf32>>) +func @matrix_type(!spv.matrix<3 x vector<3xf32>>) -> () + +// ----- + +// CHECK: func @matrix_type(!spv.matrix<4 x vector<4xf16>>) +func @matrix_type(!spv.matrix<4 x vector<4xf16>>) -> () + +// ----- + +// expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}} +func @matrix_invalid_size(!spv.matrix<5 x vector<3xf32>>) -> () + +// ----- + +// expected-error @+1 {{matrix is expected to have 2, 3, or 4 columns}} +func @matrix_invalid_size(!spv.matrix<1 x vector<3xf32>>) -> () + +// ----- + +// expected-error @+1 {{matrix columns size has to be less than or equal to 4 and greater than or equal 2, but found 5}} +func @matrix_invalid_columns_size(!spv.matrix<3 x vector<5xf32>>) -> () + +// ----- + +// expected-error @+1 {{matrix columns size has to be less than or equal to 4 and greater than or equal 2, but found 1}} +func @matrix_invalid_columns_size(!spv.matrix<3 x vector<1xf32>>) -> () + +// ----- + +// expected-error @+1 {{expected '<'}} +func @matrix_invalid_format(!spv.matrix 3 x vector<3xf32>>) -> () + +// ----- + +// expected-error @+1 {{unbalanced ')' character in pretty dialect name}} +func @matrix_invalid_format(!spv.matrix< 3 x vector<3xf32>) -> () + +// ----- + +// expected-error @+1 {{expected 'x' in dimension list}} +func @matrix_invalid_format(!spv.matrix<2 vector<3xi32>>) -> () + +// ----- + +// expected-error @+1 {{matrix must be composed using vector type, got 'i32'}} +func @matrix_invalid_type(!spv.matrix< 3 x i32>) -> () + +// ----- + +// expected-error @+1 {{matrix must be composed using vector type, got '!spv.array<16 x f32>'}} +func @matrix_invalid_type(!spv.matrix< 3 x !spv.array<16 x f32>>) -> () + +// ----- + +// expected-error @+1 {{matrix must be composed using vector type, got '!spv.rtarray'}} +func @matrix_invalid_type(!spv.matrix< 3 x !spv.rtarray>) -> () + +// ----- + +// expected-error @+1 {{matrix columns' elements must be of Float type, got 'i32'}} +func @matrix_invalid_type(!spv.matrix<2 x vector<3xi32>>) -> () + +// ----- + +// expected-error @+1 {{expected single unsigned integer for number of columns}} +func @matrix_size_type(!spv.matrix< x vector<3xi32>>) -> () + +// ----- + +// expected-error @+1 {{expected single unsigned integer for number of columns}} +func @matrix_size_type(!spv.matrix<2.0 x vector<3xi32>>) -> () + +// ----- \ No newline at end of file