Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix for lossless_cast with minor additions #5459

Merged
merged 9 commits into from
Dec 14, 2020
82 changes: 80 additions & 2 deletions src/IROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "IRMutator.h"
#include "IROperator.h"
#include "IRPrinter.h"
#include "Simplify.h"
#include "Util.h"
#include "Var.h"

Expand Down Expand Up @@ -438,6 +439,55 @@ Expr const_false(int w) {
return make_zero(UInt(1, w));
}

void check(Type t, const Expr &in, const Expr &correct) {
Expr result = lossless_cast(t, in);
internal_assert(equal(result, correct))
<< "Incorrect lossless_cast result:\nlossless_cast("
<< t << ", "
<< in
<< ") gave: "
<< result
<< " but expected was: "
<< correct << "\n";
}

void lossless_cast_test() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think our convention would be to put this at the bottom of the file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Expr x = Variable::make(Int(32), "x");
Type u8 = UInt(8);
Type u16 = UInt(16);
Type u32 = UInt(32);
Type i8 = Int(8);
Type i16 = Int(8);
Type i32 = Int(32);
Type u8x = UInt(8, 4);

Expr e = Ramp::make(0, 1, 4);
check(u8x, e, cast(u8x, e));

// Overflowing ramp
e = Ramp::make(make_const(u16, 5), make_const(u16, 32800), 3);
check(u8, e, Expr());

e = x % 4;
check(UInt(8), e, cast(UInt(8), e));

e = cast(u8, x);
check(i32, e, cast(i32, e));

e = cast(u8, x);
check(i32, e, cast(i32, e));

e = cast(i8, Variable::make(u16, "x"));
check(u16, e, Expr());

e = cast(i16, Variable::make(u16, "x"));
check(u16, e, Expr());

e = cast(u32, Variable::make(u8, "x"));
check(u16, e, cast(u16, Variable::make(u8, "x")));

debug(0) << "lossless_cast test passed\n";
}
Expr lossless_cast(Type t, Expr e) {
if (!e.defined() || t == e.type()) {
return e;
Expand All @@ -446,8 +496,9 @@ Expr lossless_cast(Type t, Expr e) {
}

if (const Cast *c = e.as<Cast>()) {
if (t.can_represent(c->value.type())) {
// We can recurse into widening casts.
// We can recurse into widening casts.
if (c->type.can_represent(c->value.type()) &&
t.can_represent(c->value.type())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can get rid of this second part of the condition, and leave that to the recursion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right

return lossless_cast(t, c->value);
} else {
return Expr();
Expand Down Expand Up @@ -537,6 +588,7 @@ Expr lossless_cast(Type t, Expr e) {
Type narrower = reduce->value.type().with_bits(t.bits() / 2);
Expr val = lossless_cast(narrower, reduce->value);
if (val.defined()) {
val = cast(narrower.with_bits(t.bits()), val);
return VectorReduce::make(reduce->op, val, reduce->type.lanes());
}
}
Expand Down Expand Up @@ -566,6 +618,32 @@ Expr lossless_cast(Type t, Expr e) {
return Shuffle::make(vecs, shuf->indices);
}

if (const Mod *mod = e.as<Mod>()) {
if (lossless_cast(t, mod->b).defined()) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making some changes for handling negative numbers here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return cast(t, e);
} else {
return Expr();
}
}

if (const Ramp *ramp = e.as<Ramp>()) {
if (t.bits() > 32) {
return Expr();
}
Type ty = t.with_lanes(ramp->type.lanes() / t.lanes());
Type wide_ty = ty.with_bits(64);
Expr first = ramp->base;
// Cast to wide_ty to prevent overflows.
Expr last = simplify(cast(wide_ty, first) +
cast(wide_ty, ramp->lanes - 1) * cast(wide_ty, ramp->stride));
if (lossless_cast(ty, first).defined() &&
lossless_cast(ty, last).defined()) {
return cast(t, e);
} else {
return Expr();
}
}

return Expr();
}

Expand Down
2 changes: 2 additions & 0 deletions src/IROperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ Expr const_false(int lanes = 1);
* Expr. */
Expr lossless_cast(Type t, Expr e);

void lossless_cast_test();

/** Coerce the two expressions to have the same type, using C-style
* casting rules. For the purposes of casting, a boolean type is
* UInt(1). We use the following procedure:
Expand Down
2 changes: 2 additions & 0 deletions test/internal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "IR.h"
#include "IREquality.h"
#include "IRMatch.h"
#include "IROperator.h"
#include "IRPrinter.h"
#include "Interval.h"
#include "ModulusRemainder.h"
Expand Down Expand Up @@ -42,6 +43,7 @@ int main(int argc, const char **argv) {
generator_test();
propagate_estimate_test();
uniquify_variable_names_test();
lossless_cast_test();

printf("Success!\n");
return 0;
Expand Down