Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix for lossless_cast with minor additions #5459

Merged
merged 9 commits into from
Dec 14, 2020
117 changes: 115 additions & 2 deletions src/IROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "IRMutator.h"
#include "IROperator.h"
#include "IRPrinter.h"
#include "Simplify.h"
#include "Util.h"
#include "Var.h"

Expand Down Expand Up @@ -438,6 +439,80 @@ Expr const_false(int w) {
return make_zero(UInt(1, w));
}

void check(const Type &t, const Expr &in, const Expr &correct) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename check_lossless_cast ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Expr result = lossless_cast(t, in);
internal_assert(equal(result, correct))
<< "Incorrect lossless_cast result:\nlossless_cast("
<< t << ", "
<< in
<< ") gave: "
<< result
<< " but expected was: "
<< correct << "\n";
}

void lossless_cast_test() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think our convention would be to put this at the bottom of the file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Expr x = Variable::make(Int(32), "x");
Type u8 = UInt(8);
Type u16 = UInt(16);
Type u32 = UInt(32);
Type i8 = Int(8);
Type i16 = Int(16);
Type i32 = Int(32);
Type u8x = UInt(8, 4);
Type u16x = UInt(16, 4);
Type u32x = UInt(32, 4);
Expr var_i8 = Variable::make(i8, "x");
Expr var_i16 = Variable::make(i16, "x");
Expr var_u8 = Variable::make(u8, "x");
Expr var_u16 = Variable::make(u16, "x");
Expr var_u8x = Variable::make(u8x, "x");

Expr e = Ramp::make(0, 1, 4);
check(u8x, e, cast(u8x, e));

// Overflowing ramp
e = Ramp::make(make_const(u16, 5), make_const(u16, 32800), 3);
check(u8, e, Expr());

e = x % 4;
check(UInt(8), e, cast(UInt(8), e));

e = var_i8 % make_const(i8, -128);
check(UInt(8), e, cast(UInt(8), e));

e = var_i16 % make_const(i16, -256);
check(UInt(8), e, cast(UInt(8), e));

e = var_u16 % make_const(u16, 256);
check(UInt(8), e, cast(UInt(8), e));

e = var_u8 % Variable::make(u8, "y");
check(UInt(8), e, cast(UInt(8), e));

e = cast(u8, x);
check(i32, e, cast(i32, e));

e = cast(u8, x);
check(i32, e, cast(i32, e));

e = cast(i8, var_u16);
check(u16, e, Expr());

e = cast(i16, var_u16);
check(u16, e, Expr());

e = cast(u32, var_u8);
check(u16, e, cast(u16, var_u8));

e = VectorReduce::make(VectorReduce::Add, cast(u16x, var_u8x), 1);
check(u16, e, cast(u16, e));

e = VectorReduce::make(VectorReduce::Add, cast(u32x, var_u8x), 1);
check(u16, e, VectorReduce::make(VectorReduce::Add, cast(u16x, var_u8x), 1));

debug(0) << "lossless_cast test passed\n";
}
Expr lossless_cast(Type t, Expr e) {
if (!e.defined() || t == e.type()) {
return e;
Expand All @@ -446,8 +521,9 @@ Expr lossless_cast(Type t, Expr e) {
}

if (const Cast *c = e.as<Cast>()) {
if (t.can_represent(c->value.type())) {
// We can recurse into widening casts.
// We can recurse into widening casts.
if (c->type.can_represent(c->value.type()) &&
t.can_represent(c->value.type())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can get rid of this second part of the condition, and leave that to the recursion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right

return lossless_cast(t, c->value);
} else {
return Expr();
Expand Down Expand Up @@ -537,6 +613,7 @@ Expr lossless_cast(Type t, Expr e) {
Type narrower = reduce->value.type().with_bits(t.bits() / 2);
Expr val = lossless_cast(narrower, reduce->value);
if (val.defined()) {
val = cast(narrower.with_bits(t.bits()), val);
return VectorReduce::make(reduce->op, val, reduce->type.lanes());
}
}
Expand Down Expand Up @@ -566,6 +643,42 @@ Expr lossless_cast(Type t, Expr e) {
return Shuffle::make(vecs, shuf->indices);
}

if (const Mod *mod = e.as<Mod>()) {
Expr val;
if (e.type().is_uint()) {
val = simplify(mod->b + make_const(e.type(), -1));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens here and below if mod->b is zero?

} else if (e.type().is_int()) {
// 0 <= a%b < |b|
Type wide_ty = e.type().with_bits(e.type().bits() * 2);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the cases where you return cast(t, e), could you just do this:

Expr x = cast(t, e);
if (can_prove(cast(e.type(), x) == e) {
  return x;
}

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe

if (e.type().can_represent(t.min()) && 
    e.type().can_represent(t.max()) && 
    can_prove(e >= cast(e.type(), t.min()) && 
    e <= cast(e.type(), t.max())) 

Basically I'm wondering if we can do something that appeals to verified code to reason about the bounds of mods and ramps, rather than having to worry about this method being correct. It's hard to get bounds right.

Expr b = cast(wide_ty, mod->b);
Expr minus_one = make_const(wide_ty, -1);
val = simplify(Max::make(b, minus_one * b) + minus_one);
}
if (lossless_cast(t, val).defined()) {
return cast(t, e);
} else {
return Expr();
}
}

if (const Ramp *ramp = e.as<Ramp>()) {
if (t.bits() > 32) {
return Expr();
}
Type ty = t.with_lanes(ramp->type.lanes() / t.lanes());
Type wide_ty = ty.with_bits(64);
Expr first = ramp->base;
// Cast to wide_ty to prevent overflows.
Expr last = simplify(cast(wide_ty, first) +
cast(wide_ty, ramp->lanes - 1) * cast(wide_ty, ramp->stride));
if (lossless_cast(ty, first).defined() &&
lossless_cast(ty, last).defined()) {
return cast(t, e);
} else {
return Expr();
}
}

return Expr();
}

Expand Down
2 changes: 2 additions & 0 deletions src/IROperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ Expr const_false(int lanes = 1);
* Expr. */
Expr lossless_cast(Type t, Expr e);

void lossless_cast_test();

/** Coerce the two expressions to have the same type, using C-style
* casting rules. For the purposes of casting, a boolean type is
* UInt(1). We use the following procedure:
Expand Down
2 changes: 2 additions & 0 deletions test/internal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "IR.h"
#include "IREquality.h"
#include "IRMatch.h"
#include "IROperator.h"
#include "IRPrinter.h"
#include "Interval.h"
#include "ModulusRemainder.h"
Expand Down Expand Up @@ -42,6 +43,7 @@ int main(int argc, const char **argv) {
generator_test();
propagate_estimate_test();
uniquify_variable_names_test();
lossless_cast_test();

printf("Success!\n");
return 0;
Expand Down