Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -4071,6 +4071,23 @@ def SPIRV_KHR_CooperativeMatrixLayoutAttr :
SPIRV_KHR_CML_RowMajor, SPIRV_KHR_CML_ColumnMajor
]>;

// Cooperative Matrix Operands for the SPV_KHR_cooperative_matrix extension.
def SPIRV_KHR_CMO_None : I32BitEnumAttrCaseNone<"None">;
def SPIRV_KHR_CMO_MatrixA_Signed : I32BitEnumAttrCaseBit<"ASigned", 1>;
def SPIRV_KHR_CMO_MatrixB_Signed : I32BitEnumAttrCaseBit<"BSigned", 2>;
def SPIRV_KHR_CMO_MatrixC_Signed : I32BitEnumAttrCaseBit<"CSigned", 4>;
def SPIRV_KHR_CMO_Result_Signed : I32BitEnumAttrCaseBit<"ResultSigned", 8>;
def SPIRV_KHR_CMO_AccSat : I32BitEnumAttrCaseBit<"AccSat", 16>;

def SPIRV_KHR_CooperativeMatrixOperandsAttr :
SPIRV_BitEnumAttr<"CooperativeMatrixOperandsKHR",
"valid SPIR-V Cooperative Matrix Operands (KHR)",
"cooperative_matrix_operands_khr", [
SPIRV_KHR_CMO_None, SPIRV_KHR_CMO_MatrixA_Signed,
SPIRV_KHR_CMO_MatrixB_Signed, SPIRV_KHR_CMO_MatrixC_Signed,
SPIRV_KHR_CMO_Result_Signed, SPIRV_KHR_CMO_AccSat
]>;

//===----------------------------------------------------------------------===//
// SPIR-V attribute definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -4447,6 +4464,7 @@ def SPIRV_OC_OpSUDotAccSat : I32EnumAttrCase<"OpSUDotAccSat", 445
def SPIRV_OC_OpTypeCooperativeMatrixKHR : I32EnumAttrCase<"OpTypeCooperativeMatrixKHR", 4456>;
def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMatrixLoadKHR", 4457>;
def SPIRV_OC_OpCooperativeMatrixStoreKHR : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
def SPIRV_OC_OpCooperativeMatrixMulAddKHR : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
def SPIRV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
def SPIRV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
Expand Down Expand Up @@ -4548,7 +4566,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
SPIRV_OC_OpSUDotAccSat,
SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixLengthKHR,
SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixMulAddKHR,
SPIRV_OC_OpCooperativeMatrixLengthKHR,
SPIRV_OC_OpTypeCooperativeMatrixNV, SPIRV_OC_OpCooperativeMatrixLoadNV,
SPIRV_OC_OpCooperativeMatrixStoreNV, SPIRV_OC_OpCooperativeMatrixMulAddNV,
SPIRV_OC_OpCooperativeMatrixLengthNV,
Expand Down
108 changes: 107 additions & 1 deletion mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,112 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
let results = (outs);
}

// -----

def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMulAdd",
[Pure, AllTypesMatch<["c", "result"]>]> {
let summary = "Returns the result of `(A x B) + C` of matrices A, B, and C";

let description = [{
Linear-algebraic matrix multiply of A by B and then component-wise add C.
The order of the operations is implementation-dependent. The internal
precision of floating-point operations is defined by the client API. Integer
operations used in the multiplication of A by B are performed at the
precision of the Result Type and the resulting value will equal the
low-order N bits of the correct result R, where N is the result width and R
is computed with enough precision to avoid overflow and underflow if the
SaturatingAccumulation Cooperative Matrix Operand is not present. If the
SaturatingAccumulation Cooperative Matrix Operand is present and overflow or
underflow occurs as part of calculating that intermediate result, the result
of the instruction is undefined. Integer additions of the elements of that
intermediate result with those of C are performed at the precision of Result
Type, are exact, and are saturating if the SaturatingAccumulation
Cooperative Matrix Operand is present, with the signedness of the saturation
being that of the components of Result Type. If the SaturatingAccumulation
Cooperative Matrix Operand is not present then the resulting value will
equal the low-order N bits of the correct result R, where N is the result
width and R is computed with enough precision to avoid overflow and
underflow.

Result Type must be a cooperative matrix type with M rows and N columns
whose Use must be MatrixAccumulatorKHR.

A is a cooperative matrix with M rows and K columns whose Use must be
MatrixAKHR.

B is a cooperative matrix with K rows and N columns whose Use must be
MatrixBKHR.

C is a cooperative matrix with M rows and N columns whose Use must be
MatrixAccumulatorKHR.

The values of M, N, and K must be consistent across the result and operands.
This is referred to as an MxNxK matrix multiply.

A, B, C, and Result Type must have the same scope, and this defines the
scope of the operation. A, B, C, and Result Type need not necessarily have
the same component type, this is defined by the client API.

If the Component Type of any matrix operand is an integer type, then its
components are treated as signed if the Matrix{A,B,C,Result}SignedComponents
Cooperative Matrix Operand is present and are treated as unsigned otherwise.

Cooperative Matrix Operands is an optional Cooperative Matrix Operand
literal. If not present, it is the same as specifying the Cooperative Matrix
Operand None.

For a given dynamic instance of this instruction, all invocations in a given
scope instance must be active or all must be inactive (where the scope is
the scope of the operation).

``` {.ebnf}
cooperative-matrixmuladd-op ::= ssa-id `=` `spirv.KHR.CooperativeMatrixMulAdd`
ssa-use `,` ssa-use `,` ssa-use
(`<` matrix-operands `>`)? `:`
a-cooperative-matrix-type `,`
b-cooperative-matrix-type `->`
result-cooperative-matrix-type
```

#### Example:

```
%0 = spirv.KHR.CooperativeMatrixMulAdd %matA, %matB, %matC :
!spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>,
!spirv.coopmatrix<4x4xf32, Subgroup, MatrixB> ->
!spirv.coopmatrix<4x4xf32, Subgroup, MatrixAcc>

%1 = spirv.KHR.CooperativeMatrixMulAdd %matA, %matB, %matC, <ASigned | AccSat> :
!spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
!spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
!spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
```
}];

let assemblyFormat = [{
$a `,` $b `,` $c ( `,` $matrix_operands^ )? attr-dict `:`
type($a) `,` type($b) `->` type($c)
}];

let availability = [
MinVersion<SPIRV_V_1_6>,
MaxVersion<SPIRV_V_1_6>,
Extension<[SPV_KHR_cooperative_matrix]>,
Capability<[SPIRV_C_CooperativeMatrixKHR]>
];

let arguments = (ins
SPIRV_AnyCooperativeMatrix:$a,
SPIRV_AnyCooperativeMatrix:$b,
SPIRV_AnyCooperativeMatrix:$c,
OptionalAttr<SPIRV_KHR_CooperativeMatrixOperandsAttr>:$matrix_operands
);

let results = (outs
SPIRV_AnyCooperativeMatrix:$result
);
}

