diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index f795dd89b79a1..2cb60b3836416 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -552,9 +552,14 @@ def MFMAOutTypes : AnyTypeOf<[F64, VectorOfLengthAndType<[4, 16, 32], [I32]>, VectorOfLengthAndType<[4], [F64]>]>; // wmma -def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F16, BF16, I8, SI8, UI8, F8E4M3FN, F8E5M2]>]>; +def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType< + [4, 8, 16], + [F16, BF16, + I8, SI8, UI8, + I<4>, SI<4>, UI<4>, + F8E4M3FN, F8E5M2]>]>; def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>, - VectorOfLengthAndType<[8, 16], [F16, BF16]>]>; + VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>; def AMDGPU_MFMAOp : AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>, @@ -615,8 +620,7 @@ def AMDGPU_MFMAOp : def AMDGPU_WMMAOp : AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>, - AllTypesMatch<["sourceA", "sourceB"]>, - Pure]>, + Pure]>, Arguments<(ins WMMAInTypes:$sourceA, WMMAInTypes:$sourceB, @@ -629,13 +633,17 @@ def AMDGPU_WMMAOp : let summary = "MLIR wrapper for RDNA3 wmma instructions"; let description = [{ The `amdgpu.wmma` op is an MLIR wrapper around intrinsics - for various `wmma` instructions in the RDNA3 architecture, which perform - a 16x16 matrix multiplication for different data types. + for various `wmma` instructions in the RDNA3 or RDNA4 architecture, which + perform a 16x16 * 16x16 matrix multiplication for different data types. + Note that in gfx12/RDNA4, there is also a 16x32 * 32x16 instruction for 4-bit + integer inputs. - When emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16 (or 16xbf16) vector - containing only 8 valid values: + On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16 + (or 16xbf16) vector containing only 8 valid values: - If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14. - If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15. + On gfx12/RDNA4, the result is instead returned as a vector<8 x f16/bf16> where + all values are valid and the `subwordOffset` must be `0`, as it cannot be used. `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned. diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 673ea480ad3fa..18fec95f700c4 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -410,8 +410,11 @@ def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>; def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>; // Available from gfx12 -def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>; -def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>; +def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>; +def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_bf8", [1]>; +def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>; +def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>; +def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>; //===---------------------------------------------------------------------===// // LDS transpose intrinsics (available in GFX950) @@ -771,7 +774,7 @@ def ROCDL_CvtScaleF32Bf8Op : Arguments<(ins I32:$src, F32: $scale, I32:$byteSel)> { let summary = "Scale and convert bf8 to f32"; let description = [{ - Scale `src` by the exponent in `scale` then convert 8-bit bf8 value + Scale `src` by the exponent in `scale` then convert 8-bit bf8 value from the `byteSel`th bit of `src` to fp32. }]; let assemblyFormat = [{ diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index b8574bbbee345..201794e2b0e1a 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -403,8 +403,11 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, /// Push an input operand. If it is a float type, nothing to do. If it is /// an integer type, then we need to also push its signdness (1 for signed, 0 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32 -/// vector. We also need to convert bfloat inputs to i16 to account for the lack -/// of bfloat support in the WMMA intrinsics themselves. +/// vector (or the 8xi8 vector into a 2xi32 one for gfx12+). +/// We also need to convert bfloat inputs to i16 to account for the bfloat +/// intrinsics having been defined before the AMD backend supported bfloat. We +/// similarly need to pack 8-bit float types into integers as if they were i8 +/// (which they are for the backend's purposes). static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, @@ -413,12 +416,16 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, SmallVector &operands) { Type inputType = llvmInput.getType(); auto vectorType = dyn_cast(inputType); + if (!vectorType) { + operands.push_back(llvmInput); + return; + } Type elemType = vectorType.getElementType(); if (elemType.isBF16()) llvmInput = rewriter.create( loc, vectorType.clone(rewriter.getI16Type()), llvmInput); - if (!elemType.isInteger(8)) { + if (elemType.getIntOrFloatBitWidth() > 8) { operands.push_back(llvmInput); return; } @@ -427,26 +434,34 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, // for int8. This is because, in LLVM, fp8 type is converted to int8, so the // fp8/int8 information is lost during the conversion process. auto mlirInputType = cast(mlirInput.getType()); - bool isInputInt8 = mlirInputType.getElementType().isInteger(8); - if (isInputInt8) { + bool isInputInteger = mlirInputType.getElementType().isInteger(); + if (isInputInteger) { // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag bool localIsUnsigned = isUnsigned; - if (elemType.isUnsignedInteger(8)) { + if (elemType.isUnsignedInteger()) { localIsUnsigned = true; - } else if (elemType.isSignedInteger(8)) { + } else if (elemType.isSignedInteger()) { localIsUnsigned = false; } Value sign = createI1Constant(rewriter, loc, !localIsUnsigned); operands.push_back(sign); } - int64_t numBytes = vectorType.getNumElements(); + int64_t numBits = + vectorType.getNumElements() * elemType.getIntOrFloatBitWidth(); Type i32 = rewriter.getI32Type(); - VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32); - auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits); - Value result = rewriter.createOrFold( - loc, llvmVectorType32bits, llvmInput); - operands.push_back(result); + Type intrinsicInType = numBits <= 32 + ? (Type)rewriter.getIntegerType(numBits) + : (Type)VectorType::get(numBits / 32, i32); + auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType); + Value castInput = rewriter.createOrFold( + loc, llvmIntrinsicInType, llvmInput); + // The wave64-mode 16x16x16 intrinsics that take 4-bit integers only need + // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments. + // Add in the zeros here. + if (numBits < 32) + castInput = rewriter.create(loc, i32, castInput); + operands.push_back(castInput); } /// Push the output operand. For many cases this is only pushing the output in @@ -454,7 +469,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, /// since the same numbers of VGPRs is used, we need to decide if to store the /// result in the upper 16 bits of the VGPRs or in the lower part. To store the /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will -/// be stored it in the upper part +/// be stored it in the upper part. The subwordOffset must not be set for gfx12, +/// as the instructions have been changed to return fewer registers instead. static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, @@ -617,8 +633,10 @@ static std::optional mfmaOpToIntrinsic(MFMAOp mfma, static std::optional wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset) { auto sourceVectorType = dyn_cast(wmma.getSourceA().getType()); + auto sourceBVectorType = dyn_cast(wmma.getSourceB().getType()); auto destVectorType = dyn_cast(wmma.getDestC().getType()); auto elemSourceType = sourceVectorType.getElementType(); + auto elemBSourceType = sourceBVectorType.getElementType(); auto elemDestType = destVectorType.getElementType(); if (elemSourceType.isF16() && elemDestType.isF32()) @@ -631,10 +649,33 @@ static std::optional wmmaOpToIntrinsic(WMMAOp wmma, return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); - if (isa(elemSourceType) && elemDestType.isF32()) - return ROCDL::wmma_f32_16x16x16_fp8::getOperationName(); - if (isa(elemSourceType) && elemDestType.isF32()) - return ROCDL::wmma_f32_16x16x16_bf8::getOperationName(); + if (chipset.majorVersion == 11) { + if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) + return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); + } + if (chipset.majorVersion >= 12) { + if (isa(elemSourceType) && + isa(elemBSourceType) && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName(); + if (isa(elemSourceType) && + isa(elemBSourceType) && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName(); + if (isa(elemSourceType) && + isa(elemBSourceType) && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName(); + if (isa(elemSourceType) && + isa(elemBSourceType) && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName(); + if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) { + bool isWave64 = destVectorType.getNumElements() == 4; + // This is the ambiguous case. 8 inputs to the wave64 version means that + // we want the 16x16x32 version, but for wave32 they mean the short form. + bool has8Inputs = sourceVectorType.getNumElements() == 8; + if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs)) + return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); + return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); + } + } return std::nullopt; } @@ -712,6 +753,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { if (!maybeIntrinsic.has_value()) return op.emitOpError("no intrinsic matching WMMA on the given chipset"); + if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0) + return op.emitOpError("subwordOffset not supported on gfx12+"); + OperationState loweredOp(loc, *maybeIntrinsic); loweredOp.addTypes(rawOutType); diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 271ca382e2f0b..4641fbb280bcb 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -226,14 +226,23 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns( //===----------------------------------------------------------------------===// LogicalResult WMMAOp::verify() { Type sourceAType = getSourceA().getType(); + Type sourceBType = getSourceB().getType(); Type destType = getDestC().getType(); VectorType sourceVectorAType = dyn_cast(sourceAType); + VectorType sourceVectorBType = dyn_cast(sourceBType); VectorType destVectorType = dyn_cast(destType); Type sourceAElemType = sourceVectorAType.getElementType(); + Type sourceBElemType = sourceVectorBType.getElementType(); Type destElemType = destVectorType.getElementType(); + if (sourceVectorAType.getNumElements() != + sourceVectorBType.getNumElements()) { + return emitOpError("source vectors have different lengths: ") + << sourceVectorAType << " vs. " << sourceVectorBType; + } + bool isDestFloat = isa(destElemType); bool isSrcFloat = isa( @@ -247,6 +256,13 @@ LogicalResult WMMAOp::verify() { return emitOpError("Expected int sources with int destination"); } + if (sourceAElemType != sourceBElemType && + !(isa(sourceAElemType) && + isa(sourceBElemType))) { + return emitOpError( + "source element types much match (except for fp8) but have ") + << sourceAType << " and " << sourceBType; + } return success(); } diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir index 7b2b524d4af42..94a1b78d5f040 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir @@ -1,9 +1,68 @@ // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s -func.func @mfma_to_rocdl(%arg0 : vector<8xf8E4M3FN>, %arg1 : vector<8xf8E5M2>, %arg2 : vector<8xf32>) { - // CHECK: rocdl.wmma.f32.16x16x16.fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> - amdgpu.wmma %arg0 * %arg0 + %arg2: vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32> +// CHECK-LABEL: @wmma_to_rocdl +func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>, + %arg2 : vector<8xf32>, %arg3 : vector<4xf32>, + %arg4 : vector<8xbf16>, %arg5 : vector<4xbf16>, + %arg6 : vector<8xf8E4M3FN>, %arg7 : vector<4xf8E4M3FN>, + %arg8 : vector<8xf8E5M2>, %arg9 : vector<4xf8E5M2>, + %arg10 : vector<8xi8>, %arg11 : vector<4xi8>, + %arg12 : vector<8xi32>, %arg13 : vector<4xi32>, + %arg14 : vector<16xi4>, %arg15 : vector<8xi4>, %arg16 : vector<4xi4>) { + // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf32>) -> vector<8xf32> + amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<8xf16>, vector<8xf16>, vector<8xf32> + // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32> + amdgpu.wmma %arg1 * %arg1 + %arg3 : vector<4xf16>, vector<4xf16>, vector<4xf32> + + // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xf32>) -> vector<8xf32> + amdgpu.wmma %arg4 * %arg4 + %arg2 : vector<8xbf16>, vector<8xbf16>, vector<8xf32> + // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32> + amdgpu.wmma %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32> + + // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16> + amdgpu.wmma %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16> + // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>, i1) -> vector<4xf16> + amdgpu.wmma %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16> + + // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>, i1) -> vector<8xi16> + // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16> + amdgpu.wmma %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16> + // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>, i1) -> vector<4xi16> + amdgpu.wmma %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16> + + // CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> + amdgpu.wmma %arg6 * %arg6 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32> + // CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32> + amdgpu.wmma %arg7 * %arg7 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E4M3FN>, vector<4xf32> + + // CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> + amdgpu.wmma %arg6 * %arg8 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E5M2>, vector<8xf32> + // CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32> + amdgpu.wmma %arg7 * %arg9 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E5M2>, vector<4xf32> + + // CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> + amdgpu.wmma %arg8 * %arg8 + %arg2 : vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32> + // CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32> + amdgpu.wmma %arg9 * %arg9 + %arg3 : vector<4xf8E5M2>, vector<4xf8E5M2>, vector<4xf32> + + // CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> + amdgpu.wmma %arg8 * %arg6 + %arg2 : vector<8xf8E5M2>, vector<8xf8E4M3FN>, vector<8xf32> + // CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32> + amdgpu.wmma %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32> + + // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + amdgpu.wmma %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> + amdgpu.wmma %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32> + + // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + amdgpu.wmma %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> + amdgpu.wmma %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32> + + // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<8xi32>, i1) -> vector<8xi32> + amdgpu.wmma %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> + amdgpu.wmma %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32> - // CHECK: rocdl.wmma.f32.16x16x16.bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> - amdgpu.wmma %arg1 * %arg1 + %arg2: vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32> func.return } diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir index 7b144809235d5..638a7c3f8c1c5 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir @@ -1,8 +1,9 @@ // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s -func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>, +// CHECK-LABEL: @wmma_to_rocdl +func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>, %arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>, - %arg6 : vector<16xi8>, %arg7 : vector<4xi32>, %arg8 : vector<8xi32>, - %arg9 : vector<16xui8>) { + %arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>, + %arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) { // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32> // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32> @@ -21,9 +22,14 @@ func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16> // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16> amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16> - // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32> - amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<4xi32> // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> - amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<8xi32> + amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32> + amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + amdgpu.wmma %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> + amdgpu.wmma %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32> + func.return } diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index e0ca51afb02d6..981ef4848535c 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -793,6 +793,10 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}}) %r6 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}}) + %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + // ---- Wave64 ----- // f16 -> f32 @@ -872,8 +876,14 @@ llvm.func @rocdl.wmma.fp8(%arg0 : vector<2 x i32>, %arg1 : vector<8xf32>) -> vec // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.fp8.fp8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}}) %r0 = rocdl.wmma.f32.16x16x16.fp8_fp8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> + // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.fp8.bf8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}}) + %r1 = rocdl.wmma.f32.16x16x16.fp8_bf8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> + // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf8.bf8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}}) - %r1 = rocdl.wmma.f32.16x16x16.bf8_bf8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> + %r2 = rocdl.wmma.f32.16x16x16.bf8_bf8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> + + // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf8.fp8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}}) + %r3 = rocdl.wmma.f32.16x16x16.bf8_fp8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> llvm.return %r0 : vector<8 x f32> }