Skip to content

Commit

Permalink
Factor up LM base
Browse files Browse the repository at this point in the history
  • Loading branch information
dpressel committed Oct 22, 2018
1 parent caf9910 commit 6175cef
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 68 deletions.
38 changes: 27 additions & 11 deletions python/baseline/dy/lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,24 @@
from baseline.dy.dynety import *


@register_model(task='lm', name='default')
class BasicLanguageModel(DynetModel, LanguageModel):
class LanguageModelBase(DynetModel, LanguageModel):

@classmethod
def create(cls, embeddings, **kwargs):
return cls(embeddings, **kwargs)

def __init__(self, embeddings, layers=1, hsz=650, dropout=None, **kwargs):
super(BasicLanguageModel, self).__init__(kwargs['pc'])
super(LanguageModelBase, self).__init__(kwargs['pc'])
self.tgt_key = kwargs.get('tgt_key')
vsz = embeddings[self.tgt_key].vsz
dsz = self.init_embed(embeddings)
self._rnn = dy.VanillaLSTMBuilder(layers, dsz, hsz, self.pc)
self.init_decode(dsz, layers, hsz, **kwargs)
self._output = Linear(vsz, hsz, self.pc, name="output")
self.dropout = dropout

def init_decode(self, dsz, layers=1, hsz=650, **kwargs):
pass

def init_embed(self, embeddings):
dsz = 0
self.embeddings = embeddings
Expand Down Expand Up @@ -46,13 +48,7 @@ def output(self, input_):
return [self._output(x) for x in input_]

def decode(self, input_, state, train):
if train:
if self.dropout is not None:
self._rnn.set_dropout(self.dropout)
else:
self._rnn.disable_dropout()
transduced, last_state = rnn_forward_with_state(self._rnn, input_, None, state)
return transduced, last_state
pass

def forward(self, input_, state=None, train=True):
input_ = self.embed(input_)
Expand All @@ -67,3 +63,23 @@ def save(self, file_name):
def load(self, file_name):
self.pc.populate(file_name)
return self


@register_model(task='lm', name='default')
class RNNLanguageModel(LanguageModelBase):

def __init__(self, embeddings, layers=1, hsz=650, dropout=None, **kwargs):
self._rnn = None
super(RNNLanguageModel, self).__init__(embeddings, layers, hsz, dropout, **kwargs)

def init_decode(self, dsz, layers=1, hsz=650, **kwargs):
self._rnn = dy.VanillaLSTMBuilder(layers, dsz, hsz, self.pc)

def decode(self, input_, state, train):
if train:
if self.dropout is not None:
self._rnn.set_dropout(self.dropout)
else:
self._rnn.disable_dropout()
transduced, last_state = rnn_forward_with_state(self._rnn, input_, None, state)
return transduced, last_state
1 change: 1 addition & 0 deletions python/baseline/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def create_model_for(activity, input_, output_, **kwargs):
return creator_fn(input_, output_, **kwargs)
return creator_fn(input_, **kwargs)


@exporter
def create_model(embeddings, labels, **kwargs):
return create_model_for('classify', embeddings, labels, **kwargs)
Expand Down
56 changes: 34 additions & 22 deletions python/baseline/pytorch/lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,18 @@
import math


@register_model(task='lm', name='default')
class BasicLanguageModel(nn.Module, LanguageModel):
class LanguageModelBase(nn.Module, LanguageModel):
def __init__(self):
super(BasicLanguageModel, self).__init__()
super(LanguageModelBase, self).__init__()

def save(self, outname):
torch.save(self, outname)

def create_loss(self):
return SequenceCriterion(LossFn=nn.CrossEntropyLoss)

@staticmethod
def load(filename, **kwargs):
@classmethod
def load(cls, filename, **kwargs):
if not os.path.exists(filename):
filename += '.pyt'
model = torch.load(filename)
Expand Down Expand Up @@ -61,25 +60,10 @@ def init_embed(self, embeddings, **kwargs):
return input_sz

