diff --git a/src/Bounds.cpp b/src/Bounds.cpp index f4493474c49f..3fbffbd16775 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -319,9 +319,42 @@ class Bounds : public IRVisitor { bool could_overflow = true; if (to.can_represent(from) || to.is_float()) { could_overflow = false; + } else if (from.is_float() && to.is_int() && to.bits() >= 32) { + // For int32+ destinations, float-to-signed-int is + // implementation-defined on out-of-range values. Halide's + // convention is to assume the value fits, so carry the + // float bounds through the cast UNLESS at least one + // endpoint is a constant we can prove exceeds the + // destination's range. In that case IntImm::make sign- + // extends the low bits and can invert the resulting + // interval, which is strictly worse than bounds_of_type. + // + // (Float-to-signed-int for bits <= 16 saturates, so it is + // handled by the narrower `a.is_bounded()` path below; we + // don't need to treat it specially here.) + Interval s = simplify(a); + double lo = -std::ldexp(1.0, to.bits() - 1); + double hi_exclusive = std::ldexp(1.0, to.bits() - 1); + bool const_oob = false; + if (s.has_lower_bound()) { + if (auto fmin = as_const_float(s.min)) { + const_oob = const_oob || !std::isfinite(*fmin) || *fmin < lo; + } + } + if (s.has_upper_bound()) { + if (auto fmax = as_const_float(s.max)) { + const_oob = const_oob || !std::isfinite(*fmax) || *fmax >= hi_exclusive; + } + } + if (!const_oob) { + could_overflow = false; + a = s; + } } else if (to.is_int() && to.bits() >= 32) { // If we cast to an int32 or greater, assume that it won't - // overflow. Signed 32-bit integer overflow is undefined. + // overflow. Signed 32-bit integer overflow is undefined, and + // Halide treats uint-to-signed-int narrowing as "assumed not + // to overflow" too (see PR #7814 context). could_overflow = false; } else if (a.is_bounded()) { if (from.can_represent(to)) { @@ -742,7 +775,20 @@ class Bounds : public IRVisitor { Type t = op->type.element_of(); - // Mod is always positive + if (t.is_float()) { + // Halide mod is supposed to be non-negative (Euclidean), but + // the current float-mod lowering in CodeGen_LLVM is + // `a - b * floor(a/b)`, which can produce a negative result + // when b is negative (e.g. fmod(5, -3) = -1 under that + // formula). Until the lowering is tightened to enforce the + // Euclidean invariant, fall back to unbounded here -- the + // integer reasoning below would unsoundly claim the result + // is >= 0. + bounds_of_type(t); + return; + } + + // Mod is always non-negative for integer types. interval.min = make_zero(t); interval.max = Interval::pos_inf(); @@ -765,11 +811,6 @@ class Bounds : public IRVisitor { // x % [-8, 10] -> [0,9] interval.max = Max::make(interval.min, b.max - make_one(t)); interval.max = Max::make(interval.max, make_const(t, -1) - b.min); - } else if (b.max.type().is_float()) { - // The floating point version has the same sign rules, - // but can reach all the way up to the original value, - // so there's no -1. - interval.max = Max::make(b.max, -b.min); } } } diff --git a/src/Simplify_Cast.cpp b/src/Simplify_Cast.cpp index 631686ee0bfc..0d1b278a1e85 100644 --- a/src/Simplify_Cast.cpp +++ b/src/Simplify_Cast.cpp @@ -110,13 +110,20 @@ Expr Simplify::visit(const Cast *op, ExprInfo *info) { } else if (cast && op->type.is_int_or_uint() && cast->type.is_int_or_uint() && + cast->value.type().is_int_or_uint() && op->type.bits() <= cast->type.bits() && op->type.bits() <= op->value.type().bits()) { // If this is a cast between integer types, where the // outer cast is narrower than the inner cast and the // inner cast's argument, the inner cast can be // eliminated. The inner cast is either a sign extend - // or a zero extend, and the outer cast truncates the extended bits + // or a zero extend, and the outer cast truncates the extended bits. + // The requirement that cast->value is itself int-or-uint is crucial: + // a float source makes `cast` an fp-to-int conversion, whose low + // bits are not the same as an fp-to-int conversion of a narrower + // type. For example, int32(uint64(float64(-21))) evaluates to 0 + // (float-to-uint of a negative value saturates to 0 in Halide), + // while the stripped form int32(float64(-21)) evaluates to -21. if (op->type == cast->value.type()) { return mutate(cast->value, info); } else { diff --git a/src/Solve.cpp b/src/Solve.cpp index a0b91be7a287..05fa508c4926 100644 --- a/src/Solve.cpp +++ b/src/Solve.cpp @@ -156,7 +156,12 @@ class SolveExpression : public IRMutator { } } else if (a_uses_var && b_uses_var) { if (equal(a, b)) { - expr = mutate(a * 2); + // Use Mul::make + make_const rather than operator*(Expr, int) + // because the latter rejects constants that don't fit in a's + // type (e.g. `2` in UInt(1)). make_const truncates modulo the + // width, so `UInt(1) * 2` becomes `UInt(1) * 0`, which is + // the correct modular result of `a + a` for UInt(1). + expr = mutate(Mul::make(a, make_const(a.type(), 2))); } else if (add_a && !a_failed) { // (f(x) + a) + g(x) -> (f(x) + g(x)) + a expr = mutate((add_a->a + b) + add_a->b); @@ -181,11 +186,15 @@ class SolveExpression : public IRMutator { } else if (mul_b && equal(mul_b->a, a)) { // f(x) + f(x)*a -> f(x) * (a + 1) expr = mutate(a * (mul_b->b + 1)); - } else if (div_a && !a_failed) { - // f(x)/a + g(x) -> (f(x) + g(x) * a) / b + } else if (div_a && !a_failed && no_overflow_int(op->type)) { + // f(x)/a + g(x) -> (f(x) + g(x) * a) / a + // Only valid when multiplication and division don't wrap: + // under modular arithmetic g(x)*a can overflow and the + // rewrite changes the value. Gated to Int(32)+. expr = mutate((div_a->a + b * div_a->b) / div_a->b); - } else if (div_b && !b_failed) { + } else if (div_b && !b_failed && no_overflow_int(op->type)) { // f(x) + g(x)/b -> (f(x) * b + g(x)) / b + // Same overflow concern as above. expr = mutate((a * div_b->b + div_b->a) / div_b->b); } else { expr = fail(a + b); @@ -269,8 +278,9 @@ class SolveExpression : public IRMutator { } else if (mul_a && mul_b && equal(mul_a->b, mul_b->b)) { // f(x)*a - g(x)*a -> (f(x) - g(x))*a; expr = mutate((mul_a->a - mul_b->a) * mul_a->b); - } else if (div_a && !a_failed) { - // f(x)/a - g(x) -> (f(x) - g(x) * a) / b + } else if (div_a && !a_failed && no_overflow_int(op->type)) { + // f(x)/a - g(x) -> (f(x) - g(x) * a) / a + // Same overflow concern as the analogous Add rewrite. expr = mutate((div_a->a - b * div_a->b) / div_a->b); } else { expr = fail(a - b); @@ -609,11 +619,20 @@ class SolveExpression : public IRMutator { Expr expr; if (a_uses_var && !b_uses_var) { - // We have f(x) < y. Try to unwrap f(x) - if (add_a && !a_failed) { + // We have f(x) < y. Try to unwrap f(x). + // + // Several of these rewrites rearrange the comparison by adding + // or subtracting on both sides. That's only sound under an + // assumption of no integer overflow -- for types that wrap + // (unsigned and narrow signed), ordering comparisons flip + // under wrap even though equality is preserved. So gate the + // rewrite on no_overflow_int for LT/LE/GT/GE but allow EQ/NE + // for all types (modular arithmetic preserves equality). + const bool safe_to_rearrange = no_overflow_int(a.type()) || is_eq || is_ne; + if (add_a && !a_failed && safe_to_rearrange) { // f(x) + b < c -> f(x) < c - b expr = mutate(Cmp::make(add_a->a, (b - add_a->b))); - } else if (sub_a && !a_failed) { + } else if (sub_a && !a_failed && safe_to_rearrange) { // f(x) - b < c -> f(x) < c + b expr = mutate(Cmp::make(sub_a->a, (b + sub_a->b))); } else if (mul_a) { @@ -631,11 +650,14 @@ class SolveExpression : public IRMutator { // check is true, but put an assertion anyway. internal_assert(!b.type().is_uint()) << "Negating unsigned is not legal\n"; expr = mutate(Opp::make(mul_a->a * negate(mul_a->b), negate(b))); - } else { - // Don't use operator/ and operator % to sneak - // past the division-by-zero check. We'll only - // actually use these when mul_a->b is a positive - // or negative constant. + } else if (is_positive_const(mul_a->b) && no_overflow_int(a.type())) { + // The rewrites below divide by mul_a->b, so require + // it to be a nonzero constant of known sign. + // no_overflow_int also rules out unsigned and narrow + // signed types, for which `a*c == b <=> a == b/c && + // b%c == 0` fails under modular arithmetic (consider + // uint8 with c = 3, b = 7: the rewrite misses the + // solutions that arise from wrap). Expr div = Div::make(b, mul_a->b); Expr rem = Mod::make(b, mul_a->b); if (is_eq) { @@ -644,16 +666,14 @@ class SolveExpression : public IRMutator { } else if (is_ne) { // f(x) * c != b -> f(x) != b/c || b%c != 0 expr = mutate((mul_a->a != div) || (rem != 0)); - } else if (is_positive_const(mul_a->b)) { - if (is_le) { - expr = mutate(mul_a->a <= div); - } else if (is_lt) { - expr = mutate(mul_a->a <= (b - 1) / mul_a->b); - } else if (is_gt) { - expr = mutate(mul_a->a > div); - } else if (is_ge) { - expr = mutate(mul_a->a > (b - 1) / mul_a->b); - } + } else if (is_le) { + expr = mutate(mul_a->a <= div); + } else if (is_lt) { + expr = mutate(mul_a->a <= (b - 1) / mul_a->b); + } else if (is_gt) { + expr = mutate(mul_a->a > div); + } else if (is_ge) { + expr = mutate(mul_a->a > (b - 1) / mul_a->b); } } } else if (div_a) { @@ -663,7 +683,7 @@ class SolveExpression : public IRMutator { } else if (is_negative_const(div_a->b)) { expr = mutate(Opp::make(div_a->a, b * div_a->b)); } - } else if (a.type().is_int() && a.type().bits() >= 32) { + } else if (no_overflow_int(a.type())) { if (is_eq || is_ne) { // Can't do anything with this } else if (is_negative_const(div_a->b)) { @@ -689,7 +709,7 @@ class SolveExpression : public IRMutator { } } } - } else if (a_uses_var && b_uses_var && a.type().is_int() && a.type().bits() >= 32) { + } else if (a_uses_var && b_uses_var && no_overflow_int(a.type())) { // Convert to f(x) - g(x) == 0 and let the subtract mutator clean up. // Only safe if the type is not subject to overflow. expr = mutate(Cmp::make(a - b, make_zero(a.type()))); @@ -1184,301 +1204,5 @@ Expr and_condition_over_domain(const Expr &e, const Scope &varying) { Expr or_condition_over_domain(const Expr &c, const Scope &varying) { return simplify(!and_condition_over_domain(simplify(!c), varying)); } - -// Testing code - -namespace { - -void check_solve(const Expr &a, const Expr &b) { - SolverResult solved = solve_expression(a, "x"); - internal_assert(equal(solved.result, b)) - << "Expression: " << a << "\n" - << " solved to " << solved.result << "\n" - << " instead of " << b << "\n"; -} - -void check_interval(const Expr &a, const Interval &i, bool outer) { - Interval result = - outer ? solve_for_outer_interval(a, "x") : solve_for_inner_interval(a, "x"); - result.min = simplify(result.min); - result.max = simplify(result.max); - internal_assert(equal(result.min, i.min) && equal(result.max, i.max)) - << "Expression " << a << " solved to the interval:\n" - << " min: " << result.min << "\n" - << " max: " << result.max << "\n" - << " instead of:\n" - << " min: " << i.min << "\n" - << " max: " << i.max << "\n"; -} - -void check_outer_interval(const Expr &a, const Expr &min, const Expr &max) { - check_interval(a, Interval(min, max), true); -} - -void check_inner_interval(const Expr &a, const Expr &min, const Expr &max) { - check_interval(a, Interval(min, max), false); -} - -void check_and_condition(const Expr &orig, const Expr &result, const Interval &i) { - Scope s; - s.push("x", i); - Expr cond = and_condition_over_domain(orig, s); - internal_assert(equal(cond, result)) - << "Expression " << orig - << " reduced to " << cond - << " instead of " << result << "\n"; -} -} // namespace - -void solve_test() { - using ConciseCasts::i16; - - Expr x = Variable::make(Int(32), "x"); - Expr y = Variable::make(Int(32), "y"); - Expr z = Variable::make(Int(32), "z"); - - // Check some simple cases - check_solve(3 - 4 * x, x * (-4) + 3); - check_solve(min(5, x), min(x, 5)); - check_solve(max(5, (5 + x) * y), max(x * y + 5 * y, 5)); - check_solve(5 * y + 3 * x == 2, ((x == ((2 - (5 * y)) / 3)) && (((2 - (5 * y)) % 3) == 0))); - check_solve(min(min(z, x), min(x, y)), min(x, min(y, z))); - check_solve(min(x + y, x + 5), x + min(y, 5)); - - // Check solver with expressions containing division - check_solve(x + (x * 2) / 2, x * 2); - check_solve(x + (x * 2 + y) / 2, x * 2 + (y / 2)); - check_solve(x + (x * 2 - y) / 2, x * 2 - (y / 2)); - check_solve(x + (-(x * 2) / 2), x * 0 + 0); - check_solve(x + (-(x * 2 + -3)) / 2, x * 0 + 1); - check_solve(x + (z - (x * 2 + -3)) / 2, x * 0 + (z - (-3)) / 2); - check_solve(x + (y * 16 + (z - (x * 2 + -1))) / 2, - (x * 0) + (((z - -1) + (y * 16)) / 2)); - - check_solve((x * 9 + 3) / 4 - x * 2, (x * 1 + 3) / 4); - check_solve((x * 9 + 3) / 4 + x * 2, (x * 17 + 3) / 4); - check_solve(x * 2 + (x * 9 + 3) / 4, (x * 17 + 3) / 4); - - // Check the solver doesn't perform transformations that change integer overflow behavior. - check_solve(i16(x + y) * i16(2) / i16(2), i16(x + y) * i16(2) / i16(2)); - - // A let statement - check_solve(Let::make("z", 3 + 5 * x, y + z < 8), - x <= (((8 - (3 + y)) - 1) / 5)); - - // A let statement where the variable gets used twice. - check_solve(Let::make("z", 3 + 5 * x, y + (z + z) < 8), - x <= (((8 - (6 + y)) - 1) / 10)); - - // Something where we expect a let in the output. - { - Expr e = y + 1; - for (int i = 0; i < 10; i++) { - e *= (e + 1); - } - SolverResult solved = solve_expression(x + e < e * e, "x"); - internal_assert(solved.fully_solved && solved.result.as()); - } - - // Solving inequalities for integers is a pain to get right with - // all the rounding rules. Check we didn't make a mistake with - // brute force. - for (int den = -3; den <= 3; den++) { - if (den == 0) { - continue; - } - for (int num = 5; num <= 10; num++) { - Expr in[] = { - {x * den < num}, - {x * den <= num}, - {x * den == num}, - {x * den != num}, - {x * den >= num}, - {x * den > num}, - {x / den < num}, - {x / den <= num}, - {x / den == num}, - {x / den != num}, - {x / den >= num}, - {x / den > num}, - }; - for (const auto &e : in) { - SolverResult solved = solve_expression(e, "x"); - internal_assert(solved.fully_solved) << "Error: failed to solve for x in " << e << "\n"; - Expr out = simplify(solved.result); - for (int i = -10; i < 10; i++) { - Expr in_val = substitute("x", i, e); - Expr out_val = substitute("x", i, out); - in_val = simplify(in_val); - out_val = simplify(out_val); - internal_assert(equal(in_val, out_val)) - << "Error: " - << e << " is not equivalent to " - << out << " when x == " << i << "\n"; - } - } - } - } - - // Check for combinatorial explosion - Expr e = x + y; - for (int i = 0; i < 20; i++) { - e += (e + 1) * y; - } - SolverResult solved = solve_expression(e, "x"); - internal_assert(solved.fully_solved && solved.result.defined()); - - // Check some things that we don't expect to work. - - // Quadratics: - internal_assert(!solve_expression(x * x < 4, "x").fully_solved); - - // Function calls, cast nodes, or multiplications by unknown sign - // don't get inverted, but the bit containing x still gets moved - // leftwards. - check_solve(4.0f > sqrt(x), sqrt(x) < 4.0f); - - check_solve(4 > y * x, x * y < 4); - - // Now test solving for an interval - check_inner_interval(x > 0, 1, Interval::pos_inf()); - check_inner_interval(x < 100, Interval::neg_inf(), 99); - check_outer_interval(x > 0 && x < 100, 1, 99); - check_inner_interval(x > 0 && x < 100, 1, 99); - - Expr c = Variable::make(Bool(), "c"); - check_outer_interval(Let::make("y", 0, x > y && x < 100), 1, 99); - check_outer_interval(Let::make("c", x > 0, c && x < 100), 1, 99); - - check_outer_interval((x >= 10 && x <= 90) && sin(x) > 0.5f, 10, 90); - check_inner_interval((x >= 10 && x <= 90) && sin(x) > 0.6f, Interval::pos_inf(), Interval::neg_inf()); - - check_inner_interval(x == 10, 10, 10); - check_outer_interval(x == 10, 10, 10); - - check_inner_interval(!(x != 10), 10, 10); - check_outer_interval(!(x != 10), 10, 10); - - check_inner_interval(3 * x + 4 < 27, Interval::neg_inf(), 7); - check_outer_interval(3 * x + 4 < 27, Interval::neg_inf(), 7); - - check_inner_interval(min(x, y) > 17, 18, y); - check_outer_interval(min(x, y) > 17, 18, Interval::pos_inf()); - - check_inner_interval(x / 5 < 17, Interval::neg_inf(), 84); - check_outer_interval(x / 5 < 17, Interval::neg_inf(), 84); - - // Test anding a condition over a domain - check_and_condition(x > 0, const_true(), Interval(1, y)); - check_and_condition(x > 0, const_true(), Interval(5, y)); - check_and_condition(x > 0, const_false(), Interval(-5, y)); - check_and_condition(x > 0 && x < 10, const_true(), Interval(1, 9)); - check_and_condition(x > 0 || sin(x) == 0.5f, const_true(), Interval(100, 200)); - - check_and_condition(x <= 0, const_true(), Interval(-100, 0)); - check_and_condition(x <= 0, const_false(), Interval(-100, 1)); - - check_and_condition(x <= 0 || y > 2, const_true(), Interval(-100, 0)); - check_and_condition(x > 0 || y > 2, 2 < y, Interval(-100, 0)); - - check_and_condition(x == 0, const_true(), Interval(0, 0)); - check_and_condition(x == 0, const_false(), Interval(-10, 10)); - check_and_condition(x != 0, const_false(), Interval(-10, 10)); - check_and_condition(x != 0, const_true(), Interval(-20, -10)); - - check_and_condition(y == 0, y == 0, Interval(-10, 10)); - check_and_condition(y != 0, y != 0, Interval(-10, 10)); - check_and_condition((x == 5) && (y != 0), const_false(), Interval(-10, 10)); - check_and_condition((x == 5) && (y != 3), y != 3, Interval(5, 5)); - check_and_condition((x != 0) && (y != 0), const_false(), Interval(-10, 10)); - check_and_condition((x != 0) && (y != 0), y != 0, Interval(-20, -10)); - - { - // This case used to break due to signed integer overflow in - // the simplifier. - Expr a16 = Load::make(Int(16), "a", {x}, Buffer<>(), Parameter(), const_true(), ModulusRemainder()); - Expr b16 = Load::make(Int(16), "b", {x}, Buffer<>(), Parameter(), const_true(), ModulusRemainder()); - Expr lhs = pow(cast(a16), 2) + pow(cast(b16), 2); - - Scope s; - s.push("x", Interval(-10, 10)); - Expr cond = and_condition_over_domain(lhs < 0, s); - internal_assert(!is_const_one(simplify(cond))); - } - - { - // This cause use to cause infinite recursion: - Expr t = Variable::make(Int(32), "t"); - Expr test = (x <= min(max((y - min(((z * x) + t), t)), 1), 0)); - Interval result = solve_for_outer_interval(test, "z"); - } - - { - // This case caused exponential behavior - Expr t = Variable::make(Int(32), "t"); - for (int i = 0; i < 50; i++) { - t = min(t, Variable::make(Int(32), unique_name('v'))); - t = max(t, Variable::make(Int(32), unique_name('v'))); - } - solve_for_outer_interval(t <= 5, "t"); - solve_for_inner_interval(t <= 5, "t"); - } - - // Check for partial results - check_solve(max(min(y, x), x), max(min(x, y), x)); - check_solve(min(y, x) + max(y, 2 * x), min(x, y) + max(x * 2, y)); - check_solve((min(x, y) + min(y, x)) * max(y, x), (min(x, y) * 2) * max(x, y)); - check_solve(max((min((y * x), x) + min((1 + y), x)), (y + 2 * x)), - max((min((x * y), x) + min(x, (1 + y))), (x * 2 + y))); - - { - Expr x = Variable::make(UInt(32), "x"); - Expr y = Variable::make(UInt(32), "y"); - Expr z = Variable::make(UInt(32), "z"); - check_solve(5 - (4 - 4 * x), x * (4) + 1); - check_solve(z - (y - x), x + (z - y)); - check_solve(z - (y - x) == 2, x == 2 - (z - y)); - - check_solve(x - (x - y), (x - x) + y); - - // This is used to cause infinite recursion - Expr expr = Add::make(z, Sub::make(x, y)); - SolverResult solved = solve_expression(expr, "y"); - } - - // This case was incorrect due to canonicalization of the multiply - // occurring after unpacking the LHS. - check_solve((y - z) * x, x * (y - z)); - - // These cases were incorrectly not flipping min/max when moving - // it out of the RHS of a subtract. - check_solve(min(x - y, x - z), x - max(y, z)); - check_solve(min(x - y, x), x - max(y, 0)); - check_solve(min(x, x - y), x - max(y, 0)); - check_solve(max(x - y, x - z), x - min(y, z)); - check_solve(max(x - y, x), x - min(y, 0)); - check_solve(max(x, x - y), x - min(y, 0)); - - // Check mixed add/sub - check_solve(min(x - y, x + z), x + min(0 - y, z)); - check_solve(max(x - y, x + z), x + max(0 - y, z)); - check_solve(min(x + y, x - z), x + min(y, 0 - z)); - check_solve(max(x + y, x - z), x + max(y, 0 - z)); - - check_solve((5 * Broadcast::make(x, 4) + y) / 5, - Broadcast::make(x, 4) + (Broadcast::make(y, 4) / 5)); - - // Select negates the condition to move x leftward - check_solve(select(y < z, z, x), - select(z <= y, x, z)); - - // Select negates the condition and then mutates it, moving x - // leftward (despite the simplifier preferring < to >). - check_solve(select(x < 10, 10, x), - select(x >= 10, x, 10)); - - std::cout << "Solve test passed\n"; -} - } // namespace Internal } // namespace Halide diff --git a/src/Solve.h b/src/Solve.h index 4d06fda47d6b..8a8cd1164c75 100644 --- a/src/Solve.h +++ b/src/Solve.h @@ -54,8 +54,6 @@ Expr and_condition_over_domain(const Expr &c, const Scope &varying); * provide a better response than simply const_true(). */ Expr or_condition_over_domain(const Expr &c, const Scope &varying); -void solve_test(); - } // namespace Internal } // namespace Halide diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index a25b077a1abb..5ab072a96188 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -293,6 +293,7 @@ tests(GROUPS correctness sliding_over_guard_with_if.cpp sliding_reduction.cpp sliding_window.cpp + solve.cpp sort_exprs.cpp specialize.cpp specialize_to_gpu.cpp diff --git a/test/correctness/solve.cpp b/test/correctness/solve.cpp new file mode 100644 index 000000000000..f95f9c438518 --- /dev/null +++ b/test/correctness/solve.cpp @@ -0,0 +1,647 @@ +#include "Halide.h" + +#include +#include +#include +#include + +using namespace Halide; +using namespace Halide::Internal; + +namespace { + +// Assert that solve_expression produces exactly the given expected expression. +void check_solve(const Expr &in, const Expr &expected) { + SolverResult solved = solve_expression(in, "x"); + if (!equal(solved.result, expected)) { + std::cerr << "solve_expression produced unexpected result:\n" + << " input: " << in << "\n" + << " expected: " << expected << "\n" + << " actual: " << solved.result << "\n"; + std::abort(); + } +} + +void check_interval(const Expr &a, const Interval &i, bool outer) { + Interval result = + outer ? solve_for_outer_interval(a, "x") : solve_for_inner_interval(a, "x"); + result.min = simplify(result.min); + result.max = simplify(result.max); + if (!equal(result.min, i.min) || !equal(result.max, i.max)) { + std::cerr << "Expression " << a << " solved to the interval:\n" + << " min: " << result.min << "\n" + << " max: " << result.max << "\n" + << " instead of:\n" + << " min: " << i.min << "\n" + << " max: " << i.max << "\n"; + std::abort(); + } +} + +void check_outer_interval(const Expr &a, const Expr &min, const Expr &max) { + check_interval(a, Interval(min, max), true); +} + +void check_inner_interval(const Expr &a, const Expr &min, const Expr &max) { + check_interval(a, Interval(min, max), false); +} + +void check_and_condition(const Expr &orig, const Expr &result, const Interval &i) { + Scope s; + s.push("x", i); + Expr cond = and_condition_over_domain(orig, s); + if (!equal(cond, result)) { + std::cerr << "Expression " << orig + << " reduced to " << cond + << " instead of " << result << "\n"; + std::abort(); + } +} + +// Assert that solve_expression produces a result that is semantically +// equivalent to the input under the given substitution. This is used for +// cases where we care about preserved meaning, not exact syntactic form. +void check_solve_equivalent(const Expr &in, const std::map &vars) { + SolverResult solved = solve_expression(in, "x"); + Expr in_v = simplify(substitute(vars, in)); + Expr out_v = simplify(substitute(vars, solved.result)); + if (!equal(in_v, out_v)) { + std::cerr << "solve_expression changed value under substitution:\n" + << " input: " << in << "\n" + << " solved: " << solved.result << "\n"; + for (const auto &[name, val] : vars) { + std::cerr << " " << name << " = " << val << "\n"; + } + std::cerr << " input evaluated: " << in_v << "\n" + << " solved evaluated: " << out_v << "\n"; + std::abort(); + } +} + +// Bug #1: the solver was rewriting `f(x) + b @ c` to `f(x) @ c - b` for +// every comparison @, but for unsigned types the subtraction wraps, which +// does not preserve the *ordering* comparisons LT/LE/GT/GE (the EQ/NE +// rewrite is still valid under modular arithmetic, so those stay). +void test_unsigned_ordering_not_rearranged() { + Expr x = Variable::make(UInt(32), "x"); + Expr y = Variable::make(UInt(32), "y"); + + // A concrete substitution that demonstrates the wrap: with x = 4 and + // y = (uint32_t)-14 = 4294967282, x + y = 4294967286, so + // `x + y < 1641646169` is false. The buggy rewrite `x < 1641646169 - y` + // underflows 1641646169 - 4294967282 to 1641646183, making it true. + std::map vars{ + {"x", UIntImm::make(UInt(32), 4)}, + {"y", UIntImm::make(UInt(32), 4294967282u)}, + }; + + Expr c = UIntImm::make(UInt(32), 1641646169u); + check_solve_equivalent(x + y < c, vars); + check_solve_equivalent(x + y <= c, vars); + check_solve_equivalent(x + y > c, vars); + check_solve_equivalent(x + y >= c, vars); + + // The symmetric subtraction form must be preserved too. + check_solve_equivalent(x - y < c, vars); + check_solve_equivalent(x - y <= c, vars); + check_solve_equivalent(x - y > c, vars); + check_solve_equivalent(x - y >= c, vars); +} + +// Bug #1 corollary: EQ/NE rewrites are still safe under modular arithmetic +// (modular equivalence preserves equality), so these should continue to be +// rewritten to isolate x on the left. +void test_unsigned_equality_still_rearranged() { + Expr x = Variable::make(UInt(32), "x"); + Expr y = Variable::make(UInt(32), "y"); + Expr c = UIntImm::make(UInt(32), 2u); + + // `x + y == c` should solve to `x == c - y`, matching existing tests + // in src/Solve.cpp's solve_test() for unsigned rewrites. + check_solve(x + y == c, x == (c - y)); + check_solve(x + y != c, x != (c - y)); +} + +// Bug #2: the solver was rewriting `f(x) * y @ b` to forms involving `b / y` +// and `b % y` even when `y` was a non-constant expression. When `y` evaluates +// to zero the rewrite changes the expression's value even though Halide +// defines div/mod-by-zero to return zero -- `a * 0 == b` becomes `a == b/0 && +// b%0 == 0` which collapses to `a == 0 && b == 0`, losing the original +// "always false when b != 0" semantics. +void test_nonconstant_multiplier_not_rewritten() { + Expr x = Variable::make(Int(32), "x"); + Expr y = Variable::make(Int(32), "y"); + + // At y = 0, `x * y == 1` is the well-defined `0 == 1 == false`. + // The buggy rewrite `x == 1/y && 1%y == 0` evaluates to + // `x == 0 && true == true`, which is true at x = 0 -- changing the + // value of the expression. + std::map vars_zero{ + {"x", Expr(7)}, + {"y", Expr(0)}, + }; + check_solve_equivalent(x * y == 1, vars_zero); + check_solve_equivalent(x * y != 1, vars_zero); + + // Non-zero y: must still be semantically preserved. + std::map vars_nonzero{ + {"x", Expr(7)}, + {"y", Expr(3)}, + }; + check_solve_equivalent(x * y == 1, vars_nonzero); + check_solve_equivalent(x * y != 1, vars_nonzero); +} + +// The guarded form of the Mul rewrite -- a positive constant multiplier -- +// must continue to work after the fix. visit(Div) constant-folds when both +// operands are const, so `Div::make(7, 3)` reduces to 2 during mutation; +// there is no analogous fold for Mod so the Mod node stays. +void test_positive_const_multiplier_still_rewritten() { + Expr x = Variable::make(Int(32), "x"); + Expr seven = Expr(7); + Expr three = Expr(3); + check_solve(3 * x == 7, + (x == 2) && (Mod::make(seven, three) == 0)); + check_solve(3 * x != 7, + (x != 2) || (Mod::make(seven, three) != 0)); +} + +// Solver used to rewrite `f(x) + f(x) -> f(x) * 2` via `operator*(Expr, int)`, +// which rejects constants that don't fit in the expression type. For UInt(1), +// the literal 2 isn't representable, aborting the whole solve. Use Mul::make +// directly with make_const (which truncates modulo width) so the rewrite +// applies soundly for every integer type -- for UInt(1), `a * 2` correctly +// becomes `a * 0`, matching the modular value of `a + a`. +void test_solve_does_not_abort_on_narrow_self_add() { + Expr x = Variable::make(UInt(1), "x"); + // This used to abort with + // "Integer constant 2 will be implicitly coerced to type uint1..." + SolverResult s = solve_expression(x + x, "x"); + // The actual rewritten form is unimportant here -- the test just locks + // in that solve_expression doesn't abort on this shape. + if (!s.result.defined()) { + std::cerr << "solve_expression returned undefined on `x + x` (UInt(1))\n"; + std::abort(); + } +} + +// Solver's `f(x)/a + g(x) -> (f(x) + g(x) * a) / a` rewrite is only valid +// under non-wrapping arithmetic: modularly, g(x)*a can overflow the width +// and the rewrite changes the computed value. Guard it on no_overflow_int. +void test_narrow_div_add_equivalence() { + // Reproduced from the fuzzer (seed 9414558261169807111, minimized): + // `(uint8(a4)/137) + uint8(a4)` at a4=-13 (uint8 243) is + // 243/137 + 243 = 1 + 243 = 244 (uint8, no wrap). + // The previous rewrite would convert this to + // (uint8(a4) * 138) / 137 + // which at uint8 243 gives (243*138 mod 256)/137 = 254/137 = 1. + Expr a = Variable::make(Int(32), "a"); + Expr u = Cast::make(UInt(8), a); + Expr input = u / UIntImm::make(UInt(8), 137) + u; + SolverResult s = solve_expression(input, "a"); + // Verify by concrete substitution: the solved expression must evaluate + // to the same value as the input at a = -13. + std::map subst{{"a", Expr(-13)}}; + Expr in_v = simplify(substitute(subst, input)); + Expr out_v = simplify(substitute(subst, s.result)); + if (!equal(in_v, out_v)) { + std::cerr << "solve_expression changed value on narrow div+add:\n" + << " input: " << input << " -> " << in_v << "\n" + << " solved: " << s.result << " -> " << out_v << "\n"; + std::abort(); + } +} + +// bounds_of_expr_in_scope for `float % float` was applying integer-mod +// semantics and claiming the result is always in `[0, max(|b|, -b.min)]`. +// Halide mod is defined to be non-negative (Euclidean), but the current +// float-mod lowering in CodeGen_LLVM is `a - b * floor(a/b)`, which does +// not always produce a non-negative result when b is negative -- e.g. +// fmod(5, -3) lowers to 5 - (-3)*floor(5/-3) = 5 - (-3)*(-2) = -1. The +// fuzzer exercises this via a concrete substitution and compares the +// rewritten form's simplified value against bounds_of's claim, and the +// two disagree. Until the lowering is fixed to enforce Euclidean mod, +// bounds_of falls back to unbounded for float mod. When the lowering +// is fixed, this test and the float-mod branch in Bounds.cpp can both +// go away. +void test_bounds_of_float_mod_matches_lowering() { + Expr a = Variable::make(Float(64), "a"); + Expr b = Variable::make(Float(64), "b"); + Expr e = Mod::make(a, b); + + Scope scope; + scope.push("a", Interval(Expr(-10.0), Expr(-3.0))); + scope.push("b", Interval(Expr(-9.0), Expr(4.0))); + + Interval bounds = bounds_of_expr_in_scope(e, scope); + // Under the current lowering the result can be negative, so + // bounds.min must not be provably non-negative. + if (bounds.has_lower_bound()) { + Expr provably_nonneg = simplify(bounds.min >= Expr(0.0)); + if (is_const_one(provably_nonneg)) { + std::cerr << "bounds_of float % float claims result is non-negative, " + << "but the current lowering can produce negative results:\n" + << " expr: " << e << "\n" + << " bounds.min: " << simplify(bounds.min) << "\n" + << " bounds.max: " << simplify(bounds.max) << "\n"; + std::abort(); + } + } +} + +// Simplify_Cast was applying a cast-chain simplification +// int32(uint64(X)) -> int32(X) +// whenever widths and the two outer types all lined up for the +// "sign-extend then truncate" shape. The rule's correctness depends on +// the *inner* cast actually being a sign/zero extend, which only holds +// when its source is an integer. For `int32(uint64(float64(a)))` the +// inner cast is an fp-to-uint conversion, which has entirely different +// semantics -- so the stripped form `int32(float64(a))` evaluates to a +// different value (fp-to-int vs fp-to-uint-then-truncate). +// bounds_of was taking a shortcut for `intN(float_expr)` casts (N >= 32) that +// assumed the float-to-signed-int cast "truncates in place and preserves +// interval orientation", so it carried the source float bounds through the +// cast unchanged. But the simplifier folds out-of-range float constants by +// wrapping (IntImm::make sign-extends the low 32 bits), which can invert the +// resulting int interval -- e.g. floats {22e9, 24e9} both wrap, and the two +// wrapped values happen to land on opposite sides of zero. The shortcut then +// produced an inverted (empty) interval, which downstream reasoning treated +// as a vacuously-satisfied constraint. +void test_bounds_of_float_to_int_cast_is_sound() { + Expr a = Variable::make(Int(32), "a"); + Expr f = -1712582016.0f * cast(a); + Expr e = cast(f); + + Scope scope; + scope.push("a", Interval(cast(-14), cast(-13))); + + Interval bounds = bounds_of_expr_in_scope(e, scope); + + // Concrete evaluations at each endpoint. + std::map sub_min{{"a", Expr(-14)}}; + std::map sub_max{{"a", Expr(-13)}}; + Expr v_min = simplify(substitute(sub_min, e)); + Expr v_max = simplify(substitute(sub_max, e)); + + // bounds_of must produce an interval that contains both observed values. + // Either the interval is unbounded on that side, or the bound must + // provably hold. + auto must_hold = [&](const Expr &claim, const char *msg) { + Expr simplified = simplify(claim); + if (!is_const_one(simplified)) { + std::cerr << "bounds_of int32(float) unsound: " << msg << "\n" + << " expr: " << e << "\n" + << " bounds.min: " << simplify(bounds.min) << "\n" + << " bounds.max: " << simplify(bounds.max) << "\n" + << " v(a=-14): " << v_min << "\n" + << " v(a=-13): " << v_max << "\n"; + std::abort(); + } + }; + + if (bounds.has_lower_bound()) { + must_hold(bounds.min <= v_min, "v(a=-14) below bounds.min"); + must_hold(bounds.min <= v_max, "v(a=-13) below bounds.min"); + } + if (bounds.has_upper_bound()) { + must_hold(bounds.max >= v_min, "v(a=-14) above bounds.max"); + must_hold(bounds.max >= v_max, "v(a=-13) above bounds.max"); + } +} + +// Regression guard for the precision side of the above fix. Carrying +// *symbolic* (non-constant) float bounds through an int32+ cast is load- +// bearing for bilateral_grid and similar pipelines: bounds often end up +// as `select(r_sigma > 0, 0/r_sigma, 1/r_sigma) + 0.5` etc., where we +// can't prove the range fits at bounds-inference time but the programmer +// has opted into the convention that float-to-int out-of-range is +// undefined. The soundness fix above must only reject *provably* +// out-of-range constant bounds, not symbolic ones. +void test_bounds_of_float_to_int_preserves_symbolic_bounds() { + Expr rs = Variable::make(Float(32), "r_sigma"); + Expr val = Variable::make(Float(32), "val"); + Expr e = cast(val / rs + 0.5f); + + Scope scope; + scope.push("val", Interval(Expr(0.0f), Expr(1.0f))); + // r_sigma has a bounded positive estimate (what apply_param_estimates + // or a scope binding effectively produces after frontend lowering). + scope.push("r_sigma", Interval(Expr(0.1f), Expr(1.0f))); + + Interval bounds = bounds_of_expr_in_scope(e, scope); + if (!bounds.is_bounded()) { + std::cerr << "bounds_of int32(float_expr) dropped to unbounded " + << "for a case that should carry through:\n" + << " expr: " << e << "\n" + << " bounds.min: " << (bounds.has_lower_bound() ? simplify(bounds.min) : Expr("neg_inf")) << "\n" + << " bounds.max: " << (bounds.has_upper_bound() ? simplify(bounds.max) : Expr("pos_inf")) << "\n"; + std::abort(); + } +} + +void test_simplify_preserves_float_to_uint_cast_chain() { + Expr a = Variable::make(Int(32), "a"); + Expr chained = Cast::make(Int(32), + Cast::make(UInt(64), + Cast::make(Float(64), a))); + Expr simplified = simplify(chained); + + // At a = -21, the two forms must agree. + std::map subst{{"a", Expr(-21)}}; + Expr v1 = simplify(substitute(subst, chained)); + Expr v2 = simplify(substitute(subst, simplified)); + if (!equal(v1, v2)) { + std::cerr << "simplify changed the value of a cast chain:\n" + << " original: " << chained << " -> " << v1 << "\n" + << " simplified: " << simplified << " -> " << v2 << "\n"; + std::abort(); + } +} + +// Previously lived as `solve_test()` at the bottom of src/Solve.cpp and +// was invoked from test/internal.cpp. Moved here so all solver tests are +// in one place. +void test_original_solve_test_cases() { + using ConciseCasts::i16; + + Expr x = Variable::make(Int(32), "x"); + Expr y = Variable::make(Int(32), "y"); + Expr z = Variable::make(Int(32), "z"); + + // Check some simple cases + check_solve(3 - 4 * x, x * (-4) + 3); + check_solve(min(5, x), min(x, 5)); + check_solve(max(5, (5 + x) * y), max(x * y + 5 * y, 5)); + check_solve(5 * y + 3 * x == 2, ((x == ((2 - (5 * y)) / 3)) && (((2 - (5 * y)) % 3) == 0))); + check_solve(min(min(z, x), min(x, y)), min(x, min(y, z))); + check_solve(min(x + y, x + 5), x + min(y, 5)); + + // Check solver with expressions containing division + check_solve(x + (x * 2) / 2, x * 2); + check_solve(x + (x * 2 + y) / 2, x * 2 + (y / 2)); + check_solve(x + (x * 2 - y) / 2, x * 2 - (y / 2)); + check_solve(x + (-(x * 2) / 2), x * 0 + 0); + check_solve(x + (-(x * 2 + -3)) / 2, x * 0 + 1); + check_solve(x + (z - (x * 2 + -3)) / 2, x * 0 + (z - (-3)) / 2); + check_solve(x + (y * 16 + (z - (x * 2 + -1))) / 2, + (x * 0) + (((z - -1) + (y * 16)) / 2)); + + check_solve((x * 9 + 3) / 4 - x * 2, (x * 1 + 3) / 4); + check_solve((x * 9 + 3) / 4 + x * 2, (x * 17 + 3) / 4); + check_solve(x * 2 + (x * 9 + 3) / 4, (x * 17 + 3) / 4); + + // Check the solver doesn't perform transformations that change integer overflow behavior. + check_solve(i16(x + y) * i16(2) / i16(2), i16(x + y) * i16(2) / i16(2)); + + // A let statement + check_solve(Let::make("z", 3 + 5 * x, y + z < 8), + x <= (((8 - (3 + y)) - 1) / 5)); + + // A let statement where the variable gets used twice. + check_solve(Let::make("z", 3 + 5 * x, y + (z + z) < 8), + x <= (((8 - (6 + y)) - 1) / 10)); + + // Something where we expect a let in the output. + { + Expr e = y + 1; + for (int i = 0; i < 10; i++) { + e *= (e + 1); + } + SolverResult solved = solve_expression(x + e < e * e, "x"); + if (!(solved.fully_solved && solved.result.as())) { + std::cerr << "Expected fully-solved Let-bearing result\n"; + std::abort(); + } + } + + // Solving inequalities for integers is a pain to get right with + // all the rounding rules. Check we didn't make a mistake with + // brute force. + for (int den = -3; den <= 3; den++) { + if (den == 0) { + continue; + } + for (int num = 5; num <= 10; num++) { + Expr in[] = { + {x * den < num}, + {x * den <= num}, + {x * den == num}, + {x * den != num}, + {x * den >= num}, + {x * den > num}, + {x / den < num}, + {x / den <= num}, + {x / den == num}, + {x / den != num}, + {x / den >= num}, + {x / den > num}, + }; + for (const auto &e : in) { + SolverResult solved = solve_expression(e, "x"); + if (!solved.fully_solved) { + std::cerr << "Error: failed to solve for x in " << e << "\n"; + std::abort(); + } + Expr out = simplify(solved.result); + for (int i = -10; i < 10; i++) { + Expr in_val = substitute("x", i, e); + Expr out_val = substitute("x", i, out); + in_val = simplify(in_val); + out_val = simplify(out_val); + if (!equal(in_val, out_val)) { + std::cerr << "Error: " + << e << " is not equivalent to " + << out << " when x == " << i << "\n"; + std::abort(); + } + } + } + } + } + + // Check for combinatorial explosion + { + Expr e = x + y; + for (int i = 0; i < 20; i++) { + e += (e + 1) * y; + } + SolverResult solved = solve_expression(e, "x"); + if (!(solved.fully_solved && solved.result.defined())) { + std::cerr << "Expected fully-solved defined result for combinatorial case\n"; + std::abort(); + } + } + + // Check some things that we don't expect to work. + + // Quadratics: + if (solve_expression(x * x < 4, "x").fully_solved) { + std::cerr << "Expected quadratic to not be fully solved\n"; + std::abort(); + } + + // Function calls, cast nodes, or multiplications by unknown sign + // don't get inverted, but the bit containing x still gets moved + // leftwards. + check_solve(4.0f > sqrt(x), sqrt(x) < 4.0f); + + check_solve(4 > y * x, x * y < 4); + + // Now test solving for an interval + check_inner_interval(x > 0, 1, Interval::pos_inf()); + check_inner_interval(x < 100, Interval::neg_inf(), 99); + check_outer_interval(x > 0 && x < 100, 1, 99); + check_inner_interval(x > 0 && x < 100, 1, 99); + + Expr c = Variable::make(Bool(), "c"); + check_outer_interval(Let::make("y", 0, x > y && x < 100), 1, 99); + check_outer_interval(Let::make("c", x > 0, c && x < 100), 1, 99); + + check_outer_interval((x >= 10 && x <= 90) && sin(x) > 0.5f, 10, 90); + check_inner_interval((x >= 10 && x <= 90) && sin(x) > 0.6f, Interval::pos_inf(), Interval::neg_inf()); + + check_inner_interval(x == 10, 10, 10); + check_outer_interval(x == 10, 10, 10); + + check_inner_interval(!(x != 10), 10, 10); + check_outer_interval(!(x != 10), 10, 10); + + check_inner_interval(3 * x + 4 < 27, Interval::neg_inf(), 7); + check_outer_interval(3 * x + 4 < 27, Interval::neg_inf(), 7); + + check_inner_interval(min(x, y) > 17, 18, y); + check_outer_interval(min(x, y) > 17, 18, Interval::pos_inf()); + + check_inner_interval(x / 5 < 17, Interval::neg_inf(), 84); + check_outer_interval(x / 5 < 17, Interval::neg_inf(), 84); + + // Test anding a condition over a domain + check_and_condition(x > 0, const_true(), Interval(1, y)); + check_and_condition(x > 0, const_true(), Interval(5, y)); + check_and_condition(x > 0, const_false(), Interval(-5, y)); + check_and_condition(x > 0 && x < 10, const_true(), Interval(1, 9)); + check_and_condition(x > 0 || sin(x) == 0.5f, const_true(), Interval(100, 200)); + + check_and_condition(x <= 0, const_true(), Interval(-100, 0)); + check_and_condition(x <= 0, const_false(), Interval(-100, 1)); + + check_and_condition(x <= 0 || y > 2, const_true(), Interval(-100, 0)); + check_and_condition(x > 0 || y > 2, 2 < y, Interval(-100, 0)); + + check_and_condition(x == 0, const_true(), Interval(0, 0)); + check_and_condition(x == 0, const_false(), Interval(-10, 10)); + check_and_condition(x != 0, const_false(), Interval(-10, 10)); + check_and_condition(x != 0, const_true(), Interval(-20, -10)); + + check_and_condition(y == 0, y == 0, Interval(-10, 10)); + check_and_condition(y != 0, y != 0, Interval(-10, 10)); + check_and_condition((x == 5) && (y != 0), const_false(), Interval(-10, 10)); + check_and_condition((x == 5) && (y != 3), y != 3, Interval(5, 5)); + check_and_condition((x != 0) && (y != 0), const_false(), Interval(-10, 10)); + check_and_condition((x != 0) && (y != 0), y != 0, Interval(-20, -10)); + + { + // This case used to break due to signed integer overflow in + // the simplifier. + Expr a16 = Load::make(Int(16), "a", {x}, Buffer<>(), Parameter(), const_true(), ModulusRemainder()); + Expr b16 = Load::make(Int(16), "b", {x}, Buffer<>(), Parameter(), const_true(), ModulusRemainder()); + Expr lhs = pow(cast(a16), 2) + pow(cast(b16), 2); + + Scope s; + s.push("x", Interval(-10, 10)); + Expr cond = and_condition_over_domain(lhs < 0, s); + if (is_const_one(simplify(cond))) { + std::cerr << "Expected cond to not simplify to const_one\n"; + std::abort(); + } + } + + { + // This cause use to cause infinite recursion: + Expr t = Variable::make(Int(32), "t"); + Expr test = (x <= min(max((y - min(((z * x) + t), t)), 1), 0)); + Interval result = solve_for_outer_interval(test, "z"); + } + + { + // This case caused exponential behavior + Expr t = Variable::make(Int(32), "t"); + for (int i = 0; i < 50; i++) { + t = min(t, Variable::make(Int(32), unique_name('v'))); + t = max(t, Variable::make(Int(32), unique_name('v'))); + } + solve_for_outer_interval(t <= 5, "t"); + solve_for_inner_interval(t <= 5, "t"); + } + + // Check for partial results + check_solve(max(min(y, x), x), max(min(x, y), x)); + check_solve(min(y, x) + max(y, 2 * x), min(x, y) + max(x * 2, y)); + check_solve((min(x, y) + min(y, x)) * max(y, x), (min(x, y) * 2) * max(x, y)); + check_solve(max((min((y * x), x) + min((1 + y), x)), (y + 2 * x)), + max((min((x * y), x) + min(x, (1 + y))), (x * 2 + y))); + + { + Expr x = Variable::make(UInt(32), "x"); + Expr y = Variable::make(UInt(32), "y"); + Expr z = Variable::make(UInt(32), "z"); + check_solve(5 - (4 - 4 * x), x * (4) + 1); + check_solve(z - (y - x), x + (z - y)); + check_solve(z - (y - x) == 2, x == 2 - (z - y)); + + check_solve(x - (x - y), (x - x) + y); + + // This is used to cause infinite recursion + Expr expr = Add::make(z, Sub::make(x, y)); + SolverResult solved = solve_expression(expr, "y"); + } + + // This case was incorrect due to canonicalization of the multiply + // occurring after unpacking the LHS. + check_solve((y - z) * x, x * (y - z)); + + // These cases were incorrectly not flipping min/max when moving + // it out of the RHS of a subtract. + check_solve(min(x - y, x - z), x - max(y, z)); + check_solve(min(x - y, x), x - max(y, 0)); + check_solve(min(x, x - y), x - max(y, 0)); + check_solve(max(x - y, x - z), x - min(y, z)); + check_solve(max(x - y, x), x - min(y, 0)); + check_solve(max(x, x - y), x - min(y, 0)); + + // Check mixed add/sub + check_solve(min(x - y, x + z), x + min(0 - y, z)); + check_solve(max(x - y, x + z), x + max(0 - y, z)); + check_solve(min(x + y, x - z), x + min(y, 0 - z)); + check_solve(max(x + y, x - z), x + max(y, 0 - z)); + + check_solve((5 * Broadcast::make(x, 4) + y) / 5, + Broadcast::make(x, 4) + (Broadcast::make(y, 4) / 5)); + + // Select negates the condition to move x leftward + check_solve(select(y < z, z, x), + select(z <= y, x, z)); + + // Select negates the condition and then mutates it, moving x + // leftward (despite the simplifier preferring < to >). + check_solve(select(x < 10, 10, x), + select(x >= 10, x, 10)); +} + +} // namespace + +int main(int argc, char **argv) { + test_original_solve_test_cases(); + test_unsigned_ordering_not_rearranged(); + test_unsigned_equality_still_rearranged(); + test_nonconstant_multiplier_not_rewritten(); + test_positive_const_multiplier_still_rewritten(); + test_solve_does_not_abort_on_narrow_self_add(); + test_narrow_div_add_equivalence(); + test_bounds_of_float_mod_matches_lowering(); + test_bounds_of_float_to_int_cast_is_sound(); + test_bounds_of_float_to_int_preserves_symbolic_bounds(); + test_simplify_preserves_float_to_uint_cast_chain(); + std::printf("Success!\n"); + return 0; +} diff --git a/test/fuzz/CMakeLists.txt b/test/fuzz/CMakeLists.txt index af4ba4fad84d..95716c2c134c 100644 --- a/test/fuzz/CMakeLists.txt +++ b/test/fuzz/CMakeLists.txt @@ -11,6 +11,7 @@ tests(GROUPS fuzz cse.cpp lossless_cast.cpp simplify.cpp + solve.cpp widening_lerp.cpp # By default, the libfuzzer harness runs with a timeout of 1200 seconds. # Let's dial that back: diff --git a/test/fuzz/random_expr_generator.h b/test/fuzz/random_expr_generator.h index cc196ba76f5a..a87bf7c7205f 100644 --- a/test/fuzz/random_expr_generator.h +++ b/test/fuzz/random_expr_generator.h @@ -200,8 +200,8 @@ class RandomExpressionGenerator { return fuzz.PickValueInArray(make_bin_op)(a, b); }); } - if (gen_bitwise) { - // Bitwise + if (gen_bitwise && !t.is_float()) { + // Bitwise -- not valid on float types, so skip when t is float. ops.push_back([&] { static make_bin_op_fn make_bin_op[] = { make_bitwise_or, diff --git a/test/fuzz/solve.cpp b/test/fuzz/solve.cpp new file mode 100644 index 000000000000..0e15db202d37 --- /dev/null +++ b/test/fuzz/solve.cpp @@ -0,0 +1,524 @@ +#include "Halide.h" +#include + +#include "IRGraphCXXPrinter.h" +#include "fuzz_helpers.h" +#include "random_expr_generator.h" + +// Test the solver in Halide by generating random expressions and verifying that +// solve_expression, solve_for_inner_interval, solve_for_outer_interval, and +// and_condition_over_domain satisfy their respective contracts under random +// concrete substitutions. +namespace { + +using std::map; +using std::string; +using namespace Halide; +using namespace Halide::Internal; + +// Wrap a call that may throw InternalError in an std::variant so callers can +// report the failure with context rather than aborting the whole fuzzer. +template +struct SafeResult : std::variant { + using std::variant::variant; + bool ok() const { + return this->index() == 0; + } + bool failed() const { + return this->index() == 1; + } + const T &value() const { + return std::get(*this); + } +}; + +SafeResult safe_simplify(const Expr &e) { + try { + return simplify(e); + } catch (InternalError &err) { + std::cerr << "simplify threw on:\n" + << e << "\n" + << err.what() << "\n"; + return err; + } +} + +SafeResult safe_solve_expression(const Expr &e, const string &var) { + try { + return solve_expression(e, var); + } catch (InternalError &err) { + std::cerr << "solve_expression threw on:\n" + << e << "\n solving for \"" << var << "\"\n" + << err.what() << "\n"; + return err; + } +} + +SafeResult safe_solve_for_inner_interval(const Expr &c, const string &var) { + try { + return solve_for_inner_interval(c, var); + } catch (InternalError &err) { + std::cerr << "solve_for_inner_interval threw on:\n" + << c << "\n solving for \"" << var << "\"\n" + << err.what() << "\n"; + return err; + } +} + +SafeResult safe_solve_for_outer_interval(const Expr &c, const string &var) { + try { + return solve_for_outer_interval(c, var); + } catch (InternalError &err) { + std::cerr << "solve_for_outer_interval threw on:\n" + << c << "\n solving for \"" << var << "\"\n" + << err.what() << "\n"; + return err; + } +} + +SafeResult safe_and_condition_over_domain(const Expr &c, const Scope &scope) { + try { + return and_condition_over_domain(c, scope); + } catch (InternalError &err) { + std::cerr << "and_condition_over_domain threw on:\n" + << c << "\n" + << err.what() << "\n"; + return err; + } +} + +Expr random_int_val(FuzzingContext &fuzz, int lo, int hi) { + return cast(Int(32), fuzz.ConsumeIntegralInRange(lo, hi)); +} + +// Returns true if the expression, under the given substitution, contains a +// division or modulo whose divisor simplifies to zero. Halide defines +// div/mod-by-zero to return zero, but the simplifier doesn't always fold +// that consistently across syntactically-different forms -- so solve can +// rearrange an expression into an equivalent shape whose simplified value +// at a concrete substitution differs only because one side gets the +// "returns zero" fold applied while the other doesn't. Skip those samples +// when checking equivalence. Solve often emits Let bindings, so inline +// them first (otherwise Div::b is a variable reference and we can't see +// whether it's zero). +bool has_div_or_mod_by_zero(const Expr &e, const map &vars) { + Expr inlined = substitute_in_all_lets(e); + bool found = false; + auto check_denom = [&](const Expr &denom) { + if (found) return; + if (SafeResult r = safe_simplify(substitute(vars, denom)); r.ok()) { + if (Internal::is_const_zero(r.value())) { + found = true; + } + } + }; + visit_with( + inlined, + [&](auto *self, const Div *op) { + check_denom(op->b); + self->visit_base(op); + }, + [&](auto *self, const Mod *op) { + check_denom(op->b); + self->visit_base(op); + }); + return found; +} + +// Returns true if the expression, under the given substitution, contains a +// narrowing cast whose source value doesn't fit in the destination type. +// Halide's bounds analysis assumes such casts don't overflow (see PR #7814 +// discussion) -- that's a programmer-level contract that the fuzzer's +// random value substitutions can easily violate, and the resulting +// runtime wrap then disagrees with bounds_of's "assumed-fits" prediction. +// Skip those samples when checking contracts that rely on bounds_of. +bool has_overflowing_cast(const Expr &e, const map &vars) { + Expr inlined = substitute_in_all_lets(e); + bool found = false; + auto check_cast = [&](const Cast *op) { + if (found) return; + Type to = op->type; + Type from = op->value.type(); + // Only care about casts between integer/unsigned types that could + // overflow the destination. + if (!(to.is_int_or_uint() && from.is_int_or_uint())) return; + if (to.can_represent(from)) return; + SafeResult r = safe_simplify(substitute(vars, op->value)); + if (!r.ok()) return; + if (auto iv = as_const_int(r.value())) { + if (!to.can_represent(*iv)) found = true; + } else if (auto uv = as_const_uint(r.value())) { + if (!to.can_represent(*uv)) found = true; + } + }; + visit_with( + inlined, + [&](auto *self, const Cast *op) { + check_cast(op); + self->visit_base(op); + }); + return found; +} + +// Test that solve_expression(test, var) produces an expression equivalent to +// `test` under random concrete substitutions of all variables. Modeled after +// the brute-force check at the bottom of Solve.cpp's solve_test(). +bool test_solve_expression_equivalence(RandomExpressionGenerator ®, + const Expr &test, + const string &var, + int samples) { + SafeResult res = safe_solve_expression(test, var); + if (res.failed()) { + return false; + } + Expr solved = res.value().result; + if (!solved.defined()) { + std::cerr << "solve_expression returned an undefined Expr for:\n" + << test << "\n"; + return false; + } + + // Solving again should not throw. + if (safe_solve_expression(solved, var).failed()) { + return false; + } + + map vars; + for (const auto &v : reg.fuzz_vars) { + vars[v.name()] = Expr(); + } + + for (int i = 0; i < samples; i++) { + for (auto &[name, val] : vars) { + val = random_int_val(reg.fuzz, -32, 32); + } + + // Skip samples that invoke div/mod-by-zero in the input: Halide + // defines the result as zero, but the simplifier may apply the + // fold asymmetrically across two syntactically-distinct forms + // that are otherwise semantically equivalent. We don't skip + // based on the *solved* form -- solve must never introduce new + // div/mod-by-zero that wasn't already in the input. + if (has_div_or_mod_by_zero(test, vars) || + has_overflowing_cast(test, vars)) { + continue; + } + + SafeResult test_v = safe_simplify(substitute(vars, test)); + SafeResult solved_v = safe_simplify(substitute(vars, solved)); + if (test_v.failed() || solved_v.failed()) { + return false; + } + + // If either side didn't simplify to a constant, there's likely UB + // (e.g. signed integer overflow) somewhere -- skip this sample. + if (!Internal::is_const(test_v.value()) || !Internal::is_const(solved_v.value())) { + continue; + } + + if (!equal(test_v.value(), solved_v.value())) { + std::cerr << "solve_expression produced a non-equivalent result:\n"; + for (const auto &[name, val] : vars) { + std::cerr << " " << name << " = " << val << "\n"; + } + std::cerr << " variable being solved: " << var << "\n"; + std::cerr << " original: " << test << " -> " << test_v.value() << "\n"; + std::cerr << " solved: " << solved << " -> " << solved_v.value() << "\n"; + return false; + } + } + return true; +} + +// Substitute the given variables and simplify. +Expr subst_and_simplify(const map &vars, const Expr &e) { + return simplify(substitute(vars, e)); +} + +// Returns 1 if `c` simplifies to a true constant, 0 if a false constant, -1 +// otherwise. Used to handle partial results from the simplifier safely. +int try_resolve_bool(const Expr &c) { + Expr s; + if (SafeResult r = safe_simplify(c); r.ok()) { + s = r.value(); + } else { + return -1; + } + if (is_const_one(s)) { + return 1; + } + if (is_const_zero(s)) { + return 0; + } + return -1; +} + +// Test the contracts of solve_for_inner_interval and solve_for_outer_interval +// by sampling values of `var` and checking: +// - if sample is inside the inner interval, the condition must be true +// - if sample is outside the outer interval, the condition must be false +// Non-solving variables are given concrete random values before sampling. +bool test_solve_intervals(RandomExpressionGenerator ®, + const Expr &cond, + const string &var, + int samples) { + internal_assert(cond.type().is_bool()); + + SafeResult inner_res = safe_solve_for_inner_interval(cond, var); + SafeResult outer_res = safe_solve_for_outer_interval(cond, var); + if (inner_res.failed() || outer_res.failed()) { + return false; + } + Interval inner = inner_res.value(); + Interval outer = outer_res.value(); + + map other_vars; + for (const auto &v : reg.fuzz_vars) { + if (v.name() != var) { + other_vars[v.name()] = Expr(); + } + } + + for (int i = 0; i < samples; i++) { + for (auto &[name, val] : other_vars) { + val = random_int_val(reg.fuzz, -16, 16); + } + // Skip substitutions that violate the "assumed not to overflow" + // contract for narrowing int casts. + if (has_overflowing_cast(cond, other_vars)) { + continue; + } + + Expr inner_min_v, inner_max_v, outer_min_v, outer_max_v; + if (inner.has_lower_bound()) inner_min_v = subst_and_simplify(other_vars, inner.min); + if (inner.has_upper_bound()) inner_max_v = subst_and_simplify(other_vars, inner.max); + if (outer.has_lower_bound()) outer_min_v = subst_and_simplify(other_vars, outer.min); + if (outer.has_upper_bound()) outer_max_v = subst_and_simplify(other_vars, outer.max); + Expr cond_sub = substitute(other_vars, cond); + + int val = reg.fuzz.ConsumeIntegralInRange(-64, 64); + Expr var_val = cast(Int(32), val); + int cond_truth = try_resolve_bool(substitute(var, var_val, cond_sub)); + if (cond_truth < 0) { + // Can't resolve (symbolic leftover or UB) -- skip. + continue; + } + + // Inner interval: var_val in [inner.min, inner.max] => cond is true. + // An empty inner interval is a trivial (vacuously true) claim. + int in_inner = inner.is_empty() ? 0 : 1; + if (in_inner == 1 && inner.has_lower_bound()) { + int r = try_resolve_bool(var_val >= inner_min_v); + if (r < 0) { + in_inner = -1; + } else if (r == 0) { + in_inner = 0; + } + } + if (in_inner == 1 && inner.has_upper_bound()) { + int r = try_resolve_bool(var_val <= inner_max_v); + if (r < 0) { + in_inner = -1; + } else if (r == 0) { + in_inner = 0; + } + } + if (in_inner == 1 && cond_truth == 0) { + std::cerr << "solve_for_inner_interval violation\n" + << " cond: " << cond << "\n" + << " var: " << var << " = " << val << "\n" + << " inner interval: [" << inner.min << ", " << inner.max << "]\n"; + for (const auto &[name, v] : other_vars) { + std::cerr << " " << name << " = " << v << "\n"; + } + return false; + } + + // Outer interval: var_val NOT in [outer.min, outer.max] => cond is false. + // An empty outer interval means cond is unsatisfiable, so any sample + // that evaluates to true is a violation. + int out_lb = 0, out_ub = 0; + if (outer.is_empty()) { + out_lb = 1; + } + if (outer.has_lower_bound()) { + int r = try_resolve_bool(var_val < outer_min_v); + if (r < 0) { + out_lb = -1; + } else { + out_lb = r; + } + } + if (outer.has_upper_bound()) { + int r = try_resolve_bool(var_val > outer_max_v); + if (r < 0) { + out_ub = -1; + } else { + out_ub = r; + } + } + if ((out_lb == 1 || out_ub == 1) && cond_truth == 1) { + std::cerr << "solve_for_outer_interval violation\n" + << " cond: " << cond << "\n" + << " var: " << var << " = " << val << "\n" + << " outer interval: [" << outer.min << ", " << outer.max << "]\n"; + for (const auto &[name, v] : other_vars) { + std::cerr << " " << name << " = " << v << "\n"; + } + return false; + } + } + return true; +} + +// Test that and_condition_over_domain(c, scope) implies c on the domain: +// for any concrete assignment of vars within their scope intervals, if the +// weakened condition is true, the original condition must also be true. +bool test_and_condition_over_domain(RandomExpressionGenerator ®, + const Expr &cond, + int samples) { + internal_assert(cond.type().is_bool()); + + Scope scope; + map> ranges; + for (const auto &v : reg.fuzz_vars) { + int a = reg.fuzz.ConsumeIntegralInRange(-16, 16); + int b = reg.fuzz.ConsumeIntegralInRange(-16, 16); + if (a > b) std::swap(a, b); + ranges[v.name()] = {a, b}; + scope.push(v.name(), Interval(cast(Int(32), a), cast(Int(32), b))); + } + + SafeResult weakened_res = safe_and_condition_over_domain(cond, scope); + if (weakened_res.failed()) { + return false; + } + Expr weakened = weakened_res.value(); + + for (int i = 0; i < samples; i++) { + map vars; + for (const auto &[name, r] : ranges) { + vars[name] = random_int_val(reg.fuzz, r.first, r.second); + } + // Skip substitutions that violate the "assumed not to overflow" + // contract for narrowing int casts (see has_overflowing_cast). + if (has_overflowing_cast(cond, vars)) { + continue; + } + int cond_truth = try_resolve_bool(substitute(vars, cond)); + int weak_truth = try_resolve_bool(substitute(vars, weakened)); + if (cond_truth < 0 || weak_truth < 0) { + continue; + } + if (weak_truth == 1 && cond_truth == 0) { + std::cerr << "and_condition_over_domain violation (result does not imply input):\n" + << " cond: " << cond << "\n" + << " weakened: " << weakened << "\n"; + for (const auto &[n, v] : vars) { + std::cerr << " " << n << " = " << v << " (in [" << ranges[n].first + << ", " << ranges[n].second << "])\n"; + } + return false; + } + } + return true; +} + +Expr random_comparison(RandomExpressionGenerator ®, int depth) { + using make_bin_op_fn = Expr (*)(Expr, Expr); + static make_bin_op_fn ops[] = { + EQ::make, + NE::make, + LT::make, + LE::make, + GT::make, + GE::make, + }; + Expr a = reg.random_expr(Int(32), depth); + Expr b = reg.random_expr(Int(32), depth); + return reg.fuzz.PickValueInArray(ops)(a, b); +} + +} // namespace + +FUZZ_TEST(solve, FuzzingContext &fuzz) { + // Depth of the randomly generated expression trees. + constexpr int depth = 6; + // Number of samples to test each invariant at. + constexpr int samples = 20; + + RandomExpressionGenerator reg{fuzz}; + reg.fuzz_types = {Int(8), Int(16), Int(32), Int(64), + UInt(1), UInt(8), UInt(16), UInt(32), UInt(64), + Float(32), Float(64)}; + // Leave gen_shuffles / gen_vector_reduce / gen_reinterpret off for now + // -- those exercise Deinterleaver / shuffle lowering more than solve + // proper. gen_broadcast_of_vector and gen_ramp_of_vector are on so the + // solver sees vector-typed expressions. + reg.gen_shuffles = false; + reg.gen_vector_reduce = false; + reg.gen_reinterpret = false; + + // Pick one of the generator's variables to solve for. + const string var = reg.fuzz_vars[fuzz.ConsumeIntegralInRange(0, reg.fuzz_vars.size() - 1)].name(); + + // solve_expression: arithmetic equivalence. Pick a random width so the + // generator's Broadcast/Ramp lambdas actually fire (they're no-ops on + // scalar types). Vector subtrees containing scalar variables exercise + // the solver's vector-aware handling (see e.g. the Broadcast case in + // src/Solve.cpp's solve_test). + int width = fuzz.PickValueInArray({1, 2, 3, 4, 6, 8}); + Expr test_expr = reg.random_expr(Int(32).with_lanes(width), depth); + if (!test_solve_expression_equivalence(reg, test_expr, var, samples)) { + std::cerr << "Failing expression (C++):\n"; + IRGraphCXXPrinter printer(std::cerr); + printer.print(test_expr); + std::cerr << "Expr final_expr = " << printer.node_names[test_expr.get()] << ";\n"; + std::cerr << " solving for \"" << var << "\"\n"; + return 1; + } + + // solve_expression: also handle comparisons (the solver inverts these). + Expr cmp = random_comparison(reg, depth); + if (!test_solve_expression_equivalence(reg, cmp, var, samples)) { + std::cerr << "Failing comparison (C++):\n"; + IRGraphCXXPrinter printer(std::cerr); + printer.print(cmp); + std::cerr << "Expr final_expr = " << printer.node_names[cmp.get()] << ";\n"; + std::cerr << " solving for \"" << var << "\"\n"; + return 1; + } + + // solve_for_inner_interval / solve_for_outer_interval. + if (!test_solve_intervals(reg, cmp, var, samples)) { + std::cerr << "Failing condition (C++):\n"; + IRGraphCXXPrinter printer(std::cerr); + printer.print(cmp); + std::cerr << "Expr final_expr = " << printer.node_names[cmp.get()] << ";\n"; + std::cerr << " solving for \"" << var << "\"\n"; + return 1; + } + + // Also exercise solve_for_*_interval with compound boolean conditions. + Expr cmp2 = random_comparison(reg, depth); + Expr compound = fuzz.ConsumeBool() ? (cmp && cmp2) : (cmp || cmp2); + if (!test_solve_intervals(reg, compound, var, samples)) { + std::cerr << "Failing compound condition (C++):\n"; + IRGraphCXXPrinter printer(std::cerr); + printer.print(compound); + std::cerr << "Expr final_expr = " << printer.node_names[compound.get()] << ";\n"; + std::cerr << " solving for \"" << var << "\"\n"; + return 1; + } + + // and_condition_over_domain. + if (!test_and_condition_over_domain(reg, compound, samples)) { + std::cerr << "Failing condition (C++):\n"; + IRGraphCXXPrinter printer(std::cerr); + printer.print(compound); + std::cerr << "Expr final_expr = " << printer.node_names[compound.get()] << ";\n"; + return 1; + } + + return 0; +} diff --git a/test/internal.cpp b/test/internal.cpp index 08283fa9cf54..f64bdfbca1a8 100644 --- a/test/internal.cpp +++ b/test/internal.cpp @@ -31,7 +31,6 @@ int main(int argc, const char **argv) { deinterleave_vector_test(); modulus_remainder_test(); cse_test(); - solve_test(); target_test(); cplusplus_mangle_test(); is_monotonic_test();