Skip to content

Commit

Permalink
Merge branch 'master' of github.com:apache/spark into viz2
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Or committed May 4, 2015
2 parents 0d7aa32 + 9646018 commit afb98e2
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 26 deletions.
26 changes: 26 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,27 @@ def fillna(self, value, subset=None):

return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)

def corr(self, col1, col2, method=None):
"""
Calculates the correlation of two columns of a DataFrame as a double value. Currently only
supports the Pearson Correlation Coefficient.
:func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases.
:param col1: The name of the first column
:param col2: The name of the second column
:param method: The correlation method. Currently only supports "pearson"
"""
if not isinstance(col1, str):
raise ValueError("col1 should be a string.")
if not isinstance(col2, str):
raise ValueError("col2 should be a string.")
if not method:
method = "pearson"
if not method == "pearson":
raise ValueError("Currently only the calculation of the Pearson Correlation " +
"coefficient is supported.")
return self._jdf.stat().corr(col1, col2, method)

def cov(self, col1, col2):
"""
Calculate the sample covariance for the given columns, specified by their names, as a
Expand Down Expand Up @@ -1359,6 +1380,11 @@ class DataFrameStatFunctions(object):
def __init__(self, df):
self.df = df

def corr(self, col1, col2, method=None):
return self.df.corr(col1, col2, method)

corr.__doc__ = DataFrame.corr.__doc__

def cov(self, col1, col2):
return self.df.cov(col1, col2)

Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,12 @@ def test_aggregator(self):
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])

def test_corr(self):
import math
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
corr = df.stat.corr("a", "b")
self.assertTrue(abs(corr - 0.95734012) < 1e-6)

def test_cov(self):
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
cov = df.stat.cov("a", "b")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,32 @@ import org.apache.spark.sql.execution.stat._
@Experimental
final class DataFrameStatFunctions private[sql](df: DataFrame) {

/**
* Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
* MLlib's Statistics.
*
* @param col1 the name of the column
* @param col2 the name of the column to calculate the correlation against
* @return The Pearson Correlation Coefficient as a Double.
*/
def corr(col1: String, col2: String, method: String): Double = {
require(method == "pearson", "Currently only the calculation of the Pearson Correlation " +
"coefficient is supported.")
StatFunctions.pearsonCorrelation(df, Seq(col1, col2))
}

/**
* Calculates the Pearson Correlation Coefficient of two columns of a DataFrame.
*
* @param col1 the name of the column
* @param col2 the name of the column to calculate the correlation against
* @return The Pearson Correlation Coefficient as a Double.
*/
def corr(col1: String, col2: String): Double = {
corr(col1, col2, "pearson")
}

/**
* Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,43 +23,51 @@ import org.apache.spark.sql.types.{DoubleType, NumericType}

private[sql] object StatFunctions {

/** Calculate the Pearson Correlation Coefficient for the given columns */
private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
val counts = collectStatisticalData(df, cols)
counts.Ck / math.sqrt(counts.MkX * counts.MkY)
}

/** Helper class to simplify tracking and merging counts. */
private class CovarianceCounter extends Serializable {
var xAvg = 0.0
var yAvg = 0.0
var Ck = 0.0
var count = 0L
var xAvg = 0.0 // the mean of all examples seen so far in col1
var yAvg = 0.0 // the mean of all examples seen so far in col2
var Ck = 0.0 // the co-moment after k examples
var MkX = 0.0 // sum of squares of differences from the (current) mean for col1
var MkY = 0.0 // sum of squares of differences from the (current) mean for col1
var count = 0L // count of observed examples
// add an example to the calculation
def add(x: Double, y: Double): this.type = {
val oldX = xAvg
val deltaX = x - xAvg
val deltaY = y - yAvg
count += 1
xAvg += (x - xAvg) / count
yAvg += (y - yAvg) / count
Ck += (y - yAvg) * (x - oldX)
xAvg += deltaX / count
yAvg += deltaY / count
Ck += deltaX * (y - yAvg)
MkX += deltaX * (x - xAvg)
MkY += deltaY * (y - yAvg)
this
}
// merge counters from other partitions. Formula can be found at:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Covariance
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
def merge(other: CovarianceCounter): this.type = {
val totalCount = count + other.count
Ck += other.Ck +
(xAvg - other.xAvg) * (yAvg - other.yAvg) * count / totalCount * other.count
val deltaX = xAvg - other.xAvg
val deltaY = yAvg - other.yAvg
Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
count = totalCount
this
}
// return the sample covariance for the observed examples
def cov: Double = Ck / (count - 1)
}

/**
* Calculate the covariance of two numerical columns of a DataFrame.
* @param df The DataFrame
* @param cols the column names
* @return the covariance of the two columns.
*/
private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
private def collectStatisticalData(df: DataFrame, cols: Seq[String]): CovarianceCounter = {
require(cols.length == 2, "Currently cov supports calculating the covariance " +
"between two columns.")
cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) =>
Expand All @@ -68,13 +76,23 @@ private[sql] object StatFunctions {
s"with dataType ${data.get.dataType} not supported.")
}
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
val counts = df.select(columns:_*).rdd.aggregate(new CovarianceCounter)(
df.select(columns: _*).rdd.aggregate(new CovarianceCounter)(
seqOp = (counter, row) => {
counter.add(row.getDouble(0), row.getDouble(1))
},
combOp = (baseCounter, other) => {
baseCounter.merge(other)
})
})
}

/**
* Calculate the covariance of two numerical columns of a DataFrame.
* @param df The DataFrame
* @param cols the column names
* @return the covariance of the two columns.
*/
private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
val counts = collectStatisticalData(df, cols)
counts.cov
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ public void testFrequentItems() {
Assert.assertTrue(results.collect()[0].getSeq(0).contains(1));
}

@Test
public void testCorrelation() {
DataFrame df = context.table("testData2");
Double pearsonCorr = df.stat().corr("a", "b", "pearson");
Assert.assertTrue(Math.abs(pearsonCorr) < 1e-6);
}

@Test
public void testCovariance() {
DataFrame df = context.table("testData2");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ class DataFrameStatSuite extends FunSuite {
def toLetter(i: Int): String = (i + 97).toChar.toString

test("Frequent Items") {
val rows = Array.tabulate(1000) { i =>
val rows = Seq.tabulate(1000) { i =>
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
}
val df = sqlCtx.sparkContext.parallelize(rows).toDF("numbers", "letters", "negDoubles")
val df = rows.toDF("numbers", "letters", "negDoubles")

val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
val items = results.collect().head
Expand All @@ -43,19 +43,40 @@ class DataFrameStatSuite extends FunSuite {
val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
val items2 = singleColResults.collect().head
items2.getSeq[Double](0) should contain (-1.0)
}

test("pearson correlation") {
val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
val corr1 = df.stat.corr("a", "b", "pearson")
assert(math.abs(corr1 - 1.0) < 1e-12)
val corr2 = df.stat.corr("a", "c", "pearson")
assert(math.abs(corr2 + 1.0) < 1e-12)
// non-trivial example. To reproduce in python, use:
// >>> from scipy.stats import pearsonr
// >>> import numpy as np
// >>> a = np.array(range(20))
// >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)])
// >>> pearsonr(a, b)
// (0.95723391394758572, 3.8902121417802199e-11)
// In R, use:
// > a <- 0:19
// > b <- mapply(function(x) x * x - 2 * x + 3.5, a)
// > cor(a, b)
// [1] 0.957233913947585835
val df2 = Seq.tabulate(20)(x => (x, x * x - 2 * x + 3.5)).toDF("a", "b")
val corr3 = df2.stat.corr("a", "b", "pearson")
assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
}

test("covariance") {
val rows = Array.tabulate(10)(i => (i, 2.0 * i, toLetter(i)))
val df = sqlCtx.sparkContext.parallelize(rows).toDF("singles", "doubles", "letters")
val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")

val results = df.stat.cov("singles", "doubles")
assert(math.abs(results - 55.0 / 3) < 1e-6)
assert(math.abs(results - 55.0 / 3) < 1e-12)
intercept[IllegalArgumentException] {
df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes
}
val decimalRes = decimalData.stat.cov("a", "b")
assert(math.abs(decimalRes) < 1e-6)
assert(math.abs(decimalRes) < 1e-12)
}
}

0 comments on commit afb98e2

Please sign in to comment.