82 changes: 41 additions & 41 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,27 @@

// -----

def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTEL",
def SPV_INTELJointMatrixWorkItemLengthOp : SPV_IntelVendorOp<"JointMatrixWorkItemLength",
[NoSideEffect]> {
let summary = "See extension SPV_INTEL_joint_matrix";

let description = [{
Return number of components owned by the current work-item in
Return number of components owned by the current work-item in
a joint matrix.

Result Type must be an 32-bit unsigned integer type scalar.

Type is a joint matrix type.

``` {.ebnf}
joint-matrix-length-op ::= ssa-id `=` `spv.JointMatrixWorkItemLengthINTEL
joint-matrix-length-op ::= ssa-id `=` `spv.INTEL.JointMatrixWorkItemLength
` : ` joint-matrix-type
```

For example:

```
%0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<Subgroup, i32, 8, 16>
%0 = spv.INTEL.JointMatrixWorkItemLength : !spv.jointmatrix<Subgroup, i32, 8, 16>
```
}];

Expand All @@ -60,34 +60,34 @@ def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTE

// -----

def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
def SPV_INTELJointMatrixLoadOp : SPV_IntelVendorOp<"JointMatrixLoad", []> {
let summary = "See extension SPV_INTEL_joint_matrix";

let description = [{
Load a matrix through a pointer.

Result Type is the type of the loaded matrix. It must be OpTypeJointMatrixINTEL.

Pointer is the pointer to load through. It specifies start of memory region where
Pointer is the pointer to load through. It specifies start of memory region where
elements of the matrix are stored and arranged according to Layout.

Stride is the number of elements in memory between beginnings of successive rows,
Stride is the number of elements in memory between beginnings of successive rows,
columns (or words) in the result. It must be a scalar integer type.

Layout indicates how the values loaded from memory are arranged. It must be the
Layout indicates how the values loaded from memory are arranged. It must be the
result of a constant instruction.

Scope is syncronization scope for operation on the matrix. It must be the result
Scope is syncronization scope for operation on the matrix. It must be the result
of a constant instruction with scalar integer type.

If present, any Memory Operands must begin with a memory operand literal. If not
If present, any Memory Operands must begin with a memory operand literal. If not
present, it is the same as specifying the memory operand None.

#### Example:
```mlir
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride
{memory_access = #spv.memory_access<Volatile>} :
(!spv.ptr<i32, CrossWorkgroup>, i32) ->
%0 = spv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride
{memory_access = #spv.memory_access<Volatile>} :
(!spv.ptr<i32, CrossWorkgroup>, i32) ->
!spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
```
}];
Expand Down Expand Up @@ -119,39 +119,39 @@ def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {

// -----

def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
def SPV_INTELJointMatrixMadOp : SPV_IntelVendorOp<"JointMatrixMad",
[NoSideEffect, AllTypesMatch<["c", "result"]>]> {
let summary = "See extension SPV_INTEL_joint_matrix";

let description = [{
Multiply matrix A by matrix B and add matrix C to the result
of the multiplication: A*B+C. Here A is a M x K matrix, B is
Multiply matrix A by matrix B and add matrix C to the result
of the multiplication: A*B+C. Here A is a M x K matrix, B is
a K x N matrix and C is a M x N matrix.

Behavior is undefined if sizes of operands do not meet the
conditions above. All operands and the Result Type must be
Behavior is undefined if sizes of operands do not meet the
conditions above. All operands and the Result Type must be
OpTypeJointMatrixINTEL.

A must be a OpTypeJointMatrixINTEL whose Component Type is a
signed numerical type, Row Count equals to M and Column Count
A must be a OpTypeJointMatrixINTEL whose Component Type is a
signed numerical type, Row Count equals to M and Column Count
equals to K

B must be a OpTypeJointMatrixINTEL whose Component Type is a
signed numerical type, Row Count equals to K and Column Count
B must be a OpTypeJointMatrixINTEL whose Component Type is a
signed numerical type, Row Count equals to K and Column Count
equals to N

C and Result Type must be a OpTypeJointMatrixINTEL with Row
C and Result Type must be a OpTypeJointMatrixINTEL with Row
Count equals to M and Column Count equals to N

Scope is syncronization scope for operation on the matrix.
It must be the result of a constant instruction with scalar
Scope is syncronization scope for operation on the matrix.
It must be the result of a constant instruction with scalar
integer type.

#### Example:
```mlir
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c :
!spv.jointmatrix<8x32xi8, RowMajor, Subgroup>,
!spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>
%r = spv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c :
!spv.jointmatrix<8x32xi8, RowMajor, Subgroup>,
!spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>
-> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
```

Expand Down Expand Up @@ -182,38 +182,38 @@ def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",

// -----

def SPV_JointMatrixStoreINTELOp : SPV_Op<"JointMatrixStoreINTEL", []> {
def SPV_INTELJointMatrixStoreOp : SPV_IntelVendorOp<"JointMatrixStore", []> {
let summary = "See extension SPV_INTEL_joint_matrix";

let description = [{
Store a matrix through a pointer.

Pointer is the pointer to store through. It specifies
start of memory region where elements of the matrix must
Pointer is the pointer to store through. It specifies
start of memory region where elements of the matrix must
be stored and arranged according to Layout.

Object is the matrix to store. It must be
Object is the matrix to store. It must be
OpTypeJointMatrixINTEL.

Stride is the number of elements in memory between beginnings
of successive rows, columns (or words) of the Object. It must
Stride is the number of elements in memory between beginnings
of successive rows, columns (or words) of the Object. It must
be a scalar integer type.

Layout indicates how the values stored to memory are arranged.
Layout indicates how the values stored to memory are arranged.
It must be the result of a constant instruction.

Scope is syncronization scope for operation on the matrix.
It must be the result of a constant instruction with scalar
Scope is syncronization scope for operation on the matrix.
It must be the result of a constant instruction with scalar
integer type.

If present, any Memory Operands must begin with a memory operand
literal. If not present, it is the same as specifying the memory
If present, any Memory Operands must begin with a memory operand
literal. If not present, it is the same as specifying the memory
operand None.

#### Example:
```mlir
spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride
{memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>,
spv.INTEL.JointMatrixStore <Subgroup> <ColumnMajor> %ptr, %m, %stride
{memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>,
!spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
```

Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"

// -----

def SPV_AssumeTrueKHROp : SPV_Op<"AssumeTrueKHR", []> {
def SPV_KHRAssumeTrueOp : SPV_KhrVendorOp<"AssumeTrue", []> {
let summary = "TBD";

let description = [{
Expand All @@ -27,13 +27,13 @@ def SPV_AssumeTrueKHROp : SPV_Op<"AssumeTrueKHR", []> {
<!-- End of AutoGen section -->

```
assumetruekhr-op ::= `spv.AssumeTrueKHR` ssa-use
assumetruekhr-op ::= `spv.KHR.AssumeTrue` ssa-use
```mlir

#### Example:

```
spv.AssumeTrueKHR %arg
spv.KHR.AssumeTrue %arg
```
}];

Expand Down
60 changes: 30 additions & 30 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1332,18 +1332,18 @@ void spirv::AtomicIAddOp::print(OpAsmPrinter &p) {
}

//===----------------------------------------------------------------------===//
// spv.AtomicFAddEXTOp
// spv.EXT.AtomicFAddOp
//===----------------------------------------------------------------------===//

LogicalResult spirv::AtomicFAddEXTOp::verify() {
LogicalResult spirv::EXTAtomicFAddOp::verify() {
return ::verifyAtomicUpdateOp<FloatType>(getOperation());
}

ParseResult spirv::AtomicFAddEXTOp::parse(OpAsmParser &parser,
ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser,
OperationState &result) {
return ::parseAtomicUpdateOp(parser, result, true);
}
void spirv::AtomicFAddEXTOp::print(OpAsmPrinter &p) {
void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) {
::printAtomicUpdateOp(*this, p);
}

Expand Down Expand Up @@ -2643,10 +2643,10 @@ LogicalResult spirv::GroupNonUniformShuffleXorOp::verify() {
}

//===----------------------------------------------------------------------===//
// spv.SubgroupBlockReadINTEL
// spv.INTEL.SubgroupBlockRead
//===----------------------------------------------------------------------===//

ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse the storage class specification
spirv::StorageClass storageClass;
Expand All @@ -2669,22 +2669,22 @@ ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
return success();
}

void spirv::SubgroupBlockReadINTELOp::print(OpAsmPrinter &printer) {
void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) {
printer << " " << ptr() << " : " << getType();
}

LogicalResult spirv::SubgroupBlockReadINTELOp::verify() {
LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
return failure();

return success();
}

//===----------------------------------------------------------------------===//
// spv.SubgroupBlockWriteINTEL
// spv.INTEL.SubgroupBlockWrite
//===----------------------------------------------------------------------===//

ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse the storage class specification
spirv::StorageClass storageClass;
Expand All @@ -2708,11 +2708,11 @@ ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
return success();
}

void spirv::SubgroupBlockWriteINTELOp::print(OpAsmPrinter &printer) {
void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
printer << " " << ptr() << ", " << value() << " : " << value().getType();
}

LogicalResult spirv::SubgroupBlockWriteINTELOp::verify() {
LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
return failure();

Expand Down Expand Up @@ -3813,10 +3813,10 @@ LogicalResult spirv::VectorShuffleOp::verify() {
}

//===----------------------------------------------------------------------===//
// spv.CooperativeMatrixLoadNV
// spv.NV.CooperativeMatrixLoad
//===----------------------------------------------------------------------===//

ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
Type strideType = parser.getBuilder().getIntegerType(32);
Expand All @@ -3838,7 +3838,7 @@ ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
return success();
}

void spirv::CooperativeMatrixLoadNVOp::print(OpAsmPrinter &printer) {
void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
printer << " " << pointer() << ", " << stride() << ", " << columnmajor();
// Print optional memory access attribute.
if (auto memAccess = memory_access())
Expand All @@ -3865,16 +3865,16 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
return success();
}

LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() {
LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() {
return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
result().getType());
}

//===----------------------------------------------------------------------===//
// spv.CooperativeMatrixStoreNV
// spv.NV.CooperativeMatrixStore
//===----------------------------------------------------------------------===//

ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
Type strideType = parser.getBuilder().getIntegerType(32);
Expand All @@ -3896,7 +3896,7 @@ ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
return success();
}

void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
printer << " " << pointer() << ", " << object() << ", " << stride() << ", "
<< columnmajor();
// Print optional memory access attribute.
Expand All @@ -3905,17 +3905,17 @@ void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
printer << " : " << pointer().getType() << ", " << getOperand(1).getType();
}

LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() {
LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
object().getType());
}

//===----------------------------------------------------------------------===//
// spv.CooperativeMatrixMulAddNV
// spv.NV.CooperativeMatrixMulAdd
//===----------------------------------------------------------------------===//

static LogicalResult
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
if (op.c().getType() != op.result().getType())
return op.emitOpError("result and third operand must have the same type");
auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
Expand All @@ -3936,7 +3936,7 @@ verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
return success();
}

LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() {
LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() {
return verifyCoopMatrixMulAdd(*this);
}

Expand All @@ -3960,28 +3960,28 @@ verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
}

//===----------------------------------------------------------------------===//
// spv.JointMatrixLoadINTEL
// spv.INTEL.JointMatrixLoad
//===----------------------------------------------------------------------===//

LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
return verifyPointerAndJointMatrixType(*this, pointer().getType(),
result().getType());
}

//===----------------------------------------------------------------------===//
// spv.JointMatrixStoreINTEL
// spv.INTEL.JointMatrixStore
//===----------------------------------------------------------------------===//

LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
return verifyPointerAndJointMatrixType(*this, pointer().getType(),
object().getType());
}

//===----------------------------------------------------------------------===//
// spv.JointMatrixMadINTEL
// spv.INTEL.JointMatrixMad
//===----------------------------------------------------------------------===//

static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
if (op.c().getType() != op.result().getType())
return op.emitOpError("result and third operand must have the same type");
auto typeA = op.a().getType().cast<spirv::JointMatrixINTELType>();
Expand All @@ -4002,7 +4002,7 @@ static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
return success();
}

LogicalResult spirv::JointMatrixMadINTELOp::verify() {
LogicalResult spirv::INTELJointMatrixMadOp::verify() {
return verifyJointMatrixMad(*this);
}

Expand Down
12 changes: 6 additions & 6 deletions mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -240,35 +240,35 @@ func.func @atomic_xor(%ptr : !spv.ptr<i32, StorageBuffer>, %value : i32) -> i32
// -----

//===----------------------------------------------------------------------===//
// spv.AtomicFAddEXT
// spv.EXT.AtomicFAdd
//===----------------------------------------------------------------------===//

func.func @atomic_fadd(%ptr : !spv.ptr<f32, StorageBuffer>, %value : f32) -> f32 {
// CHECK: spv.AtomicFAddEXT "Device" "None" %{{.*}}, %{{.*}} : !spv.ptr<f32, StorageBuffer>
%0 = spv.AtomicFAddEXT "Device" "None" %ptr, %value : !spv.ptr<f32, StorageBuffer>
// CHECK: spv.EXT.AtomicFAdd "Device" "None" %{{.*}}, %{{.*}} : !spv.ptr<f32, StorageBuffer>
%0 = spv.EXT.AtomicFAdd "Device" "None" %ptr, %value : !spv.ptr<f32, StorageBuffer>
return %0 : f32
}

// -----

func.func @atomic_fadd(%ptr : !spv.ptr<i32, StorageBuffer>, %value : f32) -> f32 {
// expected-error @+1 {{pointer operand must point to an float value, found 'i32'}}
%0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = #spv.scope<Workgroup>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, StorageBuffer>, f32) -> (f32)
%0 = "spv.EXT.AtomicFAdd"(%ptr, %value) {memory_scope = #spv.scope<Workgroup>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, StorageBuffer>, f32) -> (f32)
return %0 : f32
}

// -----

func.func @atomic_fadd(%ptr : !spv.ptr<f32, StorageBuffer>, %value : f64) -> f64 {
// expected-error @+1 {{expected value to have the same type as the pointer operand's pointee type 'f32', but found 'f64'}}
%0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = #spv.scope<Device>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<f32, StorageBuffer>, f64) -> (f64)
%0 = "spv.EXT.AtomicFAdd"(%ptr, %value) {memory_scope = #spv.scope<Device>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<f32, StorageBuffer>, f64) -> (f64)
return %0 : f64
}

// -----

func.func @atomic_fadd(%ptr : !spv.ptr<f32, StorageBuffer>, %value : f32) -> f32 {
// expected-error @+1 {{expected at most one of these four memory constraints to be set: `Acquire`, `Release`,`AcquireRelease` or `SequentiallyConsistent`}}
%0 = spv.AtomicFAddEXT "Device" "Acquire|Release" %ptr, %value : !spv.ptr<f32, StorageBuffer>
%0 = spv.EXT.AtomicFAdd "Device" "Acquire|Release" %ptr, %value : !spv.ptr<f32, StorageBuffer>
return %0 : f32
}
46 changes: 23 additions & 23 deletions mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,51 @@

// CHECK-LABEL: @cooperative_matrix_load
spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
// CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
%0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
// CHECK: {{%.*}} = spv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
%0 = spv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
spv.Return
}

// -----
// CHECK-LABEL: @cooperative_matrix_load_memaccess
spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
// CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
%0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
// CHECK: {{%.*}} = spv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
%0 = spv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
spv.Return
}

// CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type
spv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32, %b : i1) "None" {
// CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<vector<4xi32>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
%0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr<vector<4xi32>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
// CHECK: {{%.*}} = spv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<vector<4xi32>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
%0 = spv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spv.ptr<vector<4xi32>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
spv.Return
}

// CHECK-LABEL: @cooperative_matrix_store
spv.func @cooperative_matrix_store(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spv.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
// CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Workgroup>
spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Workgroup>
// CHECK: spv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Workgroup>
spv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Workgroup>
spv.Return
}

// CHECK-LABEL: @cooperative_matrix_store_memaccess
spv.func @cooperative_matrix_store_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %m : !spv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
// CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
// CHECK: spv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
spv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
spv.Return
}

// CHECK-LABEL: @cooperative_matrix_length
spv.func @cooperative_matrix_length() -> i32 "None" {
// CHECK: {{%.*}} = spv.CooperativeMatrixLengthNV : !spv.coopmatrix<8x16xi32, Subgroup>
%0 = spv.CooperativeMatrixLengthNV : !spv.coopmatrix<8x16xi32, Subgroup>
// CHECK: {{%.*}} = spv.NV.CooperativeMatrixLength : !spv.coopmatrix<8x16xi32, Subgroup>
%0 = spv.NV.CooperativeMatrixLength : !spv.coopmatrix<8x16xi32, Subgroup>
spv.ReturnValue %0 : i32
}

// CHECK-LABEL: @cooperative_matrix_muladd
spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x32xi8, Subgroup>, %b : !spv.coopmatrix<32x8xi8, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x32xi8, Subgroup>, !spv.coopmatrix<32x8xi8, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x32xi8, Subgroup>, !spv.coopmatrix<32x8xi8, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
// CHECK: {{%.*}} = spv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x32xi8, Subgroup>, !spv.coopmatrix<32x8xi8, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
%r = spv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spv.coopmatrix<8x32xi8, Subgroup>, !spv.coopmatrix<32x8xi8, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
spv.Return
}

Expand Down Expand Up @@ -112,47 +112,47 @@ spv.func @cooperative_matrix_access_chain(%a : !spv.ptr<!spv.coopmatrix<8x16xf32
// -----

spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<16x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
// expected-error @+1 {{'spv.CooperativeMatrixMulAddNV' op matrix size must match}}
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<16x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
// expected-error @+1 {{'spv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
%r = spv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spv.coopmatrix<16x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
spv.Return
}

// -----

spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<8x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
// expected-error @+1 {{'spv.CooperativeMatrixMulAddNV' op matrix size must match}}
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<8x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
// expected-error @+1 {{'spv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
%r = spv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<8x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
spv.Return
}

// -----

spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Workgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
// expected-error @+1 {{'spv.CooperativeMatrixMulAddNV' op matrix scope must match}}
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Workgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
// expected-error @+1 {{'spv.NV.CooperativeMatrixMulAdd' op matrix scope must match}}
%r = spv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Workgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
spv.Return
}

// -----

spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
// expected-error @+1 {{matrix element type must match}}
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xf32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
%r = spv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spv.coopmatrix<8x16xf32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
spv.Return
}

