Skip to content

Commit

Permalink
move sampleBy to stat
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Jun 11, 2015
1 parent 832f7cc commit 4a14834
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 35 deletions.
7 changes: 6 additions & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def sampleBy(self, col, fractions, seed=None):
raise ValueError("key must be float, int, long, or string, but got %r" % type(k))
fractions[k] = float(v)
seed = seed if seed is not None else random.randint(0, sys.maxsize)
return DataFrame(self._jdf.sampleBy(col, self._jmap(fractions), seed), self.sql_ctx)
return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx)

@since(1.4)
def randomSplit(self, weights, seed=None):
Expand Down Expand Up @@ -1353,6 +1353,11 @@ def freqItems(self, cols, support=None):

freqItems.__doc__ = DataFrame.freqItems.__doc__

def sampleBy(self, col, fractions, seed=None):
return self.df.sampleBy(col, fractions, seed)

sampleBy.__doc__ = DataFrame.sampleBy.__doc__


def _test():
import doctest
Expand Down
30 changes: 6 additions & 24 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
package org.apache.spark.sql

import java.io.CharArrayWriter
import java.util.{Properties, UUID}
import java.util.Properties

import scala.collection.JavaConversions._
import scala.language.implicitConversions
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
Expand All @@ -32,18 +33,19 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
import org.apache.spark.sql.json.JacksonGenerator
import org.apache.spark.sql.sources.CreateTableUsingAsSelect
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils


private[sql] object DataFrame {
def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = {
new DataFrame(sqlContext, logicalPlan)
Expand Down Expand Up @@ -945,26 +947,6 @@ class DataFrame private[sql](
sample(withReplacement, fraction, Utils.random.nextLong)
}

/**
* Returns a stratified sample without replacement based on the fraction given on each stratum.
* @param col column that defines strata
* @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
* its fraction as zero.
* @param seed random seed
* @return a new [[DataFrame]] that represents the stratified sample
*/
def sampleBy(col: String, fractions: Map[Any, Double], seed: Long): DataFrame = {
require(fractions.values.forall(p => p >= 0.0 && p <= 1.0),
s"Fractions must be in [0, 1], but got $fractions.")
import org.apache.spark.sql.functions.rand
val c = Column(col)
val r = rand(seed).as("rand_" + UUID.randomUUID().toString.take(8))
val expr = fractions.toSeq.map { case (k, v) =>
(c === k) && (r < v)
}.reduce(_ || _) || false
this.filter(expr)
}

/**
* Randomly splits this [[DataFrame]] with the provided weights.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql

import java.util.UUID

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.stat._

Expand Down Expand Up @@ -163,4 +165,26 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
def freqItems(cols: Seq[String]): DataFrame = {
FrequentItems.singlePassFreqItems(df, cols, 0.01)
}

/**
* Returns a stratified sample without replacement based on the fraction given on each stratum.
* @param col column that defines strata
* @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
* its fraction as zero.
* @param seed random seed
* @return a new [[DataFrame]] that represents the stratified sample
*
* @since 1.5.0
*/
def sampleBy(col: String, fractions: Map[Any, Double], seed: Long): DataFrame = {
require(fractions.values.forall(p => p >= 0.0 && p <= 1.0),
s"Fractions must be in [0, 1], but got $fractions.")
import org.apache.spark.sql.functions.rand
val c = Column(col)
val r = rand(seed).as("rand_" + UUID.randomUUID().toString.take(8))
val expr = fractions.toSeq.map { case (k, v) =>
(c === k) && (r < v)
}.reduce(_ || _) || false
df.filter(expr)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.sql

import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.functions.col

class DataFrameStatSuite extends SparkFunSuite {
class DataFrameStatSuite extends QueryTest {

private val sqlCtx = org.apache.spark.sql.test.TestSQLContext
import sqlCtx.implicits._
Expand Down Expand Up @@ -98,4 +98,12 @@ class DataFrameStatSuite extends SparkFunSuite {
val items2 = singleColResults.collect().head
items2.getSeq[Double](0) should contain (-1.0)
}

test("sampleBy") {
val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key"))
val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
checkAnswer(
sampled.groupBy("key").count().orderBy("key"),
Seq(Row(0, 4), Row(1, 9)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -635,12 +635,4 @@ class DataFrameSuite extends QueryTest {
val res11 = ctx.range(-1).select("id")
assert(res11.count == 0)
}

test("sampleBy") {
val df = ctx.range(0, 100).select((col("id") % 3).as("key"))
val sampled = df.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
checkAnswer(
sampled.groupBy("key").count().orderBy("key"),
Seq(Row(0, 4), Row(1, 9)))
}
}

0 comments on commit 4a14834

Please sign in to comment.