Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 75 additions & 9 deletions onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,11 @@ def _cond_is_used_in_loop_body(graph: GraphProto) -> bool:
return False


class Exporter:
class _Exporter:
"""Class used for recursive traversal of Proto structures."""

def __init__(
self, rename: bool, use_operators: bool = False, inline_const: bool = False
self, *, rename: bool, use_operators: bool, inline_const: bool, skip_initializers: bool
) -> None:
self.use_operators = use_operators
if rename:
Expand All @@ -266,6 +266,8 @@ def __init__(
# _name_remappings: used to undo the SSA-renaming in ONNX control-flow ops.
# We map the multiple SSA-variants back to the same Python variable name.
self._name_remappings: list[dict[str, str]] = []
self.skip_initializers = skip_initializers
self.skipped_initializers: dict[str, onnx.TensorProto] = {}

def _handle_attrname_conflict(self, renamer):
"""Add ref-attr-name-conflict handling logic to renaming function."""
Expand Down Expand Up @@ -338,6 +340,14 @@ def _translate_graph_body(self, graph, opsets, indent=0):
code = []
if hasattr(graph, "initializer"):
for init in graph.initializer:
if self.skip_initializers:
init_py_name = self._translate_onnx_var(init.name)
if init_py_name in self.skipped_initializers:
raise RuntimeError(
f"Initializer {init.name!r} is already present in skipped_initializers."
)
self.skipped_initializers[init_py_name] = init
continue
node = make_node(
"Constant",
[],
Expand Down Expand Up @@ -684,15 +694,61 @@ def _translate_graph(self, model: onnx.ModelProto, function_name: Optional[str])
def add(line: str) -> None:
result.append(line)

add("@script()")
add(f"def {function_name}{_translate_signature(graph.input, graph.output)}")
if self.skip_initializers:
indent_level = 2
indent = _SINGLE_INDENT
else:
indent_level = 1
indent = ""
add(f"{indent}@script()")
add(f"{indent}def {function_name}{_translate_signature(graph.input, graph.output)}")
indent = indent + _SINGLE_INDENT
doc = graph.doc_string
if doc:
add(f' """{doc}"""')
add(self._translate_graph_body(graph, opsets, indent=1))
add(f'{indent}"""{doc}"""')
add(self._translate_graph_body(graph, opsets, indent=indent_level))
return_values = ", ".join(self._translate_onnx_var(x) for x in graph.output)
add(f" return {return_values}")
return "\n".join(result)
add(f"{indent}return {return_values}")
script = "\n".join(result)
if self.skipped_initializers:
return self._substitute_initializers(script, function_name)
return script

def _substitute_initializers(self, script: str, script_function_name: str) -> str:
init_names = self.skipped_initializers.keys()
# Formal parameters representing initializers (single level indentation)
__ = _SINGLE_INDENT
initializers_as_params = "\n".join(f"{__}{x}," for x in init_names)

def generate_rand(name: str, value: TensorProto) -> str:
shape = ",".join(str(d) for d in value.dims)
if value.data_type != TensorProto.FLOAT:
raise NotImplementedError(
f"Unable to generate random initializer for data type {value.data_type}."
)
return f"{__}{name} = numpy.random.rand({shape}).astype(numpy.float32)"

random_initializer_values = "\n".join(
generate_rand(key, value) for key, value in self.skipped_initializers.items()
)
# Actual parameter values for initializers (double level indentation)
indented_initializers_as_params = "\n".join(f"{__}{__}{x}," for x in init_names)
return f"""
def make_model(
{initializers_as_params}
):
{script}

{__}model = {script_function_name}.to_model_proto()
{__}return model

def make_model_with_random_weights():
{random_initializer_values}
{__}model = make_model(
{indented_initializers_as_params}
{__})
{__}return model
"""

def _import_onnx_types(
self, proto: onnx.ModelProto | onnx.GraphProto | onnx.FunctionProto
Expand Down Expand Up @@ -778,9 +834,11 @@ def visit_graph(graph: onnx.GraphProto) -> None:
def export2python(
model_onnx,
function_name: Optional[str] = None,
*,
rename: bool = False,
use_operators: bool = False,
inline_const: bool = False,
skip_initializers: bool = False,
):
"""Exports an ONNX model to the *python* syntax.

Expand All @@ -790,6 +848,9 @@ def export2python(
function_name: main function name
use_operators: use Python operators.
inline_const: replace ONNX constants inline if compact
skip_initializers: generated script will not include initializers.
Instead, a function that generates the model, given initializer values, is generated,
along with one that generates random values for the initializers.

Returns:
python code
Expand All @@ -815,5 +876,10 @@ def export2python(
if not isinstance(model_onnx, (ModelProto, FunctionProto)):
raise TypeError(f"The function expects a ModelProto not {type(model_onnx)!r}.")

exporter = Exporter(rename, use_operators, inline_const)
exporter = _Exporter(
rename=rename,
use_operators=use_operators,
inline_const=inline_const,
skip_initializers=skip_initializers,
)
return exporter.export(model_onnx, function_name)
29 changes: 29 additions & 0 deletions tools/onnx2external.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
import os

import onnx
import onnx.external_data_helper


def convert2external(input_file_name: str) -> None:
dir_name = os.path.dirname(input_file_name)
base_name, _suffix = os.path.splitext(os.path.basename(input_file_name))
model = onnx.load(input_file_name)
os.makedirs(os.path.join(dir_name, base_name), exist_ok=True)
onnx.external_data_helper.convert_model_to_external_data(
model, location="external_data.onnx", size_threshold=128
)
onnx.save(model, os.path.join(dir_name, base_name, "model.onnx"))


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert ONNX model file to external data format"
)
parser.add_argument("input", help="ONNX model file to convert")
args = parser.parse_args()

convert2external(args.input)
16 changes: 13 additions & 3 deletions tools/onnx2script.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@


def convert2script(
input_file_name: str, output_file_name: Optional[str], verbose: bool
input_file_name: str, output_file_name: Optional[str], verbose: bool, initializers: bool
) -> None:
model = onnx.load(input_file_name, load_external_data=False)
python_code = onnxscript.proto2python(
model, use_operators=not verbose, inline_const=not verbose
model,
use_operators=not verbose,
inline_const=not verbose,
skip_initializers=not initializers,
)

# If output file name is not provided, use the input file name with .py extension
Expand All @@ -55,6 +58,13 @@ def convert2script(
help="Verbose mode, suppresses use of overloaded operators and inline constants",
default=False,
)
parser.add_argument(
"-i",
"--initializers",
action="store_true",
help="Include initializers in the generated script",
default=False,
)

args = parser.parse_args()
convert2script(args.input, args.output, args.verbose)
convert2script(args.input, args.output, args.verbose, args.initializers)
Loading