Skip to content

Commit

Permalink
addressed comments v2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
brkyvz committed May 3, 2015
1 parent d10babb commit 285b838
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ private[sql] object StatFunctions {
s"with dataType ${data.get.dataType} not supported.")
}
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
df.select(columns:_*).rdd.aggregate(new CovarianceCounter)(
df.select(columns: _*).rdd.aggregate(new CovarianceCounter)(
seqOp = (counter, row) => {
counter.add(row.getDouble(0), row.getDouble(1))
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,35 @@ class DataFrameStatSuite extends FunSuite {
test("pearson correlation") {
val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
val corr1 = df.stat.corr("a", "b", "pearson")
assert(math.abs(corr1 - 1.0) < 1e-6)
assert(math.abs(corr1 - 1.0) < 1e-12)
val corr2 = df.stat.corr("a", "c", "pearson")
assert(math.abs(corr2 + 1.0) < 1e-6)
assert(math.abs(corr2 + 1.0) < 1e-12)
// non-trivial example. To reproduce in python, use:
// >>> from scipy.stats import pearsonr
// >>> import numpy as np
// >>> a = np.array(range(20))
// >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)])
// >>> pearsonr(a, b)
// (0.95723391394758572, 3.8902121417802199e-11)
// In R, use:
// > a <- 0:19
// > b <- mapply(function(x) x * x - 2 * x + 3.5, a)
// > cor(a, b)
// [1] 0.957233913947585835
val df2 = Seq.tabulate(20)(x => (x, x * x - 2 * x + 3.5)).toDF("a", "b")
val corr3 = df2.stat.corr("a", "b", "pearson")
assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
}

test("covariance") {
val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")

val results = df.stat.cov("singles", "doubles")
assert(math.abs(results - 55.0 / 3) < 1e-6)
assert(math.abs(results - 55.0 / 3) < 1e-12)
intercept[IllegalArgumentException] {
df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes
}
val decimalRes = decimalData.stat.cov("a", "b")
assert(math.abs(decimalRes) < 1e-6)
assert(math.abs(decimalRes) < 1e-12)
}
}

0 comments on commit 285b838

Please sign in to comment.