diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index 6e7421daeb223..adeb50b6da628 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -387,12 +387,15 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern { Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter); Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter); Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter); + Value c0x7 = createConst(loc, i4Ty, 0x7, rewriter); + + Value i4BitsNoSign = arith::AndIOp::create(b, i4Bits, c0x7); // Set last Exponent bit and Mantissa. Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter); - Value bits1To24 = arith::ShLIOp::create(b, i4Bits, c0x2); + Value bits1To24 = arith::ShLIOp::create(b, i4BitsNoSign, c0x2); Value isHalf = - arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x1); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1); bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24); bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24); bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014); @@ -402,11 +405,11 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern { Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter); Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter); Value useLargerExp = - arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x4); + arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4); Value bits25To31 = arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits); Value zeroExp = - arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x0); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x0); bits25To31 = arith::SelectOp::create(b, zeroExp, zeroExpBits, bits25To31); // Set sign. diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir index 9c310d80d4c2d..f2970618d5b6e 100644 --- a/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir @@ -28,7 +28,18 @@ func.func @entry() { %zero = arith.constant 0.0 : f32 %half = arith.constant 0.5 : f32 %one = arith.constant 1.0 : f32 + %oneAndAHalf = arith.constant 1.5 : f32 + %two = arith.constant 2.0 : f32 + %three = arith.constant 3.0 : f32 + %four = arith.constant 4.0 : f32 %max = arith.constant 6.0 : f32 + %minZero = arith.constant -0.0 : f32 + %minHalf = arith.constant -0.5 : f32 + %minOne = arith.constant -1.0 : f32 + %minOneAndAHalf = arith.constant -1.5 : f32 + %minTwo = arith.constant -2.0 : f32 + %minThree = arith.constant -3.0 : f32 + %minFour = arith.constant -4.0 : f32 %min = arith.constant -6.0 : f32 %lowerThanMin = arith.constant -1000000.0 : f32 %higherThanMax = arith.constant 1000000.0 : f32 @@ -41,8 +52,28 @@ func.func @entry() { func.call @check_truncf(%half) : (f32) -> () // CHECK: 2 func.call @check_truncf(%one) : (f32) -> () + // CHECK: 3 + func.call @check_truncf(%oneAndAHalf) : (f32) -> () + // CHECK: 4 + func.call @check_truncf(%two) : (f32) -> () + // CHECK: 5 + func.call @check_truncf(%three) : (f32) -> () + // CHECK: 6 + func.call @check_truncf(%four) : (f32) -> () // CHECK: 7 func.call @check_truncf(%max) : (f32) -> () + // CHECK: 9 + func.call @check_truncf(%minHalf) : (f32) -> () + // CHECK: 10 + func.call @check_truncf(%minOne) : (f32) -> () + // CHECK: 11 + func.call @check_truncf(%minOneAndAHalf) : (f32) -> () + // CHECK: 12 + func.call @check_truncf(%minTwo) : (f32) -> () + // CHECK: 13 + func.call @check_truncf(%minThree) : (f32) -> () + // CHECK: 14 + func.call @check_truncf(%minFour) : (f32) -> () // CHECK: 15 func.call @check_truncf(%min) : (f32) -> () // CHECK: 7 @@ -60,9 +91,45 @@ func.func @entry() { // CHECK: 0.5 %halfF4 = arith.truncf %half : f32 to f4E2M1FN func.call @check_extf(%halfF4) : (f4E2M1FN) -> () + // CHECK: 1 + %oneF4 = arith.truncf %one : f32 to f4E2M1FN + func.call @check_extf(%oneF4) : (f4E2M1FN) -> () + // CHECK: 1.5 + %oneAndAHalfF4 = arith.truncf %oneAndAHalf : f32 to f4E2M1FN + func.call @check_extf(%oneAndAHalfF4) : (f4E2M1FN) -> () + // CHECK: 2 + %twoF4 = arith.truncf %two : f32 to f4E2M1FN + func.call @check_extf(%twoF4) : (f4E2M1FN) -> () + // CHECK: 3 + %threeF4 = arith.truncf %three : f32 to f4E2M1FN + func.call @check_extf(%threeF4) : (f4E2M1FN) -> () + // CHECK: 4 + %fourF4 = arith.truncf %four : f32 to f4E2M1FN + func.call @check_extf(%fourF4) : (f4E2M1FN) -> () // CHECK: 6 %higherThanMaxF4 = arith.truncf %higherThanMax : f32 to f4E2M1FN func.call @check_extf(%higherThanMaxF4) : (f4E2M1FN) -> () + // CHECK: -0 + %minZeroF4 = arith.truncf %minZero : f32 to f4E2M1FN + func.call @check_extf(%minZeroF4) : (f4E2M1FN) -> () + // CHECK: -0.5 + %minHalfF4 = arith.truncf %minHalf : f32 to f4E2M1FN + func.call @check_extf(%minHalfF4) : (f4E2M1FN) -> () + // CHECK: -1 + %minOneF4 = arith.truncf %minOne : f32 to f4E2M1FN + func.call @check_extf(%minOneF4) : (f4E2M1FN) -> () + // CHECK: -1.5 + %minOneAndAHalfF4 = arith.truncf %minOneAndAHalf : f32 to f4E2M1FN + func.call @check_extf(%minOneAndAHalfF4) : (f4E2M1FN) -> () + // CHECK: -2 + %minTwoF4 = arith.truncf %minTwo : f32 to f4E2M1FN + func.call @check_extf(%minTwoF4) : (f4E2M1FN) -> () + // CHECK: -3 + %minThreeF4 = arith.truncf %minThree : f32 to f4E2M1FN + func.call @check_extf(%minThreeF4) : (f4E2M1FN) -> () + // CHECK: -4 + %minFourF4 = arith.truncf %minFour : f32 to f4E2M1FN + func.call @check_extf(%minFourF4) : (f4E2M1FN) -> () // CHECK: -6 %lowerThanMinF4 = arith.truncf %lowerThanMin : f32 to f4E2M1FN func.call @check_extf(%lowerThanMinF4) : (f4E2M1FN) -> ()