Skip to content

Commit

Permalink
optimize graph for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-dbq committed Sep 19, 2017
1 parent 6e46073 commit b232b3c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
8 changes: 4 additions & 4 deletions python/sparkdl/graph/builder.py
Expand Up @@ -47,19 +47,20 @@ def __init__(self, graph=None, using_keras=False):
self.graph = graph or tf.Graph()
self.sess = tf.Session(graph=self.graph)
if using_keras:
self.using_keras = True
self.keras_prev_sess = K.get_session()
else:
self.using_keras = False
self.keras_prev_sess = None

def __enter__(self):
self.sess.as_default()
self.sess.__enter__()
if self.keras_prev_sess is not None:
if self.using_keras:
K.set_session(self.sess)
return self

def __exit__(self, *args):
if self.keras_prev_sess is not None:
if self.using_keras:
K.set_session(self.keras_prev_sess)
self.sess.__exit__(*args)

Expand Down Expand Up @@ -268,4 +269,3 @@ def fromList(cls, functions):
gfn = issn.asGraphFunction(first_inputs, last_outputs)

return gfn

20 changes: 18 additions & 2 deletions python/sparkdl/transformers/tf_tensor.py
Expand Up @@ -16,6 +16,7 @@

import logging
import tensorflow as tf
from tensorflow.python.tools import optimize_for_inference_lib as infr_opt
import tensorframes as tfs

from pyspark.ml import Transformer
Expand Down Expand Up @@ -60,17 +61,32 @@ def setParams(self, tfInputGraph=None, inputMapping=None, outputMapping=None, tf
# Further conanonicalization, e.g. converting dict to sorted str pairs happens here
return self._set(**kwargs)

def _transform(self, dataset):
def _optimize_for_inference(self):
""" Optimize the graph for inference """
gin = self.getTFInputGraph()
input_mapping = self.getInputMapping()
output_mapping = self.getOutputMapping()
input_node_names = [tfx.as_op_name(tnsr_name) for _, tnsr_name in input_mapping]
output_node_names = [tfx.as_op_name(tnsr_name) for tnsr_name, _ in output_mapping]

# NOTE(phi-dbq): Spark DataFrame assumes float64 as default floating point type
opt_gdef = infr_opt.optimize_for_inference(gin.graph_def,
input_node_names,
output_node_names,
tf.float64.as_datatype_enum)
return opt_gdef

def _transform(self, dataset):
graph_def = self._optimize_for_inference()
input_mapping = self.getInputMapping()
output_mapping = self.getOutputMapping()

graph = tf.Graph()
with tf.Session(graph=graph):
analyzed_df = tfs.analyze(dataset)

out_tnsr_op_names = [tfx.as_op_name(tnsr_name) for tnsr_name, _ in output_mapping]
tf.import_graph_def(graph_def=gin.graph_def, name='', return_elements=out_tnsr_op_names)
tf.import_graph_def(graph_def=graph_def, name='', return_elements=out_tnsr_op_names)

feed_dict = dict((tfx.op_name(graph, tnsr_name), col_name)
for col_name, tnsr_name in input_mapping)
Expand Down
4 changes: 3 additions & 1 deletion python/tests/transformers/tf_tensor_test.py
Expand Up @@ -157,8 +157,10 @@ def _run_test_in_tf_session(self):
_results.append(np.ravel(curr_res))
out_tgt = np.hstack(_results)

err_msg = 'not close => {} != {}, max_diff {}'
self.assertTrue(np.allclose(out_ref, out_tgt),
msg='not close => {} != {}'.format(out_ref.shape, out_tgt.shape))
msg=err_msg.format(out_ref.shape, out_tgt.shape,
np.max(np.abs(out_ref - out_tgt))))


def test_build_from_tf_graph(self):
Expand Down

0 comments on commit b232b3c

Please sign in to comment.