Skip to content

Commit

Permalink
code clean up. Show op name on the node.
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa committed Apr 28, 2018
1 parent 51f8dac commit 5619f0a
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions tensorboardX/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
from distutils.version import LooseVersion


def replace(name, scope):
return '/'.join([scope[name], name])


def parse(graph):
import torch
scope = {}
Expand All @@ -31,32 +27,38 @@ def parse(graph):
for n in graph.nodes():
attrs = {k: n[k] for k in n.attributeNames()}
attrs = str(attrs).replace("'", ' ') # singlequote will be escaped by tensorboard
if any(i.uniqueName() not in scope.keys() for i in n.inputs()): # 0.3.1 workaround
continue
inputs = [replace(i.uniqueName(), scope) for i in n.inputs()]
outputnode = next(iter(n.outputs())) # FIXME: only first output is considered
inputs = [i.uniqueName() for i in n.inputs()]
outputnode = next(iter(n.outputs())) # FIXME: only first output is considered (only Dropout)
uname = outputnode.uniqueName()
if outputnode.type().kind() == 'TensorType':
outputsize = outputnode.type().sizes()
nodes.append({'name': replace(uname, scope),
nodes.append({'name': uname,
'op': n.kind(),
'inputs': inputs,
'attr': attrs,
'outputsize': outputsize})
else:
nodes.append({'name': replace(uname, scope), 'op': n.kind(), 'inputs': inputs, 'attr': attrs})
nodes.append({'name': uname, 'op': n.kind(), 'inputs': inputs, 'attr': attrs})

for n in graph.inputs():
uname = n.uniqueName()
if uname not in scope.keys():
scope[uname] = 'unused'
outputsize = n.type().sizes()
nodes.append({'name': replace(uname, scope),
nodes.append({'name': uname,
'op': 'Parameter',
'inputs': [],
'attr': str(n.type()),
'outputsize': outputsize})

mapping = {}
for n in nodes:
mapping[n['name']] = scope[n['name']] + '/' + \
n['op'].replace('onnx::', '') + '_' + n['name']
for n in nodes:
n['name'] = mapping[n['name']]
for i, s in enumerate(n['inputs']):
n['inputs'][i] = mapping[s]
return nodes


Expand Down

0 comments on commit 5619f0a

Please sign in to comment.