Skip to content

Commit

Permalink
Refine AbstractModule methods (#2262)
Browse files Browse the repository at this point in the history
* make getParameter be final and private

* remove updateParameter Method

* remove useless override zeroGrad

* fix comments

* fix unit tests

* fix unit tests

* allocate gradWeight storage if it's not allocated

* fix unit test

* meet code review

* make zeroGrad become final

* fix compile error

* add final to module apis

* add private[bigdl] to some module methods

* reorder the method sequence

* meet code review

* remove unnecessary getParameterTable and fix unit test

* fix unit test

* fix unit test
  • Loading branch information
yiheng committed Feb 9, 2018
1 parent 20013fc commit 77ad6c0
Show file tree
Hide file tree
Showing 45 changed files with 425 additions and 660 deletions.
Expand Up @@ -86,10 +86,6 @@ class Add[T: ClassTag](val inputSize: Int
}
}

override def zeroGradParameters(): Unit = {
gradBias.zero()
}

override def clearState() : this.type = {
super.clearState()
ones.set()
Expand Down
Expand Up @@ -142,13 +142,6 @@ class BatchNormalization[T: ClassTag](
this
}

override def zeroGradParameters(): Unit = {
if (affine) {
gradWeight.zero()
gradBias.zero()
}
}

override def parameters(): (Array[Tensor[T]], Array[Tensor[T]]) = {
if (affine) {
(Array(this.weight, this.bias), Array(this.gradWeight, this.gradBias))
Expand Down
Expand Up @@ -94,26 +94,6 @@ class BiRecurrent[T : ClassTag] (
*/
override def parameters(): (Array[Tensor[T]], Array[Tensor[T]]) = birnn.parameters()

override def updateParameters(learningRate: T): Unit = birnn.updateParameters(learningRate)

/**
* If the module has parameters, this will zero the accumulation of the gradients with respect
* to these parameters. Otherwise, it does nothing.
*/
override def zeroGradParameters(): Unit = birnn.zeroGradParameters()

override def training(): BiRecurrent.this.type = {
super.training()
birnn.training()
this
}

override def evaluate(): BiRecurrent.this.type = {
super.evaluate()
birnn.evaluate()
this
}

override def canEqual(other: Any): Boolean = other.isInstanceOf[BiRecurrent[T]]


Expand Down
Expand Up @@ -194,11 +194,6 @@ class Bilinear[T: ClassTag](
}
}

override def zeroGradParameters(): Unit = {
gradWeight.zero()
gradBias.zero()
}

override def clearState(): this.type = {
super.clearState()
buff1.set()
Expand All @@ -214,15 +209,6 @@ class Bilinear[T: ClassTag](
}
}

override def getParametersTable(): Table = {
if (null == bias) {
T(getName() -> T("weight" -> weight, "gradWeight" -> gradWeight))
} else {
T(getName() -> T("weight" -> weight, "bias" -> bias,
"gradWeight" -> gradWeight, "gradBias" -> gradBias))
}
}

override def toString(): String = {
s"${getPrintName}($inputSize1, $inputSize2, $outputSize, $biasRes)"
}
Expand Down
Expand Up @@ -368,11 +368,6 @@ class BinaryTreeLSTM[T: ClassTag](
(cp ++ lp, cg ++ lg)
}

override def updateParameters(learningRate: T): Unit = {
composer.updateParameters(learningRate)
leafModule.updateParameters(learningRate)
}

override def getParametersTable(): Table = {
val pt = T()
val t1 = composer.getParametersTable()
Expand All @@ -382,11 +377,6 @@ class BinaryTreeLSTM[T: ClassTag](
pt
}

override def zeroGradParameters(): Unit = {
composer.zeroGradParameters()
leafModule.zeroGradParameters()
}

override def reset(): Unit = {
composer.reset()
leafModule.reset()
Expand Down
Expand Up @@ -141,14 +141,6 @@ class CAdd[T: ClassTag](
}
}

override def updateParameters(learningRate: T): Unit = {
bias.map(gradBias, (a, b) => ev.minus(a, ev.times(learningRate, b)))
}

override def zeroGradParameters(): Unit = {
gradBias.zero()
}

override def parameters(): (Array[Tensor[T]], Array[Tensor[T]]) = {
(Array(this.bias), Array(this.gradBias))
}
Expand Down
12 changes: 0 additions & 12 deletions spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/CMul.scala
Expand Up @@ -163,22 +163,10 @@ class CMul[T: ClassTag](
}
}

override def updateParameters(learningRate: T): Unit = {
weight.map(gradWeight, (a, b) => ev.minus(a, ev.times(learningRate, b)))
}

override def zeroGradParameters(): Unit = {
gradWeight.zero()
}

override def parameters(): (Array[Tensor[T]], Array[Tensor[T]]) = {
(Array(this.weight), Array(this.gradWeight))
}

override def getParametersTable(): Table = {
T(getName() -> T("weight" -> weight, "gradWeight" -> gradWeight))
}

override def clearState(): this.type = {
super.clearState()
_repeat.set()
Expand Down
10 changes: 0 additions & 10 deletions spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Cell.scala
Expand Up @@ -190,11 +190,6 @@ abstract class Cell[T : ClassTag](
gradInput
}

override def updateParameters(learningRate: T): Unit = {
cell.updateParameters(learningRate)
if (includePreTopology) preTopology.updateParameters(learningRate)
}

private def initAddTimes(): Unit = {
val cellTimes = cell.getTimes
if (subModules == null || subModules.length < cellTimes.length) {
Expand Down Expand Up @@ -264,11 +259,6 @@ abstract class Cell[T : ClassTag](
cell.resetTimes
}

override def zeroGradParameters(): Unit = {
cell.zeroGradParameters()
if (includePreTopology) preTopology.zeroGradParameters()
}

override def parameters(): (Array[Tensor[T]], Array[Tensor[T]]) = {
val _cell = if (includePreTopology) {
Sequential().add(preTopology).add(cell)
Expand Down
12 changes: 0 additions & 12 deletions spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Concat.scala
Expand Up @@ -257,18 +257,6 @@ class Concat[T: ClassTag](val dimension: Int)(
this.gradInput
}

// Todo: this is different from torch accUpdateGradParameters
override def updateParameters(learningRate: T): Unit = {
var offset = 1
var i = 0
while (i < this.modules.length) {
val currentOutput = this.modules(i).output.asInstanceOf[Tensor[T]]
this.modules(i).updateParameters(learningRate)
i += 1
offset += currentOutput.size(dimension)
}
}

override def equals(obj: Any): Boolean = {
if (!super.equals(obj)) {
return false
Expand Down
Expand Up @@ -51,31 +51,23 @@ abstract class Container[A <: Activity : ClassTag,
modules.filter(!_.isCompatibleWithTorch()).length <= 0
}

override def zeroGradParameters(): Unit = {
modules.foreach(_.zeroGradParameters())
}

override def updateParameters(learningRate: T): Unit = {
modules.foreach(_.updateParameters(learningRate))
}

override def reset(): Unit = {
modules.foreach(_.reset())
}

override def training(): this.type = {
final override def training(): this.type = {
train = true
modules.foreach(_.training())
this
}

override def evaluate(): this.type = {
final override def evaluate(): this.type = {
train = false
modules.foreach(_.evaluate())
this
}

override def checkEngineType(): this.type = {
final override def checkEngineType(): this.type = {
modules.foreach(_.checkEngineType())
this
}
Expand Down
Expand Up @@ -176,18 +176,10 @@ class Cosine[T: ClassTag](val inputSize : Int, val outputSize : Int)(
}
}

override def zeroGradParameters(): Unit = {
gradWeight.zero()
}

override def parameters(): (Array[Tensor[T]], Array[Tensor[T]]) = {
(Array(this.weight), Array(this.gradWeight))
}

override def getParametersTable(): Table = {
T(getName() -> T("weight" -> weight, "gradWeight" -> gradWeight))
}

override def toString(): String = {
s"${getPrintName}($inputSize, $outputSize)"
}
Expand Down
Expand Up @@ -149,10 +149,6 @@ class Euclidean[T: ClassTag](val inputSize: Int, val outputSize: Int,
s"${getPrintName}($inputSize, $outputSize)"
}

override def zeroGradParameters(): Unit = {
gradWeight.zero()
}

override def clearState() : this.type = {
super.clearState()
inputBuffer.set()
Expand All @@ -168,10 +164,6 @@ class Euclidean[T: ClassTag](val inputSize: Int, val outputSize: Int,
(Array(this.weight), Array(this.gradWeight))
}

override def getParametersTable(): Table = {
T(getName() -> T("weight" -> weight, "gradWeight" -> gradWeight))
}

override def canEqual(other: Any): Boolean = other.isInstanceOf[Euclidean[T]]

override def equals(other: Any): Boolean = other match {
Expand Down
23 changes: 0 additions & 23 deletions spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/Linear.scala
Expand Up @@ -170,20 +170,6 @@ class Linear[T: ClassTag](
}
}

override def updateParameters(learningRate: T): Unit = {
weight.add(ev.negative(learningRate), gradWeight)
if (withBias) bias.add(ev.negative(learningRate), gradBias)
}

override def zeroGradParameters(): Unit = {
gradWeight.resize(outputSize, inputSize)
gradWeight.zero()
if (withBias) {
gradBias.resize(outputSize)
gradBias.zero()
}
}

override def clearState() : this.type = {
super.clearState()
addBuffer.set()
Expand All @@ -198,15 +184,6 @@ class Linear[T: ClassTag](
}
}

override def getParametersTable(): Table = {
if (null == bias) {
T(getName() -> T("weight" -> weight, "gradWeight" -> gradWeight))
} else {
T(getName() -> T("weight" -> weight, "bias" -> bias,
"gradWeight" -> gradWeight, "gradBias" -> gradBias))
}
}

override def equals(obj: Any): Boolean = {

if (!super.equals(obj)) {
Expand Down
Expand Up @@ -380,36 +380,11 @@ class LocallyConnected1D[T: ClassTag](val nInputFrame: Int,
}
}

override def updateParameters(learningRate: T): Unit

= {
weight.map(gradWeight, (a, b) => ev.minus(a, ev.times(learningRate, b)))
bias.map(gradBias, (a, b) => ev.minus(a, ev.times(learningRate, b)))
}

override def zeroGradParameters(): Unit

= {
gradWeight.zero()
gradBias.zero()
}

override def parameters(): (Array[Tensor[T]], Array[Tensor[T]])

= {
override def parameters(): (Array[Tensor[T]], Array[Tensor[T]]) = {
(Array(this.weight, this.bias), Array(this.gradWeight, this.gradBias))
}

override def getParametersTable(): Table

= {
T(getName() -> T("weight" -> weight, "bias" -> bias,
"gradWeight" -> gradWeight, "gradBias" -> gradBias))
}

override def equals(obj: Any): Boolean

= {
override def equals(obj: Any): Boolean = {
if (!super.equals(obj)) {
return false
}
Expand All @@ -432,9 +407,7 @@ class LocallyConnected1D[T: ClassTag](val nInputFrame: Int,
gradBias == other.gradBias
}

override def hashCode(): Int

= {
override def hashCode(): Int = {
val seed = 37
var hash = super.hashCode()
hash = hash * seed + inputFrameSize.hashCode()
Expand All @@ -449,16 +422,12 @@ class LocallyConnected1D[T: ClassTag](val nInputFrame: Int,
hash
}

override def clearState(): this.type

= {
override def clearState(): this.type = {
super.clearState()
this
}

override def toString(): String

= {
override def toString(): String = {
s"nn.TemporalConvolution($inputFrameSize -> $outputFrameSize, $kernelW x $strideW)"
}
}
Expand Down

0 comments on commit 77ad6c0

Please sign in to comment.