Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ package com.intel.analytics.sparkdl.example
import com.intel.analytics.sparkdl.example.ImageNetUtils._
import com.intel.analytics.sparkdl.example.Utils._
import com.intel.analytics.sparkdl.nn._
import com.intel.analytics.sparkdl.optim.EpochOptimizer.Regime
import com.intel.analytics.sparkdl.optim._
import com.intel.analytics.sparkdl.optim.SGD
import com.intel.analytics.sparkdl.optim.SGD.{EpochSchedule, Poly, Regime}
import com.intel.analytics.sparkdl.ps.{AllReduceParameterManager, OneReduceParameterManager}
import com.intel.analytics.sparkdl.tensor._
import com.intel.analytics.sparkdl.utils.T
Expand Down Expand Up @@ -104,13 +105,7 @@ object ImageNetParallel {
val workerConfig = params.workerConfig.clone()
workerConfig("profile") = true

val regime: Array[Regime] = Array(
Regime(1, 18, T("learningRate" -> 1e-2, "weightDecay" -> 2e-4)),
Regime(19, 29, T("learningRate" -> 5e-3, "weightDecay" -> 2e-4)),
Regime(30, 43, T("learningRate" -> 1e-3, "weightDecay" -> 0.0)),
Regime(44, 52, T("learningRate" -> 5e-4, "weightDecay" -> 0.0)),
Regime(53, 100000000, T("learningRate" -> 1e-4, "weightDecay" -> 0.0))
)
driverConfig("learningRateSchedule") = Poly(0.5, 84375)

val croppedData = if (cropImage) {
loadCroppedData(trainFiles, sc, labelsMap, classNum + 0.5).coalesce(partitionNum, true)
Expand Down Expand Up @@ -151,7 +146,6 @@ object ImageNetParallel {
val optimizer = new GradAggEpochOptimizer[Float](model, criterion,
getOptimMethodFloat(params.masterOptM),
pm, dataSets, metrics, driverConfig)
optimizer.setRegimes(regime)
optimizer.addEvaluation("top1", EvaluateMethods.calcAccuracy)
optimizer.addEvaluation("top5", EvaluateMethods.calcTop5Accuracy)
optimizer.setTestDataSet(testDataSets)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ abstract class EpochOptimizer[T](
metrics: Metrics,
config: Table = T()) extends Optimizer(module, criterion, dataSets) {

import EpochOptimizer._

protected var regimes: Array[Regime] = Array[Regime]()

protected var maxEpoch: Option[Int] = None

def setMaxEpoch(maxEpoch: Int): this.type = {
Expand All @@ -44,11 +40,6 @@ abstract class EpochOptimizer[T](
}
this
}

def setRegimes(regimes: Array[Regime]): this.type = {
this.regimes = regimes.clone()
this
}
}

class GradAggEpochOptimizer[@specialized(Float, Double) T: ClassTag](
Expand All @@ -75,12 +66,6 @@ class GradAggEpochOptimizer[@specialized(Float, Double) T: ClassTag](
logInfo(s"[Epoch $i/$epochNum] Train start")
val epochStart = System.nanoTime()

// set optimize parameter from regime
for (r <- regimes) {
if (i >= r.startEpoch && i <= r.endEpoch) {
config.add(r.config)
}
}
logInfo("config" + config)

logInfo(s"[Epoch $i/$epochNum] Shuffle data")
Expand All @@ -91,6 +76,7 @@ class GradAggEpochOptimizer[@specialized(Float, Double) T: ClassTag](
(shuffleEnd -
epochStart) / 1e9
}s")
config("epoch") = i
while (!dataSets.epochFinished()) {
val lossSum = sc.accumulator(0.0, "loss sum")
val recordsNum = sc.accumulator(0, "record number")
Expand Down Expand Up @@ -189,21 +175,14 @@ class WeightAvgEpochOptimizer[@specialized(Float, Double) T: ClassTag](
for (i <- 1 to epochNum) {
logInfo(s"[Epoch $i/$epochNum] Train start")
val epochStart = System.nanoTime()

// set optimize parameter from regime
for (r <- regimes) {
if (i >= r.startEpoch && i <= r.endEpoch) {
config.add(r.config)
}
}
logInfo("config" + config)

logInfo(s"[Epoch $i/$epochNum] Shuffle data")
dataSets.reset()
val shuffleEnd = System.nanoTime()
var accumulateCount = 0
logInfo(s"[Epoch $i/$epochNum] Shuffle data complete. Takes" +
s" ${(shuffleEnd - epochStart) / 1e9}s")
config("epoch") = i
while (!dataSets.epochFinished()) {
val lossSum = sc.accumulator(0.0, "loss sum")
val recordsNum = sc.accumulator(0, "record number")
Expand Down Expand Up @@ -231,6 +210,7 @@ class WeightAvgEpochOptimizer[@specialized(Float, Double) T: ClassTag](
var stacks = 0
var tmp = System.nanoTime()
localModule.zeroGradParameters()
localModule.training()
metrics.add("init gradient time", System.nanoTime() - tmp)
val batch = data.next()
var recordsss = 0
Expand Down Expand Up @@ -292,9 +272,3 @@ class WeightAvgEpochOptimizer[@specialized(Float, Double) T: ClassTag](
module
}
}

object EpochOptimizer {

case class Regime(startEpoch: Int, endEpoch: Int, config: Table)

}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ trait HasCrossValidation[@specialized(Float, Double) T] extends Serializable wit
coalesce(models.partitions.length, false).
zipPartitions(models)((data, cacheModelIter) => {
val localModel = cacheModelIter.next().model
localModel.evaluate()
val localEvaluation = evaluationBroadcast.value
Iterator.single(data.foldLeft((0, 0))((count, t) => {
val result = localEvaluation(localModel.forward(t._1), t._2)
Expand Down
71 changes: 65 additions & 6 deletions dl/src/main/scala/com/intel/analytics/sparkdl/optim/SGD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,21 @@ import scala.reflect.ClassTag
class SGD[@specialized(Float, Double) T: ClassTag](implicit ev: TensorNumeric[T])
extends OptimMethod[T] {

import SGD._

override def optimize(feval: (Tensor[T]) => (T, Tensor[T]), x: Tensor[T],
config: Table, state: Table = null): (Tensor[T], Array[T]) = {

val _state = if (state == null) config else state
val lr = config.get[Double]("learningRate").getOrElse(1e-3)
val lrd = config.get[Double]("learningRateDecay").getOrElse(0.0)
val lrSchedule = config.get[LearningRateSchedule]("learningRateSchedule").getOrElse(Default())
lrSchedule.updateHyperParameter(config, _state)

val wd = config.get[Double]("weightDecay").getOrElse(0.0)
val mom = config.get[Double]("momentum").getOrElse(0.0)
val damp = config.get[Double]("dampening").getOrElse(mom)
val nesterov = config.get[Boolean]("nesterov").getOrElse(false)
val lrs = config.get[Tensor[T]]("learningRates").getOrElse(null)
val wds = config.get[Tensor[T]]("weightDecays").getOrElse(null)
val nevals = _state.get[Int]("evalCounter").getOrElse(0)

require(!nesterov || (mom > 0 && damp == 0),
"Nesterov momentum requires a momentum and zero dampening")
Expand Down Expand Up @@ -74,8 +76,7 @@ class SGD[@specialized(Float, Double) T: ClassTag](implicit ev: TensorNumeric[T]
}
}

val clr = ev.fromType[Double](-lr / (1 + nevals * lrd))

val clr = ev.fromType(config[Double]("clr"))
if (lrs != null) {
val deltaParameters = _state.get[Tensor[T]]("deltaParameters").getOrElse({
val deltaP = Tensor[T]().resizeAs(dfdx)
Expand All @@ -88,8 +89,66 @@ class SGD[@specialized(Float, Double) T: ClassTag](implicit ev: TensorNumeric[T]
x.add(clr, dfdx)
}

_state("evalCounter") = nevals + 1

(x, Array(fx))
}
}

object SGD {
trait LearningRateSchedule {
def updateHyperParameter(config : Table, state : Table) : Unit
}

case class EpochSchedule(regimes : Array[Regime]) extends LearningRateSchedule {
override def updateHyperParameter(config: Table, state: Table): Unit = {
val epoch = config[Int]("epoch")
for (r <- regimes) {
if (epoch >= r.startEpoch && epoch <= r.endEpoch) {
config.add(r.config)
}
}
config("clr") = -config.get[Double]("learningRate").getOrElse(1e-3)
}
}
case class Poly(power : Double, maxIteration : Int) extends LearningRateSchedule {
override def updateHyperParameter(config: Table, state: Table): Unit = {
val lr = config.get[Double]("learningRate").getOrElse(1e-3)
val nevals = state.get[Int]("evalCounter").getOrElse(0)
val clr = if (nevals > maxIteration) {
0.0
} else {
-lr * math.pow(1.0 - nevals.toDouble / maxIteration, power)
}
println(s"iteration is : ${nevals}. current learning rate is $clr")
state("evalCounter") = nevals + 1
config("clr") = clr
}
}

case class Step(stepSize : Int, gamma : Double) extends LearningRateSchedule {
override def updateHyperParameter(config: Table, state: Table): Unit = {
val lr = config.get[Double]("learningRate").getOrElse(1e-3)
var clr = -lr
val nevals = state.get[Int]("evalCounter").getOrElse(0)
var i = 0
while(i < nevals / stepSize) {
clr *= gamma
i += 1
}
state("evalCounter") = nevals + 1
config("clr") = clr
}
}

case class Default() extends LearningRateSchedule {
override def updateHyperParameter(config: Table, state: Table): Unit = {
val lr = config.get[Double]("learningRate").getOrElse(1e-3)
val lrd = config.get[Double]("learningRateDecay").getOrElse(0.0)
val nevals = state.get[Int]("evalCounter").getOrElse(0)
config("clr") = -lr / (1 + nevals * lrd)
state("evalCounter") = nevals + 1
}
}

case class Regime(startEpoch: Int, endEpoch: Int, config: Table)
}
106 changes: 105 additions & 1 deletion dl/src/test/scala/com/intel/analytics/sparkdl/optim/SGDSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package com.intel.analytics.sparkdl.optim

import com.intel.analytics.sparkdl.tensor.Tensor
import com.intel.analytics.sparkdl.optim.SGD._
import com.intel.analytics.sparkdl.tensor.{Storage, Tensor}
import com.intel.analytics.sparkdl.utils.T
import org.scalatest.{FlatSpec, Matchers}

Expand Down Expand Up @@ -65,4 +66,107 @@ class SGDSpec extends FlatSpec with Matchers {
x(Array(1)) should be(1.0 +- 0.1)
x(Array(2)) should be(1.0 +- 0.1)
}

"default learning rate decay" should "generate correct learning rates" in {
val config = T("learningRate" -> 0.1, "learningRateDecay" -> 0.1, "learningRateSchedule" ->
Default())
val optimMethod = new SGD[Double]
def feval(x: Tensor[Double]): (Double, Tensor[Double]) = {
return (0.1, Tensor[Double](Storage(Array(1.0, 1.0))))
}
val x = Tensor[Double](Storage(Array(10.0, 10.0)))
val state = T()
optimMethod.optimize(feval, x, config, state)
config[Double]("clr") should be(-0.1 / (1 + 0 * 0.1))
optimMethod.optimize(feval, x, config, state)
config[Double]("clr") should be(-0.1 / (1 + 1 * 0.1))
optimMethod.optimize(feval, x, config, state)
config[Double]("clr") should be(-0.1 / (1 + 2 * 0.1))
}

it should "be used when we leave the learningRateSchedule empty" in {
val config = T("learningRate" -> 0.1, "learningRateDecay" -> 0.1)
val optimMethod = new SGD[Double]
def feval(x: Tensor[Double]): (Double, Tensor[Double]) = {
return (0.1, Tensor[Double](Storage(Array(1.0, 1.0))))
}
val x = Tensor[Double](Storage(Array(10.0, 10.0)))
val state = T()
optimMethod.optimize(feval, x, config, state)
config[Double]("clr") should be(-0.1 / (1 + 0 * 0.1))
optimMethod.optimize(feval, x, config, state)
config[Double]("clr") should be(-0.1 / (1 + 1 * 0.1))
optimMethod.optimize(feval, x, config, state)
config[Double]("clr") should be(-0.1 / (1 + 2 * 0.1))
}

"step learning rate decay" should "generate correct learning rates" in {
val config = T("learningRate" -> 0.1, "learningRateSchedule" -> Step(5, 0.1))
val optimMethod = new SGD[Double]
def feval(x: Tensor[Double]): (Double, Tensor[Double]) = {
return (0.1, Tensor[Double](Storage(Array(1.0, 1.0))))
}
val x = Tensor[Double](Storage(Array(10.0, 10.0)))
val state = T()
for(i <- 1 to 5) {
optimMethod.optimize(feval, x, config, state)
config[Double]("clr") should be(-0.1 +- 1e-9)
}

for(i <- 1 to 5) {
optimMethod.optimize(feval, x, config, state)
config[Double]("clr") should be(-0.01 +- 1e-9)
}

for(i <- 1 to 5) {
optimMethod.optimize(feval, x, config, state)
config[Double]("clr") should be(-0.001 +- 1e-9)
}
}

"ploy learning rate decay" should "generate correct learning rates" in {
val config = T("learningRate" -> 0.1, "learningRateSchedule" -> Poly(3, 100))
val optimMethod = new SGD[Double]
def feval(x: Tensor[Double]): (Double, Tensor[Double]) = {
return (0.1, Tensor[Double](Storage(Array(1.0, 1.0))))
}
val x = Tensor[Double](Storage(Array(10.0, 10.0)))
val state = T()
optimMethod.optimize(feval, x, config, state)
config[Double]("clr") should be(-0.1)
optimMethod.optimize(feval, x, config, state)
config[Double]("clr") should be(-0.1 * (1 - 1.0 / 100) * (1 - 1.0 / 100) * (1 - 1.0 / 100))
optimMethod.optimize(feval, x, config, state)
config[Double]("clr") should be(-0.1 * (1 - 2.0 / 100) * (1 - 2.0 / 100) * (1 - 2.0 / 100))
}

"epoch decay" should "generate correct learning rates" in {
val regimes: Array[Regime] = Array(
Regime(1, 3, T("learningRate" -> 1e-2, "weightDecay" -> 2e-4)),
Regime(4, 7, T("learningRate" -> 5e-3, "weightDecay" -> 2e-4)),
Regime(8, 10, T("learningRate" -> 1e-3, "weightDecay" -> 0.0))
)

val config = T("learningRate" -> 0.1, "learningRateSchedule" -> EpochSchedule(regimes))
val optimMethod = new SGD[Double]
def feval(x: Tensor[Double]): (Double, Tensor[Double]) = {
return (0.1, Tensor[Double](Storage(Array(1.0, 1.0))))
}
val x = Tensor[Double](Storage(Array(10.0, 10.0)))
val state = T()
for(e <- 1 to 10) {
config("epoch") = e
optimMethod.optimize(feval, x, config, state)
if(e <= 3) {
config[Double]("clr") should be(-1e-2)
config[Double]("weightDecay") should be(2e-4)
} else if (e <= 7) {
config[Double]("clr") should be(-5e-3)
config[Double]("weightDecay") should be(2e-4)
} else if (e <= 10) {
config[Double]("clr") should be(-1e-3)
config[Double]("weightDecay") should be(0.0)
}
}
}
}