// -----

spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<!spv.struct<(f32 [0])>, StorageBuffer>, %stride : i32, %b : i1) "None" {
// expected-error @+1 {{Pointer must point to a scalar or vector type}}
%0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<!spv.struct<(f32 [0])>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
%0 = spv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spv.ptr<!spv.struct<(f32 [0])>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
spv.Return
}

// -----

spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, Function>, %stride : i32, %b : i1) "None" {
// expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}}
%0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<i32, Function> as !spv.coopmatrix<8x16xi32, Subgroup>
%0 = spv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spv.ptr<i32, Function> as !spv.coopmatrix<8x16xi32, Subgroup>
spv.Return
}
30 changes: 15 additions & 15 deletions mlir/test/Dialect/SPIRV/IR/group-ops.mlir
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s

//===----------------------------------------------------------------------===//
// spv.SubgroupBallotKHR
// spv.KHR.SubgroupBallot
//===----------------------------------------------------------------------===//

func.func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
// CHECK: %{{.*}} = spv.SubgroupBallotKHR %{{.*}} : vector<4xi32>
%0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
// CHECK: %{{.*}} = spv.KHR.SubgroupBallot %{{.*}} : vector<4xi32>
%0 = spv.KHR.SubgroupBallot %predicate: vector<4xi32>
return %0: vector<4xi32>
}

