Skip to content

Commit

Permalink
[SPARK-6117] [SQL] add describe function to DataFrame for summary sta…
Browse files Browse the repository at this point in the history
…tis...

Please review my solution for SPARK-6117

Author: azagrebin <azagrebin@gmail.com>

Closes apache#5073 from azagrebin/SPARK-6117 and squashes the following commits:

f9056ac [azagrebin] [SPARK-6117] [SQL] create one aggregation and split it locally into resulting DF, colocate test data with test case
ddb3950 [azagrebin] [SPARK-6117] [SQL] simplify implementation, add test for DF without numeric columns
9daf31e [azagrebin] [SPARK-6117] [SQL] add describe function to DataFrame for summary statistics
  • Loading branch information
azagrebin authored and rxin committed Mar 26, 2015
1 parent f535802 commit 5bbcd13
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 1 deletion.
53 changes: 52 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
import org.apache.spark.sql.jdbc.JDBCWriteDetails
import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.types.{NumericType, StructType}
import org.apache.spark.sql.types.{NumericType, StructType, StructField, StringType}
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -751,6 +751,57 @@ class DataFrame private[sql](
select(colNames :_*)
}

/**
* Compute numerical statistics for given columns of this [[DataFrame]]:
* count, mean (avg), stddev (standard deviation), min, max.
* Each row of the resulting [[DataFrame]] contains column with statistic name
* and columns with statistic results for each given column.
* If no columns are given then computes for all numerical columns.
*
* {{{
* df.describe("age", "height")
*
* // summary age height
* // count 10.0 10.0
* // mean 53.3 178.05
* // stddev 11.6 15.7
* // min 18.0 163.0
* // max 92.0 192.0
* }}}
*/
@scala.annotation.varargs
def describe(cols: String*): DataFrame = {

def stddevExpr(expr: Expression) =
Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr))))

val statistics = List[(String, Expression => Expression)](
"count" -> Count,
"mean" -> Average,
"stddev" -> stddevExpr,
"min" -> Min,
"max" -> Max)

val aggCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList

val localAgg = if (aggCols.nonEmpty) {
val aggExprs = statistics.flatMap { case (_, colToAgg) =>
aggCols.map(c => Column(colToAgg(Column(c).expr)).as(c))
}

agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
.grouped(aggCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) =>
Row(statistic :: aggregation.toList: _*)
}
} else {
statistics.map { case (name, _) => Row(name) }
}

val schema = StructType(("summary" :: aggCols).map(StructField(_, StringType)))
val rowRdd = sqlContext.sparkContext.parallelize(localAgg)
sqlContext.createDataFrame(rowRdd, schema)
}

/**
* Returns the first `n` rows.
* @group action
Expand Down
45 changes: 45 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,51 @@ class DataFrameSuite extends QueryTest {
assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
}

test("describe") {

val describeTestData = Seq(
("Bob", 16, 176),
("Alice", 32, 164),
("David", 60, 192),
("Amy", 24, 180)).toDF("name", "age", "height")

val describeResult = Seq(
Row("count", 4, 4),
Row("mean", 33.0, 178.0),
Row("stddev", 16.583123951777, 10.0),
Row("min", 16, 164),
Row("max", 60, 192))

val emptyDescribeResult = Seq(
Row("count", 0, 0),
Row("mean", null, null),
Row("stddev", null, null),
Row("min", null, null),
Row("max", null, null))

def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq

val describeTwoCols = describeTestData.describe("age", "height")
assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height"))
checkAnswer(describeTwoCols, describeResult)

val describeAllCols = describeTestData.describe()
assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height"))
checkAnswer(describeAllCols, describeResult)

val describeOneCol = describeTestData.describe("age")
assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age"))
checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} )

val describeNoCol = describeTestData.select("name").describe()
assert(getSchemaAsSeq(describeNoCol) === Seq("summary"))
checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _) => Row(s)} )

val emptyDescription = describeTestData.limit(0).describe()
assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height"))
checkAnswer(emptyDescription, emptyDescribeResult)
}

test("apply on query results (SPARK-5462)") {
val df = testData.sqlContext.sql("select key from testData")
checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq)
Expand Down

0 comments on commit 5bbcd13

Please sign in to comment.