Skip to content

Commit

Permalink
Merge pull request #163 from danielhers/birnn_save
Browse files Browse the repository at this point in the history
Fix #161: support save/load BiRNNBuilder
  • Loading branch information
yoavg committed Nov 17, 2016
2 parents d5320ef + 3fa9d52 commit 587059d
Showing 1 changed file with 25 additions and 11 deletions.
36 changes: 25 additions & 11 deletions python/dynet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,14 @@ cdef class Model: # {{{
elif isinstance(c, SimpleRNNBuilder):
saver.add_srnn_builder((<CSimpleRNNBuilder*>(<SimpleRNNBuilder>c).thisptr)[0])
fh.write("srnn_builder ")
elif isinstance(c, BiRNNBuilder):
fh.write("birnn_builder~%d " % (2 * len(c.builder_layers)))
for (f,b) in c.builder_layers:
self._save_one(f,saver,fh,pfh)
self._save_one(b,saver,fh,pfh)
elif isinstance(c, Saveable):
cs = c.get_components()
fh.write("user~%s " % len(cs))
fh.write("user~%d " % len(cs))
pickle.dump(c,pfh)
for subc in cs:
self._save_one(subc,saver,fh,pfh)
Expand Down Expand Up @@ -362,6 +367,11 @@ cdef class Model: # {{{
sb_ = SimpleRNNBuilder(0,0,0,self) # empty builder
loader.fill_srnn_builder((<CSimpleRNNBuilder *>sb_.thisptr)[0])
return sb_
elif tp.startswith("birnn_builder~"):
tp,num = tp.split("~",1)
num = int(num)
items = [self._load_one(itypes, loader, pfh) for _ in xrange(num)]
return BiRNNBuilder(None, None, None, None, None, zip(items[0::2], items[1::2]))
elif tp.startswith("user~"):
# user defiend type
tp,num = tp.split("~",1)
Expand Down Expand Up @@ -1126,24 +1136,28 @@ class BiRNNBuilder(object):
builder = BiRNNBuilder(1, 128, 100, model, LSTMBuilder)
[o1,o2,o3] = builder.transduce([i1,i2,i3])
"""
def __init__(self, num_layers, input_dim, hidden_dim, model, rnn_builder_factory):
def __init__(self, num_layers, input_dim, hidden_dim, model, rnn_builder_factory, builder_layers=None):
"""
@param num_layers: depth of the BiRNN
@param input_dim: size of the inputs
@param hidden_dim: size of the outputs (and intermediate layer representations)
@param model
@param rnn_builder_factory: RNNBuilder subclass, e.g. LSTMBuilder
@param builder_layers: list of (forward, backward) pairs of RNNBuilder instances to directly initialize layers
"""
assert num_layers > 0
assert hidden_dim % 2 == 0
self.builder_layers = []
f = rnn_builder_factory(1, input_dim, hidden_dim/2, model)
b = rnn_builder_factory(1, input_dim, hidden_dim/2, model)
self.builder_layers.append((f,b))
for _ in xrange(num_layers-1):
f = rnn_builder_factory(1, hidden_dim, hidden_dim/2, model)
b = rnn_builder_factory(1, hidden_dim, hidden_dim/2, model)
if builder_layers is None:
assert num_layers > 0
assert hidden_dim % 2 == 0
self.builder_layers = []
f = rnn_builder_factory(1, input_dim, hidden_dim/2, model)
b = rnn_builder_factory(1, input_dim, hidden_dim/2, model)
self.builder_layers.append((f,b))
for _ in xrange(num_layers-1):
f = rnn_builder_factory(1, hidden_dim, hidden_dim/2, model)
b = rnn_builder_factory(1, hidden_dim, hidden_dim/2, model)
self.builder_layers.append((f,b))
else:
self.builder_layers = builder_layers

def whoami(self): return "BiRNNBuilder"

Expand Down

0 comments on commit 587059d

Please sign in to comment.