Skip to content

Commit

Permalink
Support transpose mode for gpu.subgroup WMMA ops
Browse files Browse the repository at this point in the history
Add support for loading, computing, and storing `gpu.subgroup` WMMA ops
in transpose mode as well. Update the GPU to NVVM lowerings to support
`transpose` mode and update integration tests as well.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D139021
  • Loading branch information
navdeepkk-polymagelabs authored and bondhugula committed Dec 5, 2022
1 parent 03b3017 commit 3d35546
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 64 deletions.
35 changes: 22 additions & 13 deletions mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
Expand Up @@ -1141,7 +1141,8 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
`!gpu.mma_matrix` is the source value containing the data to be stored into the
destination memref which can be in global or shared memory. The store address
is determined using the indices provided. The `leadDimension` attribute
specifies the leading dimension of the destination matrix.
specifies the leading dimension of the destination matrix. If the
`transpose` attribute is present then the op does a transposed store.

This op is often meant to be used along with `gpu.subgroup_mma_load_matrix` and
`gpu.subgroup_mma_compute`.
Expand All @@ -1157,16 +1158,17 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$src,
Arg<GPU_MMAMemRef, "",[MemWrite]>:$dstMemref,
Variadic<Index>:$indices,
IndexAttr:$leadDimension);
IndexAttr:$leadDimension,
OptionalAttr<UnitAttr>:$transpose);

let assemblyFormat = [{
$src`,` $dstMemref`[`$indices`]` attr-dict `:` type($src)`,` type($dstMemref)
}];
let hasVerifier = 1;
}

def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute",
[Pure, AllTypesMatch<["opC", "res"]>]>{
def GPU_SubgroupMmaComputeOp
: GPU_Op<"subgroup_mma_compute", [Pure, AllTypesMatch<["opC", "res"]>]> {

let summary = "GPU warp synchronous matrix multiply accumulate";

Expand All @@ -1175,9 +1177,14 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute",
operation using all the threads in a subgroup.

This operation takes three `!gpu.mma_matrix`s as arguments: these hold `A`,
`B` and `C`operands for the mma operation. The operation performed is represented
`B` and `C`operands for the mma operation. The operation performed is represented
as `C += A * B`. The op returns a `!gpu.mma_matrix` which contains the result of
the operation held by all threads in a subgroup.
the operation held by all threads in a subgroup. `a_transpose` or
`b_transpose` if present, signify that the respective operand was loaded in a
transposed manner. The transpose opernads are required to map to correct
underlying intrisics but they currently do not seem to affect correctness
even if they are absent given that the operands were loaded correctly using
the `transpose` attribute in `gpu.subgroup_mma_load_matrix` op.

This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
`gpu.subgroup_mma_load_matrix` ops.
Expand All @@ -1193,9 +1200,11 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute",

let arguments = (ins Arg<MMAMatrixOf<[F16, F32]>>:$opA,
Arg<MMAMatrixOf<[F16, F32]>>:$opB,
Arg<MMAMatrixOf<[F16, F32]>>:$opC);
Arg<MMAMatrixOf<[F16, F32]>>:$opC,
OptionalAttr<UnitAttr>:$a_transpose,
OptionalAttr<UnitAttr>:$b_transpose);

let results = (outs GPU_MMAMatrix:$res);
let results = (outs GPU_MMAMatrix : $res);

let assemblyFormat = [{
$opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB) `->` type($res)
Expand All @@ -1215,11 +1224,11 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
The `gpu.subgroup_mma_constant_matrix` creates a `!gpu.mma_matrix` with
constant elements.

The operation takes a scalar input and return a `!gpu.mma_matrix` where each
element of is equal to the operand constant. The destination mma_matrix type
must have elememt type equal to the constant type. Since the layout of
`!gpu.mma_matrix` is opaque this only support setting all the elements to
the same value.
The operation takes a scalar input and return a `!gpu.mma_matrix` where
each element of is equal to the operand constant. The destination
mma_matrix type must have elememt type equal to the constant type. Since
the layout of `!gpu.mma_matrix` is opaque this only support setting all the
elements to the same value.

This op is meant to be used along with `gpu.subgroup_mma_compute`.

Expand Down
23 changes: 14 additions & 9 deletions mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
Expand Up @@ -77,12 +77,11 @@ struct WmmaLoadOpToNVVMLowering
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
return failure();

// TODO: Support transposed mma loads.
if (subgroupMmaLoadMatrixOp.getTranspose())
return failure();

// Get the shape of the MMAMatrix type being returned. The shape will
// choose which intrinsic this op will be lowered to.
NVVM::MMALayout layout = subgroupMmaLoadMatrixOp.getTranspose()
? NVVM::MMALayout::col
: NVVM::MMALayout::row;
gpu::MMAMatrixType retType =
subgroupMmaLoadMatrixOp.getRes().getType().cast<gpu::MMAMatrixType>();
ArrayRef<int64_t> retTypeShape = retType.getShape();
Expand All @@ -105,7 +104,6 @@ struct WmmaLoadOpToNVVMLowering
n = retTypeShape[1];
k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
}
NVVM::MMALayout layout = NVVM::MMALayout::row;
NVVM::MMAFrag frag = convertOperand(retType.getOperand());
// Check that there is an exisiting instruction for the combination we need.
if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
Expand Down Expand Up @@ -154,7 +152,9 @@ struct WmmaStoreOpToNVVMLowering
gpu::MMAMatrixType srcType =
subgroupMmaStoreMatrixOp.getSrc().getType().cast<gpu::MMAMatrixType>();
ArrayRef<int64_t> srcTypeShape = srcType.getShape();
NVVM::MMALayout layout = NVVM::MMALayout::row;
NVVM::MMALayout layout = subgroupMmaStoreMatrixOp.getTranspose()
? NVVM::MMALayout::col
: NVVM::MMALayout::row;
NVVM::MMATypes eltype = getElementType(srcType);
int64_t m = srcTypeShape[0];
int64_t n = srcTypeShape[1];
Expand Down Expand Up @@ -224,10 +224,15 @@ struct WmmaMmaOpToNVVMLowering
int64_t m = cTypeShape[0];
int64_t n = cTypeShape[1];
int64_t k = aTypeShape[1];
NVVM::MMALayout layout = NVVM::MMALayout::row;
NVVM::MMALayout aLayout = subgroupMmaComputeOp.getATranspose()
? NVVM::MMALayout::col
: NVVM::MMALayout::row;
NVVM::MMALayout bLayout = subgroupMmaComputeOp.getBTranspose()
? NVVM::MMALayout::col
: NVVM::MMALayout::row;
NVVM::MMATypes sourceType = getElementType(aType);
NVVM::MMATypes destType = getElementType(cType);
if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, layout, layout, sourceType,
if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, aLayout, bLayout, sourceType,
destType) == 0)
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);

