-
Notifications
You must be signed in to change notification settings - Fork 299
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* create TensorflowGraph class and refactor * change the import order * bug fix * add graph parser * update * parse LSTM * refactor and add batch major transpose * bug fix * bug fix * add doc * yapf * add value error if node name hits multiple scopes * bug fix * improve var name * add GRU * add LSTM, RNN * multi layers support * bug fix * fix GRU * move to experiment folder * refactor and add rnn support to cli * update frontend handlers * update opset version * fix test case
- Loading branch information
Showing
20 changed files
with
982 additions
and
97 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,15 @@ | ||
from tensorflow.core.framework import graph_pb2 | ||
|
||
from onnx_tf.common import get_output_node_names | ||
from onnx_tf.frontend import tensorflow_graph_to_onnx_model | ||
|
||
from onnx_tf.pb_wrapper import TensorflowGraph | ||
|
||
graph_def = graph_pb2.GraphDef() | ||
with open("input_path", "rb") as f: # load tf graph def | ||
with open("input_path", "rb") as f: # load tf graph def | ||
graph_def.ParseFromString(f.read()) | ||
output = get_output_node_names(graph_def) # get output node names | ||
output = TensorflowGraph.get_output_node_names( | ||
graph_def) # get output node names | ||
|
||
model = tensorflow_graph_to_onnx_model(graph_def, output) # convert tf graph to onnx model | ||
model = tensorflow_graph_to_onnx_model(graph_def, | ||
output) # convert tf graph to onnx model | ||
with open("output_path", 'wb') as f: | ||
f.write(model.SerializeToString()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
from __future__ import unicode_literals | ||
|
||
from onnx_tf.experiment.scope_parser import get_rnn_scope_parser | ||
from onnx_tf.frontend import TensorflowFrontend | ||
|
||
|
||
class ExperimentTensorflowFrontend(TensorflowFrontend): | ||
|
||
@classmethod | ||
def rnn_tf_graph_to_onnx_model(cls, | ||
graph_def, | ||
output, | ||
rnn_type, | ||
opset=0, | ||
producer_name="onnx-tensorflow", | ||
graph_name="graph", | ||
ignore_unimplemented=False, | ||
optimizer_passes=None): | ||
"""EXPERIMENTAL | ||
Converts a RNN Tensorflow Graph Proto to an ONNX model | ||
This function converts a Tensorflow Graph proto to an equivalent | ||
representation of ONNX model. | ||
DO NOT DEFINE customized scope name in tf.dynamic_rnn and RNN cell. | ||
:param graph_def: Tensorflow Graph Proto object. | ||
:param output: List of string or a string specifying the name | ||
of the output graph node. | ||
:param opset: Opset version number, list or tuple. | ||
Default is 0 means using latest version with domain ''. | ||
List or tuple items should be (str domain, int version number). | ||
:param rnn_type: The rnn type contained in graph, should be one of GRU, LSTM, RNN. | ||
:param producer_name: The name of the producer. | ||
:param graph_name: The name of the output ONNX Graph. | ||
:param ignore_unimplemented: Convert to ONNX model and ignore all the operators | ||
that are not currently supported by onnx-tensorflow. | ||
This is an experimental feature. By enabling this feature, | ||
the model would not be guaranteed to match the ONNX specifications. | ||
:param optimizer_passes: List of optimization names c.f. | ||
https://github.com/onnx/onnx/blob/master/onnx/optimizer.py for available | ||
optimization passes. | ||
:returns: The equivalent ONNX Model Proto object. | ||
""" | ||
|
||
tf_graph = cls._make_tf_graph(graph_def, output, graph_name) | ||
parser = get_rnn_scope_parser(rnn_type) | ||
nodes = parser.parse(tf_graph.nodes) | ||
tf_graph.update_nodes(nodes) | ||
|
||
return cls._make_onnx_model(tf_graph, opset, producer_name, | ||
ignore_unimplemented, optimizer_passes) | ||
|
||
|
||
rnn_tf_graph_to_onnx_model = ExperimentTensorflowFrontend.rnn_tf_graph_to_onnx_model |
Oops, something went wrong.