Permalink
Browse files

Add k-fold cross validation to MLLib

  • Loading branch information...
1 parent 84f7ca1 commit 65b04948ba4289d48a49f2b62bfa89fe38012232 @holdenk committed Feb 5, 2014
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+import java.util.Random
+
+import cern.jet.random.Poisson
+import cern.jet.random.engine.DRand
+
+import org.apache.spark.{Partition, TaskContext}
+
+private[spark]
+class FoldedRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable {
+ override val index: Int = prev.index
+}
+
+class FoldedRDD[T: ClassTag](
+ prev: RDD[T],
+ fold: Int,
+ folds: Int,
+ seed: Int)
+ extends RDD[T](prev) {
+
+ override def getPartitions: Array[Partition] = {
+ val rg = new Random(seed)
+ firstParent[T].partitions.map(x => new FoldedRDDPartition(x, rg.nextInt))
+ }
+
+ override def getPreferredLocations(split: Partition): Seq[String] =
+ firstParent[T].preferredLocations(split.asInstanceOf[FoldedRDDPartition].prev)
+
+ override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
+ val split = splitIn.asInstanceOf[FoldedRDDPartition]
+ val rand = new Random(split.seed)
+ firstParent[T].iterator(split.prev, context).filter(x => (rand.nextInt(folds) == fold-1))
+ }
+}
+
+/**
+ * A companion class to FoldedRDD which contains all of the elements not in the fold for the same
+ * fold/seed combination. Useful for cross validation
+ */
+class CompositeFoldedRDD[T: ClassTag](
+ prev: RDD[T],
+ fold: Int,
+ folds: Int,
+ seed: Int)
+ extends FoldedRDD[T](prev, fold, folds, seed) {
+
+ override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
+ val split = splitIn.asInstanceOf[FoldedRDDPartition]
+ val rand = new Random(split.seed)
+ firstParent[T].iterator(split.prev, context).filter(x => (rand.nextInt(folds) != fold-1))
+ }
+}
@@ -341,6 +341,17 @@ abstract class RDD[T: ClassTag](
}.toArray
}
+ /**
+ * Return a k element list of pairs of RDDs with the first element of each pair
+ * containing a unique 1/Kth of the data and the second element contain the composite of that.
+ */
+ def kFoldRdds(folds: Int, seed: Int): List[Pair[RDD[T], RDD[T]]] = {
+ 1.to(folds).map(fold => ((
+ new FoldedRDD(this, fold, folds, seed),
+ new CompositeFoldedRDD(this, fold, folds, seed)
+ ))).toList
+ }
+
def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
var fraction = 0.0
var total = 0
@@ -503,6 +503,23 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
}
+ test("kfoldRdd") {
+ val data = sc.parallelize(1 to 100, 2)
+ for (folds <- 1 to 10) {
+ for (seed <- 1 to 5) {
+ val foldedRdds = data.kFoldRdds(folds, seed)
+ assert(foldedRdds.size === folds)
+ foldedRdds.map{case (test, train) =>
+ assert(test.union(train).collect().sorted === data.collect().sorted,
+ "Each training+test set combined contains all of the data")
+ }
+ // K fold cross validation should only have each element in the test set exactly once
+ assert(foldedRdds.map(_._1).reduce((x,y) => x.union(y)).collect().sorted ===
+ data.collect().sorted)
+ }
+ }
+ }
+
test("runJob on an invalid partition") {
intercept[IllegalArgumentException] {
sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false)
@@ -23,6 +23,7 @@ import org.apache.spark.SparkContext._
import org.jblas.DoubleMatrix
import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.regression.RegressionModel
/**
* Helper methods to load, save and pre-process data used in ML Lib.
@@ -61,6 +62,38 @@ object MLUtils {
dataStr.saveAsTextFile(dir)
}
+ def meanSquaredError(a: Double, b: Double): Double = {
+ (a-b)*(a-b)
+ }
+
+ /**
+ * Function to perform cross validation on a single learner.
+ *
+ * @param data - input data set
+ * @param folds - the number of folds (must be > 1)
+ * @param learner - function to produce a model
+ * @param errorFunction - function to compute the error of a given point
+ *
+ * @return the average error on the cross validated data.
+ */
+ def crossValidate(data: RDD[LabeledPoint], folds: Int, seed: Int,
+ learner: (RDD[LabeledPoint] => RegressionModel),
+ errorFunction: ((Double,Double) => Double) = meanSquaredError): Double = {
+ if (folds <= 1) {
+ throw new IllegalArgumentException("Cross validation requires more than one fold")
+ }
+ val rdds = data.kFoldRdds(folds, seed)
+ val errorRates = rdds.map{case (testData, trainingData) =>
+ val model = learner(trainingData)
+ val predictions = model.predict(testData.map(_.features))
+ val errors = predictions.zip(testData.map(_.label)).map{case (x,y) => errorFunction(x,y)}
+ errors.sum()
+ }
+ val averageError = errorRates.sum / data.count
+ averageError
+ }
+
+
/**
* Utility function to compute mean and standard deviation on a given dataset.
*
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.util
+
+import scala.util.Random
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+
+import org.apache.spark.mllib.regression._
+import org.apache.spark.SparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext._
+
+class MLUtilsSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
+ @transient private var sc: SparkContext = _
+
+ override def beforeAll() {
+ sc = new SparkContext("local", "test")
+ }
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ // This learner always says everything is 0
+ def terribleLearner(trainingData: RDD[LabeledPoint]): RegressionModel = {
+ object AlwaysZero extends RegressionModel {
+ override def predict(testData: RDD[Array[Double]]): RDD[Double] = {
+ testData.map(_ => 0)
+ }
+ override def predict(testData: Array[Double]): Double = {
+ 0
+ }
+ }
+ AlwaysZero
+ }
+
+ // Always returns its input
+ def exactLearner(trainingData: RDD[LabeledPoint]): RegressionModel = {
+ new LinearRegressionModel(Array(1.0), 0)
+ }
+
+ test("Test cross validation with a terrible learner") {
+ val data = sc.parallelize(1.to(100).zip(1.to(100))).map(
+ x => LabeledPoint(x._1, Array(x._2)))
+ val expectedError = 1.to(100).map(x => x*x).sum / 100.0
+ for (seed <- 1 to 5) {
+ for (folds <- 2 to 5) {
+ val avgError = MLUtils.crossValidate(data, folds, seed, terribleLearner)
+ avgError should equal (expectedError)
+ }
+ }
+ }
+ test("Test cross validation with a reasonable learner") {
+ val data = sc.parallelize(1.to(100).zip(1.to(100))).map(
+ x => LabeledPoint(x._1, Array(x._2)))
+ for (seed <- 1 to 5) {
+ for (folds <- 2 to 5) {
+ val avgError = MLUtils.crossValidate(data, folds, seed, exactLearner)
+ avgError should equal (0)
+ }
+ }
+ }
+
+ test("Cross validation requires more than one fold") {
+ val data = sc.parallelize(1.to(100).zip(1.to(100))).map(
+ x => LabeledPoint(x._1, Array(x._2)))
+ val thrown = intercept[java.lang.IllegalArgumentException] {
+ val avgError = MLUtils.crossValidate(data, 1, 1, terribleLearner)
+ }
+ assert(thrown.getClass === classOf[IllegalArgumentException])
+ }
+}

0 comments on commit 65b0494

Please sign in to comment.