-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bug fix for lossless_cast with minor additions #5459
Changes from 3 commits
d01f991
ddc443c
eda67db
cc4f15c
f56bc16
ae36348
2bc1d23
4ae8be0
2422c62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
#include "IRMutator.h" | ||
#include "IROperator.h" | ||
#include "IRPrinter.h" | ||
#include "Simplify.h" | ||
#include "Util.h" | ||
#include "Var.h" | ||
|
||
|
@@ -438,6 +439,80 @@ Expr const_false(int w) { | |
return make_zero(UInt(1, w)); | ||
} | ||
|
||
void check(const Type &t, const Expr &in, const Expr &correct) { | ||
Expr result = lossless_cast(t, in); | ||
internal_assert(equal(result, correct)) | ||
<< "Incorrect lossless_cast result:\nlossless_cast(" | ||
<< t << ", " | ||
<< in | ||
<< ") gave: " | ||
<< result | ||
<< " but expected was: " | ||
<< correct << "\n"; | ||
} | ||
|
||
void lossless_cast_test() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think our convention would be to put this at the bottom of the file. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
Expr x = Variable::make(Int(32), "x"); | ||
Type u8 = UInt(8); | ||
Type u16 = UInt(16); | ||
Type u32 = UInt(32); | ||
Type i8 = Int(8); | ||
Type i16 = Int(16); | ||
Type i32 = Int(32); | ||
Type u8x = UInt(8, 4); | ||
Type u16x = UInt(16, 4); | ||
Type u32x = UInt(32, 4); | ||
Expr var_i8 = Variable::make(i8, "x"); | ||
Expr var_i16 = Variable::make(i16, "x"); | ||
Expr var_u8 = Variable::make(u8, "x"); | ||
Expr var_u16 = Variable::make(u16, "x"); | ||
Expr var_u8x = Variable::make(u8x, "x"); | ||
|
||
Expr e = Ramp::make(0, 1, 4); | ||
check(u8x, e, cast(u8x, e)); | ||
|
||
// Overflowing ramp | ||
e = Ramp::make(make_const(u16, 5), make_const(u16, 32800), 3); | ||
check(u8, e, Expr()); | ||
|
||
e = x % 4; | ||
check(UInt(8), e, cast(UInt(8), e)); | ||
|
||
e = var_i8 % make_const(i8, -128); | ||
check(UInt(8), e, cast(UInt(8), e)); | ||
|
||
e = var_i16 % make_const(i16, -256); | ||
check(UInt(8), e, cast(UInt(8), e)); | ||
|
||
e = var_u16 % make_const(u16, 256); | ||
check(UInt(8), e, cast(UInt(8), e)); | ||
|
||
e = var_u8 % Variable::make(u8, "y"); | ||
check(UInt(8), e, cast(UInt(8), e)); | ||
|
||
e = cast(u8, x); | ||
check(i32, e, cast(i32, e)); | ||
|
||
e = cast(u8, x); | ||
check(i32, e, cast(i32, e)); | ||
|
||
e = cast(i8, var_u16); | ||
check(u16, e, Expr()); | ||
|
||
e = cast(i16, var_u16); | ||
check(u16, e, Expr()); | ||
|
||
e = cast(u32, var_u8); | ||
check(u16, e, cast(u16, var_u8)); | ||
|
||
e = VectorReduce::make(VectorReduce::Add, cast(u16x, var_u8x), 1); | ||
check(u16, e, cast(u16, e)); | ||
|
||
e = VectorReduce::make(VectorReduce::Add, cast(u32x, var_u8x), 1); | ||
check(u16, e, VectorReduce::make(VectorReduce::Add, cast(u16x, var_u8x), 1)); | ||
|
||
debug(0) << "lossless_cast test passed\n"; | ||
} | ||
Expr lossless_cast(Type t, Expr e) { | ||
if (!e.defined() || t == e.type()) { | ||
return e; | ||
|
@@ -446,8 +521,9 @@ Expr lossless_cast(Type t, Expr e) { | |
} | ||
|
||
if (const Cast *c = e.as<Cast>()) { | ||
if (t.can_represent(c->value.type())) { | ||
// We can recurse into widening casts. | ||
// We can recurse into widening casts. | ||
if (c->type.can_represent(c->value.type()) && | ||
t.can_represent(c->value.type())) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can get rid of this second part of the condition, and leave that to the recursion? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right |
||
return lossless_cast(t, c->value); | ||
} else { | ||
return Expr(); | ||
|
@@ -537,6 +613,7 @@ Expr lossless_cast(Type t, Expr e) { | |
Type narrower = reduce->value.type().with_bits(t.bits() / 2); | ||
Expr val = lossless_cast(narrower, reduce->value); | ||
if (val.defined()) { | ||
val = cast(narrower.with_bits(t.bits()), val); | ||
return VectorReduce::make(reduce->op, val, reduce->type.lanes()); | ||
} | ||
} | ||
|
@@ -566,6 +643,42 @@ Expr lossless_cast(Type t, Expr e) { | |
return Shuffle::make(vecs, shuf->indices); | ||
} | ||
|
||
if (const Mod *mod = e.as<Mod>()) { | ||
Expr val; | ||
if (e.type().is_uint()) { | ||
val = simplify(mod->b + make_const(e.type(), -1)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens here and below if mod->b is zero? |
||
} else if (e.type().is_int()) { | ||
// 0 <= a%b < |b| | ||
Type wide_ty = e.type().with_bits(e.type().bits() * 2); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the cases where you return cast(t, e), could you just do this:
? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or maybe
Basically I'm wondering if we can do something that appeals to verified code to reason about the bounds of mods and ramps, rather than having to worry about this method being correct. It's hard to get bounds right. |
||
Expr b = cast(wide_ty, mod->b); | ||
Expr minus_one = make_const(wide_ty, -1); | ||
val = simplify(Max::make(b, minus_one * b) + minus_one); | ||
} | ||
if (lossless_cast(t, val).defined()) { | ||
return cast(t, e); | ||
} else { | ||
return Expr(); | ||
} | ||
} | ||
|
||
if (const Ramp *ramp = e.as<Ramp>()) { | ||
if (t.bits() > 32) { | ||
return Expr(); | ||
} | ||
Type ty = t.with_lanes(ramp->type.lanes() / t.lanes()); | ||
Type wide_ty = ty.with_bits(64); | ||
Expr first = ramp->base; | ||
// Cast to wide_ty to prevent overflows. | ||
Expr last = simplify(cast(wide_ty, first) + | ||
cast(wide_ty, ramp->lanes - 1) * cast(wide_ty, ramp->stride)); | ||
if (lossless_cast(ty, first).defined() && | ||
lossless_cast(ty, last).defined()) { | ||
return cast(t, e); | ||
} else { | ||
return Expr(); | ||
} | ||
} | ||
|
||
return Expr(); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename check_lossless_cast ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done