Skip to content

Commit

Permalink
update doc for add_graph
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa committed Jan 4, 2018
1 parent e951aef commit 6d879e6
Showing 1 changed file with 4 additions and 16 deletions.
20 changes: 4 additions & 16 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,40 +376,28 @@ def add_text(self, tag, text_string, global_step=None):
def add_graph_onnx(self, prototxt):
self.file_writer.add_graph_onnx(gg(prototxt))

def add_graph(self, model, lastVar):
def add_graph(self, model, input_to_model):
# prohibit second call?
# no, let tensorboard handles it and show its warning message.
"""Add graph data to summary.
To draw the graph, you need a model ``m`` and an input variable ``t`` that have correct size for ``m``.
Say you have runned ``r = m(t)``, then you can use ``writer.add_graph(m, r)`` to save the graph.
By default, the input tensor does not require gradient, therefore it will be omitted when back tracing.
To draw the input node, pass an additional parameter ``requires_grad=True`` when creating the input tensor.
Args:
model (torch.nn.Module): model to draw.
lastVar (torch.autograd.Variable): the root node start from.
.. note::
This is experimental feature. Graph drawing is based on autograd's backward tracing.
It goes along the ``next_functions`` attribute in a variable recursively, drawing each encountered nodes.
In some cases, the result is strange. See https://github.com/lanpa/tensorboard-pytorch/issues/7 and
https://github.com/lanpa/tensorboard-pytorch/issues/9
input_to_model (torch.autograd.Variable): a variable or a tuple of variables to be fed.
The implementation will be based to onnx backend as soon as onnx is stable enough.
"""
import torch
from distutils.version import LooseVersion
if LooseVersion(torch.__version__) >= LooseVersion("0.4"):
pass
else:
if LooseVersion(torch.__version__) >= LooseVersion("0.3"):
print('add_graph() only supports PyTorch v0.2. For PyTorch>=0.3, use add_graph_onnx()')
print('You are using PyTorch==0.3, use add_graph_onnx()')
return
if not hasattr(torch.autograd.Variable, 'grad_fn'):
print('add_graph() only supports PyTorch v0.2.')
return
self.file_writer.add_graph(graph(model, lastVar))
self.file_writer.add_graph(graph(model, input_to_model))

def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, tag='default'):
"""Add embedding projector data to summary.
Expand Down

0 comments on commit 6d879e6

Please sign in to comment.