diff --git a/onnxscript/backend/onnx_export.py b/onnxscript/backend/onnx_export.py index c8a6a9a640..b3f695d700 100644 --- a/onnxscript/backend/onnx_export.py +++ b/onnxscript/backend/onnx_export.py @@ -247,11 +247,11 @@ def _cond_is_used_in_loop_body(graph: GraphProto) -> bool: return False -class Exporter: +class _Exporter: """Class used for recursive traversal of Proto structures.""" def __init__( - self, rename: bool, use_operators: bool = False, inline_const: bool = False + self, *, rename: bool, use_operators: bool, inline_const: bool, skip_initializers: bool ) -> None: self.use_operators = use_operators if rename: @@ -266,6 +266,8 @@ def __init__( # _name_remappings: used to undo the SSA-renaming in ONNX control-flow ops. # We map the multiple SSA-variants back to the same Python variable name. self._name_remappings: list[dict[str, str]] = [] + self.skip_initializers = skip_initializers + self.skipped_initializers: dict[str, onnx.TensorProto] = {} def _handle_attrname_conflict(self, renamer): """Add ref-attr-name-conflict handling logic to renaming function.""" @@ -338,6 +340,14 @@ def _translate_graph_body(self, graph, opsets, indent=0): code = [] if hasattr(graph, "initializer"): for init in graph.initializer: + if self.skip_initializers: + init_py_name = self._translate_onnx_var(init.name) + if init_py_name in self.skipped_initializers: + raise RuntimeError( + f"Initializer {init.name!r} is already present in skipped_initializers." + ) + self.skipped_initializers[init_py_name] = init + continue node = make_node( "Constant", [], @@ -684,15 +694,61 @@ def _translate_graph(self, model: onnx.ModelProto, function_name: Optional[str]) def add(line: str) -> None: result.append(line) - add("@script()") - add(f"def {function_name}{_translate_signature(graph.input, graph.output)}") + if self.skip_initializers: + indent_level = 2 + indent = _SINGLE_INDENT + else: + indent_level = 1 + indent = "" + add(f"{indent}@script()") + add(f"{indent}def {function_name}{_translate_signature(graph.input, graph.output)}") + indent = indent + _SINGLE_INDENT doc = graph.doc_string if doc: - add(f' """{doc}"""') - add(self._translate_graph_body(graph, opsets, indent=1)) + add(f'{indent}"""{doc}"""') + add(self._translate_graph_body(graph, opsets, indent=indent_level)) return_values = ", ".join(self._translate_onnx_var(x) for x in graph.output) - add(f" return {return_values}") - return "\n".join(result) + add(f"{indent}return {return_values}") + script = "\n".join(result) + if self.skipped_initializers: + return self._substitute_initializers(script, function_name) + return script + + def _substitute_initializers(self, script: str, script_function_name: str) -> str: + init_names = self.skipped_initializers.keys() + # Formal parameters representing initializers (single level indentation) + __ = _SINGLE_INDENT + initializers_as_params = "\n".join(f"{__}{x}," for x in init_names) + + def generate_rand(name: str, value: TensorProto) -> str: + shape = ",".join(str(d) for d in value.dims) + if value.data_type != TensorProto.FLOAT: + raise NotImplementedError( + f"Unable to generate random initializer for data type {value.data_type}." + ) + return f"{__}{name} = numpy.random.rand({shape}).astype(numpy.float32)" + + random_initializer_values = "\n".join( + generate_rand(key, value) for key, value in self.skipped_initializers.items() + ) + # Actual parameter values for initializers (double level indentation) + indented_initializers_as_params = "\n".join(f"{__}{__}{x}," for x in init_names) + return f""" +def make_model( +{initializers_as_params} +): +{script} + +{__}model = {script_function_name}.to_model_proto() +{__}return model + +def make_model_with_random_weights(): +{random_initializer_values} +{__}model = make_model( +{indented_initializers_as_params} +{__}) +{__}return model +""" def _import_onnx_types( self, proto: onnx.ModelProto | onnx.GraphProto | onnx.FunctionProto @@ -778,9 +834,11 @@ def visit_graph(graph: onnx.GraphProto) -> None: def export2python( model_onnx, function_name: Optional[str] = None, + *, rename: bool = False, use_operators: bool = False, inline_const: bool = False, + skip_initializers: bool = False, ): """Exports an ONNX model to the *python* syntax. @@ -790,6 +848,9 @@ def export2python( function_name: main function name use_operators: use Python operators. inline_const: replace ONNX constants inline if compact + skip_initializers: generated script will not include initializers. + Instead, a function that generates the model, given initializer values, is generated, + along with one that generates random values for the initializers. Returns: python code @@ -815,5 +876,10 @@ def export2python( if not isinstance(model_onnx, (ModelProto, FunctionProto)): raise TypeError(f"The function expects a ModelProto not {type(model_onnx)!r}.") - exporter = Exporter(rename, use_operators, inline_const) + exporter = _Exporter( + rename=rename, + use_operators=use_operators, + inline_const=inline_const, + skip_initializers=skip_initializers, + ) return exporter.export(model_onnx, function_name) diff --git a/tools/onnx2external.py b/tools/onnx2external.py new file mode 100644 index 0000000000..1685458251 --- /dev/null +++ b/tools/onnx2external.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import argparse +import os + +import onnx +import onnx.external_data_helper + + +def convert2external(input_file_name: str) -> None: + dir_name = os.path.dirname(input_file_name) + base_name, _suffix = os.path.splitext(os.path.basename(input_file_name)) + model = onnx.load(input_file_name) + os.makedirs(os.path.join(dir_name, base_name), exist_ok=True) + onnx.external_data_helper.convert_model_to_external_data( + model, location="external_data.onnx", size_threshold=128 + ) + onnx.save(model, os.path.join(dir_name, base_name, "model.onnx")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert ONNX model file to external data format" + ) + parser.add_argument("input", help="ONNX model file to convert") + args = parser.parse_args() + + convert2external(args.input) diff --git a/tools/onnx2script.py b/tools/onnx2script.py index 02b220799a..7b57bf91d6 100644 --- a/tools/onnx2script.py +++ b/tools/onnx2script.py @@ -28,11 +28,14 @@ def convert2script( - input_file_name: str, output_file_name: Optional[str], verbose: bool + input_file_name: str, output_file_name: Optional[str], verbose: bool, initializers: bool ) -> None: model = onnx.load(input_file_name, load_external_data=False) python_code = onnxscript.proto2python( - model, use_operators=not verbose, inline_const=not verbose + model, + use_operators=not verbose, + inline_const=not verbose, + skip_initializers=not initializers, ) # If output file name is not provided, use the input file name with .py extension @@ -55,6 +58,13 @@ def convert2script( help="Verbose mode, suppresses use of overloaded operators and inline constants", default=False, ) + parser.add_argument( + "-i", + "--initializers", + action="store_true", + help="Include initializers in the generated script", + default=False, + ) args = parser.parse_args() - convert2script(args.input, args.output, args.verbose) + convert2script(args.input, args.output, args.verbose, args.initializers)