Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pyspark/bigdl/nn/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,16 @@ def load_caffe(model, defPath, modelPath, match_all=True, bigdl_type="float"):
jmodel = callBigDlFunc(bigdl_type, "loadCaffe", model, defPath, modelPath, match_all)
return Layer.of(jmodel)

@staticmethod
def load_tensorflow(path, inputs, outputs, byte_order = "little_endian", bigdl_type="float"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a test case to cover this function?

"""
Load a pre-trained Tensorflow model.
:param path: The path containing the pre-trained model.
:return: A pre-trained model.
"""
jmodel = callBigDlFunc(bigdl_type, "loadTF", path, inputs, outputs, byte_order)
return Model.of(jmodel)


class Linear(Layer):

Expand Down
88 changes: 88 additions & 0 deletions pyspark/bigdl/util/tf_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import tempfile

import tensorflow as tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why import tensorflow? I don't think we want to do that. @yiheng


from google.protobuf import text_format

from tensorflow.core.framework import graph_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.platform import gfile

from bigdl.nn.layer import Model

def convert(input_ops, output_ops, sess):
"""
Convert tensorflow model to bigdl model
:param input_ops: operation list used for input, should be placeholders
:param output_ops: operations list used for output
:param sess: current tensorflow session
:return: bigdl model
"""
input_names = map(lambda x: x.name.split(":")[0], input_ops)
output_names = map(lambda x: x.name.split(":")[0], output_ops)
temp = tempfile.mkdtemp()

saver = tf.train.Saver()
saver.save(sess, temp + '/model.chkp')
tf.train.write_graph(sess.graph, temp, 'model.pbtxt')

merge_checkpoint(temp + '/model.pbtxt',
temp + '/model.chkp',
output_names,
temp + '/model.pb', sess)
return Model.load_tensorflow(temp + '/model.pb', input_names, output_names)

def merge_checkpoint(input_graph,
checkpoint,
output_node_names,
output_graph,
sess):
"""
Get the variable values from the checkpoint file, and merge them to the GraphDef file
Args:
input_graph: the GraphDef file, doesn't contain variable values
checkpoint: the checkpoint file
output_node_names: A list of string, the output names
output_graph: String of the location and the name of the
output graph
"""
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"

input_graph_def = graph_pb2.GraphDef()
with gfile.FastGFile(input_graph, "r") as f:
text_format.Merge(f.read().decode("utf-8"), input_graph_def)

for node in input_graph_def.node:
node.device = ""

importer.import_graph_def(input_graph_def, name="")

sess.run([restore_op_name], {filename_tensor_name: checkpoint})
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names,
variable_names_blacklist=""
)

with gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
44 changes: 44 additions & 0 deletions pyspark/example/tf_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import tensorflow as tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add such an example here? It should go to unit test if needed.

import numpy as np
from bigdl.util.tf_utils import convert

def main():
input = tf.placeholder(tf.float32, [None, 5])
weight = tf.Variable(tf.random_uniform([5, 10]))
bias = tf.Variable(tf.random_uniform([10]))
middle = tf.nn.bias_add(tf.matmul(input, weight), bias)
output= tf.nn.tanh(middle)

tensor = np.random.rand(5, 5)
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
tensorflow_result = sess.run(output, {input: tensor})
bigdl_model = convert([input], [output], sess)
bigdl_result = bigdl_model.forward(tensor)

print("Tensorflow forward result is " + str(tensorflow_result))
print("BigDL forward result is " + str(bigdl_result))

np.testing.assert_almost_equal(tensorflow_result, bigdl_result, 6)
print("The results are almost equal in 6 decimals")


if __name__ == "__main__":
main()
19 changes: 19 additions & 0 deletions pyspark/test/local_integration/commands/run-tf-example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

cd "`dirname $0`"

$PYTHON_EXECUTABLE $BIGDL_HOME/pyspark/example/tf_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class Dropout[T: ClassTag](
}

