Skip to content

Commit

Permalink
Scan output axes (#1737)
Browse files Browse the repository at this point in the history
* add attribute scan_output_axes to scan op

* add test case and minor fix

* Generate op documentation
  • Loading branch information
gramalingam authored and houseroad committed Jan 17, 2019
1 parent 90920c0 commit ba05f26
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 87 deletions.
18 changes: 14 additions & 4 deletions docs/Changelog.md
Expand Up @@ -9197,6 +9197,9 @@ This version of the operator has been available since version 9 of the default O
well as scan_output_element tensors) are required to have the same shape in each iteration
of the loop (a restriction imposed to enable efficient memory allocation).

Note that the iterated element passed to the body subgraph does not have a sequence
axis. It will have a rank one less than the rank of the corresponding scan_input.

The scan operation returns the final values of the state_variables as well as the
scan_outputs.

Expand All @@ -9211,11 +9214,16 @@ This version of the operator has been available since version 9 of the default O
scan_output_element to scan_output in each iteration) for each scan_output. If this attribute
is omitted, the scan_output_element is appended to the scan_output in each iteration.

The optional attribute axes specifies the axis to be scanned for each scan_input.
The optional attribute scan_input_axes specifies the axis to be scanned for each scan_input.
If omitted, every scan_input will be scanned in axis 0. For example, if axis 0 is the
batch axis and axis 1 is the time axis (to be scanned), specify an axis value of 1.
Note that scanning a non-zero axis may be less efficient than scanning axis zero.

The optional attribute scan_output_axes specifies the axis along which the scan_outputs
are accumulated for each scan_output. For example, if axis 1 is the time axis (to be
scanned) for both inputs and outputs, specify a scan_input axis and scan_output axis
value of 1.

Note that because of the ONNX restriction that only the last parameter of an operator can
be variadic, the initial-states and scan-inputs are listed together as one input parameter.
Similarly, the final-states and scan-outputs are listed together as one output parameter.
Expand All @@ -9226,7 +9234,7 @@ This version of the operator has been available since version 9 of the default O
Scan <
num_scan_inputs = m,
body = loop-body,
axes = [axis_1, ..., axis_m]
scan_input_axes = [axis_1, ..., axis_m]
> (init_1, ..., init_n, scan_1, ..., scan_m)

is equivalent to the following pseudo-code:
Expand Down Expand Up @@ -9297,14 +9305,16 @@ This version of the operator has been available since version 9 of the default O
#### Attributes

<dl>
<dt><tt>axes</tt> : list of ints</dt>
<dd>An optional list of M flags. The i-th element of the list specifies the axis to be scanned (the sequence axis) for the i-th scan_input. If omitted, 0 will be used as the scan axis for every scan_input.</dd>
<dt><tt>body</tt> : graph (required)</dt>
<dd>The graph run each iteration. It has N+M inputs: (loop state variables..., scan_input_elts...). It has N+K outputs: (loop state variables..., scan_output_elts...). Each scan_output is created by concatenating the value of the specified scan_output_elt value at the end of each iteration of the loop. It is an error if the dimensions of these values change across loop iterations.</dd>
<dt><tt>num_scan_inputs</tt> : int (required)</dt>
<dd>An attribute specifying the number of scan_inputs M. </dd>
<dt><tt>scan_input_axes</tt> : list of ints</dt>
<dd>An optional list of M flags. The i-th element of the list specifies the axis to be scanned (the sequence axis) for the i-th scan_input. If omitted, 0 will be used as the scan axis for every scan_input.</dd>
<dt><tt>scan_input_directions</tt> : list of ints</dt>
<dd>An optional list of M flags. The i-th element of the list specifies the direction to be scanned for the i-th scan_input tensor: 0 indicates forward direction and 1 indicates reverse direction. If omitted, all scan_input tensors will be scanned in the forward direction.</dd>
<dt><tt>scan_output_axes</tt> : list of ints</dt>
<dd>An optional list of K flags. The i-th element of the list specifies the axis for the i-th scan_output. The scan outputs are accumulated along the specified axis. If omitted, 0 will be used as the scan axis for every scan_output.</dd>
<dt><tt>scan_output_directions</tt> : list of ints</dt>
<dd>An optional list of K flags, one for each scan_output. The i-th element of the list specifies whether the i-th scan_output should be constructed by appending or prepending a new value in each iteration: 0 indicates appending and 1 indicates prepending. If omitted, all scan_output tensors will be produced by appending a value in each iteration.</dd>
</dl>
Expand Down
18 changes: 14 additions & 4 deletions docs/Operators.md
Expand Up @@ -9739,6 +9739,9 @@ for test_name, shape in test_cases.items():
well as scan_output_element tensors) are required to have the same shape in each iteration
of the loop (a restriction imposed to enable efficient memory allocation).

