-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Support Tensorflow model file read/write #800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
230aad4 to
8e26d65
Compare
b8d35c9 to
e2d093b
Compare
|
Let me break the change into smalls to check in. |
3a2f024 to
a873fdc
Compare
| require( | ||
| tfTensor.getDtype == DataType.DT_FLOAT || | ||
| tfTensor.getDtype == DataType.DT_FLOAT || | ||
| tfTensor.getDtype == DataType.DT_INT32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Double DT_FLOAT.
| tmp(j) = params.get(j) | ||
| j += 1 | ||
| } | ||
| Tensor(Storage(tmp), 1, shape).asInstanceOf[Tensor[T]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems a little strange, there are five similar code snippets. Does this have much higher performance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have an idea to refine this. Can you provide a specific example?
| val posGraph = { if (direction == 0) i else graphNode.prevNodes.length - 1 - j} | ||
| val pn = patternNode.prevNodes(posPattern) | ||
| val gn = graphNode.prevNodes(posGraph) | ||
| if (patternToGraph.keySet.contains(pn)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not using patternToGraph.contains(pn) directly?
| // Normal operation node | ||
| if (patternToGraph.get(patternNode).isEmpty) return (util.Collections.emptyList(), Seq()) | ||
|
|
||
| val graphNode = patternToGraph.get(patternNode).get |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use patternToGraph(patternNode) directly?
| } | ||
| Tensor(Storage(tmp), 1, shape).asInstanceOf[Tensor[T]] | ||
| } else { | ||
| throw new IllegalArgumentException("Data type ${tfTensor.getDtype} is not supported now") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lacking the leading "s", s"Data type ${}"
|
|
||
| val shape = tfTensor.getTensorShape.getDimList.asScala.map(_.getSize.toInt).toArray | ||
|
|
||
| if (shape.product == 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a thought. Maybe it is not necessary to treat this case specially? If the tensor is a scalar and shape is an empty array, we can simply change the shape to Array(1), then the following code will be able to handle the scalar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following code can not get the only 1 element. Leave a comment.
| * Sort the pattern list to make sure the graph match first should not be a sub-graph of the graph | ||
| * match later | ||
| */ | ||
| private def sortPattern() : Unit = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a potential issue about this. MulTF and ElementWiseMulTF pattern have the exact same number of nodes and edges, but MulTF should comes before ElementWiseMulTF.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has been addressed by removing the wildcards then compare the node number. After removed wildcards, MulTF has two nodes and ElementWiseMulTf only has one, so MulTF come first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
verified by exchange the order.
| val BigDLResult = model.forward(input) | ||
|
|
||
| tfResult.map( BigDLResult.toTensor, (v1, v2) => { | ||
| assert(abs(v1 - v2) < 1e-7); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should use relative error here?
* add unit test of lenet backward * add some print * add backward test in lenet and alexnet * seperate testModel into forward and backward methods
|
|
||
| import tempfile | ||
|
|
||
| import tensorflow as tf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why import tensorflow? I don't think we want to do that. @yiheng
| # limitations under the License. | ||
| # | ||
|
|
||
| import tensorflow as tf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why add such an example here? It should go to unit test if needed.
| return Layer.of(jmodel) | ||
|
|
||
| @staticmethod | ||
| def load_tensorflow(path, inputs, outputs, byte_order = "little_endian", bigdl_type="float"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have a test case to cover this function?
|
We need to add some examples to show how to load tensorflow model (maybe in the load_model example) and save to tensorflow mode. In addition, we should add some utilities to convert between tensorflow and bigdl models. |
| }) | ||
|
|
||
| // These two pieces of code are all necessary | ||
| val nextNodes = n.nextNodes.filter( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we apply this every node in the match subgraph, not just n?
What changes were proposed in this pull request?
Support Tensorflow model file read/write
How was this patch tested?
manual test, unit test