-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bug fix for lossless_cast with minor additions #5459
Changes from 1 commit
d01f991
ddc443c
eda67db
cc4f15c
f56bc16
ae36348
2bc1d23
4ae8be0
2422c62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
#include "IRMutator.h" | ||
#include "IROperator.h" | ||
#include "IRPrinter.h" | ||
#include "Simplify.h" | ||
#include "Util.h" | ||
#include "Var.h" | ||
|
||
|
@@ -438,6 +439,55 @@ Expr const_false(int w) { | |
return make_zero(UInt(1, w)); | ||
} | ||
|
||
void check(Type t, const Expr &in, const Expr &correct) { | ||
Expr result = lossless_cast(t, in); | ||
internal_assert(equal(result, correct)) | ||
<< "Incorrect lossless_cast result:\nlossless_cast(" | ||
<< t << ", " | ||
<< in | ||
<< ") gave: " | ||
<< result | ||
<< " but expected was: " | ||
<< correct << "\n"; | ||
} | ||
|
||
void lossless_cast_test() { | ||
Expr x = Variable::make(Int(32), "x"); | ||
Type u8 = UInt(8); | ||
Type u16 = UInt(16); | ||
Type u32 = UInt(32); | ||
Type i8 = Int(8); | ||
Type i16 = Int(8); | ||
Type i32 = Int(32); | ||
Type u8x = UInt(8, 4); | ||
|
||
Expr e = Ramp::make(0, 1, 4); | ||
check(u8x, e, cast(u8x, e)); | ||
|
||
// Overflowing ramp | ||
e = Ramp::make(make_const(u16, 5), make_const(u16, 32800), 3); | ||
check(u8, e, Expr()); | ||
|
||
e = x % 4; | ||
check(UInt(8), e, cast(UInt(8), e)); | ||
|
||
e = cast(u8, x); | ||
check(i32, e, cast(i32, e)); | ||
|
||
e = cast(u8, x); | ||
check(i32, e, cast(i32, e)); | ||
|
||
e = cast(i8, Variable::make(u16, "x")); | ||
check(u16, e, Expr()); | ||
|
||
e = cast(i16, Variable::make(u16, "x")); | ||
check(u16, e, Expr()); | ||
|
||
e = cast(u32, Variable::make(u8, "x")); | ||
check(u16, e, cast(u16, Variable::make(u8, "x"))); | ||
|
||
debug(0) << "lossless_cast test passed\n"; | ||
} | ||
Expr lossless_cast(Type t, Expr e) { | ||
if (!e.defined() || t == e.type()) { | ||
return e; | ||
|
@@ -446,8 +496,9 @@ Expr lossless_cast(Type t, Expr e) { | |
} | ||
|
||
if (const Cast *c = e.as<Cast>()) { | ||
if (t.can_represent(c->value.type())) { | ||
// We can recurse into widening casts. | ||
// We can recurse into widening casts. | ||
if (c->type.can_represent(c->value.type()) && | ||
t.can_represent(c->value.type())) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can get rid of this second part of the condition, and leave that to the recursion? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right |
||
return lossless_cast(t, c->value); | ||
} else { | ||
return Expr(); | ||
|
@@ -537,6 +588,7 @@ Expr lossless_cast(Type t, Expr e) { | |
Type narrower = reduce->value.type().with_bits(t.bits() / 2); | ||
Expr val = lossless_cast(narrower, reduce->value); | ||
if (val.defined()) { | ||
val = cast(narrower.with_bits(t.bits()), val); | ||
return VectorReduce::make(reduce->op, val, reduce->type.lanes()); | ||
} | ||
} | ||
|
@@ -566,6 +618,32 @@ Expr lossless_cast(Type t, Expr e) { | |
return Shuffle::make(vecs, shuf->indices); | ||
} | ||
|
||
if (const Mod *mod = e.as<Mod>()) { | ||
if (lossless_cast(t, mod->b).defined()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Making some changes for handling negative numbers here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
return cast(t, e); | ||
} else { | ||
return Expr(); | ||
} | ||
} | ||
|
||
if (const Ramp *ramp = e.as<Ramp>()) { | ||
if (t.bits() > 32) { | ||
return Expr(); | ||
} | ||
Type ty = t.with_lanes(ramp->type.lanes() / t.lanes()); | ||
Type wide_ty = ty.with_bits(64); | ||
Expr first = ramp->base; | ||
// Cast to wide_ty to prevent overflows. | ||
Expr last = simplify(cast(wide_ty, first) + | ||
cast(wide_ty, ramp->lanes - 1) * cast(wide_ty, ramp->stride)); | ||
if (lossless_cast(ty, first).defined() && | ||
lossless_cast(ty, last).defined()) { | ||
return cast(t, e); | ||
} else { | ||
return Expr(); | ||
} | ||
} | ||
|
||
return Expr(); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think our convention would be to put this at the bottom of the file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done