Skip to content

Commit

Permalink
tensorshape for graph visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa committed Apr 5, 2018
1 parent 87dbc51 commit f425520
Showing 1 changed file with 29 additions and 7 deletions.
36 changes: 29 additions & 7 deletions tensorboardX/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,36 @@ def parse(graph):
if any(i.uniqueName() not in scope.keys() for i in n.inputs()): # 0.3.1 workaround
continue
inputs = [replace(i.uniqueName(), scope) for i in n.inputs()]
uname = next(iter(n.outputs())).uniqueName() # FIXME: only first output is considered
nodes.append({'name': replace(uname, scope), 'op': n.kind(), 'inputs': inputs, 'attr': attrs})
outputnode = next(iter(n.outputs())) # FIXME: only first output is considered
uname = outputnode.uniqueName()
if outputnode.type().kind() == 'TensorType':
outputsize = outputnode.type().sizes()
nodes.append({'name': replace(uname, scope),
'op': n.kind(),
'inputs': inputs,
'attr': attrs,
'outputsize': outputsize})
else:
nodes.append({'name': replace(uname, scope), 'op': n.kind(), 'inputs': inputs, 'attr': attrs})

for n in graph.inputs():
uname = n.uniqueName()
if uname not in scope.keys():
scope[uname] = 'unused'
nodes.append({'name': replace(uname, scope), 'op': 'Parameter', 'inputs': [], 'attr': str(n.type())})
outputsize = n.type().sizes()
nodes.append({'name': replace(uname, scope),
'op': 'Parameter',
'inputs': [],
'attr': str(n.type()),
'outputsize': outputsize})

return nodes


def graph(model, args, verbose=False):
import torch
with torch.onnx.set_training(model, False):
trace, _ = torch.jit.trace(model, args)
trace, _ = torch.jit.get_trace_graph(model, args)
if LooseVersion(torch.__version__) >= LooseVersion("0.4"):
torch.onnx._optimize_trace(trace, False)
else:
Expand All @@ -60,7 +74,15 @@ def graph(model, args, verbose=False):
list_of_nodes = parse(graph)
nodes = []
for node in list_of_nodes:
nodes.append(
NodeDef(name=node['name'], op=node['op'], input=node['inputs'],
attr={'lanpa': AttrValue(s=node['attr'].encode(encoding='utf_8'))}))
if 'outputsize' in node.keys():
shapeproto = TensorShapeProto(
dim=[TensorShapeProto.Dim(size=d) for d in node['outputsize']])
nodes.append(
NodeDef(name=node['name'], op=node['op'], input=node['inputs'],
attr={'lanpa': AttrValue(s=node['attr'].encode(encoding='utf_8')),
'_output_shapes': AttrValue(list=AttrValue.ListValue(shape=[shapeproto]))}))
else:
nodes.append(
NodeDef(name=node['name'], op=node['op'], input=node['inputs'],
attr={'lanpa': AttrValue(s=node['attr'].encode(encoding='utf_8'))}))
return GraphDef(node=nodes, versions=VersionDef(producer=22))

0 comments on commit f425520

Please sign in to comment.