-
Notifications
You must be signed in to change notification settings - Fork 494
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
146 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright 2017 Databricks, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
import tensorframes as tfs | ||
|
||
from pyspark.ml import Transformer | ||
from pyspark.ml.param import Param, Params | ||
from pyspark.sql.functions import udf | ||
|
||
from sparkdl.transformers.param import ( | ||
keyword_only, HasInputCol, HasOutputCol, SparkDLTypeConverters) | ||
import sparkdl.graph.utils as tfx | ||
from sparkdl.graph.builder import IsolatedSession | ||
|
||
class TFOneDimTensorTransformer(Transformer, HasInputCol, HasOutputCol): | ||
""" | ||
Applies the TensorFlow graph to the array column in DataFrame. | ||
Restrictions of the current API: | ||
We assume that | ||
- All graphs have a "minibatch" dimension (i.e. an unknown leading | ||
dimension) in the tensor shapes. | ||
- Input DataFrame has an array column where all elements have the same length | ||
.. note:: The input tensorflow graph should have appropriate weights constantified, | ||
since a new session is created inside this transformer. | ||
""" | ||
|
||
graphFunction = Param(Params._dummy(), "graphFunction", | ||
"A TensorFlow GraphDef with input names and output names", | ||
typeConverter=SparkDLTypeConverters.toGraphFunction) | ||
|
||
@keyword_only | ||
def __init__(self, inputCol=None, outputCol=None, graphFunction=None): | ||
""" | ||
__init__(self, inputCol=None, outputCol=None, graphFunction=None) | ||
""" | ||
super(TFOneDimTensorTransformer, self).__init__() | ||
kwargs = self._input_kwargs | ||
self.setParams(**kwargs) | ||
|
||
@keyword_only | ||
def setParams(self, inputCol=None, outputCol=None, graphFunction=None): | ||
""" | ||
setParams(self, inputCol=None, outputCol=None, graphFunction=None) | ||
""" | ||
super(TFOneDimTensorTransformer, self).__init__() | ||
kwargs = self._input_kwargs | ||
return self._set(**kwargs) | ||
|
||
def setGraphFunction(self, value): | ||
return self._set(graphFunction=value) | ||
|
||
def getGraphFunction(self): | ||
return self.getOrDefault(self.graphFunction) | ||
|
||
def _transform(self, vec_df): | ||
analyzed_df = tfs.analyze(vec_df) | ||
gfn = self.getGraphFunction() | ||
|
||
with IsolatedSession() as issn: | ||
feeds, fetches = issn.importGraphFunction(gfn, prefix='') | ||
assert len(feeds) == 1, 'only support single input TF model' | ||
assert len(fetches) == 1, 'only support single output TF model' | ||
|
||
orig_in_name = tfx.op_name(issn.graph, feeds[0]) | ||
input_df = analyzed_df.withColumnRenamed( | ||
self.getInputCol(), orig_in_name).select(orig_in_name) | ||
output_df = tfs.map_blocks(fetches, input_df) | ||
orig_out_name = tfx.op_name(issn.graph, fetches[0]) | ||
final_df = output_df.withColumn( | ||
self.getOutputCol(), output_df[orig_out_name]).drop(orig_out_name) | ||
|
||
return final_df | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright 2017 Databricks, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from pyspark.sql.types import Row | ||
|
||
from sparkdl.graph.builder import IsolatedSession | ||
from sparkdl.transformers.tf_tensor import TFOneDimTensorTransformer | ||
|
||
from ..tests import SparkDLTestCase | ||
|
||
class TFOneDimTransformerTest(SparkDLTestCase): | ||
|
||
def test_simple(self): | ||
# Build a simple input DataFrame | ||
df = self.session.createDataFrame([ | ||
Row(idx=0, vec=np.random.randn(4).tolist()), | ||
Row(idx=1, vec=np.random.randn(4).tolist()), | ||
Row(idx=2, vec=np.random.randn(4).tolist()) | ||
]) | ||
|
||
# Build the TensorFlow graph | ||
with IsolatedSession() as issn: | ||
x = tf.placeholder(tf.float64, shape=[None, 4]) | ||
z = tf.reduce_mean(x, axis=1) | ||
gfn = issn.asGraphFunction([x], [z]) | ||
|
||
# Get the reference data | ||
_results = [] | ||
for row in df.rdd.toLocalIterator(): | ||
arr = np.array(row.vec)[np.newaxis, :] | ||
_results.append(issn.run(z, {x: arr})) | ||
out_ref = np.hstack(_results) | ||
|
||
transformer = TFOneDimTensorTransformer(graphFunction=gfn, inputCol='vec', outputCol='outCol') | ||
final_df = transformer.transform(df) | ||
|
||
out_tgt = np.array([row.outCol for row in final_df.select('outCol').collect()]) | ||
self.assertTrue(np.allclose(out_ref, out_tgt)) |