Skip to content

Commit

Permalink
Merge pull request #1 from jegonzal/admm-emerson2
Browse files Browse the repository at this point in the history
refactoring
  • Loading branch information
pbailis committed Jul 18, 2014
2 parents 09caa96 + 71f0565 commit f0ec2ee
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.mllib.admm.PegasosSVM
import scopt.OptionParser

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.classification.{LogisticRegressionWithSGD, SVMWithSGD}
import org.apache.spark.mllib.classification.{SVMWithADMM, LogisticRegressionWithSGD, SVMWithSGD}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.optimization.{SquaredL2Updater, L1Updater}
Expand All @@ -39,7 +39,7 @@ object BinaryClassification {

object Algorithm extends Enumeration {
type Algorithm = Value
val SVM, LR, Pegasos, PegasosAsync = Value
val SVM, LR, SVMADMM, Pegasos, PegasosAsync = Value
}

object RegType extends Enumeration {
Expand Down Expand Up @@ -144,6 +144,12 @@ object BinaryClassification {
.setUpdater(updater)
.setRegParam(params.regParam)
algorithm.run(training).clearThreshold()
case SVMADMM =>
val algorithm = new SVMWithADMM()
algorithm.maxGlobalIterations = params.numIterations
algorithm.updater = updater
algorithm.regParam = params.regParam
algorithm.run(training).clearThreshold()
case Pegasos =>
val algorithm = new PegasosSVM()
algorithm.run(training)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ class PegasosSVM(val iterations: Integer = 10,
}

val weightsWithIntercept =
if(!async)
Vectors.fromBreeze(BSPADMMwithSGD.train(data, iterations, new PegasosBVGradient(lambda), initialWeights))
else
Vectors.fromBreeze(AsyncADMMwithSGD.train(data, iterations, new PegasosBVGradient(lambda), initialWeights))

if(!async) {
Vectors.fromBreeze(BSPADMMwithSGD.train(data, iterations, new PegasosBVGradient(lambda), initialWeights))
} else {
Vectors.fromBreeze(AsyncADMMwithSGD.train(data, iterations, new PegasosBVGradient(lambda), initialWeights))
}

val intercept = if (addIntercept) weightsWithIntercept(0) else 0.0
val weights =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,37 @@ class SVMWithSGD private (
}
}

/**
* Train a Support Vector Machine (SVM) using Stochastic Gradient Descent.
* NOTE: Labels used in SVM should be {0, 1}.
*/
class SVMWithADMM extends GeneralizedLinearAlgorithm[SVMModel] with Serializable {
var stepSize: Double = 1.0
var maxGlobalIterations: Int = Int.MaxValue
var maxLocalIterations: Int = Int.MaxValue
var regParam: Double = 1.0
var miniBatchFraction: Double = 2.0
var epsilon: Double = 1.0e-5
var updater: Updater = new SquaredL2Updater()


private val gradient = new HingeGradient()
private val localSolver = new SGDLocalOptimizer(gradient, updater)
localSolver.eta_0 = stepSize
localSolver.maxIterations = maxLocalIterations
localSolver.epsilon = epsilon
localSolver.miniBatchFraction = miniBatchFraction
override val optimizer = new ADMM(localSolver)
optimizer.numIterations = maxGlobalIterations
optimizer.regParam = regParam
optimizer.epsilon = epsilon
override protected val validators = List(DataValidators.binaryLabelValidator)
override protected def createModel(weights: Vector, intercept: Double) = {
new SVMModel(weights, intercept)
}
}


/**
* Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1}.
*/
Expand Down
Loading

0 comments on commit f0ec2ee

Please sign in to comment.