Expand Down Expand Up @@ -65,50 +65,50 @@ func.func @group_broadcast_negative_locid_vec4(%value: f32, %localid: vector<4xi
// -----

//===----------------------------------------------------------------------===//
// spv.SubgroupBallotKHR
// spv.KHR.SubgroupBallot
//===----------------------------------------------------------------------===//

func.func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
%0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
%0 = spv.KHR.SubgroupBallot %predicate: vector<4xi32>
return %0: vector<4xi32>
}

// -----

//===----------------------------------------------------------------------===//
// spv.SubgroupBlockReadINTEL
// spv.INTEL.SubgroupBlockRead
//===----------------------------------------------------------------------===//

func.func @subgroup_block_read_intel(%ptr : !spv.ptr<i32, StorageBuffer>) -> i32 {
// CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : i32
%0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32
// CHECK: spv.INTEL.SubgroupBlockRead %{{.*}} : i32
%0 = spv.INTEL.SubgroupBlockRead "StorageBuffer" %ptr : i32
return %0: i32
}

// -----

func.func @subgroup_block_read_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>) -> vector<3xi32> {
// CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : vector<3xi32>
%0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : vector<3xi32>
// CHECK: spv.INTEL.SubgroupBlockRead %{{.*}} : vector<3xi32>
%0 = spv.INTEL.SubgroupBlockRead "StorageBuffer" %ptr : vector<3xi32>
return %0: vector<3xi32>
}

