Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support register custom OpSchema by python #5906

Merged
merged 41 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
14e0365
Support register custom OpSchema by python
OYCN Feb 5, 2024
ec31343
Optimization register check logic
OYCN Feb 6, 2024
cb2d9cd
Support deregister OpSchema
OYCN Feb 6, 2024
0b13a7b
normalized code and unittest
OYCN Feb 6, 2024
0b26ad5
fix desc error
OYCN Feb 6, 2024
5b2151f
append export symbol
OYCN Feb 6, 2024
a4c356c
add annotate and update doc
OYCN Feb 6, 2024
72a0e03
refact unittest
OYCN Feb 7, 2024
dca0d07
separate registration logic
OYCN Feb 7, 2024
5a01710
enable check for custom domain
OYCN Feb 7, 2024
c68fcbf
add domain automatically when register a cudatom op
OYCN Feb 7, 2024
59bedfc
fix py lint
OYCN Feb 7, 2024
d06560b
add case for unittest
OYCN Feb 7, 2024
aa1d9c3
modify description for `set_domain_to_version`
OYCN Feb 7, 2024
82349df
fix coding error
OYCN Feb 7, 2024
7c1983c
fix py lint
OYCN Feb 7, 2024
c778180
replace `check_custom_op` to `check_custom_domain` in checker
OYCN Feb 7, 2024
f776e62
Merge branch 'main' into main
justinchuby Feb 8, 2024
306bd6b
fix code style and annotate
OYCN Feb 8, 2024
65c61f0
fix py code style
OYCN Feb 8, 2024
0af0ac7
simplify deregister_schema binding
OYCN Feb 8, 2024
ada696e
add case to check schema accessible after deregister and fix bug
OYCN Feb 9, 2024
8b5ed1b
impl update method
OYCN Feb 9, 2024
8d7301e
append annotate for unittest case
OYCN Feb 9, 2024
ff8cdfb
Safe implementation for register
OYCN Feb 9, 2024
17e7685
fix py code style
OYCN Feb 9, 2024
d5a59f5
Revert "Safe implementation for register"
OYCN Feb 14, 2024
85ba547
using rvalue reference for register pipline
OYCN Feb 14, 2024
467f938
restore the default behavior of `AddDomainToVersion`
OYCN Feb 14, 2024
7f91274
overload `has_schema` with version param
OYCN Feb 14, 2024
2b6cf01
add multi register case in unittest
OYCN Feb 14, 2024
f2577b9
fix build error in cpp unittest
OYCN Feb 14, 2024
afe3c78
Revert "fix build error in cpp unittest"
OYCN Feb 14, 2024
2b02fde
fix build error for cpp custom register
OYCN Feb 14, 2024
8380e72
fix py code style
OYCN Feb 16, 2024
568b10d
Merge branch 'main' into main
gramalingam Feb 16, 2024
ffe9068
Merge branch 'main' into main
justinchuby Feb 17, 2024
cb0e521
Add overload impl to `RegisterSchema` for compatibility
OYCN Feb 17, 2024
a772d6f
Merge branch 'main' into dev-register-schema
OYCN Feb 21, 2024
90dbaae
Merge branch 'main' into dev-register-schema
OYCN Feb 23, 2024
ce63fbc
Merge branch 'main' into dev-register-schema
gramalingam Feb 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/docsgen/source/api/defs.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
.. autofunction:: onnx.defs.get_all_schemas_with_history

.. autofunction:: onnx.defs.get_function_ops

.. autofunction:: onnx.defs.register_schema

