New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[API] TensorFlow Graph Transformer #39
Conversation
since a new session is created inside this transformer. | ||
""" | ||
|
||
graphFunction = Param(Params._dummy(), "graphFunction", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The intention is to provide this as the Transformer API and provide utilities to convert any TensorFlow/Keras graph (or checkpoints) into a GraphFunction
object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GraphFunction
and IsolatedSession
were added as internal classes not to be exposed to users, only to be used within the internals of sparkdl, so we shouldn't have GraphFunction
be a param user defines. These are foreign concepts to TensorFlow users. We should stick with what users know as much as possible so that this functionality is as easy for them to use as possible. Making this API match the one TFoS guys are building (https://github.com/yahoo/TensorFlowOnSpark/pull/114/files) would be better. If you think the proposed API in https://github.com/yahoo/TensorFlowOnSpark/pull/114/files is not as good, we should discuss.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may not need all the params the other API has, but if that is the case, we should have this API start with params that are a strict subset of the other API.
0f083c4
to
b731e55
Compare
Codecov Report
@@ Coverage Diff @@
## master #39 +/- ##
==========================================
+ Coverage 82.82% 84.64% +1.81%
==========================================
Files 23 26 +3
Lines 1217 1485 +268
Branches 5 5
==========================================
+ Hits 1008 1257 +249
- Misses 209 228 +19
Continue to review full report at Codecov.
|
5108164
to
0c90493
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some high-level comments - we can discuss in person if you don't agree with them.
since a new session is created inside this transformer. | ||
""" | ||
|
||
graphFunction = Param(Params._dummy(), "graphFunction", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GraphFunction
and IsolatedSession
were added as internal classes not to be exposed to users, only to be used within the internals of sparkdl, so we shouldn't have GraphFunction
be a param user defines. These are foreign concepts to TensorFlow users. We should stick with what users know as much as possible so that this functionality is as easy for them to use as possible. Making this API match the one TFoS guys are building (https://github.com/yahoo/TensorFlowOnSpark/pull/114/files) would be better. If you think the proposed API in https://github.com/yahoo/TensorFlowOnSpark/pull/114/files is not as good, we should discuss.
self.getInputCol(), orig_in_name).select(orig_in_name) | ||
output_df = tfs.map_blocks(fetches, input_df) | ||
orig_out_name = tfx.op_name(issn.graph, fetches[0]) | ||
final_df = output_df.withColumn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can do a rename for the column instead
since a new session is created inside this transformer. | ||
""" | ||
|
||
graphFunction = Param(Params._dummy(), "graphFunction", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may not need all the params the other API has, but if that is the case, we should have this API start with params that are a strict subset of the other API.
2039cd5
to
8e33759
Compare
I am designing the API around TensorFlow Graph + Input/Output Tensors. |
python/sparkdl/transformers/param.py
Outdated
# New in sparkdl | ||
|
||
class SparkDLTypeConverters(object): | ||
|
||
@staticmethod | ||
def toTFGraph(value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Directly importing tf.GraphDef
without the input tensor mapping is not working.
Thus the user either provides a TensorFlow Graph
or provide a GraphFunction
(for developers).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left some more comments. Let's sync offline about the API.
python/sparkdl/transformers/param.py
Outdated
@@ -45,6 +46,8 @@ def wrapper(self, *args, **kwargs): | |||
return wrapper | |||
|
|||
|
|||
""" TensorFlow Specific Parameters """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hasInputCol is not TensorFlow specific. in fact, it was copied from pyspark...!
assert len(feeds) == 1, 'only support single input TF model' | ||
assert len(fetches) == 1, 'only support single output TF model' | ||
|
||
# Change the column name for TensorFrames |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not much code -- why not just put it in the transformer and skip having to convert from/to GraphFunction? also can keep using the IsolatedSession there then.
|
||
return final_df | ||
|
||
class TFOneDimTensorTransformer(Transformer, HasInputCol, HasOutputCol, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd just name it TFTransformer or TFModel. We can add to the doc that it only supports 1-d arrays currently.
self.assertTrue(np.allclose(out_ref, out_tgt)) | ||
|
||
|
||
def test_simple_graph_function(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this test? it seems like the difference between this and test_simple() is that this one uses GraphFunction and IsolatedSession? I don't think that test belongs in this class (and the functionality should have been tested in the Suites for GraphFunction/IsolatedSession anyway).
python/sparkdl/transformers/param.py
Outdated
return self.getOrDefault(self.outputTensor) | ||
|
||
|
||
class HasInputTensor(Params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normally, we need to handle multiple 1+ input tensors. See TensorFlow Serving as a reference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also for output tensors. See https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/mnist_saved_model.py#L109-L114
|
||
@keyword_only | ||
def setParams(self, tfGraph=None, inputTensor=None, outputTensor=None, | ||
inputCol=None, outputCol=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to handle multiple input colmns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also for output columns
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering if we want the tensors to be in separate columns or in fields of a nested column?
|
||
logger = logging.getLogger('sparkdl') | ||
|
||
class TFTensorTransformer(Transformer, HasTFGraph, HasTFHParams, HasInputMapping, HasOutputMapping): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TFTensorTransformer is redundant since it spells out to be TensorFlowTensorTransformer. Why not just TFTransformer or TFModel?
analyzed_df = tfs.analyze(df) | ||
|
||
with IsolatedSession() as issn: | ||
_, fetches = issn.importGraphFunction(gfn, prefix='') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines 83-33: why do you need to create a GraphFunction and then importGraphFunction to get fetches vs just using the fetches from above?
Somewhat related: analyzed_df = tfs.analyze(df)
(line 85) doesn't have to be called right there, right? Could be called before the first IsolatedSession
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of two sessions is due to the fact that exporting to GraphFunction
modifies the graph and it has to be imported in another session.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh i see... if we just need strip_and_freeze we should just call that here directly because this is hiding the fact that that is what (and all) we are doing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It also seems that importing graph_def directly is a bit flaky. importGraphFunction
attaches the names and tensors correctly. It might just save us a few lines of code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@phi-dbq what do you mean by flaky? I do not understand why simply using import_graph_def
would not work here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes one might use a session without resetting the default graph.
with tf.Session() as sess:
tf.import_graph_def(graph_def, name='')
sess.run(...)
In case there is a default graph and there is a name conflict, the tensors in the imported graph will be renamed automatically.
(if we pass an exising graph as tf.Session(grarph=existing_graph) as sess
, it seems to use the default graph properly inside the session. )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But in this case there isn't another graph in the session right? So there is no need to have the complex logic around the graph / graphdef. Could lines 70-92 just be:
analyzed_df = tfs.analyze(df)
# prune the graph for tensorframes
with tf.session(graph=self.getTFGraph()) as sess:
fetches = [tfx.get_tensor(sess.graph, tnsr) for tnsr, output_colname in self.getOutputMapping()]
gdef = tfx.strip_and_freeze_until(fetches, sess.graph, sess)
with tf.session() as sess2:
output_names = [tfx.validated_output(sess.graph, elem) for elem in fetches]
tf.import_graph_def(gdef,
input_map=None,
return_elements=output_names,
name='')
final_fetches = [tfx.get_tensor(sess2.graph, name) for name in output_names]
out_df = tfs.map_blocks(final_fetches, analyzed_df)
for tnsr, output_colname in self.getOutputMapping():
out_df = out_df.withColumnRenamed(tfx.op_name(sess.graph, tnsr), output_colname)
?
With IsolatedSession and GraphFunction, it's hard for anyone not intimately familiar with those classes to understand what's going on here. It makes it seem like you are just creating a GraphFunction object then getting back exactly what you put in. With more explicit code like above, it's easier to tell what exactly we're doing (pruning the graph into graphdef and re-importing). Also it seems shorter...
A few questions:
1/ I thought the tensorframes changes were done to remove the necessity for freezing+reimporting? But I may be wrong there (didn't really look at the tensorframes PR).
2/ In https://github.com/databricks/spark-deep-learning/blob/master/python/sparkdl/graph/builder.py#L91 (and potentially in the code I wrote above), do we need to check that no output has been stripped from the graph? Or is that not possible because the outputs were fed into the strip/freeze function and would be kept even if an output was "floating" and would have been stripped normally?
1365723
to
bee8d18
Compare
# Using the Transformer | ||
transformer = TFModelTransformer(tfGraph=gfn, | ||
inputMapping={ | ||
input_col: gfn.input_names[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can either provide names or tensors for the mappings
976ac08
to
bab6724
Compare
# We want to clear device assignment in order to run it anywhere we want | ||
with IsolatedSession() as issn: | ||
saver = tf.train.import_meta_graph('{}.meta'.format(saved_path), clear_devices=True) | ||
saver.restore(issn.sess, saved_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to export this checkpoint as GraphFunction
in the same session where it is restored. Since the variables are only initialized in this session.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also it seems that we have to import the meta_graph
first or there won't be any graph in this session (tried with either IsolatedSession
or standard tf.Session
).
z = tf.reduce_mean(x * w, axis=1, name='tnsrOut') | ||
sess.run(w.initializer) | ||
saver = tf.train.Saver(var_list=[w]) | ||
saved_path = saver.save(sess, ckpt_dir, global_step=2702) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Save the model and record the current global_step
. Then we'd expect files like these.
model.ckpt-2702.index
model.ckpt-2702.meta
model.ckpt-2702.data-00000-of-00001
[tfx.get_tensor(issn.graph, 'tnsrOut')]) | ||
|
||
transformer = TFModelTransformer(tfGraph=gfn, | ||
inputMapping={ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be better reversing the input mapping to match feed_dict
in tf.session.run
and TensorFrames map_row
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not finished the review, but I have a couple of comments already. I will have to finish on Monday.
for colname, tnsr in self.getInputMapping().items(): | ||
feeds.append(tfx.get_tensor(issn.graph, tnsr)) | ||
new_colname = tfx.op_name(issn.graph, tnsr) | ||
df = df.withColumnRenamed(colname, new_colname) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this may not always work because there may be conflicts with existing columns, but this is rare enough I think that we do not need to handle it currently.
python/sparkdl/transformers/param.py
Outdated
|
||
# New in sparkdl | ||
|
||
class HasOutputMapping(Params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is too complicated: I do not understand what is inside that parameter (actually, I do, but I had to reverse-engineer the logic to figure this out). Here is what we should do: store a Seq[(String, String)]
for that parameter so that the order is clear, and only strings. Right now, you accept all sorts of objects such as tensors, which makes it hard to reason about. Most important, document that what each element means in the document. At least as a code comment.
The reason for doing that is that you convert to primitive types, and that the type of the param is much more clear.
python/sparkdl/transformers/param.py
Outdated
super(HasOutputMapping, self).__init__() | ||
|
||
def setOutputMapping(self, value): | ||
return self._set(outputMapping=value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, instead of being as simple, do some conversions: dicts should be converted to lists, sorted by key order; tensors should be converted to the string names, etc.
python/sparkdl/transformers/param.py
Outdated
def setOutputMapping(self, value): | ||
return self._set(outputMapping=value) | ||
|
||
def getOutputMapping(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Says what it returns (list of tensor name -> output name)
python/sparkdl/transformers/param.py
Outdated
|
||
def setOutputMapping(self, value): | ||
return self._set(outputMapping=value) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to check that the tensors are part of the graph, but this is probably too fancy given the code structure.
python/sparkdl/transformers/param.py
Outdated
raise TypeError("Could not convert %s to TensorFlow Graph" % type(value)) | ||
|
||
@staticmethod | ||
def asColumnToTensorMap(value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
based on my comment below, you do not need that function. It can be written just in a few lines during the conversion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, instead of making something static and publick, just make a small helper funciton that is not exported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am using a converter function in sparkdl.graph.utils
to convert the tensor/operation object into the corresponding operation name. If the operation is named inputPlaceholder
, the tensor name is inputPlaceholder:0
. TensorFrames expects the operation name.
python/sparkdl/transformers/param.py
Outdated
raise TypeError("Could not convert %s to TensorFlow Tensor" % type(value)) | ||
|
||
@staticmethod | ||
def asTensorToColumnMap(value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you do not need this separate function at all, based on my quick analysis. Just one is enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I take it back, you need each of the function (input and output), but instead of reverting the dict, you can simply have have a conversion function that you use like [k, convert(tensor) for ...]
and [convert(tensor), name for ...]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am using a converter function in sparkdl.graph.utils
to convert the tensor/operation object into the corresponding operation name. If the operation is named inputPlaceholder
, the tensor name is inputPlaceholder:0
. TensorFrames expects the operation name.
python/sparkdl/graph/utils.py
Outdated
@@ -165,6 +165,17 @@ def validated_input(graph, tfobj_or_name): | |||
('input must be Placeholder, but get', op.type) | |||
return name | |||
|
|||
def is_interpretable_tensor_type(tfobj_or_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not have such a function.
Here is what you are currently doing:
- checking that something belongs to a union of types
- storing it into a structure
- then using it for lookup, which uses different branches depending on the type and does some conversions.
We can drastically simplify that: - convert the something to just a string
- store this simple type (string)
- do a simpler lookup using this simple type
It is easier to reason about the flow of info.
python/sparkdl/transformers/param.py
Outdated
|
||
|
||
class HasInputMapping(Params): | ||
inputMapping = Param(Params._dummy(), "inputMapping", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above.
python/sparkdl/transformers/param.py
Outdated
|
||
|
||
class HasTFGraph(Params): | ||
tfGraph = Param(Params._dummy(), "tfGraph", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doc please.
c4b6c15
to
ad481a7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@phi-dbq I have a few comments. I think the core function of the code will be greatly simplified once I do a tiny change to tensorframes.
return name | ||
elif hasattr(tfobj_or_name, 'graph'): | ||
tfobj = tfobj_or_name | ||
return get_tensor(tfobj.graph, tfobj).name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensor_name
|
||
def as_op_name(name): | ||
def as_op_name(tfobj_or_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have 5 functions that compute some names in this file, this is at least two too many.
|
||
gfn = issn.asGraphFunction(feeds, fetches, strip_and_freeze=True) | ||
|
||
analyzed_df = tfs.analyze(df) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The analyze() function is not smart enough currently to know that the data is already analyzed, so it make take a substantial amount of time. This is an easy fix in tensorframes though.
df = dataset | ||
output_renaming = {} | ||
|
||
with IsolatedSession(graph=self.getTFGraph()) as issn: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see why you need this currently. There is a feed_dict
arguments for map_rows
that does exactly what you want in tensorframes. It still needs to be implemented for map_blocks
, but this is a quick fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once this argument is there, this block of code will not be required anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, errors here are going to be confusing because you are renaming the columns.
analyzed_df = tfs.analyze(df) | ||
|
||
with IsolatedSession() as issn: | ||
_, fetches = issn.importGraphFunction(gfn, prefix='') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@phi-dbq what do you mean by flaky? I do not understand why simply using import_graph_def
would not work here.
# Apply the transform | ||
transfomer = TFTransformer(tfGraph=graph, | ||
inputMapping={ | ||
'vec': x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are there new lines here?
python/sparkdl/transformers/param.py
Outdated
# TODO: we may want to support tf.GraphDef in the future instead of tf.Graph since user | ||
# is less likely to mess up using GraphDef vs Graph (e.g. constants vs variables). | ||
if isinstance(value, tf.Graph): | ||
def toGraphFunction(value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is unused?
python/sparkdl/transformers/param.py
Outdated
@@ -92,11 +123,69 @@ def getOutputCol(self): | |||
return self.getOrDefault(self.outputCol) | |||
|
|||
|
|||
# New in sparkdl | |||
class HasOutputCols(Params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused?
@phi-dbq, left you a few more comments skimming the latest but I'm not sure if I'm supposed to review again yet. Let me know when it's ready. |
80795d7
to
ee9b9e2
Compare
ee9b9e2
to
9b3fe86
Compare
196244e
to
111c427
Compare
111c427
to
a3517d6
Compare
439df5b
to
2d25c32
Compare
2d25c32
to
b232b3c
Compare
Closing this PR in favor of the stacked PR holistic view |
What changes are proposed in this pull request?
API design and reference implementation of a Spark MLlib Transformer from any TensorFlow Graph.
Notice
This implementation is only a proof-of-concept reference. We will break it up into multiple PRs with finer grain tests and documentations.
This is not an implementation of the (attention-based) Transformer architecture (https://arxiv.org/abs/1706.03762).
How is this patch tested?