Skip to content

Commit

Permalink
Add graph parser (#301)
Browse files Browse the repository at this point in the history
* create TensorflowGraph class and refactor

* change the import order

* bug fix

* add graph parser

* update

* parse LSTM

* refactor and add batch major transpose

* bug fix

* bug fix

* add doc

* yapf

* add value error if node name hits multiple scopes

* bug fix

* improve var name

* add GRU

* add LSTM, RNN

* multi layers support

* bug fix

* fix GRU

* move to experiment folder

* refactor and add rnn support to cli

* update frontend handlers

* update opset version

* fix test case
  • Loading branch information
fumihwh committed Jan 4, 2019
1 parent 8a640a5 commit 7463617
Show file tree
Hide file tree
Showing 20 changed files with 982 additions and 97 deletions.
26 changes: 16 additions & 10 deletions doc/CLI.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ More information: `onnx-tf convert -h`
```
usage: onnx-tf [-h] --infile INFILE --outfile OUTFILE --convert_to {onnx,tf}
[--graph GRAPH] [--device DEVICE] [--strict STRICT]
[--output OUTPUT] [--opset OPSET]
[--ignore_unimplemented IGNORE_UNIMPLEMENTED]
[--optimizer_passes OPTIMIZER_PASSES] [--opset OPSET]
[--output OUTPUT]
[--optimizer_passes OPTIMIZER_PASSES]
[--rnn_type {GRU,LSTM,RNN}]
This is the converter for converting protocol buffer between tf and onnx.
Expand Down Expand Up @@ -68,6 +69,14 @@ backend arguments (onnx -> tf):
and AveragePool ops. (from onnx_tf.backend.prepare)
frontend arguments (tf -> onnx):
--output OUTPUT List of string or a string specifying the name of the
output graph node. (from
onnx_tf.frontend.tensorflow_graph_to_onnx_model)
--opset OPSET Opset version number, list or tuple. Default is 0
means using latest version with domain ''. List or
tuple items should be (str domain, int version
number). (from
onnx_tf.frontend.tensorflow_graph_to_onnx_model)
--ignore_unimplemented IGNORE_UNIMPLEMENTED
Convert to ONNX model and ignore all the operators
that are not currently supported by onnx-tensorflow.
Expand All @@ -80,12 +89,9 @@ frontend arguments (tf -> onnx):
x/onnx/blob/master/onnx/optimizer.py for available
optimization passes. (from
onnx_tf.frontend.tensorflow_graph_to_onnx_model)
--opset OPSET Opset version number, list or tuple. Default is 0
means using latest version with domain ''. List or
tuple items should be (str domain, int version
number). (from
onnx_tf.frontend.tensorflow_graph_to_onnx_model)
--output OUTPUT List of string or a string specifying the name of the
output graph node. (from
onnx_tf.frontend.tensorflow_graph_to_onnx_model)
EXPERIMENTAL ARGUMENTS:
--rnn_type {GRU,LSTM,RNN}
RNN graph type if using experimental feature: convert
rnn graph to onnx.
```
1 change: 0 additions & 1 deletion example/onnx_to_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from onnx_tf.backend import prepare


onnx_model = onnx.load("input_path") # load onnx model
tf_rep = prepare(onnx_model) # prepare tf representation
tf_rep.export_graph("output_path") # export the model
11 changes: 6 additions & 5 deletions example/tf_to_onnx.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from tensorflow.core.framework import graph_pb2

from onnx_tf.common import get_output_node_names
from onnx_tf.frontend import tensorflow_graph_to_onnx_model

from onnx_tf.pb_wrapper import TensorflowGraph

graph_def = graph_pb2.GraphDef()
with open("input_path", "rb") as f: # load tf graph def
with open("input_path", "rb") as f: # load tf graph def
graph_def.ParseFromString(f.read())
output = get_output_node_names(graph_def) # get output node names
output = TensorflowGraph.get_output_node_names(
graph_def) # get output node names

model = tensorflow_graph_to_onnx_model(graph_def, output) # convert tf graph to onnx model
model = tensorflow_graph_to_onnx_model(graph_def,
output) # convert tf graph to onnx model
with open("output_path", 'wb') as f:
f.write(model.SerializeToString())
5 changes: 4 additions & 1 deletion onnx_tf/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
import onnx_tf.opr_checker
import onnx_tf.optimizer


def main():
args = sys.argv[1:]
parser = argparse.ArgumentParser(
description="ONNX-Tensorflow Command Line Interface")
parser.add_argument(
"command", choices=["convert", "check", "optimize"], help="Available commands.")
"command",
choices=["convert", "check", "optimize"],
help="Available commands.")

if len(args) == 0:
parser.parse_args(["-h"])
Expand Down
11 changes: 9 additions & 2 deletions onnx_tf/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,13 @@ def supports_device(device):
return False


@deprecated("onnx_tf.common.get_outputs_names is deprecated.{} {}".format(
deprecated.MSG_WILL_REMOVE,
"Use TensorflowGraph.get_outputs_names instead."))
def get_output_node_names(graph_def):
"""Get output node names from GraphDef.
Args:
graph_def: GraphDef object.
Returns:
List of output node names.
"""
Expand All @@ -178,3 +179,9 @@ def get_output_node_names(graph_def):
nodes[node.name] = node
input_names.update(set(node.input))
return list(set(nodes) - input_names)