.. autofunction:: onnx.defs.deregister_schema
```

## class `OpSchema`
Expand Down
24 changes: 14 additions & 10 deletions onnx/checker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -651,17 +651,11 @@ void check_node(const NodeProto& node, const CheckerContext& ctx, const LexicalS
const auto* schema = ctx.get_schema_registry()->GetSchema(node.op_type(), domain_version, node.domain());
if (!schema) {
if (node.domain() == ONNX_DOMAIN || node.domain() == AI_ONNX_ML_DOMAIN || node.domain() == "ai.onnx" ||
node.domain() == AI_ONNX_TRAINING_DOMAIN) {
// fail the checker if op in built-in domains has no schema
node.domain() == AI_ONNX_TRAINING_DOMAIN || ctx.check_custom_domain()) {
// fail the checker if op is in built-in domains or if it has no schema when `check_custom_domain` is true
fail_check(
"No Op registered for " + node.op_type() + " with domain_version of " +
ONNX_NAMESPACE::to_string(domain_version));
} else {
// TODO: expose the registration of the op schemas appropriately in
// python, so we can load and register operators in other domains
//
// before we complete the above todo, let's skip the schema check for
// now
}
} else if (schema->Deprecated()) {
fail_check(
Expand Down Expand Up @@ -1016,7 +1010,11 @@ void check_model(const ModelProto& model, CheckerContext& ctx) {
}
}

void check_model(const std::string& model_path, bool full_check, bool skip_opset_compatibility_check) {
void check_model(
const std::string& model_path,
bool full_check,
bool skip_opset_compatibility_check,
bool check_custom_domain) {
ModelProto model;
LoadProtoFromPath(model_path, model);

Expand All @@ -1028,6 +1026,7 @@ void check_model(const std::string& model_path, bool full_check, bool skip_opset
}
ctx.set_model_dir(model_dir);
ctx.set_skip_opset_compatibility_check(skip_opset_compatibility_check);
ctx.set_check_custom_domain(check_custom_domain);
check_model(model, ctx);

if (full_check) {
Expand All @@ -1036,9 +1035,14 @@ void check_model(const std::string& model_path, bool full_check, bool skip_opset
}
}

void check_model(const ModelProto& model, bool full_check, bool skip_opset_compatibility_check) {
void check_model(
const ModelProto& model,
bool full_check,
bool skip_opset_compatibility_check,
bool check_custom_domain) {
CheckerContext ctx;
ctx.set_skip_opset_compatibility_check(skip_opset_compatibility_check);
ctx.set_check_custom_domain(check_custom_domain);
check_model(model, ctx);
if (full_check) {
ShapeInferenceOptions options{true, 1, false};
Expand Down
21 changes: 19 additions & 2 deletions onnx/checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ class CheckerContext final {
skip_opset_compatibility_check_ = value;
}

bool check_custom_domain() const {
return check_custom_domain_;
}

void set_check_custom_domain(bool value) {
check_custom_domain_ = value;
}

explicit CheckerContext() : ir_version_(-1) {}

private:
Expand All @@ -93,6 +101,7 @@ class CheckerContext final {
const ISchemaRegistry* schema_registry_ = OpSchemaRegistry::Instance();
std::string model_dir_;
bool skip_opset_compatibility_check_ = false;
bool check_custom_domain_ = false;
};

class LexicalScopeContext final {
Expand Down Expand Up @@ -158,8 +167,16 @@ void check_model_local_functions(
const CheckerContext& ctx,
const LexicalScopeContext& parent_lex);

void check_model(const ModelProto& model, bool full_check = false, bool skip_opset_compatibility_check = false);
void check_model(const std::string& model_path, bool full_check = false, bool skip_opset_compatibility_check = false);
void check_model(
const ModelProto& model,
bool full_check = false,
bool skip_opset_compatibility_check = false,
bool check_custom_domain = false);
void check_model(
const std::string& model_path,
bool full_check = false,
bool skip_opset_compatibility_check = false,
bool check_custom_domain = false);

bool check_is_experimental_op(const NodeProto& node);

Expand Down
17 changes: 15 additions & 2 deletions onnx/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def check_model(
model: ModelProto | str | bytes | os.PathLike,
full_check: bool = False,
skip_opset_compatibility_check: bool = False,
check_custom_domain: bool = False,
) -> None:
"""Check the consistency of a model.

Expand All @@ -154,10 +155,17 @@ def check_model(
full_check: If True, the function also runs shape inference check.
skip_opset_compatibility_check: If True, the function skips the check for
opset compatibility.
check_custom_domain: If True, the function will check all domains. Otherwise
only check built-in domains.
"""
# If model is a path instead of ModelProto
if isinstance(model, (str, os.PathLike)):
C.check_model_path(os.fspath(model), full_check, skip_opset_compatibility_check)
C.check_model_path(
os.fspath(model),
full_check,
skip_opset_compatibility_check,
check_custom_domain,
)
else:
protobuf_string = (
model if isinstance(model, bytes) else model.SerializeToString()
Expand All @@ -168,7 +176,12 @@ def check_model(
raise ValueError(
"This protobuf of onnx model is too large (>2GB). Call check_model with model path instead."
)
C.check_model(protobuf_string, full_check, skip_opset_compatibility_check)
C.check_model(
protobuf_string,
full_check,
skip_opset_compatibility_check,
check_custom_domain,
)


ValidationError = C.ValidationError
44 changes: 38 additions & 6 deletions onnx/cpp2py_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,34 @@
.def(
"get_all_schemas_with_history",
[]() -> const std::vector<OpSchema> { return OpSchemaRegistry::get_all_schemas_with_history(); },
"Return the schema of all existing operators and all versions.");
"Return the schema of all existing operators and all versions.")
.def(
"set_domain_to_version",
[](const std::string& domain, int min_version, int max_version, int last_release_version) {
auto& obj = OpSchemaRegistry::DomainToVersionRange::Instance();
if (obj.Map().count(domain) == 0) {
obj.AddDomainToVersion(domain, min_version, max_version, last_release_version);
} else {
obj.UpdateDomainToVersion(domain, min_version, max_version, last_release_version);
}
},
"domain"_a,
"min_version"_a,
"max_version"_a,
"last_release_version"_a = -1,
"Set the version range and last release version of the specified domain.")
.def(
"register_schema",
[](OpSchema* schema) { RegisterSchema(*schema, 0, true, true); },
"schema"_a,
"Register a user provided OpSchema.")
.def(
"deregister_schema",
&DeregisterSchema,
"op_type"_a,
"version"_a,
"domain"_a,
"Deregister the specified OpSchema.");

// Submodule `checker`
auto checker = onnx_cpp2py_export.def_submodule("checker");
Expand Down Expand Up @@ -546,21 +573,26 @@

checker.def(
"check_model",
[](const py::bytes& bytes, bool full_check, bool skip_opset_compatibility_check) -> void {
[](const py::bytes& bytes, bool full_check, bool skip_opset_compatibility_check, bool check_custom_domain)
-> void {
ModelProto proto{};
ParseProtoFromPyBytes(&proto, bytes);
checker::check_model(proto, full_check, skip_opset_compatibility_check);
checker::check_model(proto, full_check, skip_opset_compatibility_check, check_custom_domain);
},
"bytes"_a,
"full_check"_a = false,
"skip_opset_compatibility_check"_a = false);
"skip_opset_compatibility_check"_a = false,
"check_custom_domain"_a = false);

checker.def(
"check_model_path",
(void (*)(const std::string& path, bool full_check, bool skip_opset_compatibility_check)) & checker::check_model,
(void (*)(

Check warning on line 589 in onnx/cpp2py_export.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Extra space before ( in function call [whitespace/parens] [4] Raw Output: onnx/cpp2py_export.cc:589: Extra space before ( in function call [whitespace/parens] [4]
const std::string& path, bool full_check, bool skip_opset_compatibility_check, bool check_custom_domain)) &
checker::check_model,
"path"_a,
"full_check"_a = false,
"skip_opset_compatibility_check"_a = false);
"skip_opset_compatibility_check"_a = false,
"check_custom_domain"_a = false);

// Submodule `version_converter`
auto version_converter = onnx_cpp2py_export.def_submodule("version_converter");
Expand Down
22 changes: 22 additions & 0 deletions onnx/defs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
"ONNX_ML_DOMAIN",
"AI_ONNX_PREVIEW_TRAINING_DOMAIN",
"has",
"register_schema",
"deregister_schema",
"get_schema",
"get_all_schemas",
"get_all_schemas_with_history",
Expand All @@ -31,6 +33,7 @@
get_schema = C.get_schema
get_all_schemas = C.get_all_schemas
get_all_schemas_with_history = C.get_all_schemas_with_history
deregister_schema = C.deregister_schema


def onnx_opset_version() -> int:
Expand Down Expand Up @@ -120,3 +123,22 @@ def get_function_ops() -> List[OpSchema]:


SchemaError = C.SchemaError


def register_schema(schema: OpSchema) -> None:
"""Register a user provided OpSchema.

The function extends available operator set versions for the provided domain if necessary.

Args:
schema: The OpSchema to register.
"""
version_map = C.schema_version_map()
domain = schema.domain
version = schema.since_version
min_version, max_version = version_map.get(domain, (version, version))
if domain not in version_map or not (min_version <= version <= max_version):
min_version = min(min_version, version)
max_version = max(max_version, version)
C.set_domain_to_version(schema.domain, min_version, max_version)
C.register_schema(schema)
24 changes: 22 additions & 2 deletions onnx/defs/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,23 @@

// By default if opset_version_to_load=0, it registers all opset schema for all opset versions
// Otherwise, it only registers the latest schema according to opset_version_to_load
void RegisterSchema(OpSchema schema, int opset_version_to_load, bool fail_duplicate_schema) {
OpSchemaRegistry::OpSchemaRegisterOnce ONNX_UNUSED registration(schema, opset_version_to_load, fail_duplicate_schema);
void RegisterSchema(
const OpSchema& schema,
int opset_version_to_load,
bool fail_duplicate_schema,
bool fail_with_exception) {
if (fail_with_exception) {
OpSchemaRegistry::OpSchemaRegisterOnce::OpSchemaRegisterImpl(schema, opset_version_to_load, fail_duplicate_schema);
} else {
OpSchemaRegistry::OpSchemaRegisterOnce ONNX_UNUSED registration(

Check warning on line 41 in onnx/defs/schema.cc

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: variable 'registration' of type 'OpSchemaRegistry::OpSchemaRegisterOnce' can be declared 'const' [misc-const-correctness] ```suggestion OpSchemaRegistry::OpSchemaRegisterOnce ONNX_UNUSED const registration( ```
schema, opset_version_to_load, fail_duplicate_schema);
}
}

// The (name, version, domain) must match the target exactly
// Otherwise will raise an SchemaError
void DeregisterSchema(const std::string& op_type, int version, const std::string& domain) {
OpSchemaRegistry::OpSchemaDeregister(op_type, version, domain);
}

#ifndef NDEBUG
Expand Down Expand Up @@ -919,6 +934,11 @@
// all inputs or std::numeric_limits<int>::max() (if the last input is
// variadic).

max_input_ = 0;
min_input_ = 0;
min_output_ = 0;
max_output_ = 0;

// Flag indicates whether an optional input is trailing one (there's no single
// or variadic input behind).
for (size_t i = 0; i < inputs_.size(); ++i) {
Expand Down