// -----

//===----------------------------------------------------------------------===//
// spv.SubgroupBlockWriteINTEL
// spv.INTEL.SubgroupBlockWrite
//===----------------------------------------------------------------------===//

func.func @subgroup_block_write_intel(%ptr : !spv.ptr<i32, StorageBuffer>, %value: i32) -> () {
// CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : i32
spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32
// CHECK: spv.INTEL.SubgroupBlockWrite %{{.*}}, %{{.*}} : i32
spv.INTEL.SubgroupBlockWrite "StorageBuffer" %ptr, %value : i32
return
}

// -----

func.func @subgroup_block_write_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>, %value: vector<3xi32>) -> () {
// CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : vector<3xi32>
spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : vector<3xi32>
// CHECK: spv.INTEL.SubgroupBlockWrite %{{.*}}, %{{.*}} : vector<3xi32>
spv.INTEL.SubgroupBlockWrite "StorageBuffer" %ptr, %value : vector<3xi32>
return
}
46 changes: 23 additions & 23 deletions mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,98 +2,98 @@

// CHECK-LABEL: @joint_matrix_load
spv.func @joint_matrix_load(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32) "None" {
// CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
// CHECK: {{%.*}} = spv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
%0 = spv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
spv.Return
}

// -----
// CHECK-LABEL: @joint_matrix_load_memaccess
spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<i32, CrossWorkgroup>, %stride : i32) "None" {
// CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, CrossWorkgroup>, i32) -> !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, CrossWorkgroup>, i32) -> !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
// CHECK: {{%.*}} = spv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, CrossWorkgroup>, i32) -> !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
%0 = spv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, CrossWorkgroup>, i32) -> !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
spv.Return
}

