diff --git a/pyspark/bigdl/nn/layer.py b/pyspark/bigdl/nn/layer.py index a69764a6d03..a6509ee1a46 100644 --- a/pyspark/bigdl/nn/layer.py +++ b/pyspark/bigdl/nn/layer.py @@ -407,6 +407,16 @@ def load_caffe(model, defPath, modelPath, match_all=True, bigdl_type="float"): jmodel = callBigDlFunc(bigdl_type, "loadCaffe", model, defPath, modelPath, match_all) return Layer.of(jmodel) + @staticmethod + def load_tensorflow(path, inputs, outputs, byte_order = "little_endian", bigdl_type="float"): + """ + Load a pre-trained Tensorflow model. + :param path: The path containing the pre-trained model. + :return: A pre-trained model. + """ + jmodel = callBigDlFunc(bigdl_type, "loadTF", path, inputs, outputs, byte_order) + return Model.of(jmodel) + class Linear(Layer): diff --git a/pyspark/bigdl/util/tf_utils.py b/pyspark/bigdl/util/tf_utils.py new file mode 100644 index 00000000000..ac57a52d862 --- /dev/null +++ b/pyspark/bigdl/util/tf_utils.py @@ -0,0 +1,88 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tempfile + +import tensorflow as tf + +from google.protobuf import text_format + +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import graph_util +from tensorflow.python.framework import importer +from tensorflow.python.platform import gfile + +from bigdl.nn.layer import Model + +def convert(input_ops, output_ops, sess): + """ + Convert tensorflow model to bigdl model + :param input_ops: operation list used for input, should be placeholders + :param output_ops: operations list used for output + :param sess: current tensorflow session + :return: bigdl model + """ + input_names = map(lambda x: x.name.split(":")[0], input_ops) + output_names = map(lambda x: x.name.split(":")[0], output_ops) + temp = tempfile.mkdtemp() + + saver = tf.train.Saver() + saver.save(sess, temp + '/model.chkp') + tf.train.write_graph(sess.graph, temp, 'model.pbtxt') + + merge_checkpoint(temp + '/model.pbtxt', + temp + '/model.chkp', + output_names, + temp + '/model.pb', sess) + return Model.load_tensorflow(temp + '/model.pb', input_names, output_names) + +def merge_checkpoint(input_graph, + checkpoint, + output_node_names, + output_graph, + sess): + """ + Get the variable values from the checkpoint file, and merge them to the GraphDef file + Args: + input_graph: the GraphDef file, doesn't contain variable values + checkpoint: the checkpoint file + output_node_names: A list of string, the output names + output_graph: String of the location and the name of the + output graph + """ + restore_op_name = "save/restore_all" + filename_tensor_name = "save/Const:0" + + input_graph_def = graph_pb2.GraphDef() + with gfile.FastGFile(input_graph, "r") as f: + text_format.Merge(f.read().decode("utf-8"), input_graph_def) + + for node in input_graph_def.node: + node.device = "" + + importer.import_graph_def(input_graph_def, name="") + + sess.run([restore_op_name], {filename_tensor_name: checkpoint}) + output_graph_def = graph_util.convert_variables_to_constants( + sess, + input_graph_def, + output_node_names, + variable_names_blacklist="" + ) + + with gfile.GFile(output_graph, "wb") as f: + f.write(output_graph_def.SerializeToString()) \ No newline at end of file diff --git a/pyspark/example/tf_example.py b/pyspark/example/tf_example.py new file mode 100644 index 00000000000..b341d0adab6 --- /dev/null +++ b/pyspark/example/tf_example.py @@ -0,0 +1,44 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorflow as tf +import numpy as np +from bigdl.util.tf_utils import convert + +def main(): + input = tf.placeholder(tf.float32, [None, 5]) + weight = tf.Variable(tf.random_uniform([5, 10])) + bias = tf.Variable(tf.random_uniform([10])) + middle = tf.nn.bias_add(tf.matmul(input, weight), bias) + output= tf.nn.tanh(middle) + + tensor = np.random.rand(5, 5) + with tf.Session() as sess: + init = tf.global_variables_initializer() + sess.run(init) + tensorflow_result = sess.run(output, {input: tensor}) + bigdl_model = convert([input], [output], sess) + bigdl_result = bigdl_model.forward(tensor) + + print("Tensorflow forward result is " + str(tensorflow_result)) + print("BigDL forward result is " + str(bigdl_result)) + + np.testing.assert_almost_equal(tensorflow_result, bigdl_result, 6) + print("The results are almost equal in 6 decimals") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyspark/test/local_integration/commands/run-tf-example.sh b/pyspark/test/local_integration/commands/run-tf-example.sh new file mode 100755 index 00000000000..66f331a929b --- /dev/null +++ b/pyspark/test/local_integration/commands/run-tf-example.sh @@ -0,0 +1,19 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +cd "`dirname $0`" + +$PYTHON_EXECUTABLE $BIGDL_HOME/pyspark/example/tf_example.py \ No newline at end of file diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Dropout.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Dropout.scala index 71de6e7ba4d..0d1ffb923e3 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Dropout.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Dropout.scala @@ -207,7 +207,7 @@ class Dropout[T: ClassTag]( } object Dropout { - def apply[@specialized(Float, Double) T: ClassTag]( + def apply[T: ClassTag]( initP: Double = 0.5, inplace: Boolean = false, scale: Boolean = true)(implicit ev: TensorNumeric[T]) : Dropout[T] = { diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Graph.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Graph.scala index e8cc86bae31..bad75142b13 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Graph.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Graph.scala @@ -17,6 +17,7 @@ package com.intel.analytics.bigdl.nn import com.intel.analytics.bigdl.nn.Graph.ModuleNode import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity, TensorModule} +import com.intel.analytics.bigdl.nn.tf.WithoutInput import com.intel.analytics.bigdl.tensor.Tensor import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric import com.intel.analytics.bigdl.utils.{Node, T, Table} @@ -64,7 +65,7 @@ class Graph[T: ClassTag](val inputs : Seq[ModuleNode[T]], var i = 0 while(i < executions.length) { val node = executions(i) - inputsBP(i) = if (node.prevNodes.isEmpty) { + inputsBP(i) = if (node.prevNodes.isEmpty && !node.element.isInstanceOf[WithoutInput]) { inputData(node, input) } else if (node.prevNodes.length == 1) { node.prevNodes.head.element.output.toTensor[T] @@ -198,6 +199,7 @@ class Graph[T: ClassTag](val inputs : Seq[ModuleNode[T]], private def checkRoots : Unit = { val roots = executions.filter(_.prevNodes.size == 0) + .filter(node => !node.element.isInstanceOf[WithoutInput]) require(roots.size == inputs.length, s"There're ${inputs.length} inputs, but graph has ${roots.size} roots") inputs.foreach(n => @@ -325,6 +327,12 @@ class Input[T: ClassTag]()(implicit ev: TensorNumeric[T]) extends TensorModule[T gradInput = gradOutput gradInput } + override def equals(other: Any): Boolean = { + if (!other.isInstanceOf[Input[_]]) return false + this.eq(other.asInstanceOf[Input[_]]) + } + + override def hashCode(): Int = System.identityHashCode(this) } object Input { diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Mean.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Mean.scala index cbe8c7c9f29..90999b78714 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Mean.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Mean.scala @@ -40,8 +40,8 @@ import scala.reflect.ClassTag @SerialVersionUID(2995626598003841724L) class Mean[T: ClassTag]( val dimension: Int = 1, - nInputDims: Int = -1, - squeeze: Boolean = true) + val nInputDims: Int = -1, + val squeeze: Boolean = true) (implicit ev: TensorNumeric[T]) extends Sum[T](dimension, nInputDims, true, squeeze) { override def toString: String = s"nn.Mean" } diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Module.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Module.scala index 34a7b7b88b3..294401b2b7b 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Module.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Module.scala @@ -15,11 +15,15 @@ */ package com.intel.analytics.bigdl.nn +import java.nio.ByteOrder + +import com.intel.analytics.bigdl.Module import com.intel.analytics.bigdl.nn.abstractnn.Activity import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule import com.intel.analytics.bigdl.tensor.Tensor import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric import com.intel.analytics.bigdl.utils.{CaffeLoader, File} +import com.intel.analytics.bigdl.utils.tf.{TensorflowDataFormat, TensorflowLoader} import scala.reflect.ClassTag @@ -38,6 +42,21 @@ object Module { CaffeLoader.load[T](model, defPath, modelPath, matchAll) } + /** + * Load tensorflow model from its saved protobuf file. + * @param file where is the protobuf model file + * @param inputs input node names + * @param outputs output node names, the output tensor order is same with the node order + * @param byteOrder byte order in the tensorflow file. The default value is little endian + * @return BigDL model + */ + def loadTF[T: ClassTag](file: String, inputs: Seq[String], outputs: Seq[String], + byteOrder: ByteOrder = ByteOrder.LITTLE_ENDIAN)( + implicit ev: TensorNumeric[T]): Module[T] = { + + TensorflowLoader.load(file, inputs, outputs, byteOrder) + } + def flatten[@specialized(Float, Double) T: ClassTag](parameters: Array[Tensor[T]])( implicit ev: TensorNumeric[T]): Tensor[T] = { val compactedTensor = isCompact(parameters) diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Padding.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Padding.scala index 7004f68de58..b67377d61dc 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Padding.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Padding.scala @@ -116,7 +116,7 @@ class Padding[T: ClassTag]( } object Padding{ - def apply[@specialized(Float, Double) T: ClassTag]( + def apply[T: ClassTag]( dim: Int, pad: Int, nInputDim: Int, diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Reshape.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Reshape.scala index 2c44ee825e9..48babd206b0 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Reshape.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Reshape.scala @@ -128,7 +128,7 @@ class Reshape[@specialized(Float, Double) T: ClassTag]( } object Reshape { - def apply[@specialized(Float, Double) T: ClassTag]( + def apply[T: ClassTag]( size: Array[Int], batchMode: Option[Boolean] = None)(implicit ev: TensorNumeric[T]) : Reshape[T] = { new Reshape[T](size, batchMode) diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Sigmoid.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Sigmoid.scala index 4e6a77b764a..d9f4ba477ec 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Sigmoid.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Sigmoid.scala @@ -47,7 +47,7 @@ class Sigmoid[@specialized(Float, Double) T: ClassTag]( } object Sigmoid { - def apply[@specialized(Float, Double) T: ClassTag]() + def apply[T: ClassTag]() (implicit ev: TensorNumeric[T]) : Sigmoid[T] = { new Sigmoid[T]() } diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/SpatialAveragePooling.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/SpatialAveragePooling.scala index c61a3bd6e73..9957e0a10a4 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/SpatialAveragePooling.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/SpatialAveragePooling.scala @@ -473,7 +473,7 @@ class SpatialAveragePooling[@specialized(Float, Double) T: ClassTag]( } object SpatialAveragePooling { - def apply[@specialized(Float, Double) T: ClassTag]( + def apply[T: ClassTag]( kW: Int, kH: Int, dW: Int = 1, diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/SpatialMaxPooling.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/SpatialMaxPooling.scala index 31230ad575a..b7d155a158f 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/SpatialMaxPooling.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/SpatialMaxPooling.scala @@ -290,8 +290,8 @@ object SpatialMaxPooling { def apply[@specialized(Float, Double) T: ClassTag]( kW: Int, kH: Int, - dW: Int, - dH: Int, + dW: Int = 1, + dH: Int = 1, padW: Int = 0, padH: Int = 0)(implicit ev: TensorNumeric[T]): SpatialMaxPooling[T] = { new SpatialMaxPooling[T](kW, kH, dW, dH, padW, padH) diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Tanh.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Tanh.scala index 244a01d1b2c..caba8c4b04b 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Tanh.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Tanh.scala @@ -48,7 +48,7 @@ class Tanh[@specialized(Float, Double) T: ClassTag]( object Tanh { - def apply[@specialized(Float, Double) T: ClassTag]() + def apply[T: ClassTag]() (implicit ev: TensorNumeric[T]) : Tanh[T] = { new Tanh[T]() } diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Const.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/tf/Const.scala similarity index 89% rename from spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Const.scala rename to spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/tf/Const.scala index 8b582b83425..2e46da2c5b3 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Const.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/tf/Const.scala @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.intel.analytics.bigdl.nn +package com.intel.analytics.bigdl.nn.tf import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity} import com.intel.analytics.bigdl.tensor.Tensor @@ -22,13 +22,15 @@ import com.intel.analytics.bigdl.utils.{T, Table} import scala.reflect.ClassTag +private[bigdl] trait WithoutInput + /** * Return a constant tensor defined by value * @param value the constant tensor to be returned in forward */ @SerialVersionUID(-4008935551091949324L) -class Const[T: ClassTag](value: Tensor[T])(implicit ev: TensorNumeric[T]) - extends AbstractModule[Activity, Tensor[T], T] { +private[bigdl] class Const[T: ClassTag](value: Tensor[T])(implicit ev: TensorNumeric[T]) + extends AbstractModule[Activity, Tensor[T], T] with WithoutInput { output = value @@ -59,7 +61,7 @@ class Const[T: ClassTag](value: Tensor[T])(implicit ev: TensorNumeric[T]) } } -object Const { +private[bigdl] object Const { def apply[T: ClassTag](value: Tensor[T]) (implicit ev: TensorNumeric[T]): Const[T] = { new Const[T](value) diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Fill.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/tf/Fill.scala similarity index 91% rename from spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Fill.scala rename to spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/tf/Fill.scala index 4e2a65dbe22..295989e1788 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Fill.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/tf/Fill.scala @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.intel.analytics.bigdl.nn +package com.intel.analytics.bigdl.nn.tf import com.intel.analytics.bigdl.nn.abstractnn.TensorModule import com.intel.analytics.bigdl.tensor.Tensor @@ -27,7 +27,7 @@ import scala.reflect.ClassTag * @param value the scalar value to be filled. */ @SerialVersionUID(-471757174144422555L) -class Fill[T: ClassTag](value: T) (implicit ev: TensorNumeric[T]) +private[bigdl] class Fill[T: ClassTag](value: T) (implicit ev: TensorNumeric[T]) extends TensorModule[T] { override def updateOutput(input: Tensor[T]): Tensor[T] = { @@ -44,7 +44,7 @@ class Fill[T: ClassTag](value: T) (implicit ev: TensorNumeric[T]) } -object Fill { +private[bigdl] object Fill { def apply[T: ClassTag](value: Double) (implicit ev: TensorNumeric[T]) : Fill[T] = { new Fill[T](ev.fromType(value)) diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Shape.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/tf/Shape.scala similarity index 86% rename from spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Shape.scala rename to spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/tf/Shape.scala index 406564b0247..0a052487652 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Shape.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/tf/Shape.scala @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.intel.analytics.bigdl.nn +package com.intel.analytics.bigdl.nn.tf -import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, TensorModule} +import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule import com.intel.analytics.bigdl.tensor.Tensor import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric @@ -25,7 +25,7 @@ import scala.reflect.ClassTag * Given input, return the shape of this input as a 1-D tensor */ @SerialVersionUID(-907995771209831179L) -class Shape[T: ClassTag](implicit ev: TensorNumeric[T]) +private[bigdl] class Shape[T: ClassTag](implicit ev: TensorNumeric[T]) extends AbstractModule[Tensor[T], Tensor[T], T] { override def updateOutput(input: Tensor[T]): Tensor[T] = { @@ -40,7 +40,7 @@ class Shape[T: ClassTag](implicit ev: TensorNumeric[T]) } } -object Shape { +private[bigdl] object Shape { def apply[T: ClassTag]()(implicit ev: TensorNumeric[T]): Shape[T] = { new Shape[T]() } diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/SplitAndSelect.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/tf/SplitAndSelect.scala similarity index 92% rename from spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/SplitAndSelect.scala rename to spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/tf/SplitAndSelect.scala index e832e67a63c..2706cdb8a7c 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/SplitAndSelect.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/tf/SplitAndSelect.scala @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.intel.analytics.bigdl.nn +package com.intel.analytics.bigdl.nn.tf import com.intel.analytics.bigdl.nn.abstractnn.TensorModule -import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric import com.intel.analytics.bigdl.tensor.Tensor +import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric import scala.reflect.ClassTag @@ -26,7 +26,7 @@ import scala.reflect.ClassTag * then select the [[index]]th one */ @SerialVersionUID(-9096120159559947483L) -class SplitAndSelect[T: ClassTag](dimension: Int, index: Int, numSplit: Int) +private[bigdl] class SplitAndSelect[T: ClassTag](dimension: Int, index: Int, numSplit: Int) (implicit ev: TensorNumeric[T]) extends TensorModule[T] { override def updateOutput(input: Tensor[T]): Tensor[T] = { @@ -53,7 +53,7 @@ class SplitAndSelect[T: ClassTag](dimension: Int, index: Int, numSplit: Int) } } -object SplitAndSelect { +private[bigdl] object SplitAndSelect { def apply[T: ClassTag]( dimension: Int, index: Int, diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/StrideSlice.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/tf/StrideSlice.scala similarity index 92% rename from spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/StrideSlice.scala rename to spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/tf/StrideSlice.scala index 86ac676ceaa..7338777a889 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/StrideSlice.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/tf/StrideSlice.scala @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.intel.analytics.bigdl.nn +package com.intel.analytics.bigdl.nn.tf import com.intel.analytics.bigdl.nn.abstractnn.TensorModule import com.intel.analytics.bigdl.tensor.Tensor @@ -26,7 +26,7 @@ import scala.reflect.ClassTag * @param sliceSpecs Array(dim, begin_index, end_index, stride) */ @SerialVersionUID(4436600172725317184L) -class StrideSlice[T: ClassTag](sliceSpecs: Array[(Int, Int, Int, Int)]) +private[bigdl] class StrideSlice[T: ClassTag](sliceSpecs: Array[(Int, Int, Int, Int)]) (implicit ev: TensorNumeric[T]) extends TensorModule[T] { require(sliceSpecs.map(_._4 == 1).reduce(_ && _), "only support stride 1 for now") @@ -57,7 +57,7 @@ class StrideSlice[T: ClassTag](sliceSpecs: Array[(Int, Int, Int, Int)]) } -object StrideSlice { +private[bigdl] object StrideSlice { def apply[T: ClassTag](sliceSpecs: Array[(Int, Int, Int, Int)]) (implicit ev: TensorNumeric[T]) : StrideSlice[T] = { new StrideSlice[T](sliceSpecs) diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/python/api/PythonBigDL.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/python/api/PythonBigDL.scala index 9a5bebaa46a..d3489da0f3f 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/python/api/PythonBigDL.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/python/api/PythonBigDL.scala @@ -32,8 +32,10 @@ import com.intel.analytics.bigdl.nn.Zeros import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import java.lang.{Integer, Boolean => JBoolean} +import java.nio.ByteOrder import com.intel.analytics.bigdl.nn.Graph._ +import com.intel.analytics.bigdl.nn.tf.{Const, Fill, Shape, SplitAndSelect} import scala.collection.JavaConverters._ import scala.language.existentials @@ -1280,6 +1282,16 @@ class PythonBigDL[T: ClassTag](implicit ev: TensorNumeric[T]) extends Serializab Module.loadCaffe[T](model, defPath, modelPath, matchAll) } + def loadTF(path: String, inputs: JList[String], outputs: JList[String], + byteOrder: String): AbstractModule[Activity, Activity, T] = { + val order = byteOrder match { + case "little_endian" => ByteOrder.LITTLE_ENDIAN + case "big_endian" => ByteOrder.BIG_ENDIAN + case _ => throw new IllegalArgumentException(s"No support byte order $byteOrder") + } + Module.loadTF[T](path, inputs.asScala, outputs.asScala, order) + } + def modelPredictRDD(model: AbstractModule[Activity, Activity, T], dataRdd: JavaRDD[Sample]): JavaRDD[JTensor] = { val tensorRDD = model.predict(dataRdd.rdd.map(toSample(_))) diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/tensor/DenseTensor.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/tensor/DenseTensor.scala index 2e1f4cc0c3d..aae6d6707e2 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/tensor/DenseTensor.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/tensor/DenseTensor.scala @@ -1906,6 +1906,8 @@ private[tensor] class DenseTensor[@specialized(Float, Double) T: ClassTag]( "corresponding module, please keep them same.") } } + + override def getTensorNumeric(): TensorNumeric[T] = ev } object DenseTensor { diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/tensor/Tensor.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/tensor/Tensor.scala index 1d092b68a6b..ed4d5305f97 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/tensor/Tensor.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/tensor/Tensor.scala @@ -646,6 +646,31 @@ trait Tensor[T] extends Serializable with TensorMath[T] with Activity { * @return false */ override def isTable: Boolean = false + + /** + * Return tensor numeric + * @return + */ + def getTensorNumeric(): TensorNumeric[T] + + /** + * Compare with other tensor. The shape of the other tensor must be same with this tensor. + * If element wise difference is less than delta, return true. + * @param other + * @param delta + * @return + */ + def almostEqual(other: Tensor[T], delta : Double): Boolean = { + var result = true + this.map(other, (a, b) => { + val tn = getTensorNumeric() + if (tn.isGreater(tn.abs(tn.minus(a, b)), tn.fromType(delta))) { + result = false + } + a + }) + return result + } } /** diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/BigDLToTensorflow.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/BigDLToTensorflow.scala new file mode 100644 index 00000000000..74e736ea835 --- /dev/null +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/BigDLToTensorflow.scala @@ -0,0 +1,318 @@ +/* + * Copyright 2016 The BigDL Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.intel.analytics.bigdl.utils.tf + +import java.nio.ByteOrder + +import com.intel.analytics.bigdl.nn._ +import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule +import com.intel.analytics.bigdl.tensor.Tensor +import com.intel.analytics.bigdl.utils.T +import Tensorflow._ +import BigDLToTensorflow._ +import org.tensorflow.framework.{DataType, NodeDef} + +import scala.collection.mutable.ArrayBuffer + +/** + * Wrapper of logic to convert module to tensorflow node definition + */ +trait BigDLToTensorflow { + + /** + * Convert the module to a tensorflow nodedef + * @return Mapped nodedef list, the first is the output node + */ + def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] +} + +object BigDLToTensorflow { + /** + * This method is just for test purpose. Do not use the bigdl.saveNHWC for real use case + * @return + */ + private[tf] def processSaveDim(dim: Int): Int = { + if (System.getProperty("bigdl.enableNHWC", "false").toBoolean) { + if (dim == 2) return 4 + if (dim == 3) return 2 + if (dim == 4) return 3 + dim + } else { + dim + } + } + + /** + * This method is just for test purpose. Do not use the bigdl.enableNHWC for real use case + * @return + */ + private[tf] def getDataFormat(): TensorflowDataFormat = { + if (System.getProperty("bigdl.enableNHWC", "false").toBoolean) { + TensorflowDataFormat.NHWC + } else { + TensorflowDataFormat.NCHW + } + } +} + +object InputToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 1, "Input only accept one input") + + Seq(identity(inputs(0), module.getName())) + } +} + +object ReLUToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 1, "Relu only accept one input") + + Seq(relu(inputs(0), module.getName())) + } +} + +object LinearToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 1, "Linear only accept one input") + val linear = module.asInstanceOf[Linear[_]] + val weight = const(linear.weight.t().contiguous(), linear.getName() + "/weight", byteOrder) + val weightReader = identity(weight, linear.getName() + "/weightReader") + val mm = matmul(inputs(0), weightReader, linear.getName() + "/matmul") + val bias = const(linear.bias, linear.getName() + "/bias", byteOrder) + val biasReader = identity(bias, linear.getName() + "/biasReader") + val add = biasAdd(mm, biasReader, getDataFormat(), linear.getName() + "/biasAdd") + Seq(add, biasReader, bias, mm, weightReader, weight) + } +} + +object SpatialConvolutionToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 1, "SpatialConvolution only accept one input") + val spatialConv = module.asInstanceOf[SpatialConvolution[_]] + // squeeze will modify the weight tensor + // GOIHW -> HWIO + require(spatialConv.weight.size(1) == 1, "convolution group is not supported") + val filterTensor = spatialConv.weight.select(1, 1) + .transpose(2, 3).transpose(3, 4).transpose(1, 2).transpose(2, 3).transpose(3, 4).contiguous() + + val filter = const(filterTensor, spatialConv.getName() + "/filter", byteOrder) + val filterReader = identity(filter, spatialConv.getName() + "/filterReader") + val conv = conv2D(inputs(0), filterReader, spatialConv.strideW, spatialConv.strideH, + spatialConv.kernelW, spatialConv.kernelH, spatialConv.padW, spatialConv.padH, + getDataFormat(), spatialConv.getName() + "/conv2D") + val bias = const(spatialConv.bias, spatialConv.getName() + "/bias", byteOrder) + val biasReader = identity(bias, spatialConv.getName() + "/biasReader") + val add = biasAdd(conv, biasReader, getDataFormat(), + spatialConv.getName() + "/biasAdd") + Seq(add, biasReader, bias, conv, filterReader, filter) + } +} + +object SqueezeToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 1, "Squeeze only accept one input") + val sq = module.asInstanceOf[Squeeze[_]] + Seq(squeeze(inputs(0), sq.dims.map(processSaveDim(_) - 1), sq.getName())) + } +} + +object TanhToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 1, "Tanh only accept one input") + Seq(tanh(inputs(0), module.getName())) + } +} + +object ReshapeToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 1, "Reshape only accept one input") + val rh = module.asInstanceOf[Reshape[_]] + val size = Tensor[Float](rh.size.length) + var i = 0 + while(i < rh.size.length) { + size.setValue(i + 1, rh.size(i)) + i += 1 + } + val shape = const(size, rh.getName() + "/shape", byteOrder, false, DataType.DT_INT32) + val reshapeNode = reshape(inputs(0), shape, rh.getName()) + Seq(reshapeNode, shape) + } +} + +object ViewToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 1, "Reshape only accept one input") + val viewLayer = module.asInstanceOf[View[_]] + val size = Tensor[Float](viewLayer.sizes.length) + var i = 0 + while(i < viewLayer.sizes.length) { + size.setValue(i + 1, viewLayer.sizes(i)) + i += 1 + } + val shape = const(size, viewLayer.getName() + "/shape", byteOrder, false, DataType.DT_INT32) + val reshapeNode = reshape(inputs(0), shape, viewLayer.getName()) + Seq(reshapeNode, shape) + } +} + +object MaxpoolToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 1, "Maxpool only accept one input") + val layer = module.asInstanceOf[SpatialMaxPooling[_]] + Seq(maxPool(inputs(0), layer.kW, layer.kH, layer.padW, layer.padH, + layer.dW, layer.dH, getDataFormat(), layer.getName())) + } +} + +object PaddingToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 1, "Padding only accept one input") + val layer = module.asInstanceOf[Padding[_]] + require(layer.nIndex == 1, "only support padding nIndex == 1") + require(layer.nInputDim > 0, "nInputDim must be explicit specified") + val padding = Tensor[Float](layer.nInputDim, 2).zero() + if (layer.pad < 0) { + padding.setValue(layer.dim, 1, -layer.pad) + } + else { + padding.setValue(layer.dim, 2, layer.pad) + } + val paddingsNode = const(padding, layer.getName() + "/padding", byteOrder, + false, DataType.DT_INT32) + val padNode = pad(inputs(0), paddingsNode, layer.getName() + "/output") + Seq(padNode, paddingsNode) + } +} + +object AvgpoolToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 1, "Avgpool only accept one input") + val layer = module.asInstanceOf[SpatialAveragePooling[_]] + Seq(avgPool(inputs(0), layer.kW, layer.kH, layer.padW, layer.padH, + layer.dW, layer.dH, getDataFormat(), layer.getName())) + } +} + +object SigmoidToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 1, "Sigmoid only accept one input") + Seq(sigmoid(inputs(0), module.getName())) + } +} + +object DropoutToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 1, "Dropout only accept one input") + val layer = module.asInstanceOf[Dropout[_]] + require(layer.isTraining() == false, "only support evaluating mode dropout") + require(inputs.length == 1, "require only one tensor input") + Seq(identity(inputs(0), layer.getName())) + } +} + +object CAddTableToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + Seq(addN(inputs, module.getName())) + } +} + +object CMultTableToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 2, "Tensorflow only support two tensor multiply together") + + Seq(multiply(inputs(0), inputs(1), module.getName())) + } +} + +object JoinTableToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + val layer = module.asInstanceOf[JoinTable[_]] + val axis = const(Tensor[Float](T((layer.dimension - 1).toFloat)), layer.getName() + "/axis", + byteOrder, true, DataType.DT_INT32) + val updateInputs = new ArrayBuffer[NodeDef]() + updateInputs ++= inputs.reverse + updateInputs.append(axis) + Seq(concat(updateInputs, layer.dimension - 1, layer.getName()), axis) + } +} + +object MeanToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 1, "Mean only accept one input") + val layer = module.asInstanceOf[Mean[_]] + require(layer.squeeze == true, "Mean must squeeze input") + val dimsTensor = Tensor[Float](layer.dimension) + dimsTensor.setValue(1, layer.dimension - 1) + + val dims = const(dimsTensor, layer.getName() + "/dims", byteOrder, false, DataType.DT_INT32) + val mean = reduceMean(inputs(0), dims, false, layer.getName() + "/output") + Seq(mean, dims) + } +} + +object SoftMaxToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 1, "Softmax only accept one input") + Seq(softmax(inputs(0), module.getName())) + } +} + +object LogSoftMaxToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 1, "LogSoftmax only accept one input") + Seq(logSoftmax(inputs(0), module.getName())) + } +} + +object BatchNorm2DToTF extends BigDLToTensorflow { + override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef], + byteOrder: ByteOrder): Seq[NodeDef] = { + require(inputs.length == 1, "BatchNorm only accept one input") + val layer = module.asInstanceOf[SpatialBatchNormalization[_]] + require(!layer.isTraining(), "Only support evaluate mode batch norm") + val varNode = const(layer.runningVar, layer.getName() + "/std", byteOrder) + val mean = const(layer.runningMean, layer.getName() + "/mean", byteOrder) + val scale = const(layer.weight, layer.getName() + "/scale", byteOrder) + val offset = const(layer.bias, layer.getName() + "/offset", byteOrder) + val sqrtVar = rsqrt(varNode, layer.getName() + "/stdvar") + val mul0 = multiply(scale, sqrtVar, layer.getName() + "/mul0") + val mul1 = multiply(inputs(0), mul0, layer.getName() + "/mul1") + val mul2 = multiply(mean, mul0, layer.getName() + "/mul2") + val sub = subtract(offset, mul2, layer.getName() + "/sub") + val output = add(mul1, sub, layer.getName() + "/output") + Seq(output, sub, mul2, mul1, mul0, offset, scale, mean, sqrtVar, varNode) + } +} diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/Tensorflow.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/Tensorflow.scala new file mode 100644 index 00000000000..a6f0dc5e876 --- /dev/null +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/Tensorflow.scala @@ -0,0 +1,593 @@ +/* + * Copyright 2016 The BigDL Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.intel.analytics.bigdl.utils.tf + +import java.nio.{ByteBuffer, ByteOrder} +import java.nio.charset.Charset + +import com.google.protobuf.ByteString +import com.intel.analytics.bigdl.tensor.{DoubleType, FloatType, Tensor, TensorDataType} +import org.tensorflow.framework.AttrValue.ListValue +import org.tensorflow.framework._ +import org.tensorflow.framework.TensorShapeProto.Dim + +import scala.reflect.{ClassTag, classTag} + +/** + * Tensorflow data format. It is mostly applied in processing image type data + */ +sealed trait TensorflowDataFormat { + def value : AttrValue +} + +object TensorflowDataFormat { + /** + * Store the image data in tensor as batch, height, width, channel + */ + object NHWC extends TensorflowDataFormat { + private val v = AttrValue.newBuilder().setS(ByteString + .copyFrom("NHWC", Charset.defaultCharset())).build() + + override def value: AttrValue = v + } + + /** + * Store the image data in tensor as batch, channel, height, width + */ + object NCHW extends TensorflowDataFormat { + private val v = AttrValue.newBuilder().setS(ByteString + .copyFrom("NCHW", Charset.defaultCharset())).build() + + override def value: AttrValue = v + } +} + +/** + * Tensorflow padding type + */ +sealed trait PaddingType { + def value : AttrValue +} + +object PaddingType { + + object PADDING_SAME extends PaddingType { + private val v = AttrValue.newBuilder().setS(ByteString + .copyFrom("SAME", Charset.defaultCharset())).build() + + override def value: AttrValue = v + } + + object PADDING_VALID extends PaddingType { + private val v = AttrValue.newBuilder().setS(ByteString + .copyFrom("VALID", Charset.defaultCharset())).build() + + override def value: AttrValue = v + } +} + +object Tensorflow { + /** + * Generate a placeholder tensorflow protobuf node + * @param dtype numeric type + * @param shape shape + * @param name node name + * @return + */ + def placeholder(dtype: TensorDataType, shape: Seq[Int], name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("Placeholder") + .putAttr("dtype", typeAttr(dtype)) + .putAttr("shape", shapeAttr(shape)) + .build() + } + + /** + * Generate a const tensorflow protobuf node + * @param value + * @param name + * @return + */ + def const[T: ClassTag](value : Tensor[T], name : String, byteOrder: ByteOrder, + isScalar: Boolean = false, dataType: DataType = null): NodeDef = { + val dtype = if (dataType == null) { + if (value.getType() == DoubleType) { + DataType.DT_DOUBLE + } else { + DataType.DT_FLOAT + } + } else { + dataType + } + + NodeDef.newBuilder() + .setName(name) + .setOp("Const") + .putAttr("dtype", AttrValue.newBuilder().setType(dtype).build()) + .putAttr("value", tensorAttr(value, dtype, byteOrder, isScalar)) + .build() + } + + /** + * Generate an identity tensorflow protobuf node + * @param input + * @param name + * @return + */ + def identity(input : NodeDef, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("Identity") + .addInput(input.getName) + .putAttr("T", getDataType(input)) + .build() + } + + /** + * Generate a matmul tensorflow protobuf node + * @param a + * @param b + * @param name + * @param transposeA + * @param transposeB + * @return + */ + def matmul(a: NodeDef, b: NodeDef, name: String, + transposeA: Boolean = false, transposeB: Boolean = false): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("MatMul") + .addInput(a.getName) + .addInput(b.getName) + .putAttr("T", getDataType(a)) + .putAttr("transpose_a", booleanAttr(transposeA)) + .putAttr("transpose_b", booleanAttr(transposeB)) + .build() + } + + /** + * Generate a biasAdd tensorflow protobuf node + * @param value + * @param bias + * @param dataFormat + * @param name + * @return + */ + def biasAdd(value: NodeDef, bias: NodeDef, dataFormat: TensorflowDataFormat, + name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("BiasAdd") + .addInput(value.getName) + .addInput(bias.getName) + .putAttr("T", getDataType(value)) + .putAttr("data_format", dataFormat.value) + .build() + } + + /** + * Generate a relu tensorflow protobuf node + * @param features + * @param name + * @return + */ + def relu(features: NodeDef, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("Relu") + .addInput(features.getName) + .putAttr("T", getDataType(features)) + .build() + } + + def conv2D(input: NodeDef, filter: NodeDef, sW: Int, sH: Int, kW: Int, kH: Int, pW: Int, pH: Int, + dataFormat: TensorflowDataFormat, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("Conv2D") + .addInput(input.getName) + .addInput(filter.getName) + .putAttr("T", getDataType(input)) + .putAttr("data_format", dataFormat.value) + .putAttr("padding", getPaddingType(pW, pH, kW, kH, sW, sH).value) + .putAttr("strides", strideAttr(sW, sH, dataFormat)) + .build() + } + + def squeeze(input: NodeDef, axis: Seq[Int], name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("Squeeze") + .addInput(input.getName) + .putAttr("T", getDataType(input)) + .putAttr("squeeze_dims", listIntAttr(axis)) + .build() + } + + def tanh(input: NodeDef, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("Tanh") + .addInput(input.getName) + .putAttr("T", getDataType(input)) + .build() + } + + def reshape(tensor: NodeDef, shape: NodeDef, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("Reshape") + .addInput(tensor.getName) + .addInput(shape.getName) + .putAttr("T", getDataType(tensor)) + .putAttr("Tshape", getDataType(shape)) + .build() + + } + + def maxPool(value: NodeDef, kW: Int, kH: Int, pW: Int, pH: Int, sW: Int, sH: Int, + dataFormat: TensorflowDataFormat, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("MaxPool") + .addInput(value.getName) + .putAttr("T", getDataType(value)) + .putAttr("data_format", dataFormat.value) + .putAttr("ksize", kernelAttr(kW, kH, dataFormat)) + .putAttr("padding", getPaddingType(pW, pH, kW, kH, sW, sH).value) + .putAttr("strides", strideAttr(sW, sH, dataFormat)) + .build() + } + + def avgPool(value: NodeDef, kW: Int, kH: Int, pW: Int, pH: Int, sW: Int, sH: Int, + dataFormat: TensorflowDataFormat, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("AvgPool") + .putAttr("T", getDataType(value)) + .addInput(value.getName) + .putAttr("data_format", dataFormat.value) + .putAttr("ksize", kernelAttr(kW, kH, dataFormat)) + .putAttr("padding", getPaddingType(pW, pH, kW, kH, sW, sH).value) + .putAttr("strides", strideAttr(sW, sH, dataFormat)) + .build() + } + + def sigmoid(x: NodeDef, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("Sigmoid") + .putAttr("T", getDataType(x)) + .addInput(x.getName) + .build() + } + + def multiply(x: NodeDef, y: NodeDef, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("Mul") + .putAttr("T", getDataType(x)) + .addInput(x.getName) + .addInput(y.getName) + .build() + } + + def floor(x: NodeDef, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("Floor") + .putAttr("T", getDataType(x)) + .addInput(x.getName) + .build() + } + + def add(x: NodeDef, y: NodeDef, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("Add") + .putAttr("T", getDataType(x)) + .addInput(x.getName) + .addInput(y.getName) + .build() + } + + def realdiv(x: NodeDef, y: NodeDef, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("RealDiv") + .putAttr("T", getDataType(x)) + .addInput(x.getName) + .addInput(y.getName) + .build() + } + + def subtract(x: NodeDef, y: NodeDef, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("Sub") + .putAttr("T", getDataType(x)) + .addInput(x.getName) + .addInput(y.getName) + .build() + } + + def shape(input: NodeDef, name: String, outType: DataType = DataType.DT_INT32): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("Shape") + .putAttr("T", getDataType(input)) + .putAttr("out_type", AttrValue.newBuilder().setType(outType).build()) + .build() + } + + def randomUniform(shape: NodeDef, name: String, dtype: DataType = DataType.DT_FLOAT, + seed: Int = 0): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("RandomUniform") + .putAttr("T", getDataType(shape)) + .putAttr("dtype", AttrValue.newBuilder().setType(dtype).build()) + .putAttr("seed", intAttr(seed)) + .putAttr("seed2", intAttr(seed)) + .addInput(shape.getName) + .build() + } + + def addN(inputs: Seq[NodeDef], name: String): NodeDef = { + require(inputs.length >= 2, "at least two inputs for addN") + val node = NodeDef.newBuilder() + .setName(name) + .putAttr("N", intAttr(inputs.length)) + .putAttr("T", getDataType(inputs(0))) + .setOp("AddN") + inputs.foreach(i => node.addInput(i.getName)) + node.build() + } + + def concat(inputs: Seq[NodeDef], axis: Int, name: String): NodeDef = { + require(inputs.length >= 1, "at least one inputs for addN") + + val node = NodeDef.newBuilder() + .setName(name) + .setOp("ConcatV2") + .putAttr("N", intAttr(inputs.length - 1)) + .putAttr("T", getDataType(inputs(0))) + .putAttr("Tidx", AttrValue.newBuilder().setType(DataType.DT_INT32).build()) + + inputs.foreach(i => node.addInput(i.getName)) + + node.build() + } + + def pad(tensor: NodeDef, paddings: NodeDef, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("Pad") + .putAttr("T", getDataType(tensor)) + .putAttr("Tpaddings", getDataType(paddings)) + .addInput(tensor.getName) + .addInput(paddings.getName) + .build() + } + + def reduceMean(inputTensor: NodeDef, axis: NodeDef, keepDim: Boolean, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("Mean") + .putAttr("T", getDataType(inputTensor)) + .putAttr("Tidx", getDataType(axis)) + .putAttr("keep_dims", booleanAttr(keepDim)) + .addInput(inputTensor.getName) + .addInput(axis.getName) + .build() + } + + def softmax(logits: NodeDef, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("Softmax") + .putAttr("T", getDataType(logits)) + .addInput(logits.getName) + .build() + } + + def logSoftmax(logits: NodeDef, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("LogSoftmax") + .putAttr("T", getDataType(logits)) + .addInput(logits.getName) + .build() + } + + def rsqrt(x: NodeDef, name: String): NodeDef = { + NodeDef.newBuilder() + .setName(name) + .setOp("Rsqrt") + .putAttr("T", getDataType(x)) + .addInput(x.getName) + .build() + } + + private def booleanAttr(value: Boolean): AttrValue = { + AttrValue.newBuilder().setB(value).build() + } + + private def intAttr(value: Int): AttrValue = { + AttrValue.newBuilder().setI(value).build() + } + + private def listIntAttr(value: Seq[Int]): AttrValue = { + val list = ListValue.newBuilder() + value.foreach(list.addI(_)) + AttrValue.newBuilder().setList(list).build() + } + + private def tensorAttr[T: ClassTag](value: Tensor[T], dtype: DataType, + byteOrder: ByteOrder, isScalar: Boolean): AttrValue = { + val shape = TensorShapeProto.newBuilder() + if (!isScalar) { + value.size().foreach(dim => { + shape.addDim(Dim.newBuilder().setSize(dim)) + }) + } + require(value.isContiguous(), "only support save a contiguous tensor") + + val content = if (value.getType() == DoubleType) { + val array = value.asInstanceOf[Tensor[Double]].storage().array() + val offset = value.storageOffset() - 1 + if (dtype == DataType.DT_INT32) { + val buffer = ByteBuffer.allocate(array.length * 4) + buffer.order(byteOrder) + var i = 0 + while (i < value.nElement()) { + buffer.putInt(array(i + offset).toInt) + i += 1 + } + buffer + } else if (dtype == DataType.DT_FLOAT) { + val buffer = ByteBuffer.allocate(array.length * 4) + buffer.order(byteOrder) + var i = 0 + while (i < value.nElement()) { + buffer.putFloat(array(i + offset).toFloat) + i += 1 + } + buffer + } else if (dtype == DataType.DT_DOUBLE) { + val buffer = ByteBuffer.allocate(array.length * 8) + buffer.order(byteOrder) + var i = 0 + while (i < value.nElement()) { + buffer.putDouble(array(i + offset)) + i += 1 + } + buffer + } else { + throw new UnsupportedOperationException(s"data type ${dtype} is not supported currently") + } + } else { + val array = value.asInstanceOf[Tensor[Float]].storage().array() + val offset = value.storageOffset() - 1 + if (dtype == DataType.DT_INT32) { + val buffer = ByteBuffer.allocate(array.length * 4) + buffer.order(byteOrder) + var i = 0 + while (i < value.nElement()) { + buffer.putInt(array(i + offset).toInt) + i += 1 + } + buffer + } else if (dtype == DataType.DT_FLOAT) { + val buffer = ByteBuffer.allocate(array.length * 4) + buffer.order(byteOrder) + var i = 0 + while (i < value.nElement()) { + buffer.putFloat(array(i + offset)) + i += 1 + } + buffer + } else if (dtype == DataType.DT_DOUBLE) { + throw new IllegalArgumentException(s"can not convert a float tensor to double tensor") + } else { + throw new UnsupportedOperationException(s"data type ${dtype} is not supported currently") + } + } + + AttrValue.newBuilder().setTensor( + TensorProto.newBuilder().setTensorShape(shape).setDtype(dtype) + .setTensorContent(ByteString.copyFrom(content.array())) + ).build() + } + + private def tensorAttr(value: Seq[Int]): AttrValue = { + val shape = TensorShapeProto.newBuilder() + shape.addDim(Dim.newBuilder().setSize(value.length)) + val dtype = DataType.DT_INT32 + AttrValue.newBuilder().setTensor( + TensorProto.newBuilder().setTensorShape(shape).setDtype(dtype) + ).build() + } + + private def typeAttr(dtype : TensorDataType): AttrValue = { + if (dtype == FloatType) { + AttrValue.newBuilder().setType(DataType.DT_FLOAT).build() + } else if (dtype == DoubleType) { + AttrValue.newBuilder().setType(DataType.DT_DOUBLE).build() + } else { + throw new NotImplementedError(s"type $dtype is not supported") + } + } + + private def shapeAttr(shape: Seq[Int]): AttrValue = { + val attr = TensorShapeProto.newBuilder() + shape.foreach(dim => { + attr.addDim(Dim.newBuilder().setSize(dim)) + }) + AttrValue.newBuilder().setShape(attr).build() + } + + private def getDataType(node: NodeDef): AttrValue = { + var attr = node.getAttrOrDefault("dtype", null) + if (attr != null) { + return attr + } + + attr = node.getAttrOrDefault("out_type", null) + if (attr != null) { + return attr + } + + attr = node.getAttrOrDefault("T", null) + if (attr != null) { + return attr + } + + throw new IllegalArgumentException("TensorflowSaver: Can not find data type") + } + + private def getPaddingType(padW: Int, padH: Int, kW: Int, kH: Int, sW: Int, sH: Int) + : PaddingType = { + if (padW == 0 && padH == 0) { + return PaddingType.PADDING_VALID + } else if (2 * padW == (kW - sW) && 2 * padH == (kH - sH)) { + return PaddingType.PADDING_SAME + } else { + throw new IllegalArgumentException( + s"Can not get padding type from given parameter " + + s"(padW: $padW, padH: $padH, kW: $kW, kH: $kH, sW: $sW, sH: $sH )") + } + } + + private def kernelAttr(kW: Int, kH: Int, dataFormat: TensorflowDataFormat): AttrValue = { + val kSize = if (dataFormat == TensorflowDataFormat.NHWC) { + Seq(1, kH, kW, 1) + } else { + Seq(1, 1, kH, kW) + } + listIntAttr(kSize) + } + + private def strideAttr(sW: Int, sH: Int, dataFormat: TensorflowDataFormat): AttrValue = { + val sSize = if (dataFormat == TensorflowDataFormat.NHWC) { + Seq(1, sH, sW, 1) + } else { + Seq(1, 1, sH, sW) + } + listIntAttr(sSize) + } +} diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/TensorflowLoader.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/TensorflowLoader.scala new file mode 100644 index 00000000000..f56c0d6dbc2 --- /dev/null +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/TensorflowLoader.scala @@ -0,0 +1,276 @@ +/* + * Copyright 2016 The BigDL Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.intel.analytics.bigdl.utils.tf + +import java.io.{DataInputStream, FileInputStream} +import java.nio.ByteOrder +import java.util + +import org.tensorflow.framework.{GraphDef, NodeDef} +import com.google.protobuf.CodedInputStream +import java.util.List + +import com.intel.analytics.bigdl.Module +import com.intel.analytics.bigdl.nn.Graph +import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity} +import com.intel.analytics.bigdl.tensor.Tensor +import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric +import com.intel.analytics.bigdl.utils.{DirectedGraph, Node} +import com.intel.analytics.bigdl.utils.tf.TensorflowToBigDL._ + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +object TensorflowLoader{ + + type Context[T] = mutable.HashMap[NodeDef, (Tensor[T], Tensor[T])] + + /** + * Load tensorflow model from a prototxt file + * @param graphPrototxt where is the tensorflow protobuf file + * @param inputs input node names + * @param outputs output node names + * @param byteOrder file byteOrder + * @return + */ + def load[T: ClassTag](graphPrototxt: String, inputs: Seq[String], outputs: Seq[String], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): Module[T] = { + // Get node list + val nodeList = parse(graphPrototxt) + + // Construct tf node graph + val tfGraph = buildTFGraph(nodeList, outputs) + + // Build BigDL model from the tf node graph + buildBigDLModel(tfGraph, inputs, outputs, byteOrder) + } + + /** + * Parse a tensorflow model protobuf file, read a list of op nodes from it + * @param graphProtoTxt where is the tf protobuf file + * @return + */ + private[bigdl] def parse(graphProtoTxt: String) : List[NodeDef] = { + val f = new java.io.File(graphProtoTxt) + require(f.exists(), graphProtoTxt + " does not exists") + + val reader = CodedInputStream.newInstance(new DataInputStream(new FileInputStream(f))) + reader.setSizeLimit(0x7fffffff) + + val graph = GraphDef.parseFrom(reader) + graph.getNodeList + } + + /** + * Build tf ops graph from a given node list + * @param nodes + * @param outputNodeNames + * @return + */ + private[bigdl] def buildTFGraph(nodes : List[NodeDef], outputNodeNames: Seq[String]) + : DirectedGraph[NodeDef] = { + import scala.collection.JavaConverters._ + var name2Node = nodes.asScala.map(n => n.getName -> new Node(n)).toMap + + // Process node with multiple tensor output, each tensor is regarded as a node + nodes.asScala + .flatMap(_.getInputList.asScala) + .filter(_.split(TENSOR_SEPARATOR).length > 1) + .foreach { nameWithChannel => + val name = nameWithChannel.split(TENSOR_SEPARATOR).head + val tfNode = NodeDef.newBuilder(name2Node(name).element) + .setName(nameWithChannel).build() + name2Node += nameWithChannel -> new Node(tfNode) + } + + // Connect nodes + name2Node.valuesIterator.foreach(n => { + n.element.getInputList.asScala.foreach{ + input => + // It is tricky here, remove the first char in the name of control dep node + val name = if (input.charAt(0) == '^') input.substring(1) else input + name2Node(name) -> n + } + }) + + // Build graph + val outputNodes = if (outputNodeNames == null) { + name2Node.valuesIterator.filter(_.nextNodes.length == 0).toArray + } else { + val results = name2Node.valuesIterator.toArray.filter(n => + outputNodeNames.contains(n.element.getName)) + require(results.length == outputNodeNames.length, "Invalid outputNode names") + results + } + + val dummyOutput = new Node[NodeDef](null) + outputNodes.foreach(_ -> dummyOutput) + dummyOutput.graph(reverse = true) + } + + private[bigdl] def buildBigDLModel[T: ClassTag]( + tfGraph: DirectedGraph[NodeDef], + inputs: Seq[String], + outputs: Seq[String], + byteOrder: ByteOrder, + ctx: Option[Context[T]] = None + )(implicit ev: TensorNumeric[T]): Module[T] = { + import scala.collection.JavaConverters._ + + // Map from tensorflow node to the converted BigDL node + val convertedNode = new mutable.HashMap[Node[NodeDef], + Node[AbstractModule[Activity, Tensor[T], T]]]() + val nameToNode = + new mutable.HashMap[String, Node[AbstractModule[Activity, Tensor[T], T]]]() + val context = ctx.getOrElse(new mutable.HashMap[NodeDef, (Tensor[T], Tensor[T])]) + + // BFS to keep the input order same + tfGraph.BFS.foreach(n => { + if (n.element == null) { + // Dummy node, skip + } else if (convertedNode.get(n).isDefined) { + // converted node, skip + } else { + val (module, nodes, inputNodes) = + extract[T](n.graph(reverse = true), context, byteOrder).getOrElse( + throw new UnsupportedOperationException(s"Can not find matched graph \n${n}\n\n" + + s"Its inputs are\n ${n.prevNodes.mkString("\n")}") + ) + + val node = new Node(module) + nodes.asScala.foreach(m => { + convertedNode(m) = node + nameToNode(m.element.getName) = node + }) + + // These two pieces of code are all necessary + val nextNodes = n.nextNodes.filter( + n => n.element != null && convertedNode.contains(n) && !context.contains(n.element) + ).map(convertedNode(_)).filter(_ != node).toSet + nextNodes.foreach(node -> _) + + val preNodes = inputNodes.flatMap(_.prevNodes) + .filter(n => n.element != null && convertedNode.contains(n) + && !context.contains(n.element)) + .map(convertedNode(_)).filter(_ != node).toSet + preNodes.foreach(_ -> node) + } + }) + + val inputNodes = inputs + .map(n => nameToNode.getOrElse(n, throw new IllegalArgumentException(s"Can't find node $n"))) + val outputNodes = outputs + .map(n => nameToNode.getOrElse(n, throw new IllegalArgumentException(s"Can't find node $n"))) + + + val weights = ArrayBuffer[Tensor[T]]() + val gradients = ArrayBuffer[Tensor[T]]() + for ((weight, grad) <- context.values) { + weights += weight + gradients += grad + } + + Graph(inputNodes.toArray, outputNodes.toArray, Some((weights.toArray, gradients.toArray))) + } + + /** + * Extract one module and the corresponding node list from the given graph + * @param graph + * @return + */ + private[bigdl] def extract[T: ClassTag](graph: DirectedGraph[NodeDef], + context: Context[T], byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): Option[( + AbstractModule[Activity, Tensor[T], T], + List[Node[NodeDef]], + Seq[Node[NodeDef]] + )] = { + + var i = 0 + while(i < patterns.length) { + val (result, inputs) = matchGraph(graph, patterns(i).topology) + if (result.size != 0) { + // get model + return Some(patterns(i).layer(graph, context, byteOrder), result, inputs) + } + i += 1 + } + None + } + + private def matchGraph(graph: DirectedGraph[NodeDef], pattern: DirectedGraph[String]) + : (List[Node[NodeDef]], Seq[Node[NodeDef]]) = { + require(graph.reverse && pattern.reverse, "Must pass in reversed graph") + val patternToGraph = new mutable.HashMap[Node[String], Node[NodeDef]]() + val inputs = new ArrayBuffer[Node[NodeDef]]() + patternToGraph(pattern.source) = graph.source + + pattern.BFS.foreach(patternNode => { + if (patternNode.element != N_INPUT_PLACEHOLDER && patternNode.element != INPUT_PLACEHOLDER) { + // Normal operation node + if (patternToGraph.get(patternNode).isEmpty) return (util.Collections.emptyList(), Seq()) + + val graphNode = patternToGraph(patternNode) + // Operation type should match + if (patternNode.element != graphNode.element.getOp) return ( + util.Collections.emptyList(), Seq()) + + // Prev nodes number should be same except for the Ninput case + if (patternNode.prevNodes.length != graphNode.prevNodes.length && + patternNode.prevNodes.filter(_.element == N_INPUT_PLACEHOLDER).length == 0) { + return (util.Collections.emptyList(), Seq()) + } + + var i = 0 + var direction = 0 + var j = 0 + while (i < patternNode.prevNodes.length) { + if (patternNode.prevNodes(i).element == N_INPUT_PLACEHOLDER) { + require(patternNode.prevNodes.count(_.element == N_INPUT_PLACEHOLDER) == 1, + s"only support one $N_INPUT_PLACEHOLDER ") + direction = 1 + // skip the left input nodes of graphNode, + // once we find the placeholder, we start from another side + if (!inputs.contains(graphNode)) { + inputs.append(graphNode) + } + } else if (patternNode.prevNodes(i).element == INPUT_PLACEHOLDER) { + // skip input placeholder + if (!inputs.contains(graphNode)) { + inputs.append(graphNode) + } + } else { + val posPattern = { if (direction == 0) i else patternNode.prevNodes.length - 1 - j} + val posGraph = { if (direction == 0) i else graphNode.prevNodes.length - 1 - j} + val pn = patternNode.prevNodes(posPattern) + val gn = graphNode.prevNodes(posGraph) + if (patternToGraph.contains(pn)) { + if (!patternToGraph(pn).eq(gn)) return (util.Collections.emptyList(), Seq()) + } else { + patternToGraph(pn) = gn + } + if (direction == 1) j += 1 + } + i += 1 + } + } + }) + import scala.collection.JavaConverters._ + return (patternToGraph.valuesIterator.toList.asJava, inputs) + } +} diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/TensorflowSaver.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/TensorflowSaver.scala new file mode 100644 index 00000000000..96b1ae27330 --- /dev/null +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/TensorflowSaver.scala @@ -0,0 +1,152 @@ +/* + * Copyright 2016 The BigDL Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.intel.analytics.bigdl.utils.tf + +import java.io.FileOutputStream +import java.nio.ByteOrder + +import com.google.protobuf.CodedOutputStream +import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity} +import com.intel.analytics.bigdl.nn._ +import com.intel.analytics.bigdl.tensor.Tensor +import org.apache.log4j.Logger +import org.tensorflow.framework._ + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import com.intel.analytics.bigdl.utils.tf.Tensorflow._ + +object TensorflowSaver { + /** + * Save a graph model to protobuf files so that it can be used in tensorflow inference. + * + * When save the model, placeholders will be added to the tf model as input nodes. So you need to + * pass in the names and shape for the placeholders. BigDL model doesn't have such information. + * The order of the placeholde information should be same as the inputs of the graph model + * + * @param model graph model instance + * @param inputs input node defs + * @param path where to save + * @param byteOrder model byte order + * @tparam T + */ + def saveGraphWitNodeDef[T]( + model : Graph[T], + inputs : Seq[NodeDef], + path: String, + byteOrder: ByteOrder = ByteOrder.LITTLE_ENDIAN, + extraNodes: Set[NodeDef] = Set()): Unit = { + val inputNodeCache = + new mutable.HashMap[AbstractModule[Activity, Tensor[T], T], ArrayBuffer[NodeDef]]() + model.inputs.zip(inputs).foreach(n => { + inputNodeCache(n._1.element) = ArrayBuffer(n._2) + println() + }) + + val graphBuilder = GraphDef.newBuilder() + inputs.foreach(graphBuilder.addNode(_)) + + model.executions.foreach(n => { + val nodeDefs = maps(n.element.getClass.getName).toTFDef(n.element, inputNodeCache(n.element), + byteOrder) + nodeDefs.foreach(nDef => { + graphBuilder.addNode(nDef) + }) + n.nextNodes.foreach(n => { + val list = inputNodeCache.getOrElse(n.element, ArrayBuffer()) + list.append(nodeDefs(0)) + inputNodeCache(n.element) = list + }) + }) + + extraNodes.foreach(graphBuilder.addNode(_)) + + // Save to file + val os = new FileOutputStream(path) + val output = CodedOutputStream.newInstance(os) + val graph = graphBuilder.build() + logger.info("Graph definition is:") + logger.info(graph.toString) + graph.writeTo(output) + output.flush() + os.close() + logger.info(s"Save as tensorflow model file to $path") + } + + /** + * Save a graph model to protobuf files so that it can be used in tensorflow inference. + * + * When save the model, placeholders will be added to the tf model as input nodes. So you need to + * pass in the names and shape for the placeholders. BigDL model doesn't have such information. + * The order of the placeholde information should be same as the inputs of the graph model + * + * @param model graph model instance + * @param inputs placeholder information + * @param path where to save + * @param byteOrder model byte order + * @param dataFormat model data format + * @tparam T + */ + def saveGraph[T]( + model : Graph[T], + inputs : Seq[(String, Seq[Int])], + path: String, + byteOrder: ByteOrder = ByteOrder.LITTLE_ENDIAN, + dataFormat: TensorflowDataFormat = TensorflowDataFormat.NHWC): Unit = { + val inputNodeDefs = inputs.map(input => + placeholder(model.getNumericType(), input._2, input._1) + ) + saveGraphWitNodeDef(model, inputNodeDefs, path, byteOrder) + } + + /** + * Register a customized BigDL module saver. + * @param className class name of the BigDL module + * @param saver customized saver + */ + def register(className : String, saver: BigDLToTensorflow): Unit = { + maps(className) = saver + } + + private val logger = Logger.getLogger(getClass) + + private val maps = mutable.Map[String, BigDLToTensorflow]( + getNameFromObj(ReLU.getClass.getName) -> ReLUToTF, + getNameFromObj(Linear.getClass.getName) -> LinearToTF, + getNameFromObj(SpatialConvolution.getClass.getName) -> SpatialConvolutionToTF, + getNameFromObj(Squeeze.getClass.getName) -> SqueezeToTF, + getNameFromObj(Tanh.getClass.getName) -> TanhToTF, + getNameFromObj(Reshape.getClass.getName) -> ReshapeToTF, + getNameFromObj(View.getClass.getName) -> ViewToTF, + getNameFromObj(SpatialMaxPooling.getClass.getName) -> MaxpoolToTF, + getNameFromObj(Padding.getClass.getName) -> PaddingToTF, + getNameFromObj(SpatialAveragePooling.getClass.getName) -> AvgpoolToTF, + getNameFromObj(Sigmoid.getClass.getName) -> SigmoidToTF, + getNameFromObj(Dropout.getClass.getName) -> DropoutToTF, + getNameFromObj(CAddTable.getClass.getName) -> CAddTableToTF, + getNameFromObj(CMulTable.getClass.getName) -> CMultTableToTF, + getNameFromObj(JoinTable.getClass.getName) -> JoinTableToTF, + getNameFromObj(Mean.getClass.getName) -> MeanToTF, + getNameFromObj(SoftMax.getClass.getName) -> SoftMaxToTF, + getNameFromObj(LogSoftMax.getClass.getName) -> LogSoftMaxToTF, + getNameFromObj(SpatialBatchNormalization.getClass.getName) -> BatchNorm2DToTF, + getNameFromObj(Input.getClass.getName) -> InputToTF, + getNameFromObj(Sigmoid.getClass.getName) -> SigmoidToTF + ) + + private def getNameFromObj(name: String) : String = name.substring(0, name.length - 1) +} + diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/TensorflowToBigDL.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/TensorflowToBigDL.scala new file mode 100644 index 00000000000..6605bdfd0b8 --- /dev/null +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/TensorflowToBigDL.scala @@ -0,0 +1,1151 @@ +/* + * Copyright 2016 The BigDL Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.intel.analytics.bigdl.utils.tf + +import java.nio.charset.Charset +import java.nio.{ByteBuffer, ByteOrder} + +import collection.JavaConverters._ +import com.intel.analytics.bigdl.nn._ +import com.intel.analytics.bigdl.tensor.{Storage, Tensor} +import org.tensorflow.framework.{DataType, NodeDef, TensorProto} +import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity} +import com.intel.analytics.bigdl.nn.tf._ +import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric +import com.intel.analytics.bigdl.utils.{DirectedGraph, Node, T} +import com.intel.analytics.bigdl.utils.tf.TensorflowLoader.Context +import com.intel.analytics.bigdl.utils.tf.TensorflowToBigDL._ + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.{ClassTag, classTag} + +/** + * Represent a mapping from tensorflow operations graph to BigDL Module + */ +trait TensorflowToBigDL { + + /** + * The topology of the tensorflow operation graph + * @return + */ + def topology: DirectedGraph[String] + + /** + * Get the BigDL model + * @param tfGraph operation graph + * @param context variables + * @return (module, input nodes, output nodes) + */ + def layer[T: ClassTag]( + tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder + )(implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] + + protected def getOrSetTensor[T: ClassTag]( + node: NodeDef, context: Context[T], byteOrder: ByteOrder)(f: Tensor[T] => Tensor[T])( + implicit ev: TensorNumeric[T]): (Tensor[T], Tensor[T]) = { + + if (context.contains(node)) { + context(node) + } else { + val weight = f(toTensor[T](node.getAttrMap.get("value").getTensor, byteOrder)).contiguous() + val gradient = Tensor[T](weight.size()) + context.put(node, (weight, gradient)) + (weight, gradient) + } + } +} + +object TensorflowToBigDL { + + /** + * Represent one input + */ + val INPUT_PLACEHOLDER: String = "*" + + /** + * Represent one or many inputs. Note this can only be the first or the last of the input names + */ + val N_INPUT_PLACEHOLDER: String = "..." + + /** + * Separate operation name and its output tensor. In tensorflow, if one operation output multiple + * tensors, the tensor will be referred as Op:n, which n is a integer. + */ + val TENSOR_SEPARATOR: String = ":" + + /** + * Get the pattern list. + * @return + */ + def patterns: Array[TensorflowToBigDL] = { + patternList.toArray + } + + /** + * Register a new mapping from tensor flow operations to BigDL layer. The mapping is defined as + * a subclass of TFToBigDL, which defines an operation topology(reversed graph) and how to get + * constructor parameters from the topology. + * @param pattern + */ + def registerPattern(pattern: TensorflowToBigDL): Unit = { + require(pattern.topology.reverse == true, "the topology should be a reversed graph") + patternList.append(pattern) + sortPattern() + } + + var dataFormat: String = "NHWC" + + def dataNCHW: Unit = dataFormat = "NCHW" + + /** + * Convert a tensorflow tensor proto to BigDL tensor + * @param tfTensor + * @return + */ + private[utils] def toTensor[T: ClassTag](tfTensor: TensorProto, endian: ByteOrder)( + implicit ev: TensorNumeric[T]): Tensor[T] = { + + require( + tfTensor.getDtype == DataType.DT_FLOAT || + tfTensor.getDtype == DataType.DT_DOUBLE || + tfTensor.getDtype == DataType.DT_INT32, + s"Data type ${tfTensor.getDtype} is not supported now") + + val shape = tfTensor.getTensorShape.getDimList.asScala.map(_.getSize.toInt).toArray + + /** + * When there's one element in the tensor. You cannot get the value from byte string + */ + if (shape.product == 1) { + if (classTag[T] == classTag[Float]) { + if (tfTensor.getDtype == DataType.DT_FLOAT) { + return Tensor[Float](T(tfTensor.getFloatVal(0))).asInstanceOf[Tensor[T]] + } + + if (tfTensor.getDtype == DataType.DT_INT32) { + return Tensor[Float](T(tfTensor.getIntVal(0).toFloat)).asInstanceOf[Tensor[T]] + } + + throw new IllegalArgumentException("Can not convert double to float") + } else if (classTag[T] == classTag[Double]) { + if (tfTensor.getDtype == DataType.DT_DOUBLE) { + return Tensor[Float](T(tfTensor.getDoubleVal(0))).asInstanceOf[Tensor[T]] + } + + if (tfTensor.getDtype == DataType.DT_FLOAT) { + return Tensor[Float](T(tfTensor.getFloatVal(0).toDouble)).asInstanceOf[Tensor[T]] + } + + if (tfTensor.getDtype == DataType.DT_INT32) { + return Tensor[Float](T(tfTensor.getIntVal(0).toDouble)).asInstanceOf[Tensor[T]] + } + } + } + + val buffer = ByteBuffer.wrap(tfTensor.getTensorContent.toByteArray) + buffer.order(endian) + + if (classTag[T] == classTag[Float]) { + if (tfTensor.getDtype == DataType.DT_FLOAT) { + val params = buffer.asFloatBuffer + val tmp = new Array[Float](params.capacity()) + var j = 0 + while (j < params.capacity()) { + tmp(j) = params.get(j) + j += 1 + } + Tensor(Storage(tmp), 1, shape).asInstanceOf[Tensor[T]] + } else if (tfTensor.getDtype == DataType.DT_INT32) { + val params = buffer.asIntBuffer + val tmp = new Array[Float](params.capacity()) + var j = 0 + while (j < params.capacity()) { + tmp(j) = params.get(j) + j += 1 + } + Tensor(Storage(tmp), 1, shape).asInstanceOf[Tensor[T]] + } else { + throw new IllegalArgumentException("Can not convert double to float") + } + } else if (classTag[T] == classTag[Double]) { + if (tfTensor.getDtype == DataType.DT_FLOAT) { + val params = buffer.asFloatBuffer + val tmp = new Array[Double](params.capacity()) + var j = 0 + while (j < params.capacity()) { + tmp(j) = params.get(j) + j += 1 + } + Tensor(Storage(tmp), 1, shape).asInstanceOf[Tensor[T]] + } else if (tfTensor.getDtype == DataType.DT_INT32) { + val params = buffer.asIntBuffer + val tmp = new Array[Double](params.capacity()) + var j = 0 + while (j < params.capacity()) { + tmp(j) = params.get(j) + j += 1 + } + Tensor(Storage(tmp), 1, shape).asInstanceOf[Tensor[T]] + } else if (tfTensor.getDtype == DataType.DT_DOUBLE) { + val params = buffer.asDoubleBuffer() + val tmp = new Array[Double](params.capacity()) + var j = 0 + while (j < params.capacity()) { + tmp(j) = params.get(j) + j += 1 + } + Tensor(Storage(tmp), 1, shape).asInstanceOf[Tensor[T]] + } else { + throw new IllegalArgumentException(s"Data type ${tfTensor.getDtype} is not supported now") + } + } else { + throw new IllegalArgumentException("Only support Float/Double") + } + } + + private var patternList : ArrayBuffer[TensorflowToBigDL] = { + val res = new ArrayBuffer[TensorflowToBigDL]() + // ElementWiseMulTF must be after MulTF + res.append( + FullConnectionTF, DropoutTF, AvgPoolingTF, MaxPoolingTF, ReshapeTF, InputTF, + TanhTF, ReluTF, SigmoidTF, Conv2D, Placeholder, SqueezeTF, IdentityTF, ConcatTF, + BatchNormTF, AddConstTF1, AddConstTF2, AddTF, SoftMaxTF, ElementWiseMulTF, MulTF, + SplitTF, PaddingTF, MeanTF, UnpackTF, StrideSliceTF, ShapeTF, FillTF, PackTF, ConstTF, + Flatten + ) + res + } + + sortPattern() + + /** + * Sort the pattern list to make sure the graph match first should not be a sub-graph of the graph + * match later + */ + private def sortPattern() : Unit = { + // do not calculate size and edges of a graph every time + val topToNNodes = patternList.map(g => { + val nodeSize = g.topology.BFS.count(n => + n.element != INPUT_PLACEHOLDER && n.element != N_INPUT_PLACEHOLDER) + g -> nodeSize + }).toMap + + val topToNEdges = patternList.map(g => { + val edgeSize = g.topology.BFS.filter(n => + n.element != INPUT_PLACEHOLDER && n.element != N_INPUT_PLACEHOLDER) + .map(_.nextNodes.length).reduce(_ + _) + g -> edgeSize + }).toMap + + patternList = patternList.sortWith((l, r) => { + if (topToNNodes(l) != topToNNodes(r)) { + // graph with more nodes comes first + topToNNodes(l) > topToNNodes(r) + } else { + // same node number, graph with more edges come first + topToNEdges(l) > topToNEdges(r) + } + }) + } + + /** + * This method is just for test purpose. Do not use the bigdl.saveNHWC for real use case + * @return + */ + private[tf] def processDims(dim: Int): Int = { + if (System.getProperty("bigdl.enableNHWC", "false").toBoolean) { + // exchange the dims as BigDL only support NCHW now + if (dim == 1) return 2 + if (dim == 2) return 3 + if (dim == 3) return 1 + dim + } else { + dim + } + } +} + +object FullConnectionTF extends TensorflowToBigDL{ + private val graph = { + val add = Node("BiasAdd") + val mul = Node("MatMul") + Node("*") -> mul + Node("Const") -> Node("Identity") -> mul -> add + Node("Const") -> Node("Identity") -> add + add.graph(reverse = true) + } + override def topology: DirectedGraph[String] = graph + + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + + val biasNode = tfGraph.source.prevNodes(1).prevNodes.head.element + val weightNode = tfGraph.source.prevNodes.head.prevNodes(1).prevNodes.head.element + val (bias, gradBias) = getOrSetTensor(biasNode, context, byteOrder)(t => t) + val (weight, gradWeight) = getOrSetTensor(weightNode, context, byteOrder) { t => + t.transpose(1, 2) + } + + Linear[T](inputSize = weight.size(2), outputSize = weight.size(1), + initWeight = weight, initGradWeight = gradWeight, initBias = bias, initGradBias = gradBias) + .asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object SqueezeTF extends TensorflowToBigDL { + private val graph = (Node("*") -> Node("Squeeze")).graph(reverse = true) + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + val dims = tfGraph.source.element.getAttrOrThrow("squeeze_dims").getList().getIList() + .asScala.map(_.toInt).toArray.map(processDims(_)) + + Squeeze[T](dims, batchMode = true).asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object Conv2D extends TensorflowToBigDL{ + private val graph = { + val add = Node("BiasAdd") + val conv = Node("Conv2D") + + Node("*") -> conv + Node("Const") -> Node("Identity") -> conv -> add + Node("Const") -> Node("Identity") -> add + add.graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + + val attributes = tfGraph.source.prevNodes(0).element.getAttrMap + require(attributes.get("strides").getList.getI(0).toInt == 1, s"not support strides on batch") + + val (strideH, strideW) = if (attributes.get("data_format").getS + .toString(Charset.defaultCharset()) == "NHWC") { + require(System.getProperty("bigdl.enableNHWC", "false").toBoolean, "Not support NHWC") + require(attributes.get("strides").getList.getI(3).toInt == 1, s"not support strides on depth") + (attributes.get("strides").getList.getI(1).toInt, + attributes.get("strides").getList.getI(2).toInt) + } else if (attributes.get("data_format").getS.toString(Charset.defaultCharset()) == "NCHW") { + require(attributes.get("strides").getList.getI(2).toInt == 1, s"not support strides on depth") + (attributes.get("strides").getList.getI(2).toInt, + attributes.get("strides").getList.getI(3).toInt) + } else { + throw new IllegalArgumentException("no supported data format") + } + val biasNode = tfGraph.source.prevNodes(1).prevNodes.head.element + val (bias, gradBias) = getOrSetTensor(biasNode, context, byteOrder)(t => t) + + val weightNode = tfGraph.source.prevNodes.head.prevNodes(1).prevNodes.head.element + val (weights, gradWeights) = getOrSetTensor(weightNode, context, byteOrder) { t => + t.transpose(1, 4).transpose(2, 3).transpose(3, 4) + } + + val nOuputPlane = weights.size(1) + val nInputPlane = weights.size(2) + val kernelH = weights.size(3) + val kernelW = weights.size(4) + + val (pW, pH) = + if (attributes.get("padding").getS.toString(Charset.defaultCharset()) == "SAME") { + require((kernelW - strideW) % 2 == 0) + require((kernelH - strideH) % 2 == 0) + ((kernelW - strideW) / 2, (kernelH - strideH) / 2) + } else { + (0, 0) + } + + SpatialConvolution[T]( + nInputPlane = nInputPlane, nOutputPlane = nOuputPlane, + kernelW = kernelW, kernelH = kernelH, + strideW = strideW, strideH = strideH, + padW = pW, padH = pH, + initWeight = weights, + initBias = bias, + initGradWeight = gradWeights, + initGradBias = gradBias).asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object ReluTF extends TensorflowToBigDL { + private val graph = { + (Node("*") -> Node("Relu")).graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + ReLU[T]().asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object TanhTF extends TensorflowToBigDL{ + private val graph = { + (Node("*") -> Node("Tanh")).graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + + Tanh[T]().asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object SigmoidTF extends TensorflowToBigDL{ + private val graph = { + (Node("*") -> Node("Sigmoid")).graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + Sigmoid[T]().asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object ReshapeTF extends TensorflowToBigDL { + private val graph = { + val nodeReshape = Node("Reshape") + Node("*") -> nodeReshape + Node("Const") -> nodeReshape + nodeReshape.graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + val sizes = TensorflowToBigDL.toTensor( + tfGraph.source.prevNodes(1).element.getAttrMap.get("value").getTensor, byteOrder) + + val batchMode = sizes.valueAt(1) == -1 + val arraySize = new Array[Int](if (batchMode) sizes.nElement() - 1 else sizes.nElement()) + var i = if (batchMode) 2 else 1 + var k = 0 + while(i <= sizes.nElement()) { + arraySize(k) = ev.toType[Int](sizes.valueAt(i)) + k += 1 + i += 1 + } + Reshape[T](size = arraySize, Some(batchMode)) + .asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object MaxPoolingTF extends TensorflowToBigDL { + private val graph = { + (Node("*") -> Node("MaxPool")).graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + val attributes = tfGraph.source.element.getAttrMap + + val (strideH, strideW, ksizeH, ksizeW) = if (attributes.get("data_format").getS + .toString(Charset.defaultCharset()) == "NHWC") { + require(System.getProperty("bigdl.enableNHWC", "false").toBoolean, "Not support NHWC") + require(attributes.get("strides").getList.getI(3).toInt == 1, s"not support strides on depth") + ( + attributes.get("strides").getList.getI(1).toInt, + attributes.get("strides").getList.getI(2).toInt, + attributes.get("ksize").getList.getI(1).toInt, + attributes.get("ksize").getList.getI(2).toInt + ) + } else if (attributes.get("data_format").getS.toString(Charset.defaultCharset()) == "NCHW") { + require(attributes.get("strides").getList.getI(2).toInt == 1, s"not support strides on depth") + ( + attributes.get("strides").getList.getI(2).toInt, + attributes.get("strides").getList.getI(3).toInt, + attributes.get("ksize").getList.getI(2).toInt, + attributes.get("ksize").getList.getI(3).toInt + ) + } else { + throw new IllegalArgumentException("no supported data format") + } + + val (pW, pH) = + if (attributes.get("padding").getS.toString(Charset.defaultCharset()) == "SAME") { + require((ksizeW - strideW) % 2 == 0) + require((ksizeH - strideH) % 2 == 0) + ((ksizeW - strideW) / 2, (ksizeH - strideH) / 2) + } else { + (0, 0) + } + + SpatialMaxPooling[T](ksizeW, ksizeH, strideW, strideH, pW, pH) + .asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object AvgPoolingTF extends TensorflowToBigDL{ + private val graph = { + (Node("*") -> Node("AvgPool")).graph(reverse = true) + } + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + val attributes = tfGraph.source.element.getAttrMap + + val (strideH, strideW, ksizeH, ksizeW) = if (attributes.get("data_format").getS + .toString(Charset.defaultCharset()) == "NHWC") { + require(System.getProperty("bigdl.enableNHWC", "false").toBoolean, "Not support NHWC") + require(attributes.get("strides").getList.getI(3).toInt == 1, s"not support strides on depth") + ( + attributes.get("strides").getList.getI(1).toInt, + attributes.get("strides").getList.getI(2).toInt, + attributes.get("ksize").getList.getI(1).toInt, + attributes.get("ksize").getList.getI(2).toInt + ) + } else if (attributes.get("data_format").getS.toString(Charset.defaultCharset()) == "NCHW") { + require(attributes.get("strides").getList.getI(2).toInt == 1, s"not support strides on depth") + ( + attributes.get("strides").getList.getI(2).toInt, + attributes.get("strides").getList.getI(3).toInt, + attributes.get("ksize").getList.getI(2).toInt, + attributes.get("ksize").getList.getI(3).toInt + ) + } else { + throw new IllegalArgumentException("no supported data format") + } + + val (pW, pH) = + if (attributes.get("padding").getS.toString(Charset.defaultCharset()) == "SAME") { + require((ksizeW - strideW) % 2 == 0) + require((ksizeH - strideH) % 2 == 0) + ((ksizeW - strideW) / 2, (ksizeH - strideH) / 2) + } else { + (0, 0) + } + + SpatialAveragePooling[T](ksizeW, ksizeH, strideW, strideH, pW, pH, countIncludePad = false) + .asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object DropoutTF extends TensorflowToBigDL{ + private val graph = { + val nodediv = Node("RealDiv") + val nodeP = Node("Const") + val nodeadd = Node("Add") + val noderandom = Node("Add") + val nodemin = Node("Const") + val nodesub = Node("Sub") + val nodemul = Node("Mul") + val nodedrop = Node("Mul") + Node("*") -> nodediv -> nodedrop + nodeP -> nodediv + nodeP -> nodeadd -> Node("Floor") -> nodedrop + Node("*") -> Node("Shape") -> Node("RandomUniform") -> nodemul -> noderandom -> nodeadd + Node("Const") -> nodesub -> nodemul + nodemin -> nodesub + nodemin -> noderandom + nodedrop.graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + val keepProp = tfGraph.source.prevNodes(0).prevNodes(1).element + .getAttrMap.get("value").getTensor.getFloatVal(0) + + Dropout[T](keepProp).asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object Placeholder extends TensorflowToBigDL { + private val graph = Node("Placeholder").graph(reverse = true) + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + Input[T].element.asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + + +object ConstTF extends TensorflowToBigDL { + private val graph = Node("Const").graph(reverse = true) + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + val value = TensorflowToBigDL + .toTensor(tfGraph.source.element.getAttrMap.get("value").getTensor, byteOrder) + Const(value).asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object ShapeTF extends TensorflowToBigDL { + private val graph = { + val node = Node("Shape") + Node("*") -> node + node.graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + + Shape[T]().asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object InputTF extends TensorflowToBigDL { + private val graph = (Node("Const") -> Node("Identity")).graph(reverse = true) + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + Input[T].element.asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object IdentityTF extends TensorflowToBigDL { + private val graph = (Node("*") -> Node("Identity")).graph(reverse = true) + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + Input[T].element.asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object BatchNormTF extends TensorflowToBigDL{ + private val graph = { + val nodeInput = Node("*") + val nodeMean1 = Node("Mean") + val nodeStopGrad = Node("StopGradient") + val nodeSub1 = Node("Sub") + val nodeSquare = Node("SquaredDifference") + val nodeMeanss = Node("Sum") + val nodeVarss = Node("Sum") + val nodeShape = Node("Reshape") + val nodeDivisor = Node("Reciprocal") + val nodeShiftedMean = Node("Mul") + val nodeMean2 = Node("Add") + val nodeMul1 = Node("Mul") + val nodeVariance = Node("Sub") + val nodeAdd1 = Node("Add") + val nodeMul2 = Node("Mul") + val nodeMul3 = Node("Mul") + val nodeMul4 = Node("Mul") + val nodeSub2 = Node("Sub") + val nodeAdd2 = Node("Add") + + nodeInput -> nodeMul3 -> nodeAdd2 + Node("Const") -> Node("Identity") -> nodeSub2 + nodeInput -> nodeMean1 -> nodeStopGrad -> nodeShape + Node("Const") -> nodeMean1 + nodeInput -> nodeSub1 -> nodeMeanss -> nodeShiftedMean -> nodeMean2 -> nodeMul4 + nodeStopGrad -> nodeSub1 + nodeInput -> nodeSquare -> nodeVarss -> nodeMul1 -> nodeVariance + nodeStopGrad -> nodeSquare + Node("Const") -> nodeDivisor -> nodeShiftedMean -> Node("Square") -> nodeVariance -> nodeAdd1 + Node("Const") -> nodeMeanss -> nodeDivisor -> nodeMul1 + Node("Const") -> nodeVarss -> nodeDivisor + Node("Const") -> nodeAdd1 -> Node("Rsqrt") -> nodeMul2 -> nodeMul3 + Node("Const") -> Node("Identity") -> nodeMul2 -> nodeMul4 -> nodeSub2 -> nodeAdd2 + Node("Const") -> nodeShape -> nodeMean2 + nodeAdd2.graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + val nOutput = tfGraph.source.prevNodes(1).prevNodes(1).prevNodes(1) + .prevNodes(1).prevNodes(0).element.getAttrMap.get("value").getTensor.getIntVal(0) + + val weightNode = tfGraph.source.prevNodes(1).prevNodes.head.prevNodes.head.element + val biasNode = tfGraph.source.prevNodes(1).prevNodes(1).prevNodes(1) + .prevNodes.head.prevNodes.head.element + val (weights, gradWeights) = getOrSetTensor[T](weightNode, context, byteOrder)(t => t) + val (bias, gradBias) = getOrSetTensor[T](weightNode, context, byteOrder)(t => t) + + SpatialBatchNormalization[T]( + nOutput = nOutput, + initWeight = weights, + initBias = bias, + initGradWeight = gradWeights, + initGradBias = gradBias + ).asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object FillTF extends TensorflowToBigDL{ + private val graph = { + val nodeFill = Node("Fill") + Node("*") -> nodeFill + Node("Const") -> nodeFill + nodeFill.graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + val constNode = tfGraph.source.prevNodes(1) + val const = constNode.element.getAttrMap.get("value").getTensor.getFloatVal(0) + + Fill[T](const).asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object PackTF extends TensorflowToBigDL{ + private val graph = { + val nodePack = Node("Pack") + Node("...") -> nodePack + nodePack.graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + val dim = processDims(tfGraph.source.element.getAttrMap.get("axis").getI.toInt + 1) + + Pack[T](dim).asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object UnpackTF extends TensorflowToBigDL{ + private val graph = { + val nodePack = Node("Unpack") + Node("*") -> nodePack + nodePack.graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + val dim = processDims(tfGraph.source.element.getAttrMap.get("axis").getI.toInt + 1) + val index = tfGraph.source.element.getName.split(":").toList match { + case _::Nil => 1 + case _::i::Nil => i.toInt + 1 + } + Select[T](dim, index).asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object StrideSliceTF extends TensorflowToBigDL { + private val graph = { + val nodeSlice = Node("StridedSlice") + Node("*") -> nodeSlice + Node("Const") -> nodeSlice + Node("Const") -> nodeSlice + Node("Const") -> nodeSlice + nodeSlice.graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + val startNode = tfGraph.source.prevNodes(1) + val endNode = tfGraph.source.prevNodes(2) + val strideNode = tfGraph.source.prevNodes(3) + + def getIntArray(node: Node[NodeDef]) = { + node.element.getAttrMap.get("value").getTensor.getIntValList.asScala.map(_.toInt) + } + + val start = getIntArray(startNode) + val end = getIntArray(endNode) + val stride = getIntArray(strideNode) + + val specs = (start zip end zip stride).zipWithIndex + .map(elem => (elem._2 + 1, elem._1._1._1 + 1, elem._1._1._2 + 1, elem._1._2)).toArray + + + StrideSlice[T](specs).asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + + +object ConcatTF extends TensorflowToBigDL{ + private val graph = { + val nodeConcat = Node("ConcatV2") + Node("...") -> nodeConcat + (Node("Const") -> nodeConcat).graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + val inputNumber = tfGraph.source.element.getAttrMap.get("N").getI.toInt + val nodeaxis = tfGraph.source.prevNodes(inputNumber) + val axis = processDims( + nodeaxis.element.getAttrMap.get("value").getTensor.getIntVal(0)) + val nInputDims = 4 + + JoinTable[T](dimension = axis + 1, nInputDims = -1) + .asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object Flatten extends TensorflowToBigDL { + private val graph = { + val reshapeNode = Node("Reshape") + val concatNode = Node("ConcatV2") + val sliceNode = Node("Slice") + val expandNode = Node("ExpandDims") + val prodNode = Node("Prod") + val sliceNode1 = Node("Slice") + val shapeNode = Node("Const") + val beginNode = Node("Const") + val sizeNode = Node("Const") + val beginNode1 = Node("Const") + val sizeNode1 = Node("Const") + val constNode = Node("Const") + val dimNode = Node("Const") + val axisNode = Node("Const") + + shapeNode -> sliceNode + beginNode -> sliceNode + sizeNode -> sliceNode + + shapeNode -> sliceNode1 + beginNode1 -> sliceNode1 + sizeNode1 -> sliceNode1 + + sliceNode1 -> prodNode + constNode -> prodNode + + prodNode -> expandNode + dimNode -> expandNode + + sliceNode -> concatNode + expandNode -> concatNode + axisNode -> concatNode + + Node("*") -> reshapeNode + concatNode -> reshapeNode + reshapeNode.graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + val shapetfTensor = tfGraph.source.prevNodes(1).prevNodes(0).prevNodes(0).element + .getAttrMap.get("value").getTensor + val sizes = TensorflowToBigDL.toTensor(shapetfTensor, byteOrder) + val batchMode = false + + val arraySize = Array( + ev.toType[Int](sizes.valueAt(1)), + { + var prod = 1 + var i = 2 + while(i <= sizes.nElement()) { + prod = prod * ev.toType[Int](sizes.valueAt(i)) + i = i + 1 + } + prod + } + ) + + Reshape[T](size = arraySize, Some(batchMode)) + .asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object AddConstTF1 extends TensorflowToBigDL{ + private val graph = { + val nodeAdd = Node("Add") + Node("Const") -> nodeAdd + (Node("*") -> nodeAdd).graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + val value = tfGraph.source.prevNodes.head.element + .getAttrMap.get("value").getTensor.getFloatVal(0) + AddConstant[T](value).asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object AddConstTF2 extends TensorflowToBigDL{ + private val graph = { + val nodeAdd = Node("Add") + Node("*") -> nodeAdd + (Node("Const") -> nodeAdd).graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + val value = tfGraph.source.prevNodes(1).element + .getAttrMap.get("value").getTensor.getFloatVal(0) + AddConstant[T](value).asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object AddTF extends TensorflowToBigDL{ + private val graph = { + val nodeAdd = Node("Add") + Node("*") -> nodeAdd + (Node("*") -> nodeAdd).graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + CAddTable[T]().asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object SoftMaxTF extends TensorflowToBigDL{ + private val graph = { + (Node("*") -> Node("Softmax")).graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + SoftMax[T]().asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + + +object MulTF extends TensorflowToBigDL{ + private val graph = { + val nodeMul = Node("Mul") + Node("Const") -> nodeMul + (Node("*") -> nodeMul).graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + val scale = TensorflowToBigDL.toTensor( + tfGraph.source.prevNodes(0).element.getAttrMap.get("value").getTensor, byteOrder) + require(scale.dim() == 1 && scale.size(1) == 1, s"scale must be one number") + val mul = MulConstant[T](ev.toType[Double](scale.valueAt(1))) + mul.asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object ElementWiseMulTF extends TensorflowToBigDL{ + private val graph = { + val nodeMul = Node("Mul") + Node("*") -> nodeMul + (Node("*") -> nodeMul).graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + CMulTable[T]().asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object SplitTF extends TensorflowToBigDL { + + private val graph = { + val nodeSplit = Node("Split") + Node("Const") -> nodeSplit + (Node("*") -> nodeSplit).graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + val numSplit = tfGraph.source.element.getAttrMap.get("num_split").getI.toInt + val dim = tfGraph.source.prevNodes.head.element + .getAttrMap.get("value").getTensor.getIntVal(0) + 1 + val index = tfGraph.source.element.getName.split(":").toList match { + case _::Nil => 1 + case _::i::Nil => i.toInt + 1 + } + SplitAndSelect[T](dim, index, numSplit) + .asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } + +} + + +object PaddingTF extends TensorflowToBigDL{ + private val graph = { + val nodePad = Node("Pad") + Node("*") -> nodePad + (Node("Const") -> nodePad).graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + val paddings = TensorflowToBigDL.toTensor( + tfGraph.source.prevNodes(1).element.getAttrMap.get("value").getTensor, byteOrder) + val pad = ArrayBuffer[Int]() + val padding = Sequential[T]() + + for(i <- 1 to paddings.size(1)) { + if (paddings.valueAt(i, 1) != 0 || paddings.valueAt(i, 2) != 0 ) { + val dim = processDims(i - 1) + 1 + if (paddings(Array(i, 1)) != 0) { + padding.add(Padding[T](dim, -ev.toType[Int](paddings.valueAt(i, 1)), 4)) + } + if (paddings(Array(i, 2)) != 0) { + padding.add(Padding[T](dim, ev.toType[Int](paddings.valueAt(i, 2)), 4)) + } + } + } + + padding.asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} + +object MeanTF extends TensorflowToBigDL{ + private val graph = { + val nodeMean = Node("Mean") + Node("*") -> nodeMean + (Node("Const") -> nodeMean).graph(reverse = true) + } + + override def topology: DirectedGraph[String] = graph + + override def layer[T: ClassTag](tfGraph: DirectedGraph[NodeDef], + context: Context[T], + byteOrder: ByteOrder)( + implicit ev: TensorNumeric[T]): AbstractModule[Activity, Tensor[T], T] = { + + val dims = TensorflowToBigDL.toTensor( + tfGraph.source.prevNodes(1).element.getAttrMap.get("value").getTensor, byteOrder) + val dim = ArrayBuffer[Int]() + val mean = Sequential[T]() + for (i <- 1 to dims.size(1)) { + dim += processDims(ev.toType[Int](dims.valueAt(i))) + 1 + } + dim.foreach(i => mean.add(Mean[T](i, squeeze = false))) + mean.asInstanceOf[AbstractModule[Activity, Tensor[T], T]] + } +} diff --git a/spark/dl/src/test/resources/tf/.gitignore b/spark/dl/src/test/resources/tf/.gitignore new file mode 100644 index 00000000000..9ffaa437c01 --- /dev/null +++ b/spark/dl/src/test/resources/tf/.gitignore @@ -0,0 +1,5 @@ +model/ +freeze_graph.py +log/ +*.pb +!test.pb diff --git a/spark/dl/src/test/resources/tf/models/alexnet.py b/spark/dl/src/test/resources/tf/models/alexnet.py new file mode 100644 index 00000000000..24e228072c8 --- /dev/null +++ b/spark/dl/src/test/resources/tf/models/alexnet.py @@ -0,0 +1,39 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import tensorflow as tf +from nets import alexnet +from sys import argv + +from util import run_model + +def main(): + """ + You can also run these commands manually to generate the pb file + 1. git clone https://github.com/tensorflow/models.git + 2. export PYTHONPATH=Path_to_your_model_folder + 3. python alexnet.py + """ + height, width = 224, 224 + inputs = tf.Variable(tf.random_uniform((1, height, width, 3)), name = 'input') + net, end_points = alexnet.alexnet_v2(inputs, is_training=False) + print("nodes in the graph") + for n in end_points: + print(n + " => " + str(end_points[n])) + net_outputs = map(lambda x: tf.get_default_graph().get_tensor_by_name(x), argv[2].split()) + run_model(net_outputs, argv[1], 'alexnet') + +if __name__ == "__main__": + main() diff --git a/spark/dl/src/test/resources/tf/models/inception_resnet_v2.py b/spark/dl/src/test/resources/tf/models/inception_resnet_v2.py new file mode 100644 index 00000000000..13ba6d504d5 --- /dev/null +++ b/spark/dl/src/test/resources/tf/models/inception_resnet_v2.py @@ -0,0 +1,39 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import tensorflow as tf +from nets import inception_resnet_v2 +from sys import argv + +from util import run_model + +def main(): + """ + You can also run these commands manually to generate the pb file + 1. git clone https://github.com/tensorflow/models.git + 2. export PYTHONPATH=Path_to_your_model_folder + 3. python alexnet.py + """ + height, width = 299, 299 + inputs = tf.Variable(tf.random_uniform((2, height, width, 3)), name='input') + net, end_points = inception_resnet_v2.inception_resnet_v2(inputs,is_training = False) + print("nodes in the graph") + for n in end_points: + print(n + " => " + str(end_points[n])) + net_outputs = map(lambda x: tf.get_default_graph().get_tensor_by_name(x), argv[2].split(',')) + run_model(net_outputs, argv[1]) + +if __name__ == "__main__": + main() diff --git a/spark/dl/src/test/resources/tf/models/inception_v3.py b/spark/dl/src/test/resources/tf/models/inception_v3.py new file mode 100644 index 00000000000..1be0157040b --- /dev/null +++ b/spark/dl/src/test/resources/tf/models/inception_v3.py @@ -0,0 +1,42 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import tensorflow as tf +from nets import inception +from sys import argv + +from util import run_model + +slim = tf.contrib.slim + +def main(): + """ + You can also run these commands manually to generate the pb file + 1. git clone https://github.com/tensorflow/models.git + 2. export PYTHONPATH=Path_to_your_model_folder + 3. python alexnet.py + """ + height, width = 299, 299 + num_classes = 1000 + inputs = tf.Variable(tf.random_uniform((1, height, width, 3)), name='input') + net, end_points = inception.inception_v3(inputs, num_classes,is_training=False) + print("nodes in the graph") + for n in end_points: + print(n + " => " + str(end_points[n])) + net_outputs = map(lambda x: tf.get_default_graph().get_tensor_by_name(x), argv[2].split()) + run_model(net_outputs, argv[1]) + +if __name__ == "__main__": + main() diff --git a/spark/dl/src/test/resources/tf/models/lenet.py b/spark/dl/src/test/resources/tf/models/lenet.py new file mode 100644 index 00000000000..f42d15f4d86 --- /dev/null +++ b/spark/dl/src/test/resources/tf/models/lenet.py @@ -0,0 +1,39 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import tensorflow as tf +from nets import lenet +from sys import argv + +from util import run_model + +def main(): + """ + You can also run these commands manually to generate the pb file + 1. git clone https://github.com/tensorflow/models.git + 2. export PYTHONPATH=Path_to_your_model_folder + 3. python alexnet.py + """ + height, width = 32, 32 + inputs = tf.Variable(tf.random_uniform((1, height, width, 3)), name='input') + net, end_points = lenet.lenet(inputs) + print("nodes in the graph") + for n in end_points: + print(n + " => " + str(end_points[n])) + net_outputs = map(lambda x: tf.get_default_graph().get_tensor_by_name(x), argv[2].split()) + run_model(net_outputs, argv[1], 'LeNet') + +if __name__ == "__main__": + main() diff --git a/spark/dl/src/test/resources/tf/models/overfeat.py b/spark/dl/src/test/resources/tf/models/overfeat.py new file mode 100644 index 00000000000..f3cfdd4dc38 --- /dev/null +++ b/spark/dl/src/test/resources/tf/models/overfeat.py @@ -0,0 +1,42 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import tensorflow as tf +from nets import overfeat +from sys import argv + +from util import run_model + +slim = tf.contrib.slim + +def main(): + """ + You can also run these commands manually to generate the pb file + 1. git clone https://github.com/tensorflow/models.git + 2. export PYTHONPATH=Path_to_your_model_folder + 3. python alexnet.py + """ + height, width = 231, 231 + inputs = tf.Variable(tf.random_uniform((1, height, width, 3)), name='input') + with slim.arg_scope(overfeat.overfeat_arg_scope()): + net, end_points = overfeat.overfeat(inputs, is_training = False) + print("nodes in the graph") + for n in end_points: + print(n + " => " + str(end_points[n])) + net_outputs = map(lambda x: tf.get_default_graph().get_tensor_by_name(x), argv[2].split()) + run_model(net_outputs, argv[1]) + +if __name__ == "__main__": + main() diff --git a/spark/dl/src/test/resources/tf/models/resnet_v1.py b/spark/dl/src/test/resources/tf/models/resnet_v1.py new file mode 100644 index 00000000000..e9eb1e8431a --- /dev/null +++ b/spark/dl/src/test/resources/tf/models/resnet_v1.py @@ -0,0 +1,40 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import tensorflow as tf +from nets import resnet_utils +from nets import resnet_v1 +from sys import argv + +from util import run_model + +def main(): + """ + You can also run these commands manually to generate the pb file + 1. git clone https://github.com/tensorflow/models.git + 2. export PYTHONPATH=Path_to_your_model_folder + 3. python alexnet.py + """ + height, width = 224, 224 + inputs = tf.Variable(tf.random_uniform((2, height, width, 3)), name='input') + net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=True) + print("nodes in the graph") + for n in end_points: + print(n + " => " + str(end_points[n])) + net_outputs = map(lambda x: tf.get_default_graph().get_tensor_by_name(x), argv[2].split()) + run_model(net_outputs, argv[1]) + +if __name__ == "__main__": + main() diff --git a/spark/dl/src/test/resources/tf/models/rnn.py b/spark/dl/src/test/resources/tf/models/rnn.py new file mode 100644 index 00000000000..99f1c7df540 --- /dev/null +++ b/spark/dl/src/test/resources/tf/models/rnn.py @@ -0,0 +1,57 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import tensorflow as tf +import numpy as np +from sys import argv +from tensorflow.contrib import rnn +from util import merge_checkpoint + +def main(): + """ + Run this command to generate the pb file + 1. mkdir model + 2. python rnn.py + """ + dir = argv[1] + n_steps = 2 + n_input = 10 + n_hidden = 20 + n_output = 5 + xs = tf.Variable(tf.random_uniform([4, n_steps, n_input]) + 10, name='input', dtype=tf.float32) + weight = tf.Variable(tf.random_uniform([n_hidden, n_output]) + 10, name="weight", dtype=tf.float32) + bias = tf.Variable(tf.random_uniform([n_output]) + 10, name="bias", dtype=tf.float32) + x = tf.unstack(xs, n_steps, 1) + cell = rnn.BasicRNNCell(n_hidden) + output, states = rnn.static_rnn(cell, x, dtype=tf.float32) + final = tf.nn.bias_add(tf.matmul(output[-1], weight), bias, name='output') + output = tf.Variable(tf.random_uniform(tf.shape(final)),name='output_result') + result = tf.assign(output, final) + saver = tf.train.Saver() + with tf.Session() as sess: + init = tf.global_variables_initializer() + sess.run(init) + sess.run(result) + checkpointpath = saver.save(sess, dir + '/model.chkp') + tf.train.write_graph(sess.graph, dir, 'model.pbtxt') + + input_graph = dir + "/model.pbtxt" + input_checkpoint = dir + "/model.chkp" + output_node_names= ["output", "output_result"] + output_graph = dir + "/model.pb" + + merge_checkpoint(input_graph, input_checkpoint, output_node_names, output_graph) +if __name__ == "__main__": + main() diff --git a/spark/dl/src/test/resources/tf/models/rnn_lstm.py b/spark/dl/src/test/resources/tf/models/rnn_lstm.py new file mode 100644 index 00000000000..6f371dac37f --- /dev/null +++ b/spark/dl/src/test/resources/tf/models/rnn_lstm.py @@ -0,0 +1,63 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import tensorflow as tf +import numpy as np +from sys import argv +from tensorflow.contrib import rnn +from util import merge_checkpoint + +def main(): + """ + Run this command to generate the pb file + 1. mkdir model + 2. python rnn_lstm.py + """ + dir = argv[1] + n_steps = 2 + n_input = 10 + n_hidden = 20 + n_output = 5 + # xs = tf.placeholder(tf.float32, [None, n_steps, n_input]) + xs = tf.Variable(tf.random_uniform([4, n_steps, n_input]) + 10, name='input', dtype=tf.float32) + weight = tf.Variable(tf.random_uniform([n_hidden, n_output]) + 10, name="weight", dtype=tf.float32) + bias = tf.Variable(tf.random_uniform([n_output]) + 10, name="bias", dtype=tf.float32) + + x = tf.unstack(xs, n_steps, 1) + + cell = rnn.BasicLSTMCell(n_hidden) + + output, states = rnn.static_rnn(cell, x, dtype=tf.float32) + + final = tf.nn.bias_add(tf.matmul(output[-1], weight), bias, name='output') + + output = tf.Variable(tf.random_uniform(tf.shape(final)),name='output_result') + result = tf.assign(output, final) + saver = tf.train.Saver() + with tf.Session() as sess: + init = tf.global_variables_initializer() + sess.run(init) + sess.run(result) + checkpointpath = saver.save(sess, dir + '/model.chkp') + tf.train.write_graph(sess.graph, dir, 'model.pbtxt') + + input_graph = dir + "/model.pbtxt" + input_checkpoint = dir + "/model.chkp" + output_node_names= ["output", "output_result"] + output_graph = dir + "/model.pb" + + merge_checkpoint(input_graph, input_checkpoint, output_node_names, output_graph) +if __name__ == "__main__": + main() diff --git a/spark/dl/src/test/resources/tf/models/share_weight.py b/spark/dl/src/test/resources/tf/models/share_weight.py new file mode 100644 index 00000000000..cd90d1f942e --- /dev/null +++ b/spark/dl/src/test/resources/tf/models/share_weight.py @@ -0,0 +1,55 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import tensorflow as tf +import numpy as np +from sys import argv +from util import merge_checkpoint + +def main(): + """ + Run this command to generate the pb file + 1. mkdir model + 2. python test.py + 3. wget https://raw.githubusercontent.com/tensorflow/tensorflow/v1.0.0/tensorflow/python/tools/freeze_graph.py + 4. python freeze_graph.py --input_graph model/share_weight.pbtxt --input_checkpoint model/share_weight.chkp --output_node_names=output --output_graph "share_weight.pb" + """ + xs = tf.placeholder(tf.float32, [None, 10]) + W1 = tf.Variable(tf.random_normal([10,10])) + b1 = tf.Variable(tf.random_normal([10])) + Wx_plus_b1 = tf.nn.bias_add(tf.matmul(xs,W1), b1) + output= tf.nn.tanh(Wx_plus_b1) + + Wx_plus_b2 = tf.nn.bias_add(tf.matmul(output,W1), b1) + W2 = tf.Variable(tf.random_normal([10, 1])) + b2 = tf.Variable(tf.random_normal([1])) + final = tf.nn.bias_add(tf.matmul(Wx_plus_b2, W2), b2, name='output') + dir = argv[1] + saver = tf.train.Saver() + with tf.Session() as sess: + init = tf.global_variables_initializer() + sess.run(init) + checkpointpath = saver.save(sess, dir + '/model.chkp') + tf.train.write_graph(sess.graph, dir, 'model.pbtxt') + + input_graph = dir + "/model.pbtxt" + input_checkpoint = dir + "/model.chkp" + output_node_names = "output" + output_graph = dir + "/model.pb" + + merge_checkpoint(input_graph, input_checkpoint, [output_node_names], output_graph) + +if __name__ == "__main__": + main() diff --git a/spark/dl/src/test/resources/tf/models/util.py b/spark/dl/src/test/resources/tf/models/util.py new file mode 100644 index 00000000000..0581d8bda05 --- /dev/null +++ b/spark/dl/src/test/resources/tf/models/util.py @@ -0,0 +1,112 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from google.protobuf import text_format + +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import graph_util +from tensorflow.python.framework import importer +from tensorflow.python.platform import gfile +import tensorflow as tf + +def merge_checkpoint(input_graph, + input_checkpoint, + output_node_names, + output_graph): + """ + merge the checkpoint file with the non-binary graph file to + generate one GraphDef file with the variable values + Args: + input_graph: the GraphDef file, not in the binary form + input_checkpoint: the checkpoint file + output_node_names: A string of name of the output names, + use comma to seperate multi outputs + output_graph: String of the location and the name of the + output graph + """ + restore_op_name = "save/restore_all" + filename_tensor_name = "save/Const:0" + + input_graph_def = graph_pb2.GraphDef() + mode = "r" + with gfile.FastGFile(input_graph, mode) as f: + text_format.Merge(f.read().decode("utf-8"), input_graph_def) + for node in input_graph_def.node: + node.device = "" + _ = importer.import_graph_def(input_graph_def, name="") + with session.Session() as sess: + sess.run([restore_op_name], {filename_tensor_name: input_checkpoint}) + output_graph_def = graph_util.convert_variables_to_constants( + sess, + input_graph_def, + output_node_names, + variable_names_blacklist="") + with gfile.GFile(output_graph, "wb") as f: + f.write(output_graph_def.SerializeToString()) + +def run_model(end_points, output_path, model_scope=None): + outputs = [] + results = [] + grad_inputs = [] + grad_vars = [] + grad_results = [] + trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=model_scope) + i = 0 + opt = tf.train.GradientDescentOptimizer(0.01) + for end_point in end_points: + output = tf.Variable(tf.random_uniform(tf.shape(end_point)), name='output' + str(i)) + outputs.append(output) + results.append(tf.assign(output, end_point, name = 'assign' + str(i))) + + # set up backward variables + # filter None tensor + tmp_vars = filter(lambda x: tf.gradients(end_point, x) != [None], trainable_vars) + # set up random gradient input + grad_input = tf.Variable(tf.random_uniform(tf.shape(end_point)), name='grad_input' + str(i)) + grad_inputs.append(grad_input) + # compute gradients with random input + backward = opt.compute_gradients(end_point, var_list=tmp_vars, grad_loss=grad_input) + j = 0 + for gradients, tensor in backward: + grad_var = tf.Variable(tf.random_uniform(tf.shape(tensor)), + name='{}_grad{}'.format(tensor.name[:-2], i)) + grad_vars.append(grad_var) + grad_result = tf.assign(grad_var, gradients, name='grad_assign' + str((i+1)*j)) + grad_results.append(grad_result) + j = j + 1 + i = i + 1 + + saver = tf.train.Saver() + with tf.Session() as sess: + init = tf.global_variables_initializer() + sess.run(init) + sess.run(results) + sess.run(grad_results) + saver.save(sess, output_path + '/model.chkp') + tf.train.write_graph(sess.graph, output_path, 'model.pbtxt') + # tf.summary.FileWriter(output_path + '/log', sess.graph) + + input_graph = output_path + "/model.pbtxt" + input_checkpoint = output_path + "/model.chkp" + output_file = output_path + "/model.pb" + + output_nodes = map(lambda x: 'assign' + str(x), range(len(end_points))) + grades_nodes = map(lambda x: 'grad_assign' + str(x), range(len(grad_results))) + output_nodes.extend(grades_nodes) + + # merge_checkpoint(input_graph, input_checkpoint, map(lambda x: 'assign' + str(x), range(len(end_points))), output_file) + merge_checkpoint(input_graph, input_checkpoint, output_nodes, output_file) + diff --git a/spark/dl/src/test/resources/tf/models/vgg16.py b/spark/dl/src/test/resources/tf/models/vgg16.py new file mode 100644 index 00000000000..11cdd940ee0 --- /dev/null +++ b/spark/dl/src/test/resources/tf/models/vgg16.py @@ -0,0 +1,39 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import tensorflow as tf +from nets import vgg +from sys import argv + +from util import run_model + +def main(): + """ + You can also run these commands manually to generate the pb file + 1. git clone https://github.com/tensorflow/models.git + 2. export PYTHONPATH=Path_to_your_model_folder + 3. python alexnet.py + """ + height, width = 224, 224 + inputs = tf.Variable(tf.random_uniform((1, height, width, 3)), name='input') + net, end_points = vgg.vgg_16(inputs, is_training = False) + print("nodes in the graph") + for n in end_points: + print(n + " => " + str(end_points[n])) + net_outputs = map(lambda x: tf.get_default_graph().get_tensor_by_name(x), argv[2].split()) + run_model(net_outputs, argv[1]) + +if __name__ == "__main__": + main() diff --git a/spark/dl/src/test/resources/tf/models/vgg19.py b/spark/dl/src/test/resources/tf/models/vgg19.py new file mode 100644 index 00000000000..0b29cda9bc7 --- /dev/null +++ b/spark/dl/src/test/resources/tf/models/vgg19.py @@ -0,0 +1,39 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import tensorflow as tf +from nets import vgg +from sys import argv + +from util import run_model + +def main(): + """ + You can also run these commands manually to generate the pb file + 1. git clone https://github.com/tensorflow/models.git + 2. export PYTHONPATH=Path_to_your_model_folder + 3. python alexnet.py + """ + height, width = 224, 224 + inputs = tf.Variable(tf.random_uniform((1, height, width, 3)), name='input') + net, end_points = vgg.vgg_19(inputs, is_training = False) + print("nodes in the graph") + for n in end_points: + print(n + " => " + str(end_points[n])) + net_outputs = map(lambda x: tf.get_default_graph().get_tensor_by_name(x), argv[2].split()) + run_model(net_outputs, argv[1]) + +if __name__ == "__main__": + main() diff --git a/spark/dl/src/test/resources/tf/models/vgga.py b/spark/dl/src/test/resources/tf/models/vgga.py new file mode 100644 index 00000000000..9417740d7cd --- /dev/null +++ b/spark/dl/src/test/resources/tf/models/vgga.py @@ -0,0 +1,39 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import tensorflow as tf +from nets import vgg +from sys import argv + +from util import run_model + +def main(): + """ + You can also run these commands manually to generate the pb file + 1. git clone https://github.com/tensorflow/models.git + 2. export PYTHONPATH=Path_to_your_model_folder + 3. python alexnet.py + """ + height, width = 224, 224 + inputs = tf.Variable(tf.random_uniform((1, height, width, 3)), name='input') + net, end_points = vgg.vgg_a(inputs, is_training = False) + print("nodes in the graph") + for n in end_points: + print(n + " => " + str(end_points[n])) + net_outputs = map(lambda x: tf.get_default_graph().get_tensor_by_name(x), argv[2].split()) + run_model(net_outputs, argv[1]) + +if __name__ == "__main__": + main() diff --git a/spark/dl/src/test/resources/tf/save_test.py b/spark/dl/src/test/resources/tf/save_test.py new file mode 100644 index 00000000000..d42840692f6 --- /dev/null +++ b/spark/dl/src/test/resources/tf/save_test.py @@ -0,0 +1,43 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import tensorflow as tf +import numpy as np +import os +from tensorflow.python.platform import gfile +from sys import argv + +def main(): + with gfile.FastGFile(argv[1],'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def, name='') + sess = tf.Session() + output_suffix = '' + if len(argv) == 3: + output_suffix = argv[2] + output = graph.get_tensor_by_name('output' + output_suffix + ':0') + target = graph.get_tensor_by_name('target:0') + tf_output = sess.run(output) + bigdl_output = sess.run(target) + print("Tensorflow output is:") + print(tf_output) + print("BigDL output is:") + print(bigdl_output) + np.testing.assert_almost_equal(tf_output, bigdl_output, 4) + +if __name__ == "__main__": + main() diff --git a/spark/dl/src/test/resources/tf/test.pb b/spark/dl/src/test/resources/tf/test.pb new file mode 100644 index 00000000000..839ceff1b14 Binary files /dev/null and b/spark/dl/src/test/resources/tf/test.pb differ diff --git a/spark/dl/src/test/resources/tf/test.py b/spark/dl/src/test/resources/tf/test.py new file mode 100644 index 00000000000..c0f882304d0 --- /dev/null +++ b/spark/dl/src/test/resources/tf/test.py @@ -0,0 +1,45 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import tensorflow as tf +import numpy as np +import os + +def main(): + """ + Run this command to generate the pb file + 1. mkdir model + 2. python test.py + 3. wget https://raw.githubusercontent.com/tensorflow/tensorflow/v1.0.0/tensorflow/python/tools/freeze_graph.py + 4. python freeze_graph.py --input_graph model/test.pbtxt --input_checkpoint model/test.chkp --output_node_names=output --output_graph "test.pb" + """ + dir = os.path.dirname(os.path.realpath(__file__)) + xs = tf.placeholder(tf.float32, [None, 1]) + W1 = tf.Variable(tf.zeros([1,10])+0.2) + b1 = tf.Variable(tf.zeros([10])+0.1) + Wx_plus_b1 = tf.nn.bias_add(tf.matmul(xs,W1), b1) + output= tf.nn.tanh(Wx_plus_b1) + + W2 = tf.Variable(tf.zeros([10,1])+0.2) + b2 = tf.Variable(tf.zeros([1])+0.1) + Wx_plus_b2 = tf.nn.bias_add(tf.matmul(output,W2), b2, name='output') + saver = tf.train.Saver() + with tf.Session() as sess: + init = tf.global_variables_initializer() + sess.run(init) + checkpointpath = saver.save(sess, dir + '/model/test.chkp') + tf.train.write_graph(sess.graph, dir + '/model', 'test.pbtxt') +if __name__ == "__main__": + main() diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/ConstSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/tf/ConstSpec.scala similarity index 98% rename from spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/ConstSpec.scala rename to spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/tf/ConstSpec.scala index 05f19457982..01d0aefc87e 100644 --- a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/ConstSpec.scala +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/tf/ConstSpec.scala @@ -13,12 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.intel.analytics.bigdl.nn +package com.intel.analytics.bigdl.nn.tf -import org.scalatest.{FlatSpec, Matchers} import com.intel.analytics.bigdl.numeric.NumericFloat import com.intel.analytics.bigdl.tensor.Tensor import com.intel.analytics.bigdl.utils.T +import org.scalatest.{FlatSpec, Matchers} class ConstSpec extends FlatSpec with Matchers { "Const forward tensor" should "be correct" in { diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/FillSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/tf/FillSpec.scala similarity index 97% rename from spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/FillSpec.scala rename to spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/tf/FillSpec.scala index 7712460f215..a874fd17fd8 100644 --- a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/FillSpec.scala +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/tf/FillSpec.scala @@ -13,12 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.intel.analytics.bigdl.nn +package com.intel.analytics.bigdl.nn.tf -import org.scalatest.{FlatSpec, Matchers} import com.intel.analytics.bigdl.numeric.NumericFloat import com.intel.analytics.bigdl.tensor.Tensor import com.intel.analytics.bigdl.utils.T +import org.scalatest.{FlatSpec, Matchers} class FillSpec extends FlatSpec with Matchers { diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/ShapeSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/tf/ShapeSpec.scala similarity index 97% rename from spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/ShapeSpec.scala rename to spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/tf/ShapeSpec.scala index d8bc8d15d37..876beb10745 100644 --- a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/ShapeSpec.scala +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/tf/ShapeSpec.scala @@ -13,12 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.intel.analytics.bigdl.nn +package com.intel.analytics.bigdl.nn.tf -import org.scalatest.{FlatSpec, Matchers} import com.intel.analytics.bigdl.numeric.NumericFloat import com.intel.analytics.bigdl.tensor.Tensor import com.intel.analytics.bigdl.utils.T +import org.scalatest.{FlatSpec, Matchers} class ShapeSpec extends FlatSpec with Matchers { diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/SplitAndSelectSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/tf/SplitAndSelectSpec.scala similarity index 97% rename from spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/SplitAndSelectSpec.scala rename to spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/tf/SplitAndSelectSpec.scala index dec39acf224..75053cf393f 100644 --- a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/SplitAndSelectSpec.scala +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/tf/SplitAndSelectSpec.scala @@ -13,12 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.intel.analytics.bigdl.nn +package com.intel.analytics.bigdl.nn.tf -import org.scalatest.{FlatSpec, Matchers} import com.intel.analytics.bigdl.numeric.NumericFloat import com.intel.analytics.bigdl.tensor.Tensor import com.intel.analytics.bigdl.utils.T +import org.scalatest.{FlatSpec, Matchers} class SplitAndSelectSpec extends FlatSpec with Matchers { "SplitAndSelect forward" should "be correct" in { diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/StrideSliceSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/tf/StrideSliceSpec.scala similarity index 98% rename from spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/StrideSliceSpec.scala rename to spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/tf/StrideSliceSpec.scala index 8bfc01afd0d..99fad493310 100644 --- a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/StrideSliceSpec.scala +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/tf/StrideSliceSpec.scala @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.intel.analytics.bigdl.nn +package com.intel.analytics.bigdl.nn.tf import com.intel.analytics.bigdl.tensor.Tensor import org.scalatest.{FlatSpec, Matchers} diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/utils/tf/TensorflowLoaderSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/utils/tf/TensorflowLoaderSpec.scala new file mode 100644 index 00000000000..9cfe7b6db97 --- /dev/null +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/utils/tf/TensorflowLoaderSpec.scala @@ -0,0 +1,543 @@ +/* + * Copyright 2016 The BigDL Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.intel.analytics.bigdl.utils.tf + +import java.io.{File => JFile} +import java.nio.ByteOrder +import java.util.UUID + +import com.intel.analytics.bigdl.dataset.{DistributedDataSet, MiniBatch} +import com.intel.analytics.bigdl.nn._ +import com.intel.analytics.bigdl.optim.{DistriOptimizer, Trigger} +import com.intel.analytics.bigdl.tensor.{Storage, Tensor} +import com.intel.analytics.bigdl.utils._ +import org.apache.log4j.{Level, Logger} +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD +import com.intel.analytics.bigdl.numeric.NumericFloat +import org.tensorflow.framework.NodeDef + +import scala.collection.mutable +import scala.sys.process._ +import scala.math._ +import scala.reflect.ClassTag + +object TensorflowLoaderSpec { + private val data1 = Array(0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.1f) + private val data2 = Array(0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.1f) + private val input1: Tensor[Float] = Tensor[Float](Storage[Float](data1)) + private val input2: Tensor[Float] = Tensor[Float](Storage[Float](data2)) + private val nodeNumber = 4 + private val coreNumber = 4 + + Engine.init(nodeNumber, coreNumber, true) + + private val batchSize = 2 * coreNumber + + private val prepareData: Int => (MiniBatch[Float]) = index => { + val input = Tensor[Float]().resize(batchSize, 10) + val target = Tensor[Float]().resize(batchSize) + var i = 0 + while (i < batchSize) { + if (i % 2 == 0) { + target.setValue(i + 1, 0.0f) + input.select(1, i + 1).copy(input1) + } else { + target.setValue(i + 1, 0.1f) + input.select(1, i + 1).copy(input2) + } + i += 1 + } + MiniBatch(input, target) + } +} + +class TensorflowLoaderSpec extends TensorflowSpecHelper{ + + Logger.getLogger("org").setLevel(Level.WARN) + Logger.getLogger("akka").setLevel(Level.WARN) + + import TensorflowLoaderSpec._ + + var sc: SparkContext = null + + var dataSet: DistributedDataSet[MiniBatch[Float]] = null + + before { + sc = new SparkContext("local[1]", "RDDOptimizerSpec") + + val rdd = sc.parallelize(1 to (256 * 4), 4).map(prepareData) + + dataSet = new DistributedDataSet[MiniBatch[Float]] { + override def originRDD(): RDD[_] = rdd + + override def data(train : Boolean): RDD[MiniBatch[Float]] = rdd + + override def size(): Long = 256 * nodeNumber + + override def shuffle(): Unit = {} + } + + Engine.model.setPoolSize(1) + + System.setProperty("bigdl.enableNHWC", "true") + } + + after { + if (sc != null) { + sc.stop() + } + System.setProperty("bigdl.enableNHWC", "false") + } + + "TensorFlow loader" should "read a list of nodes from pb file" in { + val resource = getClass().getClassLoader().getResource("tf") + val path = processPath(resource.getPath()) + JFile.separator + "test.pb" + val results = TensorflowLoader.parse(path) + results.size() should be(14) + } + + "TensorFlow loader" should "be able to build a TF graph" in { + val resource = getClass().getClassLoader().getResource("tf") + val path = processPath(resource.getPath()) + JFile.separator + "test.pb" + val results = TensorflowLoader.parse(path) + val tfGraph = TensorflowLoader.buildTFGraph(results, Seq("output")) + tfGraph.size should be(15) // there's a dummy output + val topSort = tfGraph.topologySort// It can do topology sort + topSort.length should be(15) + topSort(0).element should be(null) + topSort(1).element.getName should be("output") + topSort(2).element.getName should be("MatMul_1") + topSort(3).element.getName should be("Variable_3/read") + topSort(4).element.getName should be("Variable_3") + topSort(5).element.getName should be("Tanh") + topSort(6).element.getName should be("Variable_2/read") + topSort(7).element.getName should be("Variable_2") + topSort(8).element.getName should be("BiasAdd") + topSort(9).element.getName should be("MatMul") + topSort(10).element.getName should be("Variable_1/read") + topSort(11).element.getName should be("Variable_1") + topSort(12).element.getName should be("Placeholder") + topSort(13).element.getName should be("Variable/read") + topSort(14).element.getName should be("Variable") + } + + "TensorFlow loader" should "be able to build a BigDL graph" in { + val resource = getClass().getClassLoader().getResource("tf") + val path = processPath(resource.getPath()) + JFile.separator + "test.pb" + val model = TensorflowLoader.load(path, Seq("Placeholder"), Seq("output"), + ByteOrder.LITTLE_ENDIAN) + val container = model.asInstanceOf[Graph[Float]] + container.modules.length should be(4) + RandomGenerator.RNG.setSeed(100) + val input = Tensor[Float](4, 1).rand() + val output1 = container.forward(input) + + val model2 = Sequential[Float]() + val fc1 = Linear[Float](1, 10) + fc1.parameters()._1(0).fill(0.2f) + fc1.parameters()._1(1).fill(0.1f) + model2.add(fc1).add(Tanh()) + + val fc2 = Linear[Float](10, 1) + fc2.parameters()._1(0).fill(0.2f) + fc2.parameters()._1(1).fill(0.1f) + model2.add(fc2) + + val output2 = model2.forward(input) + output1 should be(output2) + } + + "Shared weights" should "be the same instance" in { + tfCheck() + val modelName = "share_weight" + // Generate command and prepare the temp folder + val s = JFile.separator + val modelsFolder = processPath(getClass().getClassLoader().getResource("tf").getPath()) + + s + "models" + val modelScript = modelsFolder + s + s"$modelName.py" + val tmpLocation = java.io.File.createTempFile("tensorflowLoaderTest" + UUID.randomUUID(), + modelName) + tmpLocation.delete() + tmpLocation.mkdir() + + require(runPython(s"$modelScript $tmpLocation"), "error when run the model script") + + // Load the model and input/output tensors + val modelFile = tmpLocation + s + "model.pb" + val model = TensorflowLoader.load(modelFile, Seq("Placeholder"), Seq("output"), + ByteOrder.LITTLE_ENDIAN) + val container = model.asInstanceOf[Graph[Float]] + val l1 = container.modules(1).asInstanceOf[Linear[Float]] + val l2 = container.modules(3).asInstanceOf[Linear[Float]] + assert(l1.weight eq l2.weight) + assert(l1.bias eq l2.bias) + } + + "Shared weights" should "be the same after running optimizer" in { + tfCheck() + val modelName = "share_weight" + // Generate command and prepare the temp folder + val s = JFile.separator + val modelsFolder = processPath(getClass().getClassLoader().getResource("tf").getPath()) + + s + "models" + val modelScript = modelsFolder + s + s"$modelName.py" + val tmpLocation = java.io.File.createTempFile("tensorflowLoaderTest" + UUID.randomUUID(), + modelName) + tmpLocation.delete() + tmpLocation.mkdir() + + require(runPython(s"$modelScript $tmpLocation"), "error when run the model script") + + // Load the model and input/output tensors + val modelFile = tmpLocation + s + "model.pb" + val model = TensorflowLoader.load(modelFile, Seq("Placeholder"), Seq("output"), + ByteOrder.LITTLE_ENDIAN) + val container = model.asInstanceOf[Graph[Float]] + + val optimizer = new DistriOptimizer[Float](container, dataSet, new MSECriterion[Float]()) + .setState(T("learningRate" -> 20.0)) + .setEndWhen(Trigger.maxEpoch(5)) + optimizer.optimize() + + val l1 = container.modules(1).asInstanceOf[Linear[Float]] + val l2 = container.modules(3).asInstanceOf[Linear[Float]] + assert(l1.weight == l2.weight) + assert(l1.bias == l2.bias) + } + + "static simple rnn " should "have the same inference result as tensorflow" in { + System.setProperty("bigdl.enableNHWC", "false") + tfCheck() + val modelName = "rnn" + // Generate command and prepare the temp folder + val s = JFile.separator + val modelsFolder = processPath(getClass().getClassLoader().getResource("tf").getPath()) + + s + "models" + val modelScript = modelsFolder + s + s"$modelName.py" + val tmpLocation = java.io.File.createTempFile("tensorflowLoaderTest" + UUID.randomUUID(), + modelName) + tmpLocation.delete() + tmpLocation.mkdir() + + require(runPython(s"$modelScript $tmpLocation"), "error when run the model script") + + // Load the model and input/output tensors + val modelFile = tmpLocation + s + "model.pb" + + + val results = TensorflowLoader.parse(modelFile) + val tfGraph = TensorflowLoader.buildTFGraph(results, Seq("output")) + val model = TensorflowLoader.buildBigDLModel(tfGraph, Seq("input"), + Seq("output"), + ByteOrder.LITTLE_ENDIAN) + val input = TensorflowToBigDL.toTensor(results.get(0).getAttrMap.get("value").getTensor, + ByteOrder.LITTLE_ENDIAN).contiguous() + val tfResult = TensorflowToBigDL.toTensor(results.get(results.size()-1) + .getAttrMap.get("value").getTensor, ByteOrder.LITTLE_ENDIAN) + val bigDLResult = model.forward(input) + tfResult.almostEqual(bigDLResult.toTensor, 1e-6) + } + + "static lstm rnn " should "have the same inference result as tensorflow" in { + tfCheck() + System.setProperty("bigdl.enableNHWC", "false") + val modelName = "rnn_lstm" + // Generate command and prepare the temp folder + val s = JFile.separator + val modelsFolder = processPath(getClass().getClassLoader().getResource("tf").getPath()) + + s + "models" + val modelScript = modelsFolder + s + s"$modelName.py" + val tmpLocation = java.io.File.createTempFile("tensorflowLoaderTest" + UUID.randomUUID(), + modelName) + tmpLocation.delete() + tmpLocation.mkdir() + + require(runPython(s"$modelScript $tmpLocation"), "error when run the model script") + + // Load the model and input/output tensors + val modelFile = tmpLocation + s + "model.pb" + + val results = TensorflowLoader.parse(modelFile) + val tfGraph = TensorflowLoader.buildTFGraph(results.subList(0, results.size()-1), Seq("output")) + val model = TensorflowLoader.buildBigDLModel(tfGraph, Seq("input"), + Seq("output"), + ByteOrder.LITTLE_ENDIAN) + val input = TensorflowToBigDL.toTensor(results.get(0).getAttrMap.get("value").getTensor, + ByteOrder.LITTLE_ENDIAN).contiguous() + val tfResult = TensorflowToBigDL.toTensor(results.get(results.size()-1) + .getAttrMap.get("value").getTensor, ByteOrder.LITTLE_ENDIAN) + val bigDLResult = model.forward(input) + tfResult.almostEqual(bigDLResult.toTensor, 1e-5) + } + + "Tensorflow lenet" should "be load correctly" in { + testModelForward("lenet", Seq("LeNet/pool2/MaxPool:0"), true).foreach { + case(tf, bigdl) => + val transpose = bigdl.transpose(2, 3).transpose(3, 4) + tf.almostEqual(transpose, 1e-6) should be(true) + } + testModelBackward("lenet", Seq("LeNet/pool2/MaxPool:0"), true, + Seq((4, 3), (3, 2))).foreach { + case(tf, bigdl) => + if (tf.dim() == 4) { + val trans = tf.transpose(1, 4).transpose(2, 3).transpose(3, 4).contiguous() + trans.almostEqual(bigdl, 1e-4) should be(true) + } + else { + tf.almostEqual(bigdl, 1e-4) should be(true) + } + } + } + + "Tensorflow Alexnet" should "be load correctly" in { + testModelForward("alexnet", Seq("alexnet_v2/fc8/squeezed:0"), true).foreach { + case(tf, bigdl) => + tf.almostEqual(bigdl, 1e-7) should be(true) + } + testModelBackward("alexnet", Seq("alexnet_v2/fc8/squeezed:0"), true, + Seq.empty).foreach { + case(tf, bigdl) => + if (tf.dim() == 4) { + val trans = tf.transpose(1, 4).transpose(2, 3).transpose(3, 4).contiguous() + trans.almostEqual(bigdl, 1e-4) should be(true) + } + else { + tf.almostEqual(bigdl, 1e-4) should be(true) + } + } + } + + "TensorFlow vgg_a" should "be load correctly" in { + testModelForward("vgga", Seq("vgg_a/fc8/squeezed:0"), true).foreach { + case(tf, bigdl) => + tf.almostEqual(bigdl, 1e-7) should be(true) + } + } + + "TensorFlow vgg_16" should "be load correctly" in { + testModelForward("vgg16", Seq("vgg_16/fc8/squeezed:0"), true).foreach { + case(tf, bigdl) => + tf.almostEqual(bigdl, 1e-7) should be(true) + } + } + + "TensorFlow vgg_19" should "be load correctly" in { + testModelForward("vgg19", Seq("vgg_19/fc8/squeezed:0"), true).foreach { + case(tf, bigdl) => + tf.almostEqual(bigdl, 1e-7) should be(true) + } + } + + "TensorFlow overfeat" should "be load correctly" in { + testModelForward("overfeat", Seq("overfeat/fc8/squeezed:0"), true).foreach { + case(tf, bigdl) => + tf.almostEqual(bigdl, 1e-7) should be(true) + } + } + + "TensorFlow inception_v3" should "be load correctly" in { + testModelForward("inception_v3", Seq("InceptionV3/Logits/SpatialSqueeze:0"), true).foreach { + case(tf, bigdl) => + tf.almostEqual(bigdl, 1e-7) should be(true) + } + } + + "TensorFlow resnet_v1" should "be load correctly" in { + testModelForward("resnet_v1", Seq("resnet_v1_101/SpatialSqueeze:0"), true).foreach { + case(tf, bigdl) => + tf.almostEqual(bigdl, 1e-6) should be(true) + } + } + + "TensorFlow inception_resnet_v2" should "be load correctly" in { + testModelForward("inception_resnet_v2", Seq("InceptionResnetV2/Logits/Logits/BiasAdd:0", + "InceptionResnetV2/AuxLogits/Logits/BiasAdd:0"), true).foreach { + case(tf, bigdl) => + tf.almostEqual(bigdl, 1e-7) should be(true) + } + } + + + private def testModelForward(modelName: String, endPoints: Seq[String], transInput: Boolean) + : Seq[(Tensor[Float], Tensor[Float])] = { + + tfCheck() + // Generate command and prepare the temp folder + val s = JFile.separator + val modelsFolder = processPath(getClass().getClassLoader().getResource("tf").getPath()) + + s + "models" + val modelScript = modelsFolder + s + s"$modelName.py" + val tmpLocation = java.io.File.createTempFile("tensorflowLoaderTest" + UUID.randomUUID(), + modelName) + tmpLocation.delete() + tmpLocation.mkdir() + + require(runPython(s"$modelScript $tmpLocation ${endPoints.mkString(",")}"), + "error when run the model script") + + // Load the model and input/output tensors + import collection.JavaConverters._ + val modelFile = tmpLocation + s + "model.pb" + val tfNodes = TensorflowLoader.parse(modelFile) + + // filter node for gradient computing + val tfGraph = TensorflowLoader.buildTFGraph(tfNodes, endPoints.map(_.split(":")(0))) + val context = new mutable.HashMap[NodeDef, (Tensor[Float], Tensor[Float])] + val model = TensorflowLoader.buildBigDLModel(tfGraph, Seq("input"), + endPoints.map(_.split(":")(0)), ByteOrder.LITTLE_ENDIAN, Some(context)) + + // Compare the tensor contents + val tfInputTensor = tfNodes.asScala.filter(_.getName == "input")(0) + .getAttrMap.get("value").getTensor + + val tfOutputTensors = (0 until endPoints.length).map( + i => tfNodes.asScala.filter(_.getName == s"output$i")(0).getAttrMap.get("value").getTensor) + val input = TensorflowToBigDL.toTensor(tfInputTensor, + ByteOrder.LITTLE_ENDIAN) + + val transposeInput = if (transInput) { + input.transpose(2, 4).transpose(3, 4).contiguous() + } else { + input + } + + val bigdlOutputs = if (endPoints.length == 1) { + Seq(model.forward(transposeInput).toTensor) + } else { + val t = model.forward(transposeInput).toTable + (1 to endPoints.length).map(t[Tensor[Float]](_)) + } + + val comparePair = tfOutputTensors.zip(bigdlOutputs).map{ + x => + val tensor = TensorflowToBigDL.toTensor(x._1, ByteOrder.LITTLE_ENDIAN) + (tensor, x._2) + } + tmpLocation.deleteOnExit() + comparePair + } + + private def testModelBackward( + modelName: String, + endPoints: Seq[String], + transInput: Boolean, + transOutputSeq: Seq[(Int, Int)]): Seq[(Tensor[Float], Tensor[Float])] = { + + tfCheck() + // Generate command and prepare the temp folder + val s = JFile.separator + val modelsFolder = processPath(getClass().getClassLoader().getResource("tf").getPath()) + + s + "models" + val modelScript = modelsFolder + s + s"$modelName.py" + val tmpLocation = java.io.File.createTempFile("tensorflowLoaderTest" + UUID.randomUUID(), + modelName) + tmpLocation.delete() + tmpLocation.mkdir() + + require(runPython(s"$modelScript $tmpLocation ${endPoints.mkString(",")}"), + "error when run the model script") + + // Load the model and input/output tensors + import collection.JavaConverters._ + val modelFile = tmpLocation + s + "model.pb" + val tfNodes = TensorflowLoader.parse(modelFile) + + // filter node for gradient computing + val tfGraph = TensorflowLoader.buildTFGraph(tfNodes, endPoints.map(_.split(":")(0))) + val context = new mutable.HashMap[NodeDef, (Tensor[Float], Tensor[Float])] + val model = TensorflowLoader.buildBigDLModel(tfGraph, Seq("input"), + endPoints.map(_.split(":")(0)), ByteOrder.LITTLE_ENDIAN, Some(context)) + + // Compare the tensor contents + val tfInputTensor = tfNodes.asScala.filter(_.getName == "input")(0) + .getAttrMap.get("value").getTensor + + val input = TensorflowToBigDL.toTensor(tfInputTensor, + ByteOrder.LITTLE_ENDIAN) + + val transposeInput = if (transInput) { + input.transpose(2, 4).transpose(3, 4).contiguous() + } else { + input + } + + val bigdlOutputs = if (endPoints.length == 1) { + Seq(model.forward(transposeInput).toTensor) + } else { + val t = model.forward(transposeInput).toTable + (1 to endPoints.length).map(t[Tensor[Float]](_)) + } + + // get gradient input of tensorflow + val gradInputs = (0 until endPoints.length).map{ + i => + val t = tfNodes.asScala.filter(_.getName == s"grad_input$i")(0) + .getAttrMap.get("value").getTensor + var tensor = TensorflowToBigDL.toTensor(t, ByteOrder.LITTLE_ENDIAN) + for (trans <- transOutputSeq) { + tensor = tensor.transpose(trans._1, trans._2) + } + tensor.contiguous() + } + + // check shape equality here + for (i <- 0 until endPoints.length) { + bigdlOutputs(i).size() should be(gradInputs(i).size()) + } + + // find all gradients tensor in tensorflow graph + val tfGradTensorsMap = context.keySet.map{ + node => + val t = tfNodes.asScala.filter(_.getName.contains(node.getName + "_grad"))(0) + t.getName -> + TensorflowToBigDL.toTensor(t.getAttrMap.get("value").getTensor, ByteOrder.LITTLE_ENDIAN) + }.toMap + + + val comparePair = new mutable.ArrayBuffer[(Tensor[Float], Tensor[Float])]() + + // do backward for each output and its corresponding gradient input + for (i <- 0 until gradInputs.length) { + // println(s"grad $i") + model.backward(transposeInput, gradInputs(i)) + val pairs = context.keySet.map{ + x => + val name = s"${x.getName}_grad$i" + // if (tfGradTensorsMap.contains(name)) { + // println(x.getName) + // context(x)._2.size().foreach(println(_)) + // println(name) + // tfGradTensorsMap(name).size().foreach(println(_)) + // } + (tfGradTensorsMap.get(name).getOrElse(null), context(x)._2) + }.toSeq.filter(_._2 != null) + comparePair ++= pairs + } + println(s"Compare ${comparePair.length} pairs of gradient vars in this graph") + tmpLocation.deleteOnExit() + comparePair + } + + + private def processPath(path: String): String = { + if (path.contains(":")) { + path.substring(1) + } else { + path + } + } +} diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/utils/tf/TensorflowSaverSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/utils/tf/TensorflowSaverSpec.scala new file mode 100644 index 00000000000..d39d5a342dd --- /dev/null +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/utils/tf/TensorflowSaverSpec.scala @@ -0,0 +1,295 @@ +/* + * Copyright 2016 The BigDL Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.intel.analytics.bigdl.utils.tf + + +import java.nio.ByteOrder +import java.util.UUID + +import com.intel.analytics.bigdl.nn.abstractnn.AbstractModule +import com.intel.analytics.bigdl.nn._ +import com.intel.analytics.bigdl.numeric.NumericFloat +import com.intel.analytics.bigdl.tensor.Tensor +import com.intel.analytics.bigdl.utils.{T, Table} +import org.apache.log4j.Logger + +class TensorflowSaverSpec extends TensorflowSpecHelper { + + private val logger = Logger.getLogger(getClass) + + before { + System.setProperty("bigdl.enableNHWC", "true") + } + + after { + System.setProperty("bigdl.enableNHWC", "false") + } + + "ReLU layer" should "be correctly saved" in { + val inputTensor = Tensor[Float](T( + T(1.0f, 2.0f, 5.0f, 6.0f), + T(-3.0f, -4.0f, -7.0f, -8.0f) + )) + test(ReLU[Float](), inputTensor) should be(true) + } + + "Linear layer" should "be correctly saved" in { + val layer = Linear[Float](3, 4, + initWeight = Tensor(T( + T(1.0f, 2.0f, 3.0f), + T(4.0f, 5.0f, 6.0f), + T(7.0f, 8.0f, 9.0f), + T(10.0f, 11.0f, 12.0f) + )), + initBias = Tensor(T(1.0f, 2.0f, 3.0f, 4.0f)) + ) + val input = Tensor[Float](T( + T(1.0f, 2.0f, 5.0f), + T(-3.0f, -4.0f, -7.0f) + )) + test(layer, input, false, "/biasAdd") should be(true) + } + + "AvgPooling" should "be correctly saved" in { + val layer = SpatialAveragePooling(2, 2) + val input = Tensor[Float](T(T( + T( + T(1.0f, 2.0f, 5.0f), + T(-3.0f, -4.0f, -7.0f), + T(-4.0f, -2.0f, -1.0f) + ), + T( + T(-1.0f, -2.0f, -5.0f), + T(3.0f, 4.0f, 7.0f), + T(4.0f, 2.0f, 1.0f) + ) + ))) + test(layer, input, true) should be(true) + } + + "MaxPooling" should "be correctly saved" in { + val layer = SpatialMaxPooling(2, 2) + val input = Tensor[Float](T(T( + T( + T(1.0f, 2.0f, 5.0f), + T(-3.0f, -4.0f, -7.0f), + T(-4.0f, -2.0f, -1.0f) + ), + T( + T(-1.0f, -2.0f, -5.0f), + T(3.0f, 4.0f, 7.0f), + T(4.0f, 2.0f, 1.0f) + ) + ))) + test(layer, input, true) should be(true) + } + + "Tanh" should "be correctly saved" in { + val layer = Tanh() + val input = Tensor[Float](4).rand() + test(layer, input) should be(true) + } + + "Squeeze" should "be correctly saved" in { + System.setProperty("bigdl.enableNHWC", "false") + val layer = Squeeze(3) + val input = Tensor[Float](4, 2, 1, 2).rand() + test(layer, input, false) should be(true) + } + + "CAddTableToTF" should "be correct" in { + val layer = CAddTable[Float]() + val input1 = Tensor[Float](4, 2, 2).rand() + val input2 = Tensor[Float](4, 2, 2).rand() + testMultiInput(layer, Seq(input1, input2), false) should be(true) + } + + "CMultToTF" should "be correct" in { + val layer = CMulTable[Float]() + val input1 = Tensor[Float](4, 2, 2).rand() + val input2 = Tensor[Float](4, 2, 2).rand() + testMultiInput(layer, Seq(input1, input2), false) should be(true) + } + + "JoinTableToTF" should "be correct" in { + val layer = JoinTable[Float](3, -1) + val input1 = Tensor[Float](4, 2, 2).rand() + val input2 = Tensor[Float](4, 2, 2).rand() + testMultiInput(layer, Seq(input1, input2), false) should be(true) + } + + "LogSoftMax" should "be correctly saved" in { + val layer = LogSoftMax() + val input = Tensor[Float](4, 5).rand() + test(layer, input, false) should be(true) + } + + "SoftMax" should "be correctly saved" in { + val layer = SoftMax() + val input = Tensor[Float](4, 5).rand() + test(layer, input, false) should be(true) + } + + "Sigmoid" should "be correctly saved" in { + val layer = Sigmoid() + val input = Tensor[Float](4, 5).rand() + test(layer, input, false) should be(true) + } + + "SpatialConvolution" should "be correctly saved" in { + val layer = SpatialConvolution(3, 5, 2, 2) + val input = Tensor[Float](4, 3, 5, 5).rand() + test(layer, input, true, "/biasAdd") should be(true) + } + + "Mean" should "be correctly saved" in { + val layer = Mean(1, -1, true) + val input = Tensor[Float](4, 5).rand() + test(layer, input, false, "/output") should be(true) + } + + "Padding" should "be correctly saved" in { + val layer = Padding(1, 2, 2) + val input = Tensor[Float](4, 5).rand() + test(layer, input, false, "/output") should be(true) + } + + "Batch Norm2D" should "be correctly saved" in { + val layer = SpatialBatchNormalization(2) + layer.evaluate() + layer.weight.rand(10.0, 20.0) + layer.bias.rand() + layer.runningVar.rand(0.9, 1.1) + layer.runningMean.rand() + val input = Tensor[Float](3, 2, 4, 5).rand() + test(layer, input, true, "/output") should be(true) + } + + "Dropout" should "be correctly saved" in { + val layer = Dropout() + layer.evaluate() + val input = Tensor[Float](3, 2).rand() + test(layer, input, false) should be(true) + } + + "View" should "be correctly saved" in { + val layer = View(2, 4) + val input = Tensor[Float](2, 2, 2).rand() + test(layer, input, false) should be(true) + } + + "Reshape" should "be correctly saved" in { + val layer = Reshape(Array(2, 4)) + val input = Tensor[Float](2, 2, 2).rand() + test(layer, input, false) should be(true) + } + + "lenet" should "be correctly saved" in { + tfCheck() + val conv1 = SpatialConvolution(1, 6, 5, 5).setName("conv1").apply() + val tanh1 = Tanh().setName("tanh1").apply(conv1) + val pool1 = SpatialMaxPooling(2, 2, 2, 2).setName("pool1").apply(tanh1) + val tanh2 = Tanh().setName("tanh2").apply(pool1) + val conv2 = SpatialConvolution(6, 12, 5, 5).setName("conv2").apply(tanh2) + val pool2 = SpatialMaxPooling(2, 2, 2, 2).setName("output").apply(conv2) + + val funcModel = Graph(conv1, pool2) + val inputData = Tensor(4, 1, 28, 28).rand() + val transInput = inputData.transpose(2, 3).transpose(3, 4).contiguous() + val outputData = funcModel.forward(inputData).toTensor + + val tmpFile = java.io.File.createTempFile("tensorflowSaverTest" + UUID.randomUUID(), "lenet") + TensorflowSaver.saveGraphWitNodeDef( + funcModel, + Seq(Tensorflow.const(transInput, "input", ByteOrder.LITTLE_ENDIAN)), + tmpFile.getPath, + ByteOrder.LITTLE_ENDIAN, + Set(Tensorflow.const(outputData.transpose(2, 3).transpose(3, 4).contiguous(), + "target", ByteOrder.LITTLE_ENDIAN)) + ) + + runPythonSaveTest(tmpFile.getPath, "") should be(true) + } + + private def test(layer: AbstractModule[Tensor[Float], Tensor[Float], Float], + inputTensor: Tensor[Float], + convertNHWC: Boolean = false, + outputSuffix: String = "") : Boolean = { + tfCheck() + val layerNode = layer.setName("output").apply() + val graph = Graph(layerNode, layerNode) + val outputTensor = layer.forward(inputTensor) + + val tmpFile = java.io.File.createTempFile("tensorflowSaverTest" + UUID.randomUUID(), "Layer") + logger.info(s"Save model to ${tmpFile}") + val tfTensor = if (convertNHWC) { + inputTensor.transpose(2, 3).transpose(3, 4).contiguous() + } else { + inputTensor + } + val outputSave = if (convertNHWC) { + outputTensor.transpose(2, 3).transpose(3, 4).contiguous() + } else { + outputTensor + } + TensorflowSaver.saveGraphWitNodeDef( + graph, + Seq(Tensorflow.const(tfTensor, "input", ByteOrder.LITTLE_ENDIAN)), + tmpFile.getPath, + ByteOrder.LITTLE_ENDIAN, + Set(Tensorflow.const(outputSave, "target", ByteOrder.LITTLE_ENDIAN)) + ) + runPythonSaveTest(tmpFile.getPath, outputSuffix) + } + + private def testMultiInput(layer: AbstractModule[Table, Tensor[Float], Float], + inputTensors: Seq[Tensor[Float]], + convertNHWC: Boolean = false, + outputSuffix: String = "") : Boolean = { + tfCheck() + val layerNode = layer.setName("output").apply() + val inputNodes = inputTensors.map(_ => Input[Float]()).toArray + inputNodes.foreach(_ -> layerNode) + inputNodes.zipWithIndex.foreach(n => n._1.element.setName("inputNode" + n._2)) + val graph = Graph(inputNodes, layerNode) + val inputTable = T() + inputTensors.foreach(inputTable.insert(_)) + val outputTensor = layer.forward(inputTable) + + val tmpFile = java.io.File.createTempFile("tensorflowSaverTest" + UUID.randomUUID(), "Layer") + logger.info(s"Save model to ${tmpFile}") + val tfTensors = if (convertNHWC) { + inputTensors.map(t => t.transpose(2, 3).transpose(3, 4).contiguous()) + } else { + inputTensors + } + val outputSave = if (convertNHWC) { + outputTensor.transpose(2, 3).transpose(3, 4).contiguous() + } else { + outputTensor + } + + TensorflowSaver.saveGraphWitNodeDef( + graph, + tfTensors.zipWithIndex.map(t => + Tensorflow.const(t._1, "input" + t._2, ByteOrder.LITTLE_ENDIAN)), + tmpFile.getPath, + ByteOrder.LITTLE_ENDIAN, + Set(Tensorflow.const(outputSave, "target", ByteOrder.LITTLE_ENDIAN)) + ) + runPythonSaveTest(tmpFile.getPath, outputSuffix) + } +} diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/utils/tf/TensorflowSpecHelper.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/utils/tf/TensorflowSpecHelper.scala new file mode 100644 index 00000000000..d42fb516eea --- /dev/null +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/utils/tf/TensorflowSpecHelper.scala @@ -0,0 +1,60 @@ +/* + * Copyright 2016 The BigDL Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.intel.analytics.bigdl.utils.tf + +import com.intel.analytics.bigdl.utils.TestUtils.processPath +import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers} +import java.io.{File => JFile} + +import org.apache.log4j.Logger + +import scala.sys.process._ + +class TensorflowSpecHelper extends FlatSpec with Matchers with BeforeAndAfter { + + private val logger = Logger.getLogger(getClass) + + protected def tfCheck(): Unit = { + var exitValue : String = "" + try { + exitValue = ((Seq("python", "-c", "import sys; print ','.join(sys.path)"))!!) + ((Seq("python", "-c", "import tensorflow"))!!) + } catch { + case _: Throwable => cancel("python or tensorflow is not installed") + } + + if (!exitValue.contains("models")) { + cancel("Tensorflow models path is not exported") + } + } + + protected def runPython(cmd: String): Boolean = { + try { + logger.info("run command\n" + cmd) + val proc = s"python $cmd".run + return proc.exitValue() == 0 + } catch { + case _: Throwable => false + } + } + + protected def runPythonSaveTest(graphPath: String, outputSuffix: String) : Boolean = { + val resource = getClass().getClassLoader().getResource("tf") + val path = processPath(resource.getPath()) + JFile.separator + + s"save_test.py $graphPath $outputSuffix" + runPython(path) + } +}