diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 42e5cbc05e1e0..ba90888226194 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.stat.FrequentItems +import org.apache.spark.sql.execution.stat.{ContingencyTable, FrequentItems} /** * :: Experimental :: @@ -27,6 +27,20 @@ import org.apache.spark.sql.execution.stat.FrequentItems @Experimental final class DataFrameStatFunctions private[sql](df: DataFrame) { + /** + * Computes a pair-wise frequency table of the given columns. Also known as a contingency table. + * The number of distinct values for each column should be less than Int.MaxValue. The first + * column of each row will be the distinct values of `col1` and the column names will be the + * distinct values of `col2` sorted in lexicographical order. Counts will be returned as `Long`s. + * + * @param col1 The name of the first column. + * @param col2 The name of the second column. + * @return A Local DataFrame containing the table + */ + def crosstab(col1: String, col2: String): DataFrame = { + ContingencyTable.crossTabulate(df, col1, col2) + } + /** * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/ContingencyTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/ContingencyTable.scala new file mode 100644 index 0000000000000..916df88fec957 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/ContingencyTable.scala @@ -0,0 +1,38 @@ +package org.apache.spark.sql.execution.stat + +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types._ +import org.apache.spark.sql.functions._ + + +private[sql] object ContingencyTable { + + /** Generate a table of frequencies for the elements of two columns. */ + private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = { + val tableName = s"${col1}_$col2" + val distinctVals = df.select(countDistinct(col1), countDistinct(col2)).collect().head + val distinctCol1 = distinctVals.getLong(0) + val distinctCol2 = distinctVals.getLong(1) + + require(distinctCol1 < Int.MaxValue, s"The number of distinct values for $col1, can't " + + s"exceed Int.MaxValue. Currently $distinctCol1") + require(distinctCol2 < Int.MaxValue, s"The number of distinct values for $col2, can't " + + s"exceed Int.MaxValue. Currently $distinctCol2") + // Aggregate the counts for the two columns + val allCounts = + df.groupBy(col1, col2).agg(col(col1), col(col2), count("*")).orderBy(col1, col2).collect() + // Pivot the table + val pivotedTable = allCounts.grouped(distinctCol2.toInt).toArray + // Get the column names (distinct values of col2) + val headerNames = pivotedTable.head.map(r => StructField(r.get(1).toString, LongType)) + val schema = StructType(StructField(tableName, StringType) +: headerNames) + val table = pivotedTable.map { rows => + // the value of col1 is the first value, the rest are the counts + val rowValues = rows.head.get(0).toString +: rows.map(_.getLong(2)) + Row(rowValues:_*) + } + new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)) + } + +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index ebe96e649d940..3896ef96ffb40 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -178,6 +178,24 @@ public void testCreateDataFrameFromJavaBeans() { Assert.assertEquals(bean.getD().get(i), d.apply(i)); } } + + @Test + public void testCrosstab() { + DataFrame df = context.table("testData2"); + DataFrame crosstab = df.stat().crosstab("a", "b"); + String[] columnNames = crosstab.schema().fieldNames(); + Assert.assertEquals(columnNames[0], "a_b"); + Assert.assertEquals(columnNames[1], "1"); + Assert.assertEquals(columnNames[2], "2"); + Row[] rows = crosstab.collect(); + Integer count = 1; + for (Row row : rows) { + Assert.assertEquals(row.get(0).toString(), count.toString()); + Assert.assertEquals(row.getLong(1), 1L); + Assert.assertEquals(row.getLong(2), 1L); + count++; + } + } @Test public void testFrequentItems() { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index bb1d29c71d23b..c0afef2b7fe11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -24,9 +24,26 @@ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ class DataFrameStatSuite extends FunSuite { + import TestData._ val sqlCtx = TestSQLContext + test("crosstab") { + val crosstab = testData2.stat.crosstab("a", "b") + val columnNames = crosstab.schema.fieldNames + assert(columnNames(0) === "a_b") + assert(columnNames(1) === "1") + assert(columnNames(2) === "2") + val rows: Array[Row] = crosstab.collect() + var count: Integer = 1 + rows.foreach { row => + assert(row.get(0).toString === count.toString) + assert(row.getLong(1) === 1L) + assert(row.getLong(2) === 1L) + count += 1 + } + } + test("Frequent Items") { def toLetter(i: Int): String = (i + 96).toChar.toString val rows = Array.tabulate(1000) { i =>