diff --git a/docs/docsgen/source/api/defs.md b/docs/docsgen/source/api/defs.md index ed027ceca63..509c2d2f528 100644 --- a/docs/docsgen/source/api/defs.md +++ b/docs/docsgen/source/api/defs.md @@ -22,6 +22,10 @@ .. autofunction:: onnx.defs.get_all_schemas_with_history .. autofunction:: onnx.defs.get_function_ops + +.. autofunction:: onnx.defs.register_schema + +.. autofunction:: onnx.defs.deregister_schema ``` ## class `OpSchema` diff --git a/onnx/checker.cc b/onnx/checker.cc index a8623c6064c..39b38df01ed 100644 --- a/onnx/checker.cc +++ b/onnx/checker.cc @@ -572,17 +572,11 @@ void check_node(const NodeProto& node, const CheckerContext& ctx, const LexicalS const auto* schema = ctx.get_schema_registry()->GetSchema(node.op_type(), domain_version, node.domain()); if (!schema) { if (node.domain() == ONNX_DOMAIN || node.domain() == AI_ONNX_ML_DOMAIN || node.domain() == "ai.onnx" || - node.domain() == AI_ONNX_TRAINING_DOMAIN) { - // fail the checker if op in built-in domains has no schema + node.domain() == AI_ONNX_TRAINING_DOMAIN || ctx.check_custom_domain()) { + // fail the checker if op is in built-in domains or if it has no schema when `check_custom_domain` is true fail_check( "No Op registered for " + node.op_type() + " with domain_version of " + ONNX_NAMESPACE::to_string(domain_version)); - } else { - // TODO: expose the registration of the op schemas appropriately in - // python, so we can load and register operators in other domains - // - // before we complete the above todo, let's skip the schema check for - // now } } else if (schema->Deprecated()) { fail_check( @@ -937,7 +931,11 @@ void check_model(const ModelProto& model, CheckerContext& ctx) { } } -void check_model(const std::string& model_path, bool full_check, bool skip_opset_compatibility_check) { +void check_model( + const std::string& model_path, + bool full_check, + bool skip_opset_compatibility_check, + bool check_custom_domain) { ModelProto model; LoadProtoFromPath(model_path, model); @@ -949,6 +947,7 @@ void check_model(const std::string& model_path, bool full_check, bool skip_opset } ctx.set_model_dir(model_dir); ctx.set_skip_opset_compatibility_check(skip_opset_compatibility_check); + ctx.set_check_custom_domain(check_custom_domain); check_model(model, ctx); if (full_check) { @@ -957,9 +956,14 @@ void check_model(const std::string& model_path, bool full_check, bool skip_opset } } -void check_model(const ModelProto& model, bool full_check, bool skip_opset_compatibility_check) { +void check_model( + const ModelProto& model, + bool full_check, + bool skip_opset_compatibility_check, + bool check_custom_domain) { CheckerContext ctx; ctx.set_skip_opset_compatibility_check(skip_opset_compatibility_check); + ctx.set_check_custom_domain(check_custom_domain); check_model(model, ctx); if (full_check) { ShapeInferenceOptions options{true, 1, false}; diff --git a/onnx/checker.h b/onnx/checker.h index 83012213469..9be8855ee93 100644 --- a/onnx/checker.h +++ b/onnx/checker.h @@ -84,6 +84,14 @@ class CheckerContext final { skip_opset_compatibility_check_ = value; } + bool check_custom_domain() const { + return check_custom_domain_; + } + + void set_check_custom_domain(bool value) { + check_custom_domain_ = value; + } + explicit CheckerContext() : ir_version_(-1) {} private: @@ -93,6 +101,7 @@ class CheckerContext final { const ISchemaRegistry* schema_registry_ = OpSchemaRegistry::Instance(); std::string model_dir_; bool skip_opset_compatibility_check_ = false; + bool check_custom_domain_ = false; }; class LexicalScopeContext final { @@ -158,8 +167,16 @@ void check_model_local_functions( const CheckerContext& ctx, const LexicalScopeContext& parent_lex); -void check_model(const ModelProto& model, bool full_check = false, bool skip_opset_compatibility_check = false); -void check_model(const std::string& model_path, bool full_check = false, bool skip_opset_compatibility_check = false); +void check_model( + const ModelProto& model, + bool full_check = false, + bool skip_opset_compatibility_check = false, + bool check_custom_domain = false); +void check_model( + const std::string& model_path, + bool full_check = false, + bool skip_opset_compatibility_check = false, + bool check_custom_domain = false); std::string resolve_external_data_location( const std::string& base_dir, const std::string& location, diff --git a/onnx/checker.py b/onnx/checker.py index f624a4d7f35..c89c4cec545 100644 --- a/onnx/checker.py +++ b/onnx/checker.py @@ -137,6 +137,7 @@ def check_model( model: ModelProto | str | bytes | os.PathLike, full_check: bool = False, skip_opset_compatibility_check: bool = False, + check_custom_domain: bool = False, ) -> None: """Check the consistency of a model. @@ -154,10 +155,17 @@ def check_model( full_check: If True, the function also runs shape inference check. skip_opset_compatibility_check: If True, the function skips the check for opset compatibility. + check_custom_domain: If True, the function will check all domains. Otherwise + only check built-in domains. """ # If model is a path instead of ModelProto if isinstance(model, (str, os.PathLike)): - C.check_model_path(os.fspath(model), full_check, skip_opset_compatibility_check) + C.check_model_path( + os.fspath(model), + full_check, + skip_opset_compatibility_check, + check_custom_domain, + ) else: protobuf_string = ( model if isinstance(model, bytes) else model.SerializeToString() @@ -168,7 +176,12 @@ def check_model( raise ValueError( "This protobuf of onnx model is too large (>2GB). Call check_model with model path instead." ) - C.check_model(protobuf_string, full_check, skip_opset_compatibility_check) + C.check_model( + protobuf_string, + full_check, + skip_opset_compatibility_check, + check_custom_domain, + ) ValidationError = C.ValidationError diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index ff896da1659..2184af53dee 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -403,6 +403,14 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { }, "op_type"_a, "domain"_a = ONNX_DOMAIN) + .def( + "has_schema", + [](const std::string& op_type, int max_inclusive_version, const std::string& domain) -> bool { + return OpSchemaRegistry::Schema(op_type, max_inclusive_version, domain) != nullptr; + }, + "op_type"_a, + "max_inclusive_version"_a, + "domain"_a = ONNX_DOMAIN) .def( "schema_version_map", []() -> std::unordered_map> { @@ -442,7 +450,34 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { .def( "get_all_schemas_with_history", []() -> const std::vector { return OpSchemaRegistry::get_all_schemas_with_history(); }, - "Return the schema of all existing operators and all versions."); + "Return the schema of all existing operators and all versions.") + .def( + "set_domain_to_version", + [](const std::string& domain, int min_version, int max_version, int last_release_version) { + auto& obj = OpSchemaRegistry::DomainToVersionRange::Instance(); + if (obj.Map().count(domain) == 0) { + obj.AddDomainToVersion(domain, min_version, max_version, last_release_version); + } else { + obj.UpdateDomainToVersion(domain, min_version, max_version, last_release_version); + } + }, + "domain"_a, + "min_version"_a, + "max_version"_a, + "last_release_version"_a = -1, + "Set the version range and last release version of the specified domain.") + .def( + "register_schema", + [](OpSchema schema) { RegisterSchema(std::move(schema), 0, true, true); }, + "schema"_a, + "Register a user provided OpSchema.") + .def( + "deregister_schema", + &DeregisterSchema, + "op_type"_a, + "version"_a, + "domain"_a, + "Deregister the specified OpSchema."); // Submodule `checker` auto checker = onnx_cpp2py_export.def_submodule("checker"); @@ -519,21 +554,26 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { checker.def( "check_model", - [](const py::bytes& bytes, bool full_check, bool skip_opset_compatibility_check) -> void { + [](const py::bytes& bytes, bool full_check, bool skip_opset_compatibility_check, bool check_custom_domain) + -> void { ModelProto proto{}; ParseProtoFromPyBytes(&proto, bytes); - checker::check_model(proto, full_check, skip_opset_compatibility_check); + checker::check_model(proto, full_check, skip_opset_compatibility_check, check_custom_domain); }, "bytes"_a, "full_check"_a = false, - "skip_opset_compatibility_check"_a = false); + "skip_opset_compatibility_check"_a = false, + "check_custom_domain"_a = false); checker.def( "check_model_path", - (void (*)(const std::string& path, bool full_check, bool skip_opset_compatibility_check)) & checker::check_model, + (void (*)( + const std::string& path, bool full_check, bool skip_opset_compatibility_check, bool check_custom_domain)) & + checker::check_model, "path"_a, "full_check"_a = false, - "skip_opset_compatibility_check"_a = false); + "skip_opset_compatibility_check"_a = false, + "check_custom_domain"_a = false); checker.def("_resolve_external_data_location", &checker::resolve_external_data_location); diff --git a/onnx/defs/__init__.py b/onnx/defs/__init__.py index 55fb5cf4133..0750b7727b8 100644 --- a/onnx/defs/__init__.py +++ b/onnx/defs/__init__.py @@ -8,6 +8,8 @@ "ONNX_ML_DOMAIN", "AI_ONNX_PREVIEW_TRAINING_DOMAIN", "has", + "register_schema", + "deregister_schema", "get_schema", "get_all_schemas", "get_all_schemas_with_history", @@ -31,6 +33,7 @@ get_schema = C.get_schema get_all_schemas = C.get_all_schemas get_all_schemas_with_history = C.get_all_schemas_with_history +deregister_schema = C.deregister_schema def onnx_opset_version() -> int: @@ -120,3 +123,22 @@ def get_function_ops() -> List[OpSchema]: SchemaError = C.SchemaError + + +def register_schema(schema: OpSchema) -> None: + """Register a user provided OpSchema. + + The function extends available operator set versions for the provided domain if necessary. + + Args: + schema: The OpSchema to register. + """ + version_map = C.schema_version_map() + domain = schema.domain + version = schema.since_version + min_version, max_version = version_map.get(domain, (version, version)) + if domain not in version_map or not (min_version <= version <= max_version): + min_version = min(min_version, version) + max_version = max(max_version, version) + C.set_domain_to_version(schema.domain, min_version, max_version) + C.register_schema(schema) diff --git a/onnx/defs/schema.cc b/onnx/defs/schema.cc index 640237d1fdd..b8fab81d22d 100644 --- a/onnx/defs/schema.cc +++ b/onnx/defs/schema.cc @@ -30,8 +30,31 @@ constexpr int OpSchema::kUninitializedSinceVersion; // By default if opset_version_to_load=0, it registers all opset schema for all opset versions // Otherwise, it only registers the latest schema according to opset_version_to_load -void RegisterSchema(OpSchema schema, int opset_version_to_load, bool fail_duplicate_schema) { - OpSchemaRegistry::OpSchemaRegisterOnce ONNX_UNUSED registration(schema, opset_version_to_load, fail_duplicate_schema); +void RegisterSchema( + const OpSchema& schema, + int opset_version_to_load, + bool fail_duplicate_schema, + bool fail_with_exception) { + RegisterSchema(OpSchema(schema), opset_version_to_load, fail_duplicate_schema, fail_with_exception); +} +void RegisterSchema( + OpSchema&& schema, + int opset_version_to_load, + bool fail_duplicate_schema, + bool fail_with_exception) { + if (fail_with_exception) { + OpSchemaRegistry::OpSchemaRegisterOnce::OpSchemaRegisterImpl( + std::move(schema), opset_version_to_load, fail_duplicate_schema); + } else { + OpSchemaRegistry::OpSchemaRegisterOnce::OpSchemaRegisterNoExcept( + std::move(schema), opset_version_to_load, fail_duplicate_schema); + } +} + +// The (name, version, domain) must match the target exactly +// Otherwise will raise an SchemaError +void DeregisterSchema(const std::string& op_type, int version, const std::string& domain) { + OpSchemaRegistry::OpSchemaDeregister(op_type, version, domain); } #ifndef NDEBUG @@ -919,6 +942,11 @@ void OpSchema::Finalize() { // all inputs or std::numeric_limits::max() (if the last input is // variadic). + max_input_ = 0; + min_input_ = 0; + min_output_ = 0; + max_output_ = 0; + // Flag indicates whether an optional input is trailing one (there's no single // or variadic input behind). for (size_t i = 0; i < inputs_.size(); ++i) { diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index c7b2029109d..969037858b2 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -1187,16 +1187,52 @@ class OpSchemaRegistry final : public ISchemaRegistry { void AddDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version = -1) { std::lock_guard lock(mutex_); - assert(map_.end() == map_.find(domain)); + if (map_.count(domain) != 0) { + std::stringstream err; + err << "Trying to add a domain to DomainToVersion map, but the domain is already exist with version range (" + << map_.at(domain).first << ", " << map_.at(domain).second << "). domain: \"" << domain << "\"" + << std::endl; + fail_schema(err.str()); + } + if (last_release_version_map_.count(domain) != 0) { + std::stringstream err; + err << "Trying to add a domain to LastReleaseVersion map, but the domain is already exist with last version: " + << last_release_version_map_.at(domain) << ", domain: \"" << domain << "\"" << std::endl; + fail_schema(err.str()); + } map_[domain] = std::make_pair(min_version, max_version); // If a last-release-version is not explicitly specified, use max as // last-release-version. - if (last_release_version == -1) + if (last_release_version == -1) { last_release_version = max_version; - assert(last_release_version_map_.end() == last_release_version_map_.find(domain)); + } last_release_version_map_[domain] = last_release_version; } + void + UpdateDomainToVersion(const std::string& domain, int min_version, int max_version, int last_release_version = -1) { + std::lock_guard lock(mutex_); + if (map_.count(domain) == 0) { + std::stringstream err; + err << "Trying to update a domain in DomainToVersion map, but the domain has not been add. domain: \"" << domain + << "\"" << std::endl; + fail_schema(err.str()); + } + if (last_release_version_map_.count(domain) == 0) { + std::stringstream err; + err << "Trying to update a domain in LastReleaseVersion map, but the domain has not been add. domain: \"" + << domain << "\"" << std::endl; + fail_schema(err.str()); + } + map_.at(domain).first = min_version; + map_.at(domain).second = max_version; + // Correspond to `AddDomainToVersion` + if (last_release_version == -1) { + last_release_version = max_version; + } + last_release_version_map_.at(domain) = last_release_version; + } + static DomainToVersionRange& Instance(); private: @@ -1213,52 +1249,62 @@ class OpSchemaRegistry final : public ISchemaRegistry { class OpSchemaRegisterOnce final { public: - OpSchemaRegisterOnce(OpSchema& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { + // Export to cpp custom register macro + OpSchemaRegisterOnce(OpSchema op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { + OpSchemaRegisterNoExcept(std::move(op_schema), opset_version_to_load, fail_duplicate_schema); + } + static void + OpSchemaRegisterNoExcept(OpSchema&& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { ONNX_TRY { - op_schema.Finalize(); - auto& m = GetMapWithoutEnsuringRegistration(); - auto& op_name = op_schema.Name(); - auto& op_domain = op_schema.domain(); - auto& schema_ver_map = m[op_name][op_domain]; - auto ver = op_schema.SinceVersion(); - if (OpSchema::kUninitializedSinceVersion == ver) { - op_schema.SinceVersion(1); - ver = op_schema.SinceVersion(); + OpSchemaRegisterImpl(std::move(op_schema), opset_version_to_load, fail_duplicate_schema); + } + ONNX_CATCH(const std::exception& e) { + ONNX_HANDLE_EXCEPTION([&]() { std::cerr << "Schema error: " << e.what() << std::endl; }); + } + } + static void + OpSchemaRegisterImpl(OpSchema&& op_schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true) { + op_schema.Finalize(); + auto& m = GetMapWithoutEnsuringRegistration(); + auto& op_name = op_schema.Name(); + auto& op_domain = op_schema.domain(); + auto& schema_ver_map = m[op_name][op_domain]; + auto ver = op_schema.SinceVersion(); + if (OpSchema::kUninitializedSinceVersion == ver) { + op_schema.SinceVersion(1); + ver = op_schema.SinceVersion(); + } + + // Stops because the exact opset_version is registered + if (schema_ver_map.count(ver)) { + if (fail_duplicate_schema) { + const auto& schema = schema_ver_map[ver]; + std::stringstream err; + err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver + << ") from file " << op_schema.file() << " line " << op_schema.line() + << ", but it is already registered from file " << schema.file() << " line " << schema.line() << std::endl; + fail_schema(err.str()); } + return; + } - // Stops because the exact opset_version is registered - if (schema_ver_map.count(ver)) { - if (fail_duplicate_schema) { - const auto& schema = schema_ver_map[ver]; - std::stringstream err; - err << "Trying to register schema with name " << op_name << " (domain: " << op_domain << " version: " << ver - << ") from file " << op_schema.file() << " line " << op_schema.line() - << ", but it is already registered from file " << schema.file() << " line " << schema.line() - << std::endl; - fail_schema(err.str()); - } + if (opset_version_to_load != 0) { + // Stops because the opset_version is higher than opset_version_to_load + if (ver > opset_version_to_load) { return; } - if (opset_version_to_load != 0) { - // Stops because the opset_version is higher than opset_version_to_load - if (ver > opset_version_to_load) + // Stops because a later version is registered within target opset version + if (!schema_ver_map.empty()) { + int max_registered_ver_le_target = GetMaxRegisteredVerWithinTarget(schema_ver_map, opset_version_to_load); + if (max_registered_ver_le_target >= ver) { return; - - // Stops because a later version is registered within target opset version - if (!schema_ver_map.empty()) { - int max_registered_ver_le_target = GetMaxRegisteredVerWithinTarget(schema_ver_map, opset_version_to_load); - if (max_registered_ver_le_target >= ver) - return; } } - - CheckDomainAndVersionToRegister(op_schema, op_name, op_domain); - schema_ver_map.insert(std::pair(ver, std::move(op_schema))); - } - ONNX_CATCH(const std::exception& e) { - ONNX_HANDLE_EXCEPTION([&]() { std::cerr << "Schema error: " << e.what() << std::endl; }); } + + CheckDomainAndVersionToRegister(op_schema, op_name, op_domain); + schema_ver_map.insert(std::pair(ver, std::move(op_schema))); } private: @@ -1307,6 +1353,19 @@ class OpSchemaRegistry final : public ISchemaRegistry { } }; + static void + OpSchemaDeregister(const std::string& op_type, const int version, const std::string& domain = ONNX_DOMAIN) { + auto& schema_map = GetMapWithoutEnsuringRegistration(); + if (schema_map.count(op_type) && schema_map[op_type].count(domain) && schema_map[op_type][domain].count(version)) { + schema_map[op_type][domain].erase(version); + } else { + std::stringstream err; + err << "Attempting to deregister an unregistered schema with name: " << op_type << " domain: " << domain + << " version: " << version << std::endl; + fail_schema(err.str()); + } + } + // Deregister all ONNX opset schemas from domain // Domain with default value ONNX_DOMAIN means ONNX. static void OpSchemaDeregisterAll(const std::string& domain = ONNX_DOMAIN) { @@ -1416,21 +1475,33 @@ class OpSchemaRegistry final : public ISchemaRegistry { for (auto& x : map()) { for (auto& y : x.second) { auto& version2schema = y.second; - r.emplace_back(version2schema.rbegin()->second); + if (!version2schema.empty()) { + r.emplace_back(version2schema.rbegin()->second); + } } } return r; } }; -void RegisterSchema(OpSchema schema, int opset_version_to_load = 0, bool fail_duplicate_schema = true); +void RegisterSchema( + const OpSchema& schema, + int opset_version_to_load = 0, + bool fail_duplicate_schema = true, + bool fail_with_exception = false); +void RegisterSchema( + OpSchema&& schema, + int opset_version_to_load = 0, + bool fail_duplicate_schema = true, + bool fail_with_exception = false); +void DeregisterSchema(const std::string& op_type, int version, const std::string& domain); // Registers the latest opset schema before opset_version_to_load // By default opset_version_to_load=0 means it will register all versions template void RegisterOpSetSchema(int opset_version_to_load = 0, bool fail_duplicate_schema = true) { T::ForEachSchema([opset_version_to_load, fail_duplicate_schema](OpSchema&& schema) { - RegisterSchema(schema, opset_version_to_load, fail_duplicate_schema); + RegisterSchema(std::move(schema), opset_version_to_load, fail_duplicate_schema); }); }; diff --git a/onnx/onnx_cpp2py_export/checker.pyi b/onnx/onnx_cpp2py_export/checker.pyi index 83ce4650690..c887293772a 100644 --- a/onnx/onnx_cpp2py_export/checker.pyi +++ b/onnx/onnx_cpp2py_export/checker.pyi @@ -17,5 +17,5 @@ def check_attribute(bytes: bytes, checker_context: CheckerContext, lexical_scope def check_node(bytes: bytes, checker_context: CheckerContext, lexical_scope_context: LexicalScopeContext) -> None: ... # noqa: A002 def check_function(bytes: bytes, checker_context: CheckerContext, lexical_scope_context: LexicalScopeContext) -> None: ... # noqa: A002 def check_graph(bytes: bytes, checker_context: CheckerContext, lexical_scope_context: LexicalScopeContext) -> None: ... # noqa: A002 -def check_model(bytes: bytes, full_check: bool, skip_opset_compatibility_check: bool) -> None: ... # noqa: A002 -def check_model_path(path: str, full_check: bool, skip_opset_compatibility_check: bool) -> None: ... +def check_model(bytes: bytes, full_check: bool, skip_opset_compatibility_check: bool, check_custom_domain: bool) -> None: ... # noqa: A002 +def check_model_path(path: str, full_check: bool, skip_opset_compatibility_check: bool, check_custom_domain: bool) -> None: ... diff --git a/onnx/onnx_cpp2py_export/defs.pyi b/onnx/onnx_cpp2py_export/defs.pyi index d3dc88e443f..dbc9061600d 100644 --- a/onnx/onnx_cpp2py_export/defs.pyi +++ b/onnx/onnx_cpp2py_export/defs.pyi @@ -182,7 +182,12 @@ class OpSchema: CONSUME_ALLOWED: OpSchema.UseType = ... CONSUME_ENFORCED: OpSchema.UseType = ... -def has_schema(op_type: str) -> bool: ... +@overload +def has_schema(op_type: str, domain: str = "") -> bool: ... +@overload +def has_schema( + op_type: str, max_inclusive_version: int, domain: str = "" +) -> bool: ... def schema_version_map() -> dict[str, tuple[int, int]]: ... @overload def get_schema( @@ -192,3 +197,6 @@ def get_schema( def get_schema(op_type: str, domain: str = "") -> OpSchema: ... def get_all_schemas() -> Sequence[OpSchema]: ... def get_all_schemas_with_history() -> Sequence[OpSchema]: ... +def set_domain_to_version(domain: str, min_version: int, max_version: int, last_release_version: int = -1) -> None: ... +def register_schema(schema: OpSchema) -> None: ... +def deregister_schema(op_type: str, version: int, domain: str) -> None: ... diff --git a/onnx/test/schema_test.py b/onnx/test/schema_test.py index 1ee42ded555..5de210f4e4e 100644 --- a/onnx/test/schema_test.py +++ b/onnx/test/schema_test.py @@ -1,8 +1,9 @@ # Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 +import contextlib import unittest -from typing import Sequence +from typing import List, Sequence import parameterized @@ -258,5 +259,169 @@ def test_init_with_default_value(self): self.assertEqual("attr1 description", attribute.description) +@parameterized.parameterized_class( + [ + # register to exist domain + { + "op_type": "CustomOp", + "op_version": 5, + "op_domain": "", + "trap_op_version": [1, 2, 6, 7], + }, + # register to new domain + { + "op_type": "CustomOp", + "op_version": 5, + "op_domain": "test", + "trap_op_version": [1, 2, 6, 7], + }, + ] +) +class TestOpSchemaRegister(unittest.TestCase): + op_type: str + op_version: int + op_domain: str + # register some fake schema to check behavior + trap_op_version: List[int] + + def setUp(self) -> None: + # Ensure the schema is unregistered + self.assertFalse(onnx.defs.has(self.op_type, self.op_domain)) + + def tearDown(self) -> None: + # Clean up the registered schema + for version in [*self.trap_op_version, self.op_version]: + with contextlib.suppress(onnx.defs.SchemaError): + onnx.defs.deregister_schema(self.op_type, version, self.op_domain) + + def test_register_multi_schema(self): + for version in [*self.trap_op_version, self.op_version]: + op_schema = defs.OpSchema( + self.op_type, + self.op_domain, + version, + ) + onnx.defs.register_schema(op_schema) + self.assertTrue(onnx.defs.has(self.op_type, version, self.op_domain)) + for version in [*self.trap_op_version, self.op_version]: + # Also make sure the `op_schema` is accessible after register + registered_op = onnx.defs.get_schema( + op_schema.name, version, op_schema.domain + ) + op_schema = defs.OpSchema( + self.op_type, + self.op_domain, + version, + ) + self.assertEqual(str(registered_op), str(op_schema)) + + def test_using_the_specified_version_in_onnx_check(self): + input = f""" + < + ir_version: 7, + opset_import: [ + "{self.op_domain}" : {self.op_version} + ] + > + agraph (float[N, 128] X, int32 Y) => (float[N] Z) + {{ + Z = {self.op_domain}.{self.op_type}(X, Y) + }} + """ + model = onnx.parser.parse_model(input) + op_schema = defs.OpSchema( + self.op_type, + self.op_domain, + self.op_version, + inputs=[ + defs.OpSchema.FormalParameter("input1", "T"), + defs.OpSchema.FormalParameter("input2", "int32"), + ], + outputs=[ + defs.OpSchema.FormalParameter("output1", "T"), + ], + type_constraints=[("T", ["tensor(float)"], "")], + attributes=[ + defs.OpSchema.Attribute( + "attr1", defs.OpSchema.AttrType.INTS, "attr1 description" + ) + ], + ) + with self.assertRaises(onnx.checker.ValidationError): + onnx.checker.check_model(model, check_custom_domain=True) + onnx.defs.register_schema(op_schema) + # The fake schema will raise check exception if selected in checker + for version in self.trap_op_version: + onnx.defs.register_schema( + defs.OpSchema( + self.op_type, + self.op_domain, + version, + outputs=[ + defs.OpSchema.FormalParameter("output1", "int32"), + ], + ) + ) + onnx.checker.check_model(model, check_custom_domain=True) + + def test_register_schema_raises_error_when_registering_a_schema_twice(self): + op_schema = defs.OpSchema( + self.op_type, + self.op_domain, + self.op_version, + ) + onnx.defs.register_schema(op_schema) + with self.assertRaises(onnx.defs.SchemaError): + onnx.defs.register_schema(op_schema) + + def test_deregister_the_specified_schema(self): + for version in [*self.trap_op_version, self.op_version]: + op_schema = defs.OpSchema( + self.op_type, + self.op_domain, + version, + ) + onnx.defs.register_schema(op_schema) + self.assertTrue(onnx.defs.has(op_schema.name, version, op_schema.domain)) + onnx.defs.deregister_schema(op_schema.name, self.op_version, op_schema.domain) + for version in self.trap_op_version: + self.assertTrue(onnx.defs.has(op_schema.name, version, op_schema.domain)) + # Maybe has lesser op version in trap list + if onnx.defs.has(op_schema.name, self.op_version, op_schema.domain): + schema = onnx.defs.get_schema( + op_schema.name, self.op_version, op_schema.domain + ) + self.assertLess(schema.since_version, self.op_version) + + def test_deregister_schema_raises_error_when_opschema_does_not_exist(self): + with self.assertRaises(onnx.defs.SchemaError): + onnx.defs.deregister_schema(self.op_type, self.op_version, self.op_domain) + + def test_legacy_schema_accessible_after_deregister(self): + op_schema = defs.OpSchema( + self.op_type, + self.op_domain, + self.op_version, + ) + onnx.defs.register_schema(op_schema) + schema_a = onnx.defs.get_schema( + op_schema.name, op_schema.since_version, op_schema.domain + ) + schema_b = onnx.defs.get_schema(op_schema.name, op_schema.domain) + + def filter_schema(schemas): + return [op for op in schemas if op.name == op_schema.name] + + schema_c = filter_schema(onnx.defs.get_all_schemas()) + schema_d = filter_schema(onnx.defs.get_all_schemas_with_history()) + self.assertEqual(len(schema_c), 1) + self.assertEqual(len(schema_d), 1) + # Avoid memory residue and access storage as much as possible + self.assertEqual(str(schema_a), str(op_schema)) + self.assertEqual(str(schema_b), str(op_schema)) + self.assertEqual(str(schema_c[0]), str(op_schema)) + self.assertEqual(str(schema_d[0]), str(op_schema)) + + if __name__ == "__main__": unittest.main()