Skip to content

Conversation

jtuyls
Copy link
Contributor

@jtuyls jtuyls commented Sep 22, 2025

The signed i4 bitcast was used when setting the exponent and mantissa and instead the sign should be omitted in the comparisons.

Without this, for example the following incorrect conversion from -0.5 f4 to -3.0 f32 will happen:

| Binary | F4E2M1 | f32[23:32] | f32
| 1001 | -0.5 | 1 1000 000 01 | -3.0

Walkthrough:
Bits 23 and 24 are set based on:

Value isHalf =
        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);

Because 1001 (i4) != 1, bit 23 and 24 are set to the leading two bits of 1001 << 2, which is 01. The correct bits are 00.

Bits 25 through 31 are set based on the i4 value being greater or equal to 4:

Value useLargerExp =
        arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);

As 1001 is a negative i4 value, this is false and those bits are incorrectly set to 1000 000 instead of 0111 111.

@jtuyls
Copy link
Contributor Author

jtuyls commented Sep 22, 2025

@Muzammiluddin-Syed-ECE @krzysz00 @kuhar Could you help review this PR?

@llvmbot
Copy link
Member

llvmbot commented Sep 22, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: Jorn Tuyls (jtuyls)

Changes

The signed i4 bitcast was used when setting the exponent and mantissa and instead the sign should be omitted in the comparisons.

Without this, for example the following incorrect conversion from -0.5 f4 to -3.0 f32 will happen:

| Binary | F4E2M1 | f32[23:32] | f32
| 1001 | -0.5 | 1 1000 000 01 | -3.0

Walkthrough:
Bits 23 and 24 are set based on:

Value isHalf =
        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);

Because 1001 (i4) != 1, bit 23 and 24 are set to the leading two bits of 1001 &lt;&lt; 2, which is 01.

Bits 25 through 31 are set based on the i4 value being larger than 4:

Value useLargerExp =
        arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);

As 1001 is a negative i4 value, this is false and those bits are incorrectly set to 1000 000 instead of 0111 111.


Full diff: https://github.com/llvm/llvm-project/pull/160121.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+6-3)
  • (modified) mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir (+63)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 6e7421daeb223..54307c9ac843b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -387,12 +387,15 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
     Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
     Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
+    Value c0x7 = createConst(loc, i4Ty, 0x7, rewriter);
+
+    Value i4BitsNoSign = arith::AndIOp::create(b, i4Bits, c0x7);
 
     // Set last Exponent bit and Mantissa.
     Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
-    Value bits1To24 = arith::ShLIOp::create(b, i4Bits, c0x2);
+    Value bits1To24 = arith::ShLIOp::create(b, i4BitsNoSign, c0x2);
     Value isHalf =
-        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x1);
+        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);
     bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24);
     bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24);
     bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014);
@@ -402,7 +405,7 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
     Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
     Value useLargerExp =
