From 5c5d3d92a8dd42f8f20fd7ea9e75ac48febadc55 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Wed, 11 Sep 2024 18:23:39 +0000 Subject: [PATCH 1/4] [mlir][AMDGPU] Remove an old bf16 workaround The AMDGPU backend now implements LLVM's `bfloat` type. Therefore, we no longer need to type convert MLIR's `bf16` to `i16` during lowerings to ROCDL. As a result of this change, we discovered that, whel the code for MFMA and WMMA intrinsics was mainly prepared for this change, we were failing to bitcast the bf16 results of WMMA operations out from the i16 they're natively represented as. This commit also fixes that issue. --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 27 ++++++++++--------- mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir | 1 + .../Conversion/GPUToROCDL/gpu-to-rocdl.mlir | 14 +++++----- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index c2785f34564e3..31d35390a7e7f 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -671,18 +671,25 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Type outType = typeConverter->convertType(op.getDestD().getType()); + auto outType = + cast(typeConverter->convertType(op.getDestD().getType())); if (chipset.majorVersion != 11 && chipset.majorVersion != 12) return op->emitOpError("WMMA only supported on gfx11 and gfx12"); + // The WMMA operations represent vectors of bf16s as vectors of i16s, so we + // need to bitcast bfloats to i16 and then bitcast them back. + VectorType rawOutType = outType; + if (outType.getElementType().isBF16()) + rawOutType = outType.clone(rewriter.getI16Type()); + std::optional maybeIntrinsic = wmmaOpToIntrinsic(op, chipset); if (!maybeIntrinsic.has_value()) return op.emitOpError("no intrinsic matching WMMA on the given chipset"); OperationState loweredOp(loc, *maybeIntrinsic); - loweredOp.addTypes(outType); + loweredOp.addTypes(rawOutType); SmallVector operands; wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), @@ -694,7 +701,12 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { loweredOp.addOperands(operands); Operation *lowered = rewriter.create(loweredOp); - rewriter.replaceOp(op, lowered->getResults()); + + Operation *maybeCastBack = lowered; + if (rawOutType != outType) + maybeCastBack = + rewriter.create(loc, outType, lowered->getResult(0)); + rewriter.replaceOp(op, maybeCastBack->getResults()); return success(); } @@ -1033,15 +1045,6 @@ struct ConvertAMDGPUToROCDLPass void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, Chipset chipset) { - converter.addConversion([](BFloat16Type t) -> Type { - return IntegerType::get(t.getContext(), 16); - }); - converter.addConversion([&converter](VectorType t) -> std::optional { - if (!t.getElementType().isBF16()) - return std::nullopt; - return converter.convertType(t.clone(IntegerType::get(t.getContext(), 16))); - }); - patterns .add, RawBufferOpLowering, diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir index 1a4ef33db2aed..9ca89a0babd95 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir @@ -16,6 +16,7 @@ func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16> amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16> // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16> + // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi16> to vector<16xbf16> amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16> // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16> amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16> diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir index 56b65beb03695..3fa9fa5e935d2 100644 --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir @@ -445,22 +445,22 @@ gpu.module @test_module { // ----- -// Test that the bf16 type is lowered away on this target. +// Test that the bf16 type is passed through to LLVM. gpu.module @test_module { // CHECK-LABEL: func @bf16_id func.func @bf16_id(%arg0 : bf16) -> bf16 { - // CHECK-SAME: (%[[ARG0:.+]]: i16) - // CHECK-SAME: -> i16 - // CHECK: return %[[ARG0]] : i16 + // CHECK-SAME: (%[[ARG0:.+]]: bf16) + // CHECK-SAME: -> bf16 + // CHECK: return %[[ARG0]] : bf16 func.return %arg0 : bf16 } // CHECK-LABEL: func @bf16x4_id func.func @bf16x4_id(%arg0 : vector<4xbf16>) -> vector<4xbf16> { - // CHECK-SAME: (%[[ARG0:.+]]: vector<4xi16>) - // CHECK-SAME: -> vector<4xi16> - // CHECK: return %[[ARG0]] : vector<4xi16> + // CHECK-SAME: (%[[ARG0:.+]]: vector<4xbf16>) + // CHECK-SAME: -> vector<4xbf16> + // CHECK: return %[[ARG0]] : vector<4xbf16> func.return %arg0 : vector<4xbf16> } From 5f842ca428bc401fda31228f32e828a45d1d5b37 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Thu, 12 Sep 2024 18:12:45 +0000 Subject: [PATCH 2/4] Review comments --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 5 ++++- mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir | 7 ++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 31d35390a7e7f..42d22d9dfae24 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -672,7 +672,10 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto outType = - cast(typeConverter->convertType(op.getDestD().getType())); + typeConverter->convertType(op.getDestD().getType()); + if (!outType) + return rewriter.notifyMatchFailure( + op, "wmma output doesn't convert to a vector for no clear reason"); if (chipset.majorVersion != 11 && chipset.majorVersion != 12) return op->emitOpError("WMMA only supported on gfx11 and gfx12"); diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir index 9ca89a0babd95..7b144809235d5 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir @@ -15,10 +15,11 @@ func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16> // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16> amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16> - // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16> - // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi16> to vector<16xbf16> + // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16> + // CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16> amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16> - // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16> + // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16> + // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16> amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16> // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32> amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<4xi32> From d7f7973804d5fd38bcb0b8253ee949f707e71de0 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Thu, 12 Sep 2024 14:02:26 -0500 Subject: [PATCH 3/4] Shorten error message on type conversion failure Co-authored-by: Jakub Kuderski --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 42d22d9dfae24..28991b6b9ed2b 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -675,7 +675,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { typeConverter->convertType(op.getDestD().getType()); if (!outType) return rewriter.notifyMatchFailure( - op, "wmma output doesn't convert to a vector for no clear reason"); + op, "type conversion failed"); if (chipset.majorVersion != 11 && chipset.majorVersion != 12) return op->emitOpError("WMMA only supported on gfx11 and gfx12"); From 2dfd30bdef32976282a93dd8e4b1c7fc53d9381f Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Thu, 12 Sep 2024 19:08:00 +0000 Subject: [PATCH 4/4] Fix clang-format from suggestion --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 28991b6b9ed2b..f80d2793eaef5 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -674,8 +674,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { auto outType = typeConverter->convertType(op.getDestD().getType()); if (!outType) - return rewriter.notifyMatchFailure( - op, "type conversion failed"); + return rewriter.notifyMatchFailure(op, "type conversion failed"); if (chipset.majorVersion != 11 && chipset.majorVersion != 12) return op->emitOpError("WMMA only supported on gfx11 and gfx12");