Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added graph methods for saving using the external data storage format #1088

Merged
merged 1 commit into from Sep 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_internals.py
Expand Up @@ -226,7 +226,7 @@ def test_node_attr_onnx(self):
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
n1 = g.get_node_by_name("n1")
self.assertTrue("my_attr" in n1.attr)
self.assertTrue("my_attr" not in n1.attr_onnx)
self.assertTrue("my_attr" not in n1.get_onnx_attrs())

n1 = helper.make_node("Conv", ["X", "W"], ["Y"], name="n1", domain="my_domain", my_attr="my_attr")
graph_proto = helper.make_graph(
Expand All @@ -240,7 +240,7 @@ def test_node_attr_onnx(self):
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
n1 = g.get_node_by_name("n1")
self.assertTrue("my_attr" in n1.attr)
self.assertTrue("my_attr" in n1.attr_onnx)
self.assertTrue("my_attr" in n1.get_onnx_attrs())

def test_tensor_data(self):
tensors = {
Expand Down
59 changes: 45 additions & 14 deletions tf2onnx/graph.py
Expand Up @@ -28,6 +28,13 @@
# todo(pengwa): remove protected-access later
# pylint: disable=broad-except,protected-access

class ExternalTensorStorage():
"""Passed into graph and node methods to accumulate tensors to save externally"""
def __init__(self):
self.name_to_tensor_data = {}
self.name_counter = 0
self.external_tensor_size_threshold = 1024
self.node_to_modified_value_attr = {}

class Node(object):
"""A Node - wrapper around onnx nodes that we use for graph manipulations."""
Expand Down Expand Up @@ -88,16 +95,40 @@ def inputs(self):
def attr(self):
return self._attr

@property
def attr_onnx(self):
"""Return onnx valid attributes"""
def get_value_attr(self, external_tensor_storage=None):
"""Return onnx attr for value property of node.
Attr is modified to point to external tensor data stored in external_tensor_storage, if included.
"""
a = self._attr["value"]
if external_tensor_storage is not None and self in external_tensor_storage.node_to_modified_value_attr:
return external_tensor_storage.node_to_modified_value_attr[self]
if external_tensor_storage is None or a.type != AttributeProto.TENSOR:
return a
if np.product(a.t.dims) > external_tensor_storage.external_tensor_size_threshold:
a = copy.copy(a)
tensor_name = self.name + "_" + str(external_tensor_storage.name_counter)
external_tensor_storage.name_counter += 1
external_tensor_storage.name_to_tensor_data[tensor_name] = a.t.raw_data
external_tensor_storage.node_to_modified_value_attr[self] = a
a.t.raw_data = b'__EXTERNAL'
location = a.t.external_data.add()
location.key = "location"
location.value = tensor_name
a.t.data_location = TensorProto.EXTERNAL
return a

def get_onnx_attrs(self, external_tensor_storage=None):
"""Return onnx valid attributes.
Attrs point to external tensor data stored in external_tensor_storage, if included."""
schema = get_schema(self.type, self.graph.opset, self.domain)
if schema is None and not (self.is_const() or self.is_graph_input()):
logger.debug("Node %s uses non-stardard onnx op <%s, %s>, skip attribute check",
self.name, self.domain, self.type)
onnx_attrs = {}
for a in self._attr.values():
if schema is None or schema.has_attribute(a.name):
if a.name == "value":
onnx_attrs[a.name] = self.get_value_attr(external_tensor_storage)
elif schema is None or schema.has_attribute(a.name):
onnx_attrs[a.name] = a
return onnx_attrs

Expand Down Expand Up @@ -328,7 +359,7 @@ def set_body_graph_as_attr(self, attr_name, graph):
self.graph.contained_graphs[self.name].update({attr_name: graph})
graph.parent_graph = self.graph

def update_proto(self):
def update_proto(self, external_tensor_storage=None):
"""Update protobuf from internal structure."""
nodes = list(self._op.input)
for node in nodes:
Expand All @@ -346,10 +377,10 @@ def update_proto(self):
attr_graphs = self.get_body_graphs()
if attr_graphs:
for attr_name, sub_graph in attr_graphs.items():
graph_proto = sub_graph.make_graph("graph for " + self.name + " " + attr_name)
graph_proto = sub_graph.make_graph("graph for " + self.name + " " + attr_name, external_tensor_storage)
self.set_attr(attr_name, graph_proto)

attr = list(self.attr_onnx.values())
attr = list(self.get_onnx_attrs(external_tensor_storage).values())
if attr:
self._op.attribute.extend(attr)

Expand Down Expand Up @@ -743,10 +774,10 @@ def update_node_shape_dtype(self, node, override=False):
self.set_shape(output, shape)
logger.debug("Set shape of [%s] to %s", output, shape)

def update_proto(self):
def update_proto(self, external_tensor_storage=None):
"""Update the onnx protobuf from out internal Node structure."""
for node in self._nodes:
node.update_proto()
node.update_proto(external_tensor_storage)

def get_nodes(self):
"""Get node list."""
Expand Down Expand Up @@ -963,7 +994,7 @@ def _get_unvisited_child(g, node, not_visited):
ret = [x for _, x in sorted(zip(label, ops))]
self.reset_nodes(ret)

def make_graph(self, doc, graph_name=None):
def make_graph(self, doc, graph_name=None, external_tensor_storage=None):
"""
Create GraphProto for onnx from internal graph.
Args:
Expand All @@ -973,7 +1004,7 @@ def make_graph(self, doc, graph_name=None):
graph_name = graph_name or self.graph_name
self.delete_unused_nodes(self.outputs)
self.topological_sort(self.get_nodes())
self.update_proto()
self.update_proto(external_tensor_storage)

# TODO: we'd want to do something like this so that transpose optimizer is active
# for all (unit) tests
Expand Down Expand Up @@ -1016,7 +1047,7 @@ def make_graph(self, doc, graph_name=None):
# not to use numpy_helper.from_array to create a new tensor
# because sometimes onnx will have a bug that only check the tensor data in specific field
# such as at upsample it only checks the float_data field.
t = op.get_attr("value")
t = op.get_value_attr(external_tensor_storage)
tensor = helper.get_attribute_value(t)
tensor.name = op.output[0]
initializers.append(tensor)
Expand Down Expand Up @@ -1045,14 +1076,14 @@ def make_graph(self, doc, graph_name=None):

return graph

def make_model(self, graph_doc, optimize=False, graph_name="tf2onnx", **kwargs):
def make_model(self, graph_doc, optimize=False, graph_name="tf2onnx", external_tensor_storage=None, **kwargs):
"""
Create final ModelProto for onnx from internal graph.
Args:
optimize: optimize graph via onnx
doc: text for doc string of the model
"""
graph = self.make_graph(graph_doc, graph_name)
graph = self.make_graph(graph_doc, graph_name, external_tensor_storage)

if "producer_name" not in kwargs:
kwargs = {"producer_name": "tf2onnx",
Expand Down
4 changes: 2 additions & 2 deletions tf2onnx/optimizer/transpose_optimizer.py
Expand Up @@ -282,7 +282,7 @@ def _handle_nhwc_tranpose(self, trans):
return False
# move transpose into branches to let Transposes can be "handled" in each branch
for n in out_nodes:
branch_trans = n.graph.make_node("Transpose", [trans.input[0]], attr=trans.attr_onnx)
branch_trans = n.graph.make_node("Transpose", [trans.input[0]], attr=trans.get_onnx_attrs())
n.graph.replace_input(n, trans.output[0], branch_trans.output[0])

self._g.remove_node(trans.name)
Expand Down Expand Up @@ -407,7 +407,7 @@ def _add_handler(self, trans, node):
target_node.set_tensor_value(target_val)

conv_inputs = [t_p.input[0], t_p.input[1], node.input[1]]
conv_node = self._g.make_node(t_p.type, conv_inputs, attr=t_p.attr_onnx)
conv_node = self._g.make_node(t_p.type, conv_inputs, attr=t_p.get_onnx_attrs())
ops = self._g.get_nodes()
trans.input[0] = utils.port_name(conv_node.name)
self._g.replace_all_inputs(ops, node.output[0], trans.output[0])
Expand Down
2 changes: 1 addition & 1 deletion tf2onnx/schemas.py
Expand Up @@ -136,7 +136,7 @@ def build_onnx_op(node):
copied_sub_graph = copy.deepcopy(sub_graph)
graph_proto = copied_sub_graph.make_graph("graph for " + node.name + " " + attr_name)
attr.append(helper.make_attribute(attr_name, graph_proto))
attr.extend(node.attr_onnx.values())
attr.extend(node.get_onnx_attrs().values())
if attr:
onnx_node.attribute.extend(attr)
return onnx_node
Expand Down
4 changes: 2 additions & 2 deletions tf2onnx/utils.py
Expand Up @@ -405,9 +405,9 @@ def is_same(node_1, node_2):
if node_1.type != node_2.type:
return False
# check onnx attributes
if node_1.attr_onnx.keys() != node_2.attr_onnx.keys():
if node_1.get_onnx_attrs().keys() != node_2.get_onnx_attrs().keys():
return False
for name in node_1.attr_onnx.keys(): # pylint: disable=consider-iterating-dictionary
for name in node_1.get_onnx_attrs().keys(): # pylint: disable=consider-iterating-dictionary
if node_1.get_attr_value(name) != node_2.get_attr_value(name):
return False
return True
Expand Down