From 3fa9d526e3ec23d22d6a1049988a648704f635aa Mon Sep 17 00:00:00 2001 From: Daniel Hershcovich Date: Sun, 13 Nov 2016 10:19:30 +0200 Subject: [PATCH] Fix #161: support save/load BiRNNBuilder --- python/dynet.pyx | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/python/dynet.pyx b/python/dynet.pyx index 1a48dc730..a4b2c39e6 100644 --- a/python/dynet.pyx +++ b/python/dynet.pyx @@ -312,9 +312,14 @@ cdef class Model: # {{{ elif isinstance(c, SimpleRNNBuilder): saver.add_srnn_builder(((c).thisptr)[0]) fh.write("srnn_builder ") + elif isinstance(c, BiRNNBuilder): + fh.write("birnn_builder~%d " % (2 * len(c.builder_layers))) + for (f,b) in c.builder_layers: + self._save_one(f,saver,fh,pfh) + self._save_one(b,saver,fh,pfh) elif isinstance(c, Saveable): cs = c.get_components() - fh.write("user~%s " % len(cs)) + fh.write("user~%d " % len(cs)) pickle.dump(c,pfh) for subc in cs: self._save_one(subc,saver,fh,pfh) @@ -362,6 +367,11 @@ cdef class Model: # {{{ sb_ = SimpleRNNBuilder(0,0,0,self) # empty builder loader.fill_srnn_builder((sb_.thisptr)[0]) return sb_ + elif tp.startswith("birnn_builder~"): + tp,num = tp.split("~",1) + num = int(num) + items = [self._load_one(itypes, loader, pfh) for _ in xrange(num)] + return BiRNNBuilder(None, None, None, None, None, zip(items[0::2], items[1::2])) elif tp.startswith("user~"): # user defiend type tp,num = tp.split("~",1) @@ -1126,24 +1136,28 @@ class BiRNNBuilder(object): builder = BiRNNBuilder(1, 128, 100, model, LSTMBuilder) [o1,o2,o3] = builder.transduce([i1,i2,i3]) """ - def __init__(self, num_layers, input_dim, hidden_dim, model, rnn_builder_factory): + def __init__(self, num_layers, input_dim, hidden_dim, model, rnn_builder_factory, builder_layers=None): """ @param num_layers: depth of the BiRNN @param input_dim: size of the inputs @param hidden_dim: size of the outputs (and intermediate layer representations) @param model @param rnn_builder_factory: RNNBuilder subclass, e.g. LSTMBuilder + @param builder_layers: list of (forward, backward) pairs of RNNBuilder instances to directly initialize layers """ - assert num_layers > 0 - assert hidden_dim % 2 == 0 - self.builder_layers = [] - f = rnn_builder_factory(1, input_dim, hidden_dim/2, model) - b = rnn_builder_factory(1, input_dim, hidden_dim/2, model) - self.builder_layers.append((f,b)) - for _ in xrange(num_layers-1): - f = rnn_builder_factory(1, hidden_dim, hidden_dim/2, model) - b = rnn_builder_factory(1, hidden_dim, hidden_dim/2, model) + if builder_layers is None: + assert num_layers > 0 + assert hidden_dim % 2 == 0 + self.builder_layers = [] + f = rnn_builder_factory(1, input_dim, hidden_dim/2, model) + b = rnn_builder_factory(1, input_dim, hidden_dim/2, model) self.builder_layers.append((f,b)) + for _ in xrange(num_layers-1): + f = rnn_builder_factory(1, hidden_dim, hidden_dim/2, model) + b = rnn_builder_factory(1, hidden_dim, hidden_dim/2, model) + self.builder_layers.append((f,b)) + else: + self.builder_layers = builder_layers def whoami(self): return "BiRNNBuilder"