Skip to content

Commit

Permalink
Merge branch 'master' into fork2onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
linkerzhang committed May 14, 2018
2 parents 657b4c9 + 330fd0f commit 5d9ffa1
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 17 deletions.
2 changes: 0 additions & 2 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -1326,8 +1326,6 @@ This version of the operator has been available since version 1 of the default O
<dd>Specify if the RNN is forward, reverse, or bidirectional. Must be one of forward (default), reverse, or bidirectional.</dd>
<dt><tt>hidden_size</tt> : int</dt>
<dd>Number of neurons in the hidden layer</dd>
<dt><tt>linear_before_reset</tt> : int</dt>
<dd>When computing the output of the hidden gate, apply the linear transformation before multiplying by the output of the reset gate.</dd>
<dt><tt>output_sequence</tt> : int</dt>
<dd>The sequence output for the hidden is optional if 0. Default 0.</dd>
</dl>
Expand Down
41 changes: 34 additions & 7 deletions onnx/defs/math/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,28 @@ Given two equivalent values, this operator uses the indices along the axis as
"Dimension on which to do the sort. Default -1, which indicates the last"
" axis",
AttributeProto::INT,
static_cast<int64_t>(-1));
static_cast<int64_t>(-1))
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// Type inference:
propagateElemTypeFromInputToOutput(ctx, 0, 0);
updateOutputElemType(ctx, 1, TensorProto::INT64);

// Shape inference:
if (!hasInputShape(ctx, 0))
return;
auto& input_shape = getInputShape(ctx, 0);
int64_t rank = input_shape.dim_size();
int64_t axis = getAttribute(ctx, "axis", -1);
if (axis < 0) axis += rank;
if (axis < 0 || axis >= rank) return; // erroneous attribute value
int64_t k = getAttribute(ctx, "k", -1);
if (k <= 0) return; // erroneous attribute value
// TODO: unclear what results should be if axis has less than k elements.
TensorShapeProto result_shape = input_shape;
result_shape.mutable_dim(static_cast<int>(axis))->set_dim_value(k);
updateOutputShape(ctx, 0, result_shape);
updateOutputShape(ctx, 1, result_shape);
});

ONNX_OPERATOR_SCHEMA(Sin)
.SinceVersion(7)
Expand All @@ -801,7 +822,8 @@ Calculates the sine of the given input tensor, element-wise.
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input and output types to float tensors.");
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput);

ONNX_OPERATOR_SCHEMA(Cos)
.SinceVersion(7)
Expand All @@ -818,7 +840,8 @@ Calculates the cosine of the given input tensor, element-wise.
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input and output types to float tensors.");
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput);


ONNX_OPERATOR_SCHEMA(Tan)
Expand All @@ -836,7 +859,8 @@ Calculates the tangent of the given input tensor, element-wise.
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input and output types to float tensors.");
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput);

ONNX_OPERATOR_SCHEMA(Asin)
.SinceVersion(7)
Expand All @@ -853,7 +877,8 @@ Calculates the arcsine (inverse of sine) of the given input tensor, element-wise
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input and output types to float tensors.");
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput);

ONNX_OPERATOR_SCHEMA(Acos)
.SinceVersion(7)
Expand All @@ -870,7 +895,8 @@ Calculates the arccosine (inverse of cosine) of the given input tensor, element-
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input and output types to float tensors.");
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput);


ONNX_OPERATOR_SCHEMA(Atan)
Expand All @@ -888,4 +914,5 @@ Calculates the arctangent (inverse of tangent) of the given input tensor, elemen
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input and output types to float tensors.");
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput);
7 changes: 0 additions & 7 deletions onnx/defs/rnn/old.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,6 @@ Equations (Default: f=Sigmoid, g=Tanh):
"for default if not specified.",
AttributeProto::STRINGS,
OPTIONAL)
.Attr(
"linear_before_reset",
"When computing the output of the hidden gate, "
"apply the linear transformation before multiplying by the output of the "
"reset gate.",
AttributeProto::INT,
OPTIONAL)
.Input(
1,
"W",
Expand Down
8 changes: 7 additions & 1 deletion onnx/defs/tensor/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,13 @@ For example A = [[1, 2], [3, 4]], B = [1, 2], tile(A, B) = [[1, 2, 1, 2], [3, 4,
.TypeConstraint(
"T1",
{"tensor(int64)"},
"Constrain repeat's type to int64 tensors.");
"Constrain repeat's type to int64 tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Only rank of output can be inferred. We can do better if second input is
// a constant, but this requires extending InferenceContext interface to
// get values of constant inputs.
});

ONNX_OPERATOR_SCHEMA(Upsample)
.SinceVersion(7)
Expand Down
18 changes: 18 additions & 0 deletions onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,24 @@ def test_depth_to_space(self):
[])
self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.FLOAT, (2, 3, 100, 100))])

def test_topk_default_axis(self):
graph = self._make_graph(
[('x', TensorProto.FLOAT, (3, 4, 5, 10))],
[make_node('TopK', ['x'], ['y', 'z'], k=2)],
[])
self._assert_inferred(graph,
[make_tensor_value_info('y', TensorProto.FLOAT, (3, 4, 5, 2)),
make_tensor_value_info('z', TensorProto.INT64, (3, 4, 5, 2))])

def test_topk(self):
graph = self._make_graph(
[('x', TensorProto.FLOAT, (3, 4, 5, 10))],
[make_node('TopK', ['x'], ['y', 'z'], k=2, axis=2)],
[])
self._assert_inferred(graph,
[make_tensor_value_info('y', TensorProto.FLOAT, (3, 4, 2, 10)),
make_tensor_value_info('z', TensorProto.INT64, (3, 4, 2, 10))])

def test_gemm(self):
graph = self._make_graph(
[('x', TensorProto.FLOAT, (7, 5)),
Expand Down

0 comments on commit 5d9ffa1

Please sign in to comment.