-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][amdgpu] Add lowerings for ScaledExtPacked816 #168123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7b64631
08e96b1
d0932cc
0d1d668
163b15a
6d7e2a6
fc5d858
d9a254f
c5eb698
cec5f04
7dc3442
0ba6b94
551849e
c1e10c8
e958d56
1f79bdd
c5628e6
db56c98
73ed4b7
94cc740
4f27e04
0f8f3c4
a7a853e
9e4ab0e
728686f
47dc32e
3cfea7e
2b010cd
6f07ef0
6978793
ed66571
33ef57e
a83cec9
34ed3e9
b88f7f6
7a7ecaf
1025e2b
a3db728
7c44f09
f06a67e
1dbcb95
eee0ce9
9860cdd
0b8f561
642414a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,6 +43,7 @@ constexpr Chipset kGfx908 = Chipset(9, 0, 8); | |
| constexpr Chipset kGfx90a = Chipset(9, 0, 0xa); | ||
| constexpr Chipset kGfx942 = Chipset(9, 4, 2); | ||
| constexpr Chipset kGfx950 = Chipset(9, 5, 0); | ||
| constexpr Chipset kGfx1250 = Chipset(12, 5, 0); | ||
|
|
||
| /// Convert an unsigned number `val` to i32. | ||
| static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, | ||
|
|
@@ -1149,7 +1150,7 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, | |
| k, isRDNA3); | ||
|
|
||
| // Handle gfx1250. | ||
| if (chipset == Chipset{12, 5, 0}) | ||
| if (chipset == kGfx1250) | ||
| return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType, | ||
| elemDestType, k); | ||
|
|
||
|
|
@@ -1300,7 +1301,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { | |
| if (chipset.majorVersion != 11 && chipset.majorVersion != 12) | ||
| return op->emitOpError("WMMA only supported on gfx11 and gfx12"); | ||
|
|
||
| bool isGFX1250 = chipset >= Chipset(12, 5, 0); | ||
| bool isGFX1250 = chipset >= kGfx1250; | ||
|
|
||
| // The WMMA operations represent vectors of bf16s as vectors of i16s | ||
| // (except on gfx1250), so we need to bitcast bfloats to i16 and then | ||
|
|
@@ -1505,6 +1506,19 @@ struct ExtPackedFp8OpLowering final | |
| ConversionPatternRewriter &rewriter) const override; | ||
| }; | ||
|
|
||
| struct ScaledExtPacked816OpLowering final | ||
| : public ConvertOpToLLVMPattern<ScaledExtPacked816Op> { | ||
| ScaledExtPacked816OpLowering(const LLVMTypeConverter &converter, | ||
| Chipset chipset) | ||
| : ConvertOpToLLVMPattern<amdgpu::ScaledExtPacked816Op>(converter), | ||
| chipset(chipset) {} | ||
| Chipset chipset; | ||
|
|
||
| LogicalResult | ||
| matchAndRewrite(ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override; | ||
| }; | ||
|
|
||
| struct PackedTrunc2xFp8OpLowering final | ||
| : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> { | ||
| PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter, | ||
|
|
@@ -1613,6 +1627,170 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( | |
| return success(); | ||
| } | ||
|
|
||
| int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, | ||
| int32_t firstScaleLane, int32_t firstScaleByte) { | ||
| // When lowering amdgpu.scaled_ext_packed816 to rocdl.cvt.scale.pk*.f*.f* | ||
| // operations, the attributes blockSize, sourceType, firstScaleLane and | ||
| // firstScaleByte are merged into a single attribute scaleSel. This is how | ||
| // those values are merged together. | ||
| assert(llvm::is_contained({16, 32}, blockSize)); | ||
| assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth)); | ||
|
|
||
| const bool is_fp8 = bitWidth == 8; | ||
| const bool is_block_16 = blockSize == 16; | ||
|
|
||
| if (!is_fp8) { | ||
| int bit_0 = is_block_16; | ||
| assert(llvm::is_contained({0, 1, 2}, firstScaleByte)); | ||
| int bit_1 = (firstScaleByte == 2) << 1; | ||
| assert(llvm::is_contained({0, 1}, firstScaleLane)); | ||
| int bit_2 = firstScaleLane << 2; | ||
| return bit_2 | bit_1 | bit_0; | ||
| } | ||
|
|
||
| int bit_0 = is_block_16; | ||
| // firstScaleByte is guaranteed to be defined by two bits. | ||
| assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte)); | ||
| int bit_2_and_1 = firstScaleByte << 1; | ||
| assert(llvm::is_contained({0, 1}, firstScaleLane)); | ||
| int bit_3 = firstScaleLane << 3; | ||
| int bits = bit_3 | bit_2_and_1 | bit_0; | ||
| // These are invalid cases. | ||
| assert(!llvm::is_contained( | ||
| {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits)); | ||
| return bits; | ||
| } | ||
|
|
||
| static std::optional<StringRef> | ||
| scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) { | ||
| using fp4 = Float4E2M1FNType; | ||
| using fp8 = Float8E4M3FNType; | ||
| using bf8 = Float8E5M2Type; | ||
| using fp6 = Float6E2M3FNType; | ||
| using bf6 = Float6E3M2FNType; | ||
| if (isa<fp4>(srcElemType)) { | ||
| if (destElemType.isF16()) | ||
| return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName(); | ||
| if (destElemType.isBF16()) | ||
| return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName(); | ||
| if (destElemType.isF32()) | ||
| return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName(); | ||
| return std::nullopt; | ||
| } | ||
| if (isa<fp8>(srcElemType)) { | ||
| if (destElemType.isF16()) | ||
| return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName(); | ||
| if (destElemType.isBF16()) | ||
| return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName(); | ||
| if (destElemType.isF32()) | ||
| return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName(); | ||
| return std::nullopt; | ||
| } | ||
| if (isa<bf8>(srcElemType)) { | ||
| if (destElemType.isF16()) | ||
| return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName(); | ||
| if (destElemType.isBF16()) | ||
| return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName(); | ||
| if (destElemType.isF32()) | ||
| return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName(); | ||
| return std::nullopt; | ||
| } | ||
| if (isa<fp6>(srcElemType)) { | ||
| if (destElemType.isF16()) | ||
| return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName(); | ||
| if (destElemType.isBF16()) | ||
| return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName(); | ||
| if (destElemType.isF32()) | ||
| return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName(); | ||
| return std::nullopt; | ||
| } | ||
| if (isa<bf6>(srcElemType)) { | ||
| if (destElemType.isF16()) | ||
| return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName(); | ||
| if (destElemType.isBF16()) | ||
| return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName(); | ||
| if (destElemType.isF32()) | ||
| return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName(); | ||
| return std::nullopt; | ||
| } | ||
| llvm_unreachable("invalid combination of element types for packed conversion " | ||
| "instructions"); | ||
| } | ||
|
|
||
| LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite( | ||
| ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const { | ||
| using fp4 = Float4E2M1FNType; | ||
| using fp8 = Float8E4M3FNType; | ||
| using bf8 = Float8E5M2Type; | ||
| using fp6 = Float6E2M3FNType; | ||
| using bf6 = Float6E3M2FNType; | ||
| Location loc = op.getLoc(); | ||
| if (chipset != kGfx1250) { | ||
| return rewriter.notifyMatchFailure( | ||
| loc, | ||
| "Scaled fp packed conversion instructions are not available on target " | ||
| "architecture and their emulation is not implemented"); | ||
| } | ||
| int32_t firstScaleLane = op.getFirstScaleLane(); | ||
| int32_t firstScaleByte = op.getFirstScaleByte(); | ||
| int32_t blockSize = op.getBlockSize(); | ||
| auto sourceType = cast<VectorType>(op.getSource().getType()); | ||
| auto srcElemType = cast<FloatType>(sourceType.getElementType()); | ||
| unsigned bitWidth = srcElemType.getWidth(); | ||
| int32_t scaleSel = | ||
| getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte); | ||
|
|
||
| auto targetType = cast<VectorType>(op.getResult().getType()); | ||
| auto destElemType = cast<FloatType>(targetType.getElementType()); | ||
| IntegerType i32 = rewriter.getI32Type(); | ||
| Value castedScale = | ||
| LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale()); | ||
|
|
||
| Value source = adaptor.getSource(); | ||
| Type llvmResultType = typeConverter->convertType(op.getResult().getType()); | ||
| Type packedType = nullptr; | ||
| if (isa<fp4>(srcElemType)) { | ||
| packedType = i32; | ||
| packedType = getTypeConverter()->convertType(packedType); | ||
| } else if (isa<fp8, bf8>(srcElemType)) { | ||
| packedType = VectorType::get(2, i32); | ||
| packedType = getTypeConverter()->convertType(packedType); | ||
| } else if (isa<fp6, bf6>(srcElemType)) { | ||
| packedType = VectorType::get(3, i32); | ||
| packedType = getTypeConverter()->convertType(packedType); | ||
| } else { | ||
| llvm_unreachable("invalid element type for packed scaled ext"); | ||
kuhar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| if (!packedType || !llvmResultType) { | ||
| return rewriter.notifyMatchFailure(op, "type conversion failed"); | ||
| } | ||
|
|
||
| Value castedSource = | ||
| LLVM::BitcastOp::create(rewriter, loc, packedType, source); | ||
|
|
||
| std::optional<StringRef> maybeIntrinsic = | ||
| scaledExtPacked816ToIntrinsic(srcElemType, destElemType); | ||
| if (!maybeIntrinsic.has_value()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Post-commit review: this failure happens after IR has been created, which is a "pleane don't do that" Can we swap this and the bitcaste above?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! #168542 |
||
| return op.emitOpError( | ||
| "no intrinsic matching packed scaled conversion on the given chipset"); | ||
|
|
||
| OperationState loweredOp(loc, *maybeIntrinsic); | ||
| loweredOp.addTypes({llvmResultType}); | ||
| loweredOp.addOperands({castedSource, castedScale}); | ||
|
|
||
| SmallVector<NamedAttribute, 1> attrs; | ||
| attrs.push_back( | ||
| NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel))); | ||
|
|
||
| loweredOp.addAttributes(attrs); | ||
| Operation *lowered = rewriter.create(loweredOp); | ||
| rewriter.replaceOp(op, lowered); | ||
|
|
||
| return success(); | ||
amd-eochoalo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| LogicalResult ScaledExtPackedOpLowering::matchAndRewrite( | ||
| ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const { | ||
|
|
@@ -2151,9 +2329,10 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, | |
| ROCDL::RawPtrBufferAtomicCmpSwap>, | ||
| AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, | ||
| SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, | ||
| WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, | ||
| PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, | ||
| PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, | ||
| TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset); | ||
| WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPacked816OpLowering, | ||
| ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, | ||
| PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, | ||
| GatherToLDSOpLowering, TransposeLoadOpLowering, | ||
| AMDGPUPermlaneLowering>(converter, chipset); | ||
| patterns.add<AMDGPUSwizzleBitModeLowering>(converter); | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -456,3 +456,4 @@ func.func @sched_barrier() { | |
| amdgpu.sched_barrier allow = <valu|all_vmem> | ||
| func.return | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.