Skip to content

Commit

Permalink
supports rnn graph
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa committed Apr 30, 2018
1 parent 5619f0a commit 87eee66
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
43 changes: 43 additions & 0 deletions demo_rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.autograd import Variable
from tensorboardX import SummaryWriter

class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.i2h = nn.Linear(n_categories + input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(n_categories + input_size + hidden_size, output_size)
self.o2o = nn.Linear(hidden_size + output_size, output_size)
self.dropout = nn.Dropout(0.1)
self.softmax = nn.LogSoftmax(dim=1)

def forward(self, category, input, hidden):
input_combined = torch.cat((category, input, hidden), 1)
hidden = self.i2h(input_combined)
output = self.i2o(input_combined)
output_combined = torch.cat((hidden, output), 1)
output = self.o2o(output_combined)
output = self.dropout(output)
output = self.softmax(output)
return output, hidden

def initHidden(self):
return torch.zeros(1, self.hidden_size)


n_letters = 100
n_hidden = 128
n_categories = 10
rnn = RNN(n_letters, n_hidden, n_categories)
cat = torch.Tensor(1, n_categories)
dummy_input = torch.Tensor(1, n_letters)
hidden = torch.Tensor(1, n_hidden)


out, hidden = rnn(cat, dummy_input, hidden)
with SummaryWriter(comment='RNN') as w:
w.add_graph(rnn, (cat, dummy_input, hidden), verbose=True)
9 changes: 8 additions & 1 deletion tensorboardX/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def parse(graph):
for n in graph.nodes():
inputs = [i.uniqueName() for i in n.inputs()]
for i in range(1, len(inputs)):
scope[inputs[i]] = n.scopeName()
if inputs[i] not in scope.keys():
scope[inputs[i]] = n.scopeName()

uname = next(iter(n.outputs())).uniqueName()
assert n.scopeName() != '', '{} has empty scope name'.format(n)
Expand All @@ -24,6 +25,12 @@ def parse(graph):
scope['1'] = 'input'

nodes = []

for count, n in enumerate(graph.outputs()):
uname = 'output' + str(count)
scope[uname] = 'output'
nodes.append({'name': uname, 'op': 'output', 'inputs': [n.uniqueName()], 'attr': 'output'})

for n in graph.nodes():
attrs = {k: n[k] for k in n.attributeNames()}
attrs = str(attrs).replace("'", ' ') # singlequote will be escaped by tensorboard
Expand Down

0 comments on commit 87eee66

Please sign in to comment.