Skip to content

Commit

Permalink
Fixes Conchylicultor#32 - Embeddings were not disable when restoring …
Browse files Browse the repository at this point in the history
…model

When initEmbeddings is used, the embeddings are removed from the
training variables. This was however not done when restoring a
model from checkpoint.
  • Loading branch information
eschnou committed Dec 21, 2016
1 parent 88fee0e commit 7775205
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions chatbot/chatbot.py
Expand Up @@ -185,10 +185,8 @@ def main(self, args=None):
if self.args.test != Chatbot.TestMode.ALL:
self.managePreviousModel(self.sess)

# Initialize embeddings with pre-trained word2vec vectors unless we are opening
# a restored model, in which case the embeddings were saved as part of the
# checkpoint.
if self.args.initEmbeddings and self.globStep == 0:
# Initialize embeddings with pre-trained word2vec vectors
if self.args.initEmbeddings:
print("Loading pre-trained embeddings from GoogleNews-vectors-negative300.bin")
self.loadEmbedding(self.sess)

Expand Down Expand Up @@ -375,8 +373,28 @@ def daemonClose(self):
def loadEmbedding(self, sess):
""" Initialize embeddings with pre-trained word2vec vectors
Will modify the embedding weights of the current loaded model
Uses the GoogleNews pre-trained values (path hardcoded)
Uses the GoogleNews pre-trained values (path hardcoded).
No need to Initialize the embeddings if we are restoring a model,
however we must still disable their training (this part of the model
and not the checkpoint state)
"""

# Fetch embedding variables from model
with tf.variable_scope("embedding_rnn_seq2seq/RNN/EmbeddingWrapper", reuse=True):
em_in = tf.get_variable("embedding")
with tf.variable_scope("embedding_rnn_seq2seq/embedding_rnn_decoder", reuse=True):
em_out = tf.get_variable("embedding")

# Disable training for embeddings
variables = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables.remove(em_in)
variables.remove(em_out)

# If restoring a model, we can leave here
if self.globStep != 0:
return

# New model, we load the pre-trained word2vec data and initialize embeddings
with open(os.path.join(self.args.rootDir, 'data/word2vec/GoogleNews-vectors-negative300.bin'), "rb", 0) as f:
header = f.readline()
vocab_size, vector_size = map(int, header.split())
Expand All @@ -403,20 +421,10 @@ def loadEmbedding(self, sess):
S[:vector_size, :vector_size] = np.diag(s)
initW = np.dot(U[:, :self.args.embeddingSize], S[:self.args.embeddingSize, :self.args.embeddingSize])

# Initialize input embeddings
with tf.variable_scope("embedding_rnn_seq2seq/RNN/EmbeddingWrapper", reuse=True):
em_in = tf.get_variable("embedding")
sess.run(em_in.assign(initW))
# Initialize input and output embeddings
sess.run(em_in.assign(initW))
sess.run(em_out.assign(initW))

# Initialize output embeddings
with tf.variable_scope("embedding_rnn_seq2seq/embedding_rnn_decoder", reuse=True):
em_out = tf.get_variable("embedding")
sess.run(em_out.assign(initW))

# Disable training for embeddings
variables = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables.remove(em_in)
variables.remove(em_out)

def managePreviousModel(self, sess):
""" Restore or reset the model, depending of the parameters
Expand Down

0 comments on commit 7775205

Please sign in to comment.