def init_decode(self, vsz, **kwargs):
pdrop = float(kwargs.get('dropout', 0.5))
vdrop = bool(kwargs.get('variational_dropout', False))
unif = float(kwargs.get('unif', 0.0))
if vdrop:
self.rnn_dropout = VariationalDropout(pdrop)
else:
self.rnn_dropout = nn.Dropout(pdrop)

self.rnn = pytorch_lstm(self.dsz, self.hsz, 'lstm', self.layers, pdrop, batch_first=True)
self.decoder = nn.Sequential()
append2seq(self.decoder, (
pytorch_linear(self.hsz, vsz, unif),
))
pass

def decode(self, emb, hidden):
output, hidden = self.rnn(emb, hidden)
output = self.rnn_dropout(output).contiguous()
decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden
pass

@classmethod
def create(cls, embeddings, **kwargs):
Expand All @@ -99,3 +83,31 @@ def create(cls, embeddings, **kwargs):
def forward(self, input, hidden):
emb = self.embed(input)
return self.decode(emb, hidden)


@register_model(task='lm', name='default')
class RNNLanguageModel(LanguageModelBase):

def __init__(self):
super(RNNLanguageModel, self).__init__()

def init_decode(self, vsz, **kwargs):
pdrop = float(kwargs.get('dropout', 0.5))
vdrop = bool(kwargs.get('variational_dropout', False))
unif = float(kwargs.get('unif', 0.0))
if vdrop:
self.rnn_dropout = VariationalDropout(pdrop)
else:
self.rnn_dropout = nn.Dropout(pdrop)

self.rnn = pytorch_lstm(self.dsz, self.hsz, 'lstm', self.layers, pdrop, batch_first=True)
self.decoder = nn.Sequential()
append2seq(self.decoder, (
pytorch_linear(self.hsz, vsz, unif),
))

def decode(self, emb, hidden):
output, hidden = self.rnn(emb, hidden)
output = self.rnn_dropout(output).contiguous()
decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden
88 changes: 53 additions & 35 deletions python/baseline/tf/lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,20 @@
from google.protobuf import text_format


@register_model(task='lm', name='default')
class BasicLanguageModel(LanguageModel):
class LanguageModelBase(LanguageModel):

def __init__(self):
self.layers = None
self.hsz = None
self.rnntype = 'lstm'
self.pkeep = None
self.saver = None

@property
def vdrop(self):
return self._vdrop

@vdrop.setter
def vdrop(self, value):
self._vdrop = value
self.layers = None
self.hsz = None
self.probs = None

def save_using(self, saver):
self.saver = saver

def decode(self, inputs, vsz):

def cell():
return lstm_cell_w_dropout(self.hsz, self.pkeep, variational=self.vdrop)

rnnfwd = tf.contrib.rnn.MultiRNNCell([cell() for _ in range(self.layers)], state_is_tuple=True)
self.initial_state = rnnfwd.zero_state(self.batchsz, tf.float32)
rnnout, state = tf.nn.dynamic_rnn(rnnfwd, inputs, initial_state=self.initial_state, dtype=tf.float32)
output = tf.reshape(tf.concat(rnnout, 1), [-1, self.hsz])
vocab_w = tf.get_variable(
"vocab_w", [self.hsz, vsz], dtype=tf.float32)
vocab_b = tf.get_variable("vocab_b", [vsz], dtype=tf.float32)

self.logits = tf.nn.xw_plus_b(output, vocab_w, vocab_b, name="logits")
self.probs = tf.nn.softmax(self.logits, name="softmax")
self.final_state = state
pass

def save_values(self, basename):
self.saver.save(self.sess, basename)
Expand All @@ -57,7 +34,6 @@ def save_md(self, basename):
base = path[-1]
outdir = '/'.join(path[:-1])