// CHECK-LABEL: @joint_matrix_load_diff_ptr_type
spv.func @joint_matrix_load_diff_ptr_type(%ptr : !spv.ptr<vector<4xi32>, Workgroup>, %stride : i32) "None" {
// CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<vector<4xi32>, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<vector<4xi32>, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>
// CHECK: {{%.*}} = spv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<vector<4xi32>, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>
%0 = spv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<vector<4xi32>, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>
spv.Return
}

// CHECK-LABEL: @joint_matrix_store
spv.func @joint_matrix_store(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32, %m : !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>) "None" {
// CHECK: spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32)
spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> %ptr, %m, %stride : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32)
// CHECK: spv.INTEL.JointMatrixStore <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32)
spv.INTEL.JointMatrixStore <Subgroup> <RowMajor> %ptr, %m, %stride : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Workgroup>, i32)
spv.Return
}

// CHECK-LABEL: @joint_matrix_store_memaccess
spv.func @joint_matrix_store_memaccess(%ptr : !spv.ptr<i32, Workgroup>, %m : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %stride : i32) "None" {
// CHECK: spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> {{%.*}}, {{%.*}}, {{%.*}} {Volatile} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride {Volatile} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
// CHECK: spv.INTEL.JointMatrixStore <Subgroup> <ColumnMajor> {{%.*}}, {{%.*}}, {{%.*}} {Volatile} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
spv.INTEL.JointMatrixStore <Subgroup> <ColumnMajor> %ptr, %m, %stride {Volatile} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
spv.Return
}

// CHECK-LABEL: @joint_matrix_length
spv.func @joint_matrix_length() -> i32 "None" {
// CHECK: {{%.*}} = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, PackedB, Subgroup>
%0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, PackedB, Subgroup>
// CHECK: {{%.*}} = spv.INTEL.JointMatrixWorkItemLength : !spv.jointmatrix<8x16xi32, PackedB, Subgroup>
%0 = spv.INTEL.JointMatrixWorkItemLength : !spv.jointmatrix<8x16xi32, PackedB, Subgroup>
spv.ReturnValue %0 : i32
}

// CHECK-LABEL: @joint_matrix_muladd
spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, %b : !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.JointMatrixMadINTEL <Subgroup> {{%.*}}, {{%.*}}, {{%.*}} : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
// CHECK: {{%.*}} = spv.INTEL.JointMatrixMad <Subgroup> {{%.*}}, {{%.*}}, {{%.*}} : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
%r = spv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
spv.Return
}

// -----

spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<16x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
// expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix size must match}}
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<16x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
// expected-error @+1 {{'spv.INTEL.JointMatrixMad' op matrix size must match}}
%r = spv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spv.jointmatrix<16x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
spv.Return
}

// -----

spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
// expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix size must match}}
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
// expected-error @+1 {{'spv.INTEL.JointMatrixMad' op matrix size must match}}
%r = spv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
spv.Return
}

// -----

spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
// expected-error @+1 {{'spv.JointMatrixMadINTEL' op matrix scope must match}}
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
// expected-error @+1 {{'spv.INTEL.JointMatrixMad' op matrix scope must match}}
%r = spv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
spv.Return
}

