Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Print utility extension #4246

Merged
merged 30 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f6760af
Expose printing methods via Python
gramalingam May 19, 2022
631de1e
Use indentation in printing
gramalingam May 19, 2022
db2e22a
Remove extra space
gramalingam May 20, 2022
0cb5ca1
Encapsulate printer logic as a class
gramalingam Jun 1, 2022
891a4bb
Delete VS config files
gramalingam Jun 1, 2022
922a489
Add printer test
gramalingam Jun 1, 2022
c310660
Add print support for ModelProto
gramalingam Jun 2, 2022
6f0937b
Add to-string template
gramalingam Jun 2, 2022
e6ca615
Cleanup duplicate code
gramalingam Jun 2, 2022
b4a047e
Merge with main
gramalingam Jun 2, 2022
b763fef
Change Text to str
gramalingam Jun 2, 2022
1986589
Fix template specialization compile error
gramalingam Jun 8, 2022
6a13269
Remove unused variable
gramalingam Jun 8, 2022
29b6538
Add more print tests
gramalingam Jun 8, 2022
9df1e03
Flake8 warnings
gramalingam Jun 8, 2022
ffbfc7a
Address PR feedback
gramalingam Jun 8, 2022
46d023b
FIx mypy error
gramalingam Jun 9, 2022
09c105e
Fix mypy errors
gramalingam Jun 9, 2022
9b3b836
Fix typing
gramalingam Jun 9, 2022
53470d9
Format printer.cc
gramalingam Jun 9, 2022
f776069
Format code
gramalingam Jun 9, 2022
f8c166d
Support for escape char in strings
gramalingam Jun 9, 2022
7a6f600
Merge branch 'main' into python-print-2
gramalingam Jun 9, 2022
36bc7b5
Address PR comments
gramalingam Jun 13, 2022
459fce0
Format file
gramalingam Jun 14, 2022
614c3ba
Merge branch 'main' into python-print-2
gramalingam Jun 14, 2022
2296281
Fix python imports
gramalingam Jun 21, 2022
1599a54
Merge branch 'main' into python-print-2
gramalingam Jun 21, 2022
e7d51dc
Fix flake8 warning
gramalingam Jun 21, 2022
58665ca
Merge branch 'python-print-2' of https://github.com/gramalingam/onnx …
gramalingam Jun 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions onnx/cpp2py_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "onnx/checker.h"
#include "onnx/defs/function.h"
#include "onnx/defs/parser.h"
#include "onnx/defs/printer.h"
#include "onnx/defs/schema.h"
#include "onnx/py_utils.h"
#include "onnx/shape_inference/implementation.h"
Expand All @@ -31,6 +32,13 @@ static std::tuple<bool, py::bytes, py::bytes> Parse(const char* cstr) {
return std::make_tuple(status.IsOK(), py::bytes(status.ErrorMessage()), py::bytes(out));
}

template <typename ProtoType>
static std::string ProtoBytesToText(const py::bytes& bytes) {
ProtoType proto{};
ParseProtoFromPyBytes(&proto, bytes);
return ProtoToString(proto);
}

PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
onnx_cpp2py_export.doc() = "Python interface to onnx";

Expand Down Expand Up @@ -308,6 +316,15 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {

parser.def("parse_model", Parse<ModelProto>);
parser.def("parse_graph", Parse<GraphProto>);
parser.def("parse_function", Parse<FunctionProto>);

// Submodule `printer`
auto printer = onnx_cpp2py_export.def_submodule("printer");
printer.doc() = "Printer submodule";

printer.def("model_to_text", ProtoBytesToText<ModelProto>);
printer.def("function_to_text", ProtoBytesToText<FunctionProto>);
printer.def("graph_to_text", ProtoBytesToText<GraphProto>);
}

} // namespace ONNX_NAMESPACE
64 changes: 64 additions & 0 deletions onnx/defs/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,69 @@

