Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ArithToAMDGPU] Add option for saturating truncation to fp8 #74153

Merged
merged 1 commit into from
Jan 23, 2024

Conversation

krzysz00
Copy link
Contributor

@krzysz00 krzysz00 commented Dec 1, 2023

Many machine-learning applications (and most software written at AMD) expect the operation that truncates floats to 8-bit floats to be saturatinng. That is, they expect truncf 256.0 : f32 to f8E4M3FNUZ to yield 240.0, not NaN, and similarly for negative numbers. However, the underlying hardware instruction that can be used for this truncation implements overflow-to-NaN semantics.

To enable handling this usecase, we add the saturate-fp8-truncf option to ArithToAMDGPU (off by default), which causes the requisite clamping code to be emitted. Said clamping code ensures that Inf and NaN are passed through exactly (and thus trancate to NaN).

Per review feedback, this commit efactors
createScalarOrSplatConstant() to the Arith dialect utilities and uses
it in this code. It also fixes naming of existing patterns and
switches from vector.extractelement/insertelement to
vector.extract/insert.

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 1, 2023

@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

Many machine-learning applications (and most software written at AMD) expect the operation that truncates floats to 8-bit floats to be saturatinng. That is, they expect truncf 256.0 : f32 to f8E4M3FNUZ to yield 240.0, not NaN, and similarly for negative numbers. However, the underlying hardware instruction that can be used for this truncation implements overflow-to-NaN semantics.

To enable handling this usecase, we add the saturate-fp8-truncf option to ArithToAMDGPU (off by default), which causes the requisite clamping code to be emitted. Said clamping code ensures that Inf and NaN are passed through exactly (and thus trancate to NaN).


Full diff: https://github.com/llvm/llvm-project/pull/74153.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h (+2-1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+6)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+65-5)
  • (added) mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir (+57)
diff --git a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
index 7f445fee5ba6b82..a1c059800752aca 100644
--- a/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
+++ b/mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
@@ -20,7 +20,8 @@ class Pass;
 #include "mlir/Conversion/Passes.h.inc"
 
 namespace arith {
-void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns);
+void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
+                                             bool saturateFP8Truncf);
 } // namespace arith
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 06756ff3df0bb3b..2aa2ad634aeb722 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -125,6 +125,12 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
   }];
 
   let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
