Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
7b64631
Update documentation
amd-eochoalo Nov 10, 2025
08e96b1
Fix verifiers
amd-eochoalo Nov 10, 2025
d0932cc
[mlir][amdgpu] Convert scaled_ext_packed816 to rocdl
amd-eochoalo Nov 10, 2025
0d1d668
Create skeleton for pattern
amd-eochoalo Nov 10, 2025
163b15a
Initial conversion
amd-eochoalo Nov 10, 2025
6d7e2a6
Add first test
amd-eochoalo Nov 10, 2025
fc5d858
Adds two more cases
amd-eochoalo Nov 10, 2025
d9a254f
Add case for pk8.bf16.fp4
amd-eochoalo Nov 11, 2025
c5eb698
Add conversion for pk8.bf16.bf8
amd-eochoalo Nov 11, 2025
cec5f04
Fix and add new case
amd-eochoalo Nov 11, 2025
7dc3442
Add case for pk8.f32.fp4
amd-eochoalo Nov 11, 2025
0ba6b94
Add case for pk8.f32.fp8
amd-eochoalo Nov 11, 2025
551849e
Add case for pk8.f32.bf8
amd-eochoalo Nov 11, 2025
c1e10c8
Add case for pk16.f16.bf6
amd-eochoalo Nov 11, 2025
e958d56
Add case for pk16.bf16.fp6
amd-eochoalo Nov 11, 2025
1f79bdd
Add case for pk16.bf16.bf6
amd-eochoalo Nov 11, 2025
c5628e6
Add case for pk16.f32.fp6
amd-eochoalo Nov 11, 2025
db56c98
Add case for pk16.f32.bf6
amd-eochoalo Nov 11, 2025
73ed4b7
Refactor NFC
amd-eochoalo Nov 11, 2025
94cc740
Refactor NFC
amd-eochoalo Nov 11, 2025
4f27e04
Use method instead of isa
amd-eochoalo Nov 11, 2025
0f8f3c4
Hoist variable
amd-eochoalo Nov 11, 2025
a7a853e
Refactor NFC
amd-eochoalo Nov 11, 2025
9e4ab0e
Hoist variable. NFC
amd-eochoalo Nov 11, 2025
728686f
Comments. NFC
amd-eochoalo Nov 11, 2025
47dc32e
refactor. nfc
amd-eochoalo Nov 11, 2025
3cfea7e
Keep conventions
amd-eochoalo Nov 11, 2025
2b010cd
Less of exhaustive enumeration
amd-eochoalo Nov 17, 2025
6f07ef0
Correct types
amd-eochoalo Nov 17, 2025
6978793
Reflow comment
amd-eochoalo Nov 17, 2025
ed66571
superfluous empty line
amd-eochoalo Nov 17, 2025
33ef57e
Using
amd-eochoalo Nov 17, 2025
a83cec9
Add chipset check and moved tests
amd-eochoalo Nov 17, 2025
34ed3e9
Refactor NFC
amd-eochoalo Nov 17, 2025
b88f7f6
Use operation name
amd-eochoalo Nov 17, 2025
7a7ecaf
Convert result type
amd-eochoalo Nov 17, 2025
1025e2b
Check for type conversion failures
amd-eochoalo Nov 17, 2025
a3db728
Merge branch 'main' into eochoa/2025-11-10/lowerings
amd-eochoalo Nov 17, 2025
7c44f09
Add top-level if condition for each src type
amd-eochoalo Nov 17, 2025
f06a67e
Add chipset constant at beginning of file
amd-eochoalo Nov 17, 2025
1dbcb95
wip
amd-eochoalo Nov 17, 2025
eee0ce9
Add invalid srcElemType case
amd-eochoalo Nov 17, 2025
9860cdd
Update verifiers
amd-eochoalo Nov 17, 2025
0b8f561
Update verifier message
amd-eochoalo Nov 17, 2025
642414a
Fix assertion
amd-eochoalo Nov 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 185 additions & 6 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ constexpr Chipset kGfx908 = Chipset(9, 0, 8);
constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
constexpr Chipset kGfx942 = Chipset(9, 4, 2);
constexpr Chipset kGfx950 = Chipset(9, 5, 0);
constexpr Chipset kGfx1250 = Chipset(12, 5, 0);

/// Convert an unsigned number `val` to i32.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
Expand Down Expand Up @@ -1149,7 +1150,7 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
k, isRDNA3);

