Skip to content

Commit

Permalink
implemented crosstab
Browse files Browse the repository at this point in the history
  • Loading branch information
brkyvz committed May 1, 2015
1 parent 7630213 commit 27a5a81
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.stat.FrequentItems
import org.apache.spark.sql.execution.stat.{ContingencyTable, FrequentItems}

/**
* :: Experimental ::
Expand All @@ -27,6 +27,20 @@ import org.apache.spark.sql.execution.stat.FrequentItems
@Experimental
final class DataFrameStatFunctions private[sql](df: DataFrame) {

/**
* Computes a pair-wise frequency table of the given columns. Also known as a contingency table.
* The number of distinct values for each column should be less than Int.MaxValue. The first
* column of each row will be the distinct values of `col1` and the column names will be the
* distinct values of `col2` sorted in lexicographical order. Counts will be returned as `Long`s.
*
* @param col1 The name of the first column.
* @param col2 The name of the second column.
* @return A Local DataFrame containing the table
*/
def crosstab(col1: String, col2: String): DataFrame = {
ContingencyTable.crossTabulate(df, col1, col2)
}

/**
* Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.apache.spark.sql.execution.stat

import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._


private[sql] object ContingencyTable {

/** Generate a table of frequencies for the elements of two columns. */
private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
val tableName = s"${col1}_$col2"
val distinctVals = df.select(countDistinct(col1), countDistinct(col2)).collect().head
val distinctCol1 = distinctVals.getLong(0)
val distinctCol2 = distinctVals.getLong(1)

require(distinctCol1 < Int.MaxValue, s"The number of distinct values for $col1, can't " +
s"exceed Int.MaxValue. Currently $distinctCol1")
require(distinctCol2 < Int.MaxValue, s"The number of distinct values for $col2, can't " +
s"exceed Int.MaxValue. Currently $distinctCol2")
// Aggregate the counts for the two columns
val allCounts =
df.groupBy(col1, col2).agg(col(col1), col(col2), count("*")).orderBy(col1, col2).collect()
// Pivot the table
val pivotedTable = allCounts.grouped(distinctCol2.toInt).toArray
// Get the column names (distinct values of col2)
val headerNames = pivotedTable.head.map(r => StructField(r.get(1).toString, LongType))
val schema = StructType(StructField(tableName, StringType) +: headerNames)
val table = pivotedTable.map { rows =>
// the value of col1 is the first value, the rest are the counts
val rowValues = rows.head.get(0).toString +: rows.map(_.getLong(2))
Row(rowValues:_*)
}
new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,24 @@ public void testCreateDataFrameFromJavaBeans() {
Assert.assertEquals(bean.getD().get(i), d.apply(i));
}
}

@Test
public void testCrosstab() {
DataFrame df = context.table("testData2");
DataFrame crosstab = df.stat().crosstab("a", "b");
String[] columnNames = crosstab.schema().fieldNames();
Assert.assertEquals(columnNames[0], "a_b");
Assert.assertEquals(columnNames[1], "1");
Assert.assertEquals(columnNames[2], "2");
Row[] rows = crosstab.collect();
Integer count = 1;
for (Row row : rows) {
Assert.assertEquals(row.get(0).toString(), count.toString());
Assert.assertEquals(row.getLong(1), 1L);
Assert.assertEquals(row.getLong(2), 1L);
count++;
}
}

@Test
public void testFrequentItems() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,26 @@ import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._

class DataFrameStatSuite extends FunSuite {
import TestData._

val sqlCtx = TestSQLContext

test("crosstab") {
val crosstab = testData2.stat.crosstab("a", "b")
val columnNames = crosstab.schema.fieldNames
assert(columnNames(0) === "a_b")
assert(columnNames(1) === "1")
assert(columnNames(2) === "2")
val rows: Array[Row] = crosstab.collect()
var count: Integer = 1
rows.foreach { row =>
assert(row.get(0).toString === count.toString)
assert(row.getLong(1) === 1L)
assert(row.getLong(2) === 1L)
count += 1
}
}

test("Frequent Items") {
def toLetter(i: Int): String = (i + 96).toChar.toString
val rows = Array.tabulate(1000) { i =>
Expand Down

0 comments on commit 27a5a81

Please sign in to comment.