-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArithToAMDGPU] Add option for saturating truncation to fp8 #74153
Conversation
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Krzysztof Drewniak (krzysz00) ChangesMany machine-learning applications (and most software written at AMD) expect the operation that truncates floats to 8-bit floats to be saturatinng. That is, they expect To enable handling this usecase, we add the saturate-fp8-truncf option to ArithToAMDGPU (off by default), which causes the requisite clamping code to be emitted. Said clamping code ensures that Inf and NaN are passed through exactly (and thus trancate to NaN). Full diff: https://github.com/llvm/llvm-project/pull/74153.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
index 7f445fee5ba6b82..a1c059800752aca 100644
--- a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
+++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
@@ -20,7 +20,8 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
namespace arith {
-void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns);
+void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
+ bool saturateFP8Truncf);
} // namespace arith
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 06756ff3df0bb3b..2aa2ad634aeb722 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -125,6 +125,12 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
}];
let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
+
+ let options = [
+ Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool",
+ /*default=*/"false",
+ "Whether truncation to 8-bit float types should be saturating">,
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 7785405eae67be3..d6b916e6e55423d 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -44,7 +44,10 @@ struct ExtfOnFloat8RewritePattern final
struct TruncfToFloat8RewritePattern final
: public OpRewritePattern<arith::TruncFOp> {
- using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
+ bool saturateFP8 = false;
+ TruncfToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
+ : OpRewritePattern<arith::TruncFOp>::OpRewritePattern(ctx),
+ saturateFP8(saturateFP8) {}
LogicalResult match(arith::TruncFOp op) const override;
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
@@ -127,6 +130,60 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
llvm_unreachable("The only 32-bit float type is f32");
}
+static Value getMaybeVectorConstant(PatternRewriter &rewriter, Location loc,
+ const APFloat &value, Type type) {
+ if (isa<FloatType>(type))
+ return rewriter.createOrFold<arith::ConstantOp>(
+ loc, type, rewriter.getFloatAttr(type, value));
+ TypedAttr splat = DenseElementsAttr::get(cast<ShapedType>(type), value);
+ return rewriter.createOrFold<arith::ConstantOp>(loc, type, splat);
+}
+
+// If `in` is a finite value, clamp it between the maximum and minimum values
+// of `outElemType` so that subsequent conversion instructions don't
+// overflow those out-of-range values to NaN. These semantics are commonly
+// used in machine-learning contexts where failure to clamp would lead to
+// excessive NaN production.
+static Value clampInput(PatternRewriter &rewriter, Location loc,
+ Type outElemType, Value source) {
+ Type sourceType = source.getType();
+ const llvm::fltSemantics &sourceSem =
+ cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
+ const llvm::fltSemantics &targetSem =
+ cast<FloatType>(outElemType).getFloatSemantics();
+
+ APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
+ APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
+ bool ignoredLosesInfo = false;
+ // We can ignore conversion failures here because this conversion promotes
+ // from a smaller type to a larger one.
+ (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
+ (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
+
+ Value minCst = getMaybeVectorConstant(rewriter, loc, min, sourceType);
+ Value maxCst = getMaybeVectorConstant(rewriter, loc, max, sourceType);
+
+ Value inf = getMaybeVectorConstant(
+ rewriter, loc, APFloat::getInf(sourceSem, /*Negative=*/false),
+ sourceType);
+ Value negInf = getMaybeVectorConstant(
+ rewriter, loc, APFloat::getInf(sourceSem, /*Negative=*/true), sourceType);
+ Value isInf = rewriter.createOrFold<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::OEQ, source, inf);
+ Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::OEQ, source, negInf);
+ Value isNan = rewriter.createOrFold<arith::CmpFOp>(
+ loc, arith::CmpFPredicate::UNO, source, source);
+ Value isNonFinite = rewriter.create<arith::OrIOp>(
+ loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
+
+ Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
+ Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
+ Value res =
+ rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
+ return res;
+}
+
LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const {
Type outType = op.getOut().getType();
if (auto outVecType = outType.dyn_cast<VectorType>()) {
@@ -145,6 +202,8 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
+ if (saturateFP8)
+ in = clampInput(rewriter, loc, outElemType, in);
VectorType truncResType = VectorType::get(4, outElemType);
if (!in.getType().isa<VectorType>()) {
Value asFloat = castToF32(in, loc, rewriter);
@@ -196,15 +255,16 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
}
void mlir::arith::populateArithToAMDGPUConversionPatterns(
- RewritePatternSet &patterns) {
- patterns.add<ExtfOnFloat8RewritePattern, TruncfToFloat8RewritePattern>(
- patterns.getContext());
+ RewritePatternSet &patterns, bool saturateFP8Truncf) {
+ patterns.add<ExtfOnFloat8RewritePattern>(patterns.getContext());
+ patterns.add<TruncfToFloat8RewritePattern>(patterns.getContext(),
+ saturateFP8Truncf);
}
void ArithToAMDGPUConversionPass::runOnOperation() {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
- arith::populateArithToAMDGPUConversionPatterns(patterns);
+ arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return signalPassFailure();
}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
new file mode 100644
index 000000000000000..d0c2cd4090117ff
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt --split-input-file %s \
+// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{saturate-fp8-truncf=true}))' \
+// RUN: | FileCheck %s
+
+// CHECK-LABEL: func.func @scalar_trunc
+// CHECK-SAME: ([[V:%.+]]: f16)
+// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+// CHECK-DAG: [[CMin:%.+]] = arith.constant -5.734400e+04 : f16
+// CHECK-DAG: [[CMax:%.+]] = arith.constant 5.734400e+04 : f16
+// CHECK-DAG: [[CInf:%.+]] = arith.constant 0x7C00 : f16
+// CHECK-DAG: [[CNegInf:%.+]] = arith.constant 0xFC00 : f16
+// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
+// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
+// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
+// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
+// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
+// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
+// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
+// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
+// CHECK: [[FLOAT:%.+]] = arith.extf [[SATURATED]] : f16 to f32
+// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2FNUZ>
+// CHECK: [[W:%.+]] = vector.extractelement [[TRUNCV]]{{\[}}[[C0]] : index] : vector<4xf8E5M2FNUZ>
+// CHECK: return [[W]] : f8E5M2FNUZ
+func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ {
+ %w = arith.truncf %v : f16 to f8E5M2FNUZ
+ return %w : f8E5M2FNUZ
+}
+
+// No 0-D test because arith.truncf hasn't been extended to support it.
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc
+// CHECK-SAME: ([[V:%.+]]: vector<2xf32>) -> vector<2xf8E4M3FNUZ> {
+// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
+// CHECK-DAG: [[CMin:%.+]] = arith.constant dense<-2.400000e+02> : vector<2xf32>
+// CHECK-DAG: [[CMax:%.+]] = arith.constant dense<2.400000e+02> : vector<2xf32>
+// CHECK-DAG: [[CInf:%.+]] = arith.constant dense<0x7F800000> : vector<2xf32>
+// CHECK-DAG: [[CNegInf:%.+]] = arith.constant dense<0xFF800000> : vector<2xf32>
+// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
+// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
+// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
+// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
+// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
+// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
+// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
+// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
+// CHECK: [[F0:%.+]] = vector.extractelement [[SATURATED]]{{\[}}[[C0]] : index]
+// CHECK: [[F1:%.+]] = vector.extractelement [[SATURATED]]{{\[}}[[C1]] : index]
+// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E4M3FNUZ>
+// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E4M3FNUZ> to vector<2xf8E4M3FNUZ>
+// CHECK: return [[W]] : vector<2xf8E4M3FNUZ>
+func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf8E4M3FNUZ> {
+ %w = arith.truncf %v : vector<2xf32> to vector<2xf8E4M3FNUZ>
+ return %w : vector<2xf8E4M3FNUZ>
+}
|
Ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some suggestions, but overall the clamping logic looks good to me. Some of the nits for naming come from me being used to different conventions, so feel free to disregard in this PR -- just wanted to surface things that surprised me.
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, | ||
bool saturateFP8Truncf); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a documentation comment?
also nit:
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, | |
bool saturateFP8Truncf); | |
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, | |
bool saturateFP8TruncF); |
@@ -44,7 +44,10 @@ struct ExtfOnFloat8RewritePattern final | |||
|
|||
struct TruncfToFloat8RewritePattern final | |||
: public OpRewritePattern<arith::TruncFOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: public
is redundant here BTW
using OpRewritePattern<arith::TruncFOp>::OpRewritePattern; | ||
bool saturateFP8 = false; | ||
TruncfToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8) | ||
: OpRewritePattern<arith::TruncFOp>::OpRewritePattern(ctx), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
: OpRewritePattern<arith::TruncFOp>::OpRewritePattern(ctx), | |
: OpRewritePattern::OpRewritePattern(ctx), |
static Value getMaybeVectorConstant(PatternRewriter &rewriter, Location loc, | ||
const APFloat &value, Type type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Maybe
reads to me like this could fail, but this is not the case. We have a similar helper in the other arith code where it's called createScalarOrSplatConstant
:
llvm-project/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
Lines 62 to 74 in a9f39ff
static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter, | |
Location loc, Type type, | |
const APInt &value) { | |
TypedAttr attr; | |
if (dyn_cast<IntegerType>(type)) { | |
attr = rewriter.getIntegerAttr(type, value); | |
} else { | |
auto vecTy = cast<VectorType>(type); | |
attr = SplatElementsAttr::get(vecTy, value); | |
} | |
return rewriter.create<arith::ConstantOp>(loc, attr); | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Utility factored out
@@ -44,7 +44,10 @@ struct ExtfOnFloat8RewritePattern final | |||
|
|||
struct TruncfToFloat8RewritePattern final |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This is existing code, but shouldn't this be called TruncF*
? Also everywhere else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Many machine-learning applications (and most software written at AMD) expect the operation that truncates floats to 8-bit floats to be saturatinng. That is, they expect `truncf 256.0 : f32 to f8E4M3FNUZ` to yield `240.0`, not `NaN`, and similarly for negative numbers. However, the underlying hardware instruction that can be used for this truncation implements overflow-to-NaN semantics. To enable handling this usecase, we add the saturate-fp8-truncf option to ArithToAMDGPU (off by default), which causes the requisite clamping code to be emitted. Said clamping code ensures that Inf and NaN are passed through exactly (and thus trancate to NaN). Per review feedback, this commit efactors createScalarOrSplatConstant() to the Arith dialect utilities and uses it in this code. It also fixes naming of existing patterns and switches from vector.extractelement/insertelement to vector.extract/insert.
048e75d
to
f113249
Compare
Many machine-learning applications (and most software written at AMD) expect the operation that truncates floats to 8-bit floats to be saturatinng. That is, they expect
truncf 256.0 : f32 to f8E4M3FNUZ
to yield240.0
, notNaN
, and similarly for negative numbers. However, the underlying hardware instruction that can be used for this truncation implements overflow-to-NaN semantics.To enable handling this usecase, we add the saturate-fp8-truncf option to ArithToAMDGPU (off by default), which causes the requisite clamping code to be emitted. Said clamping code ensures that Inf and NaN are passed through exactly (and thus trancate to NaN).
Per review feedback, this commit efactors
createScalarOrSplatConstant() to the Arith dialect utilities and uses
it in this code. It also fixes naming of existing patterns and
switches from vector.extractelement/insertelement to
vector.extract/insert.