Skip to content

Commit

Permalink
add_graph for pytorch 1.3 (#508)
Browse files Browse the repository at this point in the history
* change name to backward_compat_mode

* add_graph for pytorch1.3
  • Loading branch information
lanpa committed Sep 23, 2019
1 parent 9b9933a commit 23de6f0
Showing 1 changed file with 45 additions and 29 deletions.
74 changes: 45 additions & 29 deletions tensorboardX/pytorch_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@

methods_OP = ['attributeNames', 'hasMultipleOutputs', 'hasUses', 'inputs',
'kind', 'outputs', 'outputsSize', 'scopeName']
methods_IO = ['node', 'offset', 'debugName'] # 'unique' <int> , 'type' <Tensor<class 'torch._C.Type'>>

backward_mode = False
methods_IO = []
backward_compat_mode = False

class NodeBase(object):
def __init__(self,
Expand Down Expand Up @@ -45,14 +44,14 @@ def __init__(self, node_cpp, valid_methods):
super(NodePy, self).__init__(node_cpp)
valid_methods = valid_methods[:]
self.inputs = []
global backward_mode
global backward_compat_mode
for m in valid_methods:
if m == 'inputs' or m == 'outputs':
list_of_node = list(getattr(node_cpp, m)())
io_unique_names = []
io_tensor_sizes = []
for n in list_of_node:
if backward_mode:
if backward_compat_mode:
io_unique_names.append(n.uniqueName())
else:
io_unique_names.append(n.debugName())
Expand All @@ -66,24 +65,21 @@ def __init__(self, node_cpp, valid_methods):
setattr(self, m + 'tensor_size', io_tensor_sizes)

else:
if m == 'debugName' and backward_mode:
if m == 'debugName' and backward_compat_mode:
setattr(self, m, getattr(node_cpp, 'uniqueName')())
else:
setattr(self, m, getattr(node_cpp, m)())


class NodePyIO(NodePy):
def __init__(self, node_cpp, input_or_output=None):
def __init__(self, node_cpp, input_or_output=None, debugName=''):
super(NodePyIO, self).__init__(node_cpp, methods_IO)
try:
tensor_size = node_cpp.type().sizes()
except RuntimeError:
tensor_size = [1, ] # fail when constant model is used.
self.tensor_size = tensor_size
self.tensor_size = [] # tensor_size
# Kind attribute string is purely descriptive and will be shown
# in detailed information for the node in TensorBoard's graph plugin.
#
# NodePyOP nodes get this from their kind() method.
self.debugName = debugName
self.kind = 'Parameter'
if input_or_output:
self.input_or_output = input_or_output
Expand Down Expand Up @@ -217,28 +213,45 @@ def parse(graph, args=None, omit_useless_nodes=True):
import torch
n_inputs = len(args) # not sure...

inputnodes = list(graph.inputs())
global backward_compat_mode
if not backward_compat_mode:
try:
inputnodes[0].debugName()
except:
backward_compat_mode = True

nodes_py = GraphPy()
for i, node in enumerate(graph.inputs()):
global backward_mode
if not backward_mode:
try:
node.debugName()
except:
backward_mode = True
if omit_useless_nodes:
if len(node.uses()) == 0: # number of user of the node (= number of outputs/ fanout)
continue

if i < n_inputs:
nodes_py.append(NodePyIO(node, 'input'))
else:
nodes_py.append(NodePyIO(node)) # parameter
for node in graph.inputs():
if node.debugName() == 'self':
continue
nodes_py.append(NodePyIO(node, input_or_output='Input', debugName=node.debugName()))


for node in graph.nodes():
# These nodes refers to parameters such as kernel size, stride, etc.
# The graph will be very tedious if we include all of them. So skip.
# p.s. Those Constant will be composed by 'prim::listConstruct' and then
# send to common OPs such as Maxpool, Conv, Linear.
# We can let user pass verbosity value to dicide how detailed the graph is.
if node.kind()=='prim::Constant':
continue

# By observation, prim::GetAttr are parameter related. ClassType is used to decorate its scope.
if node.kind()=='prim::GetAttr':
assert node.scopeName() == ''

# Since `populate_namespace_from_OP_to_IO` is already available, we just ignore this.
# TODO: When it comes to shared parameter, will it still work?
if " : ClassType" in node.__repr__():
continue

nodes_py.append(NodePyIO(node, debugName=list(node.outputs())[0].debugName()))
continue

nodes_py.append(NodePyOP(node))

for node in graph.outputs(): # must place last.
NodePyIO(node, 'output')
nodes_py.find_common_root()
nodes_py.populate_namespace_from_OP_to_IO()
return nodes_py.to_proto()
Expand All @@ -260,7 +273,10 @@ def graph(model, args, verbose=False, **kwargs):
with torch.onnx.set_training(model, False): # TODO: move outside of torch.onnx
try:
trace = torch.jit.trace(model, args)
graph = trace.graph
if type(trace) == torch.jit.ScriptModule:
graph = trace.forward_impl.graph
else:
graph = trace.graph

except RuntimeError as e:
print(e)
Expand Down

0 comments on commit 23de6f0

Please sign in to comment.