diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td index 01fe12a4660af..7aeabd297211a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td @@ -441,6 +441,38 @@ def SPIRV_GLInverseSqrtOp : SPIRV_GLUnaryArithmeticOp<"InverseSqrt", 32, SPIRV_F // ----- +def SPIRV_GLMatrixInverseOp : SPIRV_GLOp<"MatrixInverse", 34, + [Pure, SameOperandsAndResultType]> { + let summary = "Compute the inverse of a matrix"; + + let description = [{ + Result is the inverse of the operand. The operand x must be a square matrix + of floating-point type. + + Result Type and the type of x must be the same type. + + #### Example: + + ```mlir + %0 = spirv.GL.MatrixInverse %matrix : !spirv.matrix<4 x vector<4xf32>> + ``` + }]; + + let arguments = (ins + SPIRV_MatrixOf:$matrix + ); + + let results = (outs + SPIRV_MatrixOf:$result + ); + + let assemblyFormat = "$matrix attr-dict `:` type($matrix)"; + + let hasVerifier = 1; +} + +// ----- + def SPIRV_GLLogOp : SPIRV_GLUnaryArithmeticOp<"Log", 28, SPIRV_Float16or32> { let summary = "Natural logarithm of the operand"; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 9300483a0f92f..befbb2841fc8b 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -2054,6 +2054,19 @@ LogicalResult spirv::GLLdexpOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// spirv.GL.MatrixInverse +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GLMatrixInverseOp::verify() { + auto matrixType = cast(getMatrix().getType()); + if (matrixType.getNumColumns() != matrixType.getNumRows()) + return emitOpError("matrix must be square, got ") + << matrixType.getNumColumns() << " columns and " + << matrixType.getNumRows() << " rows"; + return success(); +} + //===----------------------------------------------------------------------===// // spirv.ShiftLeftLogicalOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir index eea80ca3798a6..fde1bda15589a 100644 --- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir @@ -127,6 +127,32 @@ func.func @inversesqrtvec(%arg0 : vector<3xf16>) -> () { // ----- +//===----------------------------------------------------------------------===// +// spirv.GL.MatrixInverse +//===----------------------------------------------------------------------===// + +func.func @matrix_inverse(%matrix : !spirv.matrix<4 x vector<4xf32>>) -> () { + // CHECK: spirv.GL.MatrixInverse {{%.*}} : !spirv.matrix<4 x vector<4xf32>> + %0 = spirv.GL.MatrixInverse %matrix : !spirv.matrix<4 x vector<4xf32>> + return +} + +func.func @matrix_inverse_2x2(%matrix : !spirv.matrix<2 x vector<2xf32>>) -> () { + // CHECK: spirv.GL.MatrixInverse {{%.*}} : !spirv.matrix<2 x vector<2xf32>> + %0 = spirv.GL.MatrixInverse %matrix : !spirv.matrix<2 x vector<2xf32>> + return +} + +// ----- + +func.func @matrix_inverse_non_square(%matrix : !spirv.matrix<3 x vector<4xf32>>) -> () { + // expected-error @+1 {{matrix must be square, got 3 columns and 4 rows}} + %0 = spirv.GL.MatrixInverse %matrix : !spirv.matrix<3 x vector<4xf32>> + return +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.GL.Sqrt //===----------------------------------------------------------------------===//