Skip to content

Commit

Permalink
some changes about CAddTable and ConcatTable
Browse files Browse the repository at this point in the history
  • Loading branch information
qiuxin2012 committed Nov 2, 2016
1 parent a2d788c commit 7a352b4
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 102 deletions.
19 changes: 6 additions & 13 deletions dl/src/main/scala/com/intel/analytics/sparkdl/nn/CAddTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,21 @@ import scala.reflect.ClassTag
class CAddTable[@specialized(Float, Double) T: ClassTag](val inplace: Boolean = false)(
implicit ev: TensorNumeric[T]) extends Module[Table, Tensor[T], T] {

gradInput = T()

override def updateOutput(input: Table): Tensor[T] = {
output = if (inplace) {
input.get[Tensor[T]](1).get
if (inplace) {
output = input[Tensor[T]](1)
} else {
val input1 = input.get[Tensor[T]](1).get
val input1 = input[Tensor[T]](1)
if (null == output) {
input1.clone()
output = input1.clone()
} else {
output.resizeAs(input1).copy(input1)
}
}

var i = 2
while (i <= input.length()) {
output.add(input.get[Tensor[T]](i).get)
output.add(input[Tensor[T]](i))
i += 1
}

Expand All @@ -56,19 +54,14 @@ class CAddTable[@specialized(Float, Double) T: ClassTag](val inplace: Boolean =
gradInput(i) = gradOutput
} else {
if (gradInput.contains(i)) {
gradInput.get[Tensor[T]](i).get.resizeAs(gradOutput).copy(gradOutput)
gradInput[Tensor[T]](i).resizeAs(gradOutput).copy(gradOutput)
} else {
gradInput.insert(i, gradOutput.clone())
}
}
i += 1
}

while(i <= gradInput.length()) {
gradInput.remove(i)
i += 1
}

gradInput
}

Expand Down
139 changes: 58 additions & 81 deletions dl/src/main/scala/com/intel/analytics/sparkdl/nn/ConcatTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,23 @@ import com.intel.analytics.sparkdl.utils.{Activities, T, Table}

import scala.reflect.ClassTag

class ConcatTable[T : ClassTag](implicit ev: TensorNumeric[T])
extends Container[Activities, Activities, T] {
class ConcatTable[A <: Activities : ClassTag, T : ClassTag]
(implicit ev: TensorNumeric[T]) extends Container[A, Table, T] {

output = T()

override def updateOutput(input: Activities): Activities = {
override def updateOutput(input: A): Table = {
var i = 0
while (i < modules.length) {
val currentOutput = modules(i).updateOutput(input)
if (!output.toTable().contains(i + 1)) {
output.toTable().insert(i + 1, currentOutput)
} else if (currentOutput != output.toTable().get(i + 1).get) {
output.toTable().update(i + 1, currentOutput)
}
output.toTable()(i + 1) = currentOutput
i += 1
}
output
}

/**
* add in to out
* @param out
* @param in
* @param out a table
* @param in a table
*/
private def addTable(out: Activities, in: Activities) : Unit = {
if (in.isInstanceOf[Tensor[T]] && out.isInstanceOf[Tensor[T]]) {
Expand All @@ -55,118 +49,101 @@ class ConcatTable[T : ClassTag](implicit ev: TensorNumeric[T])
} else {
var i = 1
while (i <= out.toTable().length()) {
addTable(out.toTable().get[Activities](i).get, in.toTable().get[Activities](i).get)
addTable(out.toTable()(i), in.toTable()(i))
i += 1
}
}
}

/**
* copy in to out
* @param out
* @param in
* copy src to out
* @param out a table
* @param src a table
*/
private def copyTable(out: Activities, in: Activities) : Unit = {
if (in.isInstanceOf[Tensor[T]] && out.isInstanceOf[Tensor[T]]) {
out.toTensor[T]().resizeAs(in.toTensor[T]()).copy(in.toTensor[T]())
private def copyTable(out: Activities, src: Activities) : Unit = {
if (src.isInstanceOf[Tensor[T]] && out.isInstanceOf[Tensor[T]]) {
out.toTensor[T]().resizeAs(src.toTensor[T]()).copy(src.toTensor[T]())
} else {
var i = 1
while (i <= out.toTable().length()) {
copyTable(out.toTable().get[Activities](i).get, in.toTable().get[Activities]().get)
copyTable(out.toTable()(i), src.toTable()(i))
i += 1
}
}
}

/**
* return a clone of in
* @param in
* @return cloned table
* return a clone of src,
* Notice: this is a deep copy, while Table.clone is a shallow copy.
* @param src a table
* @return cloned table of src
*/
private def cloneTable(in: Activities) : Activities = {
if (in.isInstanceOf[Tensor[T]]) {
in.toTensor[T]().clone()
private def cloneTable(src: Activities) : Activities = {
if (src.isInstanceOf[Tensor[T]]) {
src.toTensor[T]().clone()
} else {
val out = T()
var i = 1
while (i <= in.toTable().length()) {
out(i) = cloneTable(in.toTable()(i))
while (i <= src.toTable().length()) {
out(i) = cloneTable(src.toTable()(i))
i += 1
}
out
}
}

def backward(method: String, input: Activities, gradOutput: Activities,
scale : Double = 1.0) : Activities = {

val isTable = input.isInstanceOf[Table]
val wasTable = gradInput.isInstanceOf[Table]
override def updateGradInput(input: A, gradOutput: Table): A = {
val isInputTable = input.isInstanceOf[Table]
val wasGradInputTable = gradInput.isInstanceOf[Table]

if (isTable) {
if (!wasTable) {
gradInput = null
}
if (isInputTable) {
var i = 0
while (i < modules.length) {
method match {
case "updateGradInput" =>
val currentGradInput = modules(i).updateGradInput(input,
gradOutput.toTable().get(i + 1).get)
require(currentGradInput.isInstanceOf[Table],
"currentGradInput is not a table!")
if (i == 0) {
if (null == gradInput) {
gradInput = cloneTable(currentGradInput)
} else {
copyTable(gradInput, currentGradInput)
}
} else {
addTable(gradInput, currentGradInput)
}
case "accGradParameters" =>
modules(i).accGradParameters(input, gradOutput.toTable().get(i + 1).get, scale)
val currentGradInput = modules(i).updateGradInput(input,
gradOutput.toTable()(i + 1))
require(currentGradInput.isInstanceOf[Table],
"currentGradInput is not a table!")
if (i == 0) {
if (!wasGradInputTable) {
// We need deep copy here.
gradInput = cloneTable(currentGradInput).asInstanceOf[A]
} else {
copyTable(gradInput, currentGradInput)
}
} else {
addTable(gradInput, currentGradInput)
}
i += 1
}

} else {
if (wasTable) {
gradInput = null
}
var i = 0
while (i < modules.length) {
method match {
case "updateGradInput" =>
val currentGradInput = modules(i).updateGradInput(input,
gradOutput.toTable().get(i + 1).get)
if (i == 0) {
if (null == gradInput) {
gradInput = currentGradInput.toTensor().clone()
} else {
gradInput.toTensor[T]().resizeAs(
currentGradInput.toTensor[T]()).copy(currentGradInput.toTensor[T]())
}
} else {
gradInput.toTensor[T]().add(currentGradInput.toTensor[T]())
}
case "accGradParameters" =>
modules(i).accGradParameters(input, gradOutput.toTable().get(i + 1).get, scale)
val currentGradInput = modules(i).updateGradInput(input,
gradOutput.toTable()(i + 1)).toTensor[T]()
if (i == 0) {
if (wasGradInputTable) {
gradInput = currentGradInput.clone().asInstanceOf[A]
} else {
gradInput.toTensor[T]().resizeAs(
currentGradInput).copy(currentGradInput)
}
} else {
gradInput.toTensor[T]().add(currentGradInput)
}
i += 1
}
}
gradInput
}

override def updateGradInput(input: Activities, gradOutput: Activities): Activities = {
backward("updateGradInput", input, gradOutput)
}

override def accGradParameters(input: Activities, gradOutput: Activities,
scale: Double = 0.1): Unit = {

backward("accGradParameters", input, gradOutput)
override def accGradParameters(input: A, gradOutput: Table,
scale: Double = 1.0): Unit = {
var i = 0
while (i < modules.length) {
modules(i).accGradParameters(input, gradOutput.toTable()(i + 1), scale)
i += 1
}
}

override def toString(): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ object Activities {

if (classTag[A] == classTag[Tensor[T]]) {
result = Tensor[T]()
} else if (classTag[A] == classTag[Tensor[T]]) {
} else if (classTag[A] == classTag[Table]) {
result = T()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
package com.intel.analytics.sparkdl.nn

import com.intel.analytics.sparkdl.tensor.{Storage, Tensor}
import com.intel.analytics.sparkdl.utils.T
import com.intel.analytics.sparkdl.utils.{T, Table}
import org.scalatest.{FlatSpec, Matchers}

class ConcatTableSpec extends FlatSpec with Matchers {

"A ConcateTable" should "return right output and grad" in {
val ct = new ConcatTable[Double]()
val ct = new ConcatTable[Table, Double]()
ct.add(new Identity[Double]())
ct.add(new Identity[Double]())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class CAddTableSpec extends FlatSpec with BeforeAndAfter with Matchers {
RNG.setSeed(seed)

val model = new Sequential[Activities, Activities, Double]()
val ctable = new ConcatTable[Double]()
val ctable = new ConcatTable[Tensor[Double], Double]()
ctable.add(new Linear(5, 3))
ctable.add(new Linear(5, 3))
model.add(ctable)
Expand Down Expand Up @@ -72,7 +72,7 @@ class CAddTableSpec extends FlatSpec with BeforeAndAfter with Matchers {
RNG.setSeed(seed)

val model = new Sequential[Activities, Activities, Double]()
val ctable = new ConcatTable[Double]()
val ctable = new ConcatTable[Tensor[Double], Double]()
ctable.add(new Linear(5, 3))
ctable.add(new Linear(5, 3))
model.add(ctable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class ConcatTableSpec extends FlatSpec with BeforeAndAfter with Matchers {
val seed = 100
RNG.setSeed(seed)

val ctable = new ConcatTable[Double]()
val ctable = new ConcatTable[Tensor[Double], Double]()
ctable.zeroGradParameters()
ctable.add(new Linear(5, 2))
ctable.add(new Linear(5, 3))
val input = Tensor[Double](5).apply1(_ => Random.nextDouble())
Expand All @@ -46,27 +47,33 @@ class ConcatTableSpec extends FlatSpec with BeforeAndAfter with Matchers {
val output = ctable.forward(input)

val gradOutput = T(gradOutput1, gradOutput2)
val gradInput = ctable.updateGradInput(input, gradOutput)
val gradInput = ctable.backward(input, gradOutput)

val code = "torch.manualSeed(" + seed + ")\n" +
"""module = nn.ConcatTable():add(nn.Linear(5, 2)):add(nn.Linear(5, 3))
module:zeroGradParameters()
gradOutput = {gradOutput1, gradOutput2}
output = module:forward(input)
gradInput = module:backward(input, gradOutput)
output1 = output[1]
output2 = output[2]
parameters, gradParameters = module:getParameters()
"""

val (luaTime, torchResult) = TH.run(code,
Map("input" -> input, "gradOutput1" -> gradOutput1, "gradOutput2" -> gradOutput2),
Array("output1", "output2", "gradInput"))
Array("output1", "output2", "gradInput", "gradParameters"))
val luaOutput1 = torchResult("output1").asInstanceOf[Tensor[Double]]
val luaOutput2 = torchResult("output2").asInstanceOf[Tensor[Double]]
val luaGradInput = torchResult("gradInput").asInstanceOf[Tensor[Double]]
val luaGradParameters = torchResult("gradParameters").asInstanceOf[Tensor[Double]]
val luaOutput = T(luaOutput1, luaOutput2)

val gradParameters = ctable.getParameters()._2.asInstanceOf[Tensor[Double]]

output should be (luaOutput)
gradInput should be (luaGradInput)
gradParameters should be (luaGradParameters)
}

}

0 comments on commit 7a352b4

Please sign in to comment.