Expand All @@ -236,7 +241,7 @@ struct WmmaMmaOpToNVVMLowering
unpackOp(adaptor.getOpC());

rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>(
op, adaptor.getOpC().getType(), m, n, k, layout, layout, sourceType,
op, adaptor.getOpC().getType(), m, n, k, aLayout, bLayout, sourceType,
destType, unpackedOps);
return success();
}
Expand Down
13 changes: 7 additions & 6 deletions mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
Expand Up @@ -87,10 +87,9 @@ struct WmmaLoadOpToSPIRVLowering
auto i32Type = rewriter.getI32Type();
auto strideValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, IntegerAttr::get(i32Type, stride));
bool useColMajor =
static_cast<bool>(subgroupMmaLoadMatrixOp.getTranspose());
bool isColMajor = static_cast<bool>(subgroupMmaLoadMatrixOp.getTranspose());
auto columnMajor = rewriter.create<spirv::ConstantOp>(
loc, rewriter.getI1Type(), rewriter.getBoolAttr(useColMajor));
loc, rewriter.getI1Type(), rewriter.getBoolAttr(isColMajor));
rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixLoadOp>(
subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, columnMajor,
spirv::MemoryAccessAttr());
Expand Down Expand Up @@ -118,11 +117,13 @@ struct WmmaStoreOpToSPIRVLowering
auto i32Type = rewriter.getI32Type();
auto strideValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, IntegerAttr::get(i32Type, stride));
auto coloumnMajor = rewriter.create<spirv::ConstantOp>(
loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
bool useColMajor =
static_cast<bool>(subgroupMmaStoreMatrixOp.getTranspose());
auto columnMajor = rewriter.create<spirv::ConstantOp>(
loc, rewriter.getI1Type(), rewriter.getBoolAttr(useColMajor));
rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixStoreOp>(
subgroupMmaStoreMatrixOp, bufferPtr, adaptor.getSrc(), strideValue,
coloumnMajor, spirv::MemoryAccessAttr());
columnMajor, spirv::MemoryAccessAttr());
return success();
}
};
Expand Down
11 changes: 6 additions & 5 deletions mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
Expand Up @@ -473,9 +473,9 @@ static void convertTransferWriteOp(vector::TransferWriteOp op,
assert(stride);
OpBuilder b(op);
Value matrix = valueMapping.find(op.getVector())->second;
b.create<gpu::SubgroupMmaStoreMatrixOp>(op.getLoc(), matrix, op.getSource(),
op.getIndices(),
b.getIndexAttr(*stride));
b.create<gpu::SubgroupMmaStoreMatrixOp>(
op.getLoc(), matrix, op.getSource(), op.getIndices(),
b.getIndexAttr(*stride), /*transpose=*/UnitAttr());
op.erase();
}