+
+  let options = [
+    Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool",
+           /*default=*/"false",
+           "Whether truncation to 8-bit float types should be saturating">,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 7785405eae67be3..d6b916e6e55423d 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -44,7 +44,10 @@ struct ExtfOnFloat8RewritePattern final
 
 struct TruncfToFloat8RewritePattern final
     : public OpRewritePattern<arith::TruncFOp> {
-  using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
+  bool saturateFP8 = false;
+  TruncfToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
+      : OpRewritePattern<arith::TruncFOp>::OpRewritePattern(ctx),
+        saturateFP8(saturateFP8) {}
 
   LogicalResult match(arith::TruncFOp op) const override;
   void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
@@ -127,6 +130,60 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
   llvm_unreachable("The only 32-bit float type is f32");
 }
 
+static Value getMaybeVectorConstant(PatternRewriter &rewriter, Location loc,
+                                    const APFloat &value, Type type) {
+  if (isa<FloatType>(type))
+    return rewriter.createOrFold<arith::ConstantOp>(
+        loc, type, rewriter.getFloatAttr(type, value));
+  TypedAttr splat = DenseElementsAttr::get(cast<ShapedType>(type), value);
+  return rewriter.createOrFold<arith::ConstantOp>(loc, type, splat);
+}
+
+// If `in` is a finite value, clamp it between the maximum and minimum values
+// of `outElemType` so that subsequent conversion instructions don't
+// overflow those out-of-range values to NaN. These semantics are commonly
+// used in machine-learning contexts where failure to clamp would lead to
+// excessive NaN production.
+static Value clampInput(PatternRewriter &rewriter, Location loc,
+                        Type outElemType, Value source) {
+  Type sourceType = source.getType();
+  const llvm::fltSemantics &sourceSem =
+      cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
+  const llvm::fltSemantics &targetSem =
+      cast<FloatType>(outElemType).getFloatSemantics();
+
+  APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
+  APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
+  bool ignoredLosesInfo = false;
+  // We can ignore conversion failures here because this conversion promotes
+  // from a smaller type to a larger one.
+  (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
+  (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
+
+  Value minCst = getMaybeVectorConstant(rewriter, loc, min, sourceType);
+  Value maxCst = getMaybeVectorConstant(rewriter, loc, max, sourceType);
+
+  Value inf = getMaybeVectorConstant(
+      rewriter, loc, APFloat::getInf(sourceSem, /*Negative=*/false),
+      sourceType);
+  Value negInf = getMaybeVectorConstant(
+      rewriter, loc, APFloat::getInf(sourceSem, /*Negative=*/true), sourceType);
+  Value isInf = rewriter.createOrFold<arith::CmpFOp>(
+      loc, arith::CmpFPredicate::OEQ, source, inf);
+  Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
+      loc, arith::CmpFPredicate::OEQ, source, negInf);
+  Value isNan = rewriter.createOrFold<arith::CmpFOp>(
+      loc, arith::CmpFPredicate::UNO, source, source);
+  Value isNonFinite = rewriter.create<arith::OrIOp>(
+      loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
+
+  Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
+  Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
+  Value res =
+      rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
+  return res;
+}
+
 LogicalResult TruncfToFloat8RewritePattern::match(arith::TruncFOp op) const {
   Type outType = op.getOut().getType();
   if (auto outVecType = outType.dyn_cast<VectorType>()) {
@@ -145,6 +202,8 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
+  if (saturateFP8)
+    in = clampInput(rewriter, loc, outElemType, in);
   VectorType truncResType = VectorType::get(4, outElemType);
   if (!in.getType().isa<VectorType>()) {
     Value asFloat = castToF32(in, loc, rewriter);
@@ -196,15 +255,16 @@ void TruncfToFloat8RewritePattern::rewrite(arith::TruncFOp op,
 }
 
 void mlir::arith::populateArithToAMDGPUConversionPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<ExtfOnFloat8RewritePattern, TruncfToFloat8RewritePattern>(
-      patterns.getContext());
+    RewritePatternSet &patterns, bool saturateFP8Truncf) {
+  patterns.add<ExtfOnFloat8RewritePattern>(patterns.getContext());
+  patterns.add<TruncfToFloat8RewritePattern>(patterns.getContext(),
+                                             saturateFP8Truncf);
 }
 
 void ArithToAMDGPUConversionPass::runOnOperation() {
   Operation *op = getOperation();
   RewritePatternSet patterns(op->getContext());
-  arith::populateArithToAMDGPUConversionPatterns(patterns);
+  arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
   if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
     return signalPassFailure();
 }
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
new file mode 100644
index 000000000000000..d0c2cd4090117ff
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt --split-input-file %s \
+// RUN: --pass-pipeline='builtin.module(func.func(convert-arith-to-amdgpu{saturate-fp8-truncf=true}))' \
+// RUN: | FileCheck %s
+
+// CHECK-LABEL: func.func @scalar_trunc
+// CHECK-SAME: ([[V:%.+]]: f16)
+// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+// CHECK-DAG: [[CMin:%.+]] = arith.constant -5.734400e+04 : f16
+// CHECK-DAG: [[CMax:%.+]] = arith.constant 5.734400e+04 : f16
+// CHECK-DAG: [[CInf:%.+]] = arith.constant 0x7C00 : f16
+// CHECK-DAG: [[CNegInf:%.+]] = arith.constant 0xFC00 : f16
+// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
+// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
+// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
+// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
+// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
+// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
+// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
+// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
+// CHECK: [[FLOAT:%.+]] = arith.extf [[SATURATED]] : f16 to f32
+// CHECK: [[TRUNCV:%.+]] = amdgpu.packed_trunc_2xfp8 [[FLOAT]], undef into undef[word 0] : f32 to vector<4xf8E5M2FNUZ>
+// CHECK: [[W:%.+]] = vector.extractelement [[TRUNCV]]{{\[}}[[C0]] : index] : vector<4xf8E5M2FNUZ>
+// CHECK: return [[W]] : f8E5M2FNUZ
+func.func @scalar_trunc(%v: f16) -> f8E5M2FNUZ {
+  %w = arith.truncf %v : f16 to f8E5M2FNUZ
+  return %w : f8E5M2FNUZ
+}
+
+// No 0-D test because arith.truncf hasn't been extended to support it.
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc
+// CHECK-SAME: ([[V:%.+]]: vector<2xf32>) -> vector<2xf8E4M3FNUZ> {
+// CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
+// CHECK-DAG: [[C1:%.+]] = arith.constant 1 : index
+// CHECK-DAG: [[CMin:%.+]] = arith.constant dense<-2.400000e+02> : vector<2xf32>
+// CHECK-DAG: [[CMax:%.+]] = arith.constant dense<2.400000e+02> : vector<2xf32>
+// CHECK-DAG: [[CInf:%.+]] = arith.constant dense<0x7F800000> : vector<2xf32>
+// CHECK-DAG: [[CNegInf:%.+]] = arith.constant dense<0xFF800000> : vector<2xf32>
+// CHECK: [[ISINF:%.+]] = arith.cmpf oeq, [[V]], [[CInf]]
+// CHECK: [[ISNEGINF:%.+]] = arith.cmpf oeq, [[V]], [[CNegInf]]
+// CHECK: [[ISNAN:%.+]] = arith.cmpf uno, [[V]], [[V]]
+// CHECK: [[ISNONFINITE_1:%.+]] = arith.ori [[ISINF]], [[ISNEGINF]]
+// CHECK: [[ISNONFINITE:%.+]] = arith.ori [[ISNONFINITE_1]], [[ISNAN]]
+// CHECK: [[CLAMPEDBELOW:%.+]] = arith.maximumf [[V]], [[CMin]]
+// CHECK: [[CLAMPED:%.+]] = arith.minimumf [[CLAMPEDBELOW]], [[CMax]]
+// CHECK: [[SATURATED:%.+]] = arith.select [[ISNONFINITE]], [[V]], [[CLAMPED]]
+// CHECK: [[F0:%.+]] = vector.extractelement [[SATURATED]]{{\[}}[[C0]] : index]
+// CHECK: [[F1:%.+]] = vector.extractelement [[SATURATED]]{{\[}}[[C1]] : index]
+// CHECK: [[W0:%.+]] = amdgpu.packed_trunc_2xfp8 [[F0]], [[F1]] into undef[word 0] : f32 to vector<4xf8E4M3FNUZ>
+// CHECK: [[W:%.+]] = vector.extract_strided_slice [[W0]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E4M3FNUZ> to vector<2xf8E4M3FNUZ>
+// CHECK: return [[W]] : vector<2xf8E4M3FNUZ>
+func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf8E4M3FNUZ> {
+  %w = arith.truncf %v : vector<2xf32> to vector<2xf8E4M3FNUZ>
+  return %w : vector<2xf8E4M3FNUZ>
+}

@krzysz00
Copy link
Contributor Author

krzysz00 commented Jan 2, 2024

Ping

@kuhar kuhar self-requested a review January 4, 2024 22:29
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some suggestions, but overall the clamping logic looks good to me. Some of the nits for naming come from me being used to different conventions, so feel free to disregard in this PR -- just wanted to surface things that surprised me.

Comment on lines 23 to 24
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
bool saturateFP8Truncf);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a documentation comment?
also nit:

Suggested change
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
bool saturateFP8Truncf);
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
bool saturateFP8TruncF);

mlir/include/mlir/Conversion/Passes.td Outdated Show resolved Hide resolved
@@ -44,7 +44,10 @@ struct ExtfOnFloat8RewritePattern final

struct TruncfToFloat8RewritePattern final
: public OpRewritePattern<arith::TruncFOp> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: public is redundant here BTW

using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
bool saturateFP8 = false;
TruncfToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
: OpRewritePattern<arith::TruncFOp>::OpRewritePattern(ctx),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
: OpRewritePattern<arith::TruncFOp>::OpRewritePattern(ctx),
: OpRewritePattern::OpRewritePattern(ctx),

Comment on lines 133 to 134
static Value getMaybeVectorConstant(PatternRewriter &rewriter, Location loc,
const APFloat &value, Type type) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe reads to me like this could fail, but this is not the case. We have a similar helper in the other arith code where it's called createScalarOrSplatConstant:

static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
Location loc, Type type,
const APInt &value) {
TypedAttr attr;
if (dyn_cast<IntegerType>(type)) {
attr = rewriter.getIntegerAttr(type, value);
} else {
auto vecTy = cast<VectorType>(type);
attr = SplatElementsAttr::get(vecTy, value);
}
return rewriter.create<arith::ConstantOp>(loc, attr);
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Utility factored out

@@ -44,7 +44,10 @@ struct ExtfOnFloat8RewritePattern final

struct TruncfToFloat8RewritePattern final
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This is existing code, but shouldn't this be called TruncF*? Also everywhere else.

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp Outdated Show resolved Hide resolved
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Many machine-learning applications (and most software written at AMD)
expect the operation that truncates floats to 8-bit floats to be
saturatinng. That is, they expect `truncf 256.0 : f32 to f8E4M3FNUZ`
to yield `240.0`, not `NaN`, and similarly for negative numbers.
However, the underlying hardware instruction that can be used for this
truncation implements overflow-to-NaN semantics.

To enable handling this usecase, we add the saturate-fp8-truncf option
to ArithToAMDGPU (off by default), which causes the requisite clamping
code to be emitted. Said clamping code ensures that Inf and NaN are
passed through exactly (and thus trancate to NaN).

Per review feedback, this commit efactors
createScalarOrSplatConstant() to the Arith dialect utilities and uses
it in this code. It also fixes naming of existing patterns and
switches from vector.extractelement/insertelement to
vector.extract/insert.
@krzysz00 krzysz00 merged commit 750e90e into llvm:main Jan 23, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants