Skip to content

Commit

Permalink
Fixes tensorflow#110: Corrected neural_translation_word example to ru…
Browse files Browse the repository at this point in the history
…n and train a translation model.
  • Loading branch information
ilblackdragon committed Feb 25, 2016
1 parent a28882f commit d16868d
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 33 deletions.
56 changes: 38 additions & 18 deletions examples/neural_translation_word.py
Expand Up @@ -16,10 +16,12 @@

from __future__ import division, print_function, absolute_import

import cPickle
import itertools
import os
import numpy as np
import random

import numpy as np
import tensorflow as tf

import skflow
Expand Down Expand Up @@ -89,20 +91,29 @@ def split_lines(data):

MAX_DOCUMENT_LENGTH = 10

X_vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH,
min_frequency=5)
y_vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH,
min_frequency=5)
Xtrainff, ytrainff = Xy(read_iterator('train.data'))
print('Fitting dictionary for English...')
X_vocab_processor.fit(Xtrainff)
print('Fitting dictionary for French...')
y_vocab_processor.fit(ytrainff)
if not (os.path.exists('en.vocab') and os.path.exists('fr.vocab')):
X_vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH,
min_frequency=5)
y_vocab_processor = skflow.preprocessing.VocabularyProcessor(MAX_DOCUMENT_LENGTH,
min_frequency=5)
Xtrainff, ytrainff = Xy(read_iterator('train.data'))
print('Fitting dictionary for English...')
X_vocab_processor.fit(Xtrainff)
print('Fitting dictionary for French...')
y_vocab_processor.fit(ytrainff)
open('en.vocab', 'w').write(cPickle.dumps(X_vocab_processor))
open('fr.vocab', 'w').write(cPickle.dumps(y_vocab_processor))
else:
X_vocab_processor = cPickle.loads(open('en.vocab').read())
y_vocab_processor = cPickle.loads(open('fr.vocab').read())
print('Transforming...')
X_train = X_vocab_processor.transform(X_train)
y_train = y_vocab_processor.transform(y_train)
X_test = np.array(list(X_vocab_processor.transform(X_test))[:20])
y_test = list(y_test)[:20]
X_test = X_vocab_processor.transform(X_test)

# TODO: Expand this to use the whole test set.
X_test = np.array([X_test.next() for _ in range(1000)])
y_test = [y_test.next() for _ in range(1000)]

n_words = len(X_vocab_processor.vocabulary_)
print('Total words: %d' % n_words)
Expand All @@ -116,32 +127,41 @@ def translate_model(X, y):
word_vectors = skflow.ops.categorical_variable(X, n_classes=n_words,
embedding_size=EMBEDDING_SIZE, name='words')
in_X, in_y, out_y = skflow.ops.seq2seq_inputs(
word_list, y, MAX_DOCUMENT_LENGTH, MAX_DOCUMENT_LENGTH)
cell = tf.nn.rnn_cell.OutputProjectionWrapper(tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE), 256)
decoding, _, sampling_decoding, _ = skflow.ops.rnn_seq2seq(in_X, in_y, cell)
word_vectors, y, MAX_DOCUMENT_LENGTH, MAX_DOCUMENT_LENGTH)
encoder_cell = tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE)
decoder_cell = tf.nn.rnn_cell.OutputProjectionWrapper(
tf.nn.rnn_cell.GRUCell(HIDDEN_SIZE), n_words)
decoding, _, sampling_decoding, _ = skflow.ops.rnn_seq2seq(in_X, in_y,
encoder_cell, decoder_cell=decoder_cell)
return skflow.ops.sequence_classifier(decoding, out_y, sampling_decoding)


PATH = '/tmp/tf_examples/ntm_words/'

if os.path.exists(PATH):
if os.path.exists(os.path.join(PATH, 'graph.pbtxt')):
translator = skflow.TensorFlowEstimator.restore(PATH)
else:
translator = skflow.TensorFlowEstimator(model_fn=translate_model,
n_classes=n_words,
optimizer='Adam', learning_rate=0.01, batch_size=128,
continue_training=True)
continue_training=True, steps=100)

while True:
translator.fit(X_train, y_train, logdir=PATH)
translator.save(PATH)

xpred, ygold = [], []
for _ in range(10):
idx = random.randint(0, len(X_test))
xpred.append(X_test[idx])
ygold.append(y_test[idx])
xpred = np.array(xpred)
predictions = translator.predict(xpred, axis=2)
xpred_inp = X_vocab_processor.reverse(xpred)
text_outputs = y_vocab_processor.reverse(predictions)
for inp_data, input_text, pred, output_text, gold in zip(xpred, xpred_inp,
predictions, text_outputs, ygold):
print('English: %s. French (pred): %s, French (gold): %s' %
(input_text, output_text, gold.decode('utf-8')))
(input_text, output_text, gold))
print(inp_data, pred)

17 changes: 10 additions & 7 deletions skflow/ops/seq2seq_ops.py
Expand Up @@ -68,9 +68,10 @@ def seq2seq_inputs(X, y, input_length, output_length, sentinel=None, name=None):
in_X = array_ops.split_squeeze(1, input_length, X)
y = array_ops.split_squeeze(1, output_length, y)
if not sentinel:
# Set to zeros of shape of X[0]
sentinel = tf.zeros(tf.shape(in_X[0]))
sentinel.set_shape(in_X[0].get_shape())
# Set to zeros of shape of y[0], using X for batch size.
sentinel_shape = tf.pack([tf.shape(X)[0], y[0].get_shape()[1]])
sentinel = tf.zeros(sentinel_shape)
sentinel.set_shape(y[0].get_shape())
in_y = [sentinel] + y
out_y = y + [sentinel]
return in_X, in_y, out_y
Expand Down Expand Up @@ -112,20 +113,22 @@ def rnn_decoder(decoder_inputs, initial_state, cell, scope=None):
return outputs, states, sampling_outputs, sampling_states


def rnn_seq2seq(encoder_inputs, decoder_inputs, cell, dtype=tf.float32, scope=None):
def rnn_seq2seq(encoder_inputs, decoder_inputs, encoder_cell, decoder_cell=None,
dtype=tf.float32, scope=None):
"""RNN Sequence to Sequence model.
Args:
encoder_inputs: List of tensors, inputs for encoder.
decoder_inputs: List of tensors, inputs for decoder.
cell: RNN cell to use for encoder and decoder.
encoder_cell: RNN cell to use for encoder.
decoder_cell: RNN cell to use for decoder, if None encoder_cell is used.
dtype: Type to initialize encoder state with.
scope: Scope to use, if None new will be produced.
Returns:
List of tensors for outputs and states for trianing and sampling sub-graphs.
"""
with tf.variable_scope(scope or "rnn_seq2seq"):
_, enc_states = tf.nn.rnn(cell, encoder_inputs, dtype=dtype)
return rnn_decoder(decoder_inputs, enc_states[-1], cell)
_, enc_states = tf.nn.rnn(encoder_cell, encoder_inputs, dtype=dtype)
return rnn_decoder(decoder_inputs, enc_states[-1], decoder_cell or encoder_cell)

16 changes: 8 additions & 8 deletions skflow/ops/tests/test_seq2seq_ops.py
Expand Up @@ -42,22 +42,22 @@ def test_sequence_classifier(self):

def test_seq2seq_inputs(self):
inp = np.array([[[1, 0], [0, 1], [1, 0]], [[0, 1], [1, 0], [0, 1]]])
out = np.array([[[0, 1], [1, 0]], [[1, 0], [0, 1]]])
out = np.array([[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [0, 1, 0]]])
with self.test_session() as session:
X = tf.placeholder(tf.float32, [2, 3, 2])
y = tf.placeholder(tf.float32, [2, 2, 2])
y = tf.placeholder(tf.float32, [2, 2, 3])
in_X, in_y, out_y = ops.seq2seq_inputs(X, y, 3, 2)
enc_inp = session.run(in_X, feed_dict={X.name: inp})
dec_inp = session.run(in_y, feed_dict={X.name: inp, y.name: out})
dec_out = session.run(out_y, feed_dict={X.name: inp, y.name: out})
# Swaps from batch x len x height to list of len of batch x height.
self.assertAllEqual(enc_inp, np.swapaxes(inp, 0, 1))
self.assertAllEqual(dec_inp, [[[0, 0], [0, 0]],
[[0, 1], [1, 0]],
[[1, 0], [0, 1]]])
self.assertAllEqual(dec_out, [[[0, 1], [1, 0]],
[[1, 0], [0, 1]],
[[0, 0], [0, 0]]])
self.assertAllEqual(dec_inp, [[[0, 0, 0], [0, 0, 0]],
[[0, 1, 0], [1, 0, 0]],
[[1, 0, 0], [0, 1, 0]]])
self.assertAllEqual(dec_out, [[[0, 1, 0], [1, 0, 0]],
[[1, 0, 0], [0, 1, 0]],
[[0, 0, 0], [0, 0, 0]]])

def test_rnn_decoder(self):
with self.test_session() as session:
Expand Down

0 comments on commit d16868d

Please sign in to comment.