Skip to content

Commit

Permalink
handles unused module in a graph
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa committed Jan 5, 2018
1 parent 6d879e6 commit a01dec2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 5 additions & 1 deletion tensorboardX/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,21 @@ def parse(graph):

for n in graph.inputs():
uname = n.uniqueName()
if uname not in scope.keys():
scope[uname] = 'unused'
nodes.append({'name': replace(uname, scope), 'op': 'Parameter', 'inputs': [], 'attr': str(n.type())})

return nodes


def graph(model, args):
def graph(model, args, verbose=False):
import torch
with torch.onnx.set_training(model, False):
trace, _ = torch.jit.trace(model, args)
torch.onnx._optimize_trace(trace, False)
graph = trace.graph()
if verbose:
print(graph)
list_of_nodes = parse(graph)
nodes = []
for node in list_of_nodes:
Expand Down
4 changes: 2 additions & 2 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def add_text(self, tag, text_string, global_step=None):
def add_graph_onnx(self, prototxt):
self.file_writer.add_graph_onnx(gg(prototxt))

def add_graph(self, model, input_to_model):
def add_graph(self, model, input_to_model, verbose=False):
# prohibit second call?
# no, let tensorboard handles it and show its warning message.
"""Add graph data to summary.
Expand All @@ -397,7 +397,7 @@ def add_graph(self, model, input_to_model):
if not hasattr(torch.autograd.Variable, 'grad_fn'):
print('add_graph() only supports PyTorch v0.2.')
return
self.file_writer.add_graph(graph(model, input_to_model))
self.file_writer.add_graph(graph(model, input_to_model, verbose))

def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, tag='default'):
"""Add embedding projector data to summary.
Expand Down

0 comments on commit a01dec2

Please sign in to comment.