Skip to content

Commit

Permalink
restore from checkpoint example
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-dbq committed Aug 12, 2017
1 parent 09024f3 commit 976ac08
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions python/tests/transformers/tf_tensor_test.py
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import shutil
import tempfile

from keras.layers import Conv1D, Dense, Flatten, MaxPool1D
import numpy as np
Expand All @@ -38,6 +41,54 @@ def _get_rand_vec_df(self, num_rows, vec_size):
Row(idx=idx, vec=np.random.randn(vec_size).tolist())
for idx in range(num_rows))

def test_checkpoint_reload(self):
vec_size = 17
num_vecs = 31
df = self._get_rand_vec_df(num_vecs, vec_size)
analyzed_df = tfs.analyze(df)
input_col = 'vec'
output_col = 'outputCol'

# Build the TensorFlow graph
model_temp_dir = tempfile.mkdtemp()
ckpt_dir = os.path.join(model_temp_dir, 'model_ckpt')
with tf.Session() as sess:
x = tf.placeholder(tf.float64, shape=[None, vec_size], name='tnsrIn')
w = tf.Variable(tf.random_normal([vec_size], dtype=tf.float64),
dtype=tf.float64, name='varW')
z = tf.reduce_mean(x * w, axis=1, name='tnsrOut')
sess.run(w.initializer)
saver = tf.train.Saver(var_list=[w])
saved_path = saver.save(sess, ckpt_dir, global_step=2702)

# Get the reference data
_results = []
for row in df.rdd.toLocalIterator():
arr = np.array(row.vec)[np.newaxis, :]
_results.append(sess.run(z, {x: arr}))
out_ref = np.hstack(_results)

# Load the saved model checkpoint
with IsolatedSession() as issn:
saver = tf.train.import_meta_graph('{}.meta'.format(saved_path), clear_devices=True)
saver.restore(issn.sess, saved_path)
gfn = issn.asGraphFunction(
[tfx.get_tensor(issn.graph, 'tnsrIn')],
[tfx.get_tensor(issn.graph, 'tnsrOut')])

transformer = TFModelTransformer(tfGraph=gfn,
inputMapping={
input_col: 'tnsrIn'
},
outputMapping={
'tnsrOut': output_col
})
final_df = transformer.transform(analyzed_df)
out_tgt = grab_df_arr(final_df, output_col)

shutil.rmtree(model_temp_dir, ignore_errors=True)
self.assertTrue(np.allclose(out_ref, out_tgt))

def test_simple(self):
# Build a simple input DataFrame
vec_size = 17
Expand Down

0 comments on commit 976ac08

Please sign in to comment.