diff --git a/dl/src/main/scala/com/intel/analytics/sparkdl/example/ImageNetParallel.scala b/dl/src/main/scala/com/intel/analytics/sparkdl/example/ImageNetParallel.scala index 4b554fab969..c56046534dd 100644 --- a/dl/src/main/scala/com/intel/analytics/sparkdl/example/ImageNetParallel.scala +++ b/dl/src/main/scala/com/intel/analytics/sparkdl/example/ImageNetParallel.scala @@ -20,8 +20,9 @@ package com.intel.analytics.sparkdl.example import com.intel.analytics.sparkdl.example.ImageNetUtils._ import com.intel.analytics.sparkdl.example.Utils._ import com.intel.analytics.sparkdl.nn._ -import com.intel.analytics.sparkdl.optim.EpochOptimizer.Regime import com.intel.analytics.sparkdl.optim._ +import com.intel.analytics.sparkdl.optim.SGD +import com.intel.analytics.sparkdl.optim.SGD.{EpochSchedule, Poly, Regime} import com.intel.analytics.sparkdl.ps.{AllReduceParameterManager, OneReduceParameterManager} import com.intel.analytics.sparkdl.tensor._ import com.intel.analytics.sparkdl.utils.T @@ -104,13 +105,7 @@ object ImageNetParallel { val workerConfig = params.workerConfig.clone() workerConfig("profile") = true - val regime: Array[Regime] = Array( - Regime(1, 18, T("learningRate" -> 1e-2, "weightDecay" -> 2e-4)), - Regime(19, 29, T("learningRate" -> 5e-3, "weightDecay" -> 2e-4)), - Regime(30, 43, T("learningRate" -> 1e-3, "weightDecay" -> 0.0)), - Regime(44, 52, T("learningRate" -> 5e-4, "weightDecay" -> 0.0)), - Regime(53, 100000000, T("learningRate" -> 1e-4, "weightDecay" -> 0.0)) - ) + driverConfig("learningRateSchedule") = Poly(0.5, 84375) val croppedData = if (cropImage) { loadCroppedData(trainFiles, sc, labelsMap, classNum + 0.5).coalesce(partitionNum, true) @@ -151,7 +146,6 @@ object ImageNetParallel { val optimizer = new GradAggEpochOptimizer[Float](model, criterion, getOptimMethodFloat(params.masterOptM), pm, dataSets, metrics, driverConfig) - optimizer.setRegimes(regime) optimizer.addEvaluation("top1", EvaluateMethods.calcAccuracy) optimizer.addEvaluation("top5", EvaluateMethods.calcTop5Accuracy) optimizer.setTestDataSet(testDataSets) diff --git a/dl/src/main/scala/com/intel/analytics/sparkdl/optim/EpochOptimizer.scala b/dl/src/main/scala/com/intel/analytics/sparkdl/optim/EpochOptimizer.scala index aebac57f4b3..d08b3f54c8d 100644 --- a/dl/src/main/scala/com/intel/analytics/sparkdl/optim/EpochOptimizer.scala +++ b/dl/src/main/scala/com/intel/analytics/sparkdl/optim/EpochOptimizer.scala @@ -32,10 +32,6 @@ abstract class EpochOptimizer[T]( metrics: Metrics, config: Table = T()) extends Optimizer(module, criterion, dataSets) { - import EpochOptimizer._ - - protected var regimes: Array[Regime] = Array[Regime]() - protected var maxEpoch: Option[Int] = None def setMaxEpoch(maxEpoch: Int): this.type = { @@ -44,11 +40,6 @@ abstract class EpochOptimizer[T]( } this } - - def setRegimes(regimes: Array[Regime]): this.type = { - this.regimes = regimes.clone() - this - } } class GradAggEpochOptimizer[@specialized(Float, Double) T: ClassTag]( @@ -75,12 +66,6 @@ class GradAggEpochOptimizer[@specialized(Float, Double) T: ClassTag]( logInfo(s"[Epoch $i/$epochNum] Train start") val epochStart = System.nanoTime() - // set optimize parameter from regime - for (r <- regimes) { - if (i >= r.startEpoch && i <= r.endEpoch) { - config.add(r.config) - } - } logInfo("config" + config) logInfo(s"[Epoch $i/$epochNum] Shuffle data") @@ -91,6 +76,7 @@ class GradAggEpochOptimizer[@specialized(Float, Double) T: ClassTag]( (shuffleEnd - epochStart) / 1e9 }s") + config("epoch") = i while (!dataSets.epochFinished()) { val lossSum = sc.accumulator(0.0, "loss sum") val recordsNum = sc.accumulator(0, "record number") @@ -189,21 +175,14 @@ class WeightAvgEpochOptimizer[@specialized(Float, Double) T: ClassTag]( for (i <- 1 to epochNum) { logInfo(s"[Epoch $i/$epochNum] Train start") val epochStart = System.nanoTime() - - // set optimize parameter from regime - for (r <- regimes) { - if (i >= r.startEpoch && i <= r.endEpoch) { - config.add(r.config) - } - } logInfo("config" + config) - logInfo(s"[Epoch $i/$epochNum] Shuffle data") dataSets.reset() val shuffleEnd = System.nanoTime() var accumulateCount = 0 logInfo(s"[Epoch $i/$epochNum] Shuffle data complete. Takes" + s" ${(shuffleEnd - epochStart) / 1e9}s") + config("epoch") = i while (!dataSets.epochFinished()) { val lossSum = sc.accumulator(0.0, "loss sum") val recordsNum = sc.accumulator(0, "record number") @@ -231,6 +210,7 @@ class WeightAvgEpochOptimizer[@specialized(Float, Double) T: ClassTag]( var stacks = 0 var tmp = System.nanoTime() localModule.zeroGradParameters() + localModule.training() metrics.add("init gradient time", System.nanoTime() - tmp) val batch = data.next() var recordsss = 0 @@ -292,9 +272,3 @@ class WeightAvgEpochOptimizer[@specialized(Float, Double) T: ClassTag]( module } } - -object EpochOptimizer { - - case class Regime(startEpoch: Int, endEpoch: Int, config: Table) - -} diff --git a/dl/src/main/scala/com/intel/analytics/sparkdl/optim/HasCrossValidation.scala b/dl/src/main/scala/com/intel/analytics/sparkdl/optim/HasCrossValidation.scala index 16050be2d9c..54f7bd50cd2 100644 --- a/dl/src/main/scala/com/intel/analytics/sparkdl/optim/HasCrossValidation.scala +++ b/dl/src/main/scala/com/intel/analytics/sparkdl/optim/HasCrossValidation.scala @@ -60,6 +60,7 @@ trait HasCrossValidation[@specialized(Float, Double) T] extends Serializable wit coalesce(models.partitions.length, false). zipPartitions(models)((data, cacheModelIter) => { val localModel = cacheModelIter.next().model + localModel.evaluate() val localEvaluation = evaluationBroadcast.value Iterator.single(data.foldLeft((0, 0))((count, t) => { val result = localEvaluation(localModel.forward(t._1), t._2) diff --git a/dl/src/main/scala/com/intel/analytics/sparkdl/optim/SGD.scala b/dl/src/main/scala/com/intel/analytics/sparkdl/optim/SGD.scala index 63b7c424500..c04d04c8f38 100644 --- a/dl/src/main/scala/com/intel/analytics/sparkdl/optim/SGD.scala +++ b/dl/src/main/scala/com/intel/analytics/sparkdl/optim/SGD.scala @@ -26,19 +26,21 @@ import scala.reflect.ClassTag class SGD[@specialized(Float, Double) T: ClassTag](implicit ev: TensorNumeric[T]) extends OptimMethod[T] { + import SGD._ + override def optimize(feval: (Tensor[T]) => (T, Tensor[T]), x: Tensor[T], config: Table, state: Table = null): (Tensor[T], Array[T]) = { val _state = if (state == null) config else state - val lr = config.get[Double]("learningRate").getOrElse(1e-3) - val lrd = config.get[Double]("learningRateDecay").getOrElse(0.0) + val lrSchedule = config.get[LearningRateSchedule]("learningRateSchedule").getOrElse(Default()) + lrSchedule.updateHyperParameter(config, _state) + val wd = config.get[Double]("weightDecay").getOrElse(0.0) val mom = config.get[Double]("momentum").getOrElse(0.0) val damp = config.get[Double]("dampening").getOrElse(mom) val nesterov = config.get[Boolean]("nesterov").getOrElse(false) val lrs = config.get[Tensor[T]]("learningRates").getOrElse(null) val wds = config.get[Tensor[T]]("weightDecays").getOrElse(null) - val nevals = _state.get[Int]("evalCounter").getOrElse(0) require(!nesterov || (mom > 0 && damp == 0), "Nesterov momentum requires a momentum and zero dampening") @@ -74,8 +76,7 @@ class SGD[@specialized(Float, Double) T: ClassTag](implicit ev: TensorNumeric[T] } } - val clr = ev.fromType[Double](-lr / (1 + nevals * lrd)) - + val clr = ev.fromType(config[Double]("clr")) if (lrs != null) { val deltaParameters = _state.get[Tensor[T]]("deltaParameters").getOrElse({ val deltaP = Tensor[T]().resizeAs(dfdx) @@ -88,8 +89,66 @@ class SGD[@specialized(Float, Double) T: ClassTag](implicit ev: TensorNumeric[T] x.add(clr, dfdx) } - _state("evalCounter") = nevals + 1 (x, Array(fx)) } } + +object SGD { + trait LearningRateSchedule { + def updateHyperParameter(config : Table, state : Table) : Unit + } + + case class EpochSchedule(regimes : Array[Regime]) extends LearningRateSchedule { + override def updateHyperParameter(config: Table, state: Table): Unit = { + val epoch = config[Int]("epoch") + for (r <- regimes) { + if (epoch >= r.startEpoch && epoch <= r.endEpoch) { + config.add(r.config) + } + } + config("clr") = -config.get[Double]("learningRate").getOrElse(1e-3) + } + } + case class Poly(power : Double, maxIteration : Int) extends LearningRateSchedule { + override def updateHyperParameter(config: Table, state: Table): Unit = { + val lr = config.get[Double]("learningRate").getOrElse(1e-3) + val nevals = state.get[Int]("evalCounter").getOrElse(0) + val clr = if (nevals > maxIteration) { + 0.0 + } else { + -lr * math.pow(1.0 - nevals.toDouble / maxIteration, power) + } + println(s"iteration is : ${nevals}. current learning rate is $clr") + state("evalCounter") = nevals + 1 + config("clr") = clr + } + } + + case class Step(stepSize : Int, gamma : Double) extends LearningRateSchedule { + override def updateHyperParameter(config: Table, state: Table): Unit = { + val lr = config.get[Double]("learningRate").getOrElse(1e-3) + var clr = -lr + val nevals = state.get[Int]("evalCounter").getOrElse(0) + var i = 0 + while(i < nevals / stepSize) { + clr *= gamma + i += 1 + } + state("evalCounter") = nevals + 1 + config("clr") = clr + } + } + + case class Default() extends LearningRateSchedule { + override def updateHyperParameter(config: Table, state: Table): Unit = { + val lr = config.get[Double]("learningRate").getOrElse(1e-3) + val lrd = config.get[Double]("learningRateDecay").getOrElse(0.0) + val nevals = state.get[Int]("evalCounter").getOrElse(0) + config("clr") = -lr / (1 + nevals * lrd) + state("evalCounter") = nevals + 1 + } + } + + case class Regime(startEpoch: Int, endEpoch: Int, config: Table) +} diff --git a/dl/src/test/scala/com/intel/analytics/sparkdl/optim/SGDSpec.scala b/dl/src/test/scala/com/intel/analytics/sparkdl/optim/SGDSpec.scala index 3dbbb7a445d..65b31515a2e 100644 --- a/dl/src/test/scala/com/intel/analytics/sparkdl/optim/SGDSpec.scala +++ b/dl/src/test/scala/com/intel/analytics/sparkdl/optim/SGDSpec.scala @@ -17,7 +17,8 @@ package com.intel.analytics.sparkdl.optim -import com.intel.analytics.sparkdl.tensor.Tensor +import com.intel.analytics.sparkdl.optim.SGD._ +import com.intel.analytics.sparkdl.tensor.{Storage, Tensor} import com.intel.analytics.sparkdl.utils.T import org.scalatest.{FlatSpec, Matchers} @@ -65,4 +66,107 @@ class SGDSpec extends FlatSpec with Matchers { x(Array(1)) should be(1.0 +- 0.1) x(Array(2)) should be(1.0 +- 0.1) } + + "default learning rate decay" should "generate correct learning rates" in { + val config = T("learningRate" -> 0.1, "learningRateDecay" -> 0.1, "learningRateSchedule" -> + Default()) + val optimMethod = new SGD[Double] + def feval(x: Tensor[Double]): (Double, Tensor[Double]) = { + return (0.1, Tensor[Double](Storage(Array(1.0, 1.0)))) + } + val x = Tensor[Double](Storage(Array(10.0, 10.0))) + val state = T() + optimMethod.optimize(feval, x, config, state) + config[Double]("clr") should be(-0.1 / (1 + 0 * 0.1)) + optimMethod.optimize(feval, x, config, state) + config[Double]("clr") should be(-0.1 / (1 + 1 * 0.1)) + optimMethod.optimize(feval, x, config, state) + config[Double]("clr") should be(-0.1 / (1 + 2 * 0.1)) + } + + it should "be used when we leave the learningRateSchedule empty" in { + val config = T("learningRate" -> 0.1, "learningRateDecay" -> 0.1) + val optimMethod = new SGD[Double] + def feval(x: Tensor[Double]): (Double, Tensor[Double]) = { + return (0.1, Tensor[Double](Storage(Array(1.0, 1.0)))) + } + val x = Tensor[Double](Storage(Array(10.0, 10.0))) + val state = T() + optimMethod.optimize(feval, x, config, state) + config[Double]("clr") should be(-0.1 / (1 + 0 * 0.1)) + optimMethod.optimize(feval, x, config, state) + config[Double]("clr") should be(-0.1 / (1 + 1 * 0.1)) + optimMethod.optimize(feval, x, config, state) + config[Double]("clr") should be(-0.1 / (1 + 2 * 0.1)) + } + + "step learning rate decay" should "generate correct learning rates" in { + val config = T("learningRate" -> 0.1, "learningRateSchedule" -> Step(5, 0.1)) + val optimMethod = new SGD[Double] + def feval(x: Tensor[Double]): (Double, Tensor[Double]) = { + return (0.1, Tensor[Double](Storage(Array(1.0, 1.0)))) + } + val x = Tensor[Double](Storage(Array(10.0, 10.0))) + val state = T() + for(i <- 1 to 5) { + optimMethod.optimize(feval, x, config, state) + config[Double]("clr") should be(-0.1 +- 1e-9) + } + + for(i <- 1 to 5) { + optimMethod.optimize(feval, x, config, state) + config[Double]("clr") should be(-0.01 +- 1e-9) + } + + for(i <- 1 to 5) { + optimMethod.optimize(feval, x, config, state) + config[Double]("clr") should be(-0.001 +- 1e-9) + } + } + + "ploy learning rate decay" should "generate correct learning rates" in { + val config = T("learningRate" -> 0.1, "learningRateSchedule" -> Poly(3, 100)) + val optimMethod = new SGD[Double] + def feval(x: Tensor[Double]): (Double, Tensor[Double]) = { + return (0.1, Tensor[Double](Storage(Array(1.0, 1.0)))) + } + val x = Tensor[Double](Storage(Array(10.0, 10.0))) + val state = T() + optimMethod.optimize(feval, x, config, state) + config[Double]("clr") should be(-0.1) + optimMethod.optimize(feval, x, config, state) + config[Double]("clr") should be(-0.1 * (1 - 1.0 / 100) * (1 - 1.0 / 100) * (1 - 1.0 / 100)) + optimMethod.optimize(feval, x, config, state) + config[Double]("clr") should be(-0.1 * (1 - 2.0 / 100) * (1 - 2.0 / 100) * (1 - 2.0 / 100)) + } + + "epoch decay" should "generate correct learning rates" in { + val regimes: Array[Regime] = Array( + Regime(1, 3, T("learningRate" -> 1e-2, "weightDecay" -> 2e-4)), + Regime(4, 7, T("learningRate" -> 5e-3, "weightDecay" -> 2e-4)), + Regime(8, 10, T("learningRate" -> 1e-3, "weightDecay" -> 0.0)) + ) + + val config = T("learningRate" -> 0.1, "learningRateSchedule" -> EpochSchedule(regimes)) + val optimMethod = new SGD[Double] + def feval(x: Tensor[Double]): (Double, Tensor[Double]) = { + return (0.1, Tensor[Double](Storage(Array(1.0, 1.0)))) + } + val x = Tensor[Double](Storage(Array(10.0, 10.0))) + val state = T() + for(e <- 1 to 10) { + config("epoch") = e + optimMethod.optimize(feval, x, config, state) + if(e <= 3) { + config[Double]("clr") should be(-1e-2) + config[Double]("weightDecay") should be(2e-4) + } else if (e <= 7) { + config[Double]("clr") should be(-5e-3) + config[Double]("weightDecay") should be(2e-4) + } else if (e <= 10) { + config[Double]("clr") should be(-1e-3) + config[Double]("weightDecay") should be(0.0) + } + } + } }