// Handle gfx1250.
if (chipset == Chipset{12, 5, 0})
if (chipset == kGfx1250)
return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
elemDestType, k);

Expand Down Expand Up @@ -1300,7 +1301,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
return op->emitOpError("WMMA only supported on gfx11 and gfx12");

bool isGFX1250 = chipset >= Chipset(12, 5, 0);
bool isGFX1250 = chipset >= kGfx1250;

// The WMMA operations represent vectors of bf16s as vectors of i16s
// (except on gfx1250), so we need to bitcast bfloats to i16 and then
Expand Down Expand Up @@ -1505,6 +1506,19 @@ struct ExtPackedFp8OpLowering final
ConversionPatternRewriter &rewriter) const override;
};

struct ScaledExtPacked816OpLowering final
: public ConvertOpToLLVMPattern<ScaledExtPacked816Op> {
ScaledExtPacked816OpLowering(const LLVMTypeConverter &converter,
Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::ScaledExtPacked816Op>(converter),
chipset(chipset) {}
Chipset chipset;

LogicalResult
matchAndRewrite(ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

struct PackedTrunc2xFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
Expand Down Expand Up @@ -1613,6 +1627,170 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
return success();
}

int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
int32_t firstScaleLane, int32_t firstScaleByte) {
// When lowering amdgpu.scaled_ext_packed816 to rocdl.cvt.scale.pk*.f*.f*
// operations, the attributes blockSize, sourceType, firstScaleLane and
// firstScaleByte are merged into a single attribute scaleSel. This is how
// those values are merged together.
assert(llvm::is_contained({16, 32}, blockSize));
assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));

const bool is_fp8 = bitWidth == 8;
const bool is_block_16 = blockSize == 16;

if (!is_fp8) {
int bit_0 = is_block_16;
assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
int bit_1 = (firstScaleByte == 2) << 1;
assert(llvm::is_contained({0, 1}, firstScaleLane));
int bit_2 = firstScaleLane << 2;
return bit_2 | bit_1 | bit_0;
}

int bit_0 = is_block_16;
// firstScaleByte is guaranteed to be defined by two bits.
assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
int bit_2_and_1 = firstScaleByte << 1;
assert(llvm::is_contained({0, 1}, firstScaleLane));
int bit_3 = firstScaleLane << 3;
int bits = bit_3 | bit_2_and_1 | bit_0;
// These are invalid cases.
assert(!llvm::is_contained(
{0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
return bits;
}

static std::optional<StringRef>
scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
using fp4 = Float4E2M1FNType;
using fp8 = Float8E4M3FNType;
using bf8 = Float8E5M2Type;
using fp6 = Float6E2M3FNType;
using bf6 = Float6E3M2FNType;
if (isa<fp4>(srcElemType)) {
if (destElemType.isF16())
return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
if (destElemType.isBF16())
return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
if (destElemType.isF32())
return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
return std::nullopt;
}
if (isa<fp8>(srcElemType)) {
if (destElemType.isF16())
return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
if (destElemType.isBF16())
return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
if (destElemType.isF32())
return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
return std::nullopt;
}
if (isa<bf8>(srcElemType)) {
if (destElemType.isF16())
return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
if (destElemType.isBF16())
return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
if (destElemType.isF32())
return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
return std::nullopt;
}
if (isa<fp6>(srcElemType)) {
if (destElemType.isF16())
return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
if (destElemType.isBF16())
return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
if (destElemType.isF32())
return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
return std::nullopt;
}
if (isa<bf6>(srcElemType)) {
if (destElemType.isF16())
return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
if (destElemType.isBF16())
return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
if (destElemType.isF32())
return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
return std::nullopt;
}
llvm_unreachable("invalid combination of element types for packed conversion "
"instructions");
}

LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
using fp4 = Float4E2M1FNType;
using fp8 = Float8E4M3FNType;
using bf8 = Float8E5M2Type;
using fp6 = Float6E2M3FNType;
using bf6 = Float6E3M2FNType;
Location loc = op.getLoc();
if (chipset != kGfx1250) {
return rewriter.notifyMatchFailure(
loc,
"Scaled fp packed conversion instructions are not available on target "
"architecture and their emulation is not implemented");
}
int32_t firstScaleLane = op.getFirstScaleLane();
int32_t firstScaleByte = op.getFirstScaleByte();
int32_t blockSize = op.getBlockSize();
auto sourceType = cast<VectorType>(op.getSource().getType());
auto srcElemType = cast<FloatType>(sourceType.getElementType());
unsigned bitWidth = srcElemType.getWidth();
int32_t scaleSel =
getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte);

