From bc38211280388cdc856c3e5cd8ae2219ea0f4786 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Thu, 26 Jun 2025 09:54:55 +0000 Subject: [PATCH 1/2] [mlir][amdgpu] Add conversion for arith.scaling_extf to amdgpu --- .../ArithToAMDGPU/ArithToAMDGPU.cpp | 272 ++++++++++++++++++ 1 file changed, 272 insertions(+) diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 3596b3235a631..cf9bb3a000050 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -14,7 +14,10 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -32,6 +35,7 @@ using namespace mlir::amdgpu; namespace { // Define commonly used chipsets versions for convenience. constexpr Chipset kGfx942 = Chipset(9, 4, 2); +constexpr Chipset kGfx950 = Chipset(9, 5, 0); struct ArithToAMDGPUConversionPass final : impl::ArithToAMDGPUConversionPassBase { @@ -73,6 +77,28 @@ struct TruncfToFloat16RewritePattern final PatternRewriter &rewriter) const override; }; +struct ScalingExtFRewritePattern final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ScalingExtFRewritePattern(MLIRContext *ctx) + : OpRewritePattern::OpRewritePattern(ctx) {} + + LogicalResult matchAndRewrite(arith::ScalingExtFOp op, + PatternRewriter &rewriter) const override; +}; + +struct ScalingTruncFRewritePattern final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ScalingTruncFRewritePattern(MLIRContext *ctx) + : OpRewritePattern::OpRewritePattern(ctx) {} + + LogicalResult matchAndRewrite(arith::ScalingTruncFOp op, + PatternRewriter &rewriter) const override; +}; + } // end namespace static bool isSupportedF8(Type elementType, Chipset chipset) { @@ -395,6 +421,247 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( return success(); } +/// Get the broadcasted / splatted value for a chain of ops. +static Value getOriginalVectorValue(Value value) { + Value current = value; + while (Operation *definingOp = current.getDefiningOp()) { + bool skipOp = llvm::TypeSwitch(definingOp) + .Case([¤t](auto op) { + current = op.getSource(); + return true; + }) + .Case([¤t](auto op) { + current = op.getSource(); + return false; + }) + .Case([¤t](auto op) { + current = op.getInput(); + return false; + }) + .Default([](Operation *) { return false; }); + + if (!skipOp) { + break; + } + } + return current; +} + +LogicalResult +ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + constexpr int64_t opWidth = 2; + + Value in = op.getIn(); + Value scale = op.getScale(); + Value out = op.getOut(); + + Type f32 = rewriter.getF32Type(); + Type inType = getElementTypeOrSelf(in); + Type scaleType = getElementTypeOrSelf(scale); + Type outType = getElementTypeOrSelf(out); + + VectorType outVecType = dyn_cast(out.getType()); + VectorType scaleVecType = dyn_cast(scale.getType()); + + if (outVecType && outVecType.isScalable()) + return failure(); + + Type scaleF32Type = + scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32; + if (scaleType.getIntOrFloatBitWidth() < 32) + scale = rewriter.create(loc, scaleF32Type, scale); + else if (scaleType.getIntOrFloatBitWidth() > 32) + scale = rewriter.create(loc, scaleF32Type, scale); + + VectorType extScaleResultType = VectorType::get(opWidth, outType); + + if (!outVecType) { + Value inCast = + rewriter.create(loc, VectorType::get(1, inType), in); + // TODO: replace this with non-packed ScaledExtOp + Value scaleExt = rewriter.create( + loc, extScaleResultType, inCast, scale, 0); + scaleExt = rewriter.replaceOpWithNewOp(op, scaleExt, 0); + return success(); + } + + VectorType inVecType = cast(in.getType()); + Value origScale = getOriginalVectorValue(op.getScale()); + + ArrayRef inShape = inVecType.getShape(); + SmallVector originalScaleShape; + if (auto origScaleVecType = dyn_cast(origScale.getType())) + llvm::append_range(originalScaleShape, origScaleVecType.getShape()); + + originalScaleShape.insert(originalScaleShape.end(), + inShape.size() - originalScaleShape.size(), 1); + + auto maybeRatio = computeShapeRatio(inShape, originalScaleShape); + assert(maybeRatio && + "failed to derive block size from broadcast or splat operation"); + + SmallVector ratio = + maybeRatio.value_or(SmallVector(inShape.size(), 1)); + + int64_t blockSize = computeProduct(ratio); + + Value zero = rewriter.create( + loc, outType, rewriter.getFloatAttr(outType, 0.0)); + Value result = rewriter.createOrFold(loc, outVecType, zero); + + for (SmallVector offsets : StaticTileOffsetRange(inShape, ratio)) { + SmallVector strides(offsets.size(), 1); + Value block = rewriter.create( + loc, in, offsets, ratio, strides); + VectorType block1DType = VectorType::get(blockSize, inType); + Value block1D = + rewriter.create(loc, block1DType, block); + Value uniformScale = + rewriter.create(loc, scale, offsets); + + VectorType blockResultType = VectorType::get(blockSize, outType); + Value blockResult = + rewriter.createOrFold(loc, blockResultType, zero); + + for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); + i < blockSize; + i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) { + Value slice = rewriter.create( + loc, block1D, i, sliceWidth, 1); + // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1 + Value scaleExt = rewriter.create( + loc, extScaleResultType, slice, uniformScale, 0); + if (sliceWidth != opWidth) + scaleExt = rewriter.create( + loc, scaleExt, 0, sliceWidth, 1); + blockResult = rewriter.create( + loc, scaleExt, blockResult, i, 1); + } + + VectorType resultType = VectorType::get(ratio, outType); + Value cast = + rewriter.create(loc, resultType, blockResult); + result = rewriter.create(loc, cast, result, + offsets, strides); + } + + rewriter.replaceOp(op, result); + + return success(); +} + +LogicalResult +ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + constexpr int64_t opWidth = 2; + + Value in = op.getIn(); + Value scale = op.getScale(); + Value out = op.getOut(); + + Type f32 = rewriter.getF32Type(); + Type inType = getElementTypeOrSelf(in); + Type scaleType = getElementTypeOrSelf(scale); + Type outType = getElementTypeOrSelf(out); + + VectorType outVecType = dyn_cast(out.getType()); + VectorType scaleVecType = dyn_cast(scale.getType()); + + if (outVecType && outVecType.isScalable()) + return failure(); + + Type scaleF32Type = + scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32; + if (scaleType.getIntOrFloatBitWidth() < 32) + scale = rewriter.create(loc, scaleF32Type, scale); + else if (scaleType.getIntOrFloatBitWidth() > 32) + scale = rewriter.create(loc, scaleF32Type, scale); + + Value zero = rewriter.create( + loc, outType, rewriter.getFloatAttr(outType, 0.0)); + unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth(); + VectorType truncScaleResultType = VectorType::get(numPackedElem, outType); + + if (!outVecType) { + Type inVecType = VectorType::get(1, inType); + Value inCast = rewriter.create(loc, inVecType, in); + // TODO: replace this with non-packed ScaledTruncOp + Value scaleTrunc = rewriter.create( + loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr); + scaleTrunc = + rewriter.replaceOpWithNewOp(op, scaleTrunc, 0); + return success(); + } + + VectorType inVecType = cast(in.getType()); + Value origScale = getOriginalVectorValue(op.getScale()); + + ArrayRef inShape = inVecType.getShape(); + SmallVector originalScaleShape; + if (auto origScaleVecType = dyn_cast(origScale.getType())) + llvm::append_range(originalScaleShape, origScaleVecType.getShape()); + + originalScaleShape.insert(originalScaleShape.end(), + inShape.size() - originalScaleShape.size(), 1); + + auto maybeRatio = computeShapeRatio(inShape, originalScaleShape); + assert(maybeRatio && + "failed to derive block size from broadcast or splat operation"); + + SmallVector ratio = + maybeRatio.value_or(SmallVector(inShape.size(), 1)); + + int64_t blockSize = computeProduct(ratio); + + Value result = rewriter.createOrFold(loc, outVecType, zero); + + for (SmallVector offsets : StaticTileOffsetRange(inShape, ratio)) { + SmallVector strides(offsets.size(), 1); + Value block = rewriter.create( + loc, in, offsets, ratio, strides); + VectorType block1DType = VectorType::get(blockSize, inType); + Value block1D = + rewriter.create(loc, block1DType, block); + Value uniformScale = + rewriter.create(loc, scale, offsets); + + VectorType blockResultType = VectorType::get(blockSize, outType); + Value blockResult = + rewriter.createOrFold(loc, blockResultType, zero); + + for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); + i < blockSize; + i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) { + Value slice = rewriter.create( + loc, block1D, i, sliceWidth, 1); + // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1 + Value scaleTrunc = rewriter.create( + loc, truncScaleResultType, slice, uniformScale, 0, + /*existing=*/nullptr); + int64_t packedWidth = + cast(scaleTrunc.getType()).getNumElements(); + if (packedWidth != opWidth) + scaleTrunc = rewriter.create( + loc, scaleTrunc, 0, sliceWidth, 1); + blockResult = rewriter.create( + loc, scaleTrunc, blockResult, i, 1); + } + + VectorType resultType = VectorType::get(ratio, outType); + Value cast = + rewriter.create(loc, resultType, blockResult); + result = rewriter.create(loc, cast, result, + offsets, strides); + } + + rewriter.replaceOp(op, result); + + return success(); +} + void mlir::arith::populateArithToAMDGPUConversionPatterns( RewritePatternSet &patterns, bool convertFP8Arithmetic, bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) { @@ -406,6 +673,11 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns( } if (allowPackedF16Rtz) patterns.add(patterns.getContext()); + + if (chipset >= kGfx950) { + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + } } void ArithToAMDGPUConversionPass::runOnOperation() { From efc6194b7b664ed2cb9aa175ae92692b87c2fdfe Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Mon, 30 Jun 2025 12:18:36 +0000 Subject: [PATCH 2/2] add tests --- .../ArithToAMDGPU/scaling-extf.mlir | 262 ++++++++++++++++++ .../ArithToAMDGPU/scaling-truncf.mlir | 193 +++++++++++++ 2 files changed, 455 insertions(+) create mode 100644 mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir create mode 100644 mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir new file mode 100644 index 0000000000000..095f3e575eca8 --- /dev/null +++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir @@ -0,0 +1,262 @@ +// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s + +// CHECK-LABEL: @conversion_f8_f32_fallback +// CHECK: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> +// CHECK-NEXT: %[[SCALE_EXT:.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32> +// CHECK-NEXT: %[[IN_SLICE_00:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf8E5M2> to vector<1x1xf8E5M2> +// CHECK-NEXT: %[[IN_VEC_00:.+]] = vector.shape_cast %[[IN_SLICE_00]] : vector<1x1xf8E5M2> to vector<1xf8E5M2> +// CHECK-NEXT: %[[SCALE_SCALAR_00:.+]] = vector.extract %[[SCALE_EXT]][0, 0] : f32 from vector<2x2xf32> +// CHECK-NEXT: %[[PACKED_00:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_00]][0], %[[SCALE_SCALAR_00]] : vector<1xf8E5M2> to vector<2xf32> +// CHECK-NEXT: %[[OUT_SLICE_00:.+]] = vector.extract_strided_slice %[[PACKED_00]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32> +// CHECK-NEXT: %[[OUT_VEC_00:.+]] = vector.shape_cast %[[OUT_SLICE_00]] : vector<1xf32> to vector<1x1xf32> +// CHECK-NEXT: %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_VEC_00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x1xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[IN_SLICE_01:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf8E5M2> to vector<1x1xf8E5M2> +// CHECK-NEXT: %[[IN_VEC_01:.+]] = vector.shape_cast %[[IN_SLICE_01]] : vector<1x1xf8E5M2> to vector<1xf8E5M2> +// CHECK-NEXT: %[[SCALE_SCALAR_01:.+]] = vector.extract %[[SCALE_EXT]][0, 1] : f32 from vector<2x2xf32> +// CHECK-NEXT: %[[PACKED_01:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_01]][0], %[[SCALE_SCALAR_01]] : vector<1xf8E5M2> to vector<2xf32> +// CHECK-NEXT: %[[OUT_SLICE_01:.+]] = vector.extract_strided_slice %[[PACKED_01]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32> +// CHECK-NEXT: %[[OUT_VEC_01:.+]] = vector.shape_cast %[[OUT_SLICE_01]] : vector<1xf32> to vector<1x1xf32> +// CHECK-NEXT: %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_VEC_01]], %[[ACC_A]] {offsets = [0, 1], strides = [1, 1]} : vector<1x1xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[IN_SLICE_10:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf8E5M2> to vector<1x1xf8E5M2> +// CHECK-NEXT: %[[IN_VEC_10:.+]] = vector.shape_cast %[[IN_SLICE_10]] : vector<1x1xf8E5M2> to vector<1xf8E5M2> +// CHECK-NEXT: %[[SCALE_SCALAR_10:.+]] = vector.extract %[[SCALE_EXT]][1, 0] : f32 from vector<2x2xf32> +// CHECK-NEXT: %[[PACKED_10:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_10]][0], %[[SCALE_SCALAR_10]] : vector<1xf8E5M2> to vector<2xf32> +// CHECK-NEXT: %[[OUT_SLICE_10:.+]] = vector.extract_strided_slice %[[PACKED_10]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32> +// CHECK-NEXT: %[[OUT_VEC_10:.+]] = vector.shape_cast %[[OUT_SLICE_10]] : vector<1xf32> to vector<1x1xf32> +// CHECK-NEXT: %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_VEC_10]], %[[ACC_B]] {offsets = [1, 0], strides = [1, 1]} : vector<1x1xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[IN_SLICE_11:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf8E5M2> to vector<1x1xf8E5M2> +// CHECK-NEXT: %[[IN_VEC_11:.+]] = vector.shape_cast %[[IN_SLICE_11]] : vector<1x1xf8E5M2> to vector<1xf8E5M2> +// CHECK-NEXT: %[[SCALE_SCALAR_11:.+]] = vector.extract %[[SCALE_EXT]][1, 1] : f32 from vector<2x2xf32> +// CHECK-NEXT: %[[PACKED_11:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_11]][0], %[[SCALE_SCALAR_11]] : vector<1xf8E5M2> to vector<2xf32> +// CHECK-NEXT: %[[OUT_SLICE_11:.+]] = vector.extract_strided_slice %[[PACKED_11]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32> +// CHECK-NEXT: %[[OUT_VEC_11:.+]] = vector.shape_cast %[[OUT_SLICE_11]] : vector<1xf32> to vector<1x1xf32> +// CHECK-NEXT: %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_VEC_11]], %[[ACC_A]] {offsets = [1, 1], strides = [1, 1]} : vector<1x1xf32> into vector<2x2xf32> +// CHECK-NEXT: return %[[ACC_B]] : vector<2x2xf32> +func.func @conversion_f8_f32_fallback(%in: vector<2x2xf8E5M2>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf32> { + %ext = arith.scaling_extf %in, %scale : vector<2x2xf8E5M2>, vector<2x2xf8E8M0FNU> to vector<2x2xf32> + return %ext : vector<2x2xf32> +} + +// ----- + +// CHECK-LABEL: @conversion_f4_f32_fallback +// CHECK: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> +// CHECK-NEXT: %[[SCALE_EXT:.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32> +// CHECK-NEXT: %[[IN_SLICE_00:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf4E2M1FN> to vector<1x1xf4E2M1FN> +// CHECK-NEXT: %[[IN_VEC_00:.+]] = vector.shape_cast %[[IN_SLICE_00]] : vector<1x1xf4E2M1FN> to vector<1xf4E2M1FN> +// CHECK-NEXT: %[[SCALE_SCALAR_00:.+]] = vector.extract %[[SCALE_EXT]][0, 0] : f32 from vector<2x2xf32> +// CHECK-NEXT: %[[PACKED_00:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_00]][0], %[[SCALE_SCALAR_00]] : vector<1xf4E2M1FN> to vector<2xf32> +// CHECK-NEXT: %[[OUT_SLICE_00:.+]] = vector.extract_strided_slice %[[PACKED_00]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32> +// CHECK-NEXT: %[[OUT_VEC_00:.+]] = vector.shape_cast %[[OUT_SLICE_00]] : vector<1xf32> to vector<1x1xf32> +// CHECK-NEXT: %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_VEC_00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x1xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[IN_SLICE_01:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf4E2M1FN> to vector<1x1xf4E2M1FN> +// CHECK-NEXT: %[[IN_VEC_01:.+]] = vector.shape_cast %[[IN_SLICE_01]] : vector<1x1xf4E2M1FN> to vector<1xf4E2M1FN> +// CHECK-NEXT: %[[SCALE_SCALAR_01:.+]] = vector.extract %[[SCALE_EXT]][0, 1] : f32 from vector<2x2xf32> +// CHECK-NEXT: %[[PACKED_01:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_01]][0], %[[SCALE_SCALAR_01]] : vector<1xf4E2M1FN> to vector<2xf32> +// CHECK-NEXT: %[[OUT_SLICE_01:.+]] = vector.extract_strided_slice %[[PACKED_01]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32> +// CHECK-NEXT: %[[OUT_VEC_01:.+]] = vector.shape_cast %[[OUT_SLICE_01]] : vector<1xf32> to vector<1x1xf32> +// CHECK-NEXT: %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_VEC_01]], %[[ACC_A]] {offsets = [0, 1], strides = [1, 1]} : vector<1x1xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[IN_SLICE_10:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf4E2M1FN> to vector<1x1xf4E2M1FN> +// CHECK-NEXT: %[[IN_VEC_10:.+]] = vector.shape_cast %[[IN_SLICE_10]] : vector<1x1xf4E2M1FN> to vector<1xf4E2M1FN> +// CHECK-NEXT: %[[SCALE_SCALAR_10:.+]] = vector.extract %[[SCALE_EXT]][1, 0] : f32 from vector<2x2xf32> +// CHECK-NEXT: %[[PACKED_10:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_10]][0], %[[SCALE_SCALAR_10]] : vector<1xf4E2M1FN> to vector<2xf32> +// CHECK-NEXT: %[[OUT_SLICE_10:.+]] = vector.extract_strided_slice %[[PACKED_10]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32> +// CHECK-NEXT: %[[OUT_VEC_10:.+]] = vector.shape_cast %[[OUT_SLICE_10]] : vector<1xf32> to vector<1x1xf32> +// CHECK-NEXT: %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_VEC_10]], %[[ACC_B]] {offsets = [1, 0], strides = [1, 1]} : vector<1x1xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[IN_SLICE_11:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf4E2M1FN> to vector<1x1xf4E2M1FN> +// CHECK-NEXT: %[[IN_VEC_11:.+]] = vector.shape_cast %[[IN_SLICE_11]] : vector<1x1xf4E2M1FN> to vector<1xf4E2M1FN> +// CHECK-NEXT: %[[SCALE_SCALAR_11:.+]] = vector.extract %[[SCALE_EXT]][1, 1] : f32 from vector<2x2xf32> +// CHECK-NEXT: %[[PACKED_11:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_11]][0], %[[SCALE_SCALAR_11]] : vector<1xf4E2M1FN> to vector<2xf32> +// CHECK-NEXT: %[[OUT_SLICE_11:.+]] = vector.extract_strided_slice %[[PACKED_11]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32> +// CHECK-NEXT: %[[OUT_VEC_11:.+]] = vector.shape_cast %[[OUT_SLICE_11]] : vector<1xf32> to vector<1x1xf32> +// CHECK-NEXT: %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_VEC_11]], %[[ACC_A]] {offsets = [1, 1], strides = [1, 1]} : vector<1x1xf32> into vector<2x2xf32> +// CHECK-NEXT: return %[[ACC_B]] : vector<2x2xf32> +func.func @conversion_f4_f32_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf32> { + %ext = arith.scaling_extf %in, %scale : vector<2x2xf4E2M1FN>, vector<2x2xf8E8M0FNU> to vector<2x2xf32> + return %ext : vector<2x2xf32> +} + +// ----- + +// CHECK-LABEL: @conversion_f8_f16_fallback +// CHECK: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf16> +// CHECK-NEXT: %[[SCALE_EXT:.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32> +// CHECK-NEXT: %[[IN_SLICE_00:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf8E5M2> to vector<1x1xf8E5M2> +// CHECK-NEXT: %[[IN_VEC_00:.+]] = vector.shape_cast %[[IN_SLICE_00]] : vector<1x1xf8E5M2> to vector<1xf8E5M2> +// CHECK-NEXT: %[[SCALE_SCALAR_00:.+]] = vector.extract %[[SCALE_EXT]][0, 0] : f32 from vector<2x2xf32> +// CHECK-NEXT: %[[PACKED_00:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_00]][0], %[[SCALE_SCALAR_00]] : vector<1xf8E5M2> to vector<2xf16> +// CHECK-NEXT: %[[OUT_SLICE_00:.+]] = vector.extract_strided_slice %[[PACKED_00]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16> +// CHECK-NEXT: %[[OUT_VEC_00:.+]] = vector.shape_cast %[[OUT_SLICE_00]] : vector<1xf16> to vector<1x1xf16> +// CHECK-NEXT: %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_VEC_00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<2x2xf16> +// CHECK-NEXT: %[[IN_SLICE_01:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf8E5M2> to vector<1x1xf8E5M2> +// CHECK-NEXT: %[[IN_VEC_01:.+]] = vector.shape_cast %[[IN_SLICE_01]] : vector<1x1xf8E5M2> to vector<1xf8E5M2> +// CHECK-NEXT: %[[SCALE_SCALAR_01:.+]] = vector.extract %[[SCALE_EXT]][0, 1] : f32 from vector<2x2xf32> +// CHECK-NEXT: %[[PACKED_01:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_01]][0], %[[SCALE_SCALAR_01]] : vector<1xf8E5M2> to vector<2xf16> +// CHECK-NEXT: %[[OUT_SLICE_01:.+]] = vector.extract_strided_slice %[[PACKED_01]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16> +// CHECK-NEXT: %[[OUT_VEC_01:.+]] = vector.shape_cast %[[OUT_SLICE_01]] : vector<1xf16> to vector<1x1xf16> +// CHECK-NEXT: %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_VEC_01]], %[[ACC_A]] {offsets = [0, 1], strides = [1, 1]} : vector<1x1xf16> into vector<2x2xf16> +// CHECK-NEXT: %[[IN_SLICE_10:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf8E5M2> to vector<1x1xf8E5M2> +// CHECK-NEXT: %[[IN_VEC_10:.+]] = vector.shape_cast %[[IN_SLICE_10]] : vector<1x1xf8E5M2> to vector<1xf8E5M2> +// CHECK-NEXT: %[[SCALE_SCALAR_10:.+]] = vector.extract %[[SCALE_EXT]][1, 0] : f32 from vector<2x2xf32> +// CHECK-NEXT: %[[PACKED_10:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_10]][0], %[[SCALE_SCALAR_10]] : vector<1xf8E5M2> to vector<2xf16> +// CHECK-NEXT: %[[OUT_SLICE_10:.+]] = vector.extract_strided_slice %[[PACKED_10]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16> +// CHECK-NEXT: %[[OUT_VEC_10:.+]] = vector.shape_cast %[[OUT_SLICE_10]] : vector<1xf16> to vector<1x1xf16> +// CHECK-NEXT: %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_VEC_10]], %[[ACC_B]] {offsets = [1, 0], strides = [1, 1]} : vector<1x1xf16> into vector<2x2xf16> +// CHECK-NEXT: %[[IN_SLICE_11:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf8E5M2> to vector<1x1xf8E5M2> +// CHECK-NEXT: %[[IN_VEC_11:.+]] = vector.shape_cast %[[IN_SLICE_11]] : vector<1x1xf8E5M2> to vector<1xf8E5M2> +// CHECK-NEXT: %[[SCALE_SCALAR_11:.+]] = vector.extract %[[SCALE_EXT]][1, 1] : f32 from vector<2x2xf32> +// CHECK-NEXT: %[[PACKED_11:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_11]][0], %[[SCALE_SCALAR_11]] : vector<1xf8E5M2> to vector<2xf16> +// CHECK-NEXT: %[[OUT_SLICE_11:.+]] = vector.extract_strided_slice %[[PACKED_11]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16> +// CHECK-NEXT: %[[OUT_VEC_11:.+]] = vector.shape_cast %[[OUT_SLICE_11]] : vector<1xf16> to vector<1x1xf16> +// CHECK-NEXT: %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_VEC_11]], %[[ACC_A]] {offsets = [1, 1], strides = [1, 1]} : vector<1x1xf16> into vector<2x2xf16> +// CHECK-NEXT: return %[[ACC_B]] : vector<2x2xf16> +func.func @conversion_f8_f16_fallback(%in: vector<2x2xf8E5M2>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf16> { + %ext = arith.scaling_extf %in, %scale : vector<2x2xf8E5M2>, vector<2x2xf8E8M0FNU> to vector<2x2xf16> + return %ext : vector<2x2xf16> +} + +// ----- + +// CHECK-LABEL: @conversion_f4_f16_fallback +// CHECK: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf16> +// CHECK-NEXT: %[[SCALE_EXT:.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32> +// CHECK-NEXT: %[[IN_SLICE_00:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf4E2M1FN> to vector<1x1xf4E2M1FN> +// CHECK-NEXT: %[[IN_VEC_00:.+]] = vector.shape_cast %[[IN_SLICE_00]] : vector<1x1xf4E2M1FN> to vector<1xf4E2M1FN> +// CHECK-NEXT: %[[SCALE_SCALAR_00:.+]] = vector.extract %[[SCALE_EXT]][0, 0] : f32 from vector<2x2xf32> +// CHECK-NEXT: %[[PACKED_00:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_00]][0], %[[SCALE_SCALAR_00]] : vector<1xf4E2M1FN> to vector<2xf16> +// CHECK-NEXT: %[[OUT_SLICE_00:.+]] = vector.extract_strided_slice %[[PACKED_00]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16> +// CHECK-NEXT: %[[OUT_VEC_00:.+]] = vector.shape_cast %[[OUT_SLICE_00]] : vector<1xf16> to vector<1x1xf16> +// CHECK-NEXT: %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_VEC_00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<2x2xf16> +// CHECK-NEXT: %[[IN_SLICE_01:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf4E2M1FN> to vector<1x1xf4E2M1FN> +// CHECK-NEXT: %[[IN_VEC_01:.+]] = vector.shape_cast %[[IN_SLICE_01]] : vector<1x1xf4E2M1FN> to vector<1xf4E2M1FN> +// CHECK-NEXT: %[[SCALE_SCALAR_01:.+]] = vector.extract %[[SCALE_EXT]][0, 1] : f32 from vector<2x2xf32> +// CHECK-NEXT: %[[PACKED_01:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_01]][0], %[[SCALE_SCALAR_01]] : vector<1xf4E2M1FN> to vector<2xf16> +// CHECK-NEXT: %[[OUT_SLICE_01:.+]] = vector.extract_strided_slice %[[PACKED_01]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16> +// CHECK-NEXT: %[[OUT_VEC_01:.+]] = vector.shape_cast %[[OUT_SLICE_01]] : vector<1xf16> to vector<1x1xf16> +// CHECK-NEXT: %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_VEC_01]], %[[ACC_A]] {offsets = [0, 1], strides = [1, 1]} : vector<1x1xf16> into vector<2x2xf16> +// CHECK-NEXT: %[[IN_SLICE_10:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 0], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf4E2M1FN> to vector<1x1xf4E2M1FN> +// CHECK-NEXT: %[[IN_VEC_10:.+]] = vector.shape_cast %[[IN_SLICE_10]] : vector<1x1xf4E2M1FN> to vector<1xf4E2M1FN> +// CHECK-NEXT: %[[SCALE_SCALAR_10:.+]] = vector.extract %[[SCALE_EXT]][1, 0] : f32 from vector<2x2xf32> +// CHECK-NEXT: %[[PACKED_10:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_10]][0], %[[SCALE_SCALAR_10]] : vector<1xf4E2M1FN> to vector<2xf16> +// CHECK-NEXT: %[[OUT_SLICE_10:.+]] = vector.extract_strided_slice %[[PACKED_10]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16> +// CHECK-NEXT: %[[OUT_VEC_10:.+]] = vector.shape_cast %[[OUT_SLICE_10]] : vector<1xf16> to vector<1x1xf16> +// CHECK-NEXT: %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_VEC_10]], %[[ACC_B]] {offsets = [1, 0], strides = [1, 1]} : vector<1x1xf16> into vector<2x2xf16> +// CHECK-NEXT: %[[IN_SLICE_11:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 1], strides = [1, 1]} : vector<2x2xf4E2M1FN> to vector<1x1xf4E2M1FN> +// CHECK-NEXT: %[[IN_VEC_11:.+]] = vector.shape_cast %[[IN_SLICE_11]] : vector<1x1xf4E2M1FN> to vector<1xf4E2M1FN> +// CHECK-NEXT: %[[SCALE_SCALAR_11:.+]] = vector.extract %[[SCALE_EXT]][1, 1] : f32 from vector<2x2xf32> +// CHECK-NEXT: %[[PACKED_11:.+]] = amdgpu.scaled_ext_packed %[[IN_VEC_11]][0], %[[SCALE_SCALAR_11]] : vector<1xf4E2M1FN> to vector<2xf16> +// CHECK-NEXT: %[[OUT_SLICE_11:.+]] = vector.extract_strided_slice %[[PACKED_11]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16> +// CHECK-NEXT: %[[OUT_VEC_11:.+]] = vector.shape_cast %[[OUT_SLICE_11]] : vector<1xf16> to vector<1x1xf16> +// CHECK-NEXT: %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_VEC_11]], %[[ACC_A]] {offsets = [1, 1], strides = [1, 1]} : vector<1x1xf16> into vector<2x2xf16> +// CHECK-NEXT: return %[[ACC_B]] : vector<2x2xf16> +func.func @conversion_f4_f16_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf16> { + %ext = arith.scaling_extf %in, %scale : vector<2x2xf4E2M1FN>, vector<2x2xf8E8M0FNU> to vector<2x2xf16> + return %ext : vector<2x2xf16> +} + +// ----- + +// CHECK-LABEL: @conversion_broadcast +// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<8x2x4xf32> +// CHECK-DAG: %[[BCAST:.+]] = vector.broadcast %arg1 +// CHECK-DAG: %[[IN_CAST:.+]] = vector.shape_cast %arg0 +// CHECK-DAG: %[[SCALE_CAST:.+]] = vector.shape_cast %[[BCAST]] +// CHECK-DAG: %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_CAST]] +// CHECK-DAG: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 0, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 0, 0] +// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]} +// CHECK-NEXT: amdgpu.scaled_ext_packed +// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]} +// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]} +// CHECK-NEXT: amdgpu.scaled_ext_packed +// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]} +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]} +// CHECK-NEXT: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 1, 0] +// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]} +// CHECK-NEXT: amdgpu.scaled_ext_packed +// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]} +// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]} +// CHECK-NEXT: amdgpu.scaled_ext_packed +// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]} +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]} +func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf32> { + %bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU> + %cast1 = vector.shape_cast %in : vector<8x8xf8E5M2> to vector<8x2x4xf8E5M2> + %cast2 = vector.shape_cast %bc : vector<4x8x2xf8E8M0FNU> to vector<8x2x4xf8E8M0FNU> + %ext = arith.scaling_extf %cast1, %cast2 : vector<8x2x4xf8E5M2>, vector<8x2x4xf8E8M0FNU> to vector<8x2x4xf32> + %cast3 = vector.shape_cast %ext : vector<8x2x4xf32> to vector<8x8xf32> + return %cast3 : vector<8x8xf32> +} + +// ----- + +// CHECK-LABEL: @conversion_broadcast_odd +// CHECK-NEXT: %[[CST_PARTIAL:.+]] = arith.constant dense<0.000000e+00> : vector<3xf32> +// CHECK-NEXT: %[[CST_FINAL:.+]] = arith.constant dense<0.000000e+00> : vector<6xf32> +// CHECK-NEXT: %[[SCALE_BC:.+]] = vector.broadcast %arg1 : vector<2xf8E8M0FNU> to vector<3x2xf8E8M0FNU> +// CHECK-NEXT: %[[SCALE_FLAT:.+]] = vector.shape_cast %[[SCALE_BC]] : vector<3x2xf8E8M0FNU> to vector<6xf8E8M0FNU> +// CHECK-NEXT: %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_FLAT]] : vector<6xf8E8M0FNU> to vector<6xf32> +// CHECK-NEXT: %[[IN_SLICE_0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2> +// CHECK-NEXT: %[[SCALE_SCALAR_0:.+]] = vector.extract %[[SCALE_EXT]][0] : f32 from vector<6xf32> +// CHECK-NEXT: %[[IN_CHUNK_0A:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2> +// CHECK-NEXT: %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0A]][0], %[[SCALE_SCALAR_0]] : vector<2xf8E5M2> to vector<2xf32> +// CHECK-NEXT: %[[PARTIAL_ACC_0:.+]] = vector.insert_strided_slice %[[PACKED_0A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32> +// CHECK-NEXT: %[[IN_CHUNK_0B:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2> +// CHECK-NEXT: %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0B]][0], %[[SCALE_SCALAR_0]] : vector<1xf8E5M2> to vector<2xf32> +// CHECK-NEXT: %[[PACKED_0B:.+]] = vector.extract_strided_slice %[[PACKED_0B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32> +// CHECK-NEXT: %[[OUT_SLICE_0:.+]] = vector.insert_strided_slice %[[PACKED_0B]], %[[PARTIAL_ACC_0]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32> +// CHECK-NEXT: %[[FINAL_ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SLICE_0]], %[[CST_FINAL]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<6xf32> +// CHECK-NEXT: %[[IN_SLICE_1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2> +// CHECK-NEXT: %[[SCALE_SCALAR_1:.+]] = vector.extract %[[SCALE_EXT]][3] : f32 from vector<6xf32> +// CHECK-NEXT: %[[IN_CHUNK_1A:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2> +// CHECK-NEXT: %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1A]][0], %[[SCALE_SCALAR_1]] : vector<2xf8E5M2> to vector<2xf32> +// CHECK-NEXT: %[[PARTIAL_ACC_1:.+]] = vector.insert_strided_slice %[[PACKED_1A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32> +// CHECK-NEXT: %[[IN_CHUNK_1B:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2> +// CHECK-NEXT: %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1B]][0], %[[SCALE_SCALAR_1]] : vector<1xf8E5M2> to vector<2xf32> +// CHECK-NEXT: %[[PACKED_1B:.+]] = vector.extract_strided_slice %[[PACKED_1B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32> +// CHECK-NEXT: %[[OUT_SLICE_1:.+]] = vector.insert_strided_slice %[[PACKED_1B]], %[[PARTIAL_ACC_1]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32> +// CHECK-NEXT: %[[RESULT:.+]] = vector.insert_strided_slice %[[OUT_SLICE_1]], %[[FINAL_ACC_A]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<6xf32> +// CHECK-NEXT: return %[[RESULT]] : vector<6xf32> +func.func @conversion_broadcast_odd(%in: vector<6xf8E5M2>, %scale: vector<2xf8E8M0FNU>) -> vector<6xf32> { + %bc = vector.broadcast %scale : vector<2xf8E8M0FNU> to vector<3x2xf8E8M0FNU> + %cast = vector.shape_cast %bc : vector<3x2xf8E8M0FNU> to vector<6xf8E8M0FNU> + %ext = arith.scaling_extf %in, %cast : vector<6xf8E5M2>, vector<6xf8E8M0FNU> to vector<6xf32> + return %ext : vector<6xf32> +} + +// ----- +// CHECK-LABEL: @conversion_splat +// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK-DAG: %[[SCALE_SPLAT:.+]] = vector.splat %arg1 : vector<4xf8E8M0FNU> +// CHECK-DAG: %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32> +// CHECK-DAG: %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32> +// CHECK: %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2> +// CHECK-NEXT: %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK0]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32> +// CHECK-NEXT: %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> +// CHECK-NEXT: %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2> +// CHECK-NEXT: %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK1]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32> +// CHECK-NEXT: %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> +// CHECK-NEXT: return %[[FINAL_RESULT]] : vector<4xf32> +func.func @conversion_splat(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vector<4xf32> { + %splat = vector.splat %scale : vector<4xf8E8M0FNU> + %ext = arith.scaling_extf %in, %splat : vector<4xf8E5M2>, vector<4xf8E8M0FNU> to vector<4xf32> + return %ext : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: @conversion_scalar +// CHECK: %[[SCALE_F32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32 +// CHECK-NEXT: %[[SPLAT_IN:.+]] = vector.splat %arg0 : vector<1xf8E5M2> +// CHECK-NEXT: %[[PACKED_EXT:.+]] = amdgpu.scaled_ext_packed %[[SPLAT_IN]][0], %[[SCALE_F32]] : vector<1xf8E5M2> to vector<2xf32> +// CHECK-NEXT: %[[RESULT:.+]] = vector.extract %[[PACKED_EXT]][0] : f32 from vector<2xf32> +// CHECK-NEXT: return %[[RESULT]] : f32 +func.func @conversion_scalar(%in: f8E5M2, %scale: f8E8M0FNU) -> f32 { + %ext = arith.scaling_extf %in, %scale : f8E5M2, f8E8M0FNU to f32 + return %ext : f32 +} diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir new file mode 100644 index 0000000000000..0519050c5ecc4 --- /dev/null +++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir @@ -0,0 +1,193 @@ +// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s + +// CHECK-LABEL: @conversion_f8_fallback +// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf8E5M2> +// CHECK-DAG: %[[SCALE_EXT:.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32> +// CHECK: %[[IN_SLICE_00:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} +// CHECK-NEXT: %[[IN_SCALAR_00:.+]] = vector.shape_cast %[[IN_SLICE_00]] +// CHECK-NEXT: %[[SCALE_SCALAR_00:.+]] = vector.extract %[[SCALE_EXT]][0, 0] +// CHECK-NEXT: %[[PACKED_00:.+]] = amdgpu.packed_scaled_trunc %[[IN_SCALAR_00]] into undef[0], %[[SCALE_SCALAR_00]] +// CHECK-NEXT: %[[OUT_SLICE_00:.+]] = vector.extract_strided_slice %[[PACKED_00]] +// CHECK-NEXT: %[[OUT_SCALAR_00:.+]] = vector.shape_cast %[[OUT_SLICE_00]] +// CHECK-NEXT: %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SCALAR_00]], %[[CST]] +// CHECK-NEXT: %[[IN_SLICE_01:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} +// CHECK-NEXT: %[[IN_SCALAR_01:.+]] = vector.shape_cast %[[IN_SLICE_01]] +// CHECK-NEXT: %[[SCALE_SCALAR_01:.+]] = vector.extract %[[SCALE_EXT]][0, 1] +// CHECK-NEXT: %[[PACKED_01:.+]] = amdgpu.packed_scaled_trunc %[[IN_SCALAR_01]] into undef[0], %[[SCALE_SCALAR_01]] +// CHECK-NEXT: %[[OUT_SLICE_01:.+]] = vector.extract_strided_slice %[[PACKED_01]] +// CHECK-NEXT: %[[OUT_SCALAR_01:.+]] = vector.shape_cast %[[OUT_SLICE_01]] +// CHECK-NEXT: %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_SCALAR_01]], %[[ACC_A]] +// CHECK-NEXT: %[[IN_SLICE_10:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 0], sizes = [1, 1], strides = [1, 1]} +// CHECK-NEXT: %[[IN_SCALAR_10:.+]] = vector.shape_cast %[[IN_SLICE_10]] +// CHECK-NEXT: %[[SCALE_SCALAR_10:.+]] = vector.extract %[[SCALE_EXT]][1, 0] +// CHECK-NEXT: %[[PACKED_10:.+]] = amdgpu.packed_scaled_trunc %[[IN_SCALAR_10]] into undef[0], %[[SCALE_SCALAR_10]] +// CHECK-NEXT: %[[OUT_SLICE_10:.+]] = vector.extract_strided_slice %[[PACKED_10]] +// CHECK-NEXT: %[[OUT_SCALAR_10:.+]] = vector.shape_cast %[[OUT_SLICE_10]] +// CHECK-NEXT: %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SCALAR_10]], %[[ACC_B]] +// CHECK-NEXT: %[[IN_SLICE_11:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 1], strides = [1, 1]} +// CHECK-NEXT: %[[IN_SCALAR_11:.+]] = vector.shape_cast %[[IN_SLICE_11]] +// CHECK-NEXT: %[[SCALE_SCALAR_11:.+]] = vector.extract %[[SCALE_EXT]][1, 1] +// CHECK-NEXT: %[[PACKED_11:.+]] = amdgpu.packed_scaled_trunc %[[IN_SCALAR_11]] into undef[0], %[[SCALE_SCALAR_11]] +// CHECK-NEXT: %[[OUT_SLICE_11:.+]] = vector.extract_strided_slice %[[PACKED_11]] +// CHECK-NEXT: %[[OUT_SCALAR_11:.+]] = vector.shape_cast %[[OUT_SLICE_11]] +// CHECK-NEXT: %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_SCALAR_11]], %[[ACC_A]] +// CHECK-NEXT: return %[[ACC_B]] : vector<2x2xf8E5M2> +func.func @conversion_f8_fallback(%in: vector<2x2xf32>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf8E5M2> { + %ext = arith.scaling_truncf %in, %scale : vector<2x2xf32>, vector<2x2xf8E8M0FNU> to vector<2x2xf8E5M2> + return %ext : vector<2x2xf8E5M2> +} + +// ----- + +// CHECK-LABEL: @conversion_f4_fallback +// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf4E2M1FN> +// CHECK-DAG: %[[SCALE_EXT:.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32> +// CHECK: %[[IN_SLICE_00:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]} +// CHECK-NEXT: %[[IN_SCALAR_00:.+]] = vector.shape_cast %[[IN_SLICE_00]] +// CHECK-NEXT: %[[SCALE_SCALAR_00:.+]] = vector.extract %[[SCALE_EXT]][0, 0] +// CHECK-NEXT: %[[PACKED_00:.+]] = amdgpu.packed_scaled_trunc %[[IN_SCALAR_00]] into undef[0], %[[SCALE_SCALAR_00]] +// CHECK-NEXT: %[[OUT_SLICE_00:.+]] = vector.extract_strided_slice %[[PACKED_00]] +// CHECK-NEXT: %[[OUT_SCALAR_00:.+]] = vector.shape_cast %[[OUT_SLICE_00]] +// CHECK-NEXT: %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SCALAR_00]], %[[CST]] +// CHECK-NEXT: %[[IN_SLICE_01:.+]] = vector.extract_strided_slice %arg0 {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]} +// CHECK-NEXT: %[[IN_SCALAR_01:.+]] = vector.shape_cast %[[IN_SLICE_01]] +// CHECK-NEXT: %[[SCALE_SCALAR_01:.+]] = vector.extract %[[SCALE_EXT]][0, 1] +// CHECK-NEXT: %[[PACKED_01:.+]] = amdgpu.packed_scaled_trunc %[[IN_SCALAR_01]] into undef[0], %[[SCALE_SCALAR_01]] +// CHECK-NEXT: %[[OUT_SLICE_01:.+]] = vector.extract_strided_slice %[[PACKED_01]] +// CHECK-NEXT: %[[OUT_SCALAR_01:.+]] = vector.shape_cast %[[OUT_SLICE_01]] +// CHECK-NEXT: %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_SCALAR_01]], %[[ACC_A]] +// CHECK-NEXT: %[[IN_SLICE_10:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 0], sizes = [1, 1], strides = [1, 1]} +// CHECK-NEXT: %[[IN_SCALAR_10:.+]] = vector.shape_cast %[[IN_SLICE_10]] +// CHECK-NEXT: %[[SCALE_SCALAR_10:.+]] = vector.extract %[[SCALE_EXT]][1, 0] +// CHECK-NEXT: %[[PACKED_10:.+]] = amdgpu.packed_scaled_trunc %[[IN_SCALAR_10]] into undef[0], %[[SCALE_SCALAR_10]] +// CHECK-NEXT: %[[OUT_SLICE_10:.+]] = vector.extract_strided_slice %[[PACKED_10]] +// CHECK-NEXT: %[[OUT_SCALAR_10:.+]] = vector.shape_cast %[[OUT_SLICE_10]] +// CHECK-NEXT: %[[ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SCALAR_10]], %[[ACC_B]] +// CHECK-NEXT: %[[IN_SLICE_11:.+]] = vector.extract_strided_slice %arg0 {offsets = [1, 1], sizes = [1, 1], strides = [1, 1]} +// CHECK-NEXT: %[[IN_SCALAR_11:.+]] = vector.shape_cast %[[IN_SLICE_11]] +// CHECK-NEXT: %[[SCALE_SCALAR_11:.+]] = vector.extract %[[SCALE_EXT]][1, 1] +// CHECK-NEXT: %[[PACKED_11:.+]] = amdgpu.packed_scaled_trunc %[[IN_SCALAR_11]] into undef[0], %[[SCALE_SCALAR_11]] +// CHECK-NEXT: %[[OUT_SLICE_11:.+]] = vector.extract_strided_slice %[[PACKED_11]] +// CHECK-NEXT: %[[OUT_SCALAR_11:.+]] = vector.shape_cast %[[OUT_SLICE_11]] +// CHECK-NEXT: %[[ACC_B:.+]] = vector.insert_strided_slice %[[OUT_SCALAR_11]], %[[ACC_A]] +// CHECK-NEXT: return %[[ACC_B]] : vector<2x2xf4E2M1FN> +func.func @conversion_f4_fallback(%in: vector<2x2xf32>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf4E2M1FN> { + %ext = arith.scaling_truncf %in, %scale : vector<2x2xf32>, vector<2x2xf8E8M0FNU> to vector<2x2xf4E2M1FN> + return %ext : vector<2x2xf4E2M1FN> +} + +// ----- + +// CHECK-LABEL: @conversion_broadcast +// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<8x2x4xf8E5M2> +// CHECK-DAG: %[[BCAST:.+]] = vector.broadcast %arg1 +// CHECK-DAG: %[[IN_CAST:.+]] = vector.shape_cast %arg0 +// CHECK-DAG: %[[SCALE_CAST:.+]] = vector.shape_cast %[[BCAST]] +// CHECK-DAG: %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_CAST]] +// CHECK-DAG: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 0, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 0, 0] +// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]} +// CHECK-NEXT: amdgpu.packed_scaled_trunc +// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]} +// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]} +// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]} +// CHECK-NEXT: amdgpu.packed_scaled_trunc +// CHECK-NEXT: vector.extract_strided_slice +// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]} +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]} +// CHECK-NEXT: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 1, 0] +// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]} +// CHECK-NEXT: amdgpu.packed_scaled_trunc +// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]} +// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]} +// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]} +// CHECK-NEXT: amdgpu.packed_scaled_trunc +// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]} +// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]} +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]} +func.func @conversion_broadcast(%in: vector<8x8xf32>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf8E5M2> { + %bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU> + %cast1 = vector.shape_cast %in : vector<8x8xf32> to vector<8x2x4xf32> + %cast2 = vector.shape_cast %bc : vector<4x8x2xf8E8M0FNU> to vector<8x2x4xf8E8M0FNU> + %ext = arith.scaling_truncf %cast1, %cast2 : vector<8x2x4xf32>, vector<8x2x4xf8E8M0FNU> to vector<8x2x4xf8E5M2> + %cast3 = vector.shape_cast %ext : vector<8x2x4xf8E5M2> to vector<8x8xf8E5M2> + return %cast3 : vector<8x8xf8E5M2> +} + +// ----- + +// CHECK-LABEL: @conversion_broadcast_odd +// CHECK-NEXT: %[[CST3:.+]] = arith.constant dense<0.000000e+00> : vector<3xf8E5M2> +// CHECK-NEXT: %[[CST6:.+]] = arith.constant dense<0.000000e+00> : vector<6xf8E5M2> +// CHECK-NEXT: %[[SCALE_BCAST:.+]] = vector.broadcast %arg1 : vector<2xf8E8M0FNU> to vector<3x2xf8E8M0FNU> +// CHECK-NEXT: %[[SCALE_FLAT:.+]] = vector.shape_cast %[[SCALE_BCAST]] : vector<3x2xf8E8M0FNU> to vector<6xf8E8M0FNU> +// CHECK-NEXT: %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_FLAT]] : vector<6xf8E8M0FNU> to vector<6xf32> +// CHECK-NEXT: %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> +// CHECK-NEXT: %[[SCALE0:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<6xf32> +// CHECK-NEXT: %[[IN_CHUNK0_PART0:.+]] = vector.extract_strided_slice %[[IN_CHUNK0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf32> to vector<2xf32> +// CHECK-NEXT: %[[PACKED0_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART0]] into undef[0], %[[SCALE0]] : vector<2xf32> to vector<4xf8E5M2> +// CHECK-NEXT: %[[OUT_CHUNK0_PART0:.+]] = vector.extract_strided_slice %[[PACKED0_PART0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2> +// CHECK-NEXT: %[[ACCUM0_PART0:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0_PART0]], %[[CST3]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<3xf8E5M2> +// CHECK-NEXT: %[[IN_CHUNK0_PART1:.+]] = vector.extract_strided_slice %[[IN_CHUNK0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32> +// CHECK-NEXT: %[[PACKED0_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0_PART1]] into undef[0], %[[SCALE0]] : vector<1xf32> to vector<4xf8E5M2> +// CHECK-NEXT: %[[OUT_CHUNK0_PART1:.+]] = vector.extract_strided_slice %[[PACKED0_PART1]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xf8E5M2> to vector<1xf8E5M2> +// CHECK-NEXT: %[[CHUNK0_RES:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0_PART1]], %[[ACCUM0_PART0]] {offsets = [2], strides = [1]} : vector<1xf8E5M2> into vector<3xf8E5M2> +// CHECK-NEXT: %[[FINAL_ACCUM_A:.+]] = vector.insert_strided_slice %[[CHUNK0_RES]], %[[CST6]] {offsets = [0], strides = [1]} : vector<3xf8E5M2> into vector<6xf8E5M2> +// CHECK-NEXT: %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> +// CHECK-NEXT: %[[SCALE1:.+]] = vector.extract %[[SCALE_EXTF]][3] : f32 from vector<6xf32> +// CHECK-NEXT: %[[IN_CHUNK1_PART0:.+]] = vector.extract_strided_slice %[[IN_CHUNK1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf32> to vector<2xf32> +// CHECK-NEXT: %[[PACKED1_PART0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART0]] into undef[0], %[[SCALE1]] : vector<2xf32> to vector<4xf8E5M2> +// CHECK-NEXT: %[[OUT_CHUNK1_PART0:.+]] = vector.extract_strided_slice %[[PACKED1_PART0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2> +// CHECK-NEXT: %[[ACCUM1_PART0:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1_PART0]], %[[CST3]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<3xf8E5M2> +// CHECK-NEXT: %[[IN_CHUNK1_PART1:.+]] = vector.extract_strided_slice %[[IN_CHUNK1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf32> to vector<1xf32> +// CHECK-NEXT: %[[PACKED1_PART1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1_PART1]] into undef[0], %[[SCALE1]] : vector<1xf32> to vector<4xf8E5M2> +// CHECK-NEXT: %[[OUT_CHUNK1_PART1:.+]] = vector.extract_strided_slice %[[PACKED1_PART1]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xf8E5M2> to vector<1xf8E5M2> +// CHECK-NEXT: %[[CHUNK1_RES:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1_PART1]], %[[ACCUM1_PART0]] {offsets = [2], strides = [1]} : vector<1xf8E5M2> into vector<3xf8E5M2> +// CHECK-NEXT: %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[CHUNK1_RES]], %[[FINAL_ACCUM_A]] {offsets = [3], strides = [1]} : vector<3xf8E5M2> into vector<6xf8E5M2> +// CHECK-NEXT: return %[[FINAL_RESULT]] : vector<6xf8E5M2> +func.func @conversion_broadcast_odd(%in: vector<6xf32>, %scale: vector<2xf8E8M0FNU>) -> vector<6xf8E5M2> { + %bc = vector.broadcast %scale : vector<2xf8E8M0FNU> to vector<3x2xf8E8M0FNU> + %cast = vector.shape_cast %bc : vector<3x2xf8E8M0FNU> to vector<6xf8E8M0FNU> + %ext = arith.scaling_truncf %in, %cast : vector<6xf32>, vector<6xf8E8M0FNU> to vector<6xf8E5M2> + return %ext : vector<6xf8E5M2> +} + +// ----- + +// CHECK-LABEL: @conversion_splat +// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4xf8E5M2> +// CHECK-DAG: %[[SCALE_SPLAT:.+]] = vector.splat %arg1 : vector<4xf8E8M0FNU> +// CHECK-DAG: %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32> +// CHECK-DAG: %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32> +// CHECK: %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> +// CHECK-NEXT: %[[PACKED0:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK0]] into undef[0], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2> +// CHECK-NEXT: %[[OUT_CHUNK0:.+]] = vector.extract_strided_slice %[[PACKED0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2> +// CHECK-NEXT: %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf8E5M2> into vector<4xf8E5M2> +// CHECK-NEXT: %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> +// CHECK-NEXT: %[[PACKED1:.+]] = amdgpu.packed_scaled_trunc %[[IN_CHUNK1]] into undef[0], %[[SCALE_SCALAR]] : vector<2xf32> to vector<4xf8E5M2> +// CHECK-NEXT: %[[OUT_CHUNK1:.+]] = vector.extract_strided_slice %[[PACKED1]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2> +// CHECK-NEXT: %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf8E5M2> into vector<4xf8E5M2> +// CHECK-NEXT: return %[[FINAL_RESULT]] : vector<4xf8E5M2> +func.func @conversion_splat(%in: vector<4xf32>, %scale: f8E8M0FNU) -> vector<4xf8E5M2> { + %splat = vector.splat %scale : vector<4xf8E8M0FNU> + %ext = arith.scaling_truncf %in, %splat : vector<4xf32>, vector<4xf8E8M0FNU> to vector<4xf8E5M2> + return %ext : vector<4xf8E5M2> +} + +// ----- + +// CHECK-LABEL: @conversion_scalar +// CHECK: %[[SCALE_F32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32 +// CHECK-NEXT: %[[SPLAT_IN:.+]] = vector.splat %arg0 : vector<1xf32> +// CHECK-NEXT: %[[PACKED_TRUNC:.+]] = amdgpu.packed_scaled_trunc %[[SPLAT_IN]] into undef[0], %[[SCALE_F32]] +// CHECK-NEXT: %[[RESULT:.+]] = vector.extract %[[PACKED_TRUNC]][0] +// CHECK-NEXT: return %[[RESULT]] : f8E5M2 +func.func @conversion_scalar(%in: f32, %scale: f8E8M0FNU) -> f8E5M2 { + %ext = arith.scaling_truncf %in, %scale : f32, f8E8M0FNU to f8E5M2 + return %ext : f8E5M2 +}