Skip to content

Commit

Permalink
[SPARK-7156][SQL] support RandomSplit in DataFrames
Browse files Browse the repository at this point in the history
This is built on top of kaka1992 's PR apache#5711 using Logical plans.

Author: Burak Yavuz <brkyvz@gmail.com>

Closes apache#5761 from brkyvz/random-sample and squashes the following commits:

a1fb0aa [Burak Yavuz] remove unrelated file
69669c3 [Burak Yavuz] fix broken test
1ddb3da [Burak Yavuz] copy base
6000328 [Burak Yavuz] added python api and fixed test
3c11d1b [Burak Yavuz] fixed broken test
f400ade [Burak Yavuz] fix build errors
2384266 [Burak Yavuz] addressed comments v0.1
e98ebac [Burak Yavuz] [SPARK-7156][SQL] support RandomSplit in DataFrames
  • Loading branch information
brkyvz authored and nemccarthy committed Jun 19, 2015
1 parent 0529595 commit c5ad4c2
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 22 deletions.
19 changes: 17 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,26 @@ abstract class RDD[T: ClassTag](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new PartitionwiseSampledRDD[T, T](
this, new BernoulliCellSampler[T](x(0), x(1)), true, seed)
randomSampleWithRange(x(0), x(1), seed)
}.toArray
}

/**
* Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability
* range.
* @param lb lower bound to use for the Bernoulli sampler
* @param ub upper bound to use for the Bernoulli sampler
* @param seed the seed for the Random number generator
* @return A random sub-sample of the RDD without replacement.
*/
private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = {
this.mapPartitionsWithIndex { case (index, partition) =>
val sampler = new BernoulliCellSampler[T](lb, ub)
sampler.setSeed(seed + index)
sampler.sample(partition)
}
}

/**
* Return a fixed-size sampled subset of this RDD in an array
*
Expand Down
8 changes: 4 additions & 4 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ public void sample() {
public void randomSplit() {
List<Integer> ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
JavaRDD<Integer> rdd = sc.parallelize(ints);
JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 11);
JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 31);
Assert.assertEquals(3, splits.length);
Assert.assertEquals(2, splits[0].count());
Assert.assertEquals(3, splits[1].count());
Assert.assertEquals(5, splits[2].count());
Assert.assertEquals(1, splits[0].count());
Assert.assertEquals(2, splits[1].count());
Assert.assertEquals(7, splits[2].count());
}

@Test
Expand Down
18 changes: 17 additions & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,14 +426,30 @@ def distinct(self):
def sample(self, withReplacement, fraction, seed=None):
"""Returns a sampled subset of this :class:`DataFrame`.
>>> df.sample(False, 0.5, 97).count()
>>> df.sample(False, 0.5, 42).count()
1
"""
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
seed = seed if seed is not None else random.randint(0, sys.maxsize)
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
return DataFrame(rdd, self.sql_ctx)

def randomSplit(self, weights, seed=None):
"""Randomly splits this :class:`DataFrame` with the provided weights.
>>> splits = df4.randomSplit([1.0, 2.0], 24)
>>> splits[0].count()
1
>>> splits[1].count()
3
"""
for w in weights:
assert w >= 0.0, "Negative weight value: %s" % w
seed = seed if seed is not None else random.randint(0, sys.maxsize)
rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed))
return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]

@property
def dtypes(self):
"""Returns all column names and their data types as a list.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,6 @@ package object dsl {
def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan =
Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)

def sample(
fraction: Double,
withReplacement: Boolean = true,
seed: Int = (math.random * 1000).toInt): LogicalPlan =
Sample(fraction, withReplacement, seed, logicalPlan)

// TODO specify the output column names
def generate(
generator: Generator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,22 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil))
}

case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
extends UnaryNode {
/**
* Sample the dataset.
*
* @param lowerBound Lower-bound of the sampling probability (usually 0.0)
* @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
* will be ub - lb.
* @param withReplacement Whether to sample with replacement.
* @param seed the random seed
* @param child the LogicalPlan
*/
case class Sample(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long,
child: LogicalPlan) extends UnaryNode {

override def output: Seq[Attribute] = child.output
}
Expand Down
38 changes: 37 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ class DataFrame private[sql](
* @group dfops
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
Sample(fraction, withReplacement, seed, logicalPlan)
Sample(0.0, fraction, withReplacement, seed, logicalPlan)
}

/**
Expand All @@ -720,6 +720,42 @@ class DataFrame private[sql](
sample(withReplacement, fraction, Utils.random.nextLong)
}

/**
* Randomly splits this [[DataFrame]] with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @param seed Seed for sampling.
* @group dfops
*/
def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = {
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new DataFrame(sqlContext, Sample(x(0), x(1), false, seed, logicalPlan))
}.toArray
}

/**
* Randomly splits this [[DataFrame]] with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @group dfops
*/
def randomSplit(weights: Array[Double]): Array[DataFrame] = {
randomSplit(weights, Utils.random.nextLong)
}

/**
* Randomly splits this [[DataFrame]] with the provided weights. Provided for the Python Api.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @param seed Seed for sampling.
* @group dfops
*/
def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = {
randomSplit(weights.toArray, seed)
}

/**
* (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more
* rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Expand(projections, output, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case logical.Sample(fraction, withReplacement, seed, child) =>
execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
case logical.Sample(lb, ub, withReplacement, seed, child) =>
execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
LocalTableScan(output, data) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,32 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {

/**
* :: DeveloperApi ::
* Sample the dataset.
* @param lowerBound Lower-bound of the sampling probability (usually 0.0)
* @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
* will be ub - lb.
* @param withReplacement Whether to sample with replacement.
* @param seed the random seed
* @param child the QueryPlan
*/
@DeveloperApi
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)
case class Sample(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long,
child: SparkPlan)
extends UnaryNode
{
override def output: Seq[Attribute] = child.output

// TODO: How to pick seed?
override def execute(): RDD[Row] = {
child.execute().map(_.copy()).sample(withReplacement, fraction, seed)
if (withReplacement) {
child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed)
} else {
child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed)
}
}
}

Expand Down
17 changes: 17 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,23 @@ class DataFrameSuite extends QueryTest {
assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
}

test("randomSplit") {
val n = 600
val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id")
for (seed <- 1 to 5) {
val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
assert(splits.length == 3, "wrong number of splits")

assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
data.collect().toList, "incomplete or wrong split")

val s = splits.map(_.count())
assert(math.abs(s(0) - 100) < 50) // std = 9.13
assert(math.abs(s(1) - 200) < 50) // std = 11.55
assert(math.abs(s(2) - 300) < 50) // std = 12.25
}
}

test("describe") {
val describeTestData = Seq(
("Bob", 16, 176),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -887,13 +887,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon)
&& fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon),
s"Sampling fraction ($fraction) must be on interval [0, 100]")
Sample(fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt,
Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt,
relation)
case Token("TOK_TABLEBUCKETSAMPLE",
Token(numerator, Nil) ::
Token(denominator, Nil) :: Nil) =>
val fraction = numerator.toDouble / denominator.toDouble
Sample(fraction, withReplacement = false, (math.random * 1000).toInt, relation)
Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)
case a: ASTNode =>
throw new NotImplementedError(
s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} :
Expand Down

0 comments on commit c5ad4c2

Please sign in to comment.