Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype of multiple scattering update definitions #5553

Merged
merged 19 commits into from Jan 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/Bounds.cpp
Expand Up @@ -1425,6 +1425,14 @@ class Bounds : public IRVisitor {
} else if (op->is_intrinsic(Call::memoize_expr)) {
internal_assert(!op->args.empty());
op->args[0].accept(this);
} else if (op->is_intrinsic(Call::scatter_gather)) {
// Take the union of the args
Interval result = Interval::nothing();
for (const Expr &e : op->args) {
e.accept(this);
result.include(interval);
}
interval = result;
} else if (op->call_type == Call::Halide) {
bounds_of_func(op->name, op->value_index, op->type);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/CSE.cpp
Expand Up @@ -314,7 +314,7 @@ Expr common_subexpression_elimination(const Expr &e_in, bool lift_all) {
// Wrap the final expr in the lets.
for (size_t i = lets.size(); i > 0; i--) {
Expr value = lets[i - 1].second;
// Drop this variable as an acceptible replacement for this expr.
// Drop this variable as an acceptable replacement for this expr.
replacer.erase(value);
// Use containing lets in the value.
value = replacer.mutate(lets[i - 1].second);
Expand Down
1 change: 1 addition & 0 deletions src/IR.cpp
Expand Up @@ -626,6 +626,7 @@ const char *const intrinsic_op_names[] = {
"require_mask",
"return_second",
"rewrite_buffer",
"scatter_gather",
"select_mask",
"shift_left",
"shift_right",
Expand Down
1 change: 1 addition & 0 deletions src/IR.h
Expand Up @@ -538,6 +538,7 @@ struct Call : public ExprNode<Call> {
require_mask,
return_second,
rewrite_buffer,
scatter_gather,
select_mask,
shift_left,
shift_right,
Expand Down
19 changes: 19 additions & 0 deletions src/IROperator.cpp
Expand Up @@ -2361,4 +2361,23 @@ Expr undef(Type t) {
Internal::Call::PureIntrinsic);
}

namespace {
Expr make_scatter_gather(const std::vector<Expr> &args) {
// There's currently no difference in the IR between a gather and
// a scatter. They're distinct just to make code more readable.
return Halide::Internal::Call::make(args[0].type(),
Halide::Internal::Call::scatter_gather,
args,
Halide::Internal::Call::PureIntrinsic);
}
} // namespace

Expr scatter(const std::vector<Expr> &args) {
return make_scatter_gather(args);
}

Expr gather(const std::vector<Expr> &args) {
return make_scatter_gather(args);
}

} // namespace Halide
79 changes: 79 additions & 0 deletions src/IROperator.h
Expand Up @@ -1398,6 +1398,85 @@ namespace Internal {
Expr promise_clamped(const Expr &value, const Expr &min, const Expr &max);
} // namespace Internal

/** Scatter and gather are used for update definition which must store
* multiple values to distinct locations at the same time. The
* multiple expressions on the right-hand-side are bundled together
* into a "gather", which must match a "scatter" the the same number
* of arguments on the left-hand-size. For example, to store the
abadams marked this conversation as resolved.
Show resolved Hide resolved
* values 1 and 2 to the locations (x, y, 3) and (x, y, 4),
* respectively:
*
\code
f(x, y, scatter(3, 4)) = gather(1, 2);
\endcode
*
* The result of gather or scatter can be treated as an
* expression. Any containing operations on it can be assumed to
* distribute over the elements. If two gather expressions are
* combined with an arithmetic operator (e.g. added), they combine
* element-wise. The following example stores the values 2 * x, 2 * y,
* and 2 * c to the locations (x + 1, y, c), (x, y + 3, c), and (x, y,
* c + 2) respectively:
*
\code
f(x + scatter(1, 0, 0), y + scatter(0, 3, 0), c + scatter(0, 0, 2)) = 2 * gather(x, y, c);
\endcode
*
* Repeated values in the scatter cause multiple stores to the same
* location. The stores happen in order from left to right, so the
* rightmost value wins. The following code is equivalent to f(x) = 5
*
\code
f(scatter(x, x)) = gather(3, 5);
\endcode
*
* Gathers are most useful for algorithms which require in-place
* swapping or permutation of multiple elements, or other kinds of
* in-place mutations that require loading multiple inputs, doing some
* operations to them jointly, then storing them again. The following
* update definition swaps the values of f at locations 3 and 5 if an
* input parameter p is true:
*
\code
f(scatter(3, 5)) = f(select(p, gather(5, 3), gather(3, 5)));
\endcode
*
* For more examples of the use of scatter and gather, see
* test/correctness/multiple_scatter.cpp
*
* It is not currently possible to use scatter and gather to write an
* update definition in which the *number* of values loaded or stored
* varies, as the size of the scatter/gather packet must be fixed a
* compile-time. A workaround is to make the unwanted extra operations
* a redundant copy of the last operation, which will be
* dead-code-eliminated by the compiler. For example, the following
abadams marked this conversation as resolved.
Show resolved Hide resolved
* update definition swaps the values at locations 3 and 5 when the
* parameter p is true, and rotates the values at locations 1, 2, and 3
* when it is false. The load from 3 and store to 5 will be redundantly
* repeated:
*
\code
f(select(p, scatter(3, 5, 5), scatter(1, 2, 3))) = f(select(p, gather(5, 3, 3), gather(2, 3, 1)));
\endcode
*
* Note that in the p == true case, we redudantly load from 3 and write
* to 5 twice.
*/
//@{
Expr scatter(const std::vector<Expr> &args);
Expr gather(const std::vector<Expr> &args);

template<typename... Args>
Expr scatter(const Expr &e, Args &&... args) {
return scatter({e, std::forward<Args>(args)...});
}

template<typename... Args>
Expr gather(const Expr &e, Args &&... args) {
return gather({e, std::forward<Args>(args)...});
}
// @}

} // namespace Halide

#endif
1 change: 1 addition & 0 deletions src/Simplify_And.cpp
Expand Up @@ -57,6 +57,7 @@ Expr Simplify::visit(const And *op, ExprInfo *bounds) {
rewrite(!x && x, false) ||
rewrite(y <= x && x < y, false) ||
rewrite(x != c0 && x == c1, b, c0 != c1) ||
rewrite(x == c0 && x == c1, false, c0 != c1) ||
// Note: In the predicate below, if undefined overflow
// occurs, the predicate counts as false. If well-defined
// overflow occurs, the condition couldn't possibly
Expand Down
42 changes: 41 additions & 1 deletion src/Simplify_Stmts.cpp
@@ -1,5 +1,6 @@
#include "Simplify_Internal.h"

#include "ExprUsesVar.h"
#include "IRMutator.h"
#include "Substitute.h"

Expand Down Expand Up @@ -388,6 +389,9 @@ Stmt Simplify::visit(const Block *op) {
rest.as<IfThenElse>() ? rest.as<IfThenElse>() : (block_rest ? block_rest->first.as<IfThenElse>() : nullptr);
Stmt if_rest = block_rest ? block_rest->rest : Stmt();

const Store *store_first = first.as<Store>();
const Store *store_next = block_rest ? block_rest->first.as<Store>() : rest.as<Store>();

if (is_no_op(first) &&
is_no_op(rest)) {
return Evaluate::make(0);
Expand All @@ -411,11 +415,28 @@ Stmt Simplify::visit(const Block *op) {
new_block = substitute(let_rest->name, new_var, new_block);

return LetStmt::make(var_name, let_first->value, new_block);
} else if (store_first &&
store_next &&
store_first->name == store_next->name &&
equal(store_first->index, store_next->index) &&
equal(store_first->predicate, store_next->predicate) &&
is_pure(store_first->index) &&
is_pure(store_first->value) &&
is_pure(store_first->predicate) &&
!expr_uses_var(store_next->index, store_next->name) &&
!expr_uses_var(store_next->value, store_next->name) &&
!expr_uses_var(store_next->predicate, store_next->name)) {
// Second store clobbers first
if (block_rest) {
return Block::make(store_next, block_rest->rest);
} else {
return store_next;
}
} else if (if_first &&
if_next &&
equal(if_first->condition, if_next->condition) &&
is_pure(if_first->condition)) {
// Two ifs with matching conditions
// Two ifs with matching conditions.
Stmt then_case = mutate(Block::make(if_first->then_case, if_next->then_case));
Stmt else_case;
if (if_first->else_case.defined() && if_next->else_case.defined()) {
Expand Down Expand Up @@ -448,6 +469,25 @@ Stmt Simplify::visit(const Block *op) {
result = Block::make(result, if_rest);
}
return result;
} else if (if_first &&
alexreinking marked this conversation as resolved.
Show resolved Hide resolved
if_next &&
is_pure(if_first->condition) &&
is_pure(if_next->condition) &&
is_const_one(mutate(!(if_first->condition && if_next->condition), nullptr))) {
// Two ifs where the first condition being true implies the
// second is false. The second if can be nested inside the
// else case of the first one, turning a block of if
// statements into an if-else chain.
Stmt then_case = if_first->then_case;
Stmt else_case = if_next;
if (if_first->else_case.defined()) {
else_case = Block::make(if_first->else_case, else_case);
}
Stmt result = mutate(IfThenElse::make(if_first->condition, then_case, else_case));
if (if_rest.defined()) {
result = Block::make(result, if_rest);
}
return result;
} else if (op->first.same_as(first) &&
op->rest.same_as(rest)) {
return op;
Expand Down