// -----

spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
// expected-error @+1 {{matrix element type must match}}
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
%r = spv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
spv.Return
}

// -----

spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<!spv.struct<(f32 [0])>, Workgroup>, %stride : i32) "None" {
// expected-error @+1 {{Pointer must point to a scalar or vector type}}
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<!spv.struct<(f32 [0])>, Workgroup>, i32)-> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
%0 = spv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<!spv.struct<(f32 [0])>, Workgroup>, i32)-> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
spv.Return
}

// -----

spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<i32, Function>, %stride : i32) "None" {
// expected-error @+1 {{Pointer storage class must be Workgroup or CrossWorkgroup}}
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Function>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
%0 = spv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Function>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
spv.Return
}
6 changes: 3 additions & 3 deletions mlir/test/Dialect/SPIRV/IR/misc-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ func.func @undef() -> () {
// -----

func.func @assume_true(%arg : i1) -> () {
// CHECK: spv.AssumeTrueKHR %{{.*}}
spv.AssumeTrueKHR %arg
// CHECK: spv.KHR.AssumeTrue %{{.*}}
spv.KHR.AssumeTrue %arg
spv.Return
}

Expand All @@ -41,6 +41,6 @@ func.func @assume_true(%arg : i1) -> () {
func.func @assume_true(%arg : f32) -> () {
// expected-error @+2{{use of value '%arg' expects different type than prior uses: 'i1' vs 'f32'}}
// expected-note @-2 {{prior use here}}
spv.AssumeTrueKHR %arg
spv.KHR.AssumeTrue %arg
spv.Return
}
4 changes: 2 additions & 2 deletions mlir/test/Dialect/SPIRV/IR/target-env.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// spv.GroupNonUniformBallot is available starting from SPIR-V 1.3 under
// GroupNonUniform capability.

// spv.SubgroupBallotKHR is available under in all SPIR-V versions under
// spv.KHR.SubgroupBallot is available under in all SPIR-V versions under
// SubgroupBallotKHR capability and SPV_KHR_shader_ballot extension.

// The GeometryPointSize capability implies the Geometry capability, which
Expand Down Expand Up @@ -130,7 +130,7 @@ func.func @bit_reverse_recursively_implied_capability(%operand: i32) -> i32 attr
func.func @subgroup_ballot_suitable_extension(%predicate: i1) -> vector<4xi32> attributes {
spv.target_env = #spv.target_env<#spv.vce<v1.4, [SubgroupBallotKHR], [SPV_KHR_shader_ballot]>, #spv.resource_limits<>>
} {
// CHECK: spv.SubgroupBallotKHR
// CHECK: spv.KHR.SubgroupBallot
%0 = "test.convert_to_subgroup_ballot_op"(%predicate): (i1) -> (vector<4xi32>)
return %0: vector<4xi32>
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ spv.module Logical GLSL450 attributes {
//===----------------------------------------------------------------------===//

// Test deducing minimal extensions.
// spv.SubgroupBallotKHR requires the SPV_KHR_shader_ballot extension.
// spv.KHR.SubgroupBallot requires the SPV_KHR_shader_ballot extension.

// CHECK: requires #spv.vce<v1.0, [SubgroupBallotKHR, Shader], [SPV_KHR_shader_ballot]>
spv.module Logical GLSL450 attributes {
Expand All @@ -159,7 +159,7 @@ spv.module Logical GLSL450 attributes {
[SPV_KHR_shader_ballot, SPV_KHR_shader_clock, SPV_KHR_variable_pointers]>, #spv.resource_limits<>>
} {
spv.func @subgroup_ballot(%predicate : i1) -> vector<4xi32> "None" {
%0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
%0 = spv.KHR.SubgroupBallot %predicate: vector<4xi32>
spv.ReturnValue %0: vector<4xi32>
}
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Target/SPIRV/atomic-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {

// CHECK-LABEL: @test_float_atomics
spv.func @test_float_atomics(%ptr: !spv.ptr<f32, Workgroup>, %value: f32) -> f32 "None" {
// CHECK: spv.AtomicFAddEXT "Workgroup" "Acquire" %{{.*}}, %{{.*}} : !spv.ptr<f32, Workgroup>
%0 = spv.AtomicFAddEXT "Workgroup" "Acquire" %ptr, %value : !spv.ptr<f32, Workgroup>
// CHECK: spv.EXT.AtomicFAdd "Workgroup" "Acquire" %{{.*}}, %{{.*}} : !spv.ptr<f32, Workgroup>
%0 = spv.EXT.AtomicFAdd "Workgroup" "Acquire" %ptr, %value : !spv.ptr<f32, Workgroup>
spv.ReturnValue %0: f32
}
}
24 changes: 12 additions & 12 deletions mlir/test/Target/SPIRV/cooperative-matrix-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,43 @@
spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_NV_cooperative_matrix]> {
// CHECK-LABEL: @cooperative_matrix_load
spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
// CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
%0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
// CHECK: {{%.*}} = spv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
%0 = spv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
spv.Return
}

// CHECK-LABEL: @cooperative_matrix_load_memaccess
spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
// CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
%0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
// CHECK: {{%.*}} = spv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
%0 = spv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
spv.Return
}

// CHECK-LABEL: @cooperative_matrix_store
spv.func @cooperative_matrix_store(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spv.coopmatrix<16x8xi32, Workgroup>, %b : i1) "None" {
// CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<16x8xi32, Workgroup>
spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<16x8xi32, Workgroup>
// CHECK: spv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<16x8xi32, Workgroup>
spv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<16x8xi32, Workgroup>
spv.Return
}

// CHECK-LABEL: @cooperative_matrix_store_memaccess
spv.func @cooperative_matrix_store_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %m : !spv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
// CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
// CHECK: spv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
spv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
spv.Return
}

// CHECK-LABEL: @cooperative_matrix_length
spv.func @cooperative_matrix_length() -> i32 "None" {
// CHECK: {{%.*}} = spv.CooperativeMatrixLengthNV : !spv.coopmatrix<8x16xi32, Subgroup>
%0 = spv.CooperativeMatrixLengthNV : !spv.coopmatrix<8x16xi32, Subgroup>
// CHECK: {{%.*}} = spv.NV.CooperativeMatrixLength : !spv.coopmatrix<8x16xi32, Subgroup>
%0 = spv.NV.CooperativeMatrixLength : !spv.coopmatrix<8x16xi32, Subgroup>
spv.ReturnValue %0 : i32
}

// CHECK-LABEL: @cooperative_matrix_muladd
spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.CooperativeMatrixMulAddNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
%r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
// CHECK: {{%.*}} = spv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
%r = spv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spv.coopmatrix<8x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>
spv.Return
}

Expand Down
20 changes: 10 additions & 10 deletions mlir/test/Target/SPIRV/group-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK-LABEL: @subgroup_ballot
spv.func @subgroup_ballot(%predicate: i1) -> vector<4xi32> "None" {
// CHECK: %{{.*}} = spv.SubgroupBallotKHR %{{.*}}: vector<4xi32>
%0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
// CHECK: %{{.*}} = spv.KHR.SubgroupBallot %{{.*}}: vector<4xi32>
%0 = spv.KHR.SubgroupBallot %predicate: vector<4xi32>
spv.ReturnValue %0: vector<4xi32>
}
// CHECK-LABEL: @group_broadcast_1
Expand All @@ -21,26 +21,26 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
}
// CHECK-LABEL: @subgroup_block_read_intel
spv.func @subgroup_block_read_intel(%ptr : !spv.ptr<i32, StorageBuffer>) -> i32 "None" {
// CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : i32
%0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32
// CHECK: spv.INTEL.SubgroupBlockRead %{{.*}} : i32
%0 = spv.INTEL.SubgroupBlockRead "StorageBuffer" %ptr : i32
spv.ReturnValue %0: i32
}
// CHECK-LABEL: @subgroup_block_read_intel_vector
spv.func @subgroup_block_read_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>) -> vector<3xi32> "None" {
// CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : vector<3xi32>
%0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : vector<3xi32>
// CHECK: spv.INTEL.SubgroupBlockRead %{{.*}} : vector<3xi32>
%0 = spv.INTEL.SubgroupBlockRead "StorageBuffer" %ptr : vector<3xi32>
spv.ReturnValue %0: vector<3xi32>
}
// CHECK-LABEL: @subgroup_block_write_intel
spv.func @subgroup_block_write_intel(%ptr : !spv.ptr<i32, StorageBuffer>, %value: i32) -> () "None" {
// CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : i32
spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32
// CHECK: spv.INTEL.SubgroupBlockWrite %{{.*}}, %{{.*}} : i32
spv.INTEL.SubgroupBlockWrite "StorageBuffer" %ptr, %value : i32
spv.Return
}
// CHECK-LABEL: @subgroup_block_write_intel_vector
spv.func @subgroup_block_write_intel_vector(%ptr : !spv.ptr<i32, StorageBuffer>, %value: vector<3xi32>) -> () "None" {
// CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : vector<3xi32>
spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : vector<3xi32>
// CHECK: spv.INTEL.SubgroupBlockWrite %{{.*}}, %{{.*}} : vector<3xi32>
spv.INTEL.SubgroupBlockWrite "StorageBuffer" %ptr, %value : vector<3xi32>
spv.Return
}
}
24 changes: 12 additions & 12 deletions mlir/test/Target/SPIRV/joint-matrix-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,43 @@
spv.module Logical GLSL450 requires #spv.vce<v1.0, [JointMatrixINTEL], [SPV_INTEL_joint_matrix]> {
// CHECK-LABEL: @joint_matrix_load
spv.func @joint_matrix_load(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32) "None" {
// CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
// CHECK: {{%.*}} = spv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
%0 = spv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>
spv.Return
}

