-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][arith][transforms] Fix f4E2M1FN to f32 cast #160121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@Muzammiluddin-Syed-ECE @krzysz00 @kuhar Could you help review this PR? |
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-arith Author: Jorn Tuyls (jtuyls) ChangesThe signed i4 bitcast was used when setting the exponent and mantissa and instead the sign should be omitted in the comparisons. Without this, for example the following incorrect conversion from | Binary | F4E2M1 | f32[23:32] | f32 Walkthrough:
Because Bits 25 through 31 are set based on the i4 value being larger than 4:
As Full diff: https://github.com/llvm/llvm-project/pull/160121.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 6e7421daeb223..54307c9ac843b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -387,12 +387,15 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
+ Value c0x7 = createConst(loc, i4Ty, 0x7, rewriter);
+
+ Value i4BitsNoSign = arith::AndIOp::create(b, i4Bits, c0x7);
// Set last Exponent bit and Mantissa.
Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
- Value bits1To24 = arith::ShLIOp::create(b, i4Bits, c0x2);
+ Value bits1To24 = arith::ShLIOp::create(b, i4BitsNoSign, c0x2);
Value isHalf =
- arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x1);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);
bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24);
bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24);
bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014);
@@ -402,7 +405,7 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
Value useLargerExp =
- arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x4);
+ arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);
Value bits25To31 =
arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits);
Value zeroExp =
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
index 9c310d80d4c2d..f58e65a04589e 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
@@ -28,7 +28,17 @@ func.func @entry() {
%zero = arith.constant 0.0 : f32
%half = arith.constant 0.5 : f32
%one = arith.constant 1.0 : f32
+ %oneAndAHalf = arith.constant 1.5 : f32
+ %two = arith.constant 2.0 : f32
+ %three = arith.constant 3.0 : f32
+ %four = arith.constant 4.0 : f32
%max = arith.constant 6.0 : f32
+ %minHalf = arith.constant -0.5 : f32
+ %minOne = arith.constant -1.0 : f32
+ %minOneAndAHalf = arith.constant -1.5 : f32
+ %minTwo = arith.constant -2.0 : f32
+ %minThree = arith.constant -3.0 : f32
+ %minFour = arith.constant -4.0 : f32
%min = arith.constant -6.0 : f32
%lowerThanMin = arith.constant -1000000.0 : f32
%higherThanMax = arith.constant 1000000.0 : f32
@@ -41,8 +51,28 @@ func.func @entry() {
func.call @check_truncf(%half) : (f32) -> ()
// CHECK: 2
func.call @check_truncf(%one) : (f32) -> ()
+ // CHECK: 3
+ func.call @check_truncf(%oneAndAHalf) : (f32) -> ()
+ // CHECK: 4
+ func.call @check_truncf(%two) : (f32) -> ()
+ // CHECK: 5
+ func.call @check_truncf(%three) : (f32) -> ()
+ // CHECK: 6
+ func.call @check_truncf(%four) : (f32) -> ()
// CHECK: 7
func.call @check_truncf(%max) : (f32) -> ()
+ // CHECK: 9
+ func.call @check_truncf(%minHalf) : (f32) -> ()
+ // CHECK: 10
+ func.call @check_truncf(%minOne) : (f32) -> ()
+ // CHECK: 11
+ func.call @check_truncf(%minOneAndAHalf) : (f32) -> ()
+ // CHECK: 12
+ func.call @check_truncf(%minTwo) : (f32) -> ()
+ // CHECK: 13
+ func.call @check_truncf(%minThree) : (f32) -> ()
+ // CHECK: 14
+ func.call @check_truncf(%minFour) : (f32) -> ()
// CHECK: 15
func.call @check_truncf(%min) : (f32) -> ()
// CHECK: 7
@@ -60,9 +90,42 @@ func.func @entry() {
// CHECK: 0.5
%halfF4 = arith.truncf %half : f32 to f4E2M1FN
func.call @check_extf(%halfF4) : (f4E2M1FN) -> ()
+ // CHECK: 1
+ %oneF4 = arith.truncf %one : f32 to f4E2M1FN
+ func.call @check_extf(%oneF4) : (f4E2M1FN) -> ()
+ // CHECK: 1.5
+ %oneAndAHalfF4 = arith.truncf %oneAndAHalf : f32 to f4E2M1FN
+ func.call @check_extf(%oneAndAHalfF4) : (f4E2M1FN) -> ()
+ // CHECK: 2
+ %twoF4 = arith.truncf %two : f32 to f4E2M1FN
+ func.call @check_extf(%twoF4) : (f4E2M1FN) -> ()
+ // CHECK: 3
+ %threeF4 = arith.truncf %three : f32 to f4E2M1FN
+ func.call @check_extf(%threeF4) : (f4E2M1FN) -> ()
+ // CHECK: 4
+ %fourF4 = arith.truncf %four : f32 to f4E2M1FN
+ func.call @check_extf(%fourF4) : (f4E2M1FN) -> ()
// CHECK: 6
%higherThanMaxF4 = arith.truncf %higherThanMax : f32 to f4E2M1FN
func.call @check_extf(%higherThanMaxF4) : (f4E2M1FN) -> ()
+ // CHECK: -0.5
+ %minHalfF4 = arith.truncf %minHalf : f32 to f4E2M1FN
+ func.call @check_extf(%minHalfF4) : (f4E2M1FN) -> ()
+ // CHECK: -1
+ %minOneF4 = arith.truncf %minOne : f32 to f4E2M1FN
+ func.call @check_extf(%minOneF4) : (f4E2M1FN) -> ()
+ // CHECK: -1.5
+ %minOneAndAHalfF4 = arith.truncf %minOneAndAHalf : f32 to f4E2M1FN
+ func.call @check_extf(%minOneAndAHalfF4) : (f4E2M1FN) -> ()
+ // CHECK: -2
+ %minTwoF4 = arith.truncf %minTwo : f32 to f4E2M1FN
+ func.call @check_extf(%minTwoF4) : (f4E2M1FN) -> ()
+ // CHECK: -3
+ %minThreeF4 = arith.truncf %minThree : f32 to f4E2M1FN
+ func.call @check_extf(%minThreeF4) : (f4E2M1FN) -> ()
+ // CHECK: -4
+ %minFourF4 = arith.truncf %minFour : f32 to f4E2M1FN
+ func.call @check_extf(%minFourF4) : (f4E2M1FN) -> ()
// CHECK: -6
%lowerThanMinF4 = arith.truncf %lowerThanMin : f32 to f4E2M1FN
func.call @check_extf(%lowerThanMinF4) : (f4E2M1FN) -> ()
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but wait for either @krzysz00 or @Muzammiluddin-Syed-ECE to confirm before landing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huh, that's a pretty big oversight (I suspect there used to be an abs() at some point and it got dropped). Approved
6ae0455
to
57ba6f8
Compare
@kuhar @krzysz00 There was one more issue with the -0.0 (f4) value, it was casted to -0.5 f32. I added a fix and test for that as well now: https://github.com/llvm/llvm-project/pull/160121/files#diff-4a23dbc7eb7e7f10e00c4e31ebb2f21359a4cceffc542169a975e8cc9ceba332R412. All f4 values should be tested now. Could you check again and help merge when it looks good? |
LGTM again, yeah |
The signed i4 bitcast was used when setting the exponent and mantissa and instead the sign should be omitted in the comparisons.
Without this, for example the following incorrect conversion from
-0.5
f4 to-3.0
f32 will happen:| Binary | F4E2M1 | f32[23:32] | f32
| 1001 | -0.5 |
1 1000 000 01|-3.0Walkthrough:
Bits 23 and 24 are set based on:
Because
1001 (i4) != 1
, bit 23 and 24 are set to the leading two bits of1001 << 2
, which is01
. The correct bits are00
.Bits 25 through 31 are set based on the i4 value being greater or equal to 4:
As
1001
is a negative i4 value, this is false and those bits are incorrectly set to1000 000
instead of0111 111
.