Skip to content

Commit 3a2f024

Browse files
add more writer test
1 parent 022b00e commit 3a2f024

File tree

11 files changed

+209
-27
lines changed

11 files changed

+209
-27
lines changed

spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Dropout.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ class Dropout[T: ClassTag](
207207
}
208208

209209
object Dropout {
210-
def apply[@specialized(Float, Double) T: ClassTag](
210+
def apply[T: ClassTag](
211211
initP: Double = 0.5,
212212
inplace: Boolean = false,
213213
scale: Boolean = true)(implicit ev: TensorNumeric[T]) : Dropout[T] = {

spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Graph.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,12 @@ class Input[T: ClassTag]()(implicit ev: TensorNumeric[T]) extends TensorModule[T
325325
gradInput = gradOutput
326326
gradInput
327327
}
328+
override def equals(other: Any): Boolean = {
329+
if (!other.isInstanceOf[Input[_]]) return false
330+
this.eq(other.asInstanceOf[Input[_]])
331+
}
332+
333+
override def hashCode(): Int = System.identityHashCode(this)
328334
}
329335

330336
object Input {

spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Padding.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class Padding[T: ClassTag](
116116
}
117117

118118
object Padding{
119-
def apply[@specialized(Float, Double) T: ClassTag](
119+
def apply[T: ClassTag](
120120
dim: Int,
121121
pad: Int,
122122
nInputDim: Int,

spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Reshape.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ class Reshape[@specialized(Float, Double) T: ClassTag](
128128
}
129129

130130
object Reshape {
131-
def apply[@specialized(Float, Double) T: ClassTag](
131+
def apply[T: ClassTag](
132132
size: Array[Int],
133133
batchMode: Option[Boolean] = None)(implicit ev: TensorNumeric[T]) : Reshape[T] = {
134134
new Reshape[T](size, batchMode)

spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Sigmoid.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class Sigmoid[@specialized(Float, Double) T: ClassTag](
4747
}
4848

4949
object Sigmoid {
50-
def apply[@specialized(Float, Double) T: ClassTag]()
50+
def apply[T: ClassTag]()
5151
(implicit ev: TensorNumeric[T]) : Sigmoid[T] = {
5252
new Sigmoid[T]()
5353
}

spark/dl/src/main/scala/com/intel/analytics/bigdl/tensor/DenseTensor.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1906,6 +1906,8 @@ private[tensor] class DenseTensor[@specialized(Float, Double) T: ClassTag](
19061906
"corresponding module, please keep them same.")
19071907
}
19081908
}
1909+
1910+
override def getTensorNumeric(): TensorNumeric[T] = ev
19091911
}
19101912

19111913
object DenseTensor {

spark/dl/src/main/scala/com/intel/analytics/bigdl/tensor/Tensor.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,12 @@ trait Tensor[T] extends Serializable with TensorMath[T] with Activity {
646646
* @return false
647647
*/
648648
override def isTable: Boolean = false
649+
650+
/**
651+
* Return tensor numeric
652+
* @return
653+
*/
654+
def getTensorNumeric(): TensorNumeric[T]
649655
}
650656

651657
/**

spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/BigDLToTensorflow.scala

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ import Tensorflow._
2525
import BigDLToTensorflow._
2626
import org.tensorflow.framework.{DataType, NodeDef}
2727

28+
import scala.collection.mutable.ArrayBuffer
29+
2830
/**
2931
* Wrapper of logic to convert module to tensorflow node definition
3032
*/
@@ -51,6 +53,15 @@ object BigDLToTensorflow {
5153
}
5254
}
5355

56+
object InputToTF extends BigDLToTensorflow {
57+
override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef],
58+
byteOrder: ByteOrder, dataFormat: TensorflowDataFormat): Seq[NodeDef] = {
59+
require(inputs.length == 1, "Input only accept one input")
60+
61+
Seq(identity(inputs(0), module.getName()))
62+
}
63+
}
64+
5465
object ReLUToTF extends BigDLToTensorflow {
5566
override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef],
5667
byteOrder: ByteOrder, dataFormat: TensorflowDataFormat): Seq[NodeDef] = {
@@ -80,10 +91,16 @@ object SpatialConvolutionToTF extends BigDLToTensorflow {
8091
byteOrder: ByteOrder, dataFormat: TensorflowDataFormat): Seq[NodeDef] = {
8192
require(inputs.length == 1, "SpatialConvolution only accept one input")
8293
val spatialConv = module.asInstanceOf[SpatialConvolution[_]]
83-
val filter = const(spatialConv.weight, spatialConv.getName() + "/filter", byteOrder)
94+
// squeeze will modify the weight tensor
95+
// GOIHW -> HWIO
96+
require(spatialConv.weight.size(1) == 1, "convolution group is not supported")
97+
val filterTensor = spatialConv.weight.select(1, 1)
98+
.transpose(2, 3).transpose(3, 4).transpose(1, 2).transpose(2, 3).transpose(3, 4).contiguous()
99+
100+
val filter = const(filterTensor, spatialConv.getName() + "/filter", byteOrder)
84101
val filterReader = identity(filter, spatialConv.getName() + "/filterReader")
85-
val conv = conv2D(inputs(0), filterReader, spatialConv.strideH, spatialConv.strideW,
86-
spatialConv.kernelW, spatialConv.kernelH, spatialConv.strideW, spatialConv.strideH,
102+
val conv = conv2D(inputs(0), filterReader, spatialConv.strideW, spatialConv.strideH,
103+
spatialConv.kernelW, spatialConv.kernelH, spatialConv.padW, spatialConv.padH,
87104
dataFormat, spatialConv.getName() + "/conv2D")
88105
val bias = const(spatialConv.bias, spatialConv.getName() + "/bias", byteOrder)
89106
val biasReader = identity(bias, spatialConv.getName() + "/biasReader")
@@ -121,7 +138,7 @@ object ReshapeToTF extends BigDLToTensorflow {
121138
size.setValue(i + 1, rh.size(i))
122139
i += 1
123140
}
124-
val shape = const(size, rh.getName() + "/shape", byteOrder, DataType.DT_INT32)
141+
val shape = const(size, rh.getName() + "/shape", byteOrder, false, DataType.DT_INT32)
125142
val reshapeNode = reshape(inputs(0), shape, rh.getName())
126143
Seq(reshapeNode, shape)
127144
}
@@ -138,7 +155,7 @@ object ViewToTF extends BigDLToTensorflow {
138155
size.setValue(i + 1, viewLayer.sizes(i))
139156
i += 1
140157
}
141-
val shape = const(size, viewLayer.getName() + "/shape", byteOrder, DataType.DT_INT32)
158+
val shape = const(size, viewLayer.getName() + "/shape", byteOrder, false, DataType.DT_INT32)
142159
val reshapeNode = reshape(inputs(0), shape, viewLayer.getName())
143160
Seq(reshapeNode, shape)
144161
}
@@ -168,7 +185,8 @@ object PaddingToTF extends BigDLToTensorflow {
168185
padding.setValue(1, 1, 0)
169186
padding.setValue(1, 2, layer.pad)
170187
}
171-
val paddingsNode = const(padding, layer.getName() + "/padding", byteOrder, DataType.DT_INT32)
188+
val paddingsNode = const(padding, layer.getName() + "/padding", byteOrder,
189+
false, DataType.DT_INT32)
172190
val padNode = pad(inputs(0), paddingsNode, layer.getName() + "/output")
173191
Seq(padNode, paddingsNode)
174192
}
@@ -234,7 +252,12 @@ object JoinTableToTF extends BigDLToTensorflow {
234252
override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef],
235253
byteOrder: ByteOrder, dataFormat: TensorflowDataFormat): Seq[NodeDef] = {
236254
val layer = module.asInstanceOf[JoinTable[_]]
237-
Seq(concat(inputs, layer.dimension - 1, layer.getName()))
255+
val axis = const(Tensor[Float](T((layer.dimension - 1).toFloat)), layer.getName() + "/axis",
256+
byteOrder, true, DataType.DT_INT32)
257+
val updateInputs = new ArrayBuffer[NodeDef]()
258+
updateInputs ++= inputs.reverse
259+
updateInputs.append(axis)
260+
Seq(concat(updateInputs, layer.dimension - 1, layer.getName()), axis)
238261
}
239262
}
240263

@@ -268,7 +291,7 @@ object LogSoftMaxToTF extends BigDLToTensorflow {
268291
}
269292
}
270293

271-
object BatchNormToTF extends BigDLToTensorflow {
294+
object BatchNorm2DToTF extends BigDLToTensorflow {
272295
override def toTFDef(module: AbstractModule[_, _, _], inputs: Seq[NodeDef],
273296
byteOrder: ByteOrder, dataFormat: TensorflowDataFormat): Seq[NodeDef] = {
274297
require(inputs.length == 1, "BatchNorm only accept one input")

spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/Tensorflow.scala

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ object Tensorflow {
103103
* @return
104104
*/
105105
def const[T: ClassTag](value : Tensor[T], name : String, byteOrder: ByteOrder,
106-
dataType: DataType = null): NodeDef = {
106+
isScalar: Boolean = false, dataType: DataType = null): NodeDef = {
107107
val dtype = if (dataType == null) {
108108
if (value.getType() == DoubleType) {
109109
DataType.DT_DOUBLE
@@ -118,7 +118,7 @@ object Tensorflow {
118118
.setName(name)
119119
.setOp("Const")
120120
.putAttr("dtype", AttrValue.newBuilder().setType(dtype).build())
121-
.putAttr("value", tensorAttr(value, dtype, byteOrder))
121+
.putAttr("value", tensorAttr(value, dtype, byteOrder, isScalar))
122122
.build()
123123
}
124124

@@ -204,7 +204,7 @@ object Tensorflow {
204204
.putAttr("T", getDataType(input))
205205
.putAttr("data_format", dataFormat.value)
206206
.putAttr("padding", getPaddingType(pW, pH, kW, kH, sW, sH).value)
207-
.putAttr("strides", listIntAttr(Seq(sH, sW)))
207+
.putAttr("strides", strideAttr(sW, sH, dataFormat))
208208
.build()
209209
}
210210

@@ -363,7 +363,7 @@ object Tensorflow {
363363
val node = NodeDef.newBuilder()
364364
.setName(name)
365365
.setOp("ConcatV2")
366-
.putAttr("N", intAttr(axis))
366+
.putAttr("N", intAttr(inputs.length - 1))
367367
.putAttr("T", getDataType(inputs(0)))
368368
.putAttr("Tidx", AttrValue.newBuilder().setType(DataType.DT_INT32).build())
369369

@@ -379,6 +379,7 @@ object Tensorflow {
379379
.putAttr("T", getDataType(tensor))
380380
.putAttr("Tpaddings", getDataType(paddings))
381381
.addInput(tensor.getName)
382+
.addInput(paddings.getName)
382383
.build()
383384
}
384385

@@ -436,11 +437,27 @@ object Tensorflow {
436437
}
437438

438439
private def tensorAttr[T: ClassTag](value: Tensor[T], dtype: DataType,
439-
byteOrder: ByteOrder): AttrValue = {
440+
byteOrder: ByteOrder, isScalar: Boolean): AttrValue = {
440441
val shape = TensorShapeProto.newBuilder()
441-
value.size().foreach(dim => {
442-
shape.addDim(Dim.newBuilder().setSize(dim))
443-
})
442+
if (!isScalar) {
443+
value.size().foreach(dim => {
444+
shape.addDim(Dim.newBuilder().setSize(dim))
445+
})
446+
}
447+
448+
/* if (value.nElement() == 1 && value.nDimension() == 1) {
449+
val tfTensor = TensorProto.newBuilder().setTensorShape(shape).setDtype(dtype)
450+
val tn = value.getTensorNumeric()
451+
dtype match {
452+
case DataType.DT_INT32 =>
453+
tfTensor.set
454+
tfTensor.setIntVal(0, tn.toType[Int](value.valueAt(1)))
455+
case _ =>
456+
throw new UnsupportedOperationException(
457+
s"$dtype is not support to write to a scalar tensor")
458+
}
459+
return AttrValue.newBuilder().setTensor(tfTensor).build()
460+
} */
444461

445462
require(value.isContiguous(), "only support save a contiguous tensor")
446463

spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/tf/TensorflowSaver.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ object TensorflowSaver {
5555
new mutable.HashMap[AbstractModule[Activity, Tensor[T], T], ArrayBuffer[NodeDef]]()
5656
model.inputs.zip(inputs).foreach(n => {
5757
inputNodeCache(n._1.element) = ArrayBuffer(n._2)
58+
println()
5859
})
5960

6061
val graphBuilder = GraphDef.newBuilder()
@@ -69,6 +70,7 @@ object TensorflowSaver {
6970
n.nextNodes.foreach(n => {
7071
val list = inputNodeCache.getOrElse(n.element, ArrayBuffer())
7172
list.append(nodeDefs(0))
73+
inputNodeCache(n.element) = list
7274
})
7375
})
7476

@@ -78,8 +80,8 @@ object TensorflowSaver {
7880
val os = new FileOutputStream(path)
7981
val output = CodedOutputStream.newInstance(os)
8082
val graph = graphBuilder.build()
81-
logger.debug("Graph definition is:")
82-
logger.debug(graph.toString)
83+
logger.info("Graph definition is:")
84+
logger.info(graph.toString)
8385
graph.writeTo(output)
8486
output.flush()
8587
os.close()
@@ -142,7 +144,9 @@ object TensorflowSaver {
142144
getNameFromObj(Mean.getClass.getName) -> MeanToTF,
143145
getNameFromObj(SoftMax.getClass.getName) -> SoftMaxToTF,
144146
getNameFromObj(LogSoftMax.getClass.getName) -> LogSoftMaxToTF,
145-
getNameFromObj(SpatialBatchNormalization.getClass.getName) -> BatchNormToTF
147+
getNameFromObj(SpatialBatchNormalization.getClass.getName) -> BatchNorm2DToTF,
148+
getNameFromObj(Input.getClass.getName) -> InputToTF,
149+
getNameFromObj(Sigmoid.getClass.getName) -> SigmoidToTF
146150
)
147151

148152
private def getNameFromObj(name: String) : String = name.substring(0, name.length - 1)

0 commit comments

Comments
 (0)