Skip to content

Commit

Permalink
Merge pull request halide#4452 from benoitsteiner/master
Browse files Browse the repository at this point in the history
Added a test to cover the featurization of reductions
  • Loading branch information
steven-johnson committed Dec 6, 2019
2 parents f667032 + eb83cc9 commit e54e455
Showing 1 changed file with 95 additions and 7 deletions.
102 changes: 95 additions & 7 deletions apps/autoscheduler/test_function_dag.cpp
Expand Up @@ -4,7 +4,7 @@

using namespace Halide;

extern "C" int generate_output_vals(
extern "C" int mul_by_two(
halide_buffer_t *input,
halide_buffer_t *output) {
if (input->is_bounds_query()) {
Expand All @@ -31,11 +31,7 @@ extern "C" int generate_output_vals(
return 0;
}

int main(int argc, char **argv) {
// Use a fixed target for the analysis to get consistent results from this test.
MachineParams params(32, 16000000, 40);
Target target("x86-64-linux-sse41-avx-avx2");

void test_coeff_wise(const MachineParams &params, const Target &target) {
Var x("x"), y("y");

std::ostringstream with_extern;
Expand All @@ -47,7 +43,7 @@ int main(int argc, char **argv) {
std::vector<Var> vars = {x, y};
Halide::Type input_type = Halide::Float(32);
g.define_extern(
"generate_output_vals",
"mul_by_two",
{arg},
input_type,
vars,
Expand Down Expand Up @@ -81,6 +77,98 @@ int main(int argc, char **argv) {

// Disabled for now: there is still work to do to populate the jacobian
//assert(with_extern.str() == without_extern.str());
}

extern "C" int matmul(
halide_buffer_t *input1,
halide_buffer_t *input2,
halide_buffer_t *output) {
if (input1->is_bounds_query() || input2->is_bounds_query()) {
// Bounds query: infer the input dimensions from the output dimensions.
// We leave the k dimension alone since we can't infer it from the output dimensions.
input1->dim[0].min = output->dim[0].min;
input1->dim[0].extent = output->dim[0].extent;
input2->dim[1].min = output->dim[1].min;
input2->dim[1].extent = output->dim[1].extent;
return 0;
}

// Actual computation: return input1 * input2.
const int max_i = output->dim[0].min + output->dim[0].extent;
const int max_j = output->dim[1].min + output->dim[1].extent;
for (int i = output->dim[0].min; i < max_i; ++i) {
for (int j = output->dim[1].min; j < max_j; ++j) {
int pos[2] = {i, j};
float *out = (float *)output->address_of(pos);
*out = 0.0f;
for (int k = 0; k < input1->dim[1].extent; ++k) {
int pos1[2] = {i, k};
float *in1 = (float *)input1->address_of(pos1);
int pos2[2] = {k, j};
float *in2 = (float *)input2->address_of(pos2);
(*out) += (*in1) * (*in2);
}
}
}
return 0;
}

void test_matmul(const MachineParams &params, const Target &target) {
Var x("x"), y("y"), k("k");
RDom r(0, 200);
Halide::Buffer<float> input1(200, 200);
Halide::Buffer<float> input2(200, 200);

std::ostringstream with_extern;
{
Func mm("mm"), h("h");

Halide::ExternFuncArgument arg1 = input1;
Halide::ExternFuncArgument arg2 = input2;
std::vector<Var> vars = {x, y};
Halide::Type input_type = Halide::Float(32);
mm.define_extern(
"matmul",
{arg1, arg2},
{input_type, input_type},
vars,
Halide::NameMangling::C);
mm.function().extern_definition_proxy_expr() = Halide::sum(input1(x, r) * input2(r, y));

h(x, y) = mm(x, y);

h.set_estimate(x, 0, 200).set_estimate(y, 0, 200);
std::vector<Halide::Internal::Function> v;
v.push_back(h.function());
Halide::Internal::Autoscheduler::FunctionDAG d(v, params, target);

d.dump(with_extern);
}
std::ostringstream without_extern;
{
Func mm("mm"), h("h");
mm(x, y) = Halide::sum(input1(x, r) * input2(r, y));
h(x, y) = mm(x, y);

h.set_estimate(x, 0, 200).set_estimate(y, 0, 200);
std::vector<Halide::Internal::Function> v;
v.push_back(h.function());
Halide::Internal::Autoscheduler::FunctionDAG d(v, params, target);

d.dump(without_extern);
}

std::cout << "with_extern:\n " << with_extern.str()
<< "\n\nwithout_extern:\n " << without_extern.str() << std::endl;
}

int main(int argc, char **argv) {
// Use a fixed target for the analysis to get consistent results from this test.
MachineParams params(32, 16000000, 40);
Target target("x86-64-linux-sse41-avx-avx2");

test_coeff_wise(params, target);
test_matmul(params, target);

return 0;
}

0 comments on commit e54e455

Please sign in to comment.