-        arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x4);
+        arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);
     Value bits25To31 =
         arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits);
     Value zeroExp =
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
index 9c310d80d4c2d..f58e65a04589e 100644
--- a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir
@@ -28,7 +28,17 @@ func.func @entry() {
   %zero = arith.constant 0.0 : f32
   %half = arith.constant 0.5 : f32
   %one = arith.constant 1.0 : f32
+  %oneAndAHalf = arith.constant 1.5 : f32
+  %two = arith.constant 2.0 : f32
+  %three = arith.constant 3.0 : f32
+  %four = arith.constant 4.0 : f32
   %max = arith.constant 6.0 : f32
+  %minHalf = arith.constant -0.5 : f32
+  %minOne = arith.constant -1.0 : f32
+  %minOneAndAHalf = arith.constant -1.5 : f32
+  %minTwo = arith.constant -2.0 : f32
+  %minThree = arith.constant -3.0 : f32
+  %minFour = arith.constant -4.0 : f32
   %min = arith.constant -6.0 : f32
   %lowerThanMin = arith.constant -1000000.0 : f32
   %higherThanMax = arith.constant 1000000.0 : f32
@@ -41,8 +51,28 @@ func.func @entry() {
   func.call @check_truncf(%half) : (f32) -> ()
   // CHECK: 2
   func.call @check_truncf(%one) : (f32) -> ()
+  // CHECK: 3
+  func.call @check_truncf(%oneAndAHalf) : (f32) -> ()
+  // CHECK: 4
+  func.call @check_truncf(%two) : (f32) -> ()
+  // CHECK: 5
+  func.call @check_truncf(%three) : (f32) -> ()
+  // CHECK: 6
+  func.call @check_truncf(%four) : (f32) -> ()
   // CHECK: 7
   func.call @check_truncf(%max) : (f32) -> ()
+  // CHECK: 9
+  func.call @check_truncf(%minHalf) : (f32) -> ()
+  // CHECK: 10
+  func.call @check_truncf(%minOne) : (f32) -> ()
+  // CHECK: 11
+  func.call @check_truncf(%minOneAndAHalf) : (f32) -> ()
+  // CHECK: 12
+  func.call @check_truncf(%minTwo) : (f32) -> ()
+  // CHECK: 13
+  func.call @check_truncf(%minThree) : (f32) -> ()
+  // CHECK: 14
+  func.call @check_truncf(%minFour) : (f32) -> ()
   // CHECK: 15
   func.call @check_truncf(%min) : (f32) -> ()
   // CHECK: 7
@@ -60,9 +90,42 @@ func.func @entry() {
   // CHECK: 0.5
   %halfF4 = arith.truncf %half : f32 to f4E2M1FN
   func.call @check_extf(%halfF4) : (f4E2M1FN) -> ()
+  // CHECK: 1
+  %oneF4 = arith.truncf %one : f32 to f4E2M1FN
+  func.call @check_extf(%oneF4) : (f4E2M1FN) -> ()
+  // CHECK: 1.5
+  %oneAndAHalfF4 = arith.truncf %oneAndAHalf : f32 to f4E2M1FN
+  func.call @check_extf(%oneAndAHalfF4) : (f4E2M1FN) -> ()
+  // CHECK: 2
+  %twoF4 = arith.truncf %two : f32 to f4E2M1FN
+  func.call @check_extf(%twoF4) : (f4E2M1FN) -> ()
+  // CHECK: 3
+  %threeF4 = arith.truncf %three : f32 to f4E2M1FN
+  func.call @check_extf(%threeF4) : (f4E2M1FN) -> ()
+  // CHECK: 4
+  %fourF4 = arith.truncf %four : f32 to f4E2M1FN
+  func.call @check_extf(%fourF4) : (f4E2M1FN) -> ()
   // CHECK: 6
   %higherThanMaxF4 = arith.truncf %higherThanMax : f32 to f4E2M1FN
   func.call @check_extf(%higherThanMaxF4) : (f4E2M1FN) -> ()
+  // CHECK: -0.5
+  %minHalfF4 = arith.truncf %minHalf : f32 to f4E2M1FN
+  func.call @check_extf(%minHalfF4) : (f4E2M1FN) -> ()
+  // CHECK: -1
+  %minOneF4 = arith.truncf %minOne : f32 to f4E2M1FN
+  func.call @check_extf(%minOneF4) : (f4E2M1FN) -> ()
+  // CHECK: -1.5
+  %minOneAndAHalfF4 = arith.truncf %minOneAndAHalf : f32 to f4E2M1FN
+  func.call @check_extf(%minOneAndAHalfF4) : (f4E2M1FN) -> ()
+  // CHECK: -2
+  %minTwoF4 = arith.truncf %minTwo : f32 to f4E2M1FN
+  func.call @check_extf(%minTwoF4) : (f4E2M1FN) -> ()
+  // CHECK: -3
+  %minThreeF4 = arith.truncf %minThree : f32 to f4E2M1FN
+  func.call @check_extf(%minThreeF4) : (f4E2M1FN) -> ()
+  // CHECK: -4
+  %minFourF4 = arith.truncf %minFour : f32 to f4E2M1FN
+  func.call @check_extf(%minFourF4) : (f4E2M1FN) -> ()
   // CHECK: -6
   %lowerThanMinF4 = arith.truncf %lowerThanMin : f32 to f4E2M1FN
   func.call @check_extf(%lowerThanMinF4) : (f4E2M1FN) -> ()

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but wait for either @krzysz00 or @Muzammiluddin-Syed-ECE to confirm before landing

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, that's a pretty big oversight (I suspect there used to be an abs() at some point and it got dropped). Approved

@jtuyls
Copy link
Contributor Author

jtuyls commented Sep 22, 2025

@kuhar @krzysz00 There was one more issue with the -0.0 (f4) value, it was casted to -0.5 f32. I added a fix and test for that as well now: https://github.com/llvm/llvm-project/pull/160121/files#diff-4a23dbc7eb7e7f10e00c4e31ebb2f21359a4cceffc542169a975e8cc9ceba332R412. All f4 values should be tested now.

Could you check again and help merge when it looks good?

@krzysz00
Copy link
Contributor

LGTM again, yeah

@kuhar kuhar merged commit faf5f28 into llvm:main Sep 22, 2025
9 checks passed
@jtuyls jtuyls deleted the fix-fp4e2m1-ext branch September 23, 2025 07:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants