Skip to content

Commit

Permalink
ConcatTable CAddTable and Identity
Browse files Browse the repository at this point in the history
  • Loading branch information
qiuxin2012 committed Nov 1, 2016
1 parent 819cd84 commit a2d788c
Show file tree
Hide file tree
Showing 7 changed files with 554 additions and 0 deletions.
78 changes: 78 additions & 0 deletions dl/src/main/scala/com/intel/analytics/sparkdl/nn/CAddTable.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.intel.analytics.sparkdl.nn

import com.intel.analytics.sparkdl.tensor.Tensor
import com.intel.analytics.sparkdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.sparkdl.utils.{T, Table}

import scala.reflect.ClassTag

class CAddTable[@specialized(Float, Double) T: ClassTag](val inplace: Boolean = false)(
implicit ev: TensorNumeric[T]) extends Module[Table, Tensor[T], T] {

gradInput = T()

override def updateOutput(input: Table): Tensor[T] = {
output = if (inplace) {
input.get[Tensor[T]](1).get
} else {
val input1 = input.get[Tensor[T]](1).get
if (null == output) {
input1.clone()
} else {
output.resizeAs(input1).copy(input1)
}
}

var i = 2
while (i <= input.length()) {
output.add(input.get[Tensor[T]](i).get)
i += 1
}

output
}

override def updateGradInput(input: Table, gradOutput: Tensor[T]) : Table = {
var i = 1
while (i <= input.length()) {
if (inplace) {
gradInput(i) = gradOutput
} else {
if (gradInput.contains(i)) {
gradInput.get[Tensor[T]](i).get.resizeAs(gradOutput).copy(gradOutput)
} else {
gradInput.insert(i, gradOutput.clone())
}
}
i += 1
}

while(i <= gradInput.length()) {
gradInput.remove(i)
i += 1
}

gradInput
}

override def toString() : String = {
"nn.CAddTable"
}
}
197 changes: 197 additions & 0 deletions dl/src/main/scala/com/intel/analytics/sparkdl/nn/ConcatTable.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.intel.analytics.sparkdl.nn

import com.intel.analytics.sparkdl.tensor.Tensor
import com.intel.analytics.sparkdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.sparkdl.utils.{Activities, T, Table}

import scala.reflect.ClassTag

class ConcatTable[T : ClassTag](implicit ev: TensorNumeric[T])
extends Container[Activities, Activities, T] {

output = T()

override def updateOutput(input: Activities): Activities = {
var i = 0
while (i < modules.length) {
val currentOutput = modules(i).updateOutput(input)
if (!output.toTable().contains(i + 1)) {
output.toTable().insert(i + 1, currentOutput)
} else if (currentOutput != output.toTable().get(i + 1).get) {
output.toTable().update(i + 1, currentOutput)
}
i += 1
}
output
}

/**
* add in to out
* @param out
* @param in
*/
private def addTable(out: Activities, in: Activities) : Unit = {
if (in.isInstanceOf[Tensor[T]] && out.isInstanceOf[Tensor[T]]) {
require(in.toTensor[T]().nElement() == out.toTensor[T]().nElement(),
"gradInput should have the same size")
out.toTensor[T]().add(in.toTensor[T]())
} else {
var i = 1
while (i <= out.toTable().length()) {
addTable(out.toTable().get[Activities](i).get, in.toTable().get[Activities](i).get)
i += 1
}
}
}

/**
* copy in to out
* @param out
* @param in
*/
private def copyTable(out: Activities, in: Activities) : Unit = {
if (in.isInstanceOf[Tensor[T]] && out.isInstanceOf[Tensor[T]]) {
out.toTensor[T]().resizeAs(in.toTensor[T]()).copy(in.toTensor[T]())
} else {
var i = 1
while (i <= out.toTable().length()) {
copyTable(out.toTable().get[Activities](i).get, in.toTable().get[Activities]().get)
i += 1
}
}
}

/**
* return a clone of in
* @param in
* @return cloned table
*/
private def cloneTable(in: Activities) : Activities = {
if (in.isInstanceOf[Tensor[T]]) {
in.toTensor[T]().clone()
} else {
val out = T()
var i = 1
while (i <= in.toTable().length()) {
out(i) = cloneTable(in.toTable()(i))
i += 1
}
out
}
}

def backward(method: String, input: Activities, gradOutput: Activities,
scale : Double = 1.0) : Activities = {

val isTable = input.isInstanceOf[Table]
val wasTable = gradInput.isInstanceOf[Table]

if (isTable) {
if (!wasTable) {
gradInput = null
}
var i = 0
while (i < modules.length) {
method match {
case "updateGradInput" =>
val currentGradInput = modules(i).updateGradInput(input,
gradOutput.toTable().get(i + 1).get)
require(currentGradInput.isInstanceOf[Table],
"currentGradInput is not a table!")
if (i == 0) {
if (null == gradInput) {
gradInput = cloneTable(currentGradInput)
} else {
copyTable(gradInput, currentGradInput)
}
} else {
addTable(gradInput, currentGradInput)
}
case "accGradParameters" =>
modules(i).accGradParameters(input, gradOutput.toTable().get(i + 1).get, scale)
}
i += 1
}

} else {
if (wasTable) {
gradInput = null
}
var i = 0
while (i < modules.length) {
method match {
case "updateGradInput" =>
val currentGradInput = modules(i).updateGradInput(input,
gradOutput.toTable().get(i + 1).get)
if (i == 0) {
if (null == gradInput) {
gradInput = currentGradInput.toTensor().clone()
} else {
gradInput.toTensor[T]().resizeAs(
currentGradInput.toTensor[T]()).copy(currentGradInput.toTensor[T]())
}
} else {
gradInput.toTensor[T]().add(currentGradInput.toTensor[T]())
}
case "accGradParameters" =>
modules(i).accGradParameters(input, gradOutput.toTable().get(i + 1).get, scale)
}
i += 1
}
}
gradInput
}

override def updateGradInput(input: Activities, gradOutput: Activities): Activities = {
backward("updateGradInput", input, gradOutput)
}

override def accGradParameters(input: Activities, gradOutput: Activities,
scale: Double = 0.1): Unit = {

backward("accGradParameters", input, gradOutput)
}

override def toString(): String = {
val tab = "\t"
val line = "\n"
val next = " |`-> "
val lastNext = " `-> "
val ext = " | "
val extlast = " "
val last = " ... -> "
var str = "nn.ConcatTable"
str = str + " {" + line + tab + "input"
var i = 1
while (i <= modules.length) {
if (i == modules.length) {
str = str + line + tab + lastNext + "(" + i + "): " +
modules(i-1).toString.replace(line, line + tab + extlast)
} else {
str = str + line + tab + next + "(" + i + "): " +
modules(i-1).toString.replace(line, line + tab + ext)
}
i += 1
}
str = str + line + tab + last + "output"
str = str + line + "}"
str
}
}
39 changes: 39 additions & 0 deletions dl/src/main/scala/com/intel/analytics/sparkdl/nn/Identity.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.intel.analytics.sparkdl.nn

import com.intel.analytics.sparkdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.sparkdl.utils.Activities

import scala.reflect.ClassTag

class Identity[@specialized(Float, Double) T: ClassTag]()
(implicit ev: TensorNumeric[T]) extends Module[Activities, Activities, T] {

override def updateOutput(input: Activities): Activities = {
output = input
output
}

override def updateGradInput(input: Activities,
gradOutput: Activities): Activities = {

gradInput = gradOutput
gradInput
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class Table private[sparkdl](
Option(state(key).asInstanceOf[T])
}

def contains(key: Any): Boolean = {
state.contains(key)
}

def apply[T](key: Any): T = {
state(key).asInstanceOf[T]
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.intel.analytics.sparkdl.nn

import com.intel.analytics.sparkdl.tensor.{Storage, Tensor}
import com.intel.analytics.sparkdl.utils.T
import org.scalatest.{FlatSpec, Matchers}

class ConcatTableSpec extends FlatSpec with Matchers {

"A ConcateTable" should "return right output and grad" in {
val ct = new ConcatTable[Double]()
ct.add(new Identity[Double]())
ct.add(new Identity[Double]())

val input = T(Tensor[Float](
Storage(Array(1f, 2, 3))),
T(
Tensor[Float](Storage(Array(4f, 3, 2, 1)))
)
)
val output = ct.forward(input)
output should be (T(input, input))

val gradOutput1 = T(
Tensor(Storage[Float](Array(0.1f, 0.2f, 0.3f))),
T(
Tensor(Storage[Float](Array(0.4f, 0.3f, 0.2f, 0.1f)))
)
)
val gradOutput = T(gradOutput1, gradOutput1)

val gradInput = ct.updateGradInput(input, gradOutput)
ct.accGradParameters(input, gradOutput)
gradInput should be (T(
Tensor(Storage[Float](Array(0.2f, 0.4f, 0.6f))),
T(
Tensor(Storage[Float](Array(0.8f, 0.6f, 0.4f, 0.2f)))
)
))
}
}
Loading

0 comments on commit a2d788c

Please sign in to comment.