auto targetType = cast<VectorType>(op.getResult().getType());
auto destElemType = cast<FloatType>(targetType.getElementType());
IntegerType i32 = rewriter.getI32Type();
Value castedScale =
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());

Value source = adaptor.getSource();
Type llvmResultType = typeConverter->convertType(op.getResult().getType());
Type packedType = nullptr;
if (isa<fp4>(srcElemType)) {
packedType = i32;
packedType = getTypeConverter()->convertType(packedType);
} else if (isa<fp8, bf8>(srcElemType)) {
packedType = VectorType::get(2, i32);
packedType = getTypeConverter()->convertType(packedType);
} else if (isa<fp6, bf6>(srcElemType)) {
packedType = VectorType::get(3, i32);
packedType = getTypeConverter()->convertType(packedType);
} else {
llvm_unreachable("invalid element type for packed scaled ext");
}

if (!packedType || !llvmResultType) {
return rewriter.notifyMatchFailure(op, "type conversion failed");
}

Value castedSource =
LLVM::BitcastOp::create(rewriter, loc, packedType, source);

std::optional<StringRef> maybeIntrinsic =
scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
if (!maybeIntrinsic.has_value())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Post-commit review: this failure happens after IR has been created, which is a "pleane don't do that"

Can we swap this and the bitcaste above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! #168542

return op.emitOpError(
"no intrinsic matching packed scaled conversion on the given chipset");

OperationState loweredOp(loc, *maybeIntrinsic);
loweredOp.addTypes({llvmResultType});
loweredOp.addOperands({castedSource, castedScale});

SmallVector<NamedAttribute, 1> attrs;
attrs.push_back(
NamedAttribute("scaleSel", rewriter.getI32IntegerAttr(scaleSel)));

loweredOp.addAttributes(attrs);
Operation *lowered = rewriter.create(loweredOp);
rewriter.replaceOp(op, lowered);

return success();
}

LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand Down Expand Up @@ -2151,9 +2329,10 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPacked816OpLowering,
ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
GatherToLDSOpLowering, TransposeLoadOpLowering,
AMDGPUPermlaneLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
45 changes: 29 additions & 16 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,28 +343,41 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===//
LogicalResult ScaledExtPacked816Op::verify() {
int blockSize = getBlockSize();
assert((blockSize == 16 || blockSize == 32) && "invalid block size");
assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size");

int firstScaleByte = getFirstScaleByte();
int firstScaleLane = getFirstScaleLane();
auto sourceType = cast<VectorType>(getSource().getType());
Type elementType = sourceType.getElementType();
auto floatType = cast<FloatType>(elementType);
int bitWidth = floatType.getWidth();
unsigned bitWidth = floatType.getWidth();

if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 16 &&
!llvm::is_contained({0, 1}, firstScaleByte)) {
return emitOpError("blockSize of 16 can only have firstScaleByte be 0 or 1 "
"for f4 and f6.");
}
if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 32 &&
!llvm::is_contained({0, 2}, firstScaleByte)) {
return emitOpError("blockSize of 32 can only have firstScaleByte be 0 or 2 "
"for f4 and f6.");
}
if (bitWidth == 8 && blockSize == 16 &&
!llvm::is_contained({0, 2}, firstScaleByte)) {
return emitOpError(
"blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.");
assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));

const bool is_fp8 = bitWidth == 8;
const bool is_block_16 = blockSize == 16;

if (!is_fp8) {
if (is_block_16) {
if (!llvm::is_contained({0, 1}, firstScaleByte)) {
return emitOpError("blockSize of 16 can only have firstScaleByte be 0 "
"or 1 for f4 and f6.");
}
} else {
if (!llvm::is_contained({0, 2}, firstScaleByte)) {
return emitOpError("blockSize of 32 can only have firstScaleByte be 0 "
"or 2 for f4 and f6.");
}
}
} else {
if (is_block_16) {
bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
((firstScaleLane == 1) && (firstScaleByte == 2));
if (!is_valid) {
return emitOpError("blockSize of 16 can only have (firstScaleLane, "
"firstScaleByte) be (0, 0) or (1, 2) for f8.");
}
}
}

return success();
Expand Down
1 change: 1 addition & 0 deletions mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -456,3 +456,4 @@ func.func @sched_barrier() {
amdgpu.sched_barrier allow = <valu|all_vmem>
func.return
}

Loading
Loading