Skip to content

Commit

Permalink
[mlir][Arith] Fix up integer range inference for truncation
Browse files Browse the repository at this point in the history
Attempting to apply the range analysis to real code revealed that
trunci wasn't correctly handling the case where truncation would
create wider ranges - for example, if we truncate [255, 257] : i16 to
i8, the result can be 255, 0, or 1, which isn't a contiguous range of
values.

The previous implementation would naively map this to [255, 1], which
would cause issues with unsigned ranges and unification.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D130501
  • Loading branch information
krzysz00 committed Aug 1, 2022
1 parent 7fc52d7 commit 938fe9f
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 12 deletions.
35 changes: 31 additions & 4 deletions mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp
Expand Up @@ -503,10 +503,37 @@ void arith::ExtSIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
static ConstantIntRanges truncIRange(const ConstantIntRanges &range,
Type destType) {
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
APInt umin = range.umin().trunc(destWidth);
APInt umax = range.umax().trunc(destWidth);
APInt smin = range.smin().trunc(destWidth);
APInt smax = range.smax().trunc(destWidth);
// If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb],
// the range of the resulting value is not contiguous ind includes 0.
// Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2],
// but you can't truncate [255, 257] similarly.
bool hasUnsignedRollover =
range.umin().lshr(destWidth) != range.umax().lshr(destWidth);
APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth)
: range.umin().trunc(destWidth);
APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth)
: range.umax().trunc(destWidth);

// Signed post-truncation rollover will not occur when either:
// - The high parts of the min and max, plus the sign bit, are the same
// - The high halves + sign bit of the min and max are either all 1s or all 0s
// and you won't create a [positive, negative] range by truncating.
// For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8
// but not [255, 257]_i16 to a range of i8s. You can also truncate
// [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16.
// You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e)
// will truncate to 0x7e, which is greater than 0
APInt sminHighPart = range.smin().ashr(destWidth - 1);
APInt smaxHighPart = range.smax().ashr(destWidth - 1);
bool hasSignedOverflow =
(sminHighPart != smaxHighPart) &&
!(sminHighPart.isAllOnes() &&
(smaxHighPart.isAllOnes() || smaxHighPart.isZero())) &&
!(sminHighPart.isZero() && smaxHighPart.isZero());
APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth)
: range.smin().trunc(destWidth);
APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
: range.smax().trunc(destWidth);
return {umin, umax, smin, smax};
}

Expand Down
83 changes: 75 additions & 8 deletions mlir/test/Dialect/Arithmetic/int-range-interface.mlir
Expand Up @@ -463,14 +463,15 @@ func.func @trunci(%arg0 : i32) -> i1 {
%c-14_i16 = arith.constant -14 : i16
%ci16_smin = arith.constant 0xffff8000 : i32
%0 = arith.minsi %arg0, %c-14_i32 : i32
%1 = arith.trunci %0 : i32 to i16
%2 = arith.cmpi sle, %1, %c-14_i16 : i16
%3 = arith.extsi %1 : i16 to i32
%4 = arith.cmpi sle, %3, %c-14_i32 : i32
%5 = arith.cmpi sge, %3, %ci16_smin : i32
%6 = arith.andi %2, %4 : i1
%7 = arith.andi %6, %5 : i1
func.return %7 : i1
%1 = arith.maxsi %0, %ci16_smin : i32
%2 = arith.trunci %1 : i32 to i16
%3 = arith.cmpi sle, %2, %c-14_i16 : i16
%4 = arith.extsi %2 : i16 to i32
%5 = arith.cmpi sle, %4, %c-14_i32 : i32
%6 = arith.cmpi sge, %4, %ci16_smin : i32
%7 = arith.andi %3, %5 : i1
%8 = arith.andi %7, %6 : i1
func.return %8 : i1
}

// CHECK-LABEL: func @index_cast
Expand Down Expand Up @@ -645,3 +646,69 @@ func.func @loop_bound_not_inferred_with_branch(%arg0 : index, %arg1 : i1) -> i1
func.return %8 : i1
}

// Test fon a bug where the noive implementation of trunctation led to the cast
// value being set to [0, 0].
// CHECK-LABEL: func.func @truncation_spillover
// CHECK: %[[unreplaced:.*]] = arith.index_cast
// CHECK: memref.store %[[unreplaced]]
func.func @truncation_spillover(%arg0 : memref<?xi32>) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c49 = arith.constant 49 : index
%0 = scf.for %arg1 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index {
%1 = arith.divsi %arg2, %c49 : index
%2 = arith.index_cast %1 : index to i32
memref.store %2, %arg0[%c0] : memref<?xi32>
%3 = arith.addi %arg2, %arg1 : index
scf.yield %3 : index
}
func.return %0 : index
}

// CHECK-LABEL: func.func @trunc_catches_overflow
// CHECK: %[[sge:.*]] = arith.cmpi sge
// CHECK: return %[[sge]]
func.func @trunc_catches_overflow(%arg0 : i16) -> i1 {
%c0_i16 = arith.constant 0 : i16
%c130_i16 = arith.constant 130 : i16
%c0_i8 = arith.constant 0 : i8
%0 = arith.maxui %arg0, %c0_i16 : i16
%1 = arith.minui %0, %c130_i16 : i16
%2 = arith.trunci %1 : i16 to i8
%3 = arith.cmpi sge, %2, %c0_i8 : i8
%4 = arith.cmpi uge, %2, %c0_i8 : i8
%5 = arith.andi %3, %4 : i1
func.return %5 : i1
}

// CHECK-LABEL: func.func @trunc_respects_same_high_half
// CHECK: %[[false:.*]] = arith.constant false
// CHECK: return %[[false]]
func.func @trunc_respects_same_high_half(%arg0 : i16) -> i1 {
%c256_i16 = arith.constant 256 : i16
%c257_i16 = arith.constant 257 : i16
%c2_i8 = arith.constant 2 : i8
%0 = arith.maxui %arg0, %c256_i16 : i16
%1 = arith.minui %0, %c257_i16 : i16
%2 = arith.trunci %1 : i16 to i8
%3 = arith.cmpi sge, %2, %c2_i8 : i8
func.return %3 : i1
}

// CHECK-LABEL: func.func @trunc_handles_small_signed_ranges
// CHECK: %[[true:.*]] = arith.constant true
// CHECK: return %[[true]]
func.func @trunc_handles_small_signed_ranges(%arg0 : i16) -> i1 {
%c-2_i16 = arith.constant -2 : i16
%c2_i16 = arith.constant 2 : i16
%c-2_i8 = arith.constant -2 : i8
%c2_i8 = arith.constant 2 : i8
%0 = arith.maxsi %arg0, %c-2_i16 : i16
%1 = arith.minsi %0, %c2_i16 : i16
%2 = arith.trunci %1 : i16 to i8
%3 = arith.cmpi sge, %2, %c-2_i8 : i8
%4 = arith.cmpi sle, %2, %c2_i8 : i8
%5 = arith.andi %3, %4 : i1
func.return %5 : i1
}

0 comments on commit 938fe9f

Please sign in to comment.