Skip to content

Commit

Permalink
Merge pull request #5 from i8run/open-source-mkl-dnn-for-cherry
Browse files Browse the repository at this point in the history
fix: rever linear
  • Loading branch information
zhangxiaoli73 committed Feb 6, 2018
2 parents 152f8dc + f006c82 commit 5dcf617
Show file tree
Hide file tree
Showing 6 changed files with 728 additions and 393 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,6 @@ object ResNet {
case spatialBatchNormalization
if (spatialBatchNormalization.isInstanceOf[SpatialBatchNormalization[Float]]) =>
val curModel = spatialBatchNormalization.asInstanceOf[SpatialBatchNormalization[Float]]
curModel.weight.apply1(_ => 1.0f)
curModel.bias.apply1(_ => 0.0f)
case linear if (linear.isInstanceOf[Linear[Float]]) =>
linear.asInstanceOf[Linear[Float]].bias.apply1(_ => 0.0f)
case _ => Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ class Linear[T: ClassTag](
val (dim, size) = if (tensor.dim() == 1 && (format == MklDnn.MemoryFormat.nc ||
format == MklDnn.MemoryFormat.oi)) {
(2, Array(1) ++ tensor.size())
} else if (tensor.dim() == 2 && (format == MklDnn.MemoryFormat.oihw)) {
(4, tensor.size() ++ Array(1, 1))
// } else if (tensor.dim() == 2 && (format == MklDnn.MemoryFormat.oihw)) {
// (4, tensor.size() ++ Array(1, 1))
} else {
(tensor.dim(), tensor.size())
}
Expand All @@ -116,6 +116,15 @@ class Linear[T: ClassTag](
primitive
}

private def init4(dim: Int, size: Array[Int], dataType: Int, format: Int, engine: Long): Long = {
val desc = MklDnn.MemoryDescInit(dim, size, dataType, format)
val primDesc = MklDnn.MemoryPrimitiveDescCreate(desc, engine)
val primitive = MklDnn.PrimitiveCreate0(primDesc)

MklDnn.PrimitiveDescDestroy(primDesc)
primitive
}

def initUser(tensor: Tensor[T], dataType: Int, format: Int, engine: Long): Long = {
val primDesc = tensor.getPrimitiveDesc()
val primitive = if (primDesc != 0L) { // if the tensor comes from mkldnn layer
Expand Down Expand Up @@ -144,8 +153,9 @@ class Linear[T: ClassTag](

var _shouldConvert: Boolean = true
def shouldConvert: Boolean = _shouldConvert
def setShouldConvert(v: Boolean): Unit = {
def setShouldConvert(v: Boolean): this.type = {
_shouldConvert = v
this
}

@transient var inputUserPrim = 0L
Expand Down Expand Up @@ -190,9 +200,37 @@ class Linear[T: ClassTag](

forwardPrimBuffer = ArrayBuffer.empty[Long]
forwardReorderPrimBuffer = ArrayBuffer.empty[Long]

// val srcMemDesc = if (input.getPrimitiveDesc() == 0L) {
// if (input.dim() == 1) {
// MklDnn.MemoryDescInit(input.dim() + 1, Array(1) ++ input.size(),
// MklDnn.DataType.f32, MklDnn.MemoryFormat.any)
// } else {
// MklDnn.MemoryDescInit(input.dim(), input.size(),
// MklDnn.DataType.f32, MklDnn.MemoryFormat.any)
// }
// } else {
// MklDnnOps.primitiveDescQueryMemory(input.getPrimitiveDesc())
// }

val srcMemDesc = if (input.dim() == 1) {
MklDnn.MemoryDescInit(input.dim() + 1, Array(1) ++ input.size(),
MklDnn.DataType.f32, MklDnn.MemoryFormat.any)
} else {
MklDnn.MemoryDescInit(input.dim(), input.size(),
MklDnn.DataType.f32, MklDnn.MemoryFormat.any)
}
val weightMemDesc = if (input.dim() == 4) {
MklDnn.MemoryDescInit(weight.dim(),
weight.size(),
// val format = if (MklDnn.getFormat(srcMemDesc) == MklDnn.MemoryFormat.nChw8c) {
// MklDnn.MemoryFormat.oIhw8i
// } else if (MklDnn.getFormat(srcMemDesc) == MklDnn.MemoryFormat.nChw16c) {
// MklDnn.MemoryFormat.oIhw16i
// } else {
// MklDnn.MemoryFormat.any
// }
MklDnn.MemoryDescInit(4,
// Array(outputSize) ++ input.size().slice(1, input.dim()),
weight.size() ++ Array(1, 1),
MklDnn.DataType.f32, MklDnn.MemoryFormat.any)
} else {
MklDnn.MemoryDescInit(weight.dim(), weight.size(),
Expand All @@ -209,14 +247,6 @@ class Linear[T: ClassTag](
MklDnn.DataType.f32, MklDnn.MemoryFormat.any)
}

val srcMemDesc = if (input.dim() == 1) {
MklDnn.MemoryDescInit(input.dim() + 1, Array(1) ++ input.size(),
MklDnn.DataType.f32, MklDnn.MemoryFormat.any)
} else {
MklDnn.MemoryDescInit(input.dim(), input.size(),
MklDnn.DataType.f32, MklDnn.MemoryFormat.any)
}

val format = input.dim() match {
case 1 => MklDnn.MemoryFormat.nc
case 2 => MklDnn.MemoryFormat.nc
Expand All @@ -239,7 +269,7 @@ class Linear[T: ClassTag](
MklDnn.Query.src_pd)
inputReorderMemoryPrim = i1._1
inputReorderPrim = i1._2
weightUserPrim = initUser(weight, MklDnn.DataType.f32, format, engine)
weightUserPrim = initUser(weight, MklDnn.DataType.f32, weightFormat, engine)
val w1 = initInternal(weightUserPrim, opPrimDesc,
MklDnn.Query.weights_pd)
weightReorderMemoryPrim = w1._1
Expand Down Expand Up @@ -284,20 +314,20 @@ class Linear[T: ClassTag](
inputPtr = MklDnn.MemorySetDataHandle(inputUserPrim,
input.storage().array().asInstanceOf[Array[Float]],
input.storageOffset() - 1)
Memory.SetDataHandle(inputReorderMemoryPrim, internalInput.nativeStorage, 0)
Memory.SetDataHandle(inputReorderMemoryPrim, internalInput.ptr, 0)
} else {
Memory.SetDataHandle(inputUserPrim,
input.asInstanceOf[MklDnnTensor[T]].nativeStorage,
input.asInstanceOf[MklDnnTensor[T]].ptr,
0)
Memory.SetDataHandle(inputReorderMemoryPrim, internalInput.nativeStorage, 0)
Memory.SetDataHandle(inputReorderMemoryPrim, internalInput.ptr, 0)
}
} else {
if (input.getTensorType == DenseType) {
MklDnnTensor.syncFromHeap(internalInput, input.storage().array(), input.storageOffset() - 1)
Memory.SetDataHandle(inputUserPrim, internalInput.nativeStorage, 0)
Memory.SetDataHandle(inputUserPrim, internalInput.ptr, 0)
} else if (input.getTensorType == MklDnnType) {
Memory.SetDataHandle(inputUserPrim,
input.asInstanceOf[MklDnnTensor[T]].nativeStorage, 0)
input.asInstanceOf[MklDnnTensor[T]].ptr, 0)
}
}

Expand All @@ -306,20 +336,19 @@ class Linear[T: ClassTag](
weightPtr = MklDnn.MemorySetDataHandle(weightUserPrim,
weight.storage().array().asInstanceOf[Array[Float]],
weight.storageOffset() - 1)
Memory.SetDataHandle(weightReorderPrim, prvWeight.nativeStorage, 0)
Memory.SetDataHandle(weightReorderMemoryPrim, prvWeight.ptr, 0)
} else {
MklDnnTensor.syncFromHeap(prvWeight, weight.storage().array(), weight.storageOffset() - 1)
Memory.SetDataHandle(weightUserPrim, prvWeight.nativeStorage, 0)
Memory.SetDataHandle(weightUserPrim, prvWeight.ptr, 0)
}

MklDnnTensor.syncFromHeap(prvBias, bias.storage().array(), bias.storageOffset() - 1)

Memory.SetDataHandle(biasUserPrim, prvBias.nativeStorage, 0)
Memory.SetDataHandle(outputUserPrim, output.asInstanceOf[MklDnnTensor[T]].nativeStorage, 0)
Memory.SetDataHandle(biasUserPrim, prvBias.ptr, 0)
Memory.SetDataHandle(outputUserPrim, output.asInstanceOf[MklDnnTensor[T]].ptr, 0)
if (forwardReorderPrimBuffer.nonEmpty) {
MklDnn.StreamSubmit(stream, forwardReorderPrimBuffer.length, forwardReorderPrimBuffer.toArray)
}
MklDnn.StreamSubmit(stream, forwardPrimBuffer.length, forwardPrimBuffer.toArray)

if (inputReorderPrim != 0L) {
if (input.getTensorType == DenseType && inputPtr != 0) {
Expand All @@ -334,6 +363,8 @@ class Linear[T: ClassTag](
}
}

MklDnn.StreamSubmit(stream, forwardPrimBuffer.length, forwardPrimBuffer.toArray)

if (shouldConvert) {
output.asInstanceOf[MklDnnTensor[T]].syncToHeap()
}
Expand Down Expand Up @@ -375,7 +406,9 @@ class Linear[T: ClassTag](
}

val weightMemDesc = if (input.dim() == 4) {
MklDnn.MemoryDescInit(weight.dim(), weight.size(),
MklDnn.MemoryDescInit(4,
// Array(outputSize) ++ input.size().slice(1, input.dim()),
weight.size() ++ Array(1, 1),
MklDnn.DataType.f32, MklDnn.MemoryFormat.any)
} else {
MklDnn.MemoryDescInit(weight.dim(), weight.size(),
Expand All @@ -386,7 +419,7 @@ class Linear[T: ClassTag](
MklDnn.MemoryDescInit(gradOutput.dim() + 1, Array(1) ++ gradOutput.size(),
MklDnn.DataType.f32, MklDnn.MemoryFormat.any)
} else {
MklDnn.MemoryDescInit(output.dim(), output.size(),
MklDnn.MemoryDescInit(gradOutput.dim(), gradOutput.size(),
MklDnn.DataType.f32, MklDnn.MemoryFormat.any)
}

Expand All @@ -406,7 +439,7 @@ class Linear[T: ClassTag](
diffDstMemDesc)
val opPrimDesc = MklDnn.PrimitiveDescCreate(opDesc, engine, forwardPrimDesc)

gradOutputUserPrim = initUser(gradOutput, MklDnn.DataType.f32, format, engine)
gradOutputUserPrim = initUser(gradOutput, MklDnn.DataType.f32, MklDnn.MemoryFormat.nc, engine)
val g1 = initInternal(gradOutputUserPrim, opPrimDesc, MklDnn.Query.diff_dst_pd)
gradOutputReorderMemoryPrim = g1._1
gradOutputReorderPrim = g1._2
Expand Down Expand Up @@ -446,21 +479,21 @@ class Linear[T: ClassTag](
gradOutputPtr = MklDnn.MemorySetDataHandle(gradOutputUserPrim,
gradOutput.storage().array().asInstanceOf[Array[Float]],
gradOutput.storageOffset() - 1)
Memory.SetDataHandle(gradOutputReorderMemoryPrim, internalGradOutput.nativeStorage, 0)
Memory.SetDataHandle(gradOutputReorderMemoryPrim, internalGradOutput.ptr, 0)
} else {
Memory.SetDataHandle(gradOutputUserPrim,
gradOutput.asInstanceOf[MklDnnTensor[T]].nativeStorage,
gradOutput.asInstanceOf[MklDnnTensor[T]].ptr,
0)
Memory.SetDataHandle(gradOutputReorderMemoryPrim, internalGradOutput.nativeStorage, 0)
Memory.SetDataHandle(gradOutputReorderMemoryPrim, internalGradOutput.ptr, 0)
}
} else {
if (gradOutput.getTensorType == DenseType) {
MklDnnTensor.syncFromHeap(internalGradOutput, gradOutput.storage().array(),
gradOutput.storageOffset() - 1)
Memory.SetDataHandle(gradOutputUserPrim, internalGradOutput.nativeStorage, 0)
Memory.SetDataHandle(gradOutputUserPrim, internalGradOutput.ptr, 0)
} else if (gradOutput.getTensorType == MklDnnType) {
Memory.SetDataHandle(gradOutputUserPrim,
gradOutput.asInstanceOf[MklDnnTensor[T]].nativeStorage, 0)
gradOutput.asInstanceOf[MklDnnTensor[T]].ptr, 0)
}
}

Expand All @@ -469,13 +502,13 @@ class Linear[T: ClassTag](
weightPtr = MklDnn.MemorySetDataHandle(weightUserPrim,
weight.storage().array().asInstanceOf[Array[Float]],
weight.storageOffset() - 1)
Memory.SetDataHandle(weightReorderPrim, prvWeight.nativeStorage, 0)
Memory.SetDataHandle(weightReorderPrim, prvWeight.ptr, 0)
} else {
Memory.SetDataHandle(weightUserPrim, prvWeight.nativeStorage, 0)
Memory.SetDataHandle(weightUserPrim, prvWeight.ptr, 0)
}

Memory.SetDataHandle(biasUserPrim, prvBias.nativeStorage, 0)
Memory.SetDataHandle(gradInputUserPrim, gradInput.asInstanceOf[MklDnnTensor[T]].nativeStorage,
Memory.SetDataHandle(biasUserPrim, prvBias.ptr, 0)
Memory.SetDataHandle(gradInputUserPrim, gradInput.asInstanceOf[MklDnnTensor[T]].ptr,
0)
if (backwardDataReorderPrimBuffer.nonEmpty) {
MklDnn.StreamSubmit(stream, backwardDataReorderPrimBuffer.length,
Expand Down Expand Up @@ -514,7 +547,9 @@ class Linear[T: ClassTag](
backwardWeightPrimBuffer = ArrayBuffer.empty[Long]
backwardWeightReorderPrimBuffer = ArrayBuffer.empty[Long]
val diffWeightMemDesc = if (input.dim() == 4) {
MklDnn.MemoryDescInit(gradWeight.dim(), gradWeight.size(),
MklDnn.MemoryDescInit(4,
// weight.size() ++ Array(1, 1),
Array(outputSize) ++ input.size().slice(1, input.dim()),
MklDnn.DataType.f32, MklDnn.MemoryFormat.any)
} else {
MklDnn.MemoryDescInit(gradWeight.dim(), gradWeight.size(),
Expand All @@ -536,7 +571,7 @@ class Linear[T: ClassTag](
MklDnn.MemoryDescInit(gradOutput.dim() + 1, Array(1) ++ gradOutput.size(),
MklDnn.DataType.f32, MklDnn.MemoryFormat.any)
} else {
MklDnn.MemoryDescInit(output.dim(), output.size(),
MklDnn.MemoryDescInit(gradOutput.dim(), gradOutput.size(),
MklDnn.DataType.f32, MklDnn.MemoryFormat.any)
}

Expand All @@ -550,7 +585,13 @@ class Linear[T: ClassTag](
case 4 => MklDnn.MemoryFormat.oihw
}

diffWeightUserPrim = initUser(diffWeight, MklDnn.DataType.f32, weightFormat, engine)
diffWeightUserPrim = if (input.dim() == 4) {
init4(4,
Array(outputSize) ++ input.size().slice(1, input.dim())
, MklDnn.DataType.f32, weightFormat, engine)
} else {
initUser(diffWeight, MklDnn.DataType.f32, weightFormat, engine)
}
val d1 = initInternal(diffWeightUserPrim, opPrimDesc, MklDnn.Query.diff_weights_pd,
userToPrim = false)
diffWeightReorderMemoryPrim = d1._1
Expand Down Expand Up @@ -595,19 +636,19 @@ class Linear[T: ClassTag](
inputPtr = MklDnn.MemorySetDataHandle(inputUserPrim,
input.storage().array().asInstanceOf[Array[Float]],
input.storageOffset() - 1)
Memory.SetDataHandle(inputReorderMemoryPrim, internalInput.nativeStorage, 0)
Memory.SetDataHandle(inputReorderMemoryPrim, internalInput.ptr, 0)
} else {
Memory.SetDataHandle(inputUserPrim,
input.asInstanceOf[MklDnnTensor[T]].nativeStorage,
input.asInstanceOf[MklDnnTensor[T]].ptr,
0)
Memory.SetDataHandle(inputReorderMemoryPrim, internalInput.nativeStorage, 0)
Memory.SetDataHandle(inputReorderMemoryPrim, internalInput.ptr, 0)
}
} else {
if (input.getTensorType == DenseType) {
Memory.SetDataHandle(inputUserPrim, internalInput.nativeStorage, 0)
Memory.SetDataHandle(inputUserPrim, internalInput.ptr, 0)
} else if (input.getTensorType == MklDnnType) {
Memory.SetDataHandle(inputUserPrim,
input.asInstanceOf[MklDnnTensor[T]].nativeStorage, 0)
input.asInstanceOf[MklDnnTensor[T]].ptr, 0)
}
}

Expand All @@ -617,27 +658,27 @@ class Linear[T: ClassTag](
gradOutputPtr = MklDnn.MemorySetDataHandle(gradOutputUserPrim,
gradOutput.storage().array().asInstanceOf[Array[Float]],
gradOutput.storageOffset() - 1)
Memory.SetDataHandle(gradOutputReorderMemoryPrim, internalGradOutput.nativeStorage, 0)
Memory.SetDataHandle(gradOutputReorderMemoryPrim, internalGradOutput.ptr, 0)
} else {
Memory.SetDataHandle(gradOutputUserPrim,
gradOutput.asInstanceOf[MklDnnTensor[T]].nativeStorage,
gradOutput.asInstanceOf[MklDnnTensor[T]].ptr,
0)
Memory.SetDataHandle(gradOutputReorderMemoryPrim, internalGradOutput.nativeStorage, 0)
Memory.SetDataHandle(gradOutputReorderMemoryPrim, internalGradOutput.ptr, 0)
}
} else {
if (gradOutput.getTensorType == DenseType) {
Memory.SetDataHandle(gradOutputUserPrim, internalGradOutput.nativeStorage, 0)
Memory.SetDataHandle(gradOutputUserPrim, internalGradOutput.ptr, 0)
} else if (gradOutput.getTensorType == MklDnnType) {
Memory.SetDataHandle(gradOutputUserPrim,
gradOutput.asInstanceOf[MklDnnTensor[T]].nativeStorage, 0)
gradOutput.asInstanceOf[MklDnnTensor[T]].ptr, 0)
}
}

Memory.SetDataHandle(diffBiasUserPrim, diffBias.nativeStorage, 0)
Memory.SetDataHandle(diffBiasUserPrim, diffBias.ptr, 0)
if (diffWeightReorderPrim != 0) {
Memory.SetDataHandle(diffWeightReorderPrim, diffWeight.nativeStorage, 0)
Memory.SetDataHandle(diffWeightReorderPrim, diffWeight.ptr, 0)
} else {
Memory.SetDataHandle(diffWeightUserPrim, diffWeight.nativeStorage, 0)
Memory.SetDataHandle(diffWeightUserPrim, diffWeight.ptr, 0)
}

MklDnn.StreamSubmit(stream, backwardWeightPrimBuffer.length, backwardWeightPrimBuffer.toArray)
Expand Down
Loading

0 comments on commit 5dcf617

Please sign in to comment.