From ad2386306cb3ef463ed0e5f20348fcf68f8731e7 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 17 Mar 2026 13:41:57 -0700 Subject: [PATCH] cast(immediate) should be make_const(t, immediate) Doing a Halide cast of a c++ int constructs an immediate Expr (e.g. IntImm) and then eagerly folds it to a different type of immediate in the Cast constructor. It's better to just construct the immediate using the desired type to begin with. Co-authored-by: Claude --- src/AddAtomicMutex.cpp | 4 ++-- src/Associativity.cpp | 4 ++-- src/BoundaryConditions.cpp | 2 +- src/Bounds.cpp | 36 ++++++++++++++++++------------------ src/CodeGen_Hexagon.cpp | 2 +- src/CodeGen_LLVM.cpp | 2 +- src/Generator.cpp | 2 +- src/Generator.h | 2 +- src/OffloadGPULoops.cpp | 8 ++++---- src/Profiling.cpp | 4 ++-- 10 files changed, 33 insertions(+), 33 deletions(-) diff --git a/src/AddAtomicMutex.cpp b/src/AddAtomicMutex.cpp index cf3b0ae8bb89..35ff29ba1330 100644 --- a/src/AddAtomicMutex.cpp +++ b/src/AddAtomicMutex.cpp @@ -326,7 +326,7 @@ class AddAtomicMutex : public IRMutator { } if (const std::string *mutex_name = needs_mutex_allocation.find(producer_name)) { - Expr extent = cast(1); // uint64_t to handle LargeBuffers + Expr extent = make_one(UInt(64)); // uint64_t to handle LargeBuffers for (const Expr &e : op->extents) { extent = extent * e; } @@ -378,7 +378,7 @@ class AddAtomicMutex : public IRMutator { if (const std::string *mutex_name = needs_mutex_allocation.find(it->first)) { // All output buffers in a Tuple have the same extent. OutputImageParam output_buffer = Func(f).output_buffers()[0]; - Expr extent = cast(1); // uint64_t to handle LargeBuffers + Expr extent = make_one(UInt(64)); // uint64_t to handle LargeBuffers for (int i = 0; i < output_buffer.dimensions(); i++) { extent *= output_buffer.dim(i).extent(); } diff --git a/src/Associativity.cpp b/src/Associativity.cpp index bd67f0245af6..421fff6278f3 100644 --- a/src/Associativity.cpp +++ b/src/Associativity.cpp @@ -537,8 +537,8 @@ void associativity_test() { Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, FunctionPtr(), 0); for (const Expr &e : {cast(min(cast(x) + y, 255)), - select(x > 255 - y, cast(255), y), - select(x < -y, y, cast(255)), + select(x > 255 - y, make_const(UInt(8), 255), y), + select(x < -y, y, make_const(UInt(8), 255)), saturating_add(x, y), saturating_add(y, x), saturating_cast(widening_add(x, y))}) { diff --git a/src/BoundaryConditions.cpp b/src/BoundaryConditions.cpp index 58f2d558ae83..a809b7cc87a5 100644 --- a/src/BoundaryConditions.cpp +++ b/src/BoundaryConditions.cpp @@ -45,7 +45,7 @@ Func constant_exterior(const Func &source, const Tuple &value, << ") than dimensions (" << args.size() << ") Func " << source.name() << " has.\n"; - Expr out_of_bounds = cast(false); + Expr out_of_bounds = Halide::Internal::make_zero(Bool()); for (size_t i = 0; i < bounds.size(); i++) { const Var &arg_var = args[i]; Expr min = bounds[i].min; diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 41fac3f2dba5..0267485b0a79 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -3623,22 +3623,22 @@ void bounds_test() { scope.pop("x"); // Check some bitwise ops. - check(scope, (cast(x) & cast(7)), u8(0), u8(7)); - check(scope, (cast(3) & cast(2)), u8(2), u8(2)); - check(scope, (cast(1) | cast(2)), u8(3), u8(3)); - check(scope, (cast(3) ^ cast(2)), u8(1), u8(1)); - check(scope, (~cast(3)), u8(0xfc), u8(0xfc)); + check(scope, (cast(x) & make_const(UInt(8), 7)), u8(0), u8(7)); + check(scope, (make_const(UInt(8), 3) & make_const(UInt(8), 2)), u8(2), u8(2)); + check(scope, (make_one(UInt(8)) | make_const(UInt(8), 2)), u8(3), u8(3)); + check(scope, (make_const(UInt(8), 3) ^ make_const(UInt(8), 2)), u8(1), u8(1)); + check(scope, (~make_const(UInt(8), 3)), u8(0xfc), u8(0xfc)); check(scope, cast(x + 5) & cast(x + 3), u8(0), u8(13)); check(scope, cast(x - 5) & cast(x + 3), i8(0), i8(13)); check(scope, cast(2 * x - 5) & cast(x - 3), i8(-128), i8(15)); check(scope, cast(x + 5) | cast(x + 3), u8(5), u8(255)); check(scope, cast(x + 5) | cast(x + 3), i8(3), i8(127)); check(scope, ~cast(x), u8(-11), u8(-1)); - check(scope, (cast(x) >> cast(1)), u8(0), u8(5)); - check(scope, (cast(10) >> cast(1)), u8(5), u8(5)); - check(scope, (cast(x + 3) << cast(1)), u8(6), u8(26)); - check(scope, (cast(x + 3) << cast(7)), u8(0), u8(255)); // Overflows - check(scope, (cast(5) << cast(1)), u8(10), u8(10)); + check(scope, (cast(x) >> make_one(UInt(8))), u8(0), u8(5)); + check(scope, (make_const(UInt(8), 10) >> make_one(UInt(8))), u8(5), u8(5)); + check(scope, (cast(x + 3) << make_one(UInt(8))), u8(6), u8(26)); + check(scope, (cast(x + 3) << make_const(UInt(8), 7)), u8(0), u8(255)); // Overflows + check(scope, (make_const(UInt(8), 5) << make_one(UInt(8))), u8(10), u8(10)); check(scope, (x << 12), 0, 10 << 12); check(scope, x & 4095, 0, 10); // LHS known to be positive check(scope, x & 123, 0, 10); // Doesn't have to be a precise bitmask @@ -3712,7 +3712,7 @@ void bounds_test() { u16(0), u16(4095)); check(scope, - cast(clamp(cast(x ^ y), cast(0), cast(128))), + cast(clamp(cast(x ^ y), make_zero(UInt(16)), make_const(UInt(16), 128))), u8(0), u8(128)); Expr u8_1 = cast(Load::make(Int(8), "buf", x, Buffer<>(), Parameter(), const_true(), ModulusRemainder())); @@ -3720,19 +3720,19 @@ void bounds_test() { check(scope, cast(u8_1) + cast(u8_2), u16(0), u16(255 * 2)); - check(scope, saturating_cast(clamp(x, 5, 10)), cast(5), cast(10)); + check(scope, saturating_cast(clamp(x, 5, 10)), make_const(UInt(8), 5), make_const(UInt(8), 10)); { scope.push("x", Interval(UInt(32).min(), UInt(32).max())); - check(scope, saturating_cast(max(cast(x), cast(5))), cast(5), Int(32).max()); + check(scope, saturating_cast(max(cast(x), make_const(UInt(32), 5))), make_const(Int(32), 5), Int(32).max()); scope.pop("x"); } { Expr z = Variable::make(Float(32), "z"); - scope.push("z", Interval(cast(-1), cast(1))); - check(scope, saturating_cast(z), cast(-1), cast(1)); - check(scope, saturating_cast(z), cast(-1), cast(1)); - check(scope, saturating_cast(z), cast(-1), cast(1)); - check(scope, saturating_cast(z), cast(0), cast(1)); + scope.push("z", Interval(make_const(Float(32), -1), make_one(Float(32)))); + check(scope, saturating_cast(z), make_const(Int(32), -1), make_one(Int(32))); + check(scope, saturating_cast(z), make_const(Float(64), -1), make_one(Float(64))); + check(scope, saturating_cast(z), make_const(Float(16), -1), make_one(Float(16))); + check(scope, saturating_cast(z), make_zero(UInt(8)), make_one(UInt(8))); scope.pop("z"); } { diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index 065dcebd1a64..78d1cbe0ed47 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -170,7 +170,7 @@ Stmt acquire_hvx_context(Stmt stmt, const Target &target) { // Modify the stmt to add a call to halide_qurt_hvx_lock, and // register a destructor to call halide_qurt_hvx_unlock. Stmt check_hvx_lock = call_halide_qurt_hvx_lock(target); - Expr dummy_obj = reinterpret(Handle(), cast(1)); + Expr dummy_obj = reinterpret(Handle(), make_one(UInt(64))); Expr hvx_unlock = Call::make(Handle(), Call::register_destructor, {Expr("halide_qurt_hvx_unlock_as_destructor"), dummy_obj}, diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 300dfa096a1e..07d8cbb31a08 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -3989,7 +3989,7 @@ void CodeGen_LLVM::codegen_asserts(const vector &asserts) { // Mix all the conditions together into a bitmask - Expr bitmask = cast(1) << 63; + Expr bitmask = make_const(UInt(64), ((uint64_t)1) << 63); for (size_t i = 0; i < asserts.size(); i++) { bitmask = bitmask | (cast(!asserts[i]->condition) << i); } diff --git a/src/Generator.cpp b/src/Generator.cpp index 1b60980418c1..b28bea0975c2 100644 --- a/src/Generator.cpp +++ b/src/Generator.cpp @@ -2167,7 +2167,7 @@ void generator_test() { const std::vector a = {1, 2, 3, 4}; Var x; Func fn_typed, fn_untyped; - fn_typed(x) = cast(38); + fn_typed(x) = make_const(Int(16), 38); fn_untyped(x) = 32.f; const std::vector fn_array = {fn_untyped, fn_untyped}; diff --git a/src/Generator.h b/src/Generator.h index 26af0aabaa51..8aaa6f1486a9 100644 --- a/src/Generator.h +++ b/src/Generator.h @@ -2065,7 +2065,7 @@ class GeneratorInput_Scalar : public GeneratorInputImpl { void set_estimate(const TBase &value) { this->check_gio_access(); user_assert(value == nullptr) << "nullptr is the only valid estimate for Input"; - Expr e = reinterpret(type_of(), cast(0)); + Expr e = reinterpret(type_of(), make_zero(UInt(64))); for (Parameter &p : this->parameters_) { p.set_estimate(e); } diff --git a/src/OffloadGPULoops.cpp b/src/OffloadGPULoops.cpp index 11b8c3ccecf3..376dcd0b9949 100644 --- a/src/OffloadGPULoops.cpp +++ b/src/OffloadGPULoops.cpp @@ -200,18 +200,18 @@ class InjectGpuOffload : public IRMutator { arg_types_or_sizes.emplace_back(cast(target_size_t_type, i.is_buffer ? 8 : i.type.bytes())); } - arg_is_buffer.emplace_back(cast(i.is_buffer)); + arg_is_buffer.emplace_back(make_const(UInt(8), (int)i.is_buffer)); } // nullptr-terminate the lists - args.emplace_back(reinterpret(Handle(), cast(0))); + args.emplace_back(reinterpret(Handle(), make_zero(UInt(64)))); if (runtime_run_takes_types) { internal_assert(sizeof(halide_type_t) == sizeof(uint32_t)); - arg_types_or_sizes.emplace_back(cast(0)); + arg_types_or_sizes.emplace_back(make_zero(UInt(32))); } else { arg_types_or_sizes.emplace_back(cast(target_size_t_type, 0)); } - arg_is_buffer.emplace_back(cast(0)); + arg_is_buffer.emplace_back(make_zero(UInt(8))); debug(3) << "bounds.num_blocks[0] = " << bounds.num_blocks[0] << "\n"; debug(3) << "bounds.num_blocks[1] = " << bounds.num_blocks[1] << "\n"; diff --git a/src/Profiling.cpp b/src/Profiling.cpp index c8054e83544e..a6bf13483751 100644 --- a/src/Profiling.cpp +++ b/src/Profiling.cpp @@ -201,7 +201,7 @@ class InjectProfiling : public IRMutator { Stmt unconditionally_set_current_func(int id) { Stmt s = Evaluate::make(Call::make(Int(32), "halide_profiler_set_current_func", - {profiler_instance, id, reinterpret(Handle(), cast(0))}, Call::Extern)); + {profiler_instance, id, reinterpret(Handle(), make_zero(UInt(64)))}, Call::Extern)); return s; } @@ -210,7 +210,7 @@ class InjectProfiling : public IRMutator { return Evaluate::make(0); } most_recently_set_func = id; - Expr last_arg = in_leaf_task ? profiler_local_sampling_token : reinterpret(Handle(), cast(0)); + Expr last_arg = in_leaf_task ? profiler_local_sampling_token : reinterpret(Handle(), make_zero(UInt(64))); // This call gets inlined and becomes a single store instruction. Stmt s = Evaluate::make(Call::make(Int(32), "halide_profiler_set_current_func", {profiler_instance, id, last_arg}, Call::Extern));