diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 6acfc2c15af42..85578c22799c6 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -8,6 +8,8 @@ #include +#include "llvm/ADT/TypeSwitch.h" + #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/Utils.h" #include "mlir/Analysis/DataFlowFramework.h" @@ -356,6 +358,16 @@ struct NarrowElementwise final : OpTraitRewritePattern { if (castKind == CastKind::None) break; } + // For operations that explicitly treat the values as signed, we should + // only do signed casts, if those are deemed possible as such based on the + // value range. + auto castKindForOp = + llvm::TypeSwitch(op) + .Case([](auto) { return CastKind::Signed; }) + .Default(CastKind::Both); + castKind = mergeCastKinds(castKind, castKindForOp); if (castKind == CastKind::None) continue; Type targetType = getTargetType(srcType, targetBitwidth); @@ -414,12 +426,26 @@ struct NarrowCmpI final : OpRewritePattern { const ConstantIntRanges &lhsRange = ranges[0]; const ConstantIntRanges &rhsRange = ranges[1]; + auto isSignedCmpPredicate = [](arith::CmpIPredicate pred) -> bool { + return pred == arith::CmpIPredicate::sge || + pred == arith::CmpIPredicate::sgt || + pred == arith::CmpIPredicate::sle || + pred == arith::CmpIPredicate::slt; + }; + // If we're to narrow the input values via a cast, we should preserve the + // sign. + CastKind predicateBasedCastRestriction = + isSignedCmpPredicate(op.getPredicate()) ? CastKind::Signed + : CastKind::Both; + Type srcType = lhs.getType(); for (unsigned targetBitwidth : targetBitwidths) { CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth); CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth); CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind); - // Note: this includes target width > src width. + castKind = mergeCastKinds(castKind, predicateBasedCastRestriction); + // Note: this includes target width > src width, as well as the unsigned + // truncatability & signed predicate scenario. if (castKind == CastKind::None) continue; diff --git a/mlir/test/Dialect/Arith/int-range-narrowing.mlir b/mlir/test/Dialect/Arith/int-range-narrowing.mlir index c3b0d280b1350..7ba22af0c0f1b 100644 --- a/mlir/test/Dialect/Arith/int-range-narrowing.mlir +++ b/mlir/test/Dialect/Arith/int-range-narrowing.mlir @@ -79,6 +79,52 @@ func.func @test_cmpi() -> i1 { return %2 : i1 } +// CHECK-LABEL: func @test_cmpi_si_pred_out_of_signed_bounds +// CHECK-NOT: arith.cmpi slt, {{.*}} : i32 +// CHECK-NOT: arith.cmpi sgt, {{.*}} : i32 +// CHECK-NOT: arith.cmpi sle, {{.*}} : i32 +// CHECK-NOT: arith.cmpi sge, {{.*}} : i32 +// CHECK: %[[A:.*]] = test.with_bounds {smax = 4292870144 : index, smin = 0 : index, umax = 4292870144 : index, umin = 0 : index} : index +// CHECK: %[[B:.*]] = test.with_bounds {smax = 4292870144 : index, smin = 0 : index, umax = 4292870144 : index, umin = 0 : index} : index +// CHECK: %[[SLT:.*]] = arith.cmpi slt, %[[B]], %[[A]] : index +// CHECK: %[[C:.*]] = test.with_bounds {smax = 2147483648 : index, smin = 0 : index, umax = 2147483648 : index, umin = 0 : index} : index +// CHECK: %[[ZERO:.*]] = test.with_bounds {smax = 0 : index, smin = 0 : index, umax = 0 : index, umin = 0 : index} : index +// CHECK: %[[SGT:.*]] = arith.cmpi sgt, %[[C]], %[[ZERO]] : index +// CHECK: %[[SLE:.*]] = arith.cmpi sle, %[[A]], %[[ZERO]] : index +// CHECK: %[[SGE:.*]] = arith.cmpi sge, %[[C]], %[[ZERO]] : index +// CHECK: %[[AND0:.*]] = arith.andi %[[SLT]], %[[SGT]] : i1 +// CHECK: %[[AND1:.*]] = arith.andi %[[SLE]], %[[SGE]] : i1 +// CHECK: %[[AND2:.*]] = arith.andi %[[AND0]], %[[AND1]] : i1 +// CHECK: return %[[AND2]] : i1 +func.func @test_cmpi_si_pred_out_of_signed_bounds() -> i1 { + %0 = test.with_bounds { umin = 0 : index, umax = 4292870144 : index, smin = 0 : index, smax = 4292870144 : index } : index + %1 = test.with_bounds { umin = 0 : index, umax = 4292870144 : index, smin = 0 : index, smax = 4292870144 : index } : index + %2 = arith.cmpi slt, %1, %0 : index + %3 = test.with_bounds { umin = 0 : index, umax = 2147483648 : index, smin = 0 : index, smax = 2147483648 : index } : index + %4 = test.with_bounds { umin = 0 : index, umax = 0 : index, smin = 0 : index, smax = 0 : index } : index + %5 = arith.cmpi sgt, %3, %4 : index + %6 = arith.cmpi sle, %0, %4 : index + %7 = arith.cmpi sge, %3, %4 : index + %8 = arith.andi %2, %5 : i1 + %9 = arith.andi %6, %7 : i1 + %10 = arith.andi %8, %9 : i1 + return %10 : i1 +} + +// CHECK-LABEL: func @test_cmpi_ui_pred_out_of_signed_bounds +// CHECK: %[[A:.*]] = test.with_bounds {smax = 4292870144 : index, smin = 0 : index, umax = 4292870144 : index, umin = 0 : index} : index +// CHECK: %[[B:.*]] = test.with_bounds {smax = 4292870144 : index, smin = 0 : index, umax = 4292870144 : index, umin = 0 : index} : index +// CHECK: %[[A_I32:.*]] = arith.index_castui %[[A]] : index to i32 +// CHECK: %[[B_I32:.*]] = arith.index_castui %[[B]] : index to i32 +// CHECK: %[[RES:.*]] = arith.cmpi ult, %[[A_I32]], %[[B_I32]] : i32 +// CHECK: return %[[RES]] : i1 +func.func @test_cmpi_ui_pred_out_of_signed_bounds() -> i1 { + %0 = test.with_bounds { umin = 0 : index, umax = 4292870144 : index, smin = 0 : index, smax = 4292870144 : index } : index + %1 = test.with_bounds { umin = 0 : index, umax = 4292870144 : index, smin = 0 : index, smax = 4292870144 : index } : index + %2 = arith.cmpi ult, %0, %1 : index + return %2 : i1 +} + // CHECK-LABEL: func @test_cmpi_vec // CHECK: %[[A:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex> // CHECK: %[[B:.*]] = test.with_bounds {smax = 10 : index, smin = 0 : index, umax = 10 : index, umin = 0 : index} : vector<4xindex> @@ -224,6 +270,57 @@ func.func @subi_mixed_ext_i8(%lhs: i8, %rhs: i8) -> i32 { return %r : i32 } +// CHECK-LABEL: func.func @signed_ops_out_of_narrowed_signed_range +// CHECK-NOT: arith.divsi {{.*}} : i32 +// CHECK-NOT: arith.ceildivsi {{.*}} : i32 +// CHECK-NOT: arith.floordivsi {{.*}} : i32 +// CHECK-NOT: arith.remsi {{.*}} : i32 +// CHECK-NOT: arith.maxsi {{.*}} : i32 +// CHECK-NOT: arith.minsi {{.*}} : i32 +// CHECK-NOT: arith.shrsi {{.*}} : i32 +// CHECK: %[[DIV_I64:.*]] = arith.divsi {{.*}} : i64 +// CHECK: %[[CEIL_I64:.*]] = arith.ceildivsi {{.*}} : i64 +// CHECK: %[[FLOOR_I64:.*]] = arith.floordivsi {{.*}} : i64 +// CHECK: %[[REM_I64:.*]] = arith.remsi {{.*}} : i64 +// CHECK: %[[MAX_I64:.*]] = arith.maxsi {{.*}} : i64 +// CHECK: %[[MIN_I64:.*]] = arith.minsi {{.*}} : i64 +// CHECK: %[[SHR_I64:.*]] = arith.shrsi {{.*}} : i64 +// CHECK: return %{{.*}} : i64, i64, i64, i64, i64, i64, i64 +func.func @signed_ops_out_of_narrowed_signed_range() -> (i64, i64, i64, i64, i64, i64, i64) { + %0 = test.with_bounds { umin = 0 : i64, umax = 4292870144 : i64, smin = 0 : i64, smax = 4292870144 : i64 } : i64 + %1 = test.with_bounds { umin = 1 : i64, umax = 8 : i64, smin = 1 : i64, smax = 8 : i64 } : i64 + %2 = test.with_bounds { umin = 0 : i64, umax = 0 : i64, smin = 0 : i64, smax = 0 : i64 } : i64 + %3 = arith.divsi %0, %1 : i64 + %4 = arith.ceildivsi %0, %1 : i64 + %5 = arith.floordivsi %0, %1 : i64 + %6 = arith.remsi %0, %1 : i64 + %7 = arith.maxsi %0, %2 : i64 + %8 = arith.minsi %0, %2 : i64 + %9 = arith.shrsi %0, %1 : i64 + return %3, %4, %5, %6, %7, %8, %9 : i64, i64, i64, i64, i64, i64, i64 +} + +// CHECK-LABEL: func.func @unsigned_ops_out_of_narrowed_signed_range +// CHECK: arith.divui {{.*}} : i32 +// CHECK: arith.ceildivui {{.*}} : i32 +// CHECK: arith.remui {{.*}} : i32 +// CHECK: arith.maxui {{.*}} : i32 +// CHECK: arith.minui {{.*}} : i32 +// CHECK: arith.shrui {{.*}} : i32 +// CHECK: return %{{.*}} : i64, i64, i64, i64, i64, i64 +func.func @unsigned_ops_out_of_narrowed_signed_range() -> (i64, i64, i64, i64, i64, i64) { + %0 = test.with_bounds { umin = 0 : i64, umax = 4292870144 : i64, smin = 0 : i64, smax = 4292870144 : i64 } : i64 + %1 = test.with_bounds { umin = 1 : i64, umax = 8 : i64, smin = 1 : i64, smax = 8 : i64 } : i64 + %2 = test.with_bounds { umin = 0 : i64, umax = 0 : i64, smin = 0 : i64, smax = 0 : i64 } : i64 + %3 = arith.divui %0, %1 : i64 + %4 = arith.ceildivui %0, %1 : i64 + %5 = arith.remui %0, %1 : i64 + %6 = arith.maxui %0, %2 : i64 + %7 = arith.minui %0, %2 : i64 + %8 = arith.shrui %0, %1 : i64 + return %3, %4, %5, %6, %7, %8 : i64, i64, i64, i64, i64, i64 +} + //===----------------------------------------------------------------------===// // arith.muli //===----------------------------------------------------------------------===//