// CHECK-LABEL: @joint_matrix_load_memaccess
spv.func @joint_matrix_load_memaccess(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32) "None" {
// CHECK: {{%.*}} = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
%0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
// CHECK: {{%.*}} = spv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
%0 = spv.INTEL.JointMatrixLoad <Subgroup> <RowMajor> %ptr, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, i32) -> !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
spv.Return
}

// CHECK-LABEL: @joint_matrix_store
spv.func @joint_matrix_store(%ptr : !spv.ptr<i32, Workgroup>, %stride : i32, %m : !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>) "None" {
// CHECK: spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32)
spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> %ptr, %m, %stride : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32)
// CHECK: spv.INTEL.JointMatrixStore <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32)
spv.INTEL.JointMatrixStore <Subgroup> <RowMajor> %ptr, %m, %stride : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Workgroup>, i32)
spv.Return
}

// CHECK-LABEL: @joint_matrix_store_memaccess
spv.func @joint_matrix_store_memaccess(%ptr : !spv.ptr<i32, Workgroup>, %m : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %stride : i32) "None" {
// CHECK: spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
spv.JointMatrixStoreINTEL <Subgroup> <RowMajor> %ptr, %m, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
// CHECK: spv.INTEL.JointMatrixStore <Subgroup> <RowMajor> {{%.*}}, {{%.*}}, {{%.*}} {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
spv.INTEL.JointMatrixStore <Subgroup> <RowMajor> %ptr, %m, %stride {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
spv.Return
}

// CHECK-LABEL: @joint_matrix_length
spv.func @joint_matrix_length() -> i32 "None" {
// CHECK: {{%.*}} = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
%0 = spv.JointMatrixWorkItemLengthINTEL : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
// CHECK: {{%.*}} = spv.INTEL.JointMatrixWorkItemLength : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
%0 = spv.INTEL.JointMatrixWorkItemLength : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>
spv.ReturnValue %0 : i32
}

// CHECK-LABEL: @joint_matrix_muladd
spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" {
// CHECK: {{%.*}} = spv.JointMatrixMadINTEL <Subgroup> {{%.*}}, {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
%r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
// CHECK: {{%.*}} = spv.INTEL.JointMatrixMad <Subgroup> {{%.*}}, {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
%r = spv.INTEL.JointMatrixMad <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
spv.Return
}
}
4 changes: 2 additions & 2 deletions mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,14 @@ ConvertToModule::matchAndRewrite(Operation *op,

ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context)
: RewritePattern("test.convert_to_subgroup_ballot_op", 1, context,
{"spv.SubgroupBallotKHR"}) {}
{"spv.KHR.SubgroupBallot"}) {}

