Skip to content

Commit

Permalink
Add reshape output for static-mean
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 605399784
  • Loading branch information
alankelly authored and xnnpack-bot committed Feb 8, 2024
1 parent c0fcf0f commit 23f386b
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 2 deletions.
4 changes: 4 additions & 0 deletions include/xnnpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ extern "C" {
/// Use transient indirection buffer to reduce memory footprint
#define XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER 0x00000020

/// Retain reduced dimensions with length 1.
#define XNN_FLAG_KEEP_DIMS 0x00000040

/// The number of entries in an array of xnn_dynamic_quantization_params that XNNPACK may read beyond array bounds.
/// The caller must allocate at least this many extra xnn_dynamic_quantization_params before passing the array to XNNPACK.
///
Expand Down Expand Up @@ -1155,6 +1158,7 @@ enum xnn_status xnn_define_static_constant_pad(
/// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor defined in the
/// @a subgraph with @a num_reduction_axes fewer dimensions than the input tensor.
/// @param flags - binary features of the Mean Node. No supported flags are currently defined.
/// @param flags - binary features of the Mean Node. The only currently supported value is XNN_FLAG_KEEP_DIMS
enum xnn_status xnn_define_static_mean(
xnn_subgraph_t subgraph,
size_t num_reduction_axes,
Expand Down
53 changes: 51 additions & 2 deletions src/subgraph/static-mean.c
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,15 @@ static enum xnn_status reshape_mean_operator(
const struct xnn_value* input_value = values + input_id;
assert(input_value->type == xnn_value_type_dense_tensor);

const uint32_t output_id = opdata->outputs[0];
assert(output_id != XNN_INVALID_VALUE_ID);
assert(output_id < num_values);

const size_t old_workspace_size = opdata->workspace_size;
enum xnn_status status = xnn_status_invalid_state;
switch (opdata->operator_objects[0]->type) {
case xnn_operator_type_mean_nd_f16:
return xnn_reshape_mean_nd_f16(
status = xnn_reshape_mean_nd_f16(
opdata->operator_objects[0],
opdata->num_reduction_axes,
opdata->reduction_axes,
Expand All @@ -74,8 +80,9 @@ static enum xnn_status reshape_mean_operator(
&opdata->workspace_size,
&opdata->workspace_alignment,
threadpool);
break;
case xnn_operator_type_mean_nd_f32:
return xnn_reshape_mean_nd_f32(
status = xnn_reshape_mean_nd_f32(
opdata->operator_objects[0],
opdata->num_reduction_axes,
opdata->reduction_axes,
Expand All @@ -84,9 +91,51 @@ static enum xnn_status reshape_mean_operator(
&opdata->workspace_size,
&opdata->workspace_alignment,
threadpool);
break;
default:
XNN_UNREACHABLE;
}
struct xnn_value* output_value = values + output_id;
size_t input_num_dims = input_value->shape.num_dims;
size_t num_reduction_axes = opdata->num_reduction_axes;
if (opdata->operator_objects[0]->flags & XNN_FLAG_KEEP_DIMS) {
output_value->shape.num_dims = input_value->shape.num_dims;
for (size_t idx = 0; idx < input_num_dims; ++idx) {
bool is_axis = false;
for (size_t axis_idx = 0; axis_idx < num_reduction_axes; ++axis_idx) {
if (opdata->reduction_axes[axis_idx] == idx) {
is_axis = true;
break;
}
}
if (is_axis) {
output_value->shape.dim[idx] = 1;
} else {
output_value->shape.dim[idx] = input_value->shape.dim[idx];
}
}
} else {
size_t num_skip_axis = 0;
for (size_t idx = 0; idx < input_num_dims; ++idx) {
bool is_axis = false;
for (size_t axis_idx = 0; axis_idx < num_reduction_axes; ++axis_idx) {
if (opdata->reduction_axes[axis_idx] == idx) {
++num_skip_axis;
is_axis = true;
break;
}
}
if (!is_axis) {
output_value->shape.dim[idx - num_skip_axis] = input_value->shape.dim[idx];
}
}
}
const size_t new_size = xnn_tensor_get_size(output_value);
if (new_size > output_value->size || opdata->workspace_size > old_workspace_size) {
output_value->size = new_size;
return xnn_status_reallocation_required;
}
return status;
}

static enum xnn_status setup_mean_operator(
Expand Down
166 changes: 166 additions & 0 deletions test/static-mean.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,170 @@ TEST_F(MeanTestF32, matches_operator_api)
}
}

TEST_F(MeanTestF32, reshape_output_keep_dims)
{
ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));

// Call subgraph API.
xnn_subgraph_t subgraph = nullptr;
ASSERT_EQ(xnn_status_success, xnn_create_subgraph(2, /*flags=*/0, &subgraph));
std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);

uint32_t input_id = XNN_INVALID_NODE_ID;
ASSERT_EQ(xnn_status_success,
xnn_define_tensor_value(subgraph, xnn_datatype_fp32, input_shape.size(), input_shape.data(),
nullptr, /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
ASSERT_NE(input_id, XNN_INVALID_NODE_ID);

uint32_t output_id = XNN_INVALID_NODE_ID;
ASSERT_EQ(xnn_status_success,
xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_shape.size(), output_shape.data(),
nullptr, /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
ASSERT_NE(output_id, XNN_INVALID_NODE_ID);

ASSERT_EQ(xnn_status_success,
xnn_define_static_mean(
subgraph,
reduction_axes.size(), reduction_axes.data(),
input_id, output_id,
/*flags=*/XNN_FLAG_KEEP_DIMS));

xnn_runtime_t runtime = nullptr;
ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
ASSERT_NE(nullptr, runtime);
std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);

const std::array<xnn_external_value, 2> external = {
xnn_external_value{input_id, input.data()},
xnn_external_value{output_id, subgraph_output.data()}
};
ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));

input_shape[0] += 2;
input_shape[1] += 4;
ASSERT_EQ(xnn_status_success, xnn_reshape_external_value(runtime, input_id, input_shape.size(), input_shape.data()));
const struct xnn_node* node = &subgraph->nodes[0];
std::vector<size_t> unique_reduction_axes = reduction_axes;
std::sort(unique_reduction_axes.begin(), unique_reduction_axes.end());
auto end = std::unique(unique_reduction_axes.begin(), unique_reduction_axes.end());
unique_reduction_axes.erase(end, unique_reduction_axes.end());
// There are too many parameters which influence the workspace size so
// knowing if reallocation is required or not is messy.
node->reshape(&runtime->opdata[0], runtime->values, runtime->num_values, /*threadpool=*/nullptr);
const xnn_shape* output_shape = &runtime->values[node->outputs[0]].shape;
size_t current_axes = 0;
for (size_t i = 0; i < output_shape->num_dims; ++i) {
if (unique_reduction_axes[current_axes] == i) {
ASSERT_EQ(output_shape->dim[i], 1);
++current_axes;
if (current_axes == unique_reduction_axes.size()) {
break;
}
} else {
ASSERT_EQ(output_shape->dim[i], input_shape[i]);
}
}

input_shape[0] -= 1;
ASSERT_EQ(xnn_status_success, xnn_reshape_external_value(runtime, input_id, input_shape.size(), input_shape.data()));
ASSERT_EQ(node->reshape(&runtime->opdata[0], runtime->values, runtime->num_values, /*threadpool=*/nullptr), xnn_status_success);
current_axes = 0;
for (size_t i = 0; i < output_shape->num_dims; ++i) {
if (unique_reduction_axes[current_axes] == i) {
ASSERT_EQ(output_shape->dim[i], 1);
++current_axes;
if (current_axes == unique_reduction_axes.size()) {
break;
}
} else {
ASSERT_EQ(output_shape->dim[i], input_shape[i]);
}
}
}

TEST_F(MeanTestF32, reshape_output_no_keep_dims)
{
ASSERT_EQ(xnn_status_success, xnn_initialize(/*allocator=*/nullptr));

// Call subgraph API.
xnn_subgraph_t subgraph = nullptr;
ASSERT_EQ(xnn_status_success, xnn_create_subgraph(2, /*flags=*/0, &subgraph));
std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);

uint32_t input_id = XNN_INVALID_NODE_ID;
ASSERT_EQ(xnn_status_success,
xnn_define_tensor_value(subgraph, xnn_datatype_fp32, input_shape.size(), input_shape.data(),
nullptr, /*external_id=*/0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id));
ASSERT_NE(input_id, XNN_INVALID_NODE_ID);

uint32_t output_id = XNN_INVALID_NODE_ID;
ASSERT_EQ(xnn_status_success,
xnn_define_tensor_value(subgraph, xnn_datatype_fp32, output_shape.size(), output_shape.data(),
nullptr, /*external_id=*/1, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id));
ASSERT_NE(output_id, XNN_INVALID_NODE_ID);

ASSERT_EQ(xnn_status_success,
xnn_define_static_mean(
subgraph,
reduction_axes.size(), reduction_axes.data(),
input_id, output_id,
/*flags=*/0));

xnn_runtime_t runtime = nullptr;
ASSERT_EQ(xnn_status_success, xnn_create_runtime_v3(subgraph, nullptr, nullptr, /*flags=*/0, &runtime));
ASSERT_NE(nullptr, runtime);
std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);

const std::array<xnn_external_value, 2> external = {
xnn_external_value{input_id, input.data()},
xnn_external_value{output_id, subgraph_output.data()}
};
ASSERT_EQ(xnn_status_success, xnn_setup_runtime(runtime, external.size(), external.data()));
ASSERT_EQ(xnn_status_success, xnn_invoke_runtime(runtime));

input_shape[0] += 2;
input_shape[1] += 4;
ASSERT_EQ(xnn_status_success, xnn_reshape_external_value(runtime, input_id, input_shape.size(), input_shape.data()));
const struct xnn_node* node = &subgraph->nodes[0];
std::vector<size_t> unique_reduction_axes = reduction_axes;
std::sort(unique_reduction_axes.begin(), unique_reduction_axes.end());
auto end = std::unique(unique_reduction_axes.begin(), unique_reduction_axes.end());
unique_reduction_axes.erase(end, unique_reduction_axes.end());
// There are too many parameters which influence the workspace size so
// knowing if reallocation is required or not is messy.
node->reshape(&runtime->opdata[0], runtime->values, runtime->num_values, /*threadpool=*/nullptr);
const xnn_shape* output_shape = &runtime->values[node->outputs[0]].shape;
size_t current_axes = 0;
size_t current_dim = 0;
for (size_t i = 0; i < input_shape.size(); ++i) {
if (unique_reduction_axes[current_axes] == i) {
++current_axes;
if (current_axes == unique_reduction_axes.size()) {
break;
}
} else {
ASSERT_EQ(output_shape->dim[current_dim], input_shape[i]);
++current_dim;
}
}

input_shape[0] -= 1;
ASSERT_EQ(xnn_status_success, xnn_reshape_external_value(runtime, input_id, input_shape.size(), input_shape.data()));
ASSERT_EQ(node->reshape(&runtime->opdata[0], runtime->values, runtime->num_values, /*threadpool=*/nullptr), xnn_status_success);
current_axes = 0;
current_dim = 0;
for (size_t i = 0; i < input_shape.size(); ++i) {
if (unique_reduction_axes[current_axes] == i) {
++current_axes;
if (current_axes == unique_reduction_axes.size()) {
break;
}
} else {
ASSERT_EQ(output_shape->dim[current_dim], input_shape[i]);
++current_dim;
}
}
}

} // namespace xnnpack

0 comments on commit 23f386b

Please sign in to comment.