Permalink
Browse files

Switch FoldedRDD to use BernoulliSampler and PartitionwiseSampledRDD

  • Loading branch information...
holdenk committed Feb 11, 2014
1 parent 7becbcb commit 969be9e4dfc5e2ce31418c564590ce2d55f759e7
@@ -24,6 +24,7 @@ import cern.jet.random.Poisson
import cern.jet.random.engine.DRand
import org.apache.spark.{Partition, TaskContext}
+import org.apache.spark.util.random.BernoulliSampler
private[spark]
class FoldedRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable {
@@ -32,24 +33,10 @@ class FoldedRDDPartition(val prev: Partition, val seed: Int) extends Partition w
class FoldedRDD[T: ClassTag](
prev: RDD[T],
- fold: Int,
- folds: Int,
+ fold: Float,
+ folds: Float,
seed: Int)
- extends RDD[T](prev) {
-
- override def getPartitions: Array[Partition] = {
- val rg = new Random(seed)
- firstParent[T].partitions.map(x => new FoldedRDDPartition(x, rg.nextInt))
- }
-
- override def getPreferredLocations(split: Partition): Seq[String] =
- firstParent[T].preferredLocations(split.asInstanceOf[FoldedRDDPartition].prev)
-
- override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
- val split = splitIn.asInstanceOf[FoldedRDDPartition]
- val rand = new Random(split.seed)
- firstParent[T].iterator(split.prev, context).filter(x => (rand.nextInt(folds) == fold-1))
- }
+ extends PartitionwiseSampledRDD[T, T](prev, new BernoulliSampler((fold-1)/folds,fold/folds, false), seed) {
}
/**
@@ -58,14 +45,8 @@ class FoldedRDD[T: ClassTag](
*/
class CompositeFoldedRDD[T: ClassTag](
prev: RDD[T],
- fold: Int,
- folds: Int,
+ fold: Float,
+ folds: Float,
seed: Int)
- extends FoldedRDD[T](prev, fold, folds, seed) {
-
- override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
- val split = splitIn.asInstanceOf[FoldedRDDPartition]
- val rand = new Random(split.seed)
- firstParent[T].iterator(split.prev, context).filter(x => (rand.nextInt(folds) != fold-1))
- }
+ extends PartitionwiseSampledRDD[T, T](prev, new BernoulliSampler((fold-1)/folds, fold/folds, true), seed) {
}
@@ -503,14 +503,28 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
}
+ test("FoldedRDD") {
+ val data = sc.parallelize(1 to 100, 2)
+ val lowerFoldedRdd = new FoldedRDD(data, 1, 2, 1)
+ val upperFoldedRdd = new FoldedRDD(data, 2, 2, 1)
+ val lowerCompositeFoldedRdd = new CompositeFoldedRDD(data, 1, 2, 1)
+ assert(lowerFoldedRdd.collect().sorted.size == 50)
+ assert(lowerCompositeFoldedRdd.collect().sorted.size == 50)
+ assert(lowerFoldedRdd.subtract(lowerCompositeFoldedRdd).collect().sorted ===
+ lowerFoldedRdd.collect().sorted)
+ assert(upperFoldedRdd.collect().sorted.size == 50)
+ }
+
test("kfoldRdd") {
val data = sc.parallelize(1 to 100, 2)
- for (folds <- 1 to 10) {
+ val collectedData = data.collect().sorted
+ for (folds <- 2 to 10) {
for (seed <- 1 to 5) {
val foldedRdds = data.kFoldRdds(folds, seed)
assert(foldedRdds.size === folds)
foldedRdds.map{case (test, train) =>
- assert(test.union(train).collect().sorted === data.collect().sorted,
+ val result = test.union(train).collect().sorted
+ assert(result === collectedData,
"Each training+test set combined contains all of the data")
}
// K fold cross validation should only have each element in the test set exactly once

0 comments on commit 969be9e

Please sign in to comment.