Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/apache/spark into SPARK-1…
Browse files Browse the repository at this point in the history
…712_new
  • Loading branch information
witgo committed May 14, 2014
2 parents 062c182 + c33b8dc commit 1d35c3c
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 12 deletions.
61 changes: 61 additions & 0 deletions core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,64 @@ class RangePartitioner[K : Ordering : ClassTag, V](
false
}
}

/**
* A [[org.apache.spark.Partitioner]] that partitions records into specified bounds
* Default value is 1000. Once all partitions have bounds elements, the partitioner
* allocates 1 element per partition so eventually the smaller partitions are at most
* off by 1 key compared to the larger partitions.
*/
class BoundaryPartitioner[K : Ordering : ClassTag, V](
partitions: Int,
@transient rdd: RDD[_ <: Product2[K,V]],
private val boundary: Int = 1000)
extends Partitioner {

// this array keeps track of keys assigned to a partition
// counts[0] refers to # of keys in partition 0 and so on
private val counts: Array[Int] = {
new Array[Int](numPartitions)
}

def numPartitions = math.abs(partitions)

/*
* Ideally, this should've been calculated based on # partitions and total keys
* But we are not calling count on RDD here to avoid calling an action.
* User has the flexibility of calling count and passing in any appropriate boundary
*/
def keysPerPartition = boundary

var currPartition = 0

/*
* Pick current partition for the key until we hit the bound for keys / partition,
* start allocating to next partition at that time.
*
* NOTE: In case where we have lets say 2000 keys and user says 3 partitions with 500
* passed in as boundary, the first 500 will goto P1, 501-1000 go to P2, 1001-1500 go to P3,
* after that, next keys go to one partition at a time. So 1501 goes to P1, 1502 goes to P2,
* 1503 goes to P3 and so on.
*/
def getPartition(key: Any): Int = {
val partition = currPartition
counts(partition) = counts(partition) + 1
/*
* Since we are filling up a partition before moving to next one (this helps in maintaining
* order of keys, in certain cases, it is possible to end up with empty partitions, like
* 3 partitions, 500 keys / partition and if rdd has 700 keys, 1 partition will be entirely
* empty.
*/
if(counts(currPartition) >= keysPerPartition)
currPartition = (currPartition + 1) % numPartitions
partition
}

override def equals(other: Any): Boolean = other match {
case r: BoundaryPartitioner[_,_] =>
(r.counts.sameElements(counts) && r.boundary == boundary
&& r.currPartition == currPartition)
case _ =>
false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Return approximate number of distinct values for each key in this RDD.
* The accuracy of approximation can be controlled through the relative standard deviation
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
* more accurate counts but increase the memory footprint and vise versa. Uses the provided
* more accurate counts but increase the memory footprint and vice versa. Uses the provided
* Partitioner to partition the output RDD.
*/
def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): RDD[(K, Long)] = {
Expand All @@ -232,7 +232,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Return approximate number of distinct values for each key in this RDD.
* The accuracy of approximation can be controlled through the relative standard deviation
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
* more accurate counts but increase the memory footprint and vise versa. HashPartitions the
* more accurate counts but increase the memory footprint and vice versa. HashPartitions the
* output RDD into numPartitions.
*
*/
Expand All @@ -244,7 +244,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Return approximate number of distinct values for each key this RDD.
* The accuracy of approximation can be controlled through the relative standard deviation
* (relativeSD) parameter, which also controls the amount of memory used. Lower values result in
* more accurate counts but increase the memory footprint and vise versa. The default value of
* more accurate counts but increase the memory footprint and vice versa. The default value of
* relativeSD is 0.05. Hash-partitions the output RDD using the existing partitioner/parallelism
* level.
*/
Expand Down
34 changes: 34 additions & 0 deletions core/src/test/scala/org/apache/spark/PartitioningSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,40 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
assert(descendingP4 != p4)
}

test("BoundaryPartitioner equality") {
// Make an RDD where all the elements are the same so that the partition range bounds
// are deterministically all the same.
val rdd = sc.parallelize(1.to(4000)).map(x => (x, x))

val p2 = new BoundaryPartitioner(2, rdd, 1000)
val p4 = new BoundaryPartitioner(4, rdd, 1000)
val anotherP4 = new BoundaryPartitioner(4, rdd)

assert(p2 === p2)
assert(p4 === p4)
assert(p2 != p4)
assert(p4 != p2)
assert(p4 === anotherP4)
assert(anotherP4 === p4)
}

test("BoundaryPartitioner getPartition") {
val rdd = sc.parallelize(1.to(2000)).map(x => (x, x))
val partitioner = new BoundaryPartitioner(4, rdd, 500)
1.to(2000).map { element => {
val partition = partitioner.getPartition(element)
if (element <= 500) {
assert(partition === 0)
} else if (element > 501 && element <= 1000) {
assert(partition === 1)
} else if (element > 1001 && element <= 1500) {
assert(partition === 2)
} else if (element > 1501 && element <= 2000) {
assert(partition === 3)
}
}}
}

test("RangePartitioner getPartition") {
val rdd = sc.parallelize(1.to(2000)).map(x => (x, x))
// We have different behaviour of getPartition for partitions with less than 1000 and more than
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
rootDir0.deleteOnExit()
rootDir1 = Files.createTempDir()
rootDir1.deleteOnExit()
rootDirs = rootDir0.getName + "," + rootDir1.getName
rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
println("Created root dirs: " + rootDirs)
}

Expand Down
7 changes: 5 additions & 2 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class SQLContext:
register L{SchemaRDD}s as tables, execute sql over tables, cache tables, and read parquet files.
"""

def __init__(self, sparkContext):
def __init__(self, sparkContext, sqlContext = None):
"""
Create a new SQLContext.
Expand Down Expand Up @@ -58,10 +58,13 @@ def __init__(self, sparkContext):
self._jvm = self._sc._jvm
self._pythonToJavaMap = self._jvm.PythonRDD.pythonToJavaMap

if sqlContext:
self._scala_SQLContext = sqlContext

@property
def _ssql_ctx(self):
"""
Accessor for the JVM SparkSQL context. Subclasses can overrite this property to provide
Accessor for the JVM SparkSQL context. Subclasses can override this property to provide
their own JVM Contexts.
"""
if not hasattr(self, '_scala_SQLContext'):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val AND = Keyword("AND")
protected val AS = Keyword("AS")
protected val ASC = Keyword("ASC")
protected val APPROXIMATE = Keyword("APPROXIMATE")
protected val AVG = Keyword("AVG")
protected val BY = Keyword("BY")
protected val CAST = Keyword("CAST")
Expand Down Expand Up @@ -318,6 +319,12 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
COUNT ~> "(" ~ "*" <~ ")" ^^ { case _ => Count(Literal(1)) } |
COUNT ~> "(" ~ expression <~ ")" ^^ { case dist ~ exp => Count(exp) } |
COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } |
APPROXIMATE ~> COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ {
case exp => ApproxCountDistinct(exp)
} |
APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ {
case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble)
} |
FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } |
AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } |
MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import com.clearspring.analytics.stream.cardinality.HyperLogLog

