/
cpp2py_export.cc
98 lines (87 loc) · 3.58 KB
/
cpp2py_export.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "onnx/py_utils.h"
#include "onnxsim.h"
namespace py = pybind11;
using namespace pybind11::literals;
struct PyModelExecutor : public ModelExecutor {
using ModelExecutor::ModelExecutor;
std::vector<onnx::TensorProto> _Run(
const onnx::ModelProto& model,
const std::vector<onnx::TensorProto>& inputs) const override {
std::vector<py::bytes> inputs_bytes;
std::transform(inputs.begin(), inputs.end(),
std::back_inserter(inputs_bytes),
[](const onnx::TensorProto& x) {
return py::bytes(x.SerializeAsString());
});
std::string model_str = model.SerializeAsString();
auto output_bytes = _PyRun(py::bytes(model_str), inputs_bytes);
std::vector<onnx::TensorProto> output_tps;
std::transform(output_bytes.begin(), output_bytes.end(),
std::back_inserter(output_tps), [](const py::bytes& x) {
onnx::TensorProto tp;
tp.ParseFromString(std::string(x));
return tp;
});
return output_tps;
}
virtual std::vector<py::bytes> _PyRun(
const py::bytes& model_bytes,
const std::vector<py::bytes>& inputs_bytes) const = 0;
};
struct PyModelExecutorTrampoline : public PyModelExecutor {
/* Inherit the constructors */
using PyModelExecutor::PyModelExecutor;
/* Trampoline (need one for each virtual function) */
std::vector<py::bytes> _PyRun(
const py::bytes& model_bytes,
const std::vector<py::bytes>& inputs_bytes) const override {
PYBIND11_OVERRIDE_PURE_NAME(
std::vector<py::bytes>, /* Return type */
PyModelExecutor, /* Parent class */
"Run", _PyRun, /* Name of function in C++ (must match Python name) */
model_bytes, inputs_bytes /* Argument(s) */
);
}
};
PYBIND11_MODULE(onnxsim_cpp2py_export, m) {
m.doc() = "ONNX Simplifier";
m.def("simplify",
[](const py::bytes& model_proto_bytes,
std::optional<std::vector<std::string>> skip_optimizers,
bool constant_folding, bool shape_inference,
size_t tensor_size_threshold) -> py::bytes {
// force env initialization to register opset
InitEnv();
ONNX_NAMESPACE::ModelProto model;
ParseProtoFromPyBytes(&model, model_proto_bytes);
auto const result = Simplify(model, skip_optimizers, constant_folding,
shape_inference, tensor_size_threshold);
std::string out;
result.SerializeToString(&out);
return py::bytes(out);
})
.def("simplify_path",
[](const std::string& in_path, const std::string& out_path,
std::optional<std::vector<std::string>> skip_optimizers,
bool constant_folding, bool shape_inference,
size_t tensor_size_threshold) -> bool {
// force env initialization to register opset
InitEnv();
SimplifyPath(in_path, out_path, skip_optimizers, constant_folding,
shape_inference, tensor_size_threshold);
return true;
})
.def("_set_model_executor",
[](std::shared_ptr<PyModelExecutor> executor) {
ModelExecutor::set_instance(std::move(executor));
});
py::class_<PyModelExecutor, PyModelExecutorTrampoline,
std::shared_ptr<PyModelExecutor>>(m, "ModelExecutor")
.def(py::init<>())
.def("Run", &PyModelExecutor::_PyRun);
}