Skip to content

Commit

Permalink
Bug fix for lossless_cast with minor additions (#5459)
Browse files Browse the repository at this point in the history
* Bug fix for lossless_cast with minor additions

The bug can seen for types where lossless_cast type can represent
cast->value.type() but not cast->type. For eg:

lossless_cast(UInt(16), cast(Int(8), Variable::make(UInt(16), e))) returns
(uint16)e which is incorrect.

The patch also adds lossless_cast of Mod and Ramp expressions.

* Handle Mod for negative numbers in lossless_cast.

* Add lossless_cast test for VectorReduce.

* Rename check to check_lossless_cast.

* clang-format complains

* Remove Ramp and Mod from lossless_cast.

* Minor changes

* Update test/correctness/CMakeLists.txt

Co-authored-by: Ankit Aggarwal <aankit@quicinc.com>
  • Loading branch information
aankit-ca and aankit-quic committed Dec 14, 2020
1 parent ed8f7c2 commit 9fea59f
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/IROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ Expr lossless_cast(Type t, Expr e) {
}

if (const Cast *c = e.as<Cast>()) {
if (t.can_represent(c->value.type())) {
if (c->type.can_represent(c->value.type())) {
// We can recurse into widening casts.
return lossless_cast(t, c->value);
} else {
Expand Down Expand Up @@ -491,6 +491,7 @@ Expr lossless_cast(Type t, Expr e) {
Type narrower = reduce->value.type().with_bits(t.bits() / 2);
Expr val = lossless_cast(narrower, reduce->value);
if (val.defined()) {
val = cast(narrower.with_bits(t.bits()), val);
return VectorReduce::make(reduce->op, val, reduce->type.lanes());
}
}
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ tests(GROUPS correctness
logical.cpp
loop_invariant_extern_calls.cpp
loop_level_generator_param.cpp
lossless_cast.cpp
lots_of_dimensions.cpp
lots_of_loop_invariants.cpp
make_struct.cpp
Expand Down
65 changes: 65 additions & 0 deletions test/correctness/lossless_cast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include "Halide.h"

using namespace Halide;
using namespace Halide::Internal;

int check_lossless_cast(const Type &t, const Expr &in, const Expr &correct) {
Expr result = lossless_cast(t, in);
if (!equal(result, correct)) {
std::cout << "Incorrect lossless_cast result:\nlossless_cast("
<< t << ", " << in << ") gave: " << result
<< " but expected was: " << correct << "\n";
return 1;
}
return 0;
}

int lossless_cast_test() {
Expr x = Variable::make(Int(32), "x");
Type u8 = UInt(8);
Type u16 = UInt(16);
Type u32 = UInt(32);
Type i8 = Int(8);
Type i16 = Int(16);
Type i32 = Int(32);
Type u8x = UInt(8, 4);
Type u16x = UInt(16, 4);
Type u32x = UInt(32, 4);
Expr var_u8 = Variable::make(u8, "x");
Expr var_u16 = Variable::make(u16, "x");
Expr var_u8x = Variable::make(u8x, "x");

int res = 0;

Expr e = cast(u8, x);
res |= check_lossless_cast(i32, e, cast(i32, e));

e = cast(u8, x);
res |= check_lossless_cast(i32, e, cast(i32, e));

e = cast(i8, var_u16);
res |= check_lossless_cast(u16, e, Expr());

e = cast(i16, var_u16);
res |= check_lossless_cast(u16, e, Expr());

e = cast(u32, var_u8);
res |= check_lossless_cast(u16, e, cast(u16, var_u8));

e = VectorReduce::make(VectorReduce::Add, cast(u16x, var_u8x), 1);
res |= check_lossless_cast(u16, e, cast(u16, e));

e = VectorReduce::make(VectorReduce::Add, cast(u32x, var_u8x), 1);
res |= check_lossless_cast(u16, e, VectorReduce::make(VectorReduce::Add, cast(u16x, var_u8x), 1));

return res;
}

int main() {
if (lossless_cast_test()) {
printf("lossless_cast test failed!\n");
return 1;
}
printf("Success!\n");
return 0;
}

0 comments on commit 9fea59f

Please sign in to comment.