object Dropout {
def apply[@specialized(Float, Double) T: ClassTag](
def apply[T: ClassTag](
initP: Double = 0.5,
inplace: Boolean = false,
scale: Boolean = true)(implicit ev: TensorNumeric[T]) : Dropout[T] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.nn.Graph.ModuleNode
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity, TensorModule}
import com.intel.analytics.bigdl.nn.tf.WithoutInput
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.{Node, T, Table}
Expand Down Expand Up @@ -64,7 +65,7 @@ class Graph[T: ClassTag](val inputs : Seq[ModuleNode[T]],
var i = 0
while(i < executions.length) {
val node = executions(i)
inputsBP(i) = if (node.prevNodes.isEmpty) {
inputsBP(i) = if (node.prevNodes.isEmpty && !node.element.isInstanceOf[WithoutInput]) {
inputData(node, input)
} else if (node.prevNodes.length == 1) {
node.prevNodes.head.element.output.toTensor[T]
Expand Down Expand Up @@ -198,6 +199,7 @@ class Graph[T: ClassTag](val inputs : Seq[ModuleNode[T]],

private def checkRoots : Unit = {
val roots = executions.filter(_.prevNodes.size == 0)
.filter(node => !node.element.isInstanceOf[WithoutInput])
require(roots.size == inputs.length,
s"There're ${inputs.length} inputs, but graph has ${roots.size} roots")
inputs.foreach(n =>
Expand Down Expand Up @@ -325,6 +327,12 @@ class Input[T: ClassTag]()(implicit ev: TensorNumeric[T]) extends TensorModule[T
gradInput = gradOutput
gradInput
}
override def equals(other: Any): Boolean = {
if (!other.isInstanceOf[Input[_]]) return false
this.eq(other.asInstanceOf[Input[_]])
}

override def hashCode(): Int = System.identityHashCode(this)
}

object Input {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ import scala.reflect.ClassTag
@SerialVersionUID(2995626598003841724L)
class Mean[T: ClassTag](
val dimension: Int = 1,
nInputDims: Int = -1,
squeeze: Boolean = true)
val nInputDims: Int = -1,
val squeeze: Boolean = true)
(implicit ev: TensorNumeric[T]) extends Sum[T](dimension, nInputDims, true, squeeze) {
override def toString: String = s"nn.Mean"
}
Expand Down
19 changes: 19 additions & 0 deletions spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Module.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
*/
package com.intel.analytics.bigdl.nn

import java.nio.ByteOrder

import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.{CaffeLoader, File}
import com.intel.analytics.bigdl.utils.tf.{TensorflowDataFormat, TensorflowLoader}

import scala.reflect.ClassTag

Expand All @@ -38,6 +42,21 @@ object Module {
CaffeLoader.load[T](model, defPath, modelPath, matchAll)
}

/**
* Load tensorflow model from its saved protobuf file.
* @param file where is the protobuf model file
* @param inputs input node names
* @param outputs output node names, the output tensor order is same with the node order
* @param byteOrder byte order in the tensorflow file. The default value is little endian
* @return BigDL model
*/
def loadTF[T: ClassTag](file: String, inputs: Seq[String], outputs: Seq[String],
byteOrder: ByteOrder = ByteOrder.LITTLE_ENDIAN)(
implicit ev: TensorNumeric[T]): Module[T] = {

TensorflowLoader.load(file, inputs, outputs, byteOrder)
}

def flatten[@specialized(Float, Double) T: ClassTag](parameters: Array[Tensor[T]])(
implicit ev: TensorNumeric[T]): Tensor[T] = {
val compactedTensor = isCompact(parameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class Padding[T: ClassTag](
}

object Padding{
def apply[@specialized(Float, Double) T: ClassTag](
def apply[T: ClassTag](
dim: Int,
pad: Int,
nInputDim: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class Reshape[@specialized(Float, Double) T: ClassTag](
}

object Reshape {
def apply[@specialized(Float, Double) T: ClassTag](
def apply[T: ClassTag](
size: Array[Int],
batchMode: Option[Boolean] = None)(implicit ev: TensorNumeric[T]) : Reshape[T] = {
new Reshape[T](size, batchMode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Sigmoid[@specialized(Float, Double) T: ClassTag](
}

object Sigmoid {
def apply[@specialized(Float, Double) T: ClassTag]()
def apply[T: ClassTag]()
(implicit ev: TensorNumeric[T]) : Sigmoid[T] = {
new Sigmoid[T]()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ class SpatialAveragePooling[@specialized(Float, Double) T: ClassTag](
}

object SpatialAveragePooling {
def apply[@specialized(Float, Double) T: ClassTag](
def apply[T: ClassTag](
kW: Int,
kH: Int,
dW: Int = 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ object SpatialMaxPooling {
def apply[@specialized(Float, Double) T: ClassTag](
kW: Int,
kH: Int,
dW: Int,
dH: Int,
dW: Int = 1,
dH: Int = 1,
padW: Int = 0,
padH: Int = 0)(implicit ev: TensorNumeric[T]): SpatialMaxPooling[T] = {
new SpatialMaxPooling[T](kW, kH, dW, dH, padW, padH)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Tanh[@specialized(Float, Double) T: ClassTag](


object Tanh {
def apply[@specialized(Float, Double) T: ClassTag]()
def apply[T: ClassTag]()
(implicit ev: TensorNumeric[T]) : Tanh[T] = {
new Tanh[T]()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.bigdl.nn
package com.intel.analytics.bigdl.nn.tf

import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.tensor.Tensor
Expand All @@ -22,13 +22,15 @@ import com.intel.analytics.bigdl.utils.{T, Table}

import scala.reflect.ClassTag

private[bigdl] trait WithoutInput

/**
* Return a constant tensor defined by value
* @param value the constant tensor to be returned in forward
*/
@SerialVersionUID(-4008935551091949324L)
class Const[T: ClassTag](value: Tensor[T])(implicit ev: TensorNumeric[T])
extends AbstractModule[Activity, Tensor[T], T] {
private[bigdl] class Const[T: ClassTag](value: Tensor[T])(implicit ev: TensorNumeric[T])
extends AbstractModule[Activity, Tensor[T], T] with WithoutInput {

output = value

Expand Down Expand Up @@ -59,7 +61,7 @@ class Const[T: ClassTag](value: Tensor[T])(implicit ev: TensorNumeric[T])
}
}

object Const {
private[bigdl] object Const {
def apply[T: ClassTag](value: Tensor[T])
(implicit ev: TensorNumeric[T]): Const[T] = {
new Const[T](value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.bigdl.nn
package com.intel.analytics.bigdl.nn.tf

import com.intel.analytics.bigdl.nn.abstractnn.TensorModule
import com.intel.analytics.bigdl.tensor.Tensor
Expand All @@ -27,7 +27,7 @@ import scala.reflect.ClassTag
* @param value the scalar value to be filled.
*/
@SerialVersionUID(-471757174144422555L)
class Fill[T: ClassTag](value: T) (implicit ev: TensorNumeric[T])
private[bigdl] class Fill[T: ClassTag](value: T) (implicit ev: TensorNumeric[T])
extends TensorModule[T] {

override def updateOutput(input: Tensor[T]): Tensor[T] = {
Expand All @@ -44,7 +44,7 @@ class Fill[T: ClassTag](value: T) (implicit ev: TensorNumeric[T])

}

object Fill {
private[bigdl] object Fill {
def apply[T: ClassTag](value: Double)
(implicit ev: TensorNumeric[T]) : Fill[T] = {
new Fill[T](ev.fromType(value))
Expand Down
Loading