Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shape inference for Split with split attribute #2328

Merged
merged 12 commits into from
Jan 14, 2020

Conversation

shinh
Copy link
Member

@shinh shinh commented Sep 18, 2019

This fixes #1735

@shinh shinh requested review from a team as code owners September 18, 2019 00:27
@CLAassistant
Copy link

CLAassistant commented Sep 18, 2019

CLA assistant check
All committers have signed the CLA.

axis += rank;
}
const auto& splitDim = shape.dim(axis);
if (!splitDim.has_dim_value()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can still do rank inference as the ranks of the outputs will be the same as the rank of the input before returning.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
const auto& splitDim = shape.dim(axis);
if (!splitDim.has_dim_value()) {
if (totalDim != splitDimValue) {
return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be an error ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return;
}
int splitDimValue = static_cast<int>(splitDim.dim_value());
} else {
int chunkSize =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per spec, it says if 'split' is not provided, split it into equal sized parts, so should the case where splitDimValue % ctx.getNumOutputs() != 0 be an error condition ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, we might have to arbitrarily deal with it like the logic in 444-448. The logic is not prescribed per spec. So I think this should be an error condition as well. To me there are 3 possible error conditions in this op-

  1. When 'split' is present and its count does not match number of outputs
  2. When 'split' is present and its sum does not match input shape's corresponding value in the 'axis' dimension
  3. When 'split' is not present and input shape's corresponding value in the 'axis' dimension cannot be split evenly to the number of outputs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@shinh
Copy link
Member Author

shinh commented Sep 19, 2019

Thanks for your review! I think I handled all of your comments. Please take a look again?

self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (2, None, 'b')), # type: ignore
make_tensor_value_info('z', TensorProto.FLOAT, (2, None, 'b'))]) # type: ignore

def test_split_fail_with_invalid_split_attribute(self): # type: () -> None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I don't think we explicitly add tests for failure scenarios. Seems like these 2 may not be needed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

@@ -1097,6 +1097,36 @@ def test_split_negative_axis(self): # type: () -> None
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (2, 2)),
make_tensor_value_info('z', TensorProto.FLOAT, (2, 2))])

def test_split_with_split_attribute(self): # type: () -> None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please add 2 more tests -

  1. One with split attribute is missing and the default 'equal' split occurs for each output
  2. One with (valid) negative axis

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the existing test case, test_split_negative_axis enough to cover these two cases? Do you want me to split it into two tests?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. :) I did not notice the other test.

}
int chunkSize = splitDimValue / numOutputs;
for (int i = 0; i < static_cast<int>(ctx.getNumOutputs()); i++) {
split.push_back(chunkSize);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit 1: It seems a little ugly to store the same value 'ChunkSize' in the vector for each output. it would be great if the single value 'chunkSize' can just be handled below when you actually assign the dim value.

nit 2: 'chunkSize' doesn't seem c++ style. You can consider using 'chunk_size'

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for 2, fixed.

As for 1, let me double-check if you like the change, since I feel current code is better. If I follow your suggestion, I think we'll have two copies of the for loop which assigns dim_value like

if (getRepeatedAttribute(ctx, "split", split)) {
   ...
   for (size_t i = 0; i < ctx.getNumOutputs(); i++) {
       *ctx.getOutputType(i)->mutable_tensor_type()->mutable_shape() =
            shape;
        ctx.getOutputType(i)
            ->mutable_tensor_type()
             ->mutable_shape()
             ->mutable_dim(axis)
             ->set_dim_value(split[i]);
    }
} else {
  ...
  int chunk_size = split_dim_value / num_outputs;
   for (size_t i = 0; i < ctx.getNumOutputs(); i++) {
       *ctx.getOutputType(i)->mutable_tensor_type()->mutable_shape() =
            shape;
        ctx.getOutputType(i)
            ->mutable_tensor_type()
             ->mutable_shape()
             ->mutable_dim(axis)
             ->set_dim_value(chunk_size);
    }
}

Is this what you want?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine. We can just keep it the way it is for now.

@hariharans29
Copy link
Contributor

Thanks for your review! I think I handled all of your comments. Please take a look again?

Just some minor comments to be addressed. Looks good otherwise. :)

@shinh
Copy link
Member Author

shinh commented Sep 20, 2019

Thanks for your review! Let me check a couple of things before I actually change them.

@shinh
Copy link
Member Author

shinh commented Sep 24, 2019

The test was randomly failing so I pushed it several times to trigger the rebuild.

@wschin wschin added this to the 1.7 milestone Sep 25, 2019
@wschin wschin added the operator Issues related to ONNX operators label Sep 25, 2019
@shinh
Copy link
Member Author

shinh commented Sep 27, 2019

@hariharans29 Ping?

Copy link
Contributor

@hariharans29 hariharans29 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. @wschin ?

@shinh
Copy link
Member Author

shinh commented Oct 4, 2019

@wschin: Friendly reminder :)

@hariharans29
Copy link
Contributor

hariharans29 commented Oct 4, 2019

@gramalingam @ebarsoum @houseroad - pinging more people for reviews

@shinh - Don't worry. This PR has been tagged for release 1.7. 1.6 just got done last week. So it will be picked up when the PRs start getting merged for the next release (maybe sooner).

@shinh
Copy link
Member Author

shinh commented Oct 4, 2019

@hariharans29: Got it, thanks! I'm worried another person makes a similar patch and there'll be duplicated effort. It happened once in #1855 (and you were the author of #2041 :).

@hariharans29
Copy link
Contributor

@shinh - I see. Sorry, I should have checked if there was already a PR for Expand shape inference. Someone on my team needed it urgently, and I just wrote it without checking for any existing PRs.

@linkerzhang linkerzhang merged commit dfa4384 into onnx:master Jan 14, 2020
shinh added a commit to shinh/onnx that referenced this pull request Jan 15, 2020
onnx#2328 did not fix the shape
inference of the old Split.
shinh added a commit to shinh/onnx that referenced this pull request Feb 27, 2020
onnx#2328 did not fix the shape
inference of the old Split.
jcwchen pushed a commit to jcwchen/onnx that referenced this pull request Sep 23, 2020
* Fix shape inference for Split with split attribute

This fixes onnx#1735

* Set rank even when split dimension is unknown

* Let shape inference fail for invalid split

* Fail shape inference for inequal division

* Ignore Python type error in the unittest

* Remove tests for failure scenarios

* Stop using camel case for local variables

Co-authored-by: Ke Zhang <kezhan@microsoft.com>
shinh added a commit to shinh/onnx that referenced this pull request Dec 29, 2020
shinh added a commit to shinh/onnx that referenced this pull request Dec 29, 2020
This should be same as onnx#2328

Signed-off-by: Shinichiro Hamaji <shinichiro.hamaji@gmail.com>
gramalingam pushed a commit that referenced this pull request Jan 6, 2021
* Fix shape inference of Split-2 with split attr

This should be same as #2328

Signed-off-by: Shinichiro Hamaji <shinichiro.hamaji@gmail.com>

* Allow negative axis even in Split-2

Signed-off-by: Shinichiro Hamaji <shinichiro.hamaji@gmail.com>

* Update onnx/defs/tensor/old.cc

Co-authored-by: Chun-Wei Chen <jacky82226@gmail.com>
Signed-off-by: Shinichiro Hamaji <shinichiro.hamaji@gmail.com>

Co-authored-by: Chun-Wei Chen <jacky82226@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
operator Issues related to ONNX operators
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Shape inference fails with a Split node with an split attribute
6 participants