Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API] TensorFlow Graph Transformer #39

Closed
wants to merge 25 commits into from

Conversation

phi-dbq
Copy link
Contributor

@phi-dbq phi-dbq commented Aug 8, 2017

What changes are proposed in this pull request?

API design and reference implementation of a Spark MLlib Transformer from any TensorFlow Graph.

Notice
This implementation is only a proof-of-concept reference. We will break it up into multiple PRs with finer grain tests and documentations.

This is not an implementation of the (attention-based) Transformer architecture (https://arxiv.org/abs/1706.03762).

How is this patch tested?

  • Integration tests
  • Manual tests

since a new session is created inside this transformer.
"""

graphFunction = Param(Params._dummy(), "graphFunction",
Copy link
Contributor Author

@phi-dbq phi-dbq Aug 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention is to provide this as the Transformer API and provide utilities to convert any TensorFlow/Keras graph (or checkpoints) into a GraphFunction object.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GraphFunction and IsolatedSession were added as internal classes not to be exposed to users, only to be used within the internals of sparkdl, so we shouldn't have GraphFunction be a param user defines. These are foreign concepts to TensorFlow users. We should stick with what users know as much as possible so that this functionality is as easy for them to use as possible. Making this API match the one TFoS guys are building (https://github.com/yahoo/TensorFlowOnSpark/pull/114/files) would be better. If you think the proposed API in https://github.com/yahoo/TensorFlowOnSpark/pull/114/files is not as good, we should discuss.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may not need all the params the other API has, but if that is the case, we should have this API start with params that are a strict subset of the other API.

@codecov-io
Copy link

codecov-io commented Aug 9, 2017

Codecov Report

Merging #39 into master will increase coverage by 1.81%.
The diff coverage is 92.06%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #39      +/-   ##
==========================================
+ Coverage   82.82%   84.64%   +1.81%     
==========================================
  Files          23       26       +3     
  Lines        1217     1485     +268     
  Branches        5        5              
==========================================
+ Hits         1008     1257     +249     
- Misses        209      228      +19
Impacted Files Coverage Δ
python/sparkdl/param/__init__.py 100% <100%> (ø) ⬆️
python/sparkdl/graph/builder.py 93.75% <100%> (+0.05%) ⬆️
python/sparkdl/transformers/tf_tensor.py 100% <100%> (ø)
python/sparkdl/__init__.py 100% <100%> (ø) ⬆️
python/sparkdl/graph/utils.py 89.01% <72.72%> (-6.06%) ⬇️
python/sparkdl/param/converters.py 80.3% <80.3%> (ø)
python/sparkdl/param/shared_params.py 81.25% <84.37%> (+1.04%) ⬆️
python/sparkdl/graph/input.py 99.29% <99.29%> (ø)
... and 1 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 3f668d9...b232b3c. Read the comment docs.

Copy link
Collaborator

@sueann sueann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some high-level comments - we can discuss in person if you don't agree with them.

since a new session is created inside this transformer.
"""

graphFunction = Param(Params._dummy(), "graphFunction",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GraphFunction and IsolatedSession were added as internal classes not to be exposed to users, only to be used within the internals of sparkdl, so we shouldn't have GraphFunction be a param user defines. These are foreign concepts to TensorFlow users. We should stick with what users know as much as possible so that this functionality is as easy for them to use as possible. Making this API match the one TFoS guys are building (https://github.com/yahoo/TensorFlowOnSpark/pull/114/files) would be better. If you think the proposed API in https://github.com/yahoo/TensorFlowOnSpark/pull/114/files is not as good, we should discuss.

self.getInputCol(), orig_in_name).select(orig_in_name)
output_df = tfs.map_blocks(fetches, input_df)
orig_out_name = tfx.op_name(issn.graph, fetches[0])
final_df = output_df.withColumn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can do a rename for the column instead

since a new session is created inside this transformer.
"""

graphFunction = Param(Params._dummy(), "graphFunction",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may not need all the params the other API has, but if that is the case, we should have this API start with params that are a strict subset of the other API.

@phi-dbq
Copy link
Contributor Author

phi-dbq commented Aug 9, 2017

I am designing the API around TensorFlow Graph + Input/Output Tensors.
The TFoS API takes a checkpoint directory for the (meta) graph. We can get a utility function to convert it to a TensorFlow Graph.

# New in sparkdl

class SparkDLTypeConverters(object):

@staticmethod
def toTFGraph(value):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Directly importing tf.GraphDef without the input tensor mapping is not working.
Thus the user either provides a TensorFlow Graph or provide a GraphFunction (for developers).

Copy link
Collaborator

@sueann sueann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left some more comments. Let's sync offline about the API.

@@ -45,6 +46,8 @@ def wrapper(self, *args, **kwargs):
return wrapper


""" TensorFlow Specific Parameters """
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hasInputCol is not TensorFlow specific. in fact, it was copied from pyspark...!

assert len(feeds) == 1, 'only support single input TF model'
assert len(fetches) == 1, 'only support single output TF model'

# Change the column name for TensorFrames
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not much code -- why not just put it in the transformer and skip having to convert from/to GraphFunction? also can keep using the IsolatedSession there then.


return final_df

class TFOneDimTensorTransformer(Transformer, HasInputCol, HasOutputCol,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd just name it TFTransformer or TFModel. We can add to the doc that it only supports 1-d arrays currently.

self.assertTrue(np.allclose(out_ref, out_tgt))


def test_simple_graph_function(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this test? it seems like the difference between this and test_simple() is that this one uses GraphFunction and IsolatedSession? I don't think that test belongs in this class (and the functionality should have been tested in the Suites for GraphFunction/IsolatedSession anyway).

return self.getOrDefault(self.outputTensor)


class HasInputTensor(Params):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally, we need to handle multiple 1+ input tensors. See TensorFlow Serving as a reference.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


@keyword_only
def setParams(self, tfGraph=None, inputTensor=None, outputTensor=None,
inputCol=None, outputCol=None):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to handle multiple input colmns.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also for output columns

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if we want the tensors to be in separate columns or in fields of a nested column?


logger = logging.getLogger('sparkdl')

class TFTensorTransformer(Transformer, HasTFGraph, HasTFHParams, HasInputMapping, HasOutputMapping):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TFTensorTransformer is redundant since it spells out to be TensorFlowTensorTransformer. Why not just TFTransformer or TFModel?

analyzed_df = tfs.analyze(df)

with IsolatedSession() as issn:
_, fetches = issn.importGraphFunction(gfn, prefix='')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 83-33: why do you need to create a GraphFunction and then importGraphFunction to get fetches vs just using the fetches from above?

Somewhat related: analyzed_df = tfs.analyze(df) (line 85) doesn't have to be called right there, right? Could be called before the first IsolatedSession here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of two sessions is due to the fact that exporting to GraphFunction modifies the graph and it has to be imported in another session.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i see... if we just need strip_and_freeze we should just call that here directly because this is hiding the fact that that is what (and all) we are doing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also seems that importing graph_def directly is a bit flaky. importGraphFunction attaches the names and tensors correctly. It might just save us a few lines of code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phi-dbq what do you mean by flaky? I do not understand why simply using import_graph_def would not work here.

Copy link
Contributor Author

@phi-dbq phi-dbq Aug 22, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes one might use a session without resetting the default graph.

with tf.Session() as sess:
    tf.import_graph_def(graph_def, name='')
    sess.run(...)

In case there is a default graph and there is a name conflict, the tensors in the imported graph will be renamed automatically.
(if we pass an exising graph as tf.Session(grarph=existing_graph) as sess, it seems to use the default graph properly inside the session. )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in this case there isn't another graph in the session right? So there is no need to have the complex logic around the graph / graphdef. Could lines 70-92 just be:

analyzed_df = tfs.analyze(df)

# prune the graph for tensorframes
with tf.session(graph=self.getTFGraph()) as sess:
  fetches = [tfx.get_tensor(sess.graph, tnsr) for tnsr, output_colname in self.getOutputMapping()]
  gdef = tfx.strip_and_freeze_until(fetches, sess.graph, sess)

with tf.session() as sess2:
  output_names = [tfx.validated_output(sess.graph, elem) for elem in fetches]
  tf.import_graph_def(gdef,
                                    input_map=None,
                                    return_elements=output_names,
                                    name='')
  final_fetches = [tfx.get_tensor(sess2.graph, name) for name in output_names]
  out_df = tfs.map_blocks(final_fetches, analyzed_df)

  for tnsr, output_colname in self.getOutputMapping():
    out_df = out_df.withColumnRenamed(tfx.op_name(sess.graph, tnsr), output_colname)

?

With IsolatedSession and GraphFunction, it's hard for anyone not intimately familiar with those classes to understand what's going on here. It makes it seem like you are just creating a GraphFunction object then getting back exactly what you put in. With more explicit code like above, it's easier to tell what exactly we're doing (pruning the graph into graphdef and re-importing). Also it seems shorter...

A few questions:
1/ I thought the tensorframes changes were done to remove the necessity for freezing+reimporting? But I may be wrong there (didn't really look at the tensorframes PR).
2/ In https://github.com/databricks/spark-deep-learning/blob/master/python/sparkdl/graph/builder.py#L91 (and potentially in the code I wrote above), do we need to check that no output has been stripped from the graph? Or is that not possible because the outputs were fed into the strip/freeze function and would be kept even if an output was "floating" and would have been stripped normally?

@phi-dbq phi-dbq force-pushed the tf-1d-transformer branch 2 times, most recently from 1365723 to bee8d18 Compare August 11, 2017 23:35
# Using the Transformer
transformer = TFModelTransformer(tfGraph=gfn,
inputMapping={
input_col: gfn.input_names[0]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can either provide names or tensors for the mappings

# We want to clear device assignment in order to run it anywhere we want
with IsolatedSession() as issn:
saver = tf.train.import_meta_graph('{}.meta'.format(saved_path), clear_devices=True)
saver.restore(issn.sess, saved_path)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to export this checkpoint as GraphFunction in the same session where it is restored. Since the variables are only initialized in this session.

Copy link
Contributor Author

@phi-dbq phi-dbq Aug 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also it seems that we have to import the meta_graph first or there won't be any graph in this session (tried with either IsolatedSession or standard tf.Session).

z = tf.reduce_mean(x * w, axis=1, name='tnsrOut')
sess.run(w.initializer)
saver = tf.train.Saver(var_list=[w])
saved_path = saver.save(sess, ckpt_dir, global_step=2702)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Save the model and record the current global_step. Then we'd expect files like these.

model.ckpt-2702.index
model.ckpt-2702.meta
model.ckpt-2702.data-00000-of-00001

[tfx.get_tensor(issn.graph, 'tnsrOut')])

transformer = TFModelTransformer(tfGraph=gfn,
inputMapping={
Copy link
Contributor Author

@phi-dbq phi-dbq Aug 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be better reversing the input mapping to match feed_dict in tf.session.run and TensorFrames map_row.

Copy link
Contributor

@thunterdb thunterdb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not finished the review, but I have a couple of comments already. I will have to finish on Monday.

for colname, tnsr in self.getInputMapping().items():
feeds.append(tfx.get_tensor(issn.graph, tnsr))
new_colname = tfx.op_name(issn.graph, tnsr)
df = df.withColumnRenamed(colname, new_colname)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may not always work because there may be conflicts with existing columns, but this is rare enough I think that we do not need to handle it currently.


# New in sparkdl

class HasOutputMapping(Params):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is too complicated: I do not understand what is inside that parameter (actually, I do, but I had to reverse-engineer the logic to figure this out). Here is what we should do: store a Seq[(String, String)] for that parameter so that the order is clear, and only strings. Right now, you accept all sorts of objects such as tensors, which makes it hard to reason about. Most important, document that what each element means in the document. At least as a code comment.

The reason for doing that is that you convert to primitive types, and that the type of the param is much more clear.

super(HasOutputMapping, self).__init__()

def setOutputMapping(self, value):
return self._set(outputMapping=value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, instead of being as simple, do some conversions: dicts should be converted to lists, sorted by key order; tensors should be converted to the string names, etc.

def setOutputMapping(self, value):
return self._set(outputMapping=value)

def getOutputMapping(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Says what it returns (list of tensor name -> output name)


def setOutputMapping(self, value):
return self._set(outputMapping=value)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to check that the tensors are part of the graph, but this is probably too fancy given the code structure.

raise TypeError("Could not convert %s to TensorFlow Graph" % type(value))

@staticmethod
def asColumnToTensorMap(value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on my comment below, you do not need that function. It can be written just in a few lines during the conversion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, instead of making something static and publick, just make a small helper funciton that is not exported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using a converter function in sparkdl.graph.utils to convert the tensor/operation object into the corresponding operation name. If the operation is named inputPlaceholder, the tensor name is inputPlaceholder:0. TensorFrames expects the operation name.

raise TypeError("Could not convert %s to TensorFlow Tensor" % type(value))

@staticmethod
def asTensorToColumnMap(value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you do not need this separate function at all, based on my quick analysis. Just one is enough.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take it back, you need each of the function (input and output), but instead of reverting the dict, you can simply have have a conversion function that you use like [k, convert(tensor) for ...] and [convert(tensor), name for ...]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using a converter function in sparkdl.graph.utils to convert the tensor/operation object into the corresponding operation name. If the operation is named inputPlaceholder, the tensor name is inputPlaceholder:0. TensorFrames expects the operation name.

@@ -165,6 +165,17 @@ def validated_input(graph, tfobj_or_name):
('input must be Placeholder, but get', op.type)
return name

def is_interpretable_tensor_type(tfobj_or_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not have such a function.

Here is what you are currently doing:

  • checking that something belongs to a union of types
  • storing it into a structure
  • then using it for lookup, which uses different branches depending on the type and does some conversions.
    We can drastically simplify that:
  • convert the something to just a string
  • store this simple type (string)
  • do a simpler lookup using this simple type
    It is easier to reason about the flow of info.



class HasInputMapping(Params):
inputMapping = Param(Params._dummy(), "inputMapping",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above.



class HasTFGraph(Params):
tfGraph = Param(Params._dummy(), "tfGraph",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc please.

@phi-dbq phi-dbq force-pushed the tf-1d-transformer branch 4 times, most recently from c4b6c15 to ad481a7 Compare August 21, 2017 22:16
Copy link
Contributor

@thunterdb thunterdb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phi-dbq I have a few comments. I think the core function of the code will be greatly simplified once I do a tiny change to tensorframes.

return name
elif hasattr(tfobj_or_name, 'graph'):
tfobj = tfobj_or_name
return get_tensor(tfobj.graph, tfobj).name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor_name


def as_op_name(name):
def as_op_name(tfobj_or_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have 5 functions that compute some names in this file, this is at least two too many.


gfn = issn.asGraphFunction(feeds, fetches, strip_and_freeze=True)

analyzed_df = tfs.analyze(df)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The analyze() function is not smart enough currently to know that the data is already analyzed, so it make take a substantial amount of time. This is an easy fix in tensorframes though.

df = dataset
output_renaming = {}

with IsolatedSession(graph=self.getTFGraph()) as issn:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see why you need this currently. There is a feed_dict arguments for map_rows that does exactly what you want in tensorframes. It still needs to be implemented for map_blocks, but this is a quick fix.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once this argument is there, this block of code will not be required anymore.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, errors here are going to be confusing because you are renaming the columns.

analyzed_df = tfs.analyze(df)

with IsolatedSession() as issn:
_, fetches = issn.importGraphFunction(gfn, prefix='')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phi-dbq what do you mean by flaky? I do not understand why simply using import_graph_def would not work here.

# Apply the transform
transfomer = TFTransformer(tfGraph=graph,
inputMapping={
'vec': x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are there new lines here?

# TODO: we may want to support tf.GraphDef in the future instead of tf.Graph since user
# is less likely to mess up using GraphDef vs Graph (e.g. constants vs variables).
if isinstance(value, tf.Graph):
def toGraphFunction(value):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is unused?

@@ -92,11 +123,69 @@ def getOutputCol(self):
return self.getOrDefault(self.outputCol)


# New in sparkdl
class HasOutputCols(Params):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused?

@sueann
Copy link
Collaborator

sueann commented Sep 13, 2017

@phi-dbq, left you a few more comments skimming the latest but I'm not sure if I'm supposed to review again yet. Let me know when it's ready.

@phi-dbq phi-dbq force-pushed the tf-1d-transformer branch 4 times, most recently from 80795d7 to ee9b9e2 Compare September 14, 2017 02:53
@phi-dbq phi-dbq force-pushed the tf-1d-transformer branch 2 times, most recently from 439df5b to 2d25c32 Compare September 19, 2017 01:01
@phi-dbq
Copy link
Contributor Author

phi-dbq commented Oct 17, 2017

Closing this PR in favor of the stacked PR holistic view
phi-dbq#12

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants