Skip to content

Conversation

@yiheng
Copy link
Contributor

@yiheng yiheng commented Apr 19, 2017

What changes were proposed in this pull request?

Support Tensorflow model file read/write

How was this patch tested?

manual test, unit test

@yiheng
Copy link
Contributor Author

yiheng commented May 26, 2017

Let me break the change into smalls to check in.

@yiheng yiheng force-pushed the tfpb branch 7 times, most recently from 3a2f024 to a873fdc Compare June 12, 2017 16:19
require(
tfTensor.getDtype == DataType.DT_FLOAT ||
tfTensor.getDtype == DataType.DT_FLOAT ||
tfTensor.getDtype == DataType.DT_INT32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double DT_FLOAT.

tmp(j) = params.get(j)
j += 1
}
Tensor(Storage(tmp), 1, shape).asInstanceOf[Tensor[T]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a little strange, there are five similar code snippets. Does this have much higher performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have an idea to refine this. Can you provide a specific example?

val posGraph = { if (direction == 0) i else graphNode.prevNodes.length - 1 - j}
val pn = patternNode.prevNodes(posPattern)
val gn = graphNode.prevNodes(posGraph)
if (patternToGraph.keySet.contains(pn)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using patternToGraph.contains(pn) directly?

// Normal operation node
if (patternToGraph.get(patternNode).isEmpty) return (util.Collections.emptyList(), Seq())

val graphNode = patternToGraph.get(patternNode).get
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use patternToGraph(patternNode) directly?

}
Tensor(Storage(tmp), 1, shape).asInstanceOf[Tensor[T]]
} else {
throw new IllegalArgumentException("Data type ${tfTensor.getDtype} is not supported now")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lacking the leading "s", s"Data type ${}"


val shape = tfTensor.getTensorShape.getDimList.asScala.map(_.getSize.toInt).toArray

if (shape.product == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought. Maybe it is not necessary to treat this case specially? If the tensor is a scalar and shape is an empty array, we can simply change the shape to Array(1), then the following code will be able to handle the scalar.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following code can not get the only 1 element. Leave a comment.

* Sort the pattern list to make sure the graph match first should not be a sub-graph of the graph
* match later
*/
private def sortPattern() : Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a potential issue about this. MulTF and ElementWiseMulTF pattern have the exact same number of nodes and edges, but MulTF should comes before ElementWiseMulTF.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been addressed by removing the wildcards then compare the node number. After removed wildcards, MulTF has two nodes and ElementWiseMulTf only has one, so MulTF come first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

verified by exchange the order.

val BigDLResult = model.forward(input)

tfResult.map( BigDLResult.toTensor, (v1, v2) => {
assert(abs(v1 - v2) < 1e-7);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should use relative error here?

@yiheng yiheng changed the title [WIP] Support Tensorflow model file read/write Support Tensorflow model file read/write Jun 15, 2017
yiheng-wang-intel and others added 2 commits June 20, 2017 16:29
* add unit test of lenet backward

* add some print

* add backward test in lenet and alexnet

* seperate testModel into forward and backward methods
@yiheng yiheng merged commit 6cf1f6a into intel:master Jun 21, 2017

import tempfile

import tensorflow as tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why import tensorflow? I don't think we want to do that. @yiheng

# limitations under the License.
#

import tensorflow as tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add such an example here? It should go to unit test if needed.

return Layer.of(jmodel)

@staticmethod
def load_tensorflow(path, inputs, outputs, byte_order = "little_endian", bigdl_type="float"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a test case to cover this function?

@jason-dai
Copy link
Contributor

We need to add some examples to show how to load tensorflow model (maybe in the load_model example) and save to tensorflow mode.

In addition, we should add some utilities to convert between tensorflow and bigdl models.

})

// These two pieces of code are all necessary
val nextNodes = n.nextNodes.filter(
Copy link
Contributor

@jason-dai jason-dai Jun 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we apply this every node in the match subgraph, not just n?

@yiheng yiheng deleted the tfpb branch August 23, 2017 02:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants