Skip to content

Commit

Permalink
[SPARK-3382] should compare diff inside loss history and convergence …
Browse files Browse the repository at this point in the history
…tolerance
  • Loading branch information
Lewuathe committed Dec 9, 2014
1 parent 5433f71 commit b9d5e61
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.mllib.optimization

import scala.collection.mutable.ArrayBuffer
import scala.util.control.Breaks

import breeze.linalg.{DenseVector => BDV}

Expand All @@ -27,7 +28,6 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.rdd.RDDFunctions._

import scala.util.control.Breaks

/**
* Class used to solve an optimization problem using Gradient Descent.
Expand All @@ -41,7 +41,7 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va
private var numIterations: Int = 100
private var regParam: Double = 0.0
private var miniBatchFraction: Double = 1.0
private var convergenceTolerance: Double = 0.0
private var convergenceTolerance: Double = 0.001

/**
* Set the initial step size of SGD for the first step. Default 1.0.
Expand Down Expand Up @@ -80,7 +80,7 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va
}

/**
* Set the convergence tolerance. Default 0.0
* Set the convergence tolerance. Default 0.001
*/
def setConvergenceTolerance(tolerance: Double): this.type = {
this.convergenceTolerance = tolerance
Expand Down Expand Up @@ -124,8 +124,8 @@ class GradientDescent private[mllib] (private var gradient: Gradient, private va
numIterations,
regParam,
miniBatchFraction,
convergenceTolerance,
initialWeights)
initialWeights,
convergenceTolerance)
weights
}

Expand Down Expand Up @@ -154,7 +154,9 @@ object GradientDescent extends Logging {
* @param regParam - regularization parameter
* @param miniBatchFraction - fraction of the input data set that should be used for
* one iteration of SGD. Default value 1.0.
*
* @param convergenceTolerance - Minibatch iteration will end within numIterations
* if the difference between last loss and last before loss
* is less than this value. Default value 0.001.
* @return A tuple containing two elements. The first element is a column matrix containing
* weights for every feature, and the second element is an array containing the
* stochastic loss computed for every iteration.
Expand All @@ -167,8 +169,13 @@ object GradientDescent extends Logging {
numIterations: Int,
regParam: Double,
miniBatchFraction: Double,
convergenceTolerance: Double,
initialWeights: Vector): (Vector, Array[Double]) = {
initialWeights: Vector,
convergenceTolerance: Double): (Vector, Array[Double]) = {

// convergenceTolerance should be set with non minibatch settings
if (miniBatchFraction < 1.0 && convergenceTolerance > 0.0) {
logWarning("testing against a convergenceTolerance can be dangerous because of the stochasticity")
}

val stochasticLossHistory = new ArrayBuffer[Double](numIterations)

Expand Down Expand Up @@ -223,7 +230,9 @@ object GradientDescent extends Logging {
weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), stepSize, i, regParam)
weights = update._1
regVal = update._2
if (stochasticLossHistory.last < convergenceTolerance) b.break
if (stochasticLossHistory.length > 1) {
if (Math.abs(stochasticLossHistory.last - stochasticLossHistory(stochasticLossHistory.length - 2)) < convergenceTolerance) b.break
}
} else {
logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero")
}
Expand All @@ -247,6 +256,6 @@ object GradientDescent extends Logging {
miniBatchFraction: Double,
initialWeights: Vector): (Vector, Array[Double]) =
GradientDescent.runMiniBatchSGD(data, gradient, updater, stepSize, numIterations,
regParam, miniBatchFraction, 0.0, initialWeights)
regParam, miniBatchFraction, initialWeights, 0.001)

}
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matc
numIterations,
regParam,
miniBatchFrac,
convergenceTolerance,
initialWeightsWithIntercept)
initialWeightsWithIntercept,
convergenceTolerance)

assert(loss.length < numIterations, "doesn't satisfy convergence tolerance")
}
Expand Down

0 comments on commit b9d5e61

Please sign in to comment.