Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Extend ONNX Frontend with com.microsoft.Pad #22000

Merged
merged 20 commits into from
Apr 5, 2024

Conversation

siddhant-0707
Copy link
Contributor

@siddhant-0707 siddhant-0707 commented Jan 6, 2024

Details:

  • created implementation for com.microsoft.Pad
  • creating test

Tickets:

@siddhant-0707 siddhant-0707 requested a review from a team as a code owner January 6, 2024 11:08
@github-actions github-actions bot added the category: ONNX FE OpenVINO ONNX FrontEnd label Jan 6, 2024
@ilya-lavrenov ilya-lavrenov added ExternalPR External contributor pr: needs tests PR needs tests updating labels Jan 8, 2024
Copy link
Contributor

This PR will be closed in a week because of 2 weeks of no activity.

@github-actions github-actions bot added the Stale label Jan 29, 2024
Copy link
Contributor

This PR will be closed in a week because of 2 weeks of no activity.

@github-actions github-actions bot added the Stale label Feb 16, 2024
Copy link
Contributor

@mitruska mitruska left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @siddhant-0707, thank you for your contribution!

MS ONNX Pad and ONNX standard Pad-11 seem to be compatible. The only difference I can see is that MS Pad allows pads input to be 2D, then the first dim is required to be 1, so it could be simply resolved by Squeeze. My suggestion is to reuse set_11::pad code within a common function to avoid duplication of the main logic. Squeeze can be applied to the MS pads input if needed, before the common part.

Also this PR doesn't have any tests for the provided changes, please add them including the mentioned 2D pads scenario.
Registration of the correct pad function should be ensured.

@@ -585,6 +585,7 @@ OperatorsBridge::OperatorsBridge() {
REGISTER_OPERATOR_WITH_DOMAIN(MICROSOFT_DOMAIN, "FusedConv", 1, fused_conv);
REGISTER_OPERATOR_WITH_DOMAIN(MICROSOFT_DOMAIN, "FusedGemm", 1, fusedgemm);
REGISTER_OPERATOR_WITH_DOMAIN(MICROSOFT_DOMAIN, "EmbedLayerNormalization", 1, embed_layer_normalization);
REGISTER_OPERATOR_WITH_DOMAIN(MICROSOFT_DOMAIN, "Pad", 1, pad);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new file src/frontends/onnx/frontend/src/op/com.microsoft/pad.hpp is not included here, so the created function is not visible.
It means that set_1::pad from /onnx/frontend/src/op/pad.hpp has been registered instead.

namespace set_1 {
ov::OutputVector pad(const ov::frontend::onnx::Node& node) {

To distinguish the pad function dedicated for conversion of MS Op, my suggestion is to put it in a different namespace or change name of the function, and register like:

Suggested change
REGISTER_OPERATOR_WITH_DOMAIN(MICROSOFT_DOMAIN, "Pad", 1, pad);
register_operator_in_custom_domain("Pad", VersionRange::single_version_for_all_opsets(), op::custom::set_1::pad, MICROSOFT_DOMAIN);

throw ngraph_error("Unsupported pad_mode in ONNX com.microsoft.Pad operator");
}

auto pads_shape = pads.get_shape();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The .get_shape() will throw for dynamic shape, it should be used only if static shape is ensured,
.get_partial_shape() can be used instead.
For partial shape .is_static() or .is_dynamic() can be used and .size() is safe only for shapes with static rank.
To check whether the rank is static: partial_shape.rank().is_static().

@mitruska mitruska self-assigned this Feb 16, 2024
@github-actions github-actions bot removed the Stale label Feb 17, 2024
Copy link
Contributor

github-actions bot commented Mar 4, 2024

This PR will be closed in a week because of 2 weeks of no activity.

@github-actions github-actions bot added the Stale label Mar 4, 2024
Copy link
Contributor

This PR was closed because it has been stalled for 2 week with no activity.

@github-actions github-actions bot closed this Mar 11, 2024
@mlukasze
Copy link
Contributor

hey @siddhant-0707 :)
will you continue a work on this PR? Or should we move a task back to the open pool?

@siddhant-0707
Copy link
Contributor Author

Apologies for the delay @mlukasze, I was a little occupied with other tasks. Please reopen the PR.
@mitruska I have reused the implementation of set_11::pad (with modification in beginning to handle case with 2d pads input). I've done this because we have to assign the squeezed value back to node.get_ng_inputs()[1] (which we can't do) Is there a way to create a new node and set its inputs and pass it to set_11::pad?.
Another way this problem can be solved is creating an overload in onnx::pad that takes in pads separately and then passes it to set_11::pad. Something like ov::OutputVector pad(const Node& node, const ov::Output<ov::Node>& pads) {}. However, this would misalign code with docs I think.

@siddhant-0707
Copy link
Contributor Author

Also added 2d and 1d tests

@mitruska
Copy link
Contributor

Hello @siddhant-0707, great to see some updates! Reopened the PR as you asked.
To continue review, please push commits with the changes you have mentioned and sync branch with the latest master.

@mitruska mitruska reopened this Mar 29, 2024
@github-actions github-actions bot removed the Stale label Mar 30, 2024
@gkrivor
Copy link
Contributor

gkrivor commented Apr 3, 2024

build_jenkins

Copy link
Contributor

@mitruska mitruska left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My previous concern (#22000 (comment)) about get_shape() call has been not applied.
Please add suggested change to avoid further issues.

About creating the common pad function, it could be a helper taking ov::OutputVector as argument, and called from the onnx standard pad and custom pad (after get_ov_inputs()).
But such code unification can be considered as follow up improvement within a separate PR.

Let's focus on the last necessary change:)

src/frontends/onnx/frontend/src/op/com.microsoft/pad.cpp Outdated Show resolved Hide resolved
Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com>
@mitruska
Copy link
Contributor

mitruska commented Apr 3, 2024

build_jenkins

@mitruska
Copy link
Contributor

mitruska commented Apr 4, 2024

build_jenkins

@mitruska mitruska removed the pr: needs tests PR needs tests updating label Apr 4, 2024
@gkrivor gkrivor added this pull request to the merge queue Apr 5, 2024
Merged via the queue into openvinotoolkit:master with commit 10c0d5b Apr 5, 2024
108 checks passed
@mlukasze mlukasze added this to the 2024.2 milestone Apr 5, 2024
bbielawx pushed a commit to bbielawx/openvino that referenced this pull request Apr 12, 2024
…#22000)

### Details:
 - created implementation for `com.microsoft.Pad`
 - creating test

### Tickets:
 - Closes openvinotoolkit#17576

---------

Co-authored-by: Georgy Krivoruchko <georgy.krivoruchko@intel.com>
Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com>
alvoron pushed a commit to alvoron/openvino that referenced this pull request Apr 29, 2024
…#22000)

### Details:
 - created implementation for `com.microsoft.Pad`
 - creating test

### Tickets:
 - Closes openvinotoolkit#17576

---------

Co-authored-by: Georgy Krivoruchko <georgy.krivoruchko@intel.com>
Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: ONNX FE OpenVINO ONNX FrontEnd ExternalPR External contributor
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Extend ONNX Frontend with com.microsoft.Pad operator
5 participants