Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/Simplify_Add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,16 @@ Expr Simplify::visit(const Add *op, ExprInfo *info) {

if (rewrite(IRMatcher::Overflow() + x, a) ||
rewrite(x + IRMatcher::Overflow(), b) ||
rewrite(x + 0, x) ||
rewrite(0 + x, x)) {
rewrite(x + 0, a) ||
rewrite(0 + x, b)) {
if (info) {
if (rewrite.result.same_as(a)) {
info->intersect(a_info);
} else {
internal_assert(rewrite.result.same_as(b));
info->intersect(b_info);
}
}
return rewrite.result;
}

Expand Down
9 changes: 5 additions & 4 deletions src/Simplify_Call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,11 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) {
// pattern minus x. We get more information that way than just
// counting the leading zeros or ones.
Expr e = mutate(make_const(op->type, (int64_t)(-1), nullptr) - a, info);
// If the result of this happens to be a constant, we may as well
// return it. This is redundant with the constant folding below, but
// the constant folding below still needs to happen when info is
// nullptr.
// If the result of this happens to fold to a constant, we may as
// well return it immediately. This only happens if a is a constant
// uint or int, in which case the logic below would produce the
// exact same constant Expr with the exact same info as we're
// already holding.
if (info->bounds.is_single_point()) {
return e;
}
Expand Down
8 changes: 7 additions & 1 deletion src/Simplify_Cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,21 @@ Expr Simplify::visit(const Cast *op, ExprInfo *info) {
// know.
*info = ExprInfo{};
} else {
int64_t old_min = value_info.bounds.min;
value_info.cast_to(op->type);
if (op->type.is_uint() && op->type.bits() == 64 && old_min > 0) {
// It's impossible for a cast *to* a uint64 in Halide to lower the
// min. Casts to uint64_t don't overflow for any source type.
value_info.bounds.min = old_min;
}
value_info.trim_bounds_using_alignment();
if (info) {
*info = value_info;
}
// It's possible we just reduced to a constant. E.g. if we cast an
// even number to uint1 we get zero.
if (value_info.bounds.is_single_point()) {
return make_const(op->type, value_info.bounds.min, nullptr);
return make_const(op->type, value_info.bounds.min, info);
}
}

Expand Down
80 changes: 46 additions & 34 deletions src/Simplify_Div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,44 @@ Expr Simplify::visit(const Div *op, ExprInfo *info) {
Expr a = mutate(op->a, &a_info);
Expr b = mutate(op->b, &b_info);

if (info) {
if (op->type.is_int_or_uint()) {
// ConstantInterval division is integer division, so we can't use
// this code path for floats.
info->bounds = a_info.bounds / b_info.bounds;
info->alignment = a_info.alignment / b_info.alignment;
info->cast_to(op->type);
info->trim_bounds_using_alignment();

// Bounded numerator divided by constantish bounded denominator can
// sometimes collapse things to a constant at this point. This
// mostly happens when the denominator is a constant and the
// numerator span is small (e.g. [23, 29]/10 = 2), but there are
// also cases with a bounded denominator (e.g. [5, 7]/[4, 5] = 1).
if (info->bounds.is_single_point()) {
if (op->type.can_represent(info->bounds.min)) {
return make_const(op->type, info->bounds.min, nullptr);
} else {
// Even though this is 'no-overflow-int', if the result
// we calculate can't fit into the destination type,
// we're better off returning an overflow condition than
// a known-wrong value. (Note that no_overflow_int() should
// only be true for signed integers.)
internal_assert(no_overflow_int(op->type)) << op->type << " " << info->bounds;
clear_expr_info(info);
return make_signed_integer_overflow(op->type);
}
ExprInfo div_info;

if (op->type.is_int_or_uint()) {
// ConstantInterval division is integer division, so we can't use
// this code path for floats.
div_info.bounds = a_info.bounds / b_info.bounds;
div_info.alignment = a_info.alignment / b_info.alignment;
div_info.cast_to(op->type);
div_info.trim_bounds_using_alignment();

// Bounded numerator divided by constantish bounded denominator can
// sometimes collapse things to a constant at this point. This
// mostly happens when the denominator is a constant and the
// numerator span is small (e.g. [23, 29]/10 = 2), but there are
// also cases with a bounded denominator (e.g. [5, 7]/[4, 5] = 1).
if (div_info.bounds.is_single_point()) {
if (op->type.can_represent(div_info.bounds.min)) {
return make_const(op->type, div_info.bounds.min, info);
} else {
// Even though this is 'no-overflow-int', if the result
// we calculate can't fit into the destination type,
// we're better off returning an overflow condition than
// a known-wrong value. (Note that no_overflow_int() should
// only be true for signed integers.)
internal_assert(no_overflow_int(op->type)) << op->type << " " << div_info.bounds;
clear_expr_info(info);
return make_signed_integer_overflow(op->type);
}
} else {
// TODO: Tracking constant integer bounds of floating point values
// isn't so useful right now, but if we want integer bounds for
// floating point division later, here's the place to put it.
clear_expr_info(info);
}
} else {
// TODO: Tracking constant integer bounds of floating point values isn't
// so useful right now, but if we want integer bounds for floating point
// division later, here's the place to put it. Just leave div_info empty
// for now (i.e. nothing is known).
}

if (info) {
*info = div_info;
}

bool denominator_non_zero =
Expand All @@ -55,9 +59,17 @@ Expr Simplify::visit(const Div *op, ExprInfo *info) {

if (rewrite(IRMatcher::Overflow() / x, a) ||
rewrite(x / IRMatcher::Overflow(), b) ||
rewrite(x / 1, x) ||
rewrite(0 / x, 0) ||
rewrite(x / 1, a) ||
rewrite(0 / x, a) ||
false) {
if (info) {
if (rewrite.result.same_as(a)) {
info->intersect(a_info);
} else {
internal_assert(rewrite.result.same_as(b));
info->intersect(b_info);
}
}
return rewrite.result;
}

Expand Down
8 changes: 7 additions & 1 deletion src/Simplify_Exprs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@ Expr Simplify::visit(const IntImm *op, ExprInfo *info) {
}

Expr Simplify::visit(const UIntImm *op, ExprInfo *info) {
if (info && Int(64).can_represent(op->value)) {
if (info) {
// Pretend it's an int constant that has been cast to uint.
int64_t v = (int64_t)(op->value);
info->bounds = ConstantInterval::single_point(v);
info->alignment = ModulusRemainder(0, v);
// If it's not representable as an int64, this will wrap the alignment appropriately:
info->cast_to(op->type);
// Be as informative as we can with bounds for out-of-range uint64s
if ((int64_t)op->value < 0) {
info->bounds = ConstantInterval::bounded_below(INT64_MAX);
}
} else {
clear_expr_info(info);
}
Expand Down
14 changes: 13 additions & 1 deletion src/Simplify_Internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,15 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
}
}

// Truncate the bounds to the new type.
// For UInt64 constants, the remainder might not be representable as an int64
if (t.bits() == 64 && t.is_uint() &&
alignment.modulus == 0 && alignment.remainder < 0) {
// Forget the leading two bits to get a representable modulus
// and remainder.
alignment.modulus = (int64_t)1 << 62;
alignment.remainder = alignment.remainder & (alignment.modulus - 1);
}

bounds.cast_to(t);
}

Expand Down Expand Up @@ -241,6 +249,10 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
// We never want to return make_const anything in the simplifier without
// also setting the ExprInfo, so shadow the global make_const.
Expr make_const(const Type &t, int64_t c, ExprInfo *info) {
if (t.is_uint() && c < 0) {
// Wrap it around
return make_const(t, (uint64_t)c, info);
}
c = normalize_constant(t, c);
set_expr_info_to_constant(info, c);
return Halide::Internal::make_const(t, c);
Expand Down
17 changes: 10 additions & 7 deletions src/Simplify_Max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ Expr Simplify::visit(const Max *op, ExprInfo *info) {
if (max_info.bounds.is_single_point()) {
// This is possible when, for example, the largest number in the type
// that satisfies the alignment of the left-hand-side is smaller than
// the min value of the right-hand-side.
return make_const(op->type, max_info.bounds.min, nullptr);
// the min value of the right-hand-side. Reinferring the info can
// potentially give us something tighter than what was computed above if
// it's a large uint64.
return make_const(op->type, max_info.bounds.min, info);
}

auto strip_likely = [](const Expr &e) {
Expand Down Expand Up @@ -65,10 +67,10 @@ Expr Simplify::visit(const Max *op, ExprInfo *info) {
return rewrite.result;
}

// Cases where one side dominates. All of these must reduce to a or b in the
// RHS for ExprInfo to update correctly.
if (EVAL_IN_LAMBDA //
(rewrite(max(x, x), a) ||
rewrite(max(c0, c1), fold(max(c0, c1))) ||
// Cases where one side dominates:
rewrite(max(x, c0), b, is_max_value(c0)) ||
rewrite(max(x, c0), a, is_min_value(c0)) ||
rewrite(max((x / c0) * c0, x), b, c0 > 0) ||
Expand Down Expand Up @@ -148,16 +150,17 @@ Expr Simplify::visit(const Max *op, ExprInfo *info) {
// than just applying max to two constant intervals.
if (rewrite.result.same_as(a)) {
info->intersect(a_info);
} else if (rewrite.result.same_as(b)) {
} else {
internal_assert(rewrite.result.same_as(b));
info->intersect(b_info);
}
}

return rewrite.result;
}

if (EVAL_IN_LAMBDA //
(rewrite(max(max(x, c0), c1), max(x, fold(max(c0, c1)))) ||
(rewrite(max(c0, c1), fold(max(c0, c1))) ||
rewrite(max(max(x, c0), c1), max(x, fold(max(c0, c1)))) ||
rewrite(max(max(x, c0), y), max(max(x, y), c0)) ||
rewrite(max(max(x, y), max(x, z)), max(max(y, z), x)) ||
rewrite(max(max(y, x), max(x, z)), max(max(y, z), x)) ||
Expand Down
12 changes: 7 additions & 5 deletions src/Simplify_Min.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Expr Simplify::visit(const Min *op, ExprInfo *info) {
// This is possible when, for example, the smallest number in the type
// that satisfies the alignment of the left-hand-side is greater than
// the max value of the right-hand-side.
return make_const(op->type, min_info.bounds.min, nullptr);
return make_const(op->type, min_info.bounds.min, info);
}

// Early out when the bounds tells us one side or the other is smaller
Expand Down Expand Up @@ -66,10 +66,10 @@ Expr Simplify::visit(const Min *op, ExprInfo *info) {
return rewrite.result;
}

// Cases where one side dominates. All of these must reduce to a or b in the
// RHS for ExprInfo to update correctly.
if (EVAL_IN_LAMBDA //
(rewrite(min(x, x), a) ||
rewrite(min(c0, c1), fold(min(c0, c1))) ||
// Cases where one side dominates:
rewrite(min(x, c0), b, is_min_value(c0)) ||
rewrite(min(x, c0), a, is_max_value(c0)) ||
rewrite(min((x / c0) * c0, x), a, c0 > 0) ||
Expand Down Expand Up @@ -148,15 +148,17 @@ Expr Simplify::visit(const Min *op, ExprInfo *info) {
if (info) {
if (rewrite.result.same_as(a)) {
info->intersect(a_info);
} else if (rewrite.result.same_as(b)) {
} else {
internal_assert(rewrite.result.same_as(b));
info->intersect(b_info);
}
}
return rewrite.result;
}

if (EVAL_IN_LAMBDA //
(rewrite(min(min(x, c0), c1), min(x, fold(min(c0, c1)))) ||
(rewrite(min(c0, c1), fold(min(c0, c1))) ||
rewrite(min(min(x, c0), c1), min(x, fold(min(c0, c1)))) ||
rewrite(min(min(x, c0), y), min(min(x, y), c0)) ||
rewrite(min(min(x, y), min(x, z)), min(min(y, z), x)) ||
rewrite(min(min(y, x), min(x, z)), min(min(y, z), x)) ||
Expand Down
16 changes: 12 additions & 4 deletions src/Simplify_Mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,18 @@ Expr Simplify::visit(const Mul *op, ExprInfo *info) {
return rewrite.result;
}

if (rewrite(0 * x, 0) ||
rewrite(1 * x, x) ||
rewrite(x * 0, 0) ||
rewrite(x * 1, x)) {
if (rewrite(0 * x, a) ||
rewrite(1 * x, b) ||
rewrite(x * 0, b) ||
rewrite(x * 1, a)) {
if (info) {
if (rewrite.result.same_as(a)) {
info->intersect(a_info);
} else {
internal_assert(rewrite.result.same_as(b));
info->intersect(b_info);
}
}
return rewrite.result;
}

Expand Down
3 changes: 2 additions & 1 deletion src/Simplify_Select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ Expr Simplify::visit(const Select *op, ExprInfo *info) {
if (info) {
if (rewrite.result.same_as(true_value)) {
*info = t_info;
} else if (rewrite.result.same_as(false_value)) {
} else {
internal_assert(rewrite.result.same_as(false_value));
*info = f_info;
}
}
Expand Down
10 changes: 9 additions & 1 deletion src/Simplify_Sub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,15 @@ Expr Simplify::visit(const Sub *op, ExprInfo *info) {

if (rewrite(IRMatcher::Overflow() - x, a) ||
rewrite(x - IRMatcher::Overflow(), b) ||
rewrite(x - 0, x)) {
rewrite(x - 0, a)) {
if (info) {
if (rewrite.result.same_as(a)) {
info->intersect(a_info);
} else {
internal_assert(rewrite.result.same_as(b));
info->intersect(b_info);
}
}
return rewrite.result;
}

Expand Down
2 changes: 0 additions & 2 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ tests(GROUPS correctness
fused_where_inner_extent_is_zero.cpp
fuzz_float_stores.cpp
fuzz_schedule.cpp
fuzz_simplify.cpp
gameoflife.cpp
gather.cpp
gpu_alloc_group_profiling.cpp
Expand Down Expand Up @@ -356,7 +355,6 @@ tests(GROUPS correctness
vectorized_initialization.cpp
vectorized_load_from_vectorized_allocation.cpp
vectorized_reduction_bug.cpp
widening_lerp.cpp
widening_reduction.cpp
# keep-sorted end
)
Expand Down
Loading
Loading