Skip to content

Commit

Permalink
CR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
holdenk committed Apr 10, 2014
1 parent 90896c7 commit 7157ae9
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
}
}

def cloneComplement() = new BernoulliSampler[T](lb, ub, !complement)
/**
* Return a sampler with is the complement of the range specified of the current sampler.
*/
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)

override def clone = new BernoulliSampler[T](lb, ub, complement)
}
Expand Down
11 changes: 5 additions & 6 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@ import scala.reflect.ClassTag

import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV,
squaredDistance => breezeSquaredDistance}
import org.jblas.DoubleMatrix

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.PartitionwiseSampledRDD
import org.apache.spark.SparkContext._
import org.apache.spark.util.random.BernoulliSampler

import org.jblas.DoubleMatrix
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.{Vector, Vectors}

Expand Down Expand Up @@ -179,13 +178,13 @@ object MLUtils {
}

/**
* Return a k element list of pairs of RDDs with the first element of each pair
* Return a k element array of pairs of RDDs with the first element of each pair
* containing the validation data, a unique 1/Kth of the data and the second
* element, the training data, contain the compliment of that.
* element, the training data, contain the complement of that.
*/
def kFold[T : ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = {
def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = {
val numFoldsF = numFolds.toFloat
(1 to numFolds).map { fold =>
(1 to numFolds).map { fold =>
val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,
complement = false)
val validation = new PartitionwiseSampledRDD(rdd, sampler, seed)
Expand Down
24 changes: 12 additions & 12 deletions mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,25 +120,25 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
for (seed <- 1 to 5) {
val foldedRdds = MLUtils.kFold(data, folds, seed)
assert(foldedRdds.size === folds)
foldedRdds.map { case (test, train) =>
val result = test.union(train).collect().sorted
val testSize = test.collect().size.toFloat
assert(testSize > 0, "empty test data")
foldedRdds.map { case (validation, training) =>
val result = validation.union(training).collect().sorted
val validationSize = validation.collect().size.toFloat
assert(validationSize > 0, "empty validation data")
val p = 1 / folds.toFloat
// Within 3 standard deviations of the mean
val range = 3 * math.sqrt(100 * p * (1-p))
val range = 3 * math.sqrt(100 * p * (1 - p))
val expected = 100 * p
val lowerBound = expected - range
val upperBound = expected + range
assert(testSize > lowerBound,
s"Test data ($testSize) smaller than expected ($lowerBound)" )
assert(testSize < upperBound,
s"Test data ($testSize) larger than expected ($upperBound)" )
assert(train.collect().size > 0, "empty training data")
assert(validationSize > lowerBound,
s"Validation data ($validationSize) smaller than expected ($lowerBound)" )
assert(validationSize < upperBound,
s"Validation data ($validationSize) larger than expected ($upperBound)" )
assert(training.collect().size > 0, "empty training data")
assert(result === collectedData,
"Each training+test set combined should contain all of the data.")
"Each training+validation set combined should contain all of the data.")
}
// K fold cross validation should only have each element in the test set exactly once
// K fold cross validation should only have each element in the validation set exactly once
assert(foldedRdds.map(_._1).reduce((x,y) => x.union(y)).collect().sorted ===
data.collect().sorted)
}
Expand Down

0 comments on commit 7157ae9

Please sign in to comment.