Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shape inference for Expand using symbolic shape input #3789

Merged
merged 9 commits into from
Mar 23, 2022

Conversation

xuzijian629
Copy link
Member

@xuzijian629 xuzijian629 commented Oct 21, 2021

Description

Even when the shape input is not available as an initializer, we can infer the output shape using symbolic shape input.

Motivation

Existing shape inference cannot infer dim_value in some case like

agraph (float[3, 1, 2] x, float[1, 4, 2] y) => (float[3, 4, 2] z)
{
    ys = Shape(y)
    z = Expand(x, ys) 
}

Signed-off-by: Joe <joe@preferred.jp>
Signed-off-by: Joe <joe@preferred.jp>
@askhade
Copy link
Contributor

askhade commented Oct 27, 2021

#3807 also adds rank inference as fallback. It is much simpler can you please check. Thanks!

@gramalingam gramalingam added the shape inference Issues related to shape inference label Oct 27, 2021
@xuzijian629
Copy link
Member Author

@askhade I followed #3807 and simplified the implementation. Sorry for sooo late response..

onnx/defs/math/defs.cc Outdated Show resolved Hide resolved
Signed-off-by: Joe <joe@preferred.jp>
Copy link
Member

@jcwchen jcwchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you!

@gramalingam
Copy link
Contributor

LGTM, thanks! A minor comment: I think we would want to use the same logic for any input that denotes a shape (in any op). So, extracting this as a generic utility may be useful. Eg., a function like:

   TensorShapeProto getShapeInput(InferenceContext& ctx, size_t input_index);

Copy link
Contributor

@gramalingam gramalingam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Added a minor suggestion.

@xuzijian629
Copy link
Member Author

LGTM, thanks! A minor comment: I think we would want to use the same logic for any input that denotes a shape (in any op). So, extracting this as a generic utility may be useful. Eg., a function like:

   TensorShapeProto getShapeInput(InferenceContext& ctx, size_t input_index);

Thank you for the suggestion. I have a question about the implementation.
How do you handle the case when the shape input has neither initializer nor symbolic data? When the shape input does not even have shape, we cannot do anything, i.e., we cannot return a TensorShapeProto.
Maybe it's better to prepare a function like setShapeInputData(InferenceContext& ctx, size_t input_index, TensorShapeProto* target_shape) to avoid such situation?

@gramalingam
Copy link
Contributor

Good question! Your suggestion is fine, but I think we will need to return a boolean value to indicate whether a value was found or not.

I think this case is not being handled correctly by the existing code, so it probably needs to be fixed anyway. I think the existing code uses the default value assigned to TensorShapeProto second_shape; which has zero dimensions, which is incorrect. If no information is available, we should skip setting the output shape.

@xuzijian629
Copy link
Member Author

@gramalingam Sorry for late response! I implemented getShapeInput and used it in the shape inference of Expand 👍

Signed-off-by: Joe <joe@preferred.jp>
@gramalingam
Copy link
Contributor

Hi @xuzijian629 : thanks for the update! Would it be possible to fix the conflicts, so we can merge this in? Thanks!

@xuzijian629
Copy link
Member Author

Hi @gramalingam,
I resolved the conflicts and merged main.

memo: 37ae312 returns early if symbolic input does not have dim_value. In such case, GetShapeInput returns with found == false, so already handled.

@gramalingam
Copy link
Contributor

Great, thanks very much @xuzijian629

@jcwchen
Copy link
Member

jcwchen commented Mar 23, 2022

@gramalingam Could you please sign-off approval for this PR again and let's merge this one? Thanks!

@gramalingam gramalingam merged commit 4a316f1 into onnx:main Mar 23, 2022
@xuzijian629 xuzijian629 deleted the expand_inference branch March 24, 2022 00:09
wschin pushed a commit to wschin/onnx that referenced this pull request Apr 12, 2022
* Infer expand by shape input

Signed-off-by: Joe <joe@preferred.jp>

* Add test

Signed-off-by: Joe <joe@preferred.jp>

* Fix comments

Signed-off-by: Joe <joe@preferred.jp>

* Generalize to getShapeInput function

Signed-off-by: Joe <joe@preferred.jp>

Co-authored-by: Chun-Wei Chen <jacky82226@gmail.com>
broune pushed a commit to broune/onnx that referenced this pull request May 6, 2023
* Infer expand by shape input

Signed-off-by: Joe <joe@preferred.jp>

* Add test

Signed-off-by: Joe <joe@preferred.jp>

* Fix comments

Signed-off-by: Joe <joe@preferred.jp>

* Generalize to getShapeInput function

Signed-off-by: Joe <joe@preferred.jp>

Co-authored-by: Chun-Wei Chen <jacky82226@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
partial data propagation shape inference Issues related to shape inference
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

None yet

4 participants