Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symbolic shape inference support-3: more ops for data propagation #3593

Merged
merged 17 commits into from
Jul 21, 2021

Conversation

jcwchen
Copy link
Member

@jcwchen jcwchen commented Jul 19, 2021

Description
Support more ops for data propagation:

  • Add
  • Sub
  • Mul
  • Gather
  • Slice
  • Concat

Also add tests for them.

Motivation and Context
#3551 ONNX has supported data propagation. Implement more data propagation functions for more ops

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
@jcwchen jcwchen added the shape inference Issues related to shape inference label Jul 19, 2021
@jcwchen jcwchen added this to the 1.10 milestone Jul 19, 2021
@jcwchen jcwchen requested review from a team as code owners July 19, 2021 23:58
@jcwchen jcwchen changed the title Symbolic shape inference support-2: more ops for data propagation Symbolic shape inference support-3: more ops for data propagation Jul 19, 2021
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
@jcwchen jcwchen added the run release CIs Use this label to trigger release tests in CI label Jul 20, 2021
@@ -1595,6 +1664,28 @@ ONNX_OPERATOR_SET_SCHEMA(
: // i - axis < q
data_shape.dim(i - q + 1); // i < out_rank < q + r - 1
}
})
.PartialDataPropagationFunction([](DataPropagationContext& ctx) {
if (!axisIsZero(ctx, true)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I wonder if we should restrict ourselves to the case where input(0) is 1-dimensional tensor? If it is multidimensional, then it is more complicated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do have this restriction now for every data propagation, right? (only 0D or 1D tensor can be getInputData)
For other N-D tensors (which N > 1), their getInputData() will be nullptr and it will stop data propagation.

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
@@ -778,6 +803,36 @@ Example 2:
]
)DOC";

inline int64_t clamp(int64_t val, int64_t low, int64_t high) {
Copy link
Contributor

@askhade askhade Jul 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be converted in 1 line:
return (val < low) ? low : (val > high) ? high : val;

also I suggest change low/high to min/max for better readability

another suggestion is you can simply use this 1 line instead of creating a inline function

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
} else if (input_dim_0.has_dim_param() || input_dim_1.has_dim_param()) {
tsp.mutable_dim()->Add()->set_dim_param("?");
} else {
// Dim is not set by value or param
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi: I suggest adding a dim in all cases, and setting its value only if both have values, like below:

auto* new_dim = tsp.mutable_dim()->Add();
if (input_dim_0.has_dim_value() && input_dim_1.has_dim_value()) {
   new_dim->set_dim_value(
          MathOpTwoIntegers(op_type, input_dim_0.dim_value(), input_dim_1.dim_value()));
}

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
(inferredShape.dim(i).has_dim_value() && expectedShape.dim(i).has_dim_value()) ||
(inferredShape.dim(i).has_dim_param() && expectedShape.dim(i).has_dim_param()))
(inferredShape.dim(i).has_dim_value() == expectedShape.dim(i).has_dim_value()) &&
(inferredShape.dim(i).has_dim_param() == expectedShape.dim(i).has_dim_param()))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update here to test with "?" shape (which means no value and no param because it is not computable) -- if the dim does not have value or parameter, it is still a valid dim.

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
@askhade askhade merged commit 623dfaa into onnx:master Jul 21, 2021
@jcwchen jcwchen deleted the jcw/symbolic-3 branch August 9, 2021 15:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run release CIs Use this label to trigger release tests in CI shape inference Issues related to shape inference
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants