-
Notifications
You must be signed in to change notification settings - Fork 94
Check inputs num for node matching #1885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Check inputs num for node matching #1885
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1885 +/- ##
==========================================
+ Coverage 75.02% 75.08% +0.06%
==========================================
Files 252 252
Lines 27415 27417 +2
Branches 5012 3190 -1822
==========================================
+ Hits 20567 20587 +20
- Misses 5875 5880 +5
+ Partials 973 950 -23 ☔ View full report in Codecov by Sentry. |
|
Thanks for the PR. Could you provide a little bit of context on why this is needed? |
Suppose we need to remove an optional input of a node, for example, we want to replace The solusion: class DFTSimplify(pat.RewriteRuleAsClass):
@classmethod
def pattern(
cls, ope: pat.OpsetPatternBuilder, inp: pat.Var, axis: pat.Var, onesided: pat.Var
) -> pat.NodeOutputPattern:
ret = ope.DFT(inp, None, axis, onesided=onesided)
assert isinstance(ret, pat.NodeOutputPattern)
return ret
@classmethod
def rewrite(
cls, ope: pat.RewriterContext, inp: ir.Value, axis: ir.Value, onesided: ir.Attr
) -> ir.Value:
del axis
return ope.DFT(inp, dft_length=None, axis=None, onesided=onesided)
@classmethod
def check(cls, _context: None, inp: ir.Value, axis: ir.Value, onesided: ir.Attr) -> bool:
del onesided
logging.info("check")
if axis.const_value is None:
return False
value = axis.const_value.numpy()
axis_value = value.item()
if axis_value < 0:
return axis_value == -2
if inp.shape is None:
return False
return axis_value + 2 == len(inp.shape)The code will crash when matching a normal |
8b31a6f to
0dbc14c
Compare
|
Could you fix the lint errors? https://github.com/microsoft/onnxscript/actions/runs/11229454770/job/31245771279?pr=1885 Thanks |
0dbc14c to
5a62e2d
Compare
Fixed. |
Signed-off-by: xuzhenqi <xuzhenqi@didiglobal.com>
5a62e2d to
4419f27
Compare
Fix inputs num mismatch for node matching.
cc @justinchuby @gramalingam