-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Symbolic shape inference support-3: more ops for data propagation #3593
Conversation
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
@@ -1595,6 +1664,28 @@ ONNX_OPERATOR_SET_SCHEMA( | |||
: // i - axis < q | |||
data_shape.dim(i - q + 1); // i < out_rank < q + r - 1 | |||
} | |||
}) | |||
.PartialDataPropagationFunction([](DataPropagationContext& ctx) { | |||
if (!axisIsZero(ctx, true)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I wonder if we should restrict ourselves to the case where input(0) is 1-dimensional tensor? If it is multidimensional, then it is more complicated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do have this restriction now for every data propagation, right? (only 0D or 1D tensor can be getInputData)
For other N-D tensors (which N > 1), their getInputData() will be nullptr and it will stop data propagation.
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
onnx/defs/tensor/defs.cc
Outdated
@@ -778,6 +803,36 @@ Example 2: | |||
] | |||
)DOC"; | |||
|
|||
inline int64_t clamp(int64_t val, int64_t low, int64_t high) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be converted in 1 line:
return (val < low) ? low : (val > high) ? high : val;
also I suggest change low/high to min/max for better readability
another suggestion is you can simply use this 1 line instead of creating a inline function
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
onnx/defs/math/defs.cc
Outdated
} else if (input_dim_0.has_dim_param() || input_dim_1.has_dim_param()) { | ||
tsp.mutable_dim()->Add()->set_dim_param("?"); | ||
} else { | ||
// Dim is not set by value or param |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi: I suggest adding a dim in all cases, and setting its value only if both have values, like below:
auto* new_dim = tsp.mutable_dim()->Add();
if (input_dim_0.has_dim_value() && input_dim_1.has_dim_value()) {
new_dim->set_dim_value(
MathOpTwoIntegers(op_type, input_dim_0.dim_value(), input_dim_1.dim_value()));
}
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
(inferredShape.dim(i).has_dim_value() && expectedShape.dim(i).has_dim_value()) || | ||
(inferredShape.dim(i).has_dim_param() && expectedShape.dim(i).has_dim_param())) | ||
(inferredShape.dim(i).has_dim_value() == expectedShape.dim(i).has_dim_value()) && | ||
(inferredShape.dim(i).has_dim_param() == expectedShape.dim(i).has_dim_param())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update here to test with "?" shape (which means no value and no param because it is not computable) -- if the dim does not have value or parameter, it is still a valid dim.
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
Description
Support more ops for data propagation:
Also add tests for them.
Motivation and Context
#3551 ONNX has supported data propagation. Implement more data propagation functions for more ops