#state = {'hsz': self.hsz, 'batchsz': self.batchsz, 'layers': self.layers}
embeddings_info = {}
for k, v in self.embeddings.items():
embeddings_info[k] = v.__class__.__name__
Expand Down Expand Up @@ -118,8 +94,6 @@ def create(cls, embeddings, **kwargs):
lm.y = kwargs.get('y', tf.placeholder(tf.int32, [None, None], name="y"))
lm.batchsz = kwargs['batchsz']
lm.sess = kwargs.get('sess', tf.Session())
lm.rnntype = kwargs.get('rnntype', 'lstm')
lm.vdrop = kwargs.get('variational_dropout', False)
lm.pkeep = kwargs.get('pkeep', tf.placeholder(tf.float32, name="pkeep"))
pdrop = kwargs.get('pdrop', 0.5)
lm.pdrop_value = pdrop
Expand All @@ -134,7 +108,7 @@ def create(cls, embeddings, **kwargs):

inputs = lm.embed()
lm.layers = kwargs.get('layers', 1)
lm.decode(inputs, embeddings[lm.tgt_key].vsz)
lm.decode(inputs, embeddings[lm.tgt_key].vsz, **kwargs)
return lm

def embed(self):
Expand All @@ -150,8 +124,8 @@ def embed(self):
word_embeddings = tf.concat(values=all_embeddings_out, axis=2)
return tf.nn.dropout(word_embeddings, self.pkeep)

@staticmethod
def load(basename, **kwargs):
@classmethod
def load(cls, basename, **kwargs):
state = read_json(basename + '.state')
if 'predict' in kwargs:
state['predict'] = kwargs['predict']
Expand Down Expand Up @@ -182,7 +156,7 @@ def load(basename, **kwargs):
Constructor = eval(class_name)
embeddings[key] = Constructor(key, **embed_args)

model = BasicLanguageModel.create(embeddings, **state)
model = cls.create(embeddings, **state)
for prop in ls_props(model):
if prop in state:
setattr(model, prop, state[prop])
Expand All @@ -195,3 +169,47 @@ def load(basename, **kwargs):
model.saver = tf.train.Saver()
model.saver.restore(model.sess, basename)
return model


@register_model(task='lm', name='default')
class RNNLanguageModel(LanguageModelBase):
def __init__(self):
super(RNNLanguageModel, self).__init__()
self.rnntype = 'lstm'

@property
def vdrop(self):
return self._vdrop

@vdrop.setter
def vdrop(self, value):
self._vdrop = value

def decode(self, inputs, vsz, **kwargs):

def cell():
return lstm_cell_w_dropout(self.hsz, self.pkeep, variational=self.vdrop)

self.rnntype = kwargs.get('rnntype', 'lstm')
self.vdrop = kwargs.get('variational_dropout', False)

rnnfwd = tf.contrib.rnn.MultiRNNCell([cell() for _ in range(self.layers)], state_is_tuple=True)
self.initial_state = rnnfwd.zero_state(self.batchsz, tf.float32)
rnnout, state = tf.nn.dynamic_rnn(rnnfwd, inputs, initial_state=self.initial_state, dtype=tf.float32)
h = tf.reshape(tf.concat(rnnout, 1), [-1, self.hsz])
self.logits = self.output(h, vsz)
self.probs = tf.nn.softmax(self.logits, name="softmax")
self.final_state = state

def output(self, h, vsz):
# Do weight sharing if we can
if self.hsz == self.embeddings[self.tgt_key].get_dsz():
with tf.variable_scope(self.embeddings[self.tgt_key].scope, reuse=True):
W = tf.get_variable("W")
return tf.matmul(h, W, transpose_b=True)
else:
vocab_w = tf.get_variable(
"vocab_w", [self.hsz, vsz], dtype=tf.float32)
vocab_b = tf.get_variable("vocab_b", [vsz], dtype=tf.float32)

return tf.nn.xw_plus_b(h, vocab_w, vocab_b, name="logits")

0 comments on commit 6175cef

Please sign in to comment.