import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.TreeNodeException
Expand Down Expand Up @@ -146,7 +148,6 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
override def eval(input: Row): Any = currentMax
}


case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
Expand All @@ -166,10 +167,47 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi
override def references = expressions.flatMap(_.references).toSet
override def nullable = false
override def dataType = IntegerType
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})"
override def newInstance() = new CountDistinctFunction(expressions, this)
}

case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
override def dataType = child.dataType
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
override def newInstance() = new ApproxCountDistinctPartitionFunction(child, this, relativeSD)
}

case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
override def dataType = IntegerType
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"
override def newInstance() = new ApproxCountDistinctMergeFunction(child, this, relativeSD)
}

case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
override def dataType = IntegerType
override def toString = s"APPROXIMATE COUNT(DISTINCT $child)"

override def asPartial: SplitEvaluation = {
val partialCount =
Alias(ApproxCountDistinctPartition(child, relativeSD), "PartialApproxCountDistinct")()

SplitEvaluation(
ApproxCountDistinctMerge(partialCount.toAttribute, relativeSD),
partialCount :: Nil)
}

override def newInstance() = new CountDistinctFunction(child :: Nil, this)
}

case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
Expand Down Expand Up @@ -269,6 +307,42 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
override def eval(input: Row): Any = count
}

case class ApproxCountDistinctPartitionFunction(
expr: Expression,
base: AggregateExpression,
relativeSD: Double)
extends AggregateFunction {
def this() = this(null, null, 0) // Required for serialization.

private val hyperLogLog = new HyperLogLog(relativeSD)

override def update(input: Row): Unit = {
val evaluatedExpr = expr.eval(input)
if (evaluatedExpr != null) {
hyperLogLog.offer(evaluatedExpr)
}
}

override def eval(input: Row): Any = hyperLogLog
}

