Skip to content

Commit

Permalink
Support register custom OpSchema by python
Browse files Browse the repository at this point in the history
Signed-off-by: oPluss <opluss@qq.com>
  • Loading branch information
OYCN committed Feb 5, 2024
1 parent 3d976ff commit 4683e8c
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 1 deletion.
4 changes: 3 additions & 1 deletion onnx/cpp2py_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,9 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
.def(
"get_all_schemas_with_history",
[]() -> const std::vector<OpSchema> { return OpSchemaRegistry::get_all_schemas_with_history(); },
"Return the schema of all existing operators and all versions.");
"Return the schema of all existing operators and all versions.")
.def(
"reg_schema", [](OpSchema* op) { RegisterSchema(*op); }, "Register the custom OpSchema.");

// Submodule `checker`
auto checker = onnx_cpp2py_export.def_submodule("checker");
Expand Down
13 changes: 13 additions & 0 deletions onnx/defs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"ONNX_ML_DOMAIN",
"AI_ONNX_PREVIEW_TRAINING_DOMAIN",
"has",
"reg_schema",
"get_schema",
"get_all_schemas",
"get_all_schemas_with_history",
Expand Down Expand Up @@ -120,3 +121,15 @@ def get_function_ops() -> List[OpSchema]:


SchemaError = C.SchemaError

def reg_schema(op: OpSchema):
name = op.name
domain = op.domain
ver = op.since_version
assert ver > 0, f'OpSchema {domain}::{name} need positive version but got {ver}'
if has(name, domain):
exist_op = get_schema(name, ver, domain)
if exist_op is not None:
if exist_op.since_version == ver:
raise SchemaError(f'OpSchema {domain}::{name} already exist in file: {exist_op.file}:{exist_op.line}')
C.reg_schema(op)
5 changes: 5 additions & 0 deletions onnx/defs/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,11 @@ void OpSchema::Finalize() {
// all inputs or std::numeric_limits<int>::max() (if the last input is
// variadic).

max_input_ = 0;
min_input_ = 0;
min_output_ = 0;
max_output_ = 0;

// Flag indicates whether an optional input is trailing one (there's no single
// or variadic input behind).
for (size_t i = 0; i < inputs_.size(); ++i) {
Expand Down
1 change: 1 addition & 0 deletions onnx/onnx_cpp2py_export/defs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,4 @@ def get_schema(
def get_schema(op_type: str, domain: str = "") -> OpSchema: ...
def get_all_schemas() -> Sequence[OpSchema]: ...
def get_all_schemas_with_history() -> Sequence[OpSchema]: ...
def reg_schema(schema: OpSchema) -> None: ...
38 changes: 38 additions & 0 deletions onnx/test/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,44 @@ def test_init_with_default_value(self):
self.assertEqual("attr1", attribute.name)
self.assertEqual("attr1 description", attribute.description)

class TestOpSchemaRegister(unittest.TestCase):
def test_register(self):
input = """
agraph (float[N, 128] X, int32 Y) => (float[N] Z)
{
Z = CustomOp<attr1=[1,2]>(X, Y)
}
"""
model = onnx.parser.parse_graph(input)
op_schema = defs.OpSchema(
"CustomOp",
"",
1,
inputs=[
defs.OpSchema.FormalParameter("input1", "T"),
defs.OpSchema.FormalParameter("input2", "int32"),
],
outputs=[
defs.OpSchema.FormalParameter("output1", "T"),
],
type_constraints=[("T", ["tensor(float)"], "")],
attributes=[
defs.OpSchema.Attribute(
"attr1", defs.OpSchema.AttrType.INTS, "attr1 description"
)
],
)
onnx.defs.reg_schema(op_schema)
onnx.checker.check_graph(model)

def test_duplicited_register(self):
op_schema = defs.OpSchema(
"CustomOpDuplicited",
"",
1,
)
onnx.defs.reg_schema(op_schema)
self.assertRaises(onnx.defs.SchemaError, lambda: onnx.defs.reg_schema(op_schema))

if __name__ == "__main__":
unittest.main()

0 comments on commit 4683e8c

Please sign in to comment.