Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
  • Loading branch information
justinchuby committed Oct 17, 2023
1 parent e4677c0 commit 13e8b6f
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions onnx/test/version_converter/automatic_upgrade_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@


class TestAutomaticUpgrade(automatic_conversion_test_base.TestAutomaticConversion):
@classmethod
def setUpClass(cls):
self.tested_ops = []

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name self.
See https://beta.ruff.rs/docs/rules/

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "self" is not defined To disable, use # type: ignore[name-defined]

def _test_op_upgrade(self, op, *args, **kwargs):
self.tested_ops.append(op)
self._test_op_conversion(op, *args, **kwargs, is_upgrade=True)

def test_Abs(self) -> None:
Expand Down Expand Up @@ -1735,6 +1740,32 @@ def test_RegexFullMatch(self) -> None:
[TensorProto.BOOL],
)

def test_ops_tested(self) -> None:
# NOTE: This test is order dependent and needs to run last in this class
all_schemas = onnx.defs.get_all_schemas()
all_op_names = {schema.name for schema in all_schemas if schema.domain == ""}
excluded_ops = {
# Sequence-based and Optional-based ops disabled because
# the version converter doesn't play nicely with sequences
"ConcatFromSequence",
"SequenceAt",
"SequenceConstruct",
"SequenceEmpty",
"SequenceErase",
"SequenceInsert",
"SequenceLength",
"SequenceMap",
"SplitToSequence",
"Optional",
"OptionalGetElement",
"OptionalHasElement",
"StringSplit",
}
expected_tested_ops = all_op_names - excluded_ops

untested_ops = expected_tested_ops - set(tested_ops)

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name tested\_ops.
See https://beta.ruff.rs/docs/rules/

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "tested_ops" is not defined To disable, use # type: ignore[name-defined]
self.assertEqual(untested_ops, {})


if __name__ == "__main__":
unittest.main()

0 comments on commit 13e8b6f

Please sign in to comment.