case class ApproxCountDistinctMergeFunction(
expr: Expression,
base: AggregateExpression,
relativeSD: Double)
extends AggregateFunction {
def this() = this(null, null, 0) // Required for serialization.

private val hyperLogLog = new HyperLogLog(relativeSD)

override def update(input: Row): Unit = {
val evaluatedExpr = expr.eval(input)
hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
}

override def eval(input: Row): Any = hyperLogLog.cardinality()
}

case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ import org.apache.spark.util.Utils
/**
* The entry point for executing Spark SQL queries from a Java program.
*/
class JavaSQLContext(sparkContext: JavaSparkContext) {
class JavaSQLContext(val sqlContext: SQLContext) {

val sqlContext = new SQLContext(sparkContext.sc)
def this(sparkContext: JavaSparkContext) = this(new SQLContext(sparkContext.sc))

/**
* Executes a query expressed in SQL, returning the result as a JavaSchemaRDD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.nio.ByteBuffer

import scala.reflect.ClassTag

import com.clearspring.analytics.stream.cardinality.HyperLogLog
import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Serializer, Kryo}

Expand All @@ -44,6 +45,8 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer)
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
new HyperLogLogSerializer)
kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
kryo.setReferences(false)
Expand Down Expand Up @@ -81,6 +84,20 @@ private[sql] class BigDecimalSerializer extends Serializer[BigDecimal] {
}
}

private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
def write(kryo: Kryo, output: Output, hyperLogLog: HyperLogLog) {
val bytes = hyperLogLog.getBytes()
output.writeInt(bytes.length)
output.writeBytes(bytes)
}

def read(kryo: Kryo, input: Input, tpe: Class[HyperLogLog]): HyperLogLog = {
val length = input.readInt()
val bytes = input.readBytes(length)
HyperLogLog.Builder.build(bytes)
}
}

/**
* Maps do not have a no arg constructor and so cannot be serialized by default. So, we serialize
* them as `Array[(k,v)]`.
Expand Down
21 changes: 19 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,25 @@ class SQLQuerySuite extends QueryTest {
test("count") {
checkAnswer(
sql("SELECT COUNT(*) FROM testData2"),
testData2.count()
)
testData2.count())
}

test("count distinct") {
checkAnswer(
sql("SELECT COUNT(DISTINCT b) FROM testData2"),
2)
}

test("approximate count distinct") {
checkAnswer(
sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"),
3)
}

test("approximate count distinct with user provided standard deviation") {
checkAnswer(
sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"),
3)
}

// No support for primitive nulls yet.
Expand Down

0 comments on commit 1d35c3c

Please sign in to comment.