Skip to content

Commit

Permalink
improves error handling, adds networks in demo
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa committed May 8, 2018
1 parent ee35e97 commit 989ce08
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 3 deletions.
59 changes: 59 additions & 0 deletions demo_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,62 @@ def forward(self, x):
with SummaryWriter(comment='resnet18') as w:
model = torchvision.models.resnet18()
w.add_graph(model, (dummy_input, ))



class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x*2


model = SimpleModel()
dummy_input = (torch.zeros(1, 2, 3),)

with SummaryWriter(comment='constantModel') as w:
w.add_graph(model, dummy_input)


def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)

class BasicBlock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
# self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.stride = stride

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = F.relu(out)
return out


dummy_input = torch.rand(1, 3, 224, 224)

with SummaryWriter(comment='basicblock') as w:
model = BasicBlock(3,3)
w.add_graph(model, (dummy_input, ))#, verbose=True)

import pytest
with pytest.raises(Exception) as e_info:
dummy_input = torch.rand(1, 1, 224, 224)
with SummaryWriter(comment='basicblock_error') as w:
w.add_graph(model, (dummy_input, )) # error

8 changes: 5 additions & 3 deletions tensorboardX/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,14 @@ def graph(model, args, verbose=False):
try:
trace, _ = torch.jit.get_trace_graph(model, args)
except RuntimeError:
print("Error occurs, checking if it's onnx problem...")
print('Error occurs, No graph saved')
_ = model(args) # don't catch, just print the error message
print("Checking if it's onnx problem...")
try:
torch.onnx.export(model, args, "/tmp/dummy.pb", verbose=True)
import tempfile
torch.onnx.export(model, args, tempfile.TemporaryFile(), verbose=True)
except RuntimeError:
print("Your model fails onnx too, please report to onnx team")
print('No graph saved')
return GraphDef(versions=VersionDef(producer=22))
if LooseVersion(torch.__version__) >= LooseVersion("0.4"):
torch.onnx._optimize_trace(trace, False)
Expand Down

0 comments on commit 989ce08

Please sign in to comment.