diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 8badb84a879fa..c903da7c1b6f7 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4346,6 +4346,8 @@ class SPIRV_MatrixOf : def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>; def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>; +def SPIRV_ScalarOrVectorOrPtrOrCoopMatrix : + AnyTypeOf<[SPIRV_ScalarOrVectorOrPtr, SPIRV_AnyCooperativeMatrix]>; class SPIRV_Vec4 : VectorOfRankAndLengthAndType<[1], [4], [type]>; def SPIRV_IntVec4 : SPIRV_Vec4; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td index f05030dfb5d57..a434764750a56 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td @@ -40,11 +40,11 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> { let summary = "Bit pattern-preserving type conversion."; let description = [{ - Result Type must be an OpTypePointer, or a scalar or vector of - numerical-type. + Result Type must be an OpTypePointer, cooperative matrix, or a scalar or + vector of numerical-type. - Operand must have a type of OpTypePointer, or a scalar or vector of - numerical-type. It must be a different type than Result Type. + Operand must have a type of OpTypePointer, cooperative matrix, or a scalar + or vector of numerical-type. It must be a different type than Result Type. If either Result Type or Operand is a pointer, the other must be a pointer (diverges from the SPIR-V spec). @@ -71,11 +71,11 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> { }]; let arguments = (ins - SPIRV_ScalarOrVectorOrPtr:$operand + SPIRV_ScalarOrVectorOrPtrOrCoopMatrix:$operand ); let results = (outs - SPIRV_ScalarOrVectorOrPtr:$result + SPIRV_ScalarOrVectorOrPtrOrCoopMatrix:$result ); let assemblyFormat = [{ diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp index a5330dc56d48f..ea8cb8de74536 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp @@ -87,6 +87,36 @@ LogicalResult BitcastOp::verify() { if (operandType == resultType) { return emitError("result type must be different from operand type"); } + + auto operandCoopMatrixType = + dyn_cast(operandType); + auto resultCoopMatrixType = + dyn_cast(resultType); + if (operandCoopMatrixType || resultCoopMatrixType) { + if (!operandCoopMatrixType || !resultCoopMatrixType) + return emitError("unhandled bit cast conversion from cooperative matrix " + "type to non-cooperative matrix type"); + + if (operandCoopMatrixType.getRows() != resultCoopMatrixType.getRows() || + operandCoopMatrixType.getColumns() != resultCoopMatrixType.getColumns()) + return emitError("cooperative matrix dimensions must match"); + + if (operandCoopMatrixType.getScope() != resultCoopMatrixType.getScope()) + return emitError("cooperative matrix scope must match"); + + if (operandCoopMatrixType.getUse() != resultCoopMatrixType.getUse()) + return emitError("cooperative matrix use must match"); + + unsigned operandBitWidth = + getBitWidth(operandCoopMatrixType.getElementType()); + unsigned resultBitWidth = + getBitWidth(resultCoopMatrixType.getElementType()); + if (operandBitWidth != resultBitWidth) + return emitOpError("mismatch in result and operand type bitwidth"); + + return success(); + } + if (isa(operandType) && !isa(resultType)) { return emitError( diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir index 4480a1f3720f2..6614a49a1253e 100644 --- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir @@ -40,6 +40,12 @@ func.func @cast6(%arg0 : vector<4xf32>) { return } +func.func @cast_coop_matrix(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) { + // CHECK: {{%.*}} = spirv.Bitcast {{%.*}} : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Subgroup, MatrixA> + %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Subgroup, MatrixA> + return +} + // ----- func.func @cast1(%arg0 : f32) { @@ -82,6 +88,46 @@ func.func @cast3(%arg0 : i64) { // ----- +func.func @cast_coop_matrix_size_mismatch(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) { + // expected-error @+1 {{cooperative matrix dimensions must match}} + %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x2xi32, Subgroup, MatrixA> + return +} + +// ----- + +func.func @cast_coop_matrix_scope_mismatch(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) { + // expected-error @+1 {{cooperative matrix scope must match}} + %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Workgroup, MatrixA> + return +} + +// ----- + +func.func @cast_coop_matrix_use_mismatch(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) { + // expected-error @+1 {{cooperative matrix use must match}} + %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Subgroup, MatrixB> + return +} + +// ----- + +func.func @cast_coop_matrix_bitwidth_mismatch(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) { + // expected-error @+1 {{mismatch in result and operand type bitwidth}} + %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi64, Subgroup, MatrixA> + return +} + +// ----- + +func.func @cast_coop_to_non_coop(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) { + // expected-error @+1 {{unhandled bit cast conversion from cooperative matrix type to non-cooperative matrix type}} + %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to i32 + return +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.ConvertFToS //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/cast-ops.mlir b/mlir/test/Target/SPIRV/cast-ops.mlir index 4f29610f928c4..317264d37b8c8 100644 --- a/mlir/test/Target/SPIRV/cast-ops.mlir +++ b/mlir/test/Target/SPIRV/cast-ops.mlir @@ -19,6 +19,16 @@ spirv.module Logical GLSL450 requires #spirv.vce { // ----- +spirv.module Logical Vulkan requires #spirv.vce { + spirv.func @bit_cast_coop(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) "None" { + // CHECK: {{%.*}} = spirv.Bitcast {{%.*}} : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Subgroup, MatrixA> + %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Subgroup, MatrixA> + spirv.Return + } +} + +// ----- + spirv.module Logical GLSL450 requires #spirv.vce { spirv.func @convert_f_to_s(%arg0 : f32) -> i32 "None" { // CHECK: {{%.*}} = spirv.ConvertFToS {{%.*}} : f32 to i32