CONST_MINUS_ONE_INT32 = "_onnx_tf_internal_minus_one_int32"
CONST_ZERO_INT32 = "_onnx_tf_internal_zero_int32"
CONST_ONE_INT32 = "_onnx_tf_internal_one_int32"
CONST_ONE_FP32 = "_onnx_tf_internal_one_fp32"
26 changes: 21 additions & 5 deletions onnx_tf/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from tensorflow.python.tools import freeze_graph

import onnx_tf.backend as backend
from onnx_tf.common import get_output_node_names
from onnx_tf.common import get_unique_suffix
import onnx_tf.experiment.frontend as experiment_frontend
import onnx_tf.frontend as frontend
from onnx_tf.pb_wrapper import TensorflowGraph

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
Expand Down Expand Up @@ -123,6 +124,15 @@ def add_argument_group(parser, group_name, funcs):
for k, v in param_doc_dict.items():
group.add_argument("--{}".format(k), help=v["doc"], **v["params"])

def add_experimental_args(parser):
group = parser.add_argument_group("EXPERIMENTAL ARGUMENTS")
group.add_argument(
"--rnn_type",
choices=["GRU", "LSTM", "RNN"],
help=
"RNN graph type if using experimental feature: convert rnn graph to onnx."
)

# backend args
# Args must be named consistently with respect to backend.prepare.
add_argument_group(parser, "backend arguments (onnx -> tf)",
Expand Down Expand Up @@ -151,6 +161,8 @@ def add_argument_group(parser, group_name, funcs):
}
})])

add_experimental_args(parser)

return parser.parse_args(args)


Expand Down Expand Up @@ -182,7 +194,6 @@ def convert(infile, outfile, convert_to, graph=None, **kwargs):
elif ext == ".ckpt":
latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(infile))
saver = tf.train.import_meta_graph(latest_ckpt + ".meta")
output_node_names = []
temp_file_suffix = get_unique_suffix()
workdir = 'onnx-tf_workdir_{}'.format(temp_file_suffix)
with tf.Session() as sess:
Expand All @@ -192,8 +203,9 @@ def convert(infile, outfile, convert_to, graph=None, **kwargs):
])
saver.restore(sess, latest_ckpt)
# Take users' hint or deduce output node automatically.
kwargs["output"] = kwargs.get("output", None) or get_output_node_names(
sess.graph.as_graph_def())
kwargs["output"] = kwargs.get(
"output", None) or TensorflowGraph.get_output_node_names(
sess.graph.as_graph_def())

# Save the graph to disk for freezing.
tf.train.write_graph(
Expand Down Expand Up @@ -226,6 +238,10 @@ def convert(infile, outfile, convert_to, graph=None, **kwargs):
raise ValueError(
"Input file is not supported. Should be .pb or .ckpt, but get {}".
format(ext))
onnx_model = frontend.tensorflow_graph_to_onnx_model(graph_def, **kwargs)

if "rnn_type" in kwargs:
onnx_model = experiment_frontend.rnn_tf_graph_to_onnx_model(graph_def, **kwargs)
else:
onnx_model = frontend.tensorflow_graph_to_onnx_model(graph_def, **kwargs)
onnx.save(onnx_model, outfile)
logger.info("Converting completes successfully.")
Empty file added onnx_tf/experiment/__init__.py
Empty file.
59 changes: 59 additions & 0 deletions onnx_tf/experiment/frontend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from onnx_tf.experiment.scope_parser import get_rnn_scope_parser
from onnx_tf.frontend import TensorflowFrontend


class ExperimentTensorflowFrontend(TensorflowFrontend):

@classmethod
def rnn_tf_graph_to_onnx_model(cls,
graph_def,
output,
rnn_type,
opset=0,
producer_name="onnx-tensorflow",
graph_name="graph",
ignore_unimplemented=False,
optimizer_passes=None):
"""EXPERIMENTAL
Converts a RNN Tensorflow Graph Proto to an ONNX model
This function converts a Tensorflow Graph proto to an equivalent
representation of ONNX model.
DO NOT DEFINE customized scope name in tf.dynamic_rnn and RNN cell.
:param graph_def: Tensorflow Graph Proto object.
:param output: List of string or a string specifying the name
of the output graph node.
:param opset: Opset version number, list or tuple.
Default is 0 means using latest version with domain ''.
List or tuple items should be (str domain, int version number).
:param rnn_type: The rnn type contained in graph, should be one of GRU, LSTM, RNN.
:param producer_name: The name of the producer.
:param graph_name: The name of the output ONNX Graph.
:param ignore_unimplemented: Convert to ONNX model and ignore all the operators
that are not currently supported by onnx-tensorflow.
This is an experimental feature. By enabling this feature,
the model would not be guaranteed to match the ONNX specifications.
:param optimizer_passes: List of optimization names c.f.
https://github.com/onnx/onnx/blob/master/onnx/optimizer.py for available
optimization passes.
:returns: The equivalent ONNX Model Proto object.
"""

tf_graph = cls._make_tf_graph(graph_def, output, graph_name)
parser = get_rnn_scope_parser(rnn_type)
nodes = parser.parse(tf_graph.nodes)
tf_graph.update_nodes(nodes)

return cls._make_onnx_model(tf_graph, opset, producer_name,
ignore_unimplemented, optimizer_passes)


rnn_tf_graph_to_onnx_model = ExperimentTensorflowFrontend.rnn_tf_graph_to_onnx_model

0 comments on commit 7463617

Please sign in to comment.