Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -912,12 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
// wmma
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
[4, 8, 16],
[F16, BF16,
I8, SI8, UI8,
I<4>, SI<4>, UI<4>,
F8E4M3FN, F8E5M2]>]>;
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;

Expand Down Expand Up @@ -968,6 +966,14 @@ def AMDGPU_MFMAOp :

The negateA, negateB, and negateC flags are only supported for double-precision
operations on gfx94x.

Example:
```mlir
%0 = amdgpu.mfma %matA * %matB + %matC
{ abid = 1 : i32, cbsz = 1 : i32,
m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32 }
blgp = bcast_second_32 : f32, f32, vector<32xf32>
```
}];
let assemblyFormat = [{
$sourceA `*` $sourceB `+` $destC
Expand All @@ -982,36 +988,43 @@ def AMDGPU_WMMAOp :
AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>,
Pure]>,
Arguments<(ins
ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m,
ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$k,
WMMAInTypes:$sourceA,
WMMAInTypes:$sourceB,
WMMAOutTypes:$destC,
DefaultValuedAttr<ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>, "0">:$subwordOffset,
DefaultValuedAttr<ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>, "0">:$subwordOffset,
UnitAttr:$unsignedA,
UnitAttr:$unsignedB,
UnitAttr:$clamp)>,
Results<(outs WMMAOutTypes: $destD)> {
let summary = "MLIR wrapper for RDNA3 wmma instructions";
let summary = "MLIR wrapper for wmma instructions";
let description = [{
The `amdgpu.wmma` op is an MLIR wrapper around intrinsics
for various `wmma` instructions in the RDNA3 or RDNA4 architecture, which
perform a 16x16 * 16x16 matrix multiplication for different data types.
Note that in gfx12/RDNA4, there is also a 16x32 * 32x16 instruction for 4-bit
integer inputs.
The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
instructions in the AMDGPU architecture, which perform matrix multiplication.
Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
dimensions.

On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
(or 16xbf16) vector containing only 8 valid values:
- If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14.
- If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15.
On gfx12/RDNA4, the result is instead returned as a vector<8 x f16/bf16> where
all values are valid and the `subwordOffset` must be `0`, as it cannot be used.
On gfx12/RDNA4 and gfx1250, the result is instead returned as vector where all
the values are valid and the `subwordOffset` must be `0`, as it cannot be used.

`unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.

The `clamp` flag is used to saturate the output of type T to numeric_limits<T>::max()
The `clamp` flag is used to saturate the output of type T to `numeric_limits<T>::max()`
in case of overflow.

Example:
```mlir
%0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While the syntax is okay, it is weird that the mfma instructions encode this stuff as an attribute dict while wmma does it as a custom parser

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can give mfma the same syntax, I didn't want to make too many changes in the same PR though

```
}];
let assemblyFormat = [{
$sourceA `*` $sourceB `+` $destC
custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
attr-dict
`:` type($sourceA) `,` type($sourceB) `,` type($destC)
}];
Expand Down
25 changes: 24 additions & 1 deletion mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This file declares a dialect for MLIR wrappers around AMDGPU-specific
// intrinssics and for other AMD GPU-specific functionality.
// intrinsics and for other AMD GPU-specific functionality.
//
//===----------------------------------------------------------------------===//

Expand All @@ -26,6 +26,29 @@

#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc"

namespace mlir::amdgpu {
/// Parser for the `custom<MNKDimensionList>` custom assembly format used by
/// WMMAOp.
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
IntegerAttr &n, IntegerAttr &k);
inline ParseResult parseMNKDimensionList(OpAsmParser &parser, Operation *,
IntegerAttr &m, IntegerAttr &n,
IntegerAttr &k) {
return parseMNKDimensionList(parser, m, n, k);
}

/// Printer for the `custom<MNKDimensionList>` custom assembly format used by
/// WMMAOp.
inline void printMNKDimensionList(OpAsmPrinter &printer, IntegerAttr m,
IntegerAttr n, IntegerAttr k) {
printer.printDimensionList(ArrayRef{m.getInt(), n.getInt(), k.getInt()});
}
inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *,
IntegerAttr m, IntegerAttr n, IntegerAttr k) {
printMNKDimensionList(printer, m, n, k);
}
} // namespace mlir::amdgpu

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc"

Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/IR/CommonAttrConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,11 @@ def IntPositivePowerOf2 : AllAttrOf<[IntPositive, IntPowerOf2]>;

class IntValidAlignment<Attr attr>: ConfinedAttr<attr, [IntPositivePowerOf2]>;

class IntIsOneOf<list<int> values> : AttrConstraint<
CPred<"::llvm::is_contained({" # !interleave(!foreach(val, values, val), ", ") #
"}, ::llvm::cast<::mlir::IntegerAttr>($_self).getInt())">,
"whose value is one of {" # !interleave(!foreach(val, values, val), ", ") # "}">;

class ArrayMaxCount<int n> : AttrConstraint<
CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() <= " # n>,
"with at most " # n # " elements">;
Expand Down
70 changes: 40 additions & 30 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
Expand Down Expand Up @@ -993,28 +994,36 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
/// on the architecture you are compiling for.
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
Chipset chipset) {
auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
auto elemSourceType = sourceVectorType.getElementType();
auto elemBSourceType = sourceBVectorType.getElementType();
auto elemDestType = destVectorType.getElementType();

if (elemSourceType.isF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
if (elemSourceType.isF16() && elemDestType.isF16())
return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isBF16())
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
if (chipset.majorVersion == 11) {
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
Type elemSourceType = sourceVectorType.getElementType();
Type elemBSourceType = sourceBVectorType.getElementType();
Type elemDestType = destVectorType.getElementType();

const uint32_t k = wmma.getK();

if (k == 16) {
if (elemSourceType.isF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
if (elemSourceType.isF16() && elemDestType.isF16())
return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isBF16())
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
if (chipset.majorVersion == 11) {
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
}
}
if (chipset.majorVersion >= 12) {
if (chipset.majorVersion < 12)
return std::nullopt;

// gfx12+
if (k == 16) {
if (isa<Float8E4M3FNType>(elemSourceType) &&
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
Expand All @@ -1027,17 +1036,18 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
if (isa<Float8E5M2Type>(elemSourceType) &&
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
bool isWave64 = destVectorType.getNumElements() == 4;
// This is the ambiguous case. 8 inputs to the wave64 version means that
// we want the 16x16x32 version, but for wave32 they mean the short form.
bool has8Inputs = sourceVectorType.getNumElements() == 8;
if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
}

return std::nullopt;
}
return std::nullopt;
if (k == 32) {
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
return std::nullopt;
}

llvm_unreachable("unhandled WMMA case");
}

namespace {
Expand Down
62 changes: 35 additions & 27 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,45 +360,53 @@ LogicalResult ScaledExtPacked816Op::verify() {
//===----------------------------------------------------------------------===//
// WMMAOp
//===----------------------------------------------------------------------===//
LogicalResult WMMAOp::verify() {
Type sourceAType = getSourceA().getType();
Type sourceBType = getSourceB().getType();
Type destType = getDestC().getType();

VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
VectorType destVectorType = dyn_cast<VectorType>(destType);
ParseResult mlir::amdgpu::parseMNKDimensionList(OpAsmParser &parser,
IntegerAttr &m, IntegerAttr &n,
IntegerAttr &k) {
SmallVector<int64_t, 3> dimensions;
if (parser.parseDimensionList(dimensions, false, false))
return failure();
if (dimensions.size() != 3)
return parser.emitError(parser.getCurrentLocation())
<< "expected 3 dimensions in MNK dimension list";

Type sourceAElemType = sourceVectorAType.getElementType();
Type sourceBElemType = sourceVectorBType.getElementType();
Type destElemType = destVectorType.getElementType();
m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
return success();
}

if (sourceVectorAType.getNumElements() !=
sourceVectorBType.getNumElements()) {
LogicalResult WMMAOp::verify() {
auto sourceAType = cast<VectorType>(getSourceA().getType());
auto sourceBType = cast<VectorType>(getSourceB().getType());
auto destType = cast<VectorType>(getDestC().getType());

Type sourceAElemType = sourceAType.getElementType();
Type sourceBElemType = sourceBType.getElementType();
if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
return emitOpError("source vectors have different lengths: ")
<< sourceVectorAType << " vs. " << sourceVectorBType;
<< sourceAType << " vs. " << sourceBType;
}

bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
bool isSrcFloat =
isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
sourceAElemType);

if (isDestFloat && !isSrcFloat) {
return emitOpError("Expected float sources with float destination");
}
bool isDestFloat = destType.getElementType().isFloat();
bool isSrcFloat = sourceAElemType.isFloat();

if (!isDestFloat && isSrcFloat) {
return emitOpError("Expected int sources with int destination");
}
if (isDestFloat && !isSrcFloat)
return emitOpError("expected float sources with float destination");
if (!isDestFloat && isSrcFloat)
return emitOpError("expected int sources with int destination");

if (sourceAElemType != sourceBElemType &&
!(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
return emitOpError(
"source element types much match (except for fp8) but have ")
<< sourceAType << " and " << sourceBType;
}

if (!sourceAElemType.isInteger(4) && getK() != 16) {
return emitOpError("K dimension must be 16 for source element type ")
<< sourceAElemType;
}
return success();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,35 +1,36 @@
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s

// CHECK-LABEL: @wmma_to_rocdl
func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
%arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>,
%arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>,
%arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) {
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
// CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
// CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16>
amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
// CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
amdgpu.wmma %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
amdgpu.wmma %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>

func.return
}
Loading