From 7b64631ffa0f4e2880d0a839443e450120da7a68 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 10 Nov 2025 13:23:20 -0500 Subject: [PATCH 01/44] Update documentation --- mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 35 +++++++++++++------ 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 45cb67f0eee4a..4820b7a747ac2 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -127,7 +127,7 @@ def AMDGPU_ScaledExtPacked816Op FixedVectorOfShapeAndType<[4], F8E8M0FNU>:$scale, ConfinedAttr:$blockSize, ConfinedAttr, IntMaxValue<1>]>:$firstScaleLane, - ConfinedAttr, IntMaxValue<2>]>:$firstScaleByte)>, + ConfinedAttr, IntMaxValue<3>]>:$firstScaleByte)>, Results<( outs AnyTypeOf<[FixedVectorOfShapeAndType<[8], F32>, FixedVectorOfShapeAndType<[8], F16>, @@ -139,17 +139,21 @@ def AMDGPU_ScaledExtPacked816Op let summary = "Extend a vector of packed floating point values"; let description = [{ - The scales applied to the input microfloats are stored in two bytes which + The scales applied to the input microfloats are stored in bytes which come from the `scales` input provided in a *half* of the wave identified - by `firstScaleLane`. The pair of bytes used is selected by - `firstScaleByte`. The 16 vectors in consecutive lanes starting from + by `firstScaleLane`. The bytes used is selected by `firstScaleByte` and depends + on the type of `source`. The 16 vectors in consecutive lanes starting from `firstScaleLane` (which we'll call the scale vectors) will be used by both - halves of the wave (with lane L reading from L % 16'th scale vector), but - each half will use a different byte. + halves of the wave (with lane L reading from L % 16'th scale vector). + + When `source` is either F4E2M1FN, F6E2M3FN, or F6E3M2FN each half of the + wave will use a different byte. The first one being `firstScaleByte` and + the second one being `firstScaleByte` + 1. When the block size is 32, + `firstScaleByte` can be either 0 or 2, selecting halves of the scale vectors. + Lanes 0-15 will read from `firstScaleByte` and lanes 16-31 will read + from `firstScaleByte` + 1. + - When the block size is 32, `firstScaleByte` can be either 0 or 2, - selecting halves of the scale vectors. Lanes 0-15 will read from - `firstScaleByte` and lanes 16-31 will read from `firstScaleByte` + 1. For example: ```mlir // Input: 8-element vector of F8E4M3FN, converting to F32 @@ -165,7 +169,8 @@ def AMDGPU_ScaledExtPacked816Op : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16> ``` - However, when the block size is 16, `firstScaleByte` can be 0 or 1. + When `source` is either F4E2M1FN, F6E2M3FN, or F6E3M2FN and + the block size is 16, `firstScaleByte` can be 0 or 1. Lanes 0-15 read from the `firstScaleByte`th element of the scale vectors, while lanes 16-31 read from `firstScaleByte` + 2. For example: @@ -187,6 +192,16 @@ def AMDGPU_ScaledExtPacked816Op instructions use for matix scales. These selection operands allows one to choose portions of the matrix to convert. + When `source` is either F8E4M3FN or F8E5M2 and `blockSize` is 32, + then the same byte will be used by both halves of the wave. + In this case, `firstScaleByte` can be any value from 0 to 3. + + When `source` is either F8E4M3FN or F8E5M2 and `blockSize` is 16, + following combinations are allowed: + * `firstScaleLane(0), firstScaleByte(0)` + * `firstScaleLane(1), firstScaleByte(2)` + all other combinations are reserved. + Available on gfx1250+. }]; From 08e96b19369451dd5ec4e72ed2905bd0b2e0cf71 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 10 Nov 2025 13:44:08 -0500 Subject: [PATCH 02/44] Fix verifiers --- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 23 +++++++++++++++----- mlir/test/Dialect/AMDGPU/invalid.mlir | 20 ++++++++++++----- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index df955fc90b45f..5c35823678576 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -344,14 +344,27 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns( LogicalResult ScaledExtPacked816Op::verify() { int blockSize = getBlockSize(); assert((blockSize == 16 || blockSize == 32) && "invalid block size"); + int firstScaleByte = getFirstScaleByte(); - if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) { - return emitOpError( - "blockSize of 16 can only have firstScaleByte be 0 or 1."); + auto sourceType = cast(getSource().getType()); + Type elementType = sourceType.getElementType(); + auto floatType = cast(elementType); + int bitWidth = floatType.getWidth(); + + if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 16 && + !llvm::is_contained({0, 1}, firstScaleByte)) { + return emitOpError("blockSize of 16 can only have firstScaleByte be 0 or 1 " + "for f4 and f6."); + } + if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 32 && + !llvm::is_contained({0, 2}, firstScaleByte)) { + return emitOpError("blockSize of 32 can only have firstScaleByte be 0 or 2 " + "for f4 and f6."); } - if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) { + if (bitWidth == 8 && blockSize == 16 && + !llvm::is_contained({0, 2}, firstScaleByte)) { return emitOpError( - "blockSize of 32 can only have firstScaleByte be 0 or 2."); + "blockSize of 16 can only have firstScaleByte be 0 or 2 for f8."); } return success(); diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir index 4c6f62a045405..5c8cc8b67c4b3 100644 --- a/mlir/test/Dialect/AMDGPU/invalid.mlir +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -333,17 +333,25 @@ func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 : // ----- -func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { - // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1.}} - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> +func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1 for f4 and f6}} + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> func.return } // ----- -func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { - // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2.}} - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> +func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2 for f4 and f6.}} + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + func.return +} + +// ----- + +func.func @amdgpu.scaled_ext_packed816_invalid_attributes_for_f8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.}} + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> func.return } From d0932cc2935840b7e86dc700f0ec056ccabebdea Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 10 Nov 2025 11:13:57 -0500 Subject: [PATCH 03/44] [mlir][amdgpu] Convert scaled_ext_packed816 to rocdl --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 3a307a0756d93..48c1b17a2203a 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1492,6 +1492,19 @@ struct ExtPackedFp8OpLowering final ConversionPatternRewriter &rewriter) const override; }; +struct ScaledExtPacked816OpLowering final + : public ConvertOpToLLVMPattern { + ScaledExtPacked816OpLowering(const LLVMTypeConverter &converter, + Chipset chipset) + : ConvertOpToLLVMPattern(converter), + chipset(chipset) {} + Chipset chipset; + + LogicalResult + matchAndRewrite(ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + struct PackedTrunc2xFp8OpLowering final : public ConvertOpToLLVMPattern { PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter, @@ -1600,6 +1613,12 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( return success(); } +LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( + ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + return failure(); +} + LogicalResult ScaledExtPackedOpLowering::matchAndRewrite( ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -2138,9 +2157,10 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, ROCDL::RawPtrBufferAtomicCmpSwap>, AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, - WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, - PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, - PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, - TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset); + WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPacked816OpLowering, + ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, + PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, + GatherToLDSOpLowering, TransposeLoadOpLowering, + AMDGPUPermlaneLowering>(converter, chipset); patterns.add(converter); } From 0d1d668762b534c90f2f838788b495c09b872d81 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 10 Nov 2025 15:43:41 -0500 Subject: [PATCH 04/44] Create skeleton for pattern --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 120 ++++++++++++++++++ 1 file changed, 120 insertions(+) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 48c1b17a2203a..568013bee5ec8 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1613,9 +1613,129 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( return success(); } +int getScaleSel(int blockSize, int bitWidth, int firstScaleLane, + int firstScaleByte) { + // When lowering amdgpu.scaled_ext_packed816 to + // rocdl.cvt.scale.pk*.f*.f* operations, the + // attributes blockSize, sourceType, firstScaleLane and firstScaleByte + // are merged into a single attribute scaleSel. + // + // This is how those values are merged together. + assert(llvm::is_contained({16, 32}, blockSize)); + assert(llvm::is_contained({4, 6, 8}, bitWidth)); + + const bool is_fp8 = bitWidth == 8; + const bool is_block_16 = blockSize == 16; + + if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) { + return 0b000; + } + if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) { + return 0b001; + } + if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) { + return 0b010; + } + if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && is_block_16) { + return 0b011; + } + if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && !is_block_16) { + return 0b100; + } + if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && is_block_16) { + return 0b101; + } + if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) { + return 0b110; + } + if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) { + return 0b111; + } + + if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) { + return 0b0000; + } + if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) { + return 0b0001; + } + if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 1 && !is_block_16) { + return 0b0010; + } + if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) { + return 0b0100; + } + if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 3 && !is_block_16) { + return 0b0110; + } + if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 1 && !is_block_16) { + return 0b1010; + } + if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) { + return 0b1100; + } + if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) { + return 0b1101; + } + if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 3 && !is_block_16) { + return 0b1110; + } + + llvm_unreachable("invalid combination of firstScaleLane, firstScaleByte, " + "blockSize and type."); + return 0; +} + LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + + int firstScaleLane = op.getFirstScaleLane(); + int firstScaleByte = op.getFirstScaleByte(); + int blockSize = op.getBlockSize(); + auto sourceType = cast(op.getSource().getType()); + auto srcElemType = cast(sourceType.getElementType()); + int bitWidth = srcElemType.getWidth(); + int scaleSel = + getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte); + + auto targetType = cast(op.getResult().getType()); + auto tgtElemType = cast(targetType.getElementType()); + Location loc = op.getLoc(); + // Ok, so we need to construct ops depending on the sourceType and targetType. + // smallT = [Fp4, Fp8, Bf8] + // largeT = [F16, Bf16, F32] + // CvtPkScalePk{8}${largeT}${smallT} + if (isa(srcElemType) and isa(tgtElemType)) { + ROCDL::CvtPkScalePk8F16Fp4Op::create( + rewriter, loc, op.getResult().getType(), adaptor.getSource(), + adaptor.getScale(), scaleSel); + return success(); + } + /* + CvtPkScalePk8F16Fp8Op + CvtPkScalePk8F16Bf8Op + + CvtPkScalePk8Bf16Fp4Op + CvtPkScalePk8Bf16Fp8Op + CvtPkScalePk8Bf16Bf8Op + + CvtPkScalePk8F32Fp4Op + CvtPkScalePk8F32Fp8Op + CvtPkScalePk8F32Bf8Op + + // smallT = [Fp6, Bf6] + // largeT = [F16, Bf16, F32] + // CvtPkScalePk{16}${largeT}${smallT} + CvtPkScale16F16Fp6Op + CvtPkScale16F16Bf6Op + + CvtPkScale16Bf16Fp6Op + CvtPkScale16Bf16Bf6Op + + CvtPkScale16F32Fp6Op + CvtPkScale16F32Bf6Op + */ + return failure(); } From 163b15aef1795ed5d30e627a7f781406b00a37d2 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 10 Nov 2025 16:15:34 -0500 Subject: [PATCH 05/44] Initial conversion --- .../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 568013bee5ec8..d93332a3f3c40 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1701,14 +1701,24 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( auto targetType = cast(op.getResult().getType()); auto tgtElemType = cast(targetType.getElementType()); Location loc = op.getLoc(); + // %scale: vector<4xf8E8M0FNU> + // =========================== + // %scale: i32 + IntegerType i32 = rewriter.getI32Type(); + Value castedScale = + LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale()); + // Ok, so we need to construct ops depending on the sourceType and targetType. // smallT = [Fp4, Fp8, Bf8] // largeT = [F16, Bf16, F32] // CvtPkScalePk{8}${largeT}${smallT} if (isa(srcElemType) and isa(tgtElemType)) { - ROCDL::CvtPkScalePk8F16Fp4Op::create( - rewriter, loc, op.getResult().getType(), adaptor.getSource(), - adaptor.getScale(), scaleSel); + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource()); + auto newOp = ROCDL::CvtPkScalePk8F16Fp4Op::create( + rewriter, loc, op.getResult().getType(), castedSource, castedScale, + scaleSel); + rewriter.replaceOp(op, newOp); return success(); } /* From 6d7e2a65171ae253430067e3db23b94ca6a29c27 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 10 Nov 2025 16:26:39 -0500 Subject: [PATCH 06/44] Add first test --- .../Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index 2fd3df6dcfa71..840187d1f36d7 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -456,3 +456,16 @@ func.func @sched_barrier() { amdgpu.sched_barrier allow = func.return } + +// CHECK-LABEL: @scaled_ext_packed816_fp4 +// CHECK: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) { + // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4> + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32 + // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16> + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + func.return %ret0: vector<8xf16> +} + From fc5d8587ac8f3ea9142d3810d1dfea7ae42b4ff1 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 10 Nov 2025 17:18:58 -0500 Subject: [PATCH 07/44] Adds two more cases --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 28 +++++++++++++++++++ .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 25 ++++++++++++++++- 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index d93332a3f3c40..605bcf38204ba 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1721,6 +1721,34 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOp(op, newOp); return success(); } + if (isa(srcElemType) and isa(tgtElemType)) { + // vector<8xf8E4M3FN> + Value source = adaptor.getSource(); + + // vector<2xi32> + VectorType v2xi32 = VectorType::get(2, i32); + Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); + + auto newOp = ROCDL::CvtPkScalePk8F16Fp8Op::create( + rewriter, loc, op.getResult().getType(), castedSource, castedScale, + scaleSel); + rewriter.replaceOp(op, newOp); + return success(); + } + if (isa(srcElemType) and isa(tgtElemType)) { + // vector<8xf8E5M2> + Value source = adaptor.getSource(); + + // vector<2xi32> + VectorType v2xi32 = VectorType::get(2, i32); + Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); + + auto newOp = ROCDL::CvtPkScalePk8F16Bf8Op::create( + rewriter, loc, op.getResult().getType(), castedSource, castedScale, + scaleSel); + rewriter.replaceOp(op, newOp); + return success(); + } /* CvtPkScalePk8F16Fp8Op CvtPkScalePk8F16Bf8Op diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index 840187d1f36d7..8edd6c038af1e 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -458,7 +458,7 @@ func.func @sched_barrier() { } // CHECK-LABEL: @scaled_ext_packed816_fp4 -// CHECK: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) { // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4> @@ -469,3 +469,26 @@ func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E func.return %ret0: vector<8xf16> } +// CHECK-LABEL: @scaled_ext_packed816_fp8 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) { + // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8> + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16> + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + func.return %ret0 : vector<8xf16> +} + +// CHECK-LABEL: @scaled_ext_packed816_bf8 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) { + // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8> + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16> + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> + func.return %ret0 : vector<8xf16> +} From d9a254f629fe430c961ef407e2d529a1a10f691a Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 09:14:28 -0500 Subject: [PATCH 08/44] Add case for pk8.bf16.fp4 --- .../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 18 +++++++++++++++--- .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 11 ++++++++--- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 605bcf38204ba..5fd58c2b2906b 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1713,6 +1713,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( // largeT = [F16, Bf16, F32] // CvtPkScalePk{8}${largeT}${smallT} if (isa(srcElemType) and isa(tgtElemType)) { + // CvtPkScalePk8F16Fp4Op + // i32 Value castedSource = LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource()); auto newOp = ROCDL::CvtPkScalePk8F16Fp4Op::create( @@ -1722,6 +1724,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( return success(); } if (isa(srcElemType) and isa(tgtElemType)) { + // CvtPkScalePk8F16Fp8Op // vector<8xf8E4M3FN> Value source = adaptor.getSource(); @@ -1736,6 +1739,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( return success(); } if (isa(srcElemType) and isa(tgtElemType)) { + // CvtPkScalePk8F16Bf8Op // vector<8xf8E5M2> Value source = adaptor.getSource(); @@ -1749,11 +1753,19 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOp(op, newOp); return success(); } + if (isa(srcElemType) and isa(tgtElemType)) { + // CvtPkScalePk8Bf16Fp4Op + // i32 + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource()); + auto newOp = ROCDL::CvtPkScalePk8Bf16Fp4Op::create( + rewriter, loc, op.getResult().getType(), castedSource, castedScale, + scaleSel); + rewriter.replaceOp(op, newOp); + return success(); + } /* - CvtPkScalePk8F16Fp8Op - CvtPkScalePk8F16Bf8Op - CvtPkScalePk8Bf16Fp4Op CvtPkScalePk8Bf16Fp8Op CvtPkScalePk8Bf16Bf8Op diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index 8edd6c038af1e..d2099d2f60eff 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -459,14 +459,19 @@ func.func @sched_barrier() { // CHECK-LABEL: @scaled_ext_packed816_fp4 // CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) -func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) { +func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) { // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4> // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32 - // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16> + // CHECK: rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16> %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> - func.return %ret0: vector<8xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32 + // CHECK: rocdl.cvt.scale.pk8.bf16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xbf16> + %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> + func.return %ret0, %ret1: vector<8xf16>, vector<8xbf16> } // CHECK-LABEL: @scaled_ext_packed816_fp8 From c5eb6989affbd32c2100909943e790af851a5a43 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 09:28:50 -0500 Subject: [PATCH 09/44] Add conversion for pk8.bf16.bf8 --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 16 +++++++++- .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 31 ++++++++++++------- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 5fd58c2b2906b..ede420606b7b4 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1764,9 +1764,23 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOp(op, newOp); return success(); } + if (isa(srcElemType) and isa(tgtElemType)) { + // CvtPkScalePk8Bf16Fp8Op + // vector<8xf8E5M2> + Value source = adaptor.getSource(); + + // vector<2xi32> + VectorType v2xi32 = VectorType::get(2, i32); + Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); + + auto newOp = ROCDL::CvtPkScalePk8Bf16Fp8Op::create( + rewriter, loc, op.getResult().getType(), castedSource, castedScale, + scaleSel); + rewriter.replaceOp(op, newOp); + return success(); + } /* - CvtPkScalePk8Bf16Fp8Op CvtPkScalePk8Bf16Bf8Op CvtPkScalePk8F32Fp4Op diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index d2099d2f60eff..e6b2ef7bc3f79 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -474,6 +474,23 @@ func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E func.return %ret0, %ret1: vector<8xf16>, vector<8xbf16> } +// CHECK-LABEL: @scaled_ext_packed816_bf8 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) { + // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8> + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16> + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16> + %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16> + func.return %ret0, %ret1 : vector<8xf16>, vector<8xbf16> +} + // CHECK-LABEL: @scaled_ext_packed816_fp8 // CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) { @@ -481,19 +498,9 @@ func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8> // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> - // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16> + // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16> %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16> - func.return %ret0 : vector<8xf16> -} -// CHECK-LABEL: @scaled_ext_packed816_bf8 -// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) -func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) { - // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> - // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8> - // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 - // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> - // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16> - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> func.return %ret0 : vector<8xf16> } + From cec5f045c639adeb2fda944751768b69fd9a5b9d Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 09:41:25 -0500 Subject: [PATCH 10/44] Fix and add new case --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 24 ++++++++++++-- .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 32 +++++++++++-------- 2 files changed, 40 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index ede420606b7b4..8a55f41e8495c 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1710,8 +1710,12 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( // Ok, so we need to construct ops depending on the sourceType and targetType. // smallT = [Fp4, Fp8, Bf8] + // Bf8 = E5M2 + // Fp8 = E4M3 + // // largeT = [F16, Bf16, F32] // CvtPkScalePk{8}${largeT}${smallT} + if (isa(srcElemType) and isa(tgtElemType)) { // CvtPkScalePk8F16Fp4Op // i32 @@ -1764,9 +1768,9 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOp(op, newOp); return success(); } - if (isa(srcElemType) and isa(tgtElemType)) { + if (isa(srcElemType) and isa(tgtElemType)) { // CvtPkScalePk8Bf16Fp8Op - // vector<8xf8E5M2> + // vector<8xf8E4M3FN> Value source = adaptor.getSource(); // vector<2xi32> @@ -1779,9 +1783,23 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOp(op, newOp); return success(); } + if (isa(srcElemType) and isa(tgtElemType)) { + // CvtPkScalePk8Bf16Bf8Op + // vector<8xf8E5M2> + Value source = adaptor.getSource(); + + // vector<2xi32> + VectorType v2xi32 = VectorType::get(2, i32); + Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); + + auto newOp = ROCDL::CvtPkScalePk8Bf16Bf8Op::create( + rewriter, loc, op.getResult().getType(), castedSource, castedScale, + scaleSel); + rewriter.replaceOp(op, newOp); + return success(); + } /* - CvtPkScalePk8Bf16Bf8Op CvtPkScalePk8F32Fp4Op CvtPkScalePk8F32Fp8Op diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index e6b2ef7bc3f79..e248856ed472e 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -474,6 +474,24 @@ func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E func.return %ret0, %ret1: vector<8xf16>, vector<8xbf16> } +// CHECK-LABEL: @scaled_ext_packed816_fp8 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) { + // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8> + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16> + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16> + %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> + + func.return %ret0, %ret1 : vector<8xf16>, vector<8xbf16> +} + // CHECK-LABEL: @scaled_ext_packed816_bf8 // CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) { @@ -486,21 +504,9 @@ func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> - // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16> + // CHECK: rocdl.cvt.scale.pk8.bf16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16> %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16> func.return %ret0, %ret1 : vector<8xf16>, vector<8xbf16> } -// CHECK-LABEL: @scaled_ext_packed816_fp8 -// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) -func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>) { - // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> - // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8> - // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 - // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> - // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16> - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16> - - func.return %ret0 : vector<8xf16> -} From 7dc34425abb1ff3271b039cbd95d008531c6a940 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 09:45:32 -0500 Subject: [PATCH 11/44] Add case for pk8.f32.fp4 --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 13 +++++++++++-- .../Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 9 +++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 8a55f41e8495c..5c5aad91cd8e0 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1798,9 +1798,18 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOp(op, newOp); return success(); } + if (isa(srcElemType) and isa(tgtElemType)) { + // CvtPkScalePk8F32Fp4Op + // i32 + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource()); + auto newOp = ROCDL::CvtPkScalePk8F32Fp4Op::create( + rewriter, loc, op.getResult().getType(), castedSource, castedScale, + scaleSel); + rewriter.replaceOp(op, newOp); + return success(); + } /* - - CvtPkScalePk8F32Fp4Op CvtPkScalePk8F32Fp8Op CvtPkScalePk8F32Bf8Op diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index e248856ed472e..13f109b787d4d 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -459,7 +459,7 @@ func.func @sched_barrier() { // CHECK-LABEL: @scaled_ext_packed816_fp4 // CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) -func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) { +func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4> // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 @@ -471,7 +471,12 @@ func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32 // CHECK: rocdl.cvt.scale.pk8.bf16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xbf16> %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> - func.return %ret0, %ret1: vector<8xf16>, vector<8xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32 + // CHECK: rocdl.cvt.scale.pk8.f32.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf32> + %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32> + func.return %ret0, %ret1, %ret2: vector<8xf16>, vector<8xbf16>, vector<8xf32> } // CHECK-LABEL: @scaled_ext_packed816_fp8 From 0ba6b949e626bbafdb5fdc5b94cea342b4e7aab5 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 09:50:01 -0500 Subject: [PATCH 12/44] Add case for pk8.f32.fp8 --- .../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 17 +++++++++++++++-- .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 9 +++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 5c5aad91cd8e0..a347a822aba66 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1809,9 +1809,22 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOp(op, newOp); return success(); } + if (isa(srcElemType) and isa(tgtElemType)) { + // CvtPkScalePk8F32Fp8Op + // vector<8xf8E4M3FN> + Value source = adaptor.getSource(); + + // vector<2xi32> + VectorType v2xi32 = VectorType::get(2, i32); + Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); + + auto newOp = ROCDL::CvtPkScalePk8F32Fp8Op::create( + rewriter, loc, op.getResult().getType(), castedSource, castedScale, + scaleSel); + rewriter.replaceOp(op, newOp); + return success(); + } /* - CvtPkScalePk8F32Fp4Op - CvtPkScalePk8F32Fp8Op CvtPkScalePk8F32Bf8Op // smallT = [Fp6, Bf6] diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index 13f109b787d4d..917d25b9bc8f2 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -481,7 +481,7 @@ func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E // CHECK-LABEL: @scaled_ext_packed816_fp8 // CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) -func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) { +func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8> // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 @@ -494,7 +494,12 @@ func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16> %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> - func.return %ret0, %ret1 : vector<8xf16>, vector<8xbf16> + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.f32.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32> + %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32> + + func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> } // CHECK-LABEL: @scaled_ext_packed816_bf8 From 551849e8496343a293005c32e9f3c6cc6429eba6 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 09:54:18 -0500 Subject: [PATCH 13/44] Add case for pk8.f32.bf8 --- .../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 16 +++++++++++++++- .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 9 +++++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index a347a822aba66..72617f76e6c91 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1824,8 +1824,22 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOp(op, newOp); return success(); } + if (isa(srcElemType) and isa(tgtElemType)) { + // CvtPkScalePk8F32Bf8Op + // vector<8xf8E5M2> + Value source = adaptor.getSource(); + + // vector<2xi32> + VectorType v2xi32 = VectorType::get(2, i32); + Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); + + auto newOp = ROCDL::CvtPkScalePk8F32Bf8Op::create( + rewriter, loc, op.getResult().getType(), castedSource, castedScale, + scaleSel); + rewriter.replaceOp(op, newOp); + return success(); + } /* - CvtPkScalePk8F32Bf8Op // smallT = [Fp6, Bf6] // largeT = [F16, Bf16, F32] diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index 917d25b9bc8f2..220759b6fab39 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -504,7 +504,7 @@ func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E // CHECK-LABEL: @scaled_ext_packed816_bf8 // CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) -func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>) { +func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8> // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 @@ -516,7 +516,12 @@ func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> // CHECK: rocdl.cvt.scale.pk8.bf16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16> %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16> - func.return %ret0, %ret1 : vector<8xf16>, vector<8xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.f32.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32> + %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32> + func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> } From c1e10c8e4b9b6a6ba9ce210aabc7bc97a5608cab Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 10:26:48 -0500 Subject: [PATCH 14/44] Add case for pk16.f16.bf6 --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 38 ++++++++++++++++--- .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 26 +++++++++++++ 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 72617f76e6c91..af068fbfd957f 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1714,7 +1714,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( // Fp8 = E4M3 // // largeT = [F16, Bf16, F32] - // CvtPkScalePk{8}${largeT}${smallT} + // CvtPkScalePk8${largeT}${smallT} if (isa(srcElemType) and isa(tgtElemType)) { // CvtPkScalePk8F16Fp4Op @@ -1839,13 +1839,39 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOp(op, newOp); return success(); } - /* - // smallT = [Fp6, Bf6] // largeT = [F16, Bf16, F32] - // CvtPkScalePk{16}${largeT}${smallT} - CvtPkScale16F16Fp6Op - CvtPkScale16F16Bf6Op + // + // Fp6 = Float6E2M3FN + // Bf6 = Float6E3M2FN + + // CvtPkScalePk16${largeT}${smallT} + if (isa(srcElemType) and isa(tgtElemType)) { + // CvtPkScale16F16Fp6Op + Value source = adaptor.getSource(); + VectorType v3xi32 = VectorType::get(3, i32); + Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); + + auto newOp = ROCDL::CvtPkScalePk16F16Fp6Op::create( + rewriter, loc, op.getResult().getType(), castedSource, castedScale, + scaleSel); + rewriter.replaceOp(op, newOp); + return success(); + } + if (isa(srcElemType) and isa(tgtElemType)) { + // CvtPkScale16F16Bf6Op + Value source = adaptor.getSource(); + VectorType v3xi32 = VectorType::get(3, i32); + Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); + + auto newOp = ROCDL::CvtPkScalePk16F16Bf6Op::create( + rewriter, loc, op.getResult().getType(), castedSource, castedScale, + scaleSel); + rewriter.replaceOp(op, newOp); + return success(); + } + + /* CvtPkScale16Bf16Fp6Op CvtPkScale16Bf16Bf6Op diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index 220759b6fab39..349440053a646 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -525,3 +525,29 @@ func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M } +// CHECK-LABEL: @scaled_ext_packed816_fp6 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) { + // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.f16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16> + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16> + return %ret0: vector<16xf16> +} + +// CHECK-LABEL: @scaled_ext_packed816_bf6 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) { + // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.f16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16> + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16> + return %ret0: vector<16xf16> +} + From e958d56ca7e5fa32767a10186fe17905adf11ecf Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 10:30:17 -0500 Subject: [PATCH 15/44] Add case for pk16.bf16.fp6 --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 13 ++++++++++++- .../Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 9 +++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index af068fbfd957f..cbe3251268f78 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1870,10 +1870,21 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOp(op, newOp); return success(); } + if (isa(srcElemType) and isa(tgtElemType)) { + // CvtPkScale16Bf16Fp6Op + Value source = adaptor.getSource(); + VectorType v3xi32 = VectorType::get(3, i32); + Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); + + auto newOp = ROCDL::CvtPkScalePk16Bf16Fp6Op::create( + rewriter, loc, op.getResult().getType(), castedSource, castedScale, + scaleSel); + rewriter.replaceOp(op, newOp); + return success(); + } /* - CvtPkScale16Bf16Fp6Op CvtPkScale16Bf16Bf6Op CvtPkScale16F32Fp6Op diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index 349440053a646..76e26e0cab40a 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -527,7 +527,7 @@ func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M // CHECK-LABEL: @scaled_ext_packed816_fp6 // CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) -func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) { +func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>) { // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6> @@ -535,7 +535,12 @@ func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8 // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> // CHECK: rocdl.cvt.scale.pk16.f16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16> %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16> - return %ret0: vector<16xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.bf16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16> + %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> + return %ret0, %ret1: vector<16xf16>, vector<16xbf16> } // CHECK-LABEL: @scaled_ext_packed816_bf6 From 1f79bdd99637a7be7765df808c3e3a63199ca524 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 10:38:08 -0500 Subject: [PATCH 16/44] Add case for pk16.bf16.bf6 --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 13 ++++++++++++- .../Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 9 +++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index cbe3251268f78..322c8efcf3aea 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1882,10 +1882,21 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOp(op, newOp); return success(); } + if (isa(srcElemType) and isa(tgtElemType)) { + // CvtPkScale16Bf16Bf6Op + Value source = adaptor.getSource(); + VectorType v3xi32 = VectorType::get(3, i32); + Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); + + auto newOp = ROCDL::CvtPkScalePk16Bf16Bf6Op::create( + rewriter, loc, op.getResult().getType(), castedSource, castedScale, + scaleSel); + rewriter.replaceOp(op, newOp); + return success(); + } /* - CvtPkScale16Bf16Bf6Op CvtPkScale16F32Fp6Op CvtPkScale16F32Bf6Op diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index 76e26e0cab40a..b31116538228e 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -545,7 +545,7 @@ func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8 // CHECK-LABEL: @scaled_ext_packed816_bf6 // CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) -func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) { +func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>) { // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6> @@ -553,6 +553,11 @@ func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8 // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> // CHECK: rocdl.cvt.scale.pk16.f16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16> %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16> - return %ret0: vector<16xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.bf16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16> + %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> + return %ret0, %ret1: vector<16xf16>, vector<16xbf16> } From c5628e6ecf7a101a247da33fb55a497cfe33e98c Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 10:44:19 -0500 Subject: [PATCH 17/44] Add case for pk16.f32.fp6 --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 13 ++++++++++++- .../Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 9 +++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 322c8efcf3aea..fb189b24b29e7 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1894,11 +1894,22 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOp(op, newOp); return success(); } + if (isa(srcElemType) and isa(tgtElemType)) { + // CvtPkScale16F32Fp6Op + Value source = adaptor.getSource(); + VectorType v3xi32 = VectorType::get(3, i32); + Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); + + auto newOp = ROCDL::CvtPkScalePk16F32Fp6Op::create( + rewriter, loc, op.getResult().getType(), castedSource, castedScale, + scaleSel); + rewriter.replaceOp(op, newOp); + return success(); + } /* - CvtPkScale16F32Fp6Op CvtPkScale16F32Bf6Op */ diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index b31116538228e..dccdb81033738 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -527,7 +527,7 @@ func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M // CHECK-LABEL: @scaled_ext_packed816_fp6 // CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) -func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>) { +func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6> @@ -540,7 +540,12 @@ func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8 // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> // CHECK: rocdl.cvt.scale.pk16.bf16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16> %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> - return %ret0, %ret1: vector<16xf16>, vector<16xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.f32.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32> + %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32> + return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32> } // CHECK-LABEL: @scaled_ext_packed816_bf6 From db56c98bfa77ab923d4cfed8ddbb4ed15863b690 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 10:48:29 -0500 Subject: [PATCH 18/44] Add case for pk16.f32.bf6 --- .../Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 16 +++++++++++----- .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 9 +++++++-- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index fb189b24b29e7..26282ec4c2279 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1906,12 +1906,18 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOp(op, newOp); return success(); } + if (isa(srcElemType) and isa(tgtElemType)) { + // CvtPkScale16F32Bf6Op + Value source = adaptor.getSource(); + VectorType v3xi32 = VectorType::get(3, i32); + Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); - /* - - - CvtPkScale16F32Bf6Op - */ + auto newOp = ROCDL::CvtPkScalePk16F32Bf6Op::create( + rewriter, loc, op.getResult().getType(), castedSource, castedScale, + scaleSel); + rewriter.replaceOp(op, newOp); + return success(); + } return failure(); } diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index dccdb81033738..94a04d98004c7 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -550,7 +550,7 @@ func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8 // CHECK-LABEL: @scaled_ext_packed816_bf6 // CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) -func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>) { +func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6> @@ -563,6 +563,11 @@ func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8 // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> // CHECK: rocdl.cvt.scale.pk16.bf16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16> %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> - return %ret0, %ret1: vector<16xf16>, vector<16xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.f32.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32> + %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32> + return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32> } From 73ed4b7b7d1562a0aa50fa7caec5c15426c57c27 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 10:52:19 -0500 Subject: [PATCH 19/44] Refactor NFC --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 74 ++++++++----------- 1 file changed, 31 insertions(+), 43 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 26282ec4c2279..44393033ef442 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1725,9 +1725,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter, loc, op.getResult().getType(), castedSource, castedScale, scaleSel); rewriter.replaceOp(op, newOp); - return success(); - } - if (isa(srcElemType) and isa(tgtElemType)) { + } else if (isa(srcElemType) and + isa(tgtElemType)) { // CvtPkScalePk8F16Fp8Op // vector<8xf8E4M3FN> Value source = adaptor.getSource(); @@ -1740,9 +1739,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter, loc, op.getResult().getType(), castedSource, castedScale, scaleSel); rewriter.replaceOp(op, newOp); - return success(); - } - if (isa(srcElemType) and isa(tgtElemType)) { + } else if (isa(srcElemType) and + isa(tgtElemType)) { // CvtPkScalePk8F16Bf8Op // vector<8xf8E5M2> Value source = adaptor.getSource(); @@ -1755,9 +1753,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter, loc, op.getResult().getType(), castedSource, castedScale, scaleSel); rewriter.replaceOp(op, newOp); - return success(); - } - if (isa(srcElemType) and isa(tgtElemType)) { + } else if (isa(srcElemType) and + isa(tgtElemType)) { // CvtPkScalePk8Bf16Fp4Op // i32 Value castedSource = @@ -1766,9 +1763,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter, loc, op.getResult().getType(), castedSource, castedScale, scaleSel); rewriter.replaceOp(op, newOp); - return success(); - } - if (isa(srcElemType) and isa(tgtElemType)) { + } else if (isa(srcElemType) and + isa(tgtElemType)) { // CvtPkScalePk8Bf16Fp8Op // vector<8xf8E4M3FN> Value source = adaptor.getSource(); @@ -1781,9 +1777,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter, loc, op.getResult().getType(), castedSource, castedScale, scaleSel); rewriter.replaceOp(op, newOp); - return success(); - } - if (isa(srcElemType) and isa(tgtElemType)) { + } else if (isa(srcElemType) and + isa(tgtElemType)) { // CvtPkScalePk8Bf16Bf8Op // vector<8xf8E5M2> Value source = adaptor.getSource(); @@ -1796,9 +1791,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter, loc, op.getResult().getType(), castedSource, castedScale, scaleSel); rewriter.replaceOp(op, newOp); - return success(); - } - if (isa(srcElemType) and isa(tgtElemType)) { + } else if (isa(srcElemType) and + isa(tgtElemType)) { // CvtPkScalePk8F32Fp4Op // i32 Value castedSource = @@ -1807,9 +1801,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter, loc, op.getResult().getType(), castedSource, castedScale, scaleSel); rewriter.replaceOp(op, newOp); - return success(); - } - if (isa(srcElemType) and isa(tgtElemType)) { + } else if (isa(srcElemType) and + isa(tgtElemType)) { // CvtPkScalePk8F32Fp8Op // vector<8xf8E4M3FN> Value source = adaptor.getSource(); @@ -1822,9 +1815,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter, loc, op.getResult().getType(), castedSource, castedScale, scaleSel); rewriter.replaceOp(op, newOp); - return success(); - } - if (isa(srcElemType) and isa(tgtElemType)) { + } else if (isa(srcElemType) and + isa(tgtElemType)) { // CvtPkScalePk8F32Bf8Op // vector<8xf8E5M2> Value source = adaptor.getSource(); @@ -1837,7 +1829,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter, loc, op.getResult().getType(), castedSource, castedScale, scaleSel); rewriter.replaceOp(op, newOp); - return success(); } // smallT = [Fp6, Bf6] // largeT = [F16, Bf16, F32] @@ -1846,7 +1837,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( // Bf6 = Float6E3M2FN // CvtPkScalePk16${largeT}${smallT} - if (isa(srcElemType) and isa(tgtElemType)) { + else if (isa(srcElemType) and + isa(tgtElemType)) { // CvtPkScale16F16Fp6Op Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); @@ -1856,9 +1848,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter, loc, op.getResult().getType(), castedSource, castedScale, scaleSel); rewriter.replaceOp(op, newOp); - return success(); - } - if (isa(srcElemType) and isa(tgtElemType)) { + } else if (isa(srcElemType) and + isa(tgtElemType)) { // CvtPkScale16F16Bf6Op Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); @@ -1868,9 +1859,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter, loc, op.getResult().getType(), castedSource, castedScale, scaleSel); rewriter.replaceOp(op, newOp); - return success(); - } - if (isa(srcElemType) and isa(tgtElemType)) { + } else if (isa(srcElemType) and + isa(tgtElemType)) { // CvtPkScale16Bf16Fp6Op Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); @@ -1880,9 +1870,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter, loc, op.getResult().getType(), castedSource, castedScale, scaleSel); rewriter.replaceOp(op, newOp); - return success(); - } - if (isa(srcElemType) and isa(tgtElemType)) { + } else if (isa(srcElemType) and + isa(tgtElemType)) { // CvtPkScale16Bf16Bf6Op Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); @@ -1892,9 +1881,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter, loc, op.getResult().getType(), castedSource, castedScale, scaleSel); rewriter.replaceOp(op, newOp); - return success(); - } - if (isa(srcElemType) and isa(tgtElemType)) { + } else if (isa(srcElemType) and + isa(tgtElemType)) { // CvtPkScale16F32Fp6Op Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); @@ -1904,9 +1892,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter, loc, op.getResult().getType(), castedSource, castedScale, scaleSel); rewriter.replaceOp(op, newOp); - return success(); - } - if (isa(srcElemType) and isa(tgtElemType)) { + } else if (isa(srcElemType) and + isa(tgtElemType)) { // CvtPkScale16F32Bf6Op Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); @@ -1916,10 +1903,11 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter, loc, op.getResult().getType(), castedSource, castedScale, scaleSel); rewriter.replaceOp(op, newOp); - return success(); + } else { + return failure(); } - return failure(); + return success(); } LogicalResult ScaledExtPackedOpLowering::matchAndRewrite( From 94cc74070c0d91382c3444d96e837eb17f6afc52 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 11:01:43 -0500 Subject: [PATCH 20/44] Refactor NFC --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 90 +++++++------------ 1 file changed, 30 insertions(+), 60 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 44393033ef442..e28b53d8ed1a5 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1721,10 +1721,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( // i32 Value castedSource = LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource()); - auto newOp = ROCDL::CvtPkScalePk8F16Fp4Op::create( - rewriter, loc, op.getResult().getType(), castedSource, castedScale, - scaleSel); - rewriter.replaceOp(op, newOp); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and isa(tgtElemType)) { // CvtPkScalePk8F16Fp8Op @@ -1735,10 +1733,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( VectorType v2xi32 = VectorType::get(2, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); - auto newOp = ROCDL::CvtPkScalePk8F16Fp8Op::create( - rewriter, loc, op.getResult().getType(), castedSource, castedScale, - scaleSel); - rewriter.replaceOp(op, newOp); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and isa(tgtElemType)) { // CvtPkScalePk8F16Bf8Op @@ -1749,20 +1745,16 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( VectorType v2xi32 = VectorType::get(2, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); - auto newOp = ROCDL::CvtPkScalePk8F16Bf8Op::create( - rewriter, loc, op.getResult().getType(), castedSource, castedScale, - scaleSel); - rewriter.replaceOp(op, newOp); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and isa(tgtElemType)) { // CvtPkScalePk8Bf16Fp4Op // i32 Value castedSource = LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource()); - auto newOp = ROCDL::CvtPkScalePk8Bf16Fp4Op::create( - rewriter, loc, op.getResult().getType(), castedSource, castedScale, - scaleSel); - rewriter.replaceOp(op, newOp); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and isa(tgtElemType)) { // CvtPkScalePk8Bf16Fp8Op @@ -1773,10 +1765,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( VectorType v2xi32 = VectorType::get(2, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); - auto newOp = ROCDL::CvtPkScalePk8Bf16Fp8Op::create( - rewriter, loc, op.getResult().getType(), castedSource, castedScale, - scaleSel); - rewriter.replaceOp(op, newOp); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and isa(tgtElemType)) { // CvtPkScalePk8Bf16Bf8Op @@ -1787,20 +1777,16 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( VectorType v2xi32 = VectorType::get(2, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); - auto newOp = ROCDL::CvtPkScalePk8Bf16Bf8Op::create( - rewriter, loc, op.getResult().getType(), castedSource, castedScale, - scaleSel); - rewriter.replaceOp(op, newOp); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and isa(tgtElemType)) { // CvtPkScalePk8F32Fp4Op // i32 Value castedSource = LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource()); - auto newOp = ROCDL::CvtPkScalePk8F32Fp4Op::create( - rewriter, loc, op.getResult().getType(), castedSource, castedScale, - scaleSel); - rewriter.replaceOp(op, newOp); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and isa(tgtElemType)) { // CvtPkScalePk8F32Fp8Op @@ -1811,10 +1797,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( VectorType v2xi32 = VectorType::get(2, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); - auto newOp = ROCDL::CvtPkScalePk8F32Fp8Op::create( - rewriter, loc, op.getResult().getType(), castedSource, castedScale, - scaleSel); - rewriter.replaceOp(op, newOp); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and isa(tgtElemType)) { // CvtPkScalePk8F32Bf8Op @@ -1825,10 +1809,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( VectorType v2xi32 = VectorType::get(2, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); - auto newOp = ROCDL::CvtPkScalePk8F32Bf8Op::create( - rewriter, loc, op.getResult().getType(), castedSource, castedScale, - scaleSel); - rewriter.replaceOp(op, newOp); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), castedSource, castedScale, scaleSel); } // smallT = [Fp6, Bf6] // largeT = [F16, Bf16, F32] @@ -1844,10 +1826,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( VectorType v3xi32 = VectorType::get(3, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); - auto newOp = ROCDL::CvtPkScalePk16F16Fp6Op::create( - rewriter, loc, op.getResult().getType(), castedSource, castedScale, - scaleSel); - rewriter.replaceOp(op, newOp); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and isa(tgtElemType)) { // CvtPkScale16F16Bf6Op @@ -1855,10 +1835,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( VectorType v3xi32 = VectorType::get(3, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); - auto newOp = ROCDL::CvtPkScalePk16F16Bf6Op::create( - rewriter, loc, op.getResult().getType(), castedSource, castedScale, - scaleSel); - rewriter.replaceOp(op, newOp); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and isa(tgtElemType)) { // CvtPkScale16Bf16Fp6Op @@ -1866,10 +1844,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( VectorType v3xi32 = VectorType::get(3, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); - auto newOp = ROCDL::CvtPkScalePk16Bf16Fp6Op::create( - rewriter, loc, op.getResult().getType(), castedSource, castedScale, - scaleSel); - rewriter.replaceOp(op, newOp); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and isa(tgtElemType)) { // CvtPkScale16Bf16Bf6Op @@ -1877,10 +1853,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( VectorType v3xi32 = VectorType::get(3, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); - auto newOp = ROCDL::CvtPkScalePk16Bf16Bf6Op::create( - rewriter, loc, op.getResult().getType(), castedSource, castedScale, - scaleSel); - rewriter.replaceOp(op, newOp); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and isa(tgtElemType)) { // CvtPkScale16F32Fp6Op @@ -1888,10 +1862,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( VectorType v3xi32 = VectorType::get(3, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); - auto newOp = ROCDL::CvtPkScalePk16F32Fp6Op::create( - rewriter, loc, op.getResult().getType(), castedSource, castedScale, - scaleSel); - rewriter.replaceOp(op, newOp); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and isa(tgtElemType)) { // CvtPkScale16F32Bf6Op @@ -1899,10 +1871,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( VectorType v3xi32 = VectorType::get(3, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); - auto newOp = ROCDL::CvtPkScalePk16F32Bf6Op::create( - rewriter, loc, op.getResult().getType(), castedSource, castedScale, - scaleSel); - rewriter.replaceOp(op, newOp); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else { return failure(); } From 4f27e043e39776b908e0ee4e079fd16fc7ef2f3a Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 13:45:30 -0500 Subject: [PATCH 21/44] Use method instead of isa --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 46 +++++++------------ 1 file changed, 16 insertions(+), 30 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index e28b53d8ed1a5..53f7accdb5a54 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1699,7 +1699,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte); auto targetType = cast(op.getResult().getType()); - auto tgtElemType = cast(targetType.getElementType()); + auto destElemType = cast(targetType.getElementType()); Location loc = op.getLoc(); // %scale: vector<4xf8E8M0FNU> // =========================== @@ -1716,15 +1716,14 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( // largeT = [F16, Bf16, F32] // CvtPkScalePk8${largeT}${smallT} - if (isa(srcElemType) and isa(tgtElemType)) { + if (isa(srcElemType) and destElemType.isF16()) { // CvtPkScalePk8F16Fp4Op // i32 Value castedSource = LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource()); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and - isa(tgtElemType)) { + } else if (isa(srcElemType) and destElemType.isF16()) { // CvtPkScalePk8F16Fp8Op // vector<8xf8E4M3FN> Value source = adaptor.getSource(); @@ -1735,8 +1734,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and - isa(tgtElemType)) { + } else if (isa(srcElemType) and destElemType.isF16()) { // CvtPkScalePk8F16Bf8Op // vector<8xf8E5M2> Value source = adaptor.getSource(); @@ -1747,16 +1745,14 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and - isa(tgtElemType)) { + } else if (isa(srcElemType) and destElemType.isBF16()) { // CvtPkScalePk8Bf16Fp4Op // i32 Value castedSource = LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource()); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and - isa(tgtElemType)) { + } else if (isa(srcElemType) and destElemType.isBF16()) { // CvtPkScalePk8Bf16Fp8Op // vector<8xf8E4M3FN> Value source = adaptor.getSource(); @@ -1767,8 +1763,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and - isa(tgtElemType)) { + } else if (isa(srcElemType) and destElemType.isBF16()) { // CvtPkScalePk8Bf16Bf8Op // vector<8xf8E5M2> Value source = adaptor.getSource(); @@ -1779,16 +1774,14 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and - isa(tgtElemType)) { + } else if (isa(srcElemType) and destElemType.isF32()) { // CvtPkScalePk8F32Fp4Op // i32 Value castedSource = LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource()); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and - isa(tgtElemType)) { + } else if (isa(srcElemType) and destElemType.isF32()) { // CvtPkScalePk8F32Fp8Op // vector<8xf8E4M3FN> Value source = adaptor.getSource(); @@ -1799,8 +1792,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and - isa(tgtElemType)) { + } else if (isa(srcElemType) and destElemType.isF32()) { // CvtPkScalePk8F32Bf8Op // vector<8xf8E5M2> Value source = adaptor.getSource(); @@ -1819,8 +1811,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( // Bf6 = Float6E3M2FN // CvtPkScalePk16${largeT}${smallT} - else if (isa(srcElemType) and - isa(tgtElemType)) { + else if (isa(srcElemType) and destElemType.isF16()) { // CvtPkScale16F16Fp6Op Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); @@ -1828,8 +1819,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and - isa(tgtElemType)) { + } else if (isa(srcElemType) and destElemType.isF16()) { // CvtPkScale16F16Bf6Op Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); @@ -1837,8 +1827,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and - isa(tgtElemType)) { + } else if (isa(srcElemType) and destElemType.isBF16()) { // CvtPkScale16Bf16Fp6Op Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); @@ -1846,8 +1835,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and - isa(tgtElemType)) { + } else if (isa(srcElemType) and destElemType.isBF16()) { // CvtPkScale16Bf16Bf6Op Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); @@ -1855,8 +1843,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and - isa(tgtElemType)) { + } else if (isa(srcElemType) and destElemType.isF32()) { // CvtPkScale16F32Fp6Op Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); @@ -1864,8 +1851,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and - isa(tgtElemType)) { + } else if (isa(srcElemType) and destElemType.isF32()) { // CvtPkScale16F32Bf6Op Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); From 0f8f3c493817873ca8047b7f2b788d1845716f09 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 13:47:59 -0500 Subject: [PATCH 22/44] Hoist variable --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 53f7accdb5a54..646d27a164830 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1715,6 +1715,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( // // largeT = [F16, Bf16, F32] // CvtPkScalePk8${largeT}${smallT} + Value source = adaptor.getSource(); if (isa(srcElemType) and destElemType.isF16()) { // CvtPkScalePk8F16Fp4Op @@ -1726,7 +1727,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( } else if (isa(srcElemType) and destElemType.isF16()) { // CvtPkScalePk8F16Fp8Op // vector<8xf8E4M3FN> - Value source = adaptor.getSource(); // vector<2xi32> VectorType v2xi32 = VectorType::get(2, i32); @@ -1737,7 +1737,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( } else if (isa(srcElemType) and destElemType.isF16()) { // CvtPkScalePk8F16Bf8Op // vector<8xf8E5M2> - Value source = adaptor.getSource(); // vector<2xi32> VectorType v2xi32 = VectorType::get(2, i32); @@ -1755,7 +1754,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( } else if (isa(srcElemType) and destElemType.isBF16()) { // CvtPkScalePk8Bf16Fp8Op // vector<8xf8E4M3FN> - Value source = adaptor.getSource(); // vector<2xi32> VectorType v2xi32 = VectorType::get(2, i32); @@ -1766,7 +1764,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( } else if (isa(srcElemType) and destElemType.isBF16()) { // CvtPkScalePk8Bf16Bf8Op // vector<8xf8E5M2> - Value source = adaptor.getSource(); // vector<2xi32> VectorType v2xi32 = VectorType::get(2, i32); @@ -1784,7 +1781,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( } else if (isa(srcElemType) and destElemType.isF32()) { // CvtPkScalePk8F32Fp8Op // vector<8xf8E4M3FN> - Value source = adaptor.getSource(); // vector<2xi32> VectorType v2xi32 = VectorType::get(2, i32); @@ -1795,7 +1791,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( } else if (isa(srcElemType) and destElemType.isF32()) { // CvtPkScalePk8F32Bf8Op // vector<8xf8E5M2> - Value source = adaptor.getSource(); // vector<2xi32> VectorType v2xi32 = VectorType::get(2, i32); @@ -1813,7 +1808,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( // CvtPkScalePk16${largeT}${smallT} else if (isa(srcElemType) and destElemType.isF16()) { // CvtPkScale16F16Fp6Op - Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); @@ -1821,7 +1815,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF16()) { // CvtPkScale16F16Bf6Op - Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); @@ -1829,7 +1822,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isBF16()) { // CvtPkScale16Bf16Fp6Op - Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); @@ -1837,7 +1829,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isBF16()) { // CvtPkScale16Bf16Bf6Op - Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); @@ -1845,7 +1836,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF32()) { // CvtPkScale16F32Fp6Op - Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); @@ -1853,7 +1843,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF32()) { // CvtPkScale16F32Bf6Op - Value source = adaptor.getSource(); VectorType v3xi32 = VectorType::get(3, i32); Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); From a7a853ea525814ef1f7707cdb99a6327fcbbbfa1 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 13:59:31 -0500 Subject: [PATCH 23/44] Refactor NFC --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 92 ++++++++----------- 1 file changed, 40 insertions(+), 52 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 646d27a164830..824e4249088ae 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1708,6 +1708,19 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( Value castedScale = LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale()); + Value source = adaptor.getSource(); + Type packedType; + if (isa(srcElemType)) { + packedType = i32; + } else if (isa(srcElemType) || + isa(srcElemType)) { + packedType = VectorType::get(2, i32); + } else if (isa(srcElemType) || + isa(srcElemType)) { + packedType = VectorType::get(3, i32); + } else { + llvm_unreachable("invalid element type for scaled ext"); + } // Ok, so we need to construct ops depending on the sourceType and targetType. // smallT = [Fp4, Fp8, Bf8] // Bf8 = E5M2 @@ -1715,87 +1728,68 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( // // largeT = [F16, Bf16, F32] // CvtPkScalePk8${largeT}${smallT} - Value source = adaptor.getSource(); if (isa(srcElemType) and destElemType.isF16()) { // CvtPkScalePk8F16Fp4Op // i32 Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource()); + LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF16()) { // CvtPkScalePk8F16Fp8Op // vector<8xf8E4M3FN> - - // vector<2xi32> - VectorType v2xi32 = VectorType::get(2, i32); - Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); - + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF16()) { // CvtPkScalePk8F16Bf8Op // vector<8xf8E5M2> - - // vector<2xi32> - VectorType v2xi32 = VectorType::get(2, i32); - Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); - + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isBF16()) { // CvtPkScalePk8Bf16Fp4Op // i32 Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource()); + LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isBF16()) { // CvtPkScalePk8Bf16Fp8Op // vector<8xf8E4M3FN> - - // vector<2xi32> - VectorType v2xi32 = VectorType::get(2, i32); - Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); - + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isBF16()) { // CvtPkScalePk8Bf16Bf8Op // vector<8xf8E5M2> - - // vector<2xi32> - VectorType v2xi32 = VectorType::get(2, i32); - Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); - + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF32()) { // CvtPkScalePk8F32Fp4Op // i32 Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getSource()); + LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF32()) { // CvtPkScalePk8F32Fp8Op // vector<8xf8E4M3FN> - - // vector<2xi32> - VectorType v2xi32 = VectorType::get(2, i32); - Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); - + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF32()) { // CvtPkScalePk8F32Bf8Op // vector<8xf8E5M2> - - // vector<2xi32> - VectorType v2xi32 = VectorType::get(2, i32); - Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v2xi32, source); - + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } @@ -1808,44 +1802,38 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( // CvtPkScalePk16${largeT}${smallT} else if (isa(srcElemType) and destElemType.isF16()) { // CvtPkScale16F16Fp6Op - VectorType v3xi32 = VectorType::get(3, i32); - Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); - + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF16()) { // CvtPkScale16F16Bf6Op - VectorType v3xi32 = VectorType::get(3, i32); - Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); - + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isBF16()) { // CvtPkScale16Bf16Fp6Op - VectorType v3xi32 = VectorType::get(3, i32); - Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); - + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isBF16()) { // CvtPkScale16Bf16Bf6Op - VectorType v3xi32 = VectorType::get(3, i32); - Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); - + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF32()) { // CvtPkScale16F32Fp6Op - VectorType v3xi32 = VectorType::get(3, i32); - Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); - + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF32()) { // CvtPkScale16F32Bf6Op - VectorType v3xi32 = VectorType::get(3, i32); - Value castedSource = LLVM::BitcastOp::create(rewriter, loc, v3xi32, source); - + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else { From 9e4ab0e7b03009fe50cddaf8f218e93fe0bc82f1 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 14:02:38 -0500 Subject: [PATCH 24/44] Hoist variable. NFC --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 56 +------------------ 1 file changed, 2 insertions(+), 54 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 824e4249088ae..3d41d47da6e00 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1728,68 +1728,34 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( // // largeT = [F16, Bf16, F32] // CvtPkScalePk8${largeT}${smallT} + Value castedSource = + LLVM::BitcastOp::create(rewriter, loc, packedType, source); if (isa(srcElemType) and destElemType.isF16()) { - // CvtPkScalePk8F16Fp4Op - // i32 - Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF16()) { - // CvtPkScalePk8F16Fp8Op - // vector<8xf8E4M3FN> - Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF16()) { - // CvtPkScalePk8F16Bf8Op - // vector<8xf8E5M2> - Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isBF16()) { - // CvtPkScalePk8Bf16Fp4Op - // i32 - Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isBF16()) { - // CvtPkScalePk8Bf16Fp8Op - // vector<8xf8E4M3FN> - Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isBF16()) { - // CvtPkScalePk8Bf16Bf8Op - // vector<8xf8E5M2> - Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF32()) { - // CvtPkScalePk8F32Fp4Op - // i32 - Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF32()) { - // CvtPkScalePk8F32Fp8Op - // vector<8xf8E4M3FN> - Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF32()) { - // CvtPkScalePk8F32Bf8Op - // vector<8xf8E5M2> - Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } @@ -1801,39 +1767,21 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( // CvtPkScalePk16${largeT}${smallT} else if (isa(srcElemType) and destElemType.isF16()) { - // CvtPkScale16F16Fp6Op - Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF16()) { - // CvtPkScale16F16Bf6Op - Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isBF16()) { - // CvtPkScale16Bf16Fp6Op - Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isBF16()) { - // CvtPkScale16Bf16Bf6Op - Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF32()) { - // CvtPkScale16F32Fp6Op - Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else if (isa(srcElemType) and destElemType.isF32()) { - // CvtPkScale16F32Bf6Op - Value castedSource = - LLVM::BitcastOp::create(rewriter, loc, packedType, source); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else { From 728686f4ab79a5f64c389c83fc47447f0b69b663 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 14:03:55 -0500 Subject: [PATCH 25/44] Comments. NFC --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 3d41d47da6e00..0affa2ead9f78 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1721,7 +1721,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( } else { llvm_unreachable("invalid element type for scaled ext"); } - // Ok, so we need to construct ops depending on the sourceType and targetType. // smallT = [Fp4, Fp8, Bf8] // Bf8 = E5M2 // Fp8 = E4M3 @@ -1760,11 +1759,10 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } // smallT = [Fp6, Bf6] + // Fp6 = Float6E2M3FN + // Bf6 = Float6E3M2FN // largeT = [F16, Bf16, F32] // - // Fp6 = Float6E2M3FN - // Bf6 = Float6E3M2FN - // CvtPkScalePk16${largeT}${smallT} else if (isa(srcElemType) and destElemType.isF16()) { rewriter.replaceOpWithNewOp( From 47dc32e6698b7ea8f3926bc39d97273b32a012df Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 14:09:01 -0500 Subject: [PATCH 26/44] refactor. nfc --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 0affa2ead9f78..ec889879fd8d4 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1701,9 +1701,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( auto targetType = cast(op.getResult().getType()); auto destElemType = cast(targetType.getElementType()); Location loc = op.getLoc(); - // %scale: vector<4xf8E8M0FNU> - // =========================== - // %scale: i32 IntegerType i32 = rewriter.getI32Type(); Value castedScale = LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale()); @@ -1730,31 +1727,31 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( Value castedSource = LLVM::BitcastOp::create(rewriter, loc, packedType, source); - if (isa(srcElemType) and destElemType.isF16()) { + if (isa(srcElemType) && destElemType.isF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and destElemType.isF16()) { + } else if (isa(srcElemType) && destElemType.isF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and destElemType.isF16()) { + } else if (isa(srcElemType) && destElemType.isF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and destElemType.isBF16()) { + } else if (isa(srcElemType) && destElemType.isBF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and destElemType.isBF16()) { + } else if (isa(srcElemType) && destElemType.isBF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and destElemType.isBF16()) { + } else if (isa(srcElemType) && destElemType.isBF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and destElemType.isF32()) { + } else if (isa(srcElemType) && destElemType.isF32()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and destElemType.isF32()) { + } else if (isa(srcElemType) && destElemType.isF32()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and destElemType.isF32()) { + } else if (isa(srcElemType) && destElemType.isF32()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } @@ -1764,22 +1761,22 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( // largeT = [F16, Bf16, F32] // // CvtPkScalePk16${largeT}${smallT} - else if (isa(srcElemType) and destElemType.isF16()) { + else if (isa(srcElemType) && destElemType.isF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and destElemType.isF16()) { + } else if (isa(srcElemType) && destElemType.isF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and destElemType.isBF16()) { + } else if (isa(srcElemType) && destElemType.isBF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and destElemType.isBF16()) { + } else if (isa(srcElemType) && destElemType.isBF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and destElemType.isF32()) { + } else if (isa(srcElemType) && destElemType.isF32()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) and destElemType.isF32()) { + } else if (isa(srcElemType) && destElemType.isF32()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else { From 3cfea7e2abf631970b8a115c3ed1f7d56267ad38 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 11 Nov 2025 14:14:44 -0500 Subject: [PATCH 27/44] Keep conventions --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index ec889879fd8d4..6dbf57342cb3f 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1709,12 +1709,15 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( Type packedType; if (isa(srcElemType)) { packedType = i32; + packedType = getTypeConverter()->convertType(packedType); } else if (isa(srcElemType) || isa(srcElemType)) { packedType = VectorType::get(2, i32); + packedType = getTypeConverter()->convertType(packedType); } else if (isa(srcElemType) || isa(srcElemType)) { packedType = VectorType::get(3, i32); + packedType = getTypeConverter()->convertType(packedType); } else { llvm_unreachable("invalid element type for scaled ext"); } From 2b010cd04d4644b06a864777de0c63edae4be0c6 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 17 Nov 2025 09:35:11 -0500 Subject: [PATCH 28/44] Less of exhaustive enumeration --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 71 +++++-------------- 1 file changed, 19 insertions(+), 52 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 6dbf57342cb3f..a9f58063ac32b 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1627,62 +1627,29 @@ int getScaleSel(int blockSize, int bitWidth, int firstScaleLane, const bool is_fp8 = bitWidth == 8; const bool is_block_16 = blockSize == 16; - if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) { - return 0b000; - } - if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) { - return 0b001; - } - if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) { - return 0b010; - } - if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && is_block_16) { - return 0b011; - } - if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && !is_block_16) { - return 0b100; - } - if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && is_block_16) { - return 0b101; - } - if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) { - return 0b110; - } - if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) { - return 0b111; - } - - if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) { - return 0b0000; - } - if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) { - return 0b0001; - } - if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 1 && !is_block_16) { - return 0b0010; - } - if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) { - return 0b0100; - } - if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 3 && !is_block_16) { - return 0b0110; - } - if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 1 && !is_block_16) { - return 0b1010; - } - if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) { - return 0b1100; - } - if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) { - return 0b1101; - } - if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 3 && !is_block_16) { - return 0b1110; + if (!is_fp8) { + int bit_0 = is_block_16; + assert(llvm::is_contained({0, 2}, firstScaleByte)); + int bit_1 = (firstScaleByte == 2) << 1; + assert(llvm::is_contained({0, 1}, firstScaleLane)); + int bit_2 = firstScaleLane << 2; + return bit_2 | bit_1 | bit_0; + } else { + int bit_0 = is_block_16; + // firstScaleByte is guaranteed to be defined by two bits. + assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte)); + int bit_2_and_1 = firstScaleByte << 1; + assert(llvm::is_contained({0, 1}, firstScaleLane)); + int bit_3 = firstScaleLane << 3; + int bits = bit_3 | bit_2_and_1 | bit_0; + // These are invalid cases. + assert(!llvm::is_contained( + {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits)); + return bits; } llvm_unreachable("invalid combination of firstScaleLane, firstScaleByte, " "blockSize and type."); - return 0; } LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( From 6f07ef03b6f778fec82abc85914c6b47095d2312 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 17 Nov 2025 09:39:10 -0500 Subject: [PATCH 29/44] Correct types --- .../lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index a9f58063ac32b..596cac1b76469 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1613,8 +1613,8 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( return success(); } -int getScaleSel(int blockSize, int bitWidth, int firstScaleLane, - int firstScaleByte) { +int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, + int32_t firstScaleLane, int32_t firstScaleByte) { // When lowering amdgpu.scaled_ext_packed816 to // rocdl.cvt.scale.pk*.f*.f* operations, the // attributes blockSize, sourceType, firstScaleLane and firstScaleByte @@ -1656,13 +1656,13 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - int firstScaleLane = op.getFirstScaleLane(); - int firstScaleByte = op.getFirstScaleByte(); - int blockSize = op.getBlockSize(); + int32_t firstScaleLane = op.getFirstScaleLane(); + int32_t firstScaleByte = op.getFirstScaleByte(); + int32_t blockSize = op.getBlockSize(); auto sourceType = cast(op.getSource().getType()); auto srcElemType = cast(sourceType.getElementType()); - int bitWidth = srcElemType.getWidth(); - int scaleSel = + unsigned bitWidth = srcElemType.getWidth(); + int32_t scaleSel = getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte); auto targetType = cast(op.getResult().getType()); From 69787933cc44ed01faff7abf38078c0f61917bbd Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 17 Nov 2025 09:39:48 -0500 Subject: [PATCH 30/44] Reflow comment --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 596cac1b76469..a7c9a68dd7731 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1615,12 +1615,10 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t firstScaleLane, int32_t firstScaleByte) { - // When lowering amdgpu.scaled_ext_packed816 to - // rocdl.cvt.scale.pk*.f*.f* operations, the - // attributes blockSize, sourceType, firstScaleLane and firstScaleByte - // are merged into a single attribute scaleSel. - // - // This is how those values are merged together. + // When lowering amdgpu.scaled_ext_packed816 to rocdl.cvt.scale.pk*.f*.f* + // operations, the attributes blockSize, sourceType, firstScaleLane and + // firstScaleByte are merged into a single attribute scaleSel. This is how + // those values are merged together. assert(llvm::is_contained({16, 32}, blockSize)); assert(llvm::is_contained({4, 6, 8}, bitWidth)); From ed66571310d9fd687e1b549beb5a79704c629cb4 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 17 Nov 2025 09:40:29 -0500 Subject: [PATCH 31/44] superfluous empty line --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index a7c9a68dd7731..15c511460a552 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1653,7 +1653,6 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - int32_t firstScaleLane = op.getFirstScaleLane(); int32_t firstScaleByte = op.getFirstScaleByte(); int32_t blockSize = op.getBlockSize(); From 33ef57e0dce2640dc8c3cf3c5623ffc71eb42d18 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 17 Nov 2025 09:55:53 -0500 Subject: [PATCH 32/44] Using --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 58 ++++++++----------- 1 file changed, 24 insertions(+), 34 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 15c511460a552..b0b0d4a9fe604 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1620,7 +1620,7 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, // firstScaleByte are merged into a single attribute scaleSel. This is how // those values are merged together. assert(llvm::is_contained({16, 32}, blockSize)); - assert(llvm::is_contained({4, 6, 8}, bitWidth)); + assert(llvm::is_contained(::llvm::ArrayRef{4, 6, 8}, bitWidth)); const bool is_fp8 = bitWidth == 8; const bool is_block_16 = blockSize == 16; @@ -1653,6 +1653,11 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + using fp4 = Float4E2M1FNType; + using fp8 = Float8E4M3FNType; + using bf8 = Float8E5M2Type; + using fp6 = Float6E2M3FNType; + using bf6 = Float6E3M2FNType; int32_t firstScaleLane = op.getFirstScaleLane(); int32_t firstScaleByte = op.getFirstScaleByte(); int32_t blockSize = op.getBlockSize(); @@ -1671,79 +1676,64 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( Value source = adaptor.getSource(); Type packedType; - if (isa(srcElemType)) { + if (isa(srcElemType)) { packedType = i32; packedType = getTypeConverter()->convertType(packedType); - } else if (isa(srcElemType) || - isa(srcElemType)) { + } else if (isa(srcElemType)) { packedType = VectorType::get(2, i32); packedType = getTypeConverter()->convertType(packedType); - } else if (isa(srcElemType) || - isa(srcElemType)) { + } else if (isa(srcElemType)) { packedType = VectorType::get(3, i32); packedType = getTypeConverter()->convertType(packedType); } else { llvm_unreachable("invalid element type for scaled ext"); } - // smallT = [Fp4, Fp8, Bf8] - // Bf8 = E5M2 - // Fp8 = E4M3 - // - // largeT = [F16, Bf16, F32] - // CvtPkScalePk8${largeT}${smallT} Value castedSource = LLVM::BitcastOp::create(rewriter, loc, packedType, source); - if (isa(srcElemType) && destElemType.isF16()) { + if (isa(srcElemType) && destElemType.isF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isF16()) { + } else if (isa(srcElemType) && destElemType.isF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isF16()) { + } else if (isa(srcElemType) && destElemType.isF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isBF16()) { + } else if (isa(srcElemType) && destElemType.isBF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isBF16()) { + } else if (isa(srcElemType) && destElemType.isBF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isBF16()) { + } else if (isa(srcElemType) && destElemType.isBF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isF32()) { + } else if (isa(srcElemType) && destElemType.isF32()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isF32()) { + } else if (isa(srcElemType) && destElemType.isF32()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isF32()) { + } else if (isa(srcElemType) && destElemType.isF32()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } - // smallT = [Fp6, Bf6] - // Fp6 = Float6E2M3FN - // Bf6 = Float6E3M2FN - // largeT = [F16, Bf16, F32] - // - // CvtPkScalePk16${largeT}${smallT} - else if (isa(srcElemType) && destElemType.isF16()) { + } else if (isa(srcElemType) && destElemType.isF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isF16()) { + } else if (isa(srcElemType) && destElemType.isF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isBF16()) { + } else if (isa(srcElemType) && destElemType.isBF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isBF16()) { + } else if (isa(srcElemType) && destElemType.isBF16()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isF32()) { + } else if (isa(srcElemType) && destElemType.isF32()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isF32()) { + } else if (isa(srcElemType) && destElemType.isF32()) { rewriter.replaceOpWithNewOp( op, op.getResult().getType(), castedSource, castedScale, scaleSel); } else { From a83cec94451eae1889c92e333dd6fa7b47904bad Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 17 Nov 2025 10:18:27 -0500 Subject: [PATCH 33/44] Add chipset check and moved tests --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 8 +- .../AMDGPUToROCDL/amdgpu-to-rocdl.mlir | 114 ----------------- .../AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir | 116 ++++++++++++++++++ 3 files changed, 123 insertions(+), 115 deletions(-) create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index b0b0d4a9fe604..7e30feb520a72 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1658,6 +1658,13 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( using bf8 = Float8E5M2Type; using fp6 = Float6E2M3FNType; using bf6 = Float6E3M2FNType; + Location loc = op.getLoc(); + if (chipset != Chipset{12, 5, 0}) { + return rewriter.notifyMatchFailure( + loc, + "Scaled fp packed conversion instructions are not available on target " + "architecture and their emulation is not implemented"); + } int32_t firstScaleLane = op.getFirstScaleLane(); int32_t firstScaleByte = op.getFirstScaleByte(); int32_t blockSize = op.getBlockSize(); @@ -1669,7 +1676,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( auto targetType = cast(op.getResult().getType()); auto destElemType = cast(targetType.getElementType()); - Location loc = op.getLoc(); IntegerType i32 = rewriter.getI32Type(); Value castedScale = LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale()); diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index 94a04d98004c7..432b8876696a9 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -457,117 +457,3 @@ func.func @sched_barrier() { func.return } -// CHECK-LABEL: @scaled_ext_packed816_fp4 -// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) -func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { - // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> - // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4> - // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 - // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32 - // CHECK: rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16> - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> - - // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 - // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32 - // CHECK: rocdl.cvt.scale.pk8.bf16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xbf16> - %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> - - // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 - // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32 - // CHECK: rocdl.cvt.scale.pk8.f32.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf32> - %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32> - func.return %ret0, %ret1, %ret2: vector<8xf16>, vector<8xbf16>, vector<8xf32> -} - -// CHECK-LABEL: @scaled_ext_packed816_fp8 -// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) -func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { - // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> - // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8> - // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 - // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> - // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16> - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16> - - // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 - // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> - // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16> - %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> - - // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 - // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> - // CHECK: rocdl.cvt.scale.pk8.f32.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32> - %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32> - - func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> -} - -// CHECK-LABEL: @scaled_ext_packed816_bf8 -// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) -func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { - // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> - // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8> - // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 - // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> - // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16> - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> - - // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 - // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> - // CHECK: rocdl.cvt.scale.pk8.bf16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16> - %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16> - - // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 - // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> - // CHECK: rocdl.cvt.scale.pk8.f32.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32> - %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32> - func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> -} - - -// CHECK-LABEL: @scaled_ext_packed816_fp6 -// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) -func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { - // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> - // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6> - - // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 - // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> - // CHECK: rocdl.cvt.scale.pk16.f16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16> - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16> - - // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 - // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> - // CHECK: rocdl.cvt.scale.pk16.bf16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16> - %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> - - // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 - // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> - // CHECK: rocdl.cvt.scale.pk16.f32.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32> - %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32> - return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32> -} - -// CHECK-LABEL: @scaled_ext_packed816_bf6 -// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) -func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { - // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> - // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6> - - // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 - // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> - // CHECK: rocdl.cvt.scale.pk16.f16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16> - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16> - - // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 - // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> - // CHECK: rocdl.cvt.scale.pk16.bf16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16> - %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> - - // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 - // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> - // CHECK: rocdl.cvt.scale.pk16.f32.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32> - %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32> - return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32> -} - diff --git a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir new file mode 100644 index 0000000000000..811a8e49dc5c6 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir @@ -0,0 +1,116 @@ +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 | FileCheck %s + +// CHECK-LABEL: @scaled_ext_packed816_fp4 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { + // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK: %[[SOURCE_8xi4:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf4E2M1FN> to vector<8xi4> + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32 + // CHECK: rocdl.cvt.scale.pk8.f16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf16> + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32 + // CHECK: rocdl.cvt.scale.pk8.bf16.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xbf16> + %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_i32:.+]] = llvm.bitcast %[[SOURCE_8xi4]] : vector<8xi4> to i32 + // CHECK: rocdl.cvt.scale.pk8.f32.fp4 %[[SOURCE_i32]], %[[SCALE_i32]][0] : vector<8xf32> + %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf32> + func.return %ret0, %ret1, %ret2: vector<8xf16>, vector<8xbf16>, vector<8xf32> +} + +// CHECK-LABEL: @scaled_ext_packed816_fp8 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E4M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { + // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E4M3FN> to vector<8xi8> + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.f16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16> + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.bf16.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16> + %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.f32.fp8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32> + %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32> + + func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> +} + +// CHECK-LABEL: @scaled_ext_packed816_bf8 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf8E5M2>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) { + // CHECK: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK: %[[SOURCE_8xi8:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<8xf8E5M2> to vector<8xi8> + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: %[[RES:.+]] = rocdl.cvt.scale.pk8.f16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf16> + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.bf16.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xbf16> + %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v2xi32:.+]] = llvm.bitcast %[[SOURCE_8xi8]] : vector<8xi8> to vector<2xi32> + // CHECK: rocdl.cvt.scale.pk8.f32.bf8 %[[SOURCE_v2xi32]], %[[SCALE_i32]][0] : vector<8xf32> + %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf32> + func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> +} + + +// CHECK-LABEL: @scaled_ext_packed816_fp6 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E2M3FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { + // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E2M3FN> to vector<16xi6> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.f16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16> + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.bf16.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16> + %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.f32.fp6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32> + %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf32> + return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32> +} + +// CHECK-LABEL: @scaled_ext_packed816_bf6 +// CHECK-SAME: (%[[SOURCE:.+]]: vector<16xf6E3M2FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) +func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) { + // CHECK-DAG: %[[SCALE_4xi8:.+]] = builtin.unrealized_conversion_cast %[[SCALE]] : vector<4xf8E8M0FNU> to vector<4xi8> + // CHECK-DAG: %[[SOURCE_16xi6:.+]] = builtin.unrealized_conversion_cast %[[SOURCE]] : vector<16xf6E3M2FN> to vector<16xi6> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.f16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf16> + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.bf16.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xbf16> + %ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xbf16> + + // CHECK: %[[SCALE_i32:.+]] = llvm.bitcast %[[SCALE_4xi8]] : vector<4xi8> to i32 + // CHECK: %[[SOURCE_v3xi32:.+]] = llvm.bitcast %[[SOURCE_16xi6]] : vector<16xi6> to vector<3xi32> + // CHECK: rocdl.cvt.scale.pk16.f32.bf6 %[[SOURCE_v3xi32]], %[[SCALE_i32]][0] : vector<16xf32> + %ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32> + return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32> +} + From 34ed3e9384a683e44b967baff53fb952b3320e90 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 17 Nov 2025 10:21:24 -0500 Subject: [PATCH 34/44] Refactor NFC --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 7e30feb520a72..0c5f4ff7f8227 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1632,22 +1632,19 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, assert(llvm::is_contained({0, 1}, firstScaleLane)); int bit_2 = firstScaleLane << 2; return bit_2 | bit_1 | bit_0; - } else { - int bit_0 = is_block_16; - // firstScaleByte is guaranteed to be defined by two bits. - assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte)); - int bit_2_and_1 = firstScaleByte << 1; - assert(llvm::is_contained({0, 1}, firstScaleLane)); - int bit_3 = firstScaleLane << 3; - int bits = bit_3 | bit_2_and_1 | bit_0; - // These are invalid cases. - assert(!llvm::is_contained( - {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits)); - return bits; } - llvm_unreachable("invalid combination of firstScaleLane, firstScaleByte, " - "blockSize and type."); + int bit_0 = is_block_16; + // firstScaleByte is guaranteed to be defined by two bits. + assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte)); + int bit_2_and_1 = firstScaleByte << 1; + assert(llvm::is_contained({0, 1}, firstScaleLane)); + int bit_3 = firstScaleLane << 3; + int bits = bit_3 | bit_2_and_1 | bit_0; + // These are invalid cases. + assert(!llvm::is_contained( + {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits)); + return bits; } LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( From b88f7f6e74c0038873dabec3d31b2bba12b8e6ab Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 17 Nov 2025 10:44:41 -0500 Subject: [PATCH 35/44] Use operation name --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 105 ++++++++++-------- 1 file changed, 57 insertions(+), 48 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 0c5f4ff7f8227..23660361094c3 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1647,6 +1647,46 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, return bits; } +static std::optional +scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) { + using fp4 = Float4E2M1FNType; + using fp8 = Float8E4M3FNType; + using bf8 = Float8E5M2Type; + using fp6 = Float6E2M3FNType; + using bf6 = Float6E3M2FNType; + if (isa(srcElemType) && destElemType.isF16()) + return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName(); + if (isa(srcElemType) && destElemType.isF16()) + return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName(); + if (isa(srcElemType) && destElemType.isF16()) + return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName(); + if (isa(srcElemType) && destElemType.isBF16()) + return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName(); + if (isa(srcElemType) && destElemType.isBF16()) + return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName(); + if (isa(srcElemType) && destElemType.isBF16()) + return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName(); + if (isa(srcElemType) && destElemType.isF32()) + return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName(); + if (isa(srcElemType) && destElemType.isF32()) + return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName(); + if (isa(srcElemType) && destElemType.isF32()) + return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName(); + if (isa(srcElemType) && destElemType.isF16()) + return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName(); + if (isa(srcElemType) && destElemType.isF16()) + return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName(); + if (isa(srcElemType) && destElemType.isBF16()) + return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName(); + if (isa(srcElemType) && destElemType.isBF16()) + return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName(); + if (isa(srcElemType) && destElemType.isF32()) + return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName(); + if (isa(srcElemType) && destElemType.isF32()) + return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName(); + return std::nullopt; +} + LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -1694,54 +1734,23 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( Value castedSource = LLVM::BitcastOp::create(rewriter, loc, packedType, source); - if (isa(srcElemType) && destElemType.isF16()) { - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isF16()) { - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isF16()) { - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isBF16()) { - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isBF16()) { - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isBF16()) { - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isF32()) { - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isF32()) { - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isF32()) { - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isF16()) { - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isF16()) { - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isBF16()) { - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isBF16()) { - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isF32()) { - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else if (isa(srcElemType) && destElemType.isF32()) { - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), castedSource, castedScale, scaleSel); - } else { - return failure(); - } + std::optional maybeIntrinsic = + scaledExtPacked816ToIntrinsic(srcElemType, destElemType); + if (!maybeIntrinsic.has_value()) + return op.emitOpError( + "no intrinsic matching packed scaled conversion on the given chipset"); + + OperationState loweredOp(loc, *maybeIntrinsic); + loweredOp.addTypes({op.getResult().getType()}); + loweredOp.addOperands({castedSource, castedScale}); + + SmallVector attrs; + attrs.push_back( + NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel))); + + loweredOp.addAttributes(attrs); + Operation *lowered = rewriter.create(loweredOp); + rewriter.replaceOp(op, lowered); return success(); } From 7a7ecaf31b5ad267343b871d5f2f44b8eff875b3 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 17 Nov 2025 10:48:38 -0500 Subject: [PATCH 36/44] Convert result type --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 23660361094c3..2e73f7b0d4266 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1741,7 +1741,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( "no intrinsic matching packed scaled conversion on the given chipset"); OperationState loweredOp(loc, *maybeIntrinsic); - loweredOp.addTypes({op.getResult().getType()}); + Type llvmResultType = typeConverter->convertType(op.getResult().getType()); + loweredOp.addTypes({llvmResultType}); loweredOp.addOperands({castedSource, castedScale}); SmallVector attrs; From 1025e2b999b7f42947842cfb3185b921a0ea534d Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 17 Nov 2025 10:58:20 -0500 Subject: [PATCH 37/44] Check for type conversion failures --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 2e73f7b0d4266..ae907b8ffefc3 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1718,7 +1718,8 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale()); Value source = adaptor.getSource(); - Type packedType; + Type llvmResultType = typeConverter->convertType(op.getResult().getType()); + Type packedType = nullptr; if (isa(srcElemType)) { packedType = i32; packedType = getTypeConverter()->convertType(packedType); @@ -1729,8 +1730,13 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( packedType = VectorType::get(3, i32); packedType = getTypeConverter()->convertType(packedType); } else { - llvm_unreachable("invalid element type for scaled ext"); + llvm_unreachable("invalid element type for packed scaled ext"); + } + + if (!packedType || !llvmResultType) { + return rewriter.notifyMatchFailure(op, "type conversion failed"); } + Value castedSource = LLVM::BitcastOp::create(rewriter, loc, packedType, source); @@ -1741,7 +1747,6 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( "no intrinsic matching packed scaled conversion on the given chipset"); OperationState loweredOp(loc, *maybeIntrinsic); - Type llvmResultType = typeConverter->convertType(op.getResult().getType()); loweredOp.addTypes({llvmResultType}); loweredOp.addOperands({castedSource, castedScale}); From 7c44f0959a85b0a801a687e781605df81e3b1c6d Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 17 Nov 2025 11:21:25 -0500 Subject: [PATCH 38/44] Add top-level if condition for each src type --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 78 +++++++++++-------- 1 file changed, 47 insertions(+), 31 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index f9bbc6535f0e1..d23a89a199131 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1667,37 +1667,53 @@ scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) { using bf8 = Float8E5M2Type; using fp6 = Float6E2M3FNType; using bf6 = Float6E3M2FNType; - if (isa(srcElemType) && destElemType.isF16()) - return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName(); - if (isa(srcElemType) && destElemType.isF16()) - return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName(); - if (isa(srcElemType) && destElemType.isF16()) - return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName(); - if (isa(srcElemType) && destElemType.isBF16()) - return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName(); - if (isa(srcElemType) && destElemType.isBF16()) - return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName(); - if (isa(srcElemType) && destElemType.isBF16()) - return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName(); - if (isa(srcElemType) && destElemType.isF32()) - return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName(); - if (isa(srcElemType) && destElemType.isF32()) - return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName(); - if (isa(srcElemType) && destElemType.isF32()) - return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName(); - if (isa(srcElemType) && destElemType.isF16()) - return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName(); - if (isa(srcElemType) && destElemType.isF16()) - return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName(); - if (isa(srcElemType) && destElemType.isBF16()) - return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName(); - if (isa(srcElemType) && destElemType.isBF16()) - return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName(); - if (isa(srcElemType) && destElemType.isF32()) - return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName(); - if (isa(srcElemType) && destElemType.isF32()) - return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName(); - return std::nullopt; + if (isa(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName(); + return std::nullopt; + } + if (isa(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName(); + return std::nullopt; + } + if (isa(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName(); + return std::nullopt; + } + if (isa(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName(); + return std::nullopt; + } + if (isa(srcElemType)) { + if (destElemType.isF16()) + return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName(); + if (destElemType.isBF16()) + return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName(); + if (destElemType.isF32()) + return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName(); + return std::nullopt; + } + llvm_unreachable("invalid combination of element types for packed conversion " + "instructions"); } LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( From f06a67e94e08dd242c6a89dd06023b6ce5d95ad0 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 17 Nov 2025 11:37:30 -0500 Subject: [PATCH 39/44] Add chipset constant at beginning of file --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index d23a89a199131..4d1734fe710c0 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -43,6 +43,7 @@ constexpr Chipset kGfx908 = Chipset(9, 0, 8); constexpr Chipset kGfx90a = Chipset(9, 0, 0xa); constexpr Chipset kGfx942 = Chipset(9, 4, 2); constexpr Chipset kGfx950 = Chipset(9, 5, 0); +constexpr Chipset kGfx1250 = Chipset(12, 5, 0); /// Convert an unsigned number `val` to i32. static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, @@ -1149,7 +1150,7 @@ static std::optional wmmaOpToIntrinsic(WMMAOp wmma, k, isRDNA3); // Handle gfx1250. - if (chipset == Chipset{12, 5, 0}) + if (chipset == kGfx1250) return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType, elemDestType, k); @@ -1300,7 +1301,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { if (chipset.majorVersion != 11 && chipset.majorVersion != 12) return op->emitOpError("WMMA only supported on gfx11 and gfx12"); - bool isGFX1250 = chipset >= Chipset(12, 5, 0); + bool isGFX1250 = chipset >= kGfx1250; // The WMMA operations represent vectors of bf16s as vectors of i16s // (except on gfx1250), so we need to bitcast bfloats to i16 and then @@ -1725,7 +1726,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( using fp6 = Float6E2M3FNType; using bf6 = Float6E3M2FNType; Location loc = op.getLoc(); - if (chipset != Chipset{12, 5, 0}) { + if (chipset != kGfx1250) { return rewriter.notifyMatchFailure( loc, "Scaled fp packed conversion instructions are not available on target " From 1dbcb95412e274d598c337b700b2524130fd3d25 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 17 Nov 2025 11:42:14 -0500 Subject: [PATCH 40/44] wip --- mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir index 811a8e49dc5c6..87e4eb363d343 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 | FileCheck %s +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --split-input-file --verify-diagnostics \ +// RUN: | FileCheck %s // CHECK-LABEL: @scaled_ext_packed816_fp4 // CHECK-SAME: (%[[SOURCE:.+]]: vector<8xf4E2M1FN>, %[[SCALE:.+]]: vector<4xf8E8M0FNU>) From eee0ce97137e993bed9450d281536e74d21d8108 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 17 Nov 2025 11:52:55 -0500 Subject: [PATCH 41/44] Add invalid srcElemType case --- .../AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir | 40 +++++++++++++++++++ mlir/test/Dialect/AMDGPU/invalid.mlir | 32 --------------- 2 files changed, 40 insertions(+), 32 deletions(-) diff --git a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir index 87e4eb363d343..73711a2b98ac9 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir @@ -115,3 +115,43 @@ func.func @scaled_ext_packed816_bf6(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8 return %ret0, %ret1, %ret2: vector<16xf16>, vector<16xbf16>, vector<16xf32> } +// ----- + +func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1 for f4 and f6}} + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + func.return +} + +// ----- + +func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2 for f4 and f6.}} + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> + func.return +} + +// ----- + +func.func @amdgpu.scaled_ext_packed816_invalid_attributes_for_f8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.}} + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> + func.return +} + +// ----- + +func.func @amdgpu.scaled_ext_packed816_invalid_input_output_sizes(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op failed to verify that all of {source, res} have same shape}} + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<16xf16> + func.return +} + +// ----- + +func.func @amdgpu.scaled_ext_packed816_invalid_src_elem_type(%v: vector<16xf16>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op operand #0 must be}} + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf16>, vector<4xf8E8M0FNU> -> vector<16xf16> + return %ret0: vector<16xf16> +} + diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir index 5c8cc8b67c4b3..61fdf29a78cbd 100644 --- a/mlir/test/Dialect/AMDGPU/invalid.mlir +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -333,38 +333,6 @@ func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 : // ----- -func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) { - // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1 for f4 and f6}} - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> - func.return -} - -// ----- - -func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) { - // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2 for f4 and f6.}} - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16> - func.return -} - -// ----- - -func.func @amdgpu.scaled_ext_packed816_invalid_attributes_for_f8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { - // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.}} - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> - func.return -} - -// ----- - -func.func @amdgpu.scaled_ext_packed816_invalid_input_output_sizes(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { - // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op failed to verify that all of {source, res} have same shape}} - %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<16xf16> - func.return -} - -// ----- - func.func @scaled_mfma_invalid_m(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> { // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}} %0 = amdgpu.scaled_mfma 8x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32> From 9860cdd9ffb108f472a18dff751a3401e13f695e Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 17 Nov 2025 14:28:08 -0500 Subject: [PATCH 42/44] Update verifiers --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 2 +- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 45 ++++++++++++------- .../AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir | 7 +++ 3 files changed, 37 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 4d1734fe710c0..bf83591bf6047 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1634,7 +1634,7 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, // firstScaleByte are merged into a single attribute scaleSel. This is how // those values are merged together. assert(llvm::is_contained({16, 32}, blockSize)); - assert(llvm::is_contained(::llvm::ArrayRef{4, 6, 8}, bitWidth)); + assert(llvm::is_contained(llvm::ArrayRef{4, 6, 8}, bitWidth)); const bool is_fp8 = bitWidth == 8; const bool is_block_16 = blockSize == 16; diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 5c35823678576..955de3bb861ba 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -343,28 +343,41 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns( //===----------------------------------------------------------------------===// LogicalResult ScaledExtPacked816Op::verify() { int blockSize = getBlockSize(); - assert((blockSize == 16 || blockSize == 32) && "invalid block size"); + assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size"); int firstScaleByte = getFirstScaleByte(); + int firstScaleLane = getFirstScaleLane(); auto sourceType = cast(getSource().getType()); Type elementType = sourceType.getElementType(); auto floatType = cast(elementType); - int bitWidth = floatType.getWidth(); + unsigned bitWidth = floatType.getWidth(); - if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 16 && - !llvm::is_contained({0, 1}, firstScaleByte)) { - return emitOpError("blockSize of 16 can only have firstScaleByte be 0 or 1 " - "for f4 and f6."); - } - if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 32 && - !llvm::is_contained({0, 2}, firstScaleByte)) { - return emitOpError("blockSize of 32 can only have firstScaleByte be 0 or 2 " - "for f4 and f6."); - } - if (bitWidth == 8 && blockSize == 16 && - !llvm::is_contained({0, 2}, firstScaleByte)) { - return emitOpError( - "blockSize of 16 can only have firstScaleByte be 0 or 2 for f8."); + assert(llvm::is_contained(llvm::ArrayRef{4, 6, 8}, bitWidth)); + + const bool is_fp8 = bitWidth == 8; + const bool is_block_16 = blockSize == 16; + + if (!is_fp8) { + if (is_block_16) { + if (!llvm::is_contained({0, 1}, firstScaleByte)) { + return emitOpError("blockSize of 16 can only have firstScaleByte be 0 " + "or 1 for f4 and f6."); + } + } else { + if (!llvm::is_contained({0, 2}, firstScaleByte)) { + return emitOpError("blockSize of 32 can only have firstScaleByte be 0 " + "or 2 for f4 and f6."); + } + } + } else { + if (is_block_16) { + bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) || + ((firstScaleLane == 1) && (firstScaleByte == 2)); + if (!is_valid) { + return emitOpError( + "blockSize of 16 can only have firstScaleByte be 0 or 2 for f8."); + } + } } return success(); diff --git a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir index 73711a2b98ac9..fbe13a29c53ab 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir @@ -155,3 +155,10 @@ func.func @amdgpu.scaled_ext_packed816_invalid_src_elem_type(%v: vector<16xf16>, return %ret0: vector<16xf16> } +// ----- + +func.func @amdgpu.scaled_ext_packed816_invalid_dst_elem_type(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf64>) { + // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op result #0 must be vector}} + %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf64> + return %ret0: vector<16xf64> +} From 0b8f561e00d2cea1a6e60e53e13720a617611330 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 17 Nov 2025 14:35:10 -0500 Subject: [PATCH 43/44] Update verifier message --- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 4 ++-- mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 955de3bb861ba..d55f3cec47c1f 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -374,8 +374,8 @@ LogicalResult ScaledExtPacked816Op::verify() { bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) || ((firstScaleLane == 1) && (firstScaleByte == 2)); if (!is_valid) { - return emitOpError( - "blockSize of 16 can only have firstScaleByte be 0 or 2 for f8."); + return emitOpError("blockSize of 16 can only have (firstScaleLane, " + "firstScaleByte) be (0, 0) or (1, 2) for f8."); } } } diff --git a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir index fbe13a29c53ab..d2391140ce056 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir @@ -134,7 +134,7 @@ func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_3 // ----- func.func @amdgpu.scaled_ext_packed816_invalid_attributes_for_f8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) { - // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.}} + // expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have (firstScaleLane, firstScaleByte) be (0, 0) or (1, 2) for f8.}} %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16> func.return } From 642414a9289719ba9f4551839d441674714f6e0f Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Mon, 17 Nov 2025 14:52:41 -0500 Subject: [PATCH 44/44] Fix assertion --- mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index bf83591bf6047..edc6565f44f00 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1641,7 +1641,7 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, if (!is_fp8) { int bit_0 = is_block_16; - assert(llvm::is_contained({0, 2}, firstScaleByte)); + assert(llvm::is_contained({0, 1, 2}, firstScaleByte)); int bit_1 = (firstScaleByte == 2) << 1; assert(llvm::is_contained({0, 1}, firstScaleLane)); int bit_2 = firstScaleLane << 2;