Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-dbq committed Aug 8, 2017
1 parent a5a6e07 commit 0f083c4
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 0 deletions.
93 changes: 93 additions & 0 deletions python/sparkdl/transformers/tf_tensor.py
@@ -0,0 +1,93 @@
# Copyright 2017 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf
import tensorframes as tfs

from pyspark.ml import Transformer
from pyspark.ml.param import Param, Params
from pyspark.sql.functions import udf

from sparkdl.transformers.param import (
keyword_only, HasInputCol, HasOutputCol, SparkDLTypeConverters)
import sparkdl.graph.utils as tfx
from sparkdl.graph.builder import IsolatedSession

class TFOneDimTensorTransformer(Transformer, HasInputCol, HasOutputCol):
"""
Applies the TensorFlow graph to the array column in DataFrame.
Restrictions of the current API:
We assume that
- All graphs have a "minibatch" dimension (i.e. an unknown leading
dimension) in the tensor shapes.
- Input DataFrame has an array column where all elements have the same length
.. note:: The input tensorflow graph should have appropriate weights constantified,
since a new session is created inside this transformer.
"""

graphFunction = Param(Params._dummy(), "graphFunction",
"A TensorFlow GraphDef with input names and output names",
typeConverter=SparkDLTypeConverters.toGraphFunction)

@keyword_only
def __init__(self, inputCol=None, outputCol=None, graphFunction=None):
"""
__init__(self, inputCol=None, outputCol=None, graphFunction=None)
"""
super(TFOneDimTensorTransformer, self).__init__()
kwargs = self._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, inputCol=None, outputCol=None, graphFunction=None):
"""
setParams(self, inputCol=None, outputCol=None, graphFunction=None)
"""
super(TFOneDimTensorTransformer, self).__init__()
kwargs = self._input_kwargs
return self._set(**kwargs)

def setGraphFunction(self, value):
return self._set(graphFunction=value)

def getGraphFunction(self):
return self.getOrDefault(self.graphFunction)

def _transform(self, vec_df):
analyzed_df = tfs.analyze(vec_df)
gfn = self.getGraphFunction()

with IsolatedSession() as issn:
feeds, fetches = issn.importGraphFunction(gfn, prefix='')
assert len(feeds) == 1, 'only support single input TF model'
assert len(fetches) == 1, 'only support single output TF model'

orig_in_name = tfx.op_name(issn.graph, feeds[0])
input_df = analyzed_df.withColumnRenamed(
self.getInputCol(), orig_in_name).select(orig_in_name)
output_df = tfs.map_blocks(fetches, input_df)
orig_out_name = tfx.op_name(issn.graph, fetches[0])
final_df = output_df.withColumn(
self.getOutputCol(), output_df[orig_out_name]).drop(orig_out_name)

return final_df

53 changes: 53 additions & 0 deletions python/tests/transformers/tf_tensor_test.py
@@ -0,0 +1,53 @@
# Copyright 2017 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
import tensorflow as tf

from pyspark.sql.types import Row

from sparkdl.graph.builder import IsolatedSession
from sparkdl.transformers.tf_tensor import TFOneDimTensorTransformer

from ..tests import SparkDLTestCase

class TFOneDimTransformerTest(SparkDLTestCase):

def test_simple(self):
# Build a simple input DataFrame
df = self.session.createDataFrame([
Row(idx=0, vec=np.random.randn(4).tolist()),
Row(idx=1, vec=np.random.randn(4).tolist()),
Row(idx=2, vec=np.random.randn(4).tolist())
])

# Build the TensorFlow graph
with IsolatedSession() as issn:
x = tf.placeholder(tf.float64, shape=[None, 4])
z = tf.reduce_mean(x, axis=1)
gfn = issn.asGraphFunction([x], [z])

# Get the reference data
_results = []
for row in df.rdd.toLocalIterator():
arr = np.array(row.vec)[np.newaxis, :]
_results.append(issn.run(z, {x: arr}))
out_ref = np.hstack(_results)

transformer = TFOneDimTensorTransformer(graphFunction=gfn, inputCol='vec', outputCol='outCol')
final_df = transformer.transform(df)

out_tgt = np.array([row.outCol for row in final_df.select('outCol').collect()])
self.assertTrue(np.allclose(out_ref, out_tgt))

0 comments on commit 0f083c4

Please sign in to comment.