diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index b6e101952676a..1ecf7d97bf51b 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -72,12 +72,34 @@ LogicalResult EmulateFloatPattern::matchAndRewrite( rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes, op->getAttrs(), op->getSuccessors(), /*regions=*/{}); SmallVector newResults(expandedOp->getResults()); - for (auto [res, oldType, newType] : llvm::zip_equal( - MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) { + for (auto [res, origResult, oldType, newType] : + llvm::zip_equal(MutableArrayRef{newResults}, op->getResults(), + op->getResultTypes(), resultTypes)) { if (oldType != newType) { - auto truncFOp = arith::TruncFOp::create(rewriter, loc, oldType, res); - truncFOp.setFastmath(arith::FastMathFlags::contract); - res = truncFOp.getResult(); + // If all uses of the original result are arith.extf ops that extend from + // the unsupported type to the target (wider) type, we can skip the + // intermediate truncf round-trip and directly replace those extf ops with + // the wider emulated value. This avoids emitting arith.extf on the + // unsupported type in the output, which cannot be lowered to LLVM for + // types that lack native hardware support (e.g. fp8 variants). + Type targetType = newType; + bool allUsersAreExtFToTargetType = + !origResult.use_empty() && + llvm::all_of(origResult.getUsers(), [targetType](Operation *user) { + auto extFOp = dyn_cast(user); + return extFOp && extFOp.getType() == targetType; + }); + if (allUsersAreExtFToTargetType) { + // Replace all extf users directly with the wider emulated value. + for (Operation *user : + llvm::make_early_inc_range(origResult.getUsers())) + rewriter.replaceOp(user, res); + // No truncf needed; res already has the target (wider) type. + } else { + auto truncFOp = arith::TruncFOp::create(rewriter, loc, oldType, res); + truncFOp.setFastmath(arith::FastMathFlags::contract); + res = truncFOp.getResult(); + } } } rewriter.replaceOp(op, newResults); diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir index fcd004ac554aa..00fb6d282f4a6 100644 --- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir +++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt --split-input-file --arith-emulate-unsupported-floats="source-types=bf16,f8E4M3FNUZ target-type=f32" %s | FileCheck %s +// RUN: mlir-opt --split-input-file --arith-emulate-unsupported-floats="source-types=f8E4M3FNUZ target-type=f32" --convert-arith-to-llvm %s | FileCheck %s --check-prefix=LLVM func.func @basic_expansion(%x: bf16) -> bf16 { // CHECK-LABEL: @basic_expansion @@ -60,14 +61,20 @@ func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) { // ----- +// When the result of an emulated op is only used by an extf back to the +// target type, the pass skips the truncf/extf round-trip and uses the +// wider emulated value directly. This avoids emitting arith.extf on the +// unsupported type, which cannot be lowered to LLVM. func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> { // CHECK-LABEL: @vectors // CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ> // CHECK: [[A_EXP:%.+]] = arith.extf [[A]] fastmath : vector<4xf8E4M3FNUZ> to vector<4xf32> -// CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32> -// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] fastmath : vector<4xf32> to vector<4xf8E4M3FNUZ> -// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32> +// CHECK: [[RET:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32> +// CHECK-NOT: arith.truncf +// CHECK-NOT: arith.extf {{%.+}} : vector<4xf8E4M3FNUZ> // CHECK: return [[RET]] +// LLVM-LABEL: @vectors +// LLVM-NOT: llvm.fpext {{.*}} : vector<4xi8> %b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ> %ret = arith.extf %b : vector<4xf8E4M3FNUZ> to vector<4xf32> func.return %ret : vector<4xf32> @@ -75,6 +82,23 @@ func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> { // ----- +// When an emulated op's result has mixed users (not all are arith.extf to the +// target type), the pass falls back to the truncf/extf round-trip. +func.func @mixed_users(%a: bf16) -> (f32, bf16) { +// CHECK-LABEL: @mixed_users +// CHECK-SAME: [[A:%.+]]: bf16 +// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] fastmath : bf16 to f32 +// CHECK: [[PROD:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : f32 +// CHECK: [[TRUNC:%.+]] = arith.truncf [[PROD]] fastmath : f32 to bf16 +// CHECK: [[EXT:%.+]] = arith.extf [[TRUNC]] : bf16 to f32 +// CHECK: return [[EXT]], [[TRUNC]] + %b = arith.mulf %a, %a : bf16 + %ext = arith.extf %b : bf16 to f32 + func.return %ext, %b : f32, bf16 +} + +// ----- + func.func @no_expansion(%x: f32) -> f32 { // CHECK-LABEL: @no_expansion // CHECK-SAME: [[X:%.+]]: f32