Note that the iterated element passed to the body subgraph does not have a sequence
axis. It will have a rank one less than the rank of the corresponding scan_input.

The scan operation returns the final values of the state_variables as well as the
scan_outputs.

Expand All @@ -9753,11 +9756,16 @@ for test_name, shape in test_cases.items():
scan_output_element to scan_output in each iteration) for each scan_output. If this attribute
is omitted, the scan_output_element is appended to the scan_output in each iteration.

The optional attribute axes specifies the axis to be scanned for each scan_input.
The optional attribute scan_input_axes specifies the axis to be scanned for each scan_input.
If omitted, every scan_input will be scanned in axis 0. For example, if axis 0 is the
batch axis and axis 1 is the time axis (to be scanned), specify an axis value of 1.
Note that scanning a non-zero axis may be less efficient than scanning axis zero.

The optional attribute scan_output_axes specifies the axis along which the scan_outputs
are accumulated for each scan_output. For example, if axis 1 is the time axis (to be
scanned) for both inputs and outputs, specify a scan_input axis and scan_output axis
value of 1.

Note that because of the ONNX restriction that only the last parameter of an operator can
be variadic, the initial-states and scan-inputs are listed together as one input parameter.
Similarly, the final-states and scan-outputs are listed together as one output parameter.
Expand All @@ -9768,7 +9776,7 @@ for test_name, shape in test_cases.items():
Scan <
num_scan_inputs = m,
body = loop-body,
axes = [axis_1, ..., axis_m]
scan_input_axes = [axis_1, ..., axis_m]
> (init_1, ..., init_n, scan_1, ..., scan_m)

is equivalent to the following pseudo-code:
Expand Down Expand Up @@ -9841,14 +9849,16 @@ Other versions of this operator: <a href="Changelog.md#Scan-8">Scan-8</a>
#### Attributes

