Skip to content

Commit

Permalink
refine inteface (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
yiheng authored and i8run committed Jun 27, 2018
1 parent fc9d165 commit 358d072
Show file tree
Hide file tree
Showing 14 changed files with 382 additions and 516 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.intel.analytics.bigdl.nn.mkldnn
import breeze.linalg.dim
import com.intel.analytics.bigdl.mkl.{DataType, Memory, MklDnn, Query}
import com.intel.analytics.bigdl.nn.Graph.ModuleNode
import com.intel.analytics.bigdl.nn.{Container, DynamicContainer, JoinTable}
import com.intel.analytics.bigdl.nn.{Container, DynamicContainer}
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.tensor.{MklDnnTensor, MklDnnType, Tensor}
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class ConcatTable extends MklDnnContainer {
require(modules.length > 0, "empty modules of concat table")
var i = 0
while (i < modules.length) {
val currentOutput = modules(i).forward(input)
val currentOutput = modules(i).forward(
reorderManager.infer(_inputFormats, mklDnnModules(i).inputFormats(), input))
output.toTable(i + 1) = currentOutput
i += 1
}
Expand Down Expand Up @@ -60,68 +61,54 @@ class ConcatTable extends MklDnnContainer {
}
}

/**
* Compute the output formats based on the input formats
*/
override private[mkldnn] def inferShape(shapes: Array[Array[Int]]) = {
require(shapes.length == 1, "Concat only accept one tensor")
require(mklDnnModules.length > 0, "Concat should contains at least one module")

val outputShape = new ArrayBuffer[Array[Int]]()
for(i <- 0 until mklDnnModules.length) {
val outputShapes = mklDnnModules(i).inferShape(shapes)
require(outputShapes.length == 1, "submodule only output one tensor")
outputShape.append(outputShapes(0))
}
outputShape.toArray
}

override private[mkldnn] def initFwdPrimitives(runtime: MklDnnRuntime, phase: Phase) = {
require(MemoryData.noUndef(inputFormats()), "Memory formats should be inited")
override private[mkldnn] def initFwdPrimitives(inputs: Array[MemoryData], phase: Phase) = {
require(mklDnnModules != null, "You should call compile first")
require(inputs.length == 1, "Concat only accept one tensor")
val buffer = new ArrayBuffer[MemoryData]()
mklDnnModules.foreach(m => {
m.initFwdPrimitives(runtime, phase)
val out = m.outputFormats()
val (realInput, out) = m.initFwdPrimitives(inputs, phase)
require(out.length == 1, "output should be one tensor")
reorderManager.register(inputs(0), realInput(0))
buffer.append(out(0))
})
_outputFormats = buffer.toArray
_inputFormats = inputs
(inputs, _outputFormats)
}

override private[mkldnn] def initBwdPrimitives(runtime: MklDnnRuntime, phase: Phase) = {
val formats = gradOutputFormats()._1
require(MemoryData.noUndef(formats), "Memory formats should be inited")
val buffer = new ArrayBuffer[MemoryData]()
mklDnnModules.foreach(m => {
m.initBwdPrimitives(runtime, phase)
val out = m.gradInputFormats()
require(out.length == 1, "output should be one tensor")
buffer.append(out(0))
})
_gradInputFormats = buffer.toArray
override private[mkldnn] def initBwdPrimitives(grads: Array[MemoryData], phase: Phase) = {
require(grads.length == mklDnnModules.length, "grad tensor number is not correct")
val realGradsBuffer = new ArrayBuffer[MemoryData]()
for(i <- 0 until grads.length) {
val m = mklDnnModules(i)
val (realGrads, gradInput) = m.initBwdPrimitives(Array(grads(i)), phase)
require(realGrads.length == 1, "real grad length should be 1")
realGradsBuffer.append(realGrads(0))
require(gradInput.length == 1, "real grad length should be 1")
if (_gradInputFormats == null) {
_gradInputFormats = gradInput
} else {
require(_gradInputFormats(0) == gradInput(0), "reorder backward should be same")
}
}
_gradOutputFormats = realGradsBuffer.toArray
(realGradsBuffer.toArray, _gradInputFormats)
}

override private[mkldnn] def initGradWPrimitives(runtime: MklDnnRuntime, phase: Phase) = {
val formats = gradOutputFormats()._2
require(MemoryData.noUndef(formats), "Memory formats should be inited")
mklDnnModules.foreach(m => {
m.initGradWPrimitives(runtime, phase)
})
override private[mkldnn] def initGradWPrimitives(grads: Array[MemoryData], phase: Phase) = {
val realGradsBuffer = new ArrayBuffer[MemoryData]()
for(i <- 0 until grads.length) {
val m = mklDnnModules(i)
val realGradOutput = m.initGradWPrimitives(grads, phase)
require(realGradOutput.length == 1, "real grad length should be 1")
realGradsBuffer.append(realGradOutput(0))
}
_gradOutputWeightFormats = realGradsBuffer.toArray
_gradOutputWeightFormats
}

override private[mkldnn] def inputFormats() = {
if (_inputFormats == null) {
require(mklDnnModules != null, "container should be compiled")
mklDnnModules.foreach { m =>
require(m.inputFormats().length == 1, "input should be one tensor")
if (_inputFormats == null) {
_inputFormats = m.inputFormats()
} else {
require(_inputFormats(0) == m.inputFormats()(0), "input format should be same")
}
}
}
require(_inputFormats != null, "You should call initFwdPrimitives first")
_inputFormats
}

Expand All @@ -136,35 +123,17 @@ class ConcatTable extends MklDnnContainer {
}

override private[mkldnn] def gradOutputFormats() = {
if (_gradOutputFormats == null) {
require(mklDnnModules != null, "container should be compiled")
val gradBuffer = new ArrayBuffer[MemoryData]()
val gradForWeightBuffer = new ArrayBuffer[MemoryData]()
mklDnnModules.foreach { m =>
val (grad, gradForWeight) = m.gradOutputFormats()
require(grad.length == 1, "module gradOutput should be one tensor")
require(gradForWeight.length == 1, "module gradOutput should be one tensor")
gradBuffer.append(grad(0))
gradForWeightBuffer.append(gradForWeight(0))
}
_gradOutputFormats = (gradBuffer.toArray, gradForWeightBuffer.toArray)
}
require(_gradInputFormats != null, "You should call initBwdPrimitives first")
_gradOutputFormats
}

override private[mkldnn] def initMemory() = {
super.initMemory()
gradInput = gradInputFormats()(0) match {
case h: HeapData => Tensor[Float]()
case n: NativeData => DnnTensor[Float](n.shape)
case _ => throw new UnsupportedOperationException("NOt support memory format")
}
}

private var _inputFormats: Array[MemoryData] = _
private var _gradInputFormats: Array[MemoryData] = _
private var _outputFormats: Array[MemoryData] = _
private var _gradOutputFormats: (Array[MemoryData], Array[MemoryData]) = _
private var _gradOutputFormats: Array[MemoryData] = _
private var _gradOutputWeightFormats: Array[MemoryData] = _

override private[mkldnn] def gradOutputWeightFormats() = _gradOutputWeightFormats
}

object ConcatTable {
Expand Down
Loading

0 comments on commit 358d072

Please sign in to comment.