//===----------------------------------------------------------------------===//
// SPV_NV_cooperative_matrix extension ops.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -380,7 +486,7 @@ def SPIRV_NVCooperativeMatrixMulAddOp : SPIRV_NvVendorOp<"CooperativeMatrixMulAd
}];

let assemblyFormat = [{
operands attr-dict`:` type($a) `,` type($b) `->` type($c)
operands attr-dict `:` type($a) `,` type($b) `->` type($c)
}];

let availability = [
Expand Down
55 changes: 55 additions & 0 deletions mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
//===----------------------------------------------------------------------===//

#include "SPIRVParsingUtils.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "llvm/ADT/STLExtras.h"
#include <cstdint>

using namespace mlir::spirv::AttrNames;

Expand Down Expand Up @@ -151,6 +154,58 @@ LogicalResult KHRCooperativeMatrixStoreOp::verify() {
getObject().getType());
}

//===----------------------------------------------------------------------===//
// spirv.KHR.CooperativeMatrixMulAdd
//===----------------------------------------------------------------------===//

LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
auto typeA = cast<spirv::CooperativeMatrixType>(getA().getType());
auto typeB = cast<spirv::CooperativeMatrixType>(getB().getType());
auto typeC = cast<spirv::CooperativeMatrixType>(getC().getType());

// Check element types. ODS enforces that `type(c) == type(result)`, so no
// need to check it here.

// Check the 'use' part of the type against the operands and the result.
if (typeA.getUse() != CooperativeMatrixUseKHR::MatrixA)
return emitOpError("operand #0 must be of use 'MatrixA'");
if (typeB.getUse() != CooperativeMatrixUseKHR::MatrixB)
return emitOpError("operand #1 must be of use 'MatrixB'");
if (typeC.getUse() != CooperativeMatrixUseKHR::MatrixAcc)
return emitOpError("operand #2 must be of use 'MatrixAcc'");

// Check the 'scope' part of the type.
if (!llvm::all_equal({typeA.getScope(), typeB.getScope(), typeC.getScope()}))
return emitOpError("matrix scope mismatch");

// Check dimension sizes. We expect 'MxK * KxN + MxN -> MxN'.
if (typeA.getRows() != typeC.getRows())
return emitOpError("matrix size mismatch on dimension 'M'");
if (typeB.getColumns() != typeC.getColumns())
return emitOpError("matrix size mismatch on dimension 'N'");
if (typeA.getColumns() != typeB.getRows())
return emitOpError("matrix size mismatch on dimension 'K'");

// The spec does not restrict the element types:
// > A, B, C, and Result Type need not necessarily have the same component
// > type, this is defined by the client API.

// Check that if Cooperative Matrix Operands are provided, the element type
// is integer.
if (getMatrixOperands()) {
Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(),
typeC.getElementType()};
if (!llvm::all_of(elementTypes,
[](Type ty) { return isa<IntegerType>(ty); })) {
return emitOpError("Matrix Operands require all matrix element types to "
"be Integer Types");
}
}

// Any further requirements need to be checked against VCE.
return success();
}

//===----------------------------------------------------------------------===//
// spirv.NV.CooperativeMatrixLength
//===----------------------------------------------------------------------===//
Expand Down
Loading