Skip to content

Commit

Permalink
Fix the names in kFold
Browse files Browse the repository at this point in the history
  • Loading branch information
holdenk committed Apr 9, 2014
1 parent c702a96 commit 2cb90b3
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,17 @@ object MLUtils {

/**
* Return a k element list of pairs of RDDs with the first element of each pair
* containing a unique 1/Kth of the data and the second element contain the compliment of that.
* containing the validation data, a unique 1/Kth of the data and the second
* element, the training data, contain the compliment of that.
*/
def kFold[T : ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): List[Pair[RDD[T], RDD[T]]] = {
def kFold[T : ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = {
val numFoldsF = numFolds.toFloat
(1 to numFolds).map { fold =>
val sampler = new BernoulliSampler[T]((fold-1)/numFoldsF,fold/numFoldsF, complement = false)
val train = new PartitionwiseSampledRDD(rdd, sampler, seed)
val test = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), seed)
(train, test)
}.toList
val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, complement = false)
val validation = new PartitionwiseSampledRDD(rdd, sampler, seed)
val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), seed)
(validation, training)
}.toArray
}

/**
Expand Down

0 comments on commit 2cb90b3

Please sign in to comment.