Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix saturating add matching in associativity checking #8220

Merged
merged 4 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 44 additions & 78 deletions src/AssociativeOpsTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,6 @@ using std::vector;

namespace {

enum class RootExpr {
Add = 0,
Mul = 1,
Max = 2,
Min = 3,
Sub = 4,
Select = 5,
And = 6,
Or = 7,
Cast = 8,
Unknown = 9, // Not supported IR type
};

enum class ValType {
UInt1 = 0,
UInt8 = 1,
Expand Down Expand Up @@ -93,12 +80,12 @@ vector<ValType> convert_halide_types_to_val_types(const vector<Type> &halide_typ

struct TableKey {
vector<ValType> types;
RootExpr root;
IRNodeType root;
size_t dim;
TableKey(ValType t, RootExpr r, size_t d)
TableKey(ValType t, IRNodeType r, size_t d)
: types({t}), root(r), dim(d) {
}
TableKey(const vector<ValType> &t, RootExpr r, size_t d)
TableKey(const vector<ValType> &t, IRNodeType r, size_t d)
: types(t), root(r), dim(d) {
}

Expand Down Expand Up @@ -169,6 +156,14 @@ void populate_ops_table_single_general_select(const vector<Type> &types, vector<
declare_vars_single(types);
}

void populate_ops_table_single_general_call(const vector<Type> &types, vector<AssociativePattern> &table) {
declare_vars_single(types);
if (types[0].code() == Type::UInt) {
table.emplace_back(saturating_add(x0, y0), zero_0, true);
table.emplace_back(saturating_cast(types[0], widening_add(x0, y0)), zero_0, true);
}
}

void populate_ops_table_double_general_add(const vector<Type> &types, vector<AssociativePattern> &table) {
declare_vars_double(types);
if (types[0] == types[1]) {
Expand Down Expand Up @@ -217,9 +212,9 @@ void populate_ops_table_single_uint8_cast(const vector<Type> &types, vector<Asso
Expr k0_uint16 = Variable::make(UInt(16), "k0");
Expr k0_uint32 = Variable::make(UInt(32), "k0");
Expr k0_uint64 = Variable::make(UInt(64), "k0");
table.emplace_back(cast<uint8_t>(min(cast<uint16_t>(x0 + y0), k0_uint16)), zero_0, true);
table.emplace_back(cast<uint8_t>(min(cast<uint32_t>(x0 + y0), k0_uint32)), zero_0, true);
table.emplace_back(cast<uint8_t>(min(cast<uint64_t>(x0 + y0), k0_uint64)), zero_0, true);
table.emplace_back(cast<uint8_t>(min(cast<uint16_t>(x0) + y0, k0_uint16)), zero_0, true);
table.emplace_back(cast<uint8_t>(min(cast<uint32_t>(x0) + y0, k0_uint32)), zero_0, true);
table.emplace_back(cast<uint8_t>(min(cast<uint64_t>(x0) + y0, k0_uint64)), zero_0, true);
}

void populate_ops_table_single_uint8_select(const vector<Type> &types, vector<AssociativePattern> &table) {
Expand All @@ -232,8 +227,8 @@ void populate_ops_table_single_uint16_cast(const vector<Type> &types, vector<Ass
declare_vars_single(types);
Expr k0_uint32 = Variable::make(UInt(32), "k0");
Expr k0_uint64 = Variable::make(UInt(64), "k0");
table.emplace_back(cast<uint16_t>(min(cast<uint32_t>(x0 + y0), k0_uint32)), zero_0, true);
table.emplace_back(cast<uint16_t>(min(cast<uint64_t>(x0 + y0), k0_uint64)), zero_0, true);
table.emplace_back(cast<uint16_t>(min(cast<uint32_t>(x0) + y0, k0_uint32)), zero_0, true);
table.emplace_back(cast<uint16_t>(min(cast<uint64_t>(x0) + y0, k0_uint64)), zero_0, true);
}

void populate_ops_table_single_uint16_select(const vector<Type> &types, vector<AssociativePattern> &table) {
Expand All @@ -255,33 +250,34 @@ void populate_ops_table_single_uint32_select(const vector<Type> &types, vector<A
}

const map<TableKey, void (*)(const vector<Type> &types, vector<AssociativePattern> &)> val_type_to_populate_luts_fn = {
{TableKey(ValType::All, RootExpr::Add, 1), &populate_ops_table_single_general_add},
{TableKey(ValType::All, RootExpr::Mul, 1), &populate_ops_table_single_general_mul},
{TableKey(ValType::All, RootExpr::Max, 1), &populate_ops_table_single_general_max},
{TableKey(ValType::All, RootExpr::Min, 1), &populate_ops_table_single_general_min},
{TableKey(ValType::All, RootExpr::Sub, 1), &populate_ops_table_single_general_sub},
{TableKey(ValType::All, RootExpr::Select, 1), &populate_ops_table_single_general_select},
{TableKey(ValType::All, RootExpr::Add, 2), &populate_ops_table_double_general_add},
{TableKey(ValType::All, RootExpr::Mul, 2), &populate_ops_table_double_general_mul},
{TableKey(ValType::All, RootExpr::Max, 2), &populate_ops_table_double_general_max},
{TableKey(ValType::All, RootExpr::Min, 2), &populate_ops_table_double_general_min},
{TableKey(ValType::All, RootExpr::Sub, 2), &populate_ops_table_double_general_sub},
{TableKey(ValType::All, RootExpr::Select, 2), &populate_ops_table_double_general_select},

{TableKey(ValType::UInt1, RootExpr::And, 1), &populate_ops_table_single_uint1_and},
{TableKey(ValType::UInt1, RootExpr::Or, 1), &populate_ops_table_single_uint1_or},

{TableKey(ValType::UInt8, RootExpr::Cast, 1), &populate_ops_table_single_uint8_cast},
{TableKey(ValType::UInt8, RootExpr::Select, 1), &populate_ops_table_single_uint8_select},

{TableKey(ValType::UInt16, RootExpr::Cast, 1), &populate_ops_table_single_uint16_cast},
{TableKey(ValType::UInt16, RootExpr::Select, 1), &populate_ops_table_single_uint16_select},

{TableKey(ValType::UInt32, RootExpr::Cast, 1), &populate_ops_table_single_uint32_cast},
{TableKey(ValType::UInt32, RootExpr::Select, 1), &populate_ops_table_single_uint32_select},
{TableKey(ValType::All, IRNodeType::Add, 1), &populate_ops_table_single_general_add},
{TableKey(ValType::All, IRNodeType::Mul, 1), &populate_ops_table_single_general_mul},
{TableKey(ValType::All, IRNodeType::Max, 1), &populate_ops_table_single_general_max},
{TableKey(ValType::All, IRNodeType::Min, 1), &populate_ops_table_single_general_min},
{TableKey(ValType::All, IRNodeType::Sub, 1), &populate_ops_table_single_general_sub},
{TableKey(ValType::All, IRNodeType::Select, 1), &populate_ops_table_single_general_select},
{TableKey(ValType::All, IRNodeType::Call, 1), &populate_ops_table_single_general_call},
{TableKey(ValType::All, IRNodeType::Add, 2), &populate_ops_table_double_general_add},
{TableKey(ValType::All, IRNodeType::Mul, 2), &populate_ops_table_double_general_mul},
{TableKey(ValType::All, IRNodeType::Max, 2), &populate_ops_table_double_general_max},
{TableKey(ValType::All, IRNodeType::Min, 2), &populate_ops_table_double_general_min},
{TableKey(ValType::All, IRNodeType::Sub, 2), &populate_ops_table_double_general_sub},
{TableKey(ValType::All, IRNodeType::Select, 2), &populate_ops_table_double_general_select},

{TableKey(ValType::UInt1, IRNodeType::And, 1), &populate_ops_table_single_uint1_and},
{TableKey(ValType::UInt1, IRNodeType::Or, 1), &populate_ops_table_single_uint1_or},

{TableKey(ValType::UInt8, IRNodeType::Cast, 1), &populate_ops_table_single_uint8_cast},
{TableKey(ValType::UInt8, IRNodeType::Select, 1), &populate_ops_table_single_uint8_select},

{TableKey(ValType::UInt16, IRNodeType::Cast, 1), &populate_ops_table_single_uint16_cast},
{TableKey(ValType::UInt16, IRNodeType::Select, 1), &populate_ops_table_single_uint16_select},

{TableKey(ValType::UInt32, IRNodeType::Cast, 1), &populate_ops_table_single_uint32_cast},
{TableKey(ValType::UInt32, IRNodeType::Select, 1), &populate_ops_table_single_uint32_select},
};

const vector<AssociativePattern> &get_ops_table_helper(const vector<Type> &types, RootExpr root, size_t dim) {
const vector<AssociativePattern> &get_ops_table_helper(const vector<Type> &types, IRNodeType root, size_t dim) {
TableKey gen_key(ValType::All, root, dim);
TableKey key(convert_halide_types_to_val_types(types), root, dim);

Expand Down Expand Up @@ -336,43 +332,13 @@ const vector<AssociativePattern> &get_ops_table(const vector<Expr> &exprs) {
types[i] = exprs[i].type();
}

RootExpr root = RootExpr::Unknown;
if (exprs[0].as<Halide::Internal::Add>()) {
debug(5) << "Returning Add root table for type " << print_types(types) << "\n";
root = RootExpr::Add;
} else if (exprs[0].as<Halide::Internal::Sub>()) {
debug(5) << "Returning Sub root table for type " << print_types(types) << "\n";
root = RootExpr::Sub;
} else if (exprs[0].as<Halide::Internal::Mul>()) {
debug(5) << "Returning Mul root table for type " << print_types(types) << "\n";
root = RootExpr::Mul;
} else if (exprs[0].as<Halide::Internal::Min>()) {
debug(5) << "Returning Min root table for type " << print_types(types) << "\n";
root = RootExpr::Min;
} else if (exprs[0].as<Halide::Internal::Max>()) {
debug(5) << "Returning Max root table for type " << print_types(types) << "\n";
root = RootExpr::Max;
} else if (exprs[0].as<Halide::Internal::Select>()) {
debug(5) << "Returning Select root table for type " << print_types(types) << "\n";
root = RootExpr::Select;
} else if (exprs[0].as<Halide::Internal::And>()) {
debug(5) << "Returning And root table for type " << print_types(types) << "\n";
root = RootExpr::And;
} else if (exprs[0].as<Halide::Internal::Or>()) {
debug(5) << "Returning Or root table for type " << print_types(types) << "\n";
root = RootExpr::Or;
} else if (exprs[0].as<Halide::Internal::Cast>()) {
debug(5) << "Returning Cast root table for type " << print_types(types) << "\n";
root = RootExpr::Cast;
}

if (root != RootExpr::Unknown) {
{
// get_ops_table_helper() lazily initializes the table, so ensure
// that multiple threads can't try to do so at the same time.
static std::mutex ops_table_lock;
std::lock_guard<std::mutex> lock_guard(ops_table_lock);

const vector<AssociativePattern> &table = get_ops_table_helper(types, root, exprs.size());
const vector<AssociativePattern> &table = get_ops_table_helper(types, exprs[0].node_type(), exprs.size());
debug(7) << "Table size: " << table.size() << "\n";
for (const auto &p : table) {
debug(7) << p;
Expand Down
45 changes: 14 additions & 31 deletions src/Associativity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,37 +543,20 @@ void associativity_test() {
Expr x_idx = Variable::make(Int(32), "x_idx");
Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, FunctionPtr(), 0);

// f(x) = uint8(uint16(x + y), 255)
check_associativity("f", {x_idx}, {Cast::make(UInt(8), min(Cast::make(UInt(16), y + f_call_0), make_const(t, 255)))},
AssociativeOp(
AssociativePattern(Cast::make(UInt(8), min(Cast::make(UInt(16), x + y), make_const(t, 255))), make_const(t, 0), true),
{Replacement("x", f_call_0)},
{Replacement("y", y)},
true));

// f(x) = uint8(uint16(x + y), uint16(255))
check_associativity("f", {x_idx}, {Cast::make(UInt(8), min(Cast::make(UInt(16), y + f_call_0), Cast::make(UInt(16), make_const(t, 255))))},
AssociativeOp(
AssociativePattern(Cast::make(UInt(8), min(Cast::make(UInt(16), x + y), make_const(t, 255))), make_const(t, 0), true),
{Replacement("x", f_call_0)},
{Replacement("y", y)},
true));

// f(x) = select(x > 255 - y, 255, y)
check_associativity("f", {x_idx}, {select(f_call_0 > make_const(t, 255) - y, make_const(t, 255), y)},
AssociativeOp(
AssociativePattern(select(x > make_const(t, 255) - y, make_const(t, 255), y), make_const(t, 0), true),
{Replacement("x", f_call_0)},
{Replacement("y", y)},
true));

// f(x) = select(x >= -y, 255, y)
check_associativity("f", {x_idx}, {select(f_call_0 >= -y, make_const(t, 255), y)},
AssociativeOp(
AssociativePattern(select(x < -y, y, make_const(t, 255)), make_const(t, 0), true),
{Replacement("x", f_call_0)},
{Replacement("y", y)},
true));
for (const Expr &e : {cast<uint8_t>(min(cast<uint16_t>(x) + y, 255)),
select(x > 255 - y, cast<uint8_t>(255), y),
select(x < -y, y, cast<uint8_t>(255)),
saturating_add(x, y),
saturating_add(y, x),
saturating_cast<uint8_t>(widening_add(x, y))}) {
check_associativity("f", {x_idx}, {substitute("x", f_call_0, e)},
AssociativeOp(
AssociativePattern(solve_expression(e, "x").result,
make_const(t, 0), true),
{Replacement("x", f_call_0)},
{Replacement("y", y)},
true));
}
}

{
Expand Down
31 changes: 31 additions & 0 deletions src/Solve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,37 @@ class SolveExpression : public IRMutator {
// Ignore intrinsics that shouldn't affect the results.
if (Call::as_tag(op)) {
return mutate(op->args[0]);
} else if (op->is_intrinsic({Call::absd, Call::bitwise_and, Call::bitwise_or,
Call::bitwise_xor, Call::halving_add, Call::rounding_halving_add,
Call::saturating_add, Call::widening_add, Call::widening_mul})) {
// It's a commutative intrinsic. We won't try to lift uses of the
// var out of the call, but we will reorder the args if it would
// help.
internal_assert(op->args.size() == 2);
bool old_uses_var = uses_var;
uses_var = false;
bool old_failed = failed;
failed = false;
Expr a = mutate(op->args[0]);
bool a_uses_var = uses_var;
bool a_failed = failed;
uses_var = false;
failed = false;
Expr b = mutate(op->args[1]);
bool b_uses_var = uses_var;
bool b_failed = failed;
uses_var = old_uses_var || a_uses_var || b_uses_var;
failed = old_failed || a_failed || b_failed;

failed |= a_uses_var && b_uses_var;

if (b_uses_var && !a_uses_var) {
return Call::make(op->type, op->name, {b, a}, op->call_type);
} else if (a.same_as(op->args[0]) && b.same_as(op->args[1])) {
return op;
} else {
return Call::make(op->type, op->name, {a, b}, op->call_type);
}
} else {
return IRMutator::visit(op);
}
Expand Down
Loading