Expand Down Expand Up @@ -800,8 +800,9 @@ static void convertContractOp(vector::ContractionOp op,
Value opA = valueMapping.find(op.getLhs())->second;
Value opB = valueMapping.find(op.getRhs())->second;
Value opC = valueMapping.find(op.getAcc())->second;
Value matmul = b.create<gpu::SubgroupMmaComputeOp>(op.getLoc(), opC.getType(),
opA, opB, opC);
Value matmul = b.create<gpu::SubgroupMmaComputeOp>(
op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
/*b_transpose=*/UnitAttr());
valueMapping[op.getResult()] = matmul;
}

Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
Expand Up @@ -10,7 +10,7 @@ gpu.module @test_module {
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
%i = arith.constant 16 : index
%j = arith.constant 16 : index
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
// CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
// CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<f16, 3>, ptr<f16, 3>, i64, array<2 x i64>, array<2 x i64>)>
Expand All @@ -20,7 +20,7 @@ gpu.module @test_module {
// CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
// CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
// CHECK-SAME: {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK-SAME: {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>

// CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
Expand All @@ -32,7 +32,7 @@ gpu.module @test_module {
// CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
// CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK32: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]]
// CHECK32-SAME: {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK32-SAME: {eltype = #nvvm.mma_type<f16>, frag = #nvvm.mma_frag<a>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<f16, 3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK32: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
return %0 : !gpu.mma_matrix<16x16xf16, "AOp">
}
Expand All @@ -50,7 +50,7 @@ gpu.module @test_module {
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
%i = arith.constant 16 : index
%j = arith.constant 16 : index
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
// CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64
// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
// CHECK: %[[EL1:.*]] = llvm.extractvalue %[[D]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
Expand All @@ -64,7 +64,7 @@ gpu.module @test_module {
// CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i64) -> !llvm.ptr<f16, 3>
// CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK: nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
// CHECK-SAME: {eltype = #nvvm.mma_type<f16>, k = 16 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
// CHECK-SAME: {eltype = #nvvm.mma_type<f16>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
// CHECK: llvm.return

// CHECK32: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
Expand All @@ -80,7 +80,7 @@ gpu.module @test_module {
// CHECK32: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<f16, 3>, i32) -> !llvm.ptr<f16, 3>
// CHECK32: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32
// CHECK32: nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]]
// CHECK32-SAME: {eltype = #nvvm.mma_type<f16>, k = 16 : i32, layout = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
// CHECK32-SAME: {eltype = #nvvm.mma_type<f16>, k = 16 : i32, layout = #nvvm.mma_layout<col>, m = 16 : i32, n = 16 : i32} : !llvm.ptr<f16, 3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
// CHECK32: llvm.return
return
}
Expand All @@ -93,7 +93,7 @@ gpu.module @test_module {
// CHECK-LABEL: func @gpu_wmma_mma_op
// CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
func.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> (!gpu.mma_matrix<16x16xf16, "COp">) {
%D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
%D = gpu.subgroup_mma_compute %A, %B, %C {a_transpose} : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[A1:.*]] = llvm.extractvalue %[[A]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A2:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A3:.*]] = llvm.extractvalue %[[A]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
Expand All @@ -115,7 +115,7 @@ gpu.module @test_module {
// CHECK: %[[C3:.*]] = llvm.extractvalue %[[C]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[C4:.*]] = llvm.extractvalue %[[C]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[RES:.*]] = nvvm.wmma.mma %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]]
// CHECK-SAME: {eltypeA = #nvvm.mma_type<f16>, eltypeB = #nvvm.mma_type<f16>, k = 16 : i32, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : (
// CHECK-SAME: {eltypeA = #nvvm.mma_type<f16>, eltypeB = #nvvm.mma_type<f16>, k = 16 : i32, layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<row>, m = 16 : i32, n = 16 : i32} : (
// CHECK-SAME: vector<2xf16>, {{.*}}) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.return %[[RES]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
return %D : !gpu.mma_matrix<16x16xf16, "COp">
Expand Down

0 comments on commit 3d35546

Please sign in to comment.