LogicalResult
ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
Value predicate = op->getOperand(0);

rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
rewriter.replaceOpWithNewOp<spirv::KHRSubgroupBallotOp>(
op, op->getResult(0).getType(), predicate);
return success();
}
Expand Down
8 changes: 6 additions & 2 deletions mlir/utils/spirv/define_inst.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@ file_name=$1
baseclass=$2

case $baseclass in
Op | ArithmeticBinaryOp | ArithmeticUnaryOp | LogicalBinaryOp | LogicalUnaryOp | CastOp | ControlFlowOp | StructureOp | AtomicUpdateOp | AtomicUpdateWithValueOp)
Op | ArithmeticBinaryOp | ArithmeticUnaryOp \
| LogicalBinaryOp | LogicalUnaryOp \
| CastOp | ControlFlowOp | StructureOp \
| AtomicUpdateOp | AtomicUpdateWithValueOp \
| KhrVendorOp | ExtVendorOp | IntelVendorOp | NvVendorOp )
;;
*)
echo "Usage : " $0 "<filename> <baseclass> (<opname>)*"
echo "<filename> is the file name of MLIR SPIR-V op definitions spec"
echo "<baseclass> must be one of " \
"(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp)"
"(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp|KhrVendorOp|ExtVendorOp|IntelVendorOp|NvVendorOp)"
exit 1;
;;
esac
Expand Down
14 changes: 11 additions & 3 deletions mlir/utils/spirv/gen_spirv_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,22 +730,29 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
'{{\n let summary = {summary};\n\n let description = '
'[{{\n{description}}}];{availability}\n')
else:
fmt_str = ('def SPV_{opname_src}Op : '
'SPV_{inst_category}<"{opname_src}"{category_args}[{traits}]> '
fmt_str = ('def SPV_{vendor_name}{opname_src}Op : '
'SPV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> '
'{{\n let summary = {summary};\n\n let description = '
'[{{\n{description}}}];{availability}\n')

vendor_name = ''
inst_category = existing_info.get('inst_category', 'Op')
if inst_category == 'Op':
fmt_str +='\n let arguments = (ins{args});\n\n'\
' let results = (outs{results});\n'
elif inst_category.endswith('VendorOp'):
vendor_name = inst_category.split('VendorOp')[0].upper()
assert len(vendor_name) != 0, 'Invalid instruction category'

fmt_str +='{extras}'\
'}}\n'

opname_src = instruction['opname']
if opname.startswith('Op'):
opname_src = opname_src[2:]
if len(vendor_name) > 0:
assert opname_src.endswith(vendor_name), "op name does not match the instruction category"
opname_src = opname_src[:-len(vendor_name)]

category_args = existing_info.get('category_args', '')

Expand All @@ -759,7 +766,7 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin

# Format summary. If the summary can fit in the same line, we print it out
# as a "-quoted string; otherwise, wrap the lines using "[{...}]".
summary = summary.strip();
summary = summary.strip()
if len(summary) + len(' let summary = "";') <= 80:
summary = '"{}"'.format(summary)
else:
Expand Down Expand Up @@ -815,6 +822,7 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
opcode=instruction['opcode'],
category_args=category_args,
inst_category=inst_category,
vendor_name=vendor_name,
traits=existing_info.get('traits', ''),
summary=summary,
description=description,
Expand Down