namespace ONNX_NAMESPACE {

Status ParserBase::Parse(Literal& result) {
bool decimal_point = false;
auto nextch = NextChar();
auto from = next_;
if (nextch == '"') {
++next_;
bool has_escape = false;
while ((next_ < end_) && (*next_ != '"')) {
if (*next_ == '\\') {
has_escape = true;
++next_;
if (next_ >= end_)
return ParseError("Incomplete string literal.");
}
++next_;
}
if (next_ >= end_)
return ParseError("Incomplete string literal.");
++next_;
result.type = LiteralType::STRING_LITERAL;
if (has_escape) {
std::string& target = result.value;
target.clear();
target.reserve(next_ - from - 2); // upper bound
// *from is the starting quote. *(next_-1) is the ending quote.
// Copy what is in-between, except for the escape character
while (++from < next_ - 1) {
// Copy current char, if not escape, or next char otherwise.
target.push_back(*from != '\\' ? (*from) : *(++from));
}
} else
result.value = std::string(from + 1, next_ - from - 2); // skip enclosing quotes
} else if ((isdigit(nextch) || (nextch == '-'))) {
++next_;

while ((next_ < end_) && (isdigit(*next_) || (*next_ == '.'))) {
if (*next_ == '.') {
if (decimal_point)
break; // Only one decimal point allowed in numeric literal
decimal_point = true;
}
++next_;
}

if (next_ == from)
return ParseError("Value expected but not found.");

// Optional exponent syntax: (e|E)(+|-)?[0-9]+
if ((next_ < end_) && ((*next_ == 'e') || (*next_ == 'E'))) {
decimal_point = true; // treat as float-literal
++next_;
if ((next_ < end_) && ((*next_ == '+') || (*next_ == '-')))
++next_;
while ((next_ < end_) && (isdigit(*next_)))
++next_;
}

result.value = std::string(from, next_ - from);
result.type = decimal_point ? LiteralType::FLOAT_LITERAL : LiteralType::INT_LITERAL;
}
return Status::OK();
}

Status OnnxParser::Parse(IdList& idlist) {
idlist.Clear();
std::string id;
Expand Down Expand Up @@ -348,6 +411,7 @@ Status OnnxParser::ParseSingleAttributeValue(AttributeProto& attr) {
}

Status OnnxParser::Parse(AttributeProto& attr) {
attr.Clear();
std::string name;
CHECK_PARSER_STATUS(ParseIdentifier(name));
attr.set_name(name);
Expand Down
53 changes: 10 additions & 43 deletions onnx/defs/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,15 @@ class KeyWordMap {
return KeyWord::NONE;
}

static const std::string& ToString(KeyWord kw) {
static std::string undefined("undefined");
for (const auto& pair : Instance()) {
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
if (pair.second == kw)
return pair.first;
}
return undefined;
}

private:
std::unordered_map<std::string, KeyWord> map_;
};
Expand Down Expand Up @@ -263,49 +272,7 @@ class ParserBase {
std::string value;
};

Status Parse(Literal& result) {
bool decimal_point = false;
auto nextch = NextChar();
auto from = next_;
if (nextch == '"') {
++next_;
// TODO: Handle escape characters
while ((next_ < end_) && (*next_ != '"')) {
++next_;
}
++next_;
result.type = LiteralType::STRING_LITERAL;
result.value = std::string(from + 1, next_ - from - 2); // skip enclosing quotes
} else if ((isdigit(nextch) || (nextch == '-'))) {
++next_;

while ((next_ < end_) && (isdigit(*next_) || (*next_ == '.'))) {
if (*next_ == '.') {
if (decimal_point)
break; // Only one decimal point allowed in numeric literal
decimal_point = true;
}
++next_;
}

if (next_ == from)
return ParseError("Value expected but not found.");

// Optional exponent syntax: (e|E)(+|-)?[0-9]+
if ((next_ < end_) && ((*next_ == 'e') || (*next_ == 'E'))) {
decimal_point = true; // treat as float-literal
++next_;
if ((next_ < end_) && ((*next_ == '+') || (*next_ == '-')))
++next_;
while ((next_ < end_) && (isdigit(*next_)))
++next_;
}

result.value = std::string(from, next_ - from);
result.type = decimal_point ? LiteralType::FLOAT_LITERAL : LiteralType::INT_LITERAL;
}
return Status::OK();
}
Status Parse(Literal& result);

Status Parse(int64_t& val) {
Literal literal;
Expand Down
Loading