Skip to content

Commit

Permalink
Remove cross validation [TODO in another pull request]
Browse files Browse the repository at this point in the history
  • Loading branch information
holdenk committed Apr 9, 2014
1 parent 91eae64 commit b78804e
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 51 deletions.
29 changes: 0 additions & 29 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -199,35 +199,6 @@ object MLUtils {
))).toList
}


/**
* Function to perform cross validation on a single learner.
*
* @param data - input data set
* @param folds - the number of folds (must be > 1)
* @param learner - function to produce a model
* @param errorFunction - function to compute the error of a given point
*
* @return the average error on the cross validated data.
*/
def crossValidate(data: RDD[LabeledPoint], folds: Int, seed: Int,
learner: (RDD[LabeledPoint] => RegressionModel),
errorFunction: ((Double,Double) => Double) = meanSquaredError): Double = {
if (folds <= 1) {
throw new IllegalArgumentException("Cross validation requires more than one fold")
}
val rdds = kFoldRdds(data, folds, seed)
val errorRates = rdds.map{case (testData, trainingData) =>
val model = learner(trainingData)
val predictions = testData.map(data => (data.label, model.predict(data.features)))
val errors = predictions.map{case (x, y) => errorFunction(x, y)}
errors.sum()
}
val averageError = errorRates.sum / data.count.toFloat
averageError
}


/**
* Utility function to compute mean and standard deviation on a given dataset.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,28 +136,6 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
new LinearRegressionModel(Array(1.0), 0)
}

test("Test cross validation with a reasonable learner") {
val data = sc.parallelize(1.to(100).zip(1.to(100))).map(
x => LabeledPoint(x._1, Array(x._2)))
val features = data.map(_.features)
val labels = data.map(_.label)
for (seed <- 1 to 5) {
for (folds <- 2 to 5) {
val avgError = MLUtils.crossValidate(data, folds, seed, exactLearner)
avgError should equal (0)
}
}
}

test("Cross validation requires more than one fold") {
val data = sc.parallelize(1.to(100).zip(1.to(100))).map(
x => LabeledPoint(x._1, Array(x._2)))
val thrown = intercept[java.lang.IllegalArgumentException] {
val avgError = MLUtils.crossValidate(data, 1, 1, exactLearner)
}
assert(thrown.getClass === classOf[IllegalArgumentException])
}

test("kfoldRdd") {
val data = sc.parallelize(1 to 100, 2)
val collectedData = data.collect().sorted
Expand Down

0 comments on commit b78804e

Please sign in to comment.