From 14e0365875041e81218f1cf54c9985af51994a4f Mon Sep 17 00:00:00 2001 From: oPluss Date: Mon, 5 Feb 2024 16:53:33 +0800 Subject: [PATCH 01/35] Support register custom OpSchema by python Signed-off-by: oPluss --- onnx/cpp2py_export.cc | 4 +++- onnx/defs/__init__.py | 13 +++++++++++ onnx/defs/schema.cc | 5 +++++ onnx/onnx_cpp2py_export/defs.pyi | 1 + onnx/test/schema_test.py | 38 ++++++++++++++++++++++++++++++++ 5 files changed, 60 insertions(+), 1 deletion(-) diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index 3868c127651..b3dff4f6524 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -469,7 +469,9 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { .def( "get_all_schemas_with_history", []() -> const std::vector { return OpSchemaRegistry::get_all_schemas_with_history(); }, - "Return the schema of all existing operators and all versions."); + "Return the schema of all existing operators and all versions.") + .def( + "register_schema", [](OpSchema* op) { RegisterSchema(*op); }, "Register the custom OpSchema."); // Submodule `checker` auto checker = onnx_cpp2py_export.def_submodule("checker"); diff --git a/onnx/defs/__init__.py b/onnx/defs/__init__.py index 55fb5cf4133..796ff454e32 100644 --- a/onnx/defs/__init__.py +++ b/onnx/defs/__init__.py @@ -8,6 +8,7 @@ "ONNX_ML_DOMAIN", "AI_ONNX_PREVIEW_TRAINING_DOMAIN", "has", + "register_schema", "get_schema", "get_all_schemas", "get_all_schemas_with_history", @@ -120,3 +121,15 @@ def get_function_ops() -> List[OpSchema]: SchemaError = C.SchemaError + +def register_schema(op: OpSchema): + name: str = op.name + domain: str = op.domain + ver: int = op.since_version + assert ver > 0, f'OpSchema {domain}::{name} need positive version but got {ver}' + if has(name, domain): + exist_op = get_schema(name, ver, domain) + if exist_op is not None: + if exist_op.since_version == ver: + raise SchemaError(f'OpSchema {domain}::{name} already exist in file: {exist_op.file}:{exist_op.line}') + C.register_schema(op) diff --git a/onnx/defs/schema.cc b/onnx/defs/schema.cc index 640237d1fdd..5d2751b4914 100644 --- a/onnx/defs/schema.cc +++ b/onnx/defs/schema.cc @@ -919,6 +919,11 @@ void OpSchema::Finalize() { // all inputs or std::numeric_limits::max() (if the last input is // variadic). + max_input_ = 0; + min_input_ = 0; + min_output_ = 0; + max_output_ = 0; + // Flag indicates whether an optional input is trailing one (there's no single // or variadic input behind). for (size_t i = 0; i < inputs_.size(); ++i) { diff --git a/onnx/onnx_cpp2py_export/defs.pyi b/onnx/onnx_cpp2py_export/defs.pyi index fcf4e9ccaa8..8e40e4e0c04 100644 --- a/onnx/onnx_cpp2py_export/defs.pyi +++ b/onnx/onnx_cpp2py_export/defs.pyi @@ -192,3 +192,4 @@ def get_schema( def get_schema(op_type: str, domain: str = "") -> OpSchema: ... def get_all_schemas() -> Sequence[OpSchema]: ... def get_all_schemas_with_history() -> Sequence[OpSchema]: ... +def register_schema(schema: OpSchema) -> None: ... diff --git a/onnx/test/schema_test.py b/onnx/test/schema_test.py index 1ee42ded555..2a289a64440 100644 --- a/onnx/test/schema_test.py +++ b/onnx/test/schema_test.py @@ -257,6 +257,44 @@ def test_init_with_default_value(self): self.assertEqual("attr1", attribute.name) self.assertEqual("attr1 description", attribute.description) +class TestOpSchemaRegister(unittest.TestCase): + def test_register(self): + input = """ + agraph (float[N, 128] X, int32 Y) => (float[N] Z) + { + Z = CustomOp(X, Y) + } + """ + model = onnx.parser.parse_graph(input) + op_schema = defs.OpSchema( + "CustomOp", + "", + 1, + inputs=[ + defs.OpSchema.FormalParameter("input1", "T"), + defs.OpSchema.FormalParameter("input2", "int32"), + ], + outputs=[ + defs.OpSchema.FormalParameter("output1", "T"), + ], + type_constraints=[("T", ["tensor(float)"], "")], + attributes=[ + defs.OpSchema.Attribute( + "attr1", defs.OpSchema.AttrType.INTS, "attr1 description" + ) + ], + ) + onnx.defs.register_schema(op_schema) + onnx.checker.check_graph(model) + + def test_duplicited_register(self): + op_schema = defs.OpSchema( + "CustomOpDuplicited", + "", + 1, + ) + onnx.defs.register_schema(op_schema) + self.assertRaises(onnx.defs.SchemaError, lambda: onnx.defs.register_schema(op_schema)) if __name__ == "__main__": unittest.main() From ec31343a2aab8b619f77ba47a80a57d0e4127466 Mon Sep 17 00:00:00 2001 From: oPluss Date: Tue, 6 Feb 2024 09:56:28 +0800 Subject: [PATCH 02/35] Optimization register check logic Signed-off-by: oPluss --- onnx/cpp2py_export.cc | 2 +- onnx/defs/__init__.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index b3dff4f6524..109a57ae562 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -471,7 +471,7 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { []() -> const std::vector { return OpSchemaRegistry::get_all_schemas_with_history(); }, "Return the schema of all existing operators and all versions.") .def( - "register_schema", [](OpSchema* op) { RegisterSchema(*op); }, "Register the custom OpSchema."); + "register_schema", [](OpSchema* op) { RegisterSchema(*op); }, "op"_a, "Register the custom OpSchema."); // Submodule `checker` auto checker = onnx_cpp2py_export.def_submodule("checker"); diff --git a/onnx/defs/__init__.py b/onnx/defs/__init__.py index 796ff454e32..02bf79a0c94 100644 --- a/onnx/defs/__init__.py +++ b/onnx/defs/__init__.py @@ -123,13 +123,15 @@ def get_function_ops() -> List[OpSchema]: SchemaError = C.SchemaError def register_schema(op: OpSchema): - name: str = op.name - domain: str = op.domain - ver: int = op.since_version + name = op.name + domain = op.domain + ver = op.since_version assert ver > 0, f'OpSchema {domain}::{name} need positive version but got {ver}' - if has(name, domain): + try: exist_op = get_schema(name, ver, domain) - if exist_op is not None: - if exist_op.since_version == ver: - raise SchemaError(f'OpSchema {domain}::{name} already exist in file: {exist_op.file}:{exist_op.line}') + except SchemaError: + exist_op = None + if exist_op is not None: + if exist_op.since_version == ver: + raise SchemaError(f'OpSchema {domain}::{name} already exist in file: {exist_op.file}:{exist_op.line}') C.register_schema(op) From cb2d9cd9f2e55a626c10ae7ef2c70f3daf3a6c3a Mon Sep 17 00:00:00 2001 From: oPluss Date: Tue, 6 Feb 2024 11:40:17 +0800 Subject: [PATCH 03/35] Support deregister OpSchema Signed-off-by: oPluss --- onnx/cpp2py_export.cc | 11 ++++++++++- onnx/defs/__init__.py | 1 + onnx/defs/schema.cc | 6 ++++++ onnx/defs/schema.h | 13 +++++++++++++ onnx/onnx_cpp2py_export/defs.pyi | 1 + onnx/test/schema_test.py | 9 +++++++++ 6 files changed, 40 insertions(+), 1 deletion(-) diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index 109a57ae562..9f4bcbbab6e 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -471,7 +471,16 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { []() -> const std::vector { return OpSchemaRegistry::get_all_schemas_with_history(); }, "Return the schema of all existing operators and all versions.") .def( - "register_schema", [](OpSchema* op) { RegisterSchema(*op); }, "op"_a, "Register the custom OpSchema."); + "register_schema", [](OpSchema* op) { RegisterSchema(*op); }, "op"_a, "Register the custom OpSchema.") + .def( + "deregister_schema", + [](const std::string& op_type, const int version, const std::string& domain) { + DeRegisterSchema(op_type, version, domain); + }, + "op_type"_a, + "version"_a, + "domain"_a = ONNX_DOMAIN, + "Register the custom OpSchema."); // Submodule `checker` auto checker = onnx_cpp2py_export.def_submodule("checker"); diff --git a/onnx/defs/__init__.py b/onnx/defs/__init__.py index 02bf79a0c94..b5239829526 100644 --- a/onnx/defs/__init__.py +++ b/onnx/defs/__init__.py @@ -32,6 +32,7 @@ get_schema = C.get_schema get_all_schemas = C.get_all_schemas get_all_schemas_with_history = C.get_all_schemas_with_history +deregister_schema = C.deregister_schema def onnx_opset_version() -> int: diff --git a/onnx/defs/schema.cc b/onnx/defs/schema.cc index 5d2751b4914..a76310490cc 100644 --- a/onnx/defs/schema.cc +++ b/onnx/defs/schema.cc @@ -34,6 +34,12 @@ void RegisterSchema(OpSchema schema, int opset_version_to_load, bool fail_duplic OpSchemaRegistry::OpSchemaRegisterOnce ONNX_UNUSED registration(schema, opset_version_to_load, fail_duplicate_schema); } +// The (name, version, domain) must exactly matches the target +// Otherwise will raise an SchemaError +void DeRegisterSchema(const std::string& name, const int version, const std::string& domain) { + OpSchemaRegistry::OpSchemaDeregister(name, version, domain); +} + #ifndef NDEBUG DbgOperatorSetTracker& DbgOperatorSetTracker::Instance() { static DbgOperatorSetTracker instance; diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index c7b2029109d..33a737db0e9 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -1307,6 +1307,18 @@ class OpSchemaRegistry final : public ISchemaRegistry { } }; + static void OpSchemaDeregister(const std::string& name, const int version, const std::string& domain = ONNX_DOMAIN) { + auto& schema_map = GetMapWithoutEnsuringRegistration(); + if (schema_map.count(name) && schema_map[name].count(domain) && schema_map[name][domain].count(version)) { + schema_map[name][domain].erase(version); + } else { + std::stringstream err; + err << "Trying to deregister schema with name: " << name << " domain: " << domain << " version: " << version + << std::endl; + fail_schema(err.str()); + } + } + // Deregister all ONNX opset schemas from domain // Domain with default value ONNX_DOMAIN means ONNX. static void OpSchemaDeregisterAll(const std::string& domain = ONNX_DOMAIN) { @@ -1424,6 +1436,7 @@ class OpSchemaRegistry final : public ISchemaRegistry { }; void RegisterSchema(OpSchema schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true); +void DeRegisterSchema(const std::string& name, const int version, const std::string& domain = ONNX_DOMAIN); // Registers the latest opset schema before opset_version_to_load // By default opset_version_to_load=0 means it will register all versions diff --git a/onnx/onnx_cpp2py_export/defs.pyi b/onnx/onnx_cpp2py_export/defs.pyi index 8e40e4e0c04..26cf31fb017 100644 --- a/onnx/onnx_cpp2py_export/defs.pyi +++ b/onnx/onnx_cpp2py_export/defs.pyi @@ -193,3 +193,4 @@ def get_schema(op_type: str, domain: str = "") -> OpSchema: ... def get_all_schemas() -> Sequence[OpSchema]: ... def get_all_schemas_with_history() -> Sequence[OpSchema]: ... def register_schema(schema: OpSchema) -> None: ... +def deregister_schema(op_name: str, version: int, domain: str = "") -> None: ... diff --git a/onnx/test/schema_test.py b/onnx/test/schema_test.py index 2a289a64440..5f01b87efc3 100644 --- a/onnx/test/schema_test.py +++ b/onnx/test/schema_test.py @@ -295,6 +295,15 @@ def test_duplicited_register(self): ) onnx.defs.register_schema(op_schema) self.assertRaises(onnx.defs.SchemaError, lambda: onnx.defs.register_schema(op_schema)) + + def test_deregister_opschema(self): + self.assertRaises(onnx.defs.SchemaError, lambda: onnx.defs.get_schema('CustomOpSchemaTmp', 1)) + onnx.defs.register_schema(defs.OpSchema('CustomOpSchemaTmp', "", 1)) + op_schema = onnx.defs.get_schema('CustomOpSchemaTmp', 1) + self.assertIsNotNone(op_schema) + onnx.defs.deregister_schema('CustomOpSchemaTmp', 1) + self.assertRaises(onnx.defs.SchemaError, lambda: onnx.defs.get_schema('CustomOpSchemaTmp', 1)) + if __name__ == "__main__": unittest.main() From 0b13a7b55358af48a7c1aa3c4b4faa3e3e8e03cb Mon Sep 17 00:00:00 2001 From: oPluss Date: Tue, 6 Feb 2024 12:49:57 +0800 Subject: [PATCH 04/35] normalized code and unittest Signed-off-by: oPluss --- onnx/cpp2py_export.cc | 11 +++++++---- onnx/defs/__init__.py | 20 +++++++++----------- onnx/defs/schema.cc | 4 ++-- onnx/defs/schema.h | 13 +++++++------ onnx/test/schema_test.py | 35 +++++++++++++++++++++++++++-------- 5 files changed, 52 insertions(+), 31 deletions(-) diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index 9f4bcbbab6e..c3f9bcaeb08 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -471,14 +471,17 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { []() -> const std::vector { return OpSchemaRegistry::get_all_schemas_with_history(); }, "Return the schema of all existing operators and all versions.") .def( - "register_schema", [](OpSchema* op) { RegisterSchema(*op); }, "op"_a, "Register the custom OpSchema.") + "register_schema", + [](OpSchema* schema) { RegisterSchema(*schema); }, + "schema"_a, + "Register the custom OpSchema.") .def( "deregister_schema", - [](const std::string& op_type, const int version, const std::string& domain) { - DeRegisterSchema(op_type, version, domain); + [](const std::string& op_type, const int specific_version, const std::string& domain) { + DeRegisterSchema(op_type, specific_version, domain); }, "op_type"_a, - "version"_a, + "specific_version"_a, "domain"_a = ONNX_DOMAIN, "Register the custom OpSchema."); diff --git a/onnx/defs/__init__.py b/onnx/defs/__init__.py index b5239829526..e367e49568a 100644 --- a/onnx/defs/__init__.py +++ b/onnx/defs/__init__.py @@ -123,16 +123,14 @@ def get_function_ops() -> List[OpSchema]: SchemaError = C.SchemaError -def register_schema(op: OpSchema): - name = op.name - domain = op.domain - ver = op.since_version - assert ver > 0, f'OpSchema {domain}::{name} need positive version but got {ver}' +def register_schema(schema: OpSchema): + name = schema.name + domain = schema.domain + version = schema.since_version try: - exist_op = get_schema(name, ver, domain) + existing_schema = get_schema(name, version, domain) except SchemaError: - exist_op = None - if exist_op is not None: - if exist_op.since_version == ver: - raise SchemaError(f'OpSchema {domain}::{name} already exist in file: {exist_op.file}:{exist_op.line}') - C.register_schema(op) + existing_schema = None + if existing_schema is not None and existing_schema.since_version == version: + raise SchemaError(f'OpSchema {domain}::{name} already exist in file: {existing_schema.file}:{existing_schema.line}') + C.register_schema(schema) diff --git a/onnx/defs/schema.cc b/onnx/defs/schema.cc index a76310490cc..b03c98a8a74 100644 --- a/onnx/defs/schema.cc +++ b/onnx/defs/schema.cc @@ -36,8 +36,8 @@ void RegisterSchema(OpSchema schema, int opset_version_to_load, bool fail_duplic // The (name, version, domain) must exactly matches the target // Otherwise will raise an SchemaError -void DeRegisterSchema(const std::string& name, const int version, const std::string& domain) { - OpSchemaRegistry::OpSchemaDeregister(name, version, domain); +void DeRegisterSchema(const std::string& name, const int specific_version, const std::string& domain) { + OpSchemaRegistry::OpSchemaDeregister(name, specific_version, domain); } #ifndef NDEBUG diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index 33a737db0e9..5a6a897a71a 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -1307,14 +1307,15 @@ class OpSchemaRegistry final : public ISchemaRegistry { } }; - static void OpSchemaDeregister(const std::string& name, const int version, const std::string& domain = ONNX_DOMAIN) { + static void + OpSchemaDeregister(const std::string& name, const int specific_version, const std::string& domain = ONNX_DOMAIN) { auto& schema_map = GetMapWithoutEnsuringRegistration(); - if (schema_map.count(name) && schema_map[name].count(domain) && schema_map[name][domain].count(version)) { - schema_map[name][domain].erase(version); + if (schema_map.count(name) && schema_map[name].count(domain) && schema_map[name][domain].count(specific_version)) { + schema_map[name][domain].erase(specific_version); } else { std::stringstream err; - err << "Trying to deregister schema with name: " << name << " domain: " << domain << " version: " << version - << std::endl; + err << "Trying to deregister schema with name: " << name << " domain: " << domain + << " version: " << specific_version << std::endl; fail_schema(err.str()); } } @@ -1436,7 +1437,7 @@ class OpSchemaRegistry final : public ISchemaRegistry { }; void RegisterSchema(OpSchema schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true); -void DeRegisterSchema(const std::string& name, const int version, const std::string& domain = ONNX_DOMAIN); +void DeRegisterSchema(const std::string& name, const int specific_version, const std::string& domain = ONNX_DOMAIN); // Registers the latest opset schema before opset_version_to_load // By default opset_version_to_load=0 means it will register all versions diff --git a/onnx/test/schema_test.py b/onnx/test/schema_test.py index 5f01b87efc3..5a8579afa80 100644 --- a/onnx/test/schema_test.py +++ b/onnx/test/schema_test.py @@ -284,25 +284,44 @@ def test_register(self): ) ], ) + self.assertFalse(onnx.defs.has(op_schema.name)) + with self.assertRaises(onnx.checker.ValidationError): + onnx.checker.check_graph(model) onnx.defs.register_schema(op_schema) onnx.checker.check_graph(model) + + # cleanup + onnx.defs.deregister_schema(op_schema.name, op_schema.since_version, op_schema.domain) - def test_duplicited_register(self): + def test_register_schema_raises_error_when_registering_a_schema_twice(self): op_schema = defs.OpSchema( - "CustomOpDuplicited", + "CustomOp", "", 1, ) + self.assertFalse(onnx.defs.has(op_schema.name)) onnx.defs.register_schema(op_schema) - self.assertRaises(onnx.defs.SchemaError, lambda: onnx.defs.register_schema(op_schema)) + with self.assertRaises(onnx.defs.SchemaError): + onnx.defs.register_schema(op_schema) + + # cleanup + onnx.defs.deregister_schema(op_schema.name, op_schema.since_version, op_schema.domain) def test_deregister_opschema(self): - self.assertRaises(onnx.defs.SchemaError, lambda: onnx.defs.get_schema('CustomOpSchemaTmp', 1)) - onnx.defs.register_schema(defs.OpSchema('CustomOpSchemaTmp', "", 1)) - op_schema = onnx.defs.get_schema('CustomOpSchemaTmp', 1) + with self.assertRaises(onnx.defs.SchemaError): + onnx.defs.get_schema('CustomOp', 1) + onnx.defs.register_schema(defs.OpSchema('CustomOp', "", 1)) + op_schema = onnx.defs.get_schema('CustomOp', 1) self.assertIsNotNone(op_schema) - onnx.defs.deregister_schema('CustomOpSchemaTmp', 1) - self.assertRaises(onnx.defs.SchemaError, lambda: onnx.defs.get_schema('CustomOpSchemaTmp', 1)) + onnx.defs.deregister_schema('CustomOp', 1) + with self.assertRaises(onnx.defs.SchemaError): + onnx.defs.get_schema('CustomOp', 1) + + def test_deregister_raise_error_when_deregister_noexist_opschema(self): + with self.assertRaises(onnx.defs.SchemaError): + onnx.defs.get_schema('CustomOp', 1) + with self.assertRaises(onnx.defs.SchemaError): + onnx.defs.deregister_schema('CustomOp', 1) if __name__ == "__main__": From 0b26ad517afda2e11b5be9756dbbcc1e194ca33b Mon Sep 17 00:00:00 2001 From: oPluss Date: Tue, 6 Feb 2024 12:57:15 +0800 Subject: [PATCH 05/35] fix desc error Signed-off-by: oPluss --- onnx/cpp2py_export.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index c3f9bcaeb08..089d9959125 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -474,7 +474,7 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { "register_schema", [](OpSchema* schema) { RegisterSchema(*schema); }, "schema"_a, - "Register the custom OpSchema.") + "Register a user provided OpSchema.") .def( "deregister_schema", [](const std::string& op_type, const int specific_version, const std::string& domain) { @@ -483,7 +483,7 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { "op_type"_a, "specific_version"_a, "domain"_a = ONNX_DOMAIN, - "Register the custom OpSchema."); + "DeRegister the specified OpSchema."); // Submodule `checker` auto checker = onnx_cpp2py_export.def_submodule("checker"); From 5b2151f8463f2ab15fcd71b08af25d5d785b6cfa Mon Sep 17 00:00:00 2001 From: oPluss Date: Tue, 6 Feb 2024 12:59:16 +0800 Subject: [PATCH 06/35] append export symbol Signed-off-by: oPluss --- onnx/defs/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnx/defs/__init__.py b/onnx/defs/__init__.py index e367e49568a..f898e680682 100644 --- a/onnx/defs/__init__.py +++ b/onnx/defs/__init__.py @@ -9,6 +9,7 @@ "AI_ONNX_PREVIEW_TRAINING_DOMAIN", "has", "register_schema", + "deregister_schema", "get_schema", "get_all_schemas", "get_all_schemas_with_history", From a4c356c4ef7e83e784528676f260baf836c9e77c Mon Sep 17 00:00:00 2001 From: oPluss Date: Tue, 6 Feb 2024 18:30:35 +0800 Subject: [PATCH 07/35] add annotate and update doc Signed-off-by: oPluss --- docs/docsgen/source/api/defs.md | 4 ++++ onnx/defs/__init__.py | 1 + 2 files changed, 5 insertions(+) diff --git a/docs/docsgen/source/api/defs.md b/docs/docsgen/source/api/defs.md index ed027ceca63..509c2d2f528 100644 --- a/docs/docsgen/source/api/defs.md +++ b/docs/docsgen/source/api/defs.md @@ -22,6 +22,10 @@ .. autofunction:: onnx.defs.get_all_schemas_with_history .. autofunction:: onnx.defs.get_function_ops + +.. autofunction:: onnx.defs.register_schema + +.. autofunction:: onnx.defs.deregister_schema ``` ## class `OpSchema` diff --git a/onnx/defs/__init__.py b/onnx/defs/__init__.py index f898e680682..eed97df293e 100644 --- a/onnx/defs/__init__.py +++ b/onnx/defs/__init__.py @@ -125,6 +125,7 @@ def get_function_ops() -> List[OpSchema]: SchemaError = C.SchemaError def register_schema(schema: OpSchema): + """Register a user provided OpSchema.""" name = schema.name domain = schema.domain version = schema.since_version From 72a0e03a8a9b238d5cdbc2135d58f4ac3aa9ffc7 Mon Sep 17 00:00:00 2001 From: opluss Date: Wed, 7 Feb 2024 19:04:50 +0800 Subject: [PATCH 08/35] refact unittest Signed-off-by: opluss --- onnx/cpp2py_export.cc | 10 ++-- onnx/defs/__init__.py | 2 +- onnx/defs/schema.cc | 6 +- onnx/defs/schema.h | 13 ++--- onnx/onnx_cpp2py_export/defs.pyi | 4 +- onnx/test/schema_test.py | 98 ++++++++++++++++++++------------ 6 files changed, 78 insertions(+), 55 deletions(-) diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index 089d9959125..de0fe023569 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -477,13 +477,13 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { "Register a user provided OpSchema.") .def( "deregister_schema", - [](const std::string& op_type, const int specific_version, const std::string& domain) { - DeRegisterSchema(op_type, specific_version, domain); + [](const std::string& op_type, int version, const std::string& domain) { + DeregisterSchema(op_type, version, domain); }, "op_type"_a, - "specific_version"_a, - "domain"_a = ONNX_DOMAIN, - "DeRegister the specified OpSchema."); + "version"_a, + "domain"_a, + "Deregister the specified OpSchema."); // Submodule `checker` auto checker = onnx_cpp2py_export.def_submodule("checker"); diff --git a/onnx/defs/__init__.py b/onnx/defs/__init__.py index eed97df293e..a50512f7712 100644 --- a/onnx/defs/__init__.py +++ b/onnx/defs/__init__.py @@ -134,5 +134,5 @@ def register_schema(schema: OpSchema): except SchemaError: existing_schema = None if existing_schema is not None and existing_schema.since_version == version: - raise SchemaError(f'OpSchema {domain}::{name} already exist in file: {existing_schema.file}:{existing_schema.line}') + raise SchemaError(f"OpSchema '{domain}::{name}' already defined in file: {existing_schema.file}:{existing_schema.line}") C.register_schema(schema) diff --git a/onnx/defs/schema.cc b/onnx/defs/schema.cc index b03c98a8a74..c1819902d20 100644 --- a/onnx/defs/schema.cc +++ b/onnx/defs/schema.cc @@ -34,10 +34,10 @@ void RegisterSchema(OpSchema schema, int opset_version_to_load, bool fail_duplic OpSchemaRegistry::OpSchemaRegisterOnce ONNX_UNUSED registration(schema, opset_version_to_load, fail_duplicate_schema); } -// The (name, version, domain) must exactly matches the target +// The (name, version, domain) must match the target exactly // Otherwise will raise an SchemaError -void DeRegisterSchema(const std::string& name, const int specific_version, const std::string& domain) { - OpSchemaRegistry::OpSchemaDeregister(name, specific_version, domain); +void DeregisterSchema(const std::string& name, int version, const std::string& domain) { + OpSchemaRegistry::OpSchemaDeregister(name, version, domain); } #ifndef NDEBUG diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index 5a6a897a71a..0e93914eb1d 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -1307,15 +1307,14 @@ class OpSchemaRegistry final : public ISchemaRegistry { } }; - static void - OpSchemaDeregister(const std::string& name, const int specific_version, const std::string& domain = ONNX_DOMAIN) { + static void OpSchemaDeregister(const std::string& name, const int version, const std::string& domain = ONNX_DOMAIN) { auto& schema_map = GetMapWithoutEnsuringRegistration(); - if (schema_map.count(name) && schema_map[name].count(domain) && schema_map[name][domain].count(specific_version)) { - schema_map[name][domain].erase(specific_version); + if (schema_map.count(name) && schema_map[name].count(domain) && schema_map[name][domain].count(version)) { + schema_map[name][domain].erase(version); } else { std::stringstream err; - err << "Trying to deregister schema with name: " << name << " domain: " << domain - << " version: " << specific_version << std::endl; + err << "Attempting to deregister an unregistered schema with name: " << name << " domain: " << domain + << " version: " << version << std::endl; fail_schema(err.str()); } } @@ -1437,7 +1436,7 @@ class OpSchemaRegistry final : public ISchemaRegistry { }; void RegisterSchema(OpSchema schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true); -void DeRegisterSchema(const std::string& name, const int specific_version, const std::string& domain = ONNX_DOMAIN); +void DeregisterSchema(const std::string& name, int version, const std::string& domain = ONNX_DOMAIN); // Registers the latest opset schema before opset_version_to_load // By default opset_version_to_load=0 means it will register all versions diff --git a/onnx/onnx_cpp2py_export/defs.pyi b/onnx/onnx_cpp2py_export/defs.pyi index 26cf31fb017..5cde8a146ca 100644 --- a/onnx/onnx_cpp2py_export/defs.pyi +++ b/onnx/onnx_cpp2py_export/defs.pyi @@ -182,7 +182,7 @@ class OpSchema: CONSUME_ALLOWED: OpSchema.UseType = ... CONSUME_ENFORCED: OpSchema.UseType = ... -def has_schema(op_type: str) -> bool: ... +def has_schema(op_type: str, domain: str = "") -> bool: ... def schema_version_map() -> dict[str, tuple[int, int]]: ... @overload def get_schema( @@ -193,4 +193,4 @@ def get_schema(op_type: str, domain: str = "") -> OpSchema: ... def get_all_schemas() -> Sequence[OpSchema]: ... def get_all_schemas_with_history() -> Sequence[OpSchema]: ... def register_schema(schema: OpSchema) -> None: ... -def deregister_schema(op_name: str, version: int, domain: str = "") -> None: ... +def deregister_schema(op_name: str, version: int, domain: str) -> None: ... diff --git a/onnx/test/schema_test.py b/onnx/test/schema_test.py index 5a8579afa80..e50eee5f371 100644 --- a/onnx/test/schema_test.py +++ b/onnx/test/schema_test.py @@ -257,19 +257,51 @@ def test_init_with_default_value(self): self.assertEqual("attr1", attribute.name) self.assertEqual("attr1 description", attribute.description) +@parameterized.parameterized_class( + [ + { + "op_type": "CustomOp", + "op_version": 1, + "op_domain": "", + } + ] +) class TestOpSchemaRegister(unittest.TestCase): - def test_register(self): - input = """ - agraph (float[N, 128] X, int32 Y) => (float[N] Z) - { - Z = CustomOp(X, Y) - } + op_type: str + op_version: int + op_domain: str + + def setUp(self) -> None: + # Ensure the schema is unregistered + self.assertFalse(onnx.defs.has(self.op_type, self.op_domain)) + + def tearDown(self) -> None: + # Clean up the registered schema + try: + onnx.defs.deregister_schema( + self.op_type, + self.op_version, + self.op_domain + ) + except onnx.defs.SchemaError: + pass + + def test_register_schema(self): + input = f""" + < + ir_version: 7, + opset_import: ["{self.op_domain}" : {self.op_version}] + > + agraph (float[N, 128] X, int32 Y) => (float[N] Z) + {{ + Z = {self.op_type}(X, Y) + }} """ - model = onnx.parser.parse_graph(input) + model = onnx.parser.parse_model(input) op_schema = defs.OpSchema( - "CustomOp", - "", - 1, + self.op_type, + self.op_domain, + self.op_version, inputs=[ defs.OpSchema.FormalParameter("input1", "T"), defs.OpSchema.FormalParameter("input2", "int32"), @@ -284,45 +316,37 @@ def test_register(self): ) ], ) - self.assertFalse(onnx.defs.has(op_schema.name)) with self.assertRaises(onnx.checker.ValidationError): - onnx.checker.check_graph(model) + onnx.checker.check_model(model) onnx.defs.register_schema(op_schema) - onnx.checker.check_graph(model) - - # cleanup - onnx.defs.deregister_schema(op_schema.name, op_schema.since_version, op_schema.domain) + self.assertTrue(onnx.defs.has(self.op_type, self.op_domain)) + onnx.checker.check_model(model) def test_register_schema_raises_error_when_registering_a_schema_twice(self): op_schema = defs.OpSchema( - "CustomOp", - "", - 1, + self.op_type, + self.op_domain, + self.op_version, ) - self.assertFalse(onnx.defs.has(op_schema.name)) onnx.defs.register_schema(op_schema) with self.assertRaises(onnx.defs.SchemaError): onnx.defs.register_schema(op_schema) - # cleanup + def test_deregister_schema(self): + op_schema = defs.OpSchema( + self.op_type, + self.op_domain, + self.op_version, + ) + onnx.defs.register_schema(op_schema) + self.assertTrue(onnx.defs.has(op_schema.name, op_schema.domain)) onnx.defs.deregister_schema(op_schema.name, op_schema.since_version, op_schema.domain) - - def test_deregister_opschema(self): - with self.assertRaises(onnx.defs.SchemaError): - onnx.defs.get_schema('CustomOp', 1) - onnx.defs.register_schema(defs.OpSchema('CustomOp', "", 1)) - op_schema = onnx.defs.get_schema('CustomOp', 1) - self.assertIsNotNone(op_schema) - onnx.defs.deregister_schema('CustomOp', 1) - with self.assertRaises(onnx.defs.SchemaError): - onnx.defs.get_schema('CustomOp', 1) - - def test_deregister_raise_error_when_deregister_noexist_opschema(self): - with self.assertRaises(onnx.defs.SchemaError): - onnx.defs.get_schema('CustomOp', 1) + self.assertFalse(onnx.defs.has(op_schema.name, op_schema.domain)) + + def test_deregister_raise_error_when_deregister_nonexistent_opschema(self): with self.assertRaises(onnx.defs.SchemaError): - onnx.defs.deregister_schema('CustomOp', 1) - + onnx.defs.deregister_schema(self.op_type, self.op_version, self.op_domain) + if __name__ == "__main__": unittest.main() From dca0d07ce8d6ea523f76221a1a8dfc152c3aae7d Mon Sep 17 00:00:00 2001 From: opluss Date: Wed, 7 Feb 2024 19:25:58 +0800 Subject: [PATCH 09/35] separate registration logic Signed-off-by: opluss --- onnx/cpp2py_export.cc | 2 +- onnx/defs/__init__.py | 14 +------- onnx/defs/schema.cc | 9 +++-- onnx/defs/schema.h | 83 +++++++++++++++++++++++-------------------- 4 files changed, 54 insertions(+), 54 deletions(-) diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index de0fe023569..6c7078059fb 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -472,7 +472,7 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { "Return the schema of all existing operators and all versions.") .def( "register_schema", - [](OpSchema* schema) { RegisterSchema(*schema); }, + [](OpSchema* schema) { RegisterSchema(*schema, 0, true, true); }, "schema"_a, "Register a user provided OpSchema.") .def( diff --git a/onnx/defs/__init__.py b/onnx/defs/__init__.py index a50512f7712..5b75c424fca 100644 --- a/onnx/defs/__init__.py +++ b/onnx/defs/__init__.py @@ -33,6 +33,7 @@ get_schema = C.get_schema get_all_schemas = C.get_all_schemas get_all_schemas_with_history = C.get_all_schemas_with_history +register_schema = C.register_schema deregister_schema = C.deregister_schema @@ -123,16 +124,3 @@ def get_function_ops() -> List[OpSchema]: SchemaError = C.SchemaError - -def register_schema(schema: OpSchema): - """Register a user provided OpSchema.""" - name = schema.name - domain = schema.domain - version = schema.since_version - try: - existing_schema = get_schema(name, version, domain) - except SchemaError: - existing_schema = None - if existing_schema is not None and existing_schema.since_version == version: - raise SchemaError(f"OpSchema '{domain}::{name}' already defined in file: {existing_schema.file}:{existing_schema.line}") - C.register_schema(schema) diff --git a/onnx/defs/schema.cc b/onnx/defs/schema.cc index c1819902d20..bbd74ce4900 100644 --- a/onnx/defs/schema.cc +++ b/onnx/defs/schema.cc @@ -30,8 +30,13 @@ constexpr int OpSchema::kUninitializedSinceVersion; // By default if opset_version_to_load=0, it registers all opset schema for all opset versions // Otherwise, it only registers the latest schema according to opset_version_to_load -void RegisterSchema(OpSchema schema, int opset_version_to_load, bool fail_duplicate_schema) { - OpSchemaRegistry::OpSchemaRegisterOnce ONNX_UNUSED registration(schema, opset_version_to_load, fail_duplicate_schema); +void RegisterSchema(OpSchema schema, int opset_version_to_load, bool fail_duplicate_schema, bool fail_with_exception) { + if (fail_with_exception) { + OpSchemaRegistry::OpSchemaRegisterOnce::OpSchemaRegisterImpl(schema, opset_version_to_load, fail_duplicate_schema); + } else { + OpSchemaRegistry::OpSchemaRegisterOnce ONNX_UNUSED registration( + schema, opset_version_to_load, fail_duplicate_schema); + } } // The (name, version, domain) must match the target exactly diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index 0e93914eb1d..0ac084d990b 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -1215,50 +1215,53 @@ class OpSchemaRegistry final : public ISchemaRegistry { public: OpSchemaRegisterOnce(OpSchema& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { ONNX_TRY { - op_schema.Finalize(); - auto& m = GetMapWithoutEnsuringRegistration(); - auto& op_name = op_schema.Name(); - auto& op_domain = op_schema.domain(); - auto& schema_ver_map = m[op_name][op_domain]; - auto ver = op_schema.SinceVersion(); - if (OpSchema::kUninitializedSinceVersion == ver) { - op_schema.SinceVersion(1); - ver = op_schema.SinceVersion(); + OpSchemaRegisterImpl(op_schema, opset_version_to_load, fail_duplicate_schema); + } + ONNX_CATCH(const std::exception& e) { + ONNX_HANDLE_EXCEPTION([&]() { std::cerr << "Schema error: " << e.what() << std::endl; }); + } + } + static void + OpSchemaRegisterImpl(OpSchema& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { + op_schema.Finalize(); + auto& m = GetMapWithoutEnsuringRegistration(); + auto& op_name = op_schema.Name(); + auto& op_domain = op_schema.domain(); + auto& schema_ver_map = m[op_name][op_domain]; + auto ver = op_schema.SinceVersion(); + if (OpSchema::kUninitializedSinceVersion == ver) { + op_schema.SinceVersion(1); + ver = op_schema.SinceVersion(); + } + + // Stops because the exact opset_version is registered + if (schema_ver_map.count(ver)) { + if (fail_duplicate_schema) { + const auto& schema = schema_ver_map[ver]; + std::stringstream err; + err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver + << ") from file " << op_schema.file() << " line " << op_schema.line() + << ", but it is already registered from file " << schema.file() << " line " << schema.line() << std::endl; + fail_schema(err.str()); } + return; + } - // Stops because the exact opset_version is registered - if (schema_ver_map.count(ver)) { - if (fail_duplicate_schema) { - const auto& schema = schema_ver_map[ver]; - std::stringstream err; - err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver - << ") from file " << op_schema.file() << " line " << op_schema.line() - << ", but it is already registered from file " << schema.file() << " line " << schema.line() - << std::endl; - fail_schema(err.str()); - } + if (opset_version_to_load != 0) { + // Stops because the opset_version is higher than opset_version_to_load + if (ver > opset_version_to_load) return; - } - if (opset_version_to_load != 0) { - // Stops because the opset_version is higher than opset_version_to_load - if (ver > opset_version_to_load) + // Stops because a later version is registered within target opset version + if (!schema_ver_map.empty()) { + int max_registered_ver_le_target = GetMaxRegisteredVerWithinTarget(schema_ver_map, opset_version_to_load); + if (max_registered_ver_le_target >= ver) return; - - // Stops because a later version is registered within target opset version - if (!schema_ver_map.empty()) { - int max_registered_ver_le_target = GetMaxRegisteredVerWithinTarget(schema_ver_map, opset_version_to_load); - if (max_registered_ver_le_target >= ver) - return; - } } - - CheckDomainAndVersionToRegister(op_schema, op_name, op_domain); - schema_ver_map.insert(std::pair(ver, std::move(op_schema))); - } - ONNX_CATCH(const std::exception& e) { - ONNX_HANDLE_EXCEPTION([&]() { std::cerr << "Schema error: " << e.what() << std::endl; }); } + + CheckDomainAndVersionToRegister(op_schema, op_name, op_domain); + schema_ver_map.insert(std::pair(ver, std::move(op_schema))); } private: @@ -1435,7 +1438,11 @@ class OpSchemaRegistry final : public ISchemaRegistry { } }; -void RegisterSchema(OpSchema schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true); +void RegisterSchema( + OpSchema schema, + int opset_version_to_load = 0, + bool fail_duplicate_schema = true, + bool fail_with_exception = false); void DeregisterSchema(const std::string& name, int version, const std::string& domain = ONNX_DOMAIN); // Registers the latest opset schema before opset_version_to_load From 5a0171008f1ede0b7fdb37f81c080b1b24cb9e83 Mon Sep 17 00:00:00 2001 From: opluss Date: Wed, 7 Feb 2024 23:04:29 +0800 Subject: [PATCH 10/35] enable check for custom domain Signed-off-by: opluss --- onnx/checker.cc | 20 ++++++++++---------- onnx/checker.h | 13 +++++++++++-- onnx/checker.py | 7 +++++-- onnx/cpp2py_export.cc | 13 ++++++++----- onnx/onnx_cpp2py_export/checker.pyi | 4 ++-- 5 files changed, 36 insertions(+), 21 deletions(-) diff --git a/onnx/checker.cc b/onnx/checker.cc index e2736e365a1..3c981fcbbf5 100644 --- a/onnx/checker.cc +++ b/onnx/checker.cc @@ -651,17 +651,11 @@ void check_node(const NodeProto& node, const CheckerContext& ctx, const LexicalS const auto* schema = ctx.get_schema_registry()->GetSchema(node.op_type(), domain_version, node.domain()); if (!schema) { if (node.domain() == ONNX_DOMAIN || node.domain() == AI_ONNX_ML_DOMAIN || node.domain() == "ai.onnx" || - node.domain() == AI_ONNX_TRAINING_DOMAIN) { - // fail the checker if op in built-in domains has no schema + node.domain() == AI_ONNX_TRAINING_DOMAIN || ctx.check_custom_op()) { + // fail the checker if op which is in built-in domains or enable `check_custom_op` has no schema fail_check( "No Op registered for " + node.op_type() + " with domain_version of " + ONNX_NAMESPACE::to_string(domain_version)); - } else { - // TODO: expose the registration of the op schemas appropriately in - // python, so we can load and register operators in other domains - // - // before we complete the above todo, let's skip the schema check for - // now } } else if (schema->Deprecated()) { fail_check( @@ -1016,7 +1010,11 @@ void check_model(const ModelProto& model, CheckerContext& ctx) { } } -void check_model(const std::string& model_path, bool full_check, bool skip_opset_compatibility_check) { +void check_model( + const std::string& model_path, + bool full_check, + bool skip_opset_compatibility_check, + bool check_custom_op) { ModelProto model; LoadProtoFromPath(model_path, model); @@ -1028,6 +1026,7 @@ void check_model(const std::string& model_path, bool full_check, bool skip_opset } ctx.set_model_dir(model_dir); ctx.set_skip_opset_compatibility_check(skip_opset_compatibility_check); + ctx.set_check_custom_op(check_custom_op); check_model(model, ctx); if (full_check) { @@ -1036,9 +1035,10 @@ void check_model(const std::string& model_path, bool full_check, bool skip_opset } } -void check_model(const ModelProto& model, bool full_check, bool skip_opset_compatibility_check) { +void check_model(const ModelProto& model, bool full_check, bool skip_opset_compatibility_check, bool check_custom_op) { CheckerContext ctx; ctx.set_skip_opset_compatibility_check(skip_opset_compatibility_check); + ctx.set_check_custom_op(check_custom_op); check_model(model, ctx); if (full_check) { ShapeInferenceOptions options{true, 1, false}; diff --git a/onnx/checker.h b/onnx/checker.h index 6796acab222..a2848f9804d 100644 --- a/onnx/checker.h +++ b/onnx/checker.h @@ -84,6 +84,14 @@ class CheckerContext final { skip_opset_compatibility_check_ = value; } + bool check_custom_op() const { + return check_custom_op_; + } + + void set_check_custom_op(bool value) { + check_custom_op_ = value; + } + explicit CheckerContext() : ir_version_(-1) {} private: @@ -93,6 +101,7 @@ class CheckerContext final { const ISchemaRegistry* schema_registry_ = OpSchemaRegistry::Instance(); std::string model_dir_; bool skip_opset_compatibility_check_ = false; + bool check_custom_op_ = false; }; class LexicalScopeContext final { @@ -158,8 +167,8 @@ void check_model_local_functions( const CheckerContext& ctx, const LexicalScopeContext& parent_lex); -void check_model(const ModelProto& model, bool full_check = false, bool skip_opset_compatibility_check = false); -void check_model(const std::string& model_path, bool full_check = false, bool skip_opset_compatibility_check = false); +void check_model(const ModelProto& model, bool full_check = false, bool skip_opset_compatibility_check = false, bool check_custom_op = false); +void check_model(const std::string& model_path, bool full_check = false, bool skip_opset_compatibility_check = false, bool check_custom_op = false); bool check_is_experimental_op(const NodeProto& node); diff --git a/onnx/checker.py b/onnx/checker.py index f624a4d7f35..c067742b337 100644 --- a/onnx/checker.py +++ b/onnx/checker.py @@ -137,6 +137,7 @@ def check_model( model: ModelProto | str | bytes | os.PathLike, full_check: bool = False, skip_opset_compatibility_check: bool = False, + check_custom_op: bool = False, ) -> None: """Check the consistency of a model. @@ -154,10 +155,12 @@ def check_model( full_check: If True, the function also runs shape inference check. skip_opset_compatibility_check: If True, the function skips the check for opset compatibility. + check_custom_op: If True, the function will check all domains. Otherwise + only check built-in domains. """ # If model is a path instead of ModelProto if isinstance(model, (str, os.PathLike)): - C.check_model_path(os.fspath(model), full_check, skip_opset_compatibility_check) + C.check_model_path(os.fspath(model), full_check, skip_opset_compatibility_check, check_custom_op) else: protobuf_string = ( model if isinstance(model, bytes) else model.SerializeToString() @@ -168,7 +171,7 @@ def check_model( raise ValueError( "This protobuf of onnx model is too large (>2GB). Call check_model with model path instead." ) - C.check_model(protobuf_string, full_check, skip_opset_compatibility_check) + C.check_model(protobuf_string, full_check, skip_opset_compatibility_check, check_custom_op) ValidationError = C.ValidationError diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index 6c7078059fb..99f12ba85dd 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -560,21 +560,24 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { checker.def( "check_model", - [](const py::bytes& bytes, bool full_check, bool skip_opset_compatibility_check) -> void { + [](const py::bytes& bytes, bool full_check, bool skip_opset_compatibility_check, bool check_custom_op) -> void { ModelProto proto{}; ParseProtoFromPyBytes(&proto, bytes); - checker::check_model(proto, full_check, skip_opset_compatibility_check); + checker::check_model(proto, full_check, skip_opset_compatibility_check, check_custom_op); }, "bytes"_a, "full_check"_a = false, - "skip_opset_compatibility_check"_a = false); + "skip_opset_compatibility_check"_a = false, + "check_custom_op"_a = false); checker.def( "check_model_path", - (void (*)(const std::string& path, bool full_check, bool skip_opset_compatibility_check)) & checker::check_model, + (void (*)(const std::string& path, bool full_check, bool skip_opset_compatibility_check, bool check_custom_op)) & + checker::check_model, "path"_a, "full_check"_a = false, - "skip_opset_compatibility_check"_a = false); + "skip_opset_compatibility_check"_a = false, + "check_custom_op"_a = false); // Submodule `version_converter` auto version_converter = onnx_cpp2py_export.def_submodule("version_converter"); diff --git a/onnx/onnx_cpp2py_export/checker.pyi b/onnx/onnx_cpp2py_export/checker.pyi index 83ce4650690..f7e8e8d637a 100644 --- a/onnx/onnx_cpp2py_export/checker.pyi +++ b/onnx/onnx_cpp2py_export/checker.pyi @@ -17,5 +17,5 @@ def check_attribute(bytes: bytes, checker_context: CheckerContext, lexical_scope def check_node(bytes: bytes, checker_context: CheckerContext, lexical_scope_context: LexicalScopeContext) -> None: ... # noqa: A002 def check_function(bytes: bytes, checker_context: CheckerContext, lexical_scope_context: LexicalScopeContext) -> None: ... # noqa: A002 def check_graph(bytes: bytes, checker_context: CheckerContext, lexical_scope_context: LexicalScopeContext) -> None: ... # noqa: A002 -def check_model(bytes: bytes, full_check: bool, skip_opset_compatibility_check: bool) -> None: ... # noqa: A002 -def check_model_path(path: str, full_check: bool, skip_opset_compatibility_check: bool) -> None: ... +def check_model(bytes: bytes, full_check: bool, skip_opset_compatibility_check: bool, check_custom_op: bool) -> None: ... # noqa: A002 +def check_model_path(path: str, full_check: bool, skip_opset_compatibility_check: bool, check_custom_op: bool) -> None: ... From c68fcbf6d453b75390b80cfad96aa81bd28317e5 Mon Sep 17 00:00:00 2001 From: opluss Date: Wed, 7 Feb 2024 23:05:17 +0800 Subject: [PATCH 11/35] add domain automatically when register a cudatom op Signed-off-by: opluss --- onnx/cpp2py_export.cc | 11 +++++++++++ onnx/defs/__init__.py | 12 ++++++++++++ onnx/defs/schema.h | 15 +++++++++++---- onnx/onnx_cpp2py_export/defs.pyi | 1 + onnx/test/schema_test.py | 15 +++++++++++---- 5 files changed, 46 insertions(+), 8 deletions(-) diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index 99f12ba85dd..26eb344377d 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -470,6 +470,17 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { "get_all_schemas_with_history", []() -> const std::vector { return OpSchemaRegistry::get_all_schemas_with_history(); }, "Return the schema of all existing operators and all versions.") + .def( + "set_domain_to_version", + [](const std::string& domain, int min_version, int max_version, int last_release_version) { + OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion( + domain, min_version, max_version, last_release_version, true); + }, + "domain"_a, + "min_version"_a, + "max_version"_a, + "last_release_version"_a = -1, + "Register a user provided OpSchema.") .def( "register_schema", [](OpSchema* schema) { RegisterSchema(*schema, 0, true, true); }, diff --git a/onnx/defs/__init__.py b/onnx/defs/__init__.py index 5b75c424fca..4430c6a6a4d 100644 --- a/onnx/defs/__init__.py +++ b/onnx/defs/__init__.py @@ -124,3 +124,15 @@ def get_function_ops() -> List[OpSchema]: SchemaError = C.SchemaError + +def register_schema(schema: OpSchema): + """Register a user provided OpSchema. And extend domain automatically.""" + map = C.schema_version_map() + domain = schema.domain + version = schema.since_version + min_version, max_version = map.get(domain, (version, version)) + if domain not in map or not (min_version <= version <= max_version): + min_version = min(min_version, version) + max_version = max(max_version, version) + C.set_domain_to_version(schema.domain, min_version, max_version) + C.register_schema(schema) diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index 0ac084d990b..736146b313c 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -1184,16 +1184,23 @@ class OpSchemaRegistry final : public ISchemaRegistry { // standard ONNX domains as above). Custom-domains are free to interpret // this as appropriate (that is, as relative to releases of custom-domain // as opposed to ONNX releases). - void - AddDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version = -1) { + void AddDomainToVersion( + const std::string& domain, + int min_version, + int max_version, + int last_release_version = -1, + bool rewrite = false) { std::lock_guard lock(mutex_); - assert(map_.end() == map_.find(domain)); + if (!rewrite) { + // Maybe we need raise an exception when domain is already exist. + assert(map_.end() == map_.find(domain)); + assert(last_release_version_map_.end() == last_release_version_map_.find(domain)); + } map_[domain] = std::make_pair(min_version, max_version); // If a last-release-version is not explicitly specified, use max as // last-release-version. if (last_release_version == -1) last_release_version = max_version; - assert(last_release_version_map_.end() == last_release_version_map_.find(domain)); last_release_version_map_[domain] = last_release_version; } diff --git a/onnx/onnx_cpp2py_export/defs.pyi b/onnx/onnx_cpp2py_export/defs.pyi index 5cde8a146ca..459e1531634 100644 --- a/onnx/onnx_cpp2py_export/defs.pyi +++ b/onnx/onnx_cpp2py_export/defs.pyi @@ -192,5 +192,6 @@ def get_schema( def get_schema(op_type: str, domain: str = "") -> OpSchema: ... def get_all_schemas() -> Sequence[OpSchema]: ... def get_all_schemas_with_history() -> Sequence[OpSchema]: ... +def set_domain_to_version(domain: str, min_version: int, max_version: int, last_release_version: int = -1) -> None: ... def register_schema(schema: OpSchema) -> None: ... def deregister_schema(op_name: str, version: int, domain: str) -> None: ... diff --git a/onnx/test/schema_test.py b/onnx/test/schema_test.py index e50eee5f371..547d85b1489 100644 --- a/onnx/test/schema_test.py +++ b/onnx/test/schema_test.py @@ -263,6 +263,11 @@ def test_init_with_default_value(self): "op_type": "CustomOp", "op_version": 1, "op_domain": "", + }, + { + "op_type": "CustomOp", + "op_version": 1, + "op_domain": "test", } ] ) @@ -290,11 +295,13 @@ def test_register_schema(self): input = f""" < ir_version: 7, - opset_import: ["{self.op_domain}" : {self.op_version}] + opset_import: [ + "{self.op_domain}" : {self.op_version} + ] > agraph (float[N, 128] X, int32 Y) => (float[N] Z) {{ - Z = {self.op_type}(X, Y) + Z = {self.op_domain}.{self.op_type}(X, Y) }} """ model = onnx.parser.parse_model(input) @@ -317,10 +324,10 @@ def test_register_schema(self): ], ) with self.assertRaises(onnx.checker.ValidationError): - onnx.checker.check_model(model) + onnx.checker.check_model(model, check_custom_op=True) onnx.defs.register_schema(op_schema) self.assertTrue(onnx.defs.has(self.op_type, self.op_domain)) - onnx.checker.check_model(model) + onnx.checker.check_model(model, check_custom_op=True) def test_register_schema_raises_error_when_registering_a_schema_twice(self): op_schema = defs.OpSchema( From 59bedfcb5c98a54c0ede016e6788e03a6849f9eb Mon Sep 17 00:00:00 2001 From: opluss Date: Wed, 7 Feb 2024 23:17:00 +0800 Subject: [PATCH 12/35] fix py lint Signed-off-by: opluss --- onnx/test/schema_test.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/onnx/test/schema_test.py b/onnx/test/schema_test.py index 547d85b1489..990e6f6b5be 100644 --- a/onnx/test/schema_test.py +++ b/onnx/test/schema_test.py @@ -1,6 +1,7 @@ # Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 +import contextlib import unittest from typing import Sequence @@ -257,6 +258,7 @@ def test_init_with_default_value(self): self.assertEqual("attr1", attribute.name) self.assertEqual("attr1 description", attribute.description) + @parameterized.parameterized_class( [ { @@ -268,7 +270,7 @@ def test_init_with_default_value(self): "op_type": "CustomOp", "op_version": 1, "op_domain": "test", - } + }, ] ) class TestOpSchemaRegister(unittest.TestCase): @@ -282,22 +284,16 @@ def setUp(self) -> None: def tearDown(self) -> None: # Clean up the registered schema - try: - onnx.defs.deregister_schema( - self.op_type, - self.op_version, - self.op_domain - ) - except onnx.defs.SchemaError: - pass + with contextlib.suppress(onnx.defs.SchemaError): + onnx.defs.deregister_schema(self.op_type, self.op_version, self.op_domain) def test_register_schema(self): input = f""" < - ir_version: 7, - opset_import: [ - "{self.op_domain}" : {self.op_version} - ] + ir_version: 7, + opset_import: [ + "{self.op_domain}" : {self.op_version} + ] > agraph (float[N, 128] X, int32 Y) => (float[N] Z) {{ @@ -337,7 +333,7 @@ def test_register_schema_raises_error_when_registering_a_schema_twice(self): ) onnx.defs.register_schema(op_schema) with self.assertRaises(onnx.defs.SchemaError): - onnx.defs.register_schema(op_schema) + onnx.defs.register_schema(op_schema) def test_deregister_schema(self): op_schema = defs.OpSchema( @@ -347,7 +343,9 @@ def test_deregister_schema(self): ) onnx.defs.register_schema(op_schema) self.assertTrue(onnx.defs.has(op_schema.name, op_schema.domain)) - onnx.defs.deregister_schema(op_schema.name, op_schema.since_version, op_schema.domain) + onnx.defs.deregister_schema( + op_schema.name, op_schema.since_version, op_schema.domain + ) self.assertFalse(onnx.defs.has(op_schema.name, op_schema.domain)) def test_deregister_raise_error_when_deregister_nonexistent_opschema(self): From d06560b81c318053ced848e54d4129e5445a1b7c Mon Sep 17 00:00:00 2001 From: opluss Date: Wed, 7 Feb 2024 23:40:55 +0800 Subject: [PATCH 13/35] add case for unittest Signed-off-by: opluss --- onnx/test/schema_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnx/test/schema_test.py b/onnx/test/schema_test.py index 990e6f6b5be..3d6d5d785ff 100644 --- a/onnx/test/schema_test.py +++ b/onnx/test/schema_test.py @@ -325,6 +325,11 @@ def test_register_schema(self): self.assertTrue(onnx.defs.has(self.op_type, self.op_domain)) onnx.checker.check_model(model, check_custom_op=True) + registered_op = onnx.defs.get_schema( + op_schema.name, op_schema.since_version, op_schema.domain + ) + self.assertEqual(str(registered_op), str(op_schema)) + def test_register_schema_raises_error_when_registering_a_schema_twice(self): op_schema = defs.OpSchema( self.op_type, From aa1d9c34f0d16745c4eefe0163c288e10b14364d Mon Sep 17 00:00:00 2001 From: opluss Date: Wed, 7 Feb 2024 23:59:21 +0800 Subject: [PATCH 14/35] modify description for `set_domain_to_version` Signed-off-by: opluss --- onnx/cpp2py_export.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index 26eb344377d..d6ae6130b88 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -480,7 +480,7 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { "min_version"_a, "max_version"_a, "last_release_version"_a = -1, - "Register a user provided OpSchema.") + "Set the version range of the specified domain.") .def( "register_schema", [](OpSchema* schema) { RegisterSchema(*schema, 0, true, true); }, From 82349df12d6a465348983604ab87f7ea7e69c8ba Mon Sep 17 00:00:00 2001 From: opluss Date: Thu, 8 Feb 2024 00:00:31 +0800 Subject: [PATCH 15/35] fix coding error Signed-off-by: opluss --- onnx/defs/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnx/defs/__init__.py b/onnx/defs/__init__.py index 4430c6a6a4d..d49235173a7 100644 --- a/onnx/defs/__init__.py +++ b/onnx/defs/__init__.py @@ -33,7 +33,6 @@ get_schema = C.get_schema get_all_schemas = C.get_all_schemas get_all_schemas_with_history = C.get_all_schemas_with_history -register_schema = C.register_schema deregister_schema = C.deregister_schema From 7c1983cec312a45146257f85e52f1e3a195db4a6 Mon Sep 17 00:00:00 2001 From: opluss Date: Thu, 8 Feb 2024 06:46:00 +0800 Subject: [PATCH 16/35] fix py lint Signed-off-by: opluss --- onnx/defs/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnx/defs/__init__.py b/onnx/defs/__init__.py index d49235173a7..bad4e9fdf40 100644 --- a/onnx/defs/__init__.py +++ b/onnx/defs/__init__.py @@ -126,11 +126,11 @@ def get_function_ops() -> List[OpSchema]: def register_schema(schema: OpSchema): """Register a user provided OpSchema. And extend domain automatically.""" - map = C.schema_version_map() + version_map = C.schema_version_map() domain = schema.domain version = schema.since_version - min_version, max_version = map.get(domain, (version, version)) - if domain not in map or not (min_version <= version <= max_version): + min_version, max_version = version_map.get(domain, (version, version)) + if domain not in version_map or not (min_version <= version <= max_version): min_version = min(min_version, version) max_version = max(max_version, version) C.set_domain_to_version(schema.domain, min_version, max_version) From c778180da2101750976d83a0b27f6ca89d090458 Mon Sep 17 00:00:00 2001 From: opluss Date: Thu, 8 Feb 2024 06:47:21 +0800 Subject: [PATCH 17/35] replace `check_custom_op` to `check_custom_domain` in checker Signed-off-by: opluss --- onnx/checker.cc | 12 ++++++------ onnx/checker.h | 14 +++++++------- onnx/checker.py | 8 ++++---- onnx/cpp2py_export.cc | 12 +++++++----- onnx/onnx_cpp2py_export/checker.pyi | 4 ++-- onnx/test/schema_test.py | 4 ++-- 6 files changed, 28 insertions(+), 26 deletions(-) diff --git a/onnx/checker.cc b/onnx/checker.cc index 3c981fcbbf5..04c5d0614df 100644 --- a/onnx/checker.cc +++ b/onnx/checker.cc @@ -651,8 +651,8 @@ void check_node(const NodeProto& node, const CheckerContext& ctx, const LexicalS const auto* schema = ctx.get_schema_registry()->GetSchema(node.op_type(), domain_version, node.domain()); if (!schema) { if (node.domain() == ONNX_DOMAIN || node.domain() == AI_ONNX_ML_DOMAIN || node.domain() == "ai.onnx" || - node.domain() == AI_ONNX_TRAINING_DOMAIN || ctx.check_custom_op()) { - // fail the checker if op which is in built-in domains or enable `check_custom_op` has no schema + node.domain() == AI_ONNX_TRAINING_DOMAIN || ctx.check_custom_domain()) { + // fail the checker if op which is in built-in domains or enable `check_custom_domain` has no schema fail_check( "No Op registered for " + node.op_type() + " with domain_version of " + ONNX_NAMESPACE::to_string(domain_version)); @@ -1014,7 +1014,7 @@ void check_model( const std::string& model_path, bool full_check, bool skip_opset_compatibility_check, - bool check_custom_op) { + bool check_custom_domain) { ModelProto model; LoadProtoFromPath(model_path, model); @@ -1026,7 +1026,7 @@ void check_model( } ctx.set_model_dir(model_dir); ctx.set_skip_opset_compatibility_check(skip_opset_compatibility_check); - ctx.set_check_custom_op(check_custom_op); + ctx.set_check_custom_domain(check_custom_domain); check_model(model, ctx); if (full_check) { @@ -1035,10 +1035,10 @@ void check_model( } } -void check_model(const ModelProto& model, bool full_check, bool skip_opset_compatibility_check, bool check_custom_op) { +void check_model(const ModelProto& model, bool full_check, bool skip_opset_compatibility_check, bool check_custom_domain) { CheckerContext ctx; ctx.set_skip_opset_compatibility_check(skip_opset_compatibility_check); - ctx.set_check_custom_op(check_custom_op); + ctx.set_check_custom_domain(check_custom_domain); check_model(model, ctx); if (full_check) { ShapeInferenceOptions options{true, 1, false}; diff --git a/onnx/checker.h b/onnx/checker.h index a2848f9804d..51822fd4171 100644 --- a/onnx/checker.h +++ b/onnx/checker.h @@ -84,12 +84,12 @@ class CheckerContext final { skip_opset_compatibility_check_ = value; } - bool check_custom_op() const { - return check_custom_op_; + bool check_custom_domain() const { + return check_custom_domain_; } - void set_check_custom_op(bool value) { - check_custom_op_ = value; + void set_check_custom_domain(bool value) { + check_custom_domain_ = value; } explicit CheckerContext() : ir_version_(-1) {} @@ -101,7 +101,7 @@ class CheckerContext final { const ISchemaRegistry* schema_registry_ = OpSchemaRegistry::Instance(); std::string model_dir_; bool skip_opset_compatibility_check_ = false; - bool check_custom_op_ = false; + bool check_custom_domain_ = false; }; class LexicalScopeContext final { @@ -167,8 +167,8 @@ void check_model_local_functions( const CheckerContext& ctx, const LexicalScopeContext& parent_lex); -void check_model(const ModelProto& model, bool full_check = false, bool skip_opset_compatibility_check = false, bool check_custom_op = false); -void check_model(const std::string& model_path, bool full_check = false, bool skip_opset_compatibility_check = false, bool check_custom_op = false); +void check_model(const ModelProto& model, bool full_check = false, bool skip_opset_compatibility_check = false, bool check_custom_domain = false); +void check_model(const std::string& model_path, bool full_check = false, bool skip_opset_compatibility_check = false, bool check_custom_domain = false); bool check_is_experimental_op(const NodeProto& node); diff --git a/onnx/checker.py b/onnx/checker.py index c067742b337..ab4538e1074 100644 --- a/onnx/checker.py +++ b/onnx/checker.py @@ -137,7 +137,7 @@ def check_model( model: ModelProto | str | bytes | os.PathLike, full_check: bool = False, skip_opset_compatibility_check: bool = False, - check_custom_op: bool = False, + check_custom_domain: bool = False, ) -> None: """Check the consistency of a model. @@ -155,12 +155,12 @@ def check_model( full_check: If True, the function also runs shape inference check. skip_opset_compatibility_check: If True, the function skips the check for opset compatibility. - check_custom_op: If True, the function will check all domains. Otherwise + check_custom_domain: If True, the function will check all domains. Otherwise only check built-in domains. """ # If model is a path instead of ModelProto if isinstance(model, (str, os.PathLike)): - C.check_model_path(os.fspath(model), full_check, skip_opset_compatibility_check, check_custom_op) + C.check_model_path(os.fspath(model), full_check, skip_opset_compatibility_check, check_custom_domain) else: protobuf_string = ( model if isinstance(model, bytes) else model.SerializeToString() @@ -171,7 +171,7 @@ def check_model( raise ValueError( "This protobuf of onnx model is too large (>2GB). Call check_model with model path instead." ) - C.check_model(protobuf_string, full_check, skip_opset_compatibility_check, check_custom_op) + C.check_model(protobuf_string, full_check, skip_opset_compatibility_check, check_custom_domain) ValidationError = C.ValidationError diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index d6ae6130b88..c7e85a50972 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -571,24 +571,26 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { checker.def( "check_model", - [](const py::bytes& bytes, bool full_check, bool skip_opset_compatibility_check, bool check_custom_op) -> void { + [](const py::bytes& bytes, bool full_check, bool skip_opset_compatibility_check, bool check_custom_domain) + -> void { ModelProto proto{}; ParseProtoFromPyBytes(&proto, bytes); - checker::check_model(proto, full_check, skip_opset_compatibility_check, check_custom_op); + checker::check_model(proto, full_check, skip_opset_compatibility_check, check_custom_domain); }, "bytes"_a, "full_check"_a = false, "skip_opset_compatibility_check"_a = false, - "check_custom_op"_a = false); + "check_custom_domain"_a = false); checker.def( "check_model_path", - (void (*)(const std::string& path, bool full_check, bool skip_opset_compatibility_check, bool check_custom_op)) & + (void (*)( + const std::string& path, bool full_check, bool skip_opset_compatibility_check, bool check_custom_domain)) & checker::check_model, "path"_a, "full_check"_a = false, "skip_opset_compatibility_check"_a = false, - "check_custom_op"_a = false); + "check_custom_domain"_a = false); // Submodule `version_converter` auto version_converter = onnx_cpp2py_export.def_submodule("version_converter"); diff --git a/onnx/onnx_cpp2py_export/checker.pyi b/onnx/onnx_cpp2py_export/checker.pyi index f7e8e8d637a..c887293772a 100644 --- a/onnx/onnx_cpp2py_export/checker.pyi +++ b/onnx/onnx_cpp2py_export/checker.pyi @@ -17,5 +17,5 @@ def check_attribute(bytes: bytes, checker_context: CheckerContext, lexical_scope def check_node(bytes: bytes, checker_context: CheckerContext, lexical_scope_context: LexicalScopeContext) -> None: ... # noqa: A002 def check_function(bytes: bytes, checker_context: CheckerContext, lexical_scope_context: LexicalScopeContext) -> None: ... # noqa: A002 def check_graph(bytes: bytes, checker_context: CheckerContext, lexical_scope_context: LexicalScopeContext) -> None: ... # noqa: A002 -def check_model(bytes: bytes, full_check: bool, skip_opset_compatibility_check: bool, check_custom_op: bool) -> None: ... # noqa: A002 -def check_model_path(path: str, full_check: bool, skip_opset_compatibility_check: bool, check_custom_op: bool) -> None: ... +def check_model(bytes: bytes, full_check: bool, skip_opset_compatibility_check: bool, check_custom_domain: bool) -> None: ... # noqa: A002 +def check_model_path(path: str, full_check: bool, skip_opset_compatibility_check: bool, check_custom_domain: bool) -> None: ... diff --git a/onnx/test/schema_test.py b/onnx/test/schema_test.py index 3d6d5d785ff..96af6b64861 100644 --- a/onnx/test/schema_test.py +++ b/onnx/test/schema_test.py @@ -320,10 +320,10 @@ def test_register_schema(self): ], ) with self.assertRaises(onnx.checker.ValidationError): - onnx.checker.check_model(model, check_custom_op=True) + onnx.checker.check_model(model, check_custom_domain=True) onnx.defs.register_schema(op_schema) self.assertTrue(onnx.defs.has(self.op_type, self.op_domain)) - onnx.checker.check_model(model, check_custom_op=True) + onnx.checker.check_model(model, check_custom_domain=True) registered_op = onnx.defs.get_schema( op_schema.name, op_schema.since_version, op_schema.domain From 306bd6b683a326bfa3facd366a84d691d04d1b4d Mon Sep 17 00:00:00 2001 From: opluss Date: Thu, 8 Feb 2024 10:51:20 +0800 Subject: [PATCH 18/35] fix code style and annotate Signed-off-by: opluss --- onnx/checker.cc | 8 ++++++-- onnx/defs/__init__.py | 11 +++++++++-- onnx/defs/schema.h | 9 ++++++--- onnx/onnx_cpp2py_export/defs.pyi | 2 +- onnx/test/schema_test.py | 2 +- 5 files changed, 23 insertions(+), 9 deletions(-) diff --git a/onnx/checker.cc b/onnx/checker.cc index 04c5d0614df..e0ab6907c15 100644 --- a/onnx/checker.cc +++ b/onnx/checker.cc @@ -652,7 +652,7 @@ void check_node(const NodeProto& node, const CheckerContext& ctx, const LexicalS if (!schema) { if (node.domain() == ONNX_DOMAIN || node.domain() == AI_ONNX_ML_DOMAIN || node.domain() == "ai.onnx" || node.domain() == AI_ONNX_TRAINING_DOMAIN || ctx.check_custom_domain()) { - // fail the checker if op which is in built-in domains or enable `check_custom_domain` has no schema + // fail the checker if op is in built-in domains or if it has no schema when `check_custom_domain` is true fail_check( "No Op registered for " + node.op_type() + " with domain_version of " + ONNX_NAMESPACE::to_string(domain_version)); @@ -1035,7 +1035,11 @@ void check_model( } } -void check_model(const ModelProto& model, bool full_check, bool skip_opset_compatibility_check, bool check_custom_domain) { +void check_model( + const ModelProto& model, + bool full_check, + bool skip_opset_compatibility_check, + bool check_custom_domain) { CheckerContext ctx; ctx.set_skip_opset_compatibility_check(skip_opset_compatibility_check); ctx.set_check_custom_domain(check_custom_domain); diff --git a/onnx/defs/__init__.py b/onnx/defs/__init__.py index bad4e9fdf40..0750b7727b8 100644 --- a/onnx/defs/__init__.py +++ b/onnx/defs/__init__.py @@ -124,8 +124,15 @@ def get_function_ops() -> List[OpSchema]: SchemaError = C.SchemaError -def register_schema(schema: OpSchema): - """Register a user provided OpSchema. And extend domain automatically.""" + +def register_schema(schema: OpSchema) -> None: + """Register a user provided OpSchema. + + The function extends available operator set versions for the provided domain if necessary. + + Args: + schema: The OpSchema to register. + """ version_map = C.schema_version_map() domain = schema.domain version = schema.since_version diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index 736146b313c..0c0f96437fd 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -1199,8 +1199,9 @@ class OpSchemaRegistry final : public ISchemaRegistry { map_[domain] = std::make_pair(min_version, max_version); // If a last-release-version is not explicitly specified, use max as // last-release-version. - if (last_release_version == -1) + if (last_release_version == -1) { last_release_version = max_version; + } last_release_version_map_[domain] = last_release_version; } @@ -1256,14 +1257,16 @@ class OpSchemaRegistry final : public ISchemaRegistry { if (opset_version_to_load != 0) { // Stops because the opset_version is higher than opset_version_to_load - if (ver > opset_version_to_load) + if (ver > opset_version_to_load) { return; + } // Stops because a later version is registered within target opset version if (!schema_ver_map.empty()) { int max_registered_ver_le_target = GetMaxRegisteredVerWithinTarget(schema_ver_map, opset_version_to_load); - if (max_registered_ver_le_target >= ver) + if (max_registered_ver_le_target >= ver) { return; + } } } diff --git a/onnx/onnx_cpp2py_export/defs.pyi b/onnx/onnx_cpp2py_export/defs.pyi index 48d756c9050..c337c567766 100644 --- a/onnx/onnx_cpp2py_export/defs.pyi +++ b/onnx/onnx_cpp2py_export/defs.pyi @@ -194,4 +194,4 @@ def get_all_schemas() -> Sequence[OpSchema]: ... def get_all_schemas_with_history() -> Sequence[OpSchema]: ... def set_domain_to_version(domain: str, min_version: int, max_version: int, last_release_version: int = -1) -> None: ... def register_schema(schema: OpSchema) -> None: ... -def deregister_schema(op_name: str, version: int, domain: str) -> None: ... +def deregister_schema(op_type: str, version: int, domain: str) -> None: ... diff --git a/onnx/test/schema_test.py b/onnx/test/schema_test.py index 96af6b64861..f3f58f3fa5a 100644 --- a/onnx/test/schema_test.py +++ b/onnx/test/schema_test.py @@ -353,7 +353,7 @@ def test_deregister_schema(self): ) self.assertFalse(onnx.defs.has(op_schema.name, op_schema.domain)) - def test_deregister_raise_error_when_deregister_nonexistent_opschema(self): + def test_deregister_schema_raises_error_when_opschema_does_not_exist(self): with self.assertRaises(onnx.defs.SchemaError): onnx.defs.deregister_schema(self.op_type, self.op_version, self.op_domain) From 65c61f0dfc4a7af1935b165fbf13fd727a0ea36d Mon Sep 17 00:00:00 2001 From: opluss Date: Thu, 8 Feb 2024 20:45:38 +0800 Subject: [PATCH 19/35] fix py code style Signed-off-by: opluss --- onnx/checker.h | 12 ++++++++++-- onnx/checker.py | 14 ++++++++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/onnx/checker.h b/onnx/checker.h index 51822fd4171..e06d0138dc6 100644 --- a/onnx/checker.h +++ b/onnx/checker.h @@ -167,8 +167,16 @@ void check_model_local_functions( const CheckerContext& ctx, const LexicalScopeContext& parent_lex); -void check_model(const ModelProto& model, bool full_check = false, bool skip_opset_compatibility_check = false, bool check_custom_domain = false); -void check_model(const std::string& model_path, bool full_check = false, bool skip_opset_compatibility_check = false, bool check_custom_domain = false); +void check_model( + const ModelProto& model, + bool full_check = false, + bool skip_opset_compatibility_check = false, + bool check_custom_domain = false); +void check_model( + const std::string& model_path, + bool full_check = false, + bool skip_opset_compatibility_check = false, + bool check_custom_domain = false); bool check_is_experimental_op(const NodeProto& node); diff --git a/onnx/checker.py b/onnx/checker.py index ab4538e1074..c89c4cec545 100644 --- a/onnx/checker.py +++ b/onnx/checker.py @@ -160,7 +160,12 @@ def check_model( """ # If model is a path instead of ModelProto if isinstance(model, (str, os.PathLike)): - C.check_model_path(os.fspath(model), full_check, skip_opset_compatibility_check, check_custom_domain) + C.check_model_path( + os.fspath(model), + full_check, + skip_opset_compatibility_check, + check_custom_domain, + ) else: protobuf_string = ( model if isinstance(model, bytes) else model.SerializeToString() @@ -171,7 +176,12 @@ def check_model( raise ValueError( "This protobuf of onnx model is too large (>2GB). Call check_model with model path instead." ) - C.check_model(protobuf_string, full_check, skip_opset_compatibility_check, check_custom_domain) + C.check_model( + protobuf_string, + full_check, + skip_opset_compatibility_check, + check_custom_domain, + ) ValidationError = C.ValidationError From 0af0ac78de7c966623000aa3d17bbf9744df063f Mon Sep 17 00:00:00 2001 From: opluss Date: Thu, 8 Feb 2024 20:49:19 +0800 Subject: [PATCH 20/35] simplify deregister_schema binding uniformly use `op_type` in OpSchema deregister Signed-off-by: opluss --- onnx/cpp2py_export.cc | 4 +--- onnx/defs/schema.cc | 4 ++-- onnx/defs/schema.h | 11 ++++++----- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index c7e85a50972..ca54d412508 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -488,9 +488,7 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { "Register a user provided OpSchema.") .def( "deregister_schema", - [](const std::string& op_type, int version, const std::string& domain) { - DeregisterSchema(op_type, version, domain); - }, + &DeregisterSchema, "op_type"_a, "version"_a, "domain"_a, diff --git a/onnx/defs/schema.cc b/onnx/defs/schema.cc index bbd74ce4900..5ccdace18cb 100644 --- a/onnx/defs/schema.cc +++ b/onnx/defs/schema.cc @@ -41,8 +41,8 @@ void RegisterSchema(OpSchema schema, int opset_version_to_load, bool fail_duplic // The (name, version, domain) must match the target exactly // Otherwise will raise an SchemaError -void DeregisterSchema(const std::string& name, int version, const std::string& domain) { - OpSchemaRegistry::OpSchemaDeregister(name, version, domain); +void DeregisterSchema(const std::string& op_type, int version, const std::string& domain) { + OpSchemaRegistry::OpSchemaDeregister(op_type, version, domain); } #ifndef NDEBUG diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index 0c0f96437fd..7d82a757459 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -1320,13 +1320,14 @@ class OpSchemaRegistry final : public ISchemaRegistry { } }; - static void OpSchemaDeregister(const std::string& name, const int version, const std::string& domain = ONNX_DOMAIN) { + static void + OpSchemaDeregister(const std::string& op_type, const int version, const std::string& domain = ONNX_DOMAIN) { auto& schema_map = GetMapWithoutEnsuringRegistration(); - if (schema_map.count(name) && schema_map[name].count(domain) && schema_map[name][domain].count(version)) { - schema_map[name][domain].erase(version); + if (schema_map.count(op_type) && schema_map[op_type].count(domain) && schema_map[op_type][domain].count(version)) { + schema_map[op_type][domain].erase(version); } else { std::stringstream err; - err << "Attempting to deregister an unregistered schema with name: " << name << " domain: " << domain + err << "Attempting to deregister an unregistered schema with name: " << op_type << " domain: " << domain << " version: " << version << std::endl; fail_schema(err.str()); } @@ -1453,7 +1454,7 @@ void RegisterSchema( int opset_version_to_load = 0, bool fail_duplicate_schema = true, bool fail_with_exception = false); -void DeregisterSchema(const std::string& name, int version, const std::string& domain = ONNX_DOMAIN); +void DeregisterSchema(const std::string& op_type, int version, const std::string& domain = ONNX_DOMAIN); // Registers the latest opset schema before opset_version_to_load // By default opset_version_to_load=0 means it will register all versions From ada696e5aaae6d31f7a40ce1ca31d38a1a482de4 Mon Sep 17 00:00:00 2001 From: opluss Date: Fri, 9 Feb 2024 09:25:59 +0800 Subject: [PATCH 21/35] add case to check schema accessible after deregister and fix bug Signed-off-by: opluss --- onnx/defs/schema.h | 4 +++- onnx/test/schema_test.py | 24 ++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index 7d82a757459..8fb31304ffb 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -1442,7 +1442,9 @@ class OpSchemaRegistry final : public ISchemaRegistry { for (auto& x : map()) { for (auto& y : x.second) { auto& version2schema = y.second; - r.emplace_back(version2schema.rbegin()->second); + if (!version2schema.empty()) { + r.emplace_back(version2schema.rbegin()->second); + } } } return r; diff --git a/onnx/test/schema_test.py b/onnx/test/schema_test.py index f3f58f3fa5a..cedff2210b8 100644 --- a/onnx/test/schema_test.py +++ b/onnx/test/schema_test.py @@ -325,6 +325,7 @@ def test_register_schema(self): self.assertTrue(onnx.defs.has(self.op_type, self.op_domain)) onnx.checker.check_model(model, check_custom_domain=True) + # Also make sure the `op_schema` is accessible after register registered_op = onnx.defs.get_schema( op_schema.name, op_schema.since_version, op_schema.domain ) @@ -357,6 +358,29 @@ def test_deregister_schema_raises_error_when_opschema_does_not_exist(self): with self.assertRaises(onnx.defs.SchemaError): onnx.defs.deregister_schema(self.op_type, self.op_version, self.op_domain) + def test_legacy_schema_accessible_after_deregister(self): + op_schema = defs.OpSchema( + self.op_type, + self.op_domain, + self.op_version, + ) + onnx.defs.register_schema(op_schema) + schema_a = onnx.defs.get_schema( + op_schema.name, op_schema.since_version, op_schema.domain + ) + schema_b = onnx.defs.get_schema(op_schema.name, op_schema.domain) + filter_schema = lambda schemas: [op for op in schemas if op.name == op_schema.name] + schema_c = filter_schema(onnx.defs.get_all_schemas()) + schema_d = filter_schema(onnx.defs.get_all_schemas_with_history()) + self.assertEqual(len(schema_c), 1) + self.assertEqual(len(schema_d), 1) + # Avoid memory residue and access storage as much as possible + self.assertEqual(str(schema_a), str(op_schema)) + self.assertEqual(str(schema_b), str(op_schema)) + self.assertEqual(str(schema_c[0]), str(op_schema)) + self.assertEqual(str(schema_d[0]), str(op_schema)) + + if __name__ == "__main__": unittest.main() From 8b5ed1b3dbd7d0822f9e8484db95be46c3600839 Mon Sep 17 00:00:00 2001 From: opluss Date: Fri, 9 Feb 2024 10:16:23 +0800 Subject: [PATCH 22/35] impl update method Signed-off-by: opluss --- onnx/cpp2py_export.cc | 10 +++++++--- onnx/defs/schema.h | 46 +++++++++++++++++++++++++++++++++---------- 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index ca54d412508..02b26365c66 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -473,14 +473,18 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { .def( "set_domain_to_version", [](const std::string& domain, int min_version, int max_version, int last_release_version) { - OpSchemaRegistry::DomainToVersionRange::Instance().AddDomainToVersion( - domain, min_version, max_version, last_release_version, true); + auto& obj = OpSchemaRegistry::DomainToVersionRange::Instance(); + if (obj.Map().count(domain) == 0) { + obj.AddDomainToVersion(domain, min_version, max_version, last_release_version); + } else { + obj.UpdateDomainToVersion(domain, min_version, max_version, last_release_version); + } }, "domain"_a, "min_version"_a, "max_version"_a, "last_release_version"_a = -1, - "Set the version range of the specified domain.") + "Set the version range and last release version of the specified domain.") .def( "register_schema", [](OpSchema* schema) { RegisterSchema(*schema, 0, true, true); }, diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index 8fb31304ffb..666c586397b 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -1184,17 +1184,20 @@ class OpSchemaRegistry final : public ISchemaRegistry { // standard ONNX domains as above). Custom-domains are free to interpret // this as appropriate (that is, as relative to releases of custom-domain // as opposed to ONNX releases). - void AddDomainToVersion( - const std::string& domain, - int min_version, - int max_version, - int last_release_version = -1, - bool rewrite = false) { + void AddDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version) { std::lock_guard lock(mutex_); - if (!rewrite) { - // Maybe we need raise an exception when domain is already exist. - assert(map_.end() == map_.find(domain)); - assert(last_release_version_map_.end() == last_release_version_map_.find(domain)); + if (map_.count(domain) != 0) { + std::stringstream err; + err << "Trying to add a domain to DomainToVersion map, but the domain is already exist with version range (" + << map_.at(domain).first << ", " << map_.at(domain).second << "). domain: \"" << domain << "\"" + << std::endl; + fail_schema(err.str()); + } + if (last_release_version_map_.count(domain) != 0) { + std::stringstream err; + err << "Trying to add a domain to LastReleaseVersion map, but the domain is already exist with last version: " + << last_release_version_map_.at(domain) << ", domain: \"" << domain << "\"" << std::endl; + fail_schema(err.str()); } map_[domain] = std::make_pair(min_version, max_version); // If a last-release-version is not explicitly specified, use max as @@ -1205,6 +1208,29 @@ class OpSchemaRegistry final : public ISchemaRegistry { last_release_version_map_[domain] = last_release_version; } + void UpdateDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version) { + std::lock_guard lock(mutex_); + if (map_.count(domain) == 0) { + std::stringstream err; + err << "Trying to update a domain in DomainToVersion map, but the domain has not been add. domain: \"" << domain + << "\"" << std::endl; + fail_schema(err.str()); + } + if (last_release_version_map_.count(domain) == 0) { + std::stringstream err; + err << "Trying to update a domain in LastReleaseVersion map, but the domain has not been add. domain: \"" + << domain << "\"" << std::endl; + fail_schema(err.str()); + } + map_.at(domain).first = min_version; + map_.at(domain).second = max_version; + // Correspond to `AddDomainToVersion` + if (last_release_version == -1) { + last_release_version = max_version; + } + last_release_version_map_.at(domain) = last_release_version; + } + static DomainToVersionRange& Instance(); private: From 8d7301e90f41f68f16c057e93682e460b458b5f2 Mon Sep 17 00:00:00 2001 From: opluss Date: Fri, 9 Feb 2024 10:19:01 +0800 Subject: [PATCH 23/35] append annotate for unittest case Signed-off-by: opluss --- onnx/test/schema_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnx/test/schema_test.py b/onnx/test/schema_test.py index cedff2210b8..86127d1c1f1 100644 --- a/onnx/test/schema_test.py +++ b/onnx/test/schema_test.py @@ -261,11 +261,13 @@ def test_init_with_default_value(self): @parameterized.parameterized_class( [ + # register to exist domain { "op_type": "CustomOp", "op_version": 1, "op_domain": "", }, + # register to new domain { "op_type": "CustomOp", "op_version": 1, From ff8cdfba0f17be952c733b3f2b858b2c903c288b Mon Sep 17 00:00:00 2001 From: opluss Date: Fri, 9 Feb 2024 10:25:17 +0800 Subject: [PATCH 24/35] Safe implementation for register Signed-off-by: opluss --- onnx/defs/schema.cc | 6 +++++- onnx/defs/schema.h | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/onnx/defs/schema.cc b/onnx/defs/schema.cc index 5ccdace18cb..0f6768e1951 100644 --- a/onnx/defs/schema.cc +++ b/onnx/defs/schema.cc @@ -30,7 +30,11 @@ constexpr int OpSchema::kUninitializedSinceVersion; // By default if opset_version_to_load=0, it registers all opset schema for all opset versions // Otherwise, it only registers the latest schema according to opset_version_to_load -void RegisterSchema(OpSchema schema, int opset_version_to_load, bool fail_duplicate_schema, bool fail_with_exception) { +void RegisterSchema( + const OpSchema& schema, + int opset_version_to_load, + bool fail_duplicate_schema, + bool fail_with_exception) { if (fail_with_exception) { OpSchemaRegistry::OpSchemaRegisterOnce::OpSchemaRegisterImpl(schema, opset_version_to_load, fail_duplicate_schema); } else { diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index 666c586397b..814f85c46ef 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -1247,7 +1247,7 @@ class OpSchemaRegistry final : public ISchemaRegistry { class OpSchemaRegisterOnce final { public: - OpSchemaRegisterOnce(OpSchema& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { + OpSchemaRegisterOnce(const OpSchema& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { ONNX_TRY { OpSchemaRegisterImpl(op_schema, opset_version_to_load, fail_duplicate_schema); } @@ -1256,7 +1256,7 @@ class OpSchemaRegistry final : public ISchemaRegistry { } } static void - OpSchemaRegisterImpl(OpSchema& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { + OpSchemaRegisterImpl(OpSchema op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { op_schema.Finalize(); auto& m = GetMapWithoutEnsuringRegistration(); auto& op_name = op_schema.Name(); @@ -1478,7 +1478,7 @@ class OpSchemaRegistry final : public ISchemaRegistry { }; void RegisterSchema( - OpSchema schema, + const OpSchema& schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true, bool fail_with_exception = false); From 17e76857947d0d6fbad53b83f9f2be94c60ba38a Mon Sep 17 00:00:00 2001 From: opluss Date: Fri, 9 Feb 2024 15:20:55 +0800 Subject: [PATCH 25/35] fix py code style Signed-off-by: opluss --- onnx/test/schema_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnx/test/schema_test.py b/onnx/test/schema_test.py index 86127d1c1f1..ebd56c08170 100644 --- a/onnx/test/schema_test.py +++ b/onnx/test/schema_test.py @@ -371,7 +371,10 @@ def test_legacy_schema_accessible_after_deregister(self): op_schema.name, op_schema.since_version, op_schema.domain ) schema_b = onnx.defs.get_schema(op_schema.name, op_schema.domain) - filter_schema = lambda schemas: [op for op in schemas if op.name == op_schema.name] + + def filter_schema(schemas): + return [op for op in schemas if op.name == op_schema.name] + schema_c = filter_schema(onnx.defs.get_all_schemas()) schema_d = filter_schema(onnx.defs.get_all_schemas_with_history()) self.assertEqual(len(schema_c), 1) @@ -383,6 +386,5 @@ def test_legacy_schema_accessible_after_deregister(self): self.assertEqual(str(schema_d[0]), str(op_schema)) - if __name__ == "__main__": unittest.main() From d5a59f5b32f5eb743dc5a793dcb0327fb8804a92 Mon Sep 17 00:00:00 2001 From: opluss Date: Wed, 14 Feb 2024 19:56:22 +0800 Subject: [PATCH 26/35] Revert "Safe implementation for register" This reverts commit ff8cdfba0f17be952c733b3f2b858b2c903c288b. Signed-off-by: opluss --- onnx/defs/schema.cc | 6 +----- onnx/defs/schema.h | 6 +++--- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/onnx/defs/schema.cc b/onnx/defs/schema.cc index 0f6768e1951..5ccdace18cb 100644 --- a/onnx/defs/schema.cc +++ b/onnx/defs/schema.cc @@ -30,11 +30,7 @@ constexpr int OpSchema::kUninitializedSinceVersion; // By default if opset_version_to_load=0, it registers all opset schema for all opset versions // Otherwise, it only registers the latest schema according to opset_version_to_load -void RegisterSchema( - const OpSchema& schema, - int opset_version_to_load, - bool fail_duplicate_schema, - bool fail_with_exception) { +void RegisterSchema(OpSchema schema, int opset_version_to_load, bool fail_duplicate_schema, bool fail_with_exception) { if (fail_with_exception) { OpSchemaRegistry::OpSchemaRegisterOnce::OpSchemaRegisterImpl(schema, opset_version_to_load, fail_duplicate_schema); } else { diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index 814f85c46ef..666c586397b 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -1247,7 +1247,7 @@ class OpSchemaRegistry final : public ISchemaRegistry { class OpSchemaRegisterOnce final { public: - OpSchemaRegisterOnce(const OpSchema& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { + OpSchemaRegisterOnce(OpSchema& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { ONNX_TRY { OpSchemaRegisterImpl(op_schema, opset_version_to_load, fail_duplicate_schema); } @@ -1256,7 +1256,7 @@ class OpSchemaRegistry final : public ISchemaRegistry { } } static void - OpSchemaRegisterImpl(OpSchema op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { + OpSchemaRegisterImpl(OpSchema& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { op_schema.Finalize(); auto& m = GetMapWithoutEnsuringRegistration(); auto& op_name = op_schema.Name(); @@ -1478,7 +1478,7 @@ class OpSchemaRegistry final : public ISchemaRegistry { }; void RegisterSchema( - const OpSchema& schema, + OpSchema schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true, bool fail_with_exception = false); From 85ba547b303ba1fce2f5532a0494181823abcadc Mon Sep 17 00:00:00 2001 From: opluss Date: Wed, 14 Feb 2024 20:06:38 +0800 Subject: [PATCH 27/35] using rvalue reference for register pipline Signed-off-by: opluss --- onnx/cpp2py_export.cc | 2 +- onnx/defs/schema.cc | 11 ++++++++--- onnx/defs/schema.h | 12 ++++++------ 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index 02b26365c66..e203c66de0e 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -487,7 +487,7 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { "Set the version range and last release version of the specified domain.") .def( "register_schema", - [](OpSchema* schema) { RegisterSchema(*schema, 0, true, true); }, + [](OpSchema schema) { RegisterSchema(std::move(schema), 0, true, true); }, "schema"_a, "Register a user provided OpSchema.") .def( diff --git a/onnx/defs/schema.cc b/onnx/defs/schema.cc index 5ccdace18cb..1c1531650b4 100644 --- a/onnx/defs/schema.cc +++ b/onnx/defs/schema.cc @@ -30,12 +30,17 @@ constexpr int OpSchema::kUninitializedSinceVersion; // By default if opset_version_to_load=0, it registers all opset schema for all opset versions // Otherwise, it only registers the latest schema according to opset_version_to_load -void RegisterSchema(OpSchema schema, int opset_version_to_load, bool fail_duplicate_schema, bool fail_with_exception) { +void RegisterSchema( + OpSchema&& schema, + int opset_version_to_load, + bool fail_duplicate_schema, + bool fail_with_exception) { if (fail_with_exception) { - OpSchemaRegistry::OpSchemaRegisterOnce::OpSchemaRegisterImpl(schema, opset_version_to_load, fail_duplicate_schema); + OpSchemaRegistry::OpSchemaRegisterOnce::OpSchemaRegisterImpl( + std::forward(schema), opset_version_to_load, fail_duplicate_schema); } else { OpSchemaRegistry::OpSchemaRegisterOnce ONNX_UNUSED registration( - schema, opset_version_to_load, fail_duplicate_schema); + std::forward(schema), opset_version_to_load, fail_duplicate_schema); } } diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index 666c586397b..58cafe13b3f 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -1247,16 +1247,16 @@ class OpSchemaRegistry final : public ISchemaRegistry { class OpSchemaRegisterOnce final { public: - OpSchemaRegisterOnce(OpSchema& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { + OpSchemaRegisterOnce(OpSchema&& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { ONNX_TRY { - OpSchemaRegisterImpl(op_schema, opset_version_to_load, fail_duplicate_schema); + OpSchemaRegisterImpl(std::forward(op_schema), opset_version_to_load, fail_duplicate_schema); } ONNX_CATCH(const std::exception& e) { ONNX_HANDLE_EXCEPTION([&]() { std::cerr << "Schema error: " << e.what() << std::endl; }); } } static void - OpSchemaRegisterImpl(OpSchema& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { + OpSchemaRegisterImpl(OpSchema&& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { op_schema.Finalize(); auto& m = GetMapWithoutEnsuringRegistration(); auto& op_name = op_schema.Name(); @@ -1297,7 +1297,7 @@ class OpSchemaRegistry final : public ISchemaRegistry { } CheckDomainAndVersionToRegister(op_schema, op_name, op_domain); - schema_ver_map.insert(std::pair(ver, std::move(op_schema))); + schema_ver_map.insert(std::pair(ver, std::forward(op_schema))); } private: @@ -1478,7 +1478,7 @@ class OpSchemaRegistry final : public ISchemaRegistry { }; void RegisterSchema( - OpSchema schema, + OpSchema&& schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true, bool fail_with_exception = false); @@ -1489,7 +1489,7 @@ void DeregisterSchema(const std::string& op_type, int version, const std::string template void RegisterOpSetSchema(int opset_version_to_load = 0, bool fail_duplicate_schema = true) { T::ForEachSchema([opset_version_to_load, fail_duplicate_schema](OpSchema&& schema) { - RegisterSchema(schema, opset_version_to_load, fail_duplicate_schema); + RegisterSchema(std::forward(schema), opset_version_to_load, fail_duplicate_schema); }); }; From 467f938bf9674981f925cec4e84cc1ddb6b1e3ef Mon Sep 17 00:00:00 2001 From: opluss Date: Wed, 14 Feb 2024 20:09:46 +0800 Subject: [PATCH 28/35] restore the default behavior of `AddDomainToVersion` Signed-off-by: opluss --- onnx/defs/schema.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index 58cafe13b3f..a8ce9d974b7 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -1184,7 +1184,8 @@ class OpSchemaRegistry final : public ISchemaRegistry { // standard ONNX domains as above). Custom-domains are free to interpret // this as appropriate (that is, as relative to releases of custom-domain // as opposed to ONNX releases). - void AddDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version) { + void + AddDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version = -1) { std::lock_guard lock(mutex_); if (map_.count(domain) != 0) { std::stringstream err; @@ -1208,7 +1209,8 @@ class OpSchemaRegistry final : public ISchemaRegistry { last_release_version_map_[domain] = last_release_version; } - void UpdateDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version) { + void + UpdateDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version = -1) { std::lock_guard lock(mutex_); if (map_.count(domain) == 0) { std::stringstream err; From 7f912746b7f05dd47b05da8bdfe77e66bc843ec5 Mon Sep 17 00:00:00 2001 From: opluss Date: Wed, 14 Feb 2024 21:02:49 +0800 Subject: [PATCH 29/35] overload `has_schema` with version param Signed-off-by: opluss --- onnx/cpp2py_export.cc | 8 ++++++++ onnx/onnx_cpp2py_export/defs.pyi | 5 +++++ 2 files changed, 13 insertions(+) diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index e203c66de0e..c15da44794f 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -430,6 +430,14 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { }, "op_type"_a, "domain"_a = ONNX_DOMAIN) + .def( + "has_schema", + [](const std::string& op_type, int max_inclusive_version, const std::string& domain) -> bool { + return OpSchemaRegistry::Schema(op_type, max_inclusive_version, domain) != nullptr; + }, + "op_type"_a, + "max_inclusive_version"_a, + "domain"_a = ONNX_DOMAIN) .def( "schema_version_map", []() -> std::unordered_map> { diff --git a/onnx/onnx_cpp2py_export/defs.pyi b/onnx/onnx_cpp2py_export/defs.pyi index c337c567766..1ce999e3355 100644 --- a/onnx/onnx_cpp2py_export/defs.pyi +++ b/onnx/onnx_cpp2py_export/defs.pyi @@ -182,7 +182,12 @@ class OpSchema: CONSUME_ALLOWED: OpSchema.UseType = ... CONSUME_ENFORCED: OpSchema.UseType = ... +@overload def has_schema(op_type: str, domain: str = "") -> bool: ... +@overload +def has_schema( + op_type: str,max_inclusive_version: int, domain: str = "" +) -> bool: ... def schema_version_map() -> dict[str, tuple[int, int]]: ... @overload def get_schema( From 2b6cf01b591bd603cd55a62a2322913934256494 Mon Sep 17 00:00:00 2001 From: opluss Date: Wed, 14 Feb 2024 21:04:05 +0800 Subject: [PATCH 30/35] add multi register case in unittest Signed-off-by: opluss --- onnx/test/schema_test.py | 87 ++++++++++++++++++++++++++++------------ 1 file changed, 62 insertions(+), 25 deletions(-) diff --git a/onnx/test/schema_test.py b/onnx/test/schema_test.py index ebd56c08170..5de210f4e4e 100644 --- a/onnx/test/schema_test.py +++ b/onnx/test/schema_test.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib import unittest -from typing import Sequence +from typing import List, Sequence import parameterized @@ -264,14 +264,16 @@ def test_init_with_default_value(self): # register to exist domain { "op_type": "CustomOp", - "op_version": 1, + "op_version": 5, "op_domain": "", + "trap_op_version": [1, 2, 6, 7], }, # register to new domain { "op_type": "CustomOp", - "op_version": 1, + "op_version": 5, "op_domain": "test", + "trap_op_version": [1, 2, 6, 7], }, ] ) @@ -279,6 +281,8 @@ class TestOpSchemaRegister(unittest.TestCase): op_type: str op_version: int op_domain: str + # register some fake schema to check behavior + trap_op_version: List[int] def setUp(self) -> None: # Ensure the schema is unregistered @@ -286,10 +290,32 @@ def setUp(self) -> None: def tearDown(self) -> None: # Clean up the registered schema - with contextlib.suppress(onnx.defs.SchemaError): - onnx.defs.deregister_schema(self.op_type, self.op_version, self.op_domain) + for version in [*self.trap_op_version, self.op_version]: + with contextlib.suppress(onnx.defs.SchemaError): + onnx.defs.deregister_schema(self.op_type, version, self.op_domain) + + def test_register_multi_schema(self): + for version in [*self.trap_op_version, self.op_version]: + op_schema = defs.OpSchema( + self.op_type, + self.op_domain, + version, + ) + onnx.defs.register_schema(op_schema) + self.assertTrue(onnx.defs.has(self.op_type, version, self.op_domain)) + for version in [*self.trap_op_version, self.op_version]: + # Also make sure the `op_schema` is accessible after register + registered_op = onnx.defs.get_schema( + op_schema.name, version, op_schema.domain + ) + op_schema = defs.OpSchema( + self.op_type, + self.op_domain, + version, + ) + self.assertEqual(str(registered_op), str(op_schema)) - def test_register_schema(self): + def test_using_the_specified_version_in_onnx_check(self): input = f""" < ir_version: 7, @@ -324,15 +350,20 @@ def test_register_schema(self): with self.assertRaises(onnx.checker.ValidationError): onnx.checker.check_model(model, check_custom_domain=True) onnx.defs.register_schema(op_schema) - self.assertTrue(onnx.defs.has(self.op_type, self.op_domain)) + # The fake schema will raise check exception if selected in checker + for version in self.trap_op_version: + onnx.defs.register_schema( + defs.OpSchema( + self.op_type, + self.op_domain, + version, + outputs=[ + defs.OpSchema.FormalParameter("output1", "int32"), + ], + ) + ) onnx.checker.check_model(model, check_custom_domain=True) - # Also make sure the `op_schema` is accessible after register - registered_op = onnx.defs.get_schema( - op_schema.name, op_schema.since_version, op_schema.domain - ) - self.assertEqual(str(registered_op), str(op_schema)) - def test_register_schema_raises_error_when_registering_a_schema_twice(self): op_schema = defs.OpSchema( self.op_type, @@ -343,18 +374,24 @@ def test_register_schema_raises_error_when_registering_a_schema_twice(self): with self.assertRaises(onnx.defs.SchemaError): onnx.defs.register_schema(op_schema) - def test_deregister_schema(self): - op_schema = defs.OpSchema( - self.op_type, - self.op_domain, - self.op_version, - ) - onnx.defs.register_schema(op_schema) - self.assertTrue(onnx.defs.has(op_schema.name, op_schema.domain)) - onnx.defs.deregister_schema( - op_schema.name, op_schema.since_version, op_schema.domain - ) - self.assertFalse(onnx.defs.has(op_schema.name, op_schema.domain)) + def test_deregister_the_specified_schema(self): + for version in [*self.trap_op_version, self.op_version]: + op_schema = defs.OpSchema( + self.op_type, + self.op_domain, + version, + ) + onnx.defs.register_schema(op_schema) + self.assertTrue(onnx.defs.has(op_schema.name, version, op_schema.domain)) + onnx.defs.deregister_schema(op_schema.name, self.op_version, op_schema.domain) + for version in self.trap_op_version: + self.assertTrue(onnx.defs.has(op_schema.name, version, op_schema.domain)) + # Maybe has lesser op version in trap list + if onnx.defs.has(op_schema.name, self.op_version, op_schema.domain): + schema = onnx.defs.get_schema( + op_schema.name, self.op_version, op_schema.domain + ) + self.assertLess(schema.since_version, self.op_version) def test_deregister_schema_raises_error_when_opschema_does_not_exist(self): with self.assertRaises(onnx.defs.SchemaError): From f2577b924dfca77d0d665fdca90646481c41a904 Mon Sep 17 00:00:00 2001 From: opluss Date: Wed, 14 Feb 2024 21:21:41 +0800 Subject: [PATCH 31/35] fix build error in cpp unittest Signed-off-by: opluss --- onnx/test/cpp/function_context_test.cc | 8 ++++---- onnx/test/cpp/function_verify_test.cc | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnx/test/cpp/function_context_test.cc b/onnx/test/cpp/function_context_test.cc index 4dc9ea93cc2..5ce1278aad3 100644 --- a/onnx/test/cpp/function_context_test.cc +++ b/onnx/test/cpp/function_context_test.cc @@ -92,7 +92,7 @@ void RegisterCustomFuncFloatSchema() { .Output(0, "Y", "Output tensor", "T", OpSchema::Single) .TypeConstraint("T", {"tensor(float)"}, "Type of the input and output values") .SetContextDependentFunctionBodyBuilder(BuildFloatFunctionBody); - ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused(schema); + ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused(std::move(schema)); (void)unused; } @@ -151,7 +151,7 @@ void RegisterCustomFunctionSchema() { .Output(0, "Y", "Output tensor", "T", OpSchema::Single) .TypeConstraint("T", {"tensor(float)", "tensor(double)"}, "Type of the input and output values") .SetContextDependentFunctionBodyBuilder(BuildFunctionBody); - ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused(schema); + ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused(std::move(schema)); (void)unused; } @@ -212,9 +212,9 @@ TEST(FunctionAPITest, VersionedFunctionBodyTest) { )ONNX", 16); - ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused2(schema_ver2); + ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused2(std::move(schema_ver2)); (void)unused2; - ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused9(schema_ver9); + ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused9(std::move(schema_ver9)); (void)unused9; const auto* schema2 = OpSchemaRegistry::Schema("MySub", 2, ONNX_DOMAIN); diff --git a/onnx/test/cpp/function_verify_test.cc b/onnx/test/cpp/function_verify_test.cc index 0931bf1aca7..6a066177c2e 100644 --- a/onnx/test/cpp/function_verify_test.cc +++ b/onnx/test/cpp/function_verify_test.cc @@ -410,7 +410,7 @@ void RegisterFunctionSchema() { return operator_sets; }()); - ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused(function_schema); + ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused(std::move(function_schema)); (void)unused; } From afe3c78ddcbbcff1cde52102c68fa1d85ee77d8f Mon Sep 17 00:00:00 2001 From: opluss Date: Wed, 14 Feb 2024 21:36:15 +0800 Subject: [PATCH 32/35] Revert "fix build error in cpp unittest" This reverts commit f2577b924dfca77d0d665fdca90646481c41a904. Signed-off-by: opluss --- onnx/test/cpp/function_context_test.cc | 8 ++++---- onnx/test/cpp/function_verify_test.cc | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnx/test/cpp/function_context_test.cc b/onnx/test/cpp/function_context_test.cc index 5ce1278aad3..4dc9ea93cc2 100644 --- a/onnx/test/cpp/function_context_test.cc +++ b/onnx/test/cpp/function_context_test.cc @@ -92,7 +92,7 @@ void RegisterCustomFuncFloatSchema() { .Output(0, "Y", "Output tensor", "T", OpSchema::Single) .TypeConstraint("T", {"tensor(float)"}, "Type of the input and output values") .SetContextDependentFunctionBodyBuilder(BuildFloatFunctionBody); - ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused(std::move(schema)); + ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused(schema); (void)unused; } @@ -151,7 +151,7 @@ void RegisterCustomFunctionSchema() { .Output(0, "Y", "Output tensor", "T", OpSchema::Single) .TypeConstraint("T", {"tensor(float)", "tensor(double)"}, "Type of the input and output values") .SetContextDependentFunctionBodyBuilder(BuildFunctionBody); - ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused(std::move(schema)); + ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused(schema); (void)unused; } @@ -212,9 +212,9 @@ TEST(FunctionAPITest, VersionedFunctionBodyTest) { )ONNX", 16); - ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused2(std::move(schema_ver2)); + ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused2(schema_ver2); (void)unused2; - ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused9(std::move(schema_ver9)); + ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused9(schema_ver9); (void)unused9; const auto* schema2 = OpSchemaRegistry::Schema("MySub", 2, ONNX_DOMAIN); diff --git a/onnx/test/cpp/function_verify_test.cc b/onnx/test/cpp/function_verify_test.cc index 6a066177c2e..0931bf1aca7 100644 --- a/onnx/test/cpp/function_verify_test.cc +++ b/onnx/test/cpp/function_verify_test.cc @@ -410,7 +410,7 @@ void RegisterFunctionSchema() { return operator_sets; }()); - ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused(std::move(function_schema)); + ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce unused(function_schema); (void)unused; } From 2b02fde86a53b9e05f1bf37e34d1d6528295e0c0 Mon Sep 17 00:00:00 2001 From: opluss Date: Wed, 14 Feb 2024 21:39:34 +0800 Subject: [PATCH 33/35] fix build error for cpp custom register Signed-off-by: opluss --- onnx/defs/schema.cc | 2 +- onnx/defs/schema.h | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/onnx/defs/schema.cc b/onnx/defs/schema.cc index 1c1531650b4..9ea7c18fc52 100644 --- a/onnx/defs/schema.cc +++ b/onnx/defs/schema.cc @@ -39,7 +39,7 @@ void RegisterSchema( OpSchemaRegistry::OpSchemaRegisterOnce::OpSchemaRegisterImpl( std::forward(schema), opset_version_to_load, fail_duplicate_schema); } else { - OpSchemaRegistry::OpSchemaRegisterOnce ONNX_UNUSED registration( + OpSchemaRegistry::OpSchemaRegisterOnce::OpSchemaRegisterNoExcept( std::forward(schema), opset_version_to_load, fail_duplicate_schema); } } diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index a8ce9d974b7..e8062f18747 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -1249,7 +1249,12 @@ class OpSchemaRegistry final : public ISchemaRegistry { class OpSchemaRegisterOnce final { public: - OpSchemaRegisterOnce(OpSchema&& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { + // Export to cpp custom register macro + OpSchemaRegisterOnce(OpSchema op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { + OpSchemaRegisterNoExcept(std::move(op_schema), opset_version_to_load, fail_duplicate_schema); + } + static void + OpSchemaRegisterNoExcept(OpSchema&& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { ONNX_TRY { OpSchemaRegisterImpl(std::forward(op_schema), opset_version_to_load, fail_duplicate_schema); } From 8380e7250bccd492fb005124d1a2deaf14c4f1e7 Mon Sep 17 00:00:00 2001 From: opluss Date: Fri, 16 Feb 2024 11:30:30 +0800 Subject: [PATCH 34/35] fix py code style Signed-off-by: opluss --- onnx/onnx_cpp2py_export/defs.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx/onnx_cpp2py_export/defs.pyi b/onnx/onnx_cpp2py_export/defs.pyi index 1ce999e3355..dbc9061600d 100644 --- a/onnx/onnx_cpp2py_export/defs.pyi +++ b/onnx/onnx_cpp2py_export/defs.pyi @@ -186,7 +186,7 @@ class OpSchema: def has_schema(op_type: str, domain: str = "") -> bool: ... @overload def has_schema( - op_type: str,max_inclusive_version: int, domain: str = "" + op_type: str, max_inclusive_version: int, domain: str = "" ) -> bool: ... def schema_version_map() -> dict[str, tuple[int, int]]: ... @overload From cb0e5218539632b89f7adf1c2f3cefcc7d254e0b Mon Sep 17 00:00:00 2001 From: opluss Date: Sat, 17 Feb 2024 14:10:46 +0800 Subject: [PATCH 35/35] Add overload impl to `RegisterSchema` for compatibility Use `move` instead of `forward` Remove default domain value in `DeregisterSchema` Signed-off-by: opluss --- onnx/defs/schema.cc | 11 +++++++++-- onnx/defs/schema.h | 13 +++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/onnx/defs/schema.cc b/onnx/defs/schema.cc index 9ea7c18fc52..b8fab81d22d 100644 --- a/onnx/defs/schema.cc +++ b/onnx/defs/schema.cc @@ -30,6 +30,13 @@ constexpr int OpSchema::kUninitializedSinceVersion; // By default if opset_version_to_load=0, it registers all opset schema for all opset versions // Otherwise, it only registers the latest schema according to opset_version_to_load +void RegisterSchema( + const OpSchema& schema, + int opset_version_to_load, + bool fail_duplicate_schema, + bool fail_with_exception) { + RegisterSchema(OpSchema(schema), opset_version_to_load, fail_duplicate_schema, fail_with_exception); +} void RegisterSchema( OpSchema&& schema, int opset_version_to_load, @@ -37,10 +44,10 @@ void RegisterSchema( bool fail_with_exception) { if (fail_with_exception) { OpSchemaRegistry::OpSchemaRegisterOnce::OpSchemaRegisterImpl( - std::forward(schema), opset_version_to_load, fail_duplicate_schema); + std::move(schema), opset_version_to_load, fail_duplicate_schema); } else { OpSchemaRegistry::OpSchemaRegisterOnce::OpSchemaRegisterNoExcept( - std::forward(schema), opset_version_to_load, fail_duplicate_schema); + std::move(schema), opset_version_to_load, fail_duplicate_schema); } } diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index e8062f18747..969037858b2 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -1256,7 +1256,7 @@ class OpSchemaRegistry final : public ISchemaRegistry { static void OpSchemaRegisterNoExcept(OpSchema&& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { ONNX_TRY { - OpSchemaRegisterImpl(std::forward(op_schema), opset_version_to_load, fail_duplicate_schema); + OpSchemaRegisterImpl(std::move(op_schema), opset_version_to_load, fail_duplicate_schema); } ONNX_CATCH(const std::exception& e) { ONNX_HANDLE_EXCEPTION([&]() { std::cerr << "Schema error: " << e.what() << std::endl; }); @@ -1304,7 +1304,7 @@ class OpSchemaRegistry final : public ISchemaRegistry { } CheckDomainAndVersionToRegister(op_schema, op_name, op_domain); - schema_ver_map.insert(std::pair(ver, std::forward(op_schema))); + schema_ver_map.insert(std::pair(ver, std::move(op_schema))); } private: @@ -1484,19 +1484,24 @@ class OpSchemaRegistry final : public ISchemaRegistry { } }; +void RegisterSchema( + const OpSchema& schema, + int opset_version_to_load = 0, + bool fail_duplicate_schema = true, + bool fail_with_exception = false); void RegisterSchema( OpSchema&& schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true, bool fail_with_exception = false); -void DeregisterSchema(const std::string& op_type, int version, const std::string& domain = ONNX_DOMAIN); +void DeregisterSchema(const std::string& op_type, int version, const std::string& domain); // Registers the latest opset schema before opset_version_to_load // By default opset_version_to_load=0 means it will register all versions template void RegisterOpSetSchema(int opset_version_to_load = 0, bool fail_duplicate_schema = true) { T::ForEachSchema([opset_version_to_load, fail_duplicate_schema](OpSchema&& schema) { - RegisterSchema(std::forward(schema), opset_version_to_load, fail_duplicate_schema); + RegisterSchema(std::move(schema), opset_version_to_load, fail_duplicate_schema); }); };