<dl>
<dt><tt>axes</tt> : list of ints</dt>
<dd>An optional list of M flags. The i-th element of the list specifies the axis to be scanned (the sequence axis) for the i-th scan_input. If omitted, 0 will be used as the scan axis for every scan_input.</dd>
<dt><tt>body</tt> : graph (required)</dt>
<dd>The graph run each iteration. It has N+M inputs: (loop state variables..., scan_input_elts...). It has N+K outputs: (loop state variables..., scan_output_elts...). Each scan_output is created by concatenating the value of the specified scan_output_elt value at the end of each iteration of the loop. It is an error if the dimensions of these values change across loop iterations.</dd>
<dt><tt>num_scan_inputs</tt> : int (required)</dt>
<dd>An attribute specifying the number of scan_inputs M. </dd>
<dt><tt>scan_input_axes</tt> : list of ints</dt>
<dd>An optional list of M flags. The i-th element of the list specifies the axis to be scanned (the sequence axis) for the i-th scan_input. If omitted, 0 will be used as the scan axis for every scan_input.</dd>
<dt><tt>scan_input_directions</tt> : list of ints</dt>
<dd>An optional list of M flags. The i-th element of the list specifies the direction to be scanned for the i-th scan_input tensor: 0 indicates forward direction and 1 indicates reverse direction. If omitted, all scan_input tensors will be scanned in the forward direction.</dd>
<dt><tt>scan_output_axes</tt> : list of ints</dt>
<dd>An optional list of K flags. The i-th element of the list specifies the axis for the i-th scan_output. The scan outputs are accumulated along the specified axis. If omitted, 0 will be used as the scan axis for every scan_output.</dd>
<dt><tt>scan_output_directions</tt> : list of ints</dt>
<dd>An optional list of K flags, one for each scan_output. The i-th element of the list specifies whether the i-th scan_output should be constructed by appending or prepending a new value in each iteration: 0 indicates appending and 1 indicates prepending. If omitted, all scan_output tensors will be produced by appending a value in each iteration.</dd>
</dl>
Expand Down
117 changes: 74 additions & 43 deletions onnx/defs/controlflow/defs.cc
Expand Up @@ -10,18 +10,32 @@ void ScanInferenceFunction(InferenceContext& ctx) {
auto num_scan_inputs =
narrow_cast<size_t>(ctx.getAttribute("num_scan_inputs")->i());
auto num_loop_state_vars = num_inputs - num_scan_inputs;
auto num_outputs = ctx.getNumOutputs();
auto num_scan_outputs = num_outputs - num_loop_state_vars;

std::vector<int64_t> axes;
bool axes_specified = false;
if (getRepeatedAttribute(ctx, "axes", axes)) {
axes_specified = true;
std::vector<int64_t> axes, output_axes;
if (getRepeatedAttribute(ctx, "scan_input_axes", axes)) {
if (axes.size() != num_scan_inputs)
fail_shape_inference(
"Number of axes specified (",
"Number of scan input axes specified (",
axes.size(),
") is not equal to number of scan inputs (",
num_scan_inputs,
").");
} else {
axes.insert(axes.end(), num_scan_inputs, 0);
}

if (getRepeatedAttribute(ctx, "scan_output_axes", output_axes)) {
if (output_axes.size() != num_scan_outputs)
fail_shape_inference(
"Number of scan output axes specified (",
output_axes.size(),
") is not equal to number of scan outputs (",
num_scan_outputs,
").");
} else {
output_axes.insert(output_axes.end(), num_scan_inputs, 0);
}

std::vector<TypeProto> temporary_type_protos;
Expand Down Expand Up @@ -58,9 +72,7 @@ void ScanInferenceFunction(InferenceContext& ctx) {
// need to remove the sequence length dimensions from the shape.
if (has_shape) {
// remove sequence length dimensions and add to subgraph_input_types
int axis = (axes_specified)
? static_cast<int>(axes[i - num_loop_state_vars])
: 0;
int axis = static_cast<int>(axes[i - num_loop_state_vars]);

// update sequence_len if a value is available
const auto& shape = input_type->tensor_type().shape();
Expand Down Expand Up @@ -101,7 +113,6 @@ void ScanInferenceFunction(InferenceContext& ctx) {

// if empty(), assume inferencing was skipped
if (!output_types.empty()) {
auto num_outputs = ctx.getNumOutputs();
if (output_types.size() != num_outputs) {
fail_type_inference(
"Graph attribute inferencing returned type information for ",
Expand All @@ -115,47 +126,52 @@ void ScanInferenceFunction(InferenceContext& ctx) {
const bool is_loop_state_var = i < num_loop_state_vars;
auto* subgraph_output_type = output_types[i];
auto* scan_output_type = ctx.getOutputType(i);
auto* mutable_scan_output_tensor_type =
scan_output_type->mutable_tensor_type();

if (!subgraph_output_type->has_tensor_type()) {
fail_type_inference(
"Scan 'body' subgraph outputs should all be tensors but output ",
i,
" was not");
}
auto& subgraph_output_tensor_type = subgraph_output_type->tensor_type();

// propagate output type. loop state vars were done in the above code.
if (!is_loop_state_var) {
if (is_loop_state_var) {
// merge shape; type already propagated
mergeInShapeInfo(
subgraph_output_tensor_type, *mutable_scan_output_tensor_type);
} else {
scan_output_type->mutable_tensor_type()->set_elem_type(
subgraph_output_type->tensor_type().elem_type());
}

// propagate shape
if (subgraph_output_type->tensor_type().has_shape()) {
// we need to add in sequence length value if
// available before merging with any existing info. Create a copy of
// the inferred type info from the subgraph to do that.
TypeProto inferred_type(*subgraph_output_type);
auto* mutable_inferred_tensor_type =
inferred_type.mutable_tensor_type();
auto* mutable_inferred_shape =
mutable_inferred_tensor_type->mutable_shape();

mutable_inferred_shape->clear_dim();

if (!is_loop_state_var) {
*mutable_inferred_shape->add_dim() = sequence_len_dim;
}

for (const auto& dim :
subgraph_output_type->tensor_type().shape().dim()) {
(*mutable_inferred_shape->add_dim()) = dim;
subgraph_output_tensor_type.elem_type());

// propagate shape
if (subgraph_output_tensor_type.has_shape()) {
// infer shape of scan-output from the shape of scan-output-element
// by adding sequence-length at the correct axis position
const TensorShapeProto& subgraph_output_shape =
subgraph_output_tensor_type.shape();
TensorShapeProto inferred_shape;

int output_axis =
static_cast<int>(output_axes[i - num_loop_state_vars]);
auto subgraph_output_rank = subgraph_output_shape.dim_size();
if (output_axis < 0 || output_axis > subgraph_output_rank)
fail_shape_inference(
"The output axis value ",
output_axis,
"specified is not consistent with the rank of subgraph output ",
subgraph_output_rank);

for (int j = 0; j < output_axis; ++j)
*(inferred_shape.add_dim()) = subgraph_output_shape.dim(j);
*(inferred_shape.add_dim()) = sequence_len_dim;
for (int j = output_axis; j < subgraph_output_rank; ++j)
*(inferred_shape.add_dim()) = subgraph_output_shape.dim(j);

// Merge inferred shape with existing shape information
mergeInShapeInfo(inferred_shape, *mutable_scan_output_tensor_type);
}

auto* mutable_scan_output_tensor_type =
scan_output_type->mutable_tensor_type();

mergeInShapeInfo(
*mutable_inferred_tensor_type, *mutable_scan_output_tensor_type);
}
}
}
Expand Down Expand Up @@ -578,6 +594,9 @@ hidden-state values of RNN-like constructs). All the output tensors (state_varia
well as scan_output_element tensors) are required to have the same shape in each iteration
of the loop (a restriction imposed to enable efficient memory allocation).
Note that the iterated element passed to the body subgraph does not have a sequence
axis. It will have a rank one less than the rank of the corresponding scan_input.
The scan operation returns the final values of the state_variables as well as the
scan_outputs.
Expand All @@ -592,11 +611,16 @@ specifies the direction in which scan_output is constructed (by appending or pre
scan_output_element to scan_output in each iteration) for each scan_output. If this attribute
is omitted, the scan_output_element is appended to the scan_output in each iteration.
The optional attribute axes specifies the axis to be scanned for each scan_input.
The optional attribute scan_input_axes specifies the axis to be scanned for each scan_input.
If omitted, every scan_input will be scanned in axis 0. For example, if axis 0 is the
batch axis and axis 1 is the time axis (to be scanned), specify an axis value of 1.
Note that scanning a non-zero axis may be less efficient than scanning axis zero.
The optional attribute scan_output_axes specifies the axis along which the scan_outputs
are accumulated for each scan_output. For example, if axis 1 is the time axis (to be
scanned) for both inputs and outputs, specify a scan_input axis and scan_output axis
value of 1.
Note that because of the ONNX restriction that only the last parameter of an operator can
be variadic, the initial-states and scan-inputs are listed together as one input parameter.
Similarly, the final-states and scan-outputs are listed together as one output parameter.
Expand All @@ -607,7 +631,7 @@ The behavior of
Scan <
num_scan_inputs = m,
body = loop-body,
axes = [axis_1, ..., axis_m]
scan_input_axes = [axis_1, ..., axis_m]
> (init_1, ..., init_n, scan_1, ..., scan_m)
is equivalent to the following pseudo-code:
Expand Down Expand Up @@ -725,12 +749,19 @@ ONNX_OPERATOR_SET_SCHEMA(
AttributeProto::INTS,
false)
.Attr(
"axes",
"scan_input_axes",
"An optional list of M flags. The i-th element of the list specifies the axis "
"to be scanned (the sequence axis) for the i-th scan_input. If omitted, 0 will "
"be used as the scan axis for every scan_input.",
AttributeProto::INTS,
false)
.Attr(
"scan_output_axes",
"An optional list of K flags. The i-th element of the list specifies the axis "
"for the i-th scan_output. The scan outputs are accumulated along the specified "
"axis. If omitted, 0 will be used as the scan axis for every scan_output.",
AttributeProto::INTS,
false)
.TypeConstraint("I", {"tensor(int64)"}, "Int64 tensor")
.TypeConstraint("V", OpSchema::all_tensor_types(), "All Tensor types")
.TypeAndShapeInferenceFunction(ScanInferenceFunction));
Expand Down

0 comments on commit ba05f26

Please sign in to comment.