Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix several bugs on Hexagon and some cleanup #5570

Merged
merged 2 commits into from
Dec 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 0 additions & 9 deletions src/CodeGen_Hexagon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,15 +547,6 @@ void CodeGen_Hexagon::compile_func(const LoweredFunc &f,
debug(2) << "Lowering after optimizing shuffles:\n"
<< body << "\n\n";

// Generating vtmpy before CSE and align_loads makes it easier to match
// patterns for vtmpy.
#if 0
// TODO(aankit): Re-enable this after fixing complexity issue.
debug(1) << "Generating vtmpy...\n";
body = vtmpy_generator(body);
debug(2) << "Lowering after generating vtmpy:\n" << body << "\n\n";
#endif

debug(1) << "Aligning loads for HVX....\n";
body = align_loads(body, target.natural_vector_size(Int(8)));
body = common_subexpression_elimination(body);
Expand Down
251 changes: 18 additions & 233 deletions src/HexagonOptimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ Expr bc(Expr x) {
}

// Helpers to generate horizontally reducing multiply operations.
Expr halide_hexagon_add_2mpy(Type result_type, const string &suffix, Expr v0, Expr v1, Expr c0, Expr c1) {
Expr call = Call::make(result_type, "halide.hexagon.add_2mpy" + suffix,
{std::move(v0), std::move(v1), std::move(c0), std::move(c1)}, Call::PureExtern);
return native_interleave(call);
}

Expr halide_hexagon_add_2mpy(Type result_type, const string &suffix, Expr v01, Expr c01) {
return Call::make(result_type, "halide.hexagon.add_2mpy" + suffix,
{std::move(v01), std::move(c01)}, Call::PureExtern);
Expand Down Expand Up @@ -741,15 +747,15 @@ class OptimizePatterns : public IRMutator {
// particular order. It should be more robust... but
// this is pretty tough to do, other than simply
// trying all permutations.
Expr b0123 = Shuffle::make_interleave({mpys[0].second, mpys[1].second,
mpys[0].second, mpys[1].second});
b0123 = simplify(b0123);
b0123 = reinterpret(Type(b0123.type().code(), 32, 1), b0123);
bool is_vdmpy = (!a01.as<Shuffle>() || vmpa_suffix.empty());
string suffix = (is_vdmpy) ? vdmpy_suffix : vmpa_suffix;
Expr new_expr = halide_hexagon_add_2mpy(op->type, suffix,
a01, b0123);
new_expr = (!is_vdmpy) ? native_interleave(new_expr) : new_expr;
Expr new_expr;
if (!a01.as<Shuffle>() || vmpa_suffix.empty()) {
Expr b01 = Shuffle::make_interleave({mpys[0].second, mpys[1].second, mpys[0].second, mpys[1].second});
b01 = simplify(b01);
b01 = reinterpret(Type(b01.type().code(), 32, 1), b01);
new_expr = halide_hexagon_add_2mpy(op->type, vdmpy_suffix, a01, b01);
} else {
new_expr = halide_hexagon_add_2mpy(op->type, vmpa_suffix, mpys[0].first, mpys[1].first, mpys[0].second, mpys[1].second);
}
if (rest.defined()) {
new_expr = Add::make(new_expr, rest);
}
Expand All @@ -759,8 +765,8 @@ class OptimizePatterns : public IRMutator {

static const vector<Pattern> adds = {
// Use accumulating versions of vmpa, vdmpy, vrmpy instructions when possible.
{"halide.hexagon.acc_add_2mpy.vh.vub.vub.b.b", wild_i16x + native_interleave(halide_hexagon_add_2mpy(Int(16, 0), ".vub.vub.b.b", wild_u8x, wild_i32)), Pattern::ReinterleaveOp0},
{"halide.hexagon.acc_add_2mpy.vw.vh.vh.b.b", wild_i32x + native_interleave(halide_hexagon_add_2mpy(Int(32, 0), ".vh.vh.b.b", wild_i16x, wild_i32)), Pattern::ReinterleaveOp0},
{"halide.hexagon.acc_add_2mpy.vh.vub.vub.b.b", wild_i16x + halide_hexagon_add_2mpy(Int(16, 0), ".vub.vub.b.b", wild_u8x, wild_u8x, wild_i8, wild_i8), Pattern::ReinterleaveOp0},
{"halide.hexagon.acc_add_2mpy.vw.vh.vh.b.b", wild_i32x + halide_hexagon_add_2mpy(Int(32, 0), ".vh.vh.b.b", wild_i16x, wild_i16x, wild_i8, wild_i8), Pattern::ReinterleaveOp0},
{"halide.hexagon.acc_add_2mpy.vh.vub.b", wild_i16x + halide_hexagon_add_2mpy(Int(16, 0), ".vub.b", wild_u8x, wild_i32)},
{"halide.hexagon.acc_add_2mpy.vw.vh.b", wild_i32x + halide_hexagon_add_2mpy(Int(32, 0), ".vh.b", wild_i16x, wild_i32)},
{"halide.hexagon.acc_add_3mpy.vh.vub.b", wild_i16x + halide_hexagon_add_3mpy(Int(16, 0), ".vub.b", wild_u8x, wild_i16), Pattern::ReinterleaveOp0},
Expand Down Expand Up @@ -1627,7 +1633,7 @@ class EliminateInterleaves : public IRMutator {
// The let must have been dead.
internal_assert(!stmt_or_expr_uses_var(op->body, op->name))
<< "EliminateInterleaves eliminated a non-dead let.\n";
return NodeType();
return op->body;
}
}
}
Expand Down Expand Up @@ -2089,219 +2095,6 @@ class OptimizeShuffles : public IRMutator {
}
};

// Attempt to generate vtmpy instructions. This requires that all lets
// be substituted prior to running, and so must be an IRGraphMutator.
class VtmpyGenerator : public IRGraphMutator {
private:
using IRMutator::visit;
typedef pair<Expr, size_t> LoadIndex;

// Check if vectors a and b point to the same buffer with the base of a
// shifted by diff i.e. base(a) = base(b) + diff.
bool is_base_shifted(const Expr &a, const Expr &b, int diff) {
Expr maybe_load_a = calc_load(a);
Expr maybe_load_b = calc_load(b);

if (maybe_load_a.defined() && maybe_load_b.defined()) {
const Load *load_a = maybe_load_a.as<Load>();
const Load *load_b = maybe_load_b.as<Load>();
if (load_a->name == load_b->name) {
Expr base_diff = simplify(load_a->index - load_b->index - diff);
if (is_const(base_diff, 0)) {
return true;
}
}
}
return false;
}

// Return the load expression of first vector if all vector in exprs are
// contiguous vectors pointing to the same buffer.
Expr are_contiguous_vectors(const vector<Expr> &exprs) {
if (exprs.empty()) {
return Expr();
}
// If the shuffle simplifies then the vectors are contiguous.
// If not, check if the bases of adjacent vectors differ by
// vector size.
Expr concat = simplify(Shuffle::make_concat(exprs));
const Shuffle *maybe_shuffle = concat.as<Shuffle>();
if (!maybe_shuffle || !maybe_shuffle->is_concat()) {
return calc_load(exprs[0]);
}
return Expr();
}

// Returns the load indicating vector start index. If the vector is sliced
// return load with shifted ramp by slice_begin expr.
Expr calc_load(const Expr &e) {
if (const Cast *maybe_cast = e.as<Cast>()) {
return calc_load(maybe_cast->value);
}
if (const Shuffle *maybe_shuffle = e.as<Shuffle>()) {
if (maybe_shuffle->is_slice() && maybe_shuffle->slice_stride() == 1) {
Expr maybe_load = calc_load(maybe_shuffle->vectors[0]);
if (!maybe_load.defined()) {
return Expr();
}
const Load *res = maybe_load.as<Load>();
Expr shifted_load = Load::make(res->type, res->name, res->index + maybe_shuffle->slice_begin(),
res->image, res->param, res->predicate, ModulusRemainder());
return shifted_load;
} else if (maybe_shuffle->is_concat()) {
return are_contiguous_vectors(maybe_shuffle->vectors);
}
}
if (const Load *maybe_load = e.as<Load>()) {
const Ramp *maybe_ramp = maybe_load->index.as<Ramp>();
if (maybe_ramp && is_const(maybe_ramp->stride, 1)) {
return maybe_load;
}
}
return Expr();
}

// Loads comparator for sorting Load Expr of the same buffer.
static bool loads_comparator(const LoadIndex &a, const LoadIndex &b) {
if (a.first.defined() && b.first.defined()) {
const Load *load_a = a.first.as<Load>();
const Load *load_b = b.first.as<Load>();
if (load_a->name == load_b->name) {
Expr base_diff = simplify(load_b->index - load_a->index);
if (is_positive_const(base_diff)) {
return true;
}
} else {
return load_a->name < load_b->name;
}
}
return false;
}

// Vtmpy helps in sliding window ops of the form a*v0 + b*v1 + v2.
// Conditions required:
// v0, v1 and v2 start indices differ by vector stride
// Current supported value of stride is 1.
// TODO: Add support for any stride.
Expr visit(const Add *op) override {
// Find opportunities vtmpy
if (op && op->type.is_vector() && (op->type.bits() == 16 || op->type.bits() == 32)) {
int lanes = op->type.lanes();
vector<MulExpr> mpys;
Expr rest;
string vtmpy_suffix;

// Finding more than 100 such expresssions is rare.
// Setting it to 100 makes sure we dont miss anything
// in most cases and also dont spend unreasonable time while
// just looking for vtmpy patterns.
const int max_mpy_ops = 100;
if (op->type.bits() == 16) {
find_mpy_ops(op, UInt(8, lanes), Int(8), max_mpy_ops, mpys, rest);
vtmpy_suffix = ".vub.h";
if (mpys.size() < 3) {
mpys.clear();
rest = Expr();
find_mpy_ops(op, Int(8, lanes), Int(8), max_mpy_ops, mpys, rest);
vtmpy_suffix = ".vb.h";
}
} else if (op->type.bits() == 32) {
find_mpy_ops(op, Int(16, lanes), Int(8), max_mpy_ops, mpys, rest);
vtmpy_suffix = ".vh.h";
}

if (mpys.size() >= 3) {
const size_t mpy_size = mpys.size();
// Used to put loads with different buffers in different buckets.
std::unordered_map<string, vector<LoadIndex>> loads;
// To keep track of indices selected for vtmpy.
std::unordered_map<size_t, bool> vtmpy_indices;
vector<Expr> vtmpy_exprs;
Expr new_expr;

for (size_t i = 0; i < mpy_size; i++) {
Expr curr_load = calc_load(mpys[i].first);
if (curr_load.defined()) {
loads[curr_load.as<Load>()->name].emplace_back(curr_load, i);
} else {
new_expr = new_expr.defined() ? new_expr + curr_load : curr_load;
}
}

for (auto iter = loads.begin(); iter != loads.end(); iter++) {
// Sort the bucket and compare bases of 3 adjacent vectors
// at a time. If they differ by vector stride, we've
// found a vtmpy
// It doesn't see to be easy to write a comparator function that'll implement a
// strict weak ordering. So, we use stable_sort instead of sort so at the very least, the relative order
// of tied elements in the vector to be sorted is not changed.
std::stable_sort(iter->second.begin(), iter->second.end(), loads_comparator);
size_t vec_size = iter->second.size();
for (size_t i = 0; i + 2 < vec_size; i++) {
Expr v0 = iter->second[i].first;
Expr v1 = iter->second[i + 1].first;
Expr v2 = iter->second[i + 2].first;
size_t v0_idx = iter->second[i].second;
size_t v1_idx = iter->second[i + 1].second;
size_t v2_idx = iter->second[i + 2].second;
if (is_const(mpys[v2_idx].second, 1) &&
is_base_shifted(v2, v1, 1) &&
is_base_shifted(v1, v0, 1)) {

vtmpy_indices[v0_idx] = true;
vtmpy_indices[v1_idx] = true;
vtmpy_indices[v2_idx] = true;

Expr dv = Shuffle::make_interleave({mpys[v0_idx].first, mpys[v2_idx].first});
Expr constant = Shuffle::make_interleave({mpys[v0_idx].second, mpys[v1_idx].second});
Expr new_expr = halide_hexagon_add_3mpy(op->type, vtmpy_suffix,
dv, constant);

vtmpy_exprs.emplace_back(new_expr);
// As we cannot test the same indices again
i = i + 2;
}
}
}
// If we found any vtmpy's then recombine Expr using
// vtmpy_expr, non_vtmpy_exprs and rest.
if (!vtmpy_exprs.empty()) {
for (size_t i = 0; i < mpy_size; i++) {
if (vtmpy_indices[i]) {
continue;
}
// We put expressions in mpys after un-broadcasting them. So, first broadcast
// then call lossless_cast.
Expr a = mpys[i].first;
Expr b = mpys[i].second;
int lanes = op->type.lanes();

if (a.type().is_scalar()) {
a = Broadcast::make(a, lanes);
}
if (b.type().is_scalar()) {
b = Broadcast::make(b, lanes);
}

Expr mpy_a = lossless_cast(op->type, a);
Expr mpy_b = lossless_cast(op->type, b);
Expr mpy_res = mpy_a * mpy_b;
new_expr = new_expr.defined() ? new_expr + mpy_res : mpy_res;
}
for (size_t i = 0; i < vtmpy_exprs.size(); i++) {
new_expr = new_expr.defined() ? new_expr + vtmpy_exprs[i] : vtmpy_exprs[i];
}
if (rest.defined()) {
new_expr = new_expr + rest;
}
return mutate(new_expr);
}
}
}
return IRMutator::visit(op);
}
};

// Convert some expressions to an equivalent form which could get better
// optimized in later stages for hexagon
class RearrangeExpressions : public IRMutator {
Expand Down Expand Up @@ -2602,14 +2395,6 @@ Stmt optimize_hexagon_shuffles(const Stmt &s, int lut_alignment) {
return OptimizeShuffles(lut_alignment).mutate(s);
}

Stmt vtmpy_generator(Stmt s) {
// Generate vtmpy instruction if possible
s = substitute_in_all_lets(s);
s = VtmpyGenerator().mutate(s);
s = common_subexpression_elimination(s);
return s;
}

Stmt scatter_gather_generator(Stmt s) {
// Generate vscatter-vgather instruction if target >= v65
s = substitute_in_all_lets(s);
Expand Down
3 changes: 0 additions & 3 deletions src/HexagonOptimize.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ namespace Internal {
* calls. */
Stmt optimize_hexagon_shuffles(const Stmt &s, int lut_alignment);

/** Generate vtmpy instruction if possible */
Stmt vtmpy_generator(Stmt s);

/* Generate vscatter-vgather instructions on Hexagon using VTCM memory.
* The pass should be run before generating shuffles.
* Some expressions which generate vscatter-vgathers are:
Expand Down
44 changes: 44 additions & 0 deletions src/runtime/hvx_128.ll
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,50 @@ define weak_odr <128 x i8> @halide.hexagon.shr.vb.vb(<128 x i8> %a, <128 x i8> %
ret <128 x i8> %r
}

declare <64 x i32> @llvm.hexagon.V6.vmpabus.128B(<64 x i32>, i32)
declare <64 x i32> @llvm.hexagon.V6.vmpabus.acc.128B(<64 x i32>, <64 x i32>, i32)

define weak_odr <128 x i16> @halide.hexagon.add_2mpy.vub.vub.b.b(<128 x i8> %low_v, <128 x i8> %high_v, i8 %low_c, i8 %high_c) nounwind uwtable readnone {
%const = call i32 @halide.hexagon.interleave.b.dup2.h(i8 %low_c, i8 %high_c)
%low = bitcast <128 x i8> %low_v to <32 x i32>
%high = bitcast <128 x i8> %high_v to <32 x i32>
%dv = call <64 x i32> @llvm.hexagon.V6.vcombine.128B(<32 x i32> %high, <32 x i32> %low)
%res = call <64 x i32> @llvm.hexagon.V6.vmpabus.128B(<64 x i32> %dv, i32 %const)
%ret_val = bitcast <64 x i32> %res to <128 x i16>
ret <128 x i16> %ret_val
}

define weak_odr <128 x i16> @halide.hexagon.acc_add_2mpy.vh.vub.vub.b.b(<128 x i16> %acc, <128 x i8> %low_v, <128 x i8> %high_v, i8 %low_c, i8 %high_c) nounwind uwtable readnone {
%dv0 = bitcast <128 x i16> %acc to <64 x i32>
%const = call i32 @halide.hexagon.interleave.b.dup2.h(i8 %low_c, i8 %high_c)
%low = bitcast <128 x i8> %low_v to <32 x i32>
%high = bitcast <128 x i8> %high_v to <32 x i32>
%dv1 = call <64 x i32> @llvm.hexagon.V6.vcombine.128B(<32 x i32> %high, <32 x i32> %low)
%res = call <64 x i32> @llvm.hexagon.V6.vmpabus.acc.128B(<64 x i32> %dv0, <64 x i32> %dv1, i32 %const)
%ret_val = bitcast <64 x i32> %res to <128 x i16>
ret <128 x i16> %ret_val
}

declare <64 x i32> @llvm.hexagon.V6.vmpahb.128B(<64 x i32>, i32)
declare <64 x i32> @llvm.hexagon.V6.vmpahb.acc.128B(<64 x i32>, <64 x i32>, i32)

define weak_odr <64 x i32> @halide.hexagon.add_2mpy.vh.vh.b.b(<64 x i16> %low_v, <64 x i16> %high_v, i8 %low_c, i8 %high_c) nounwind uwtable readnone {
%const = call i32 @halide.hexagon.interleave.b.dup2.h(i8 %low_c, i8 %high_c)
%low = bitcast <64 x i16> %low_v to <32 x i32>
%high = bitcast <64 x i16> %high_v to <32 x i32>
%dv = call <64 x i32> @llvm.hexagon.V6.vcombine.128B(<32 x i32> %high, <32 x i32> %low)
%res = call <64 x i32> @llvm.hexagon.V6.vmpahb.128B(<64 x i32> %dv, i32 %const)
ret <64 x i32> %res
}

define weak_odr <64 x i32> @halide.hexagon.acc_add_2mpy.vw.vh.vh.b.b(<64 x i32> %acc, <64 x i16> %low_v, <64 x i16> %high_v, i8 %low_c, i8 %high_c) nounwind uwtable readnone {
%const = call i32 @halide.hexagon.interleave.b.dup2.h(i8 %low_c, i8 %high_c)
%low = bitcast <64 x i16> %low_v to <32 x i32>
%high = bitcast <64 x i16> %high_v to <32 x i32>
%dv1 = call <64 x i32> @llvm.hexagon.V6.vcombine.128B(<32 x i32> %high, <32 x i32> %low)
%res = call <64 x i32> @llvm.hexagon.V6.vmpahb.acc.128B(<64 x i32> %acc, <64 x i32> %dv1, i32 %const)
ret <64 x i32> %res
}

; Define a missing saturating narrow instruction in terms of a saturating narrowing shift.
declare <32 x i32> @llvm.hexagon.V6.vasrwuhsat.128B(<32 x i32>, <32 x i32>, i32)
Expand Down