Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix for lossless_cast with minor additions #5459

Merged
merged 9 commits into from
Dec 14, 2020
122 changes: 120 additions & 2 deletions src/IROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "IRMutator.h"
#include "IROperator.h"
#include "IRPrinter.h"
#include "Simplify.h"
#include "Util.h"
#include "Var.h"

Expand Down Expand Up @@ -446,8 +447,9 @@ Expr lossless_cast(Type t, Expr e) {
}

if (const Cast *c = e.as<Cast>()) {
if (t.can_represent(c->value.type())) {
// We can recurse into widening casts.
// We can recurse into widening casts.
if (c->type.can_represent(c->value.type()) &&
t.can_represent(c->value.type())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can get rid of this second part of the condition, and leave that to the recursion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right

return lossless_cast(t, c->value);
} else {
return Expr();
Expand Down Expand Up @@ -537,6 +539,7 @@ Expr lossless_cast(Type t, Expr e) {
Type narrower = reduce->value.type().with_bits(t.bits() / 2);
Expr val = lossless_cast(narrower, reduce->value);
if (val.defined()) {
val = cast(narrower.with_bits(t.bits()), val);
return VectorReduce::make(reduce->op, val, reduce->type.lanes());
}
}
Expand Down Expand Up @@ -566,6 +569,42 @@ Expr lossless_cast(Type t, Expr e) {
return Shuffle::make(vecs, shuf->indices);
}

if (const Mod *mod = e.as<Mod>()) {
Expr val;
if (e.type().is_uint()) {
val = simplify(mod->b + make_const(e.type(), -1));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens here and below if mod->b is zero?

} else if (e.type().is_int()) {
// 0 <= a%b < |b|
Type wide_ty = e.type().with_bits(e.type().bits() * 2);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the cases where you return cast(t, e), could you just do this:

Expr x = cast(t, e);
if (can_prove(cast(e.type(), x) == e) {
  return x;
}

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe

if (e.type().can_represent(t.min()) && 
    e.type().can_represent(t.max()) && 
    can_prove(e >= cast(e.type(), t.min()) && 
    e <= cast(e.type(), t.max())) 

Basically I'm wondering if we can do something that appeals to verified code to reason about the bounds of mods and ramps, rather than having to worry about this method being correct. It's hard to get bounds right.

Expr b = cast(wide_ty, mod->b);
Expr minus_one = make_const(wide_ty, -1);
val = simplify(Max::make(b, minus_one * b) + minus_one);
}
if (lossless_cast(t, val).defined()) {
return cast(t, e);
} else {
return Expr();
}
}

if (const Ramp *ramp = e.as<Ramp>()) {
if (t.bits() > 32) {
return Expr();
}
Type ty = t.with_lanes(ramp->type.lanes() / t.lanes());
Type wide_ty = ty.with_bits(64);
Expr first = ramp->base;
// Cast to wide_ty to prevent overflows.
Expr last = simplify(cast(wide_ty, first) +
cast(wide_ty, ramp->lanes - 1) * cast(wide_ty, ramp->stride));
if (lossless_cast(ty, first).defined() &&
lossless_cast(ty, last).defined()) {
return cast(t, e);
} else {
return Expr();
}
}

return Expr();
}

Expand Down Expand Up @@ -2365,4 +2404,83 @@ Expr undef(Type t) {
Internal::Call::PureIntrinsic);
}

namespace Internal {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a general rule, if it can be an external correctness test, it should be, to keep libHalide smaller. I think this code can be.


void check_lossless_cast(const Type &t, const Expr &in, const Expr &correct) {
Expr result = lossless_cast(t, in);
internal_assert(equal(result, correct))
<< "Incorrect lossless_cast result:\nlossless_cast("
<< t << ", "
<< in
<< ") gave: "
<< result
<< " but expected was: "
<< correct << "\n";
}

void lossless_cast_test() {
Expr x = Variable::make(Int(32), "x");
Type u8 = UInt(8);
Type u16 = UInt(16);
Type u32 = UInt(32);
Type i8 = Int(8);
Type i16 = Int(16);
Type i32 = Int(32);
Type u8x = UInt(8, 4);
Type u16x = UInt(16, 4);
Type u32x = UInt(32, 4);
Expr var_i8 = Variable::make(i8, "x");
Expr var_i16 = Variable::make(i16, "x");
Expr var_u8 = Variable::make(u8, "x");
Expr var_u16 = Variable::make(u16, "x");
Expr var_u8x = Variable::make(u8x, "x");

Expr e = Ramp::make(0, 1, 4);
check_lossless_cast(u8x, e, cast(u8x, e));

// Overflowing ramp
e = Ramp::make(make_const(u16, 5), make_const(u16, 32800), 3);
check_lossless_cast(u8, e, Expr());

e = x % 4;
check_lossless_cast(UInt(8), e, cast(UInt(8), e));

e = var_i8 % make_const(i8, -128);
check_lossless_cast(UInt(8), e, cast(UInt(8), e));

e = var_i16 % make_const(i16, -256);
check_lossless_cast(UInt(8), e, cast(UInt(8), e));

e = var_u16 % make_const(u16, 256);
check_lossless_cast(UInt(8), e, cast(UInt(8), e));

e = var_u8 % Variable::make(u8, "y");
check_lossless_cast(UInt(8), e, cast(UInt(8), e));

e = cast(u8, x);
check_lossless_cast(i32, e, cast(i32, e));

e = cast(u8, x);
check_lossless_cast(i32, e, cast(i32, e));

e = cast(i8, var_u16);
check_lossless_cast(u16, e, Expr());

e = cast(i16, var_u16);
check_lossless_cast(u16, e, Expr());

e = cast(u32, var_u8);
check_lossless_cast(u16, e, cast(u16, var_u8));

e = VectorReduce::make(VectorReduce::Add, cast(u16x, var_u8x), 1);
check_lossless_cast(u16, e, cast(u16, e));

e = VectorReduce::make(VectorReduce::Add, cast(u32x, var_u8x), 1);
check_lossless_cast(u16, e, VectorReduce::make(VectorReduce::Add, cast(u16x, var_u8x), 1));

debug(0) << "lossless_cast test passed\n";
}

} // namespace Internal

} // namespace Halide
2 changes: 2 additions & 0 deletions src/IROperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ Expr const_false(int lanes = 1);
* Expr. */
Expr lossless_cast(Type t, Expr e);

void lossless_cast_test();

/** Coerce the two expressions to have the same type, using C-style
* casting rules. For the purposes of casting, a boolean type is
* UInt(1). We use the following procedure:
Expand Down
2 changes: 2 additions & 0 deletions test/internal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "IR.h"
#include "IREquality.h"
#include "IRMatch.h"
#include "IROperator.h"
#include "IRPrinter.h"
#include "Interval.h"
#include "ModulusRemainder.h"
Expand Down Expand Up @@ -42,6 +43,7 @@ int main(int argc, const char **argv) {
generator_test();
propagate_estimate_test();
uniquify_variable_names_test();
lossless_cast_test();

printf("Success!\n");
return 0;
Expand Down