-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix shape inference for Split with split attribute #2328
Conversation
onnx/defs/tensor/defs.cc
Outdated
axis += rank; | ||
} | ||
const auto& splitDim = shape.dim(axis); | ||
if (!splitDim.has_dim_value()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can still do rank inference as the ranks of the outputs will be the same as the rank of the input before returning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
onnx/defs/tensor/defs.cc
Outdated
} | ||
const auto& splitDim = shape.dim(axis); | ||
if (!splitDim.has_dim_value()) { | ||
if (totalDim != splitDimValue) { | ||
return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be an error ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
onnx/defs/tensor/defs.cc
Outdated
return; | ||
} | ||
int splitDimValue = static_cast<int>(splitDim.dim_value()); | ||
} else { | ||
int chunkSize = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per spec, it says if 'split' is not provided, split it into equal sized parts, so should the case where splitDimValue % ctx.getNumOutputs() != 0 be an error condition ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise, we might have to arbitrarily deal with it like the logic in 444-448. The logic is not prescribed per spec. So I think this should be an error condition as well. To me there are 3 possible error conditions in this op-
- When 'split' is present and its count does not match number of outputs
- When 'split' is present and its sum does not match input shape's corresponding value in the 'axis' dimension
- When 'split' is not present and input shape's corresponding value in the 'axis' dimension cannot be split evenly to the number of outputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Thanks for your review! I think I handled all of your comments. Please take a look again? |
onnx/test/shape_inference_test.py
Outdated
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (2, None, 'b')), # type: ignore | ||
make_tensor_value_info('z', TensorProto.FLOAT, (2, None, 'b'))]) # type: ignore | ||
|
||
def test_split_fail_with_invalid_split_attribute(self): # type: () -> None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I don't think we explicitly add tests for failure scenarios. Seems like these 2 may not be needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
@@ -1097,6 +1097,36 @@ def test_split_negative_axis(self): # type: () -> None | |||
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (2, 2)), | |||
make_tensor_value_info('z', TensorProto.FLOAT, (2, 2))]) | |||
|
|||
def test_split_with_split_attribute(self): # type: () -> None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you please add 2 more tests -
- One with split attribute is missing and the default 'equal' split occurs for each output
- One with (valid) negative axis
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't the existing test case, test_split_negative_axis
enough to cover these two cases? Do you want me to split it into two tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. :) I did not notice the other test.
onnx/defs/tensor/defs.cc
Outdated
} | ||
int chunkSize = splitDimValue / numOutputs; | ||
for (int i = 0; i < static_cast<int>(ctx.getNumOutputs()); i++) { | ||
split.push_back(chunkSize); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit 1: It seems a little ugly to store the same value 'ChunkSize' in the vector for each output. it would be great if the single value 'chunkSize' can just be handled below when you actually assign the dim value.
nit 2: 'chunkSize' doesn't seem c++ style. You can consider using 'chunk_size'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As for 2, fixed.
As for 1, let me double-check if you like the change, since I feel current code is better. If I follow your suggestion, I think we'll have two copies of the for loop which assigns dim_value
like
if (getRepeatedAttribute(ctx, "split", split)) {
...
for (size_t i = 0; i < ctx.getNumOutputs(); i++) {
*ctx.getOutputType(i)->mutable_tensor_type()->mutable_shape() =
shape;
ctx.getOutputType(i)
->mutable_tensor_type()
->mutable_shape()
->mutable_dim(axis)
->set_dim_value(split[i]);
}
} else {
...
int chunk_size = split_dim_value / num_outputs;
for (size_t i = 0; i < ctx.getNumOutputs(); i++) {
*ctx.getOutputType(i)->mutable_tensor_type()->mutable_shape() =
shape;
ctx.getOutputType(i)
->mutable_tensor_type()
->mutable_shape()
->mutable_dim(axis)
->set_dim_value(chunk_size);
}
}
Is this what you want?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fine. We can just keep it the way it is for now.
Just some minor comments to be addressed. Looks good otherwise. :) |
Thanks for your review! Let me check a couple of things before I actually change them. |
7efbc6b
to
f4de1e2
Compare
f4de1e2
to
e52f821
Compare
e52f821
to
fe79aeb
Compare
The test was randomly failing so I pushed it several times to trigger the rebuild. |
@hariharans29 Ping? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. @wschin ?
@wschin: Friendly reminder :) |
@gramalingam @ebarsoum @houseroad - pinging more people for reviews @shinh - Don't worry. This PR has been tagged for release 1.7. 1.6 just got done last week. So it will be picked up when the PRs start getting merged for the next release (maybe sooner). |
@hariharans29: Got it, thanks! I'm worried another person makes a similar patch and there'll be duplicated effort. It happened once in #1855 (and you were the author of #2041 :). |
@shinh - I see. Sorry, I should have checked if there was already a PR for Expand shape inference. Someone on my team needed it urgently, and I just wrote it without checking for any existing PRs. |
onnx#2328 did not fix the shape inference of the old Split.
onnx#2328 did not fix the shape inference of the old Split.
* Fix shape inference for Split with split attribute This fixes onnx#1735 * Set rank even when split dimension is unknown * Let shape inference fail for invalid split * Fail shape inference for inequal division * Ignore Python type error in the unittest * Remove tests for failure scenarios * Stop using camel case for local variables Co-authored-by: Ke Zhang <kezhan@microsoft.com>
This should be same as onnx#2328
This should be same as onnx#2328 Signed-off-by: Shinichiro Hamaji <shinichiro.hamaji@gmail.com>
* Fix shape inference of Split-2 with split attr This should be same as #2328 Signed-off-by: Shinichiro Hamaji <shinichiro.hamaji@gmail.com> * Allow negative axis even in Split-2 Signed-off-by: Shinichiro Hamaji <shinichiro.hamaji@gmail.com> * Update onnx/defs/tensor/old.cc Co-authored-by: Chun-Wei Chen <jacky82226@gmail.com> Signed-off-by: Shinichiro Hamaji <shinichiro.hamaji@gmail.com> Co-authored-by: Chun-Wei Chen <jacky82226@gmail.com>
This fixes #1735