Skip to content

Commit

Permalink
Modify converged logic to do relative comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
Lewuathe committed Mar 16, 2015
1 parent f7b19d5 commit 3aef0a2
Showing 1 changed file with 25 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ object GradientDescent extends Logging {

val stochasticLossHistory = new ArrayBuffer[Double](numIterations)
// Record previous weight and current one to calculate solution vector difference

var previousWeights: Option[Vector] = None
var currentWeights: Option[Vector] = None

Expand Down Expand Up @@ -247,7 +248,10 @@ object GradientDescent extends Logging {
currentWeights = Some(weights)
}
if (previousWeights != None && currentWeights != None) {
if (solutionVecDiff(previousWeights, currentWeights) < convergenceTol) converged = true
if (isConverged(previousWeights.get, currentWeights.get,
initialWeights, convergenceTol)) {
converged = true
}
}
} else {
logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero")
Expand All @@ -274,13 +278,26 @@ object GradientDescent extends Logging {
GradientDescent.runMiniBatchSGD(data, gradient, updater, stepSize, numIterations,
regParam, miniBatchFraction, initialWeights, 0.001)

// To compare with convergence tolerance
def solutionVecDiff(previousWeight: Option[Vector], currentWeight: Option[Vector]): Double = {
require(previousWeight != None)
require(currentWeight != None)
val lastWeight = currentWeight.get.toBreeze
val lastBeforeWeight = previousWeight.get.toBreeze
sum((lastBeforeWeight - lastWeight) :* (lastBeforeWeight - lastWeight)) / lastWeight.length

private def isConverged(previousWeights: Vector, currentWeights: Vector,
initialWeights: Vector, convergenceTol: Double): Boolean = {
require(previousWeights != None)
require(currentWeights != None)
// To compare with convergence tolerance
def solutionVecDiff(previousWeight: Vector,
currentWeight: Vector): Double = {

val lastWeight = currentWeight.toBreeze
val lastBeforeWeight = previousWeight.toBreeze
sum((lastBeforeWeight - lastWeight)
:* (lastBeforeWeight - lastWeight)) / lastWeight.length
}

def squareAvg(weights: Vector): Double =
sum(weights.toBreeze :* weights.toBreeze) / weights.toBreeze.length

solutionVecDiff(previousWeights, currentWeights) <
convergenceTol * squareAvg(initialWeights)
}

}

0 comments on commit 3aef0a2

Please sign in to comment.