In [1]:
spark
spark.version

3.4.1

In [2]:
import org.apache.spark.ml.stat.ChiSquareTest
val sess = SparkSession.builder()
  .appName("MyNotebook")
  .master("local[*]")
  .config("spark.jars.packages", "org.apache.spark:spark-mllib_2.13:3.4.1")
  .getOrCreate()

import sess.implicits._
import org.apache.spark.ml.linalg.Vectors

sess = org.apache.spark.sql.SparkSession@490f0576


org.apache.spark.sql.SparkSession@490f0576

## test native chi square

In [3]:
val df = Seq(
  (0.0, Vectors.dense(1.0, 0.0)),
  (1.0, Vectors.dense(0.0, 1.0)),
  (0.0, Vectors.dense(1.0, 1.0))
).toDF("label", "features")

df = [label: double, features: vector]


[label: double, features: vector]

In [4]:
df.show()

+-----+---------+
|label| features|
+-----+---------+
|  0.0|[1.0,0.0]|
|  1.0|[0.0,1.0]|
|  0.0|[1.0,1.0]|
+-----+---------+



In [5]:
import org.apache.spark.ml.stat.ChiSquareTest

val chi    = ChiSquareTest.test(df, "features", "label")

chi = [pValues: vector, degreesOfFreedom: array<int> ... 1 more field]


[pValues: vector, degreesOfFreedom: array<int> ... 1 more field]

In [6]:
chi.show()

+--------------------+----------------+--------------------+
|             pValues|degreesOfFreedom|          statistics|
+--------------------+----------------+--------------------+
|[0.08326451666354...|          [1, 1]|[3.00000000000000...|
+--------------------+----------------+--------------------+



## test two sample tests

In [7]:
val data = Seq(
      (5.1, 0), (4.9, 0), (5.0, 0), (5.2, 0), (5.3, 0),    // treatment = 0
      (6.1, 1), (6.3, 1), (6.5, 1), (6.2, 1), (6.4, 1)     // treatment = 1
    ).toDF("response", "treatment")

data = [response: double, treatment: int]


[response: double, treatment: int]

In [8]:
data.show()

+--------+---------+
|response|treatment|
+--------+---------+
|     5.1|        0|
|     4.9|        0|
|     5.0|        0|
|     5.2|        0|
|     5.3|        0|
|     6.1|        1|
|     6.3|        1|
|     6.5|        1|
|     6.2|        1|
|     6.4|        1|
+--------+---------+



In [12]:
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.types.DoubleType
import org.apache.commons.math3.distribution.NormalDistribution
import org.apache.spark.rdd.RDD

In [31]:
object TwoSample {

  def zeroTrimmedU(
    xRdd: RDD[Double],
    yRdd: RDD[Double],
    alpha: Double = 0.05,
    scale: Boolean = true
  ): (Double, Double, Double, (Double, Double)) = {
    // 1) Basic counts & checks
    val n0 = xRdd.count.toDouble
    val n1 = yRdd.count.toDouble
    require(n0 > 0 && n1 > 0, "Both RDDs must be non-empty")
    require(xRdd.filter(_ < 0).isEmpty(), "All x must be ≥ 0")
    require(yRdd.filter(_ < 0).isEmpty(), "All y must be ≥ 0")

    // 2) Proportions of non-zeros
    val xPlus = xRdd.filter(_ > 0)
    val yPlus = yRdd.filter(_ > 0)
    val pHat0 = xPlus.count / n0
    val pHat1 = yPlus.count / n1
    val pHat = math.max(pHat0, pHat1)

    // 3) Truncate zeros
    val nPrime0 = math.round(n0 * pHat).toInt
    val nPrime1 = math.round(n1 * pHat).toInt
    val nPlus0 = xPlus.count.toDouble
    val nPlus1 = yPlus.count.toDouble
    val pad0 = Seq.fill(nPrime0 - nPlus0.toInt)(0.0)
    val pad1 = Seq.fill(nPrime1 - nPlus1.toInt)(0.0)

    val xTrun = xRdd.sparkContext.parallelize(pad0) union xPlus
    val yTrun = yRdd.sparkContext.parallelize(pad1) union yPlus

    // 4) Compute descending‐ordinal ranks
    val tagged: RDD[(Double, Boolean)] =
      yTrun.map(v => (v, true)) union xTrun.map(v => (v, false))

    val ranks: RDD[((Double, Boolean), Long)] =
      tagged.sortBy({ case (v, _) => -v }).zipWithIndex()

    val R1: Double =
      ranks.filter { case ((_, isY), _) => isY }
        .map { case (_, idx) => (idx + 1).toDouble }
        .sum()

    // 5) Wilcoxon-style statistic
    val wPrime = - (R1 - nPrime1 * (nPrime0 + nPrime1 + 1) / 2.0)

    // 6) Variance components
    val varComp1 = (n1 * n0 * n1 * n0 / 4.0) * (pHat * pHat) * (
      (pHat0 * (1 - pHat0) / n0) + (pHat1 * (1 - pHat1) / n1)
    )
    val varComp2 = (nPlus0 * nPlus1 * (nPlus0 + nPlus1)) / 12.0
    val varW = varComp1 + varComp2

    // 7) Z and p-value
    val z = wPrime / math.sqrt(varW)
    val pValue = 2 * (1 - normalCDF(z))
    val zAlpha = normalQuantile(1 - alpha / 2)
    val confidenceInterval = (wPrime - zAlpha * math.sqrt(varW), wPrime + zAlpha * math.sqrt(varW))

    // 8) Scale the statistic to P(X' < Y')
    if (scale) {
      val locationFactor = (nPrime1.toDouble * nPrime0.toDouble) * 0.5
      val scaleFactor = 1.0 * nPrime1.toDouble * nPrime0.toDouble
      val wPrimeScaled = (wPrime + locationFactor)/scaleFactor
      val confidenceIntervalScaled = (
        (confidenceInterval._1 + locationFactor) / scaleFactor,
        (confidenceInterval._2 + locationFactor) / scaleFactor
      )
      return (z, pValue, wPrimeScaled, confidenceIntervalScaled)
    }

    (z, pValue, wPrime, confidenceInterval)
  }

  def mwU(
    xRdd: RDD[Double],
    yRdd: RDD[Double],
    alpha: Double = 0.05,
    scale: Boolean = true
  ): (Double, Double, Double, (Double, Double)) = {
    // 1) Basic counts & checks
    val n0 = xRdd.count.toDouble
    val n1 = yRdd.count.toDouble
    require(n0 > 0 && n1 > 0, "Both RDDs must be non-empty")

    // 2) Compute descending‐ordinal ranks
    val tagged: RDD[(Double, Boolean)] =
      yRdd.map(v => (v, true)) union xRdd.map(v => (v, false))

    val ranks: RDD[((Double, Boolean), Long)] =
      tagged.sortBy({ case (v, _) => -v }).zipWithIndex()

    val R1: Double =
      ranks.filter { case ((_, isY), _) => isY }
        .map { case (_, idx) => (idx + 1).toDouble }
        .sum()
    
    // 3) Wilcoxon-style statistic
    val w = - (R1 - n1 * (n0 + n1 + 1) / 2.0)

    // 4) Variance
    val varW = n0 * n1 * (n0 + n1 + 1) / 12.0

    // 5) Z and p-value
    val z = w / math.sqrt(varW)
    val pValue = 2 * (1 - normalCDF(z))
    val zAlpha = normalQuantile(1 - alpha / 2)
    val confidenceInterval = (w - zAlpha * math.sqrt(varW), w + zAlpha * math.sqrt(varW))

    // 6) Scale the statistic to P(X' < Y')
    if (scale) {
      val locationFactor = (n1 * n0) / 2.0
      val scaleFactor = n1 * n0
      val wScaled = (w + locationFactor) / scaleFactor
      val confidenceIntervalScaled = (
        (confidenceInterval._1 + locationFactor) / scaleFactor,
        (confidenceInterval._2 + locationFactor) / scaleFactor
      )
      return (z, pValue, wScaled, confidenceIntervalScaled)
    }
    
    (z, pValue, w, confidenceInterval)
  }

  def tTest(
    xRdd: RDD[Double],
    yRdd: RDD[Double],
    alpha: Double = 0.05
  ): (Double, Double, Double, (Double, Double)) = {
    // This function performs a two-sample t-test on two RDDs of doubles.
    // 1) Basic counts & checks
    val n0 = xRdd.count.toDouble
    val n1 = yRdd.count.toDouble
    require(n0 > 0 && n1 > 0, "Both RDDs must be non-empty")

    // 2) Calculate means, variances, and counts for each group
    val mean0 = xRdd.mean()
    val mean1 = yRdd.mean()
    val var0 = xRdd.variance()
    val var1 = yRdd.variance()

    // 3) Perform the t-test
    val stdErrorDifference = math.sqrt(var0 / n0 + var1 / n1)
    val z = (mean0 - mean1) / stdErrorDifference

    // 4) Calculate the p-value using the normal distribution CDF
    val pValue = 2 * (1 - normalCDF(math.abs(z)))

    // 5) Calculate the 95% confidence interval for the mean difference
    val meanDifference = mean1 - mean0
    val zAlpha = normalQuantile(1 - alpha / 2)
    val confidenceInterval = (meanDifference - zAlpha * stdErrorDifference, meanDifference + zAlpha * stdErrorDifference)

    (z, pValue, meanDifference, confidenceInterval)
  }

  def zeroTrimmedUDf(data: DataFrame, groupCol: String, valueCol: String,
    controlStr: String, treatmentStr: String, alpha: Double): (Double, Double, Double, (Double, Double)) = {
    // This test basically test P(X < Y) = 0.5, where X is a random variable from control group and Y is a random variable from treatment group
    // Filter and select the relevant data
    val filteredData = data
      .withColumn(valueCol, col(valueCol).cast(DoubleType))
      .filter(col(groupCol).isin(controlStr, treatmentStr))

    val summary = filteredData.groupBy(groupCol).agg(
      sum(when(col(valueCol) > 0, 1.0).otherwise(col(valueCol))).as("positiveCount"),
      mean(when(col(valueCol) > 0, 1.0).otherwise(col(valueCol))).as("theta"),
      count(valueCol).alias("count"))
    
    val n0Plus = summary.filter(col(groupCol) === controlStr).first().getDouble(1)
    val p0Hat = summary.filter(col(groupCol) === controlStr).first().getDouble(2)
    val n0 = summary.filter(col(groupCol) === controlStr).first().getLong(3)

    val n1Plus = summary.filter(col(groupCol) === treatmentStr).first().getDouble(1)
    val p1Hat = summary.filter(col(groupCol) === treatmentStr).first().getDouble(2)
    val n1 = summary.filter(col(groupCol) === treatmentStr).first().getLong(3)

    val pHat = if (p0Hat > p1Hat) p0Hat else p1Hat
    val samplingGrpStr = if (p0Hat > p1Hat) treatmentStr else controlStr
    val samplingSize = math.round(math.abs(p0Hat - p1Hat) * (if (p0Hat > p1Hat) n1 else n0)).toInt
    val zeroData = filteredData.filter(col(groupCol) === samplingGrpStr).filter(col(valueCol) === 0).limit(samplingSize)
    val positiveData = filteredData.filter(col(valueCol) > 0)
    val trimmedData = positiveData.union(zeroData)
    trimmedData.cache()

    val rankedData = trimmedData.withColumn("rank", row_number().over(Window.orderBy(desc(valueCol))))
      .withColumn("rankD", col("rank").cast(DoubleType))
    val r1 = rankedData.filter(col(groupCol) === treatmentStr).agg(sum("rankD")).first().getDouble(0)
    val n0Prime = trimmedData.filter(col(groupCol) === controlStr).count().toDouble
    val n1Prime = trimmedData.filter(col(groupCol) === treatmentStr).count().toDouble
    trimmedData.unpersist()

    val wPrime = - r1 + n1Prime * (n1Prime + n0Prime + 1) / 2

    val varComp1 = math.pow(n0, 2) * math.pow(n1, 2) / 4 *
      math.pow(pHat, 2) *
      ((p0Hat * (1 - p0Hat)) / n0 + (p1Hat * (1 - p1Hat)) / n1)
    val varComp2 = n1Plus * n0Plus * (n1Plus + n0Plus) / 12
    val varW = varComp1 + varComp2

    val z = wPrime / math.sqrt(varW)

    // Calculate the p-value using the normal distribution CDF
    val pValue = 2 * (1 - normalCDF(z))
    val zAlpha = normalQuantile(1 - alpha / 2)
    val confidenceInterval = (wPrime - zAlpha * math.sqrt(varW), wPrime + zAlpha * math.sqrt(varW))

    (z, pValue, wPrime, confidenceInterval)
  }

  def tTestDf(data: DataFrame, groupCol: String, valueCol: String,
    controlStr: String, treatmentStr: String, alpha: Double): (Double, Double, Double, (Double, Double)) = {
    // Filter and select the relevant data
    val filteredData = data
      .withColumn(valueCol, col(valueCol).cast(DoubleType))
      .filter(col(groupCol).isin(controlStr, treatmentStr))

    // Calculate means, variances, and counts for each group
    val summary = filteredData.groupBy(groupCol).agg(
      mean(valueCol).alias("mean"),
      variance(valueCol).alias("variance"),
      count(valueCol).alias("count")
    )

    // Extract mean, variance, and count for control and treatment
    val controlMean = summary.filter(col(groupCol) === controlStr).first().getDouble(1)
    val controlVariance = summary.filter(col(groupCol) === controlStr).first().getDouble(2)
    val controlCount = summary.filter(col(groupCol) === controlStr).first().getLong(3)

    val treatmentMean = summary.filter(col(groupCol) === treatmentStr).first().getDouble(1)
    val treatmentVariance = summary.filter(col(groupCol) === treatmentStr).first().getDouble(2)
    val treatmentCount = summary.filter(col(groupCol) === treatmentStr).first().getLong(3)

    // Perform the t-test
    val stdErrorDifference = math.sqrt(controlVariance/ controlCount + treatmentVariance / treatmentCount)
    val t = math.abs(controlMean - treatmentMean) / stdErrorDifference

    // Calculate the p-value using the normal distribution CDF
    val pValue = 2 * (1 - normalCDF(t))

    // Calculate the 95% confidence interval for the mean difference
    val meanDifference = treatmentMean - controlMean
    val zAlpha = normalQuantile(1 - alpha / 2)
    val confidenceInterval = (meanDifference - zAlpha * stdErrorDifference, meanDifference + zAlpha * stdErrorDifference)

    (t, pValue, meanDifference, confidenceInterval)
  }

  // Custom implementation of the normal distribution cumulative distribution function (CDF)
  def normalCDF(t: Double): Double = {
    val standardNormal = new NormalDistribution(0, 1)
    standardNormal.cumulativeProbability(Math.abs(t))
  }
  // Custom implementation of the normal distribution quantile function (inverse CDF)
  def normalQuantile(p: Double): Double = {
    val standardNormal = new NormalDistribution(0, 1)
    standardNormal.inverseCumulativeProbability(p)
  }
}


defined object TwoSample


In [15]:
val data2 = data.withColumn("treatment_str", col("treatment").cast("string"))

data2 = [response: double, treatment: int ... 1 more field]


[response: double, treatment: int ... 1 more field]

In [16]:
data2.show()

+--------+---------+-------------+
|response|treatment|treatment_str|
+--------+---------+-------------+
|     5.1|        0|            0|
|     4.9|        0|            0|
|     5.0|        0|            0|
|     5.2|        0|            0|
|     5.3|        0|            0|
|     6.1|        1|            1|
|     6.3|        1|            1|
|     6.5|        1|            1|
|     6.2|        1|            1|
|     6.4|        1|            1|
+--------+---------+-------------+



In [17]:
val tTest = TwoSample.tTestDf(data2, "treatment_str", "response", "0", "1", 0.05)

tTest = (11.999999999999996,0.0,1.2000000000000002,(1.0040036015459948,1.3959963984540056))


(11.999999999999996,0.0,1.2000000000000002,(1.0040036015459948,1.3959963984540056))

In [19]:
val zeroTrimU = TwoSample.zeroTrimmedUDf(data2, "treatment_str", "response", "0", "1", 0.05)

zeroTrimU = (2.7386127875258306,0.0061698993205441255,12.5,(3.5540292814142145,21.445970718585784))


(2.7386127875258306,0.0061698993205441255,12.5,(3.5540292814142145,21.445970718585784))

## more test with zero

In [20]:
val dataWithZerosUnequal = Seq(
  (0.0, "control"), (0.0, "control"), (5.0, "control"), (5.2, "control"), (5.3, "control"), // control group with zeros
  (0.0, "treatment"), (6.3, "treatment"), (6.5, "treatment"), (0.0, "treatment"), (6.4, "treatment"), (6.6, "treatment"), (6.7, "treatment") // treatment group with zeros and larger sample size
).toDF("response", "group")

// Show the generated data
dataWithZerosUnequal.show()

dataWithZerosUnequal = [response: double, group: string]


+--------+---------+
|response|    group|
+--------+---------+
|     0.0|  control|
|     0.0|  control|
|     5.0|  control|
|     5.2|  control|
|     5.3|  control|
|     0.0|treatment|
|     6.3|treatment|
|     6.5|treatment|
|     0.0|treatment|
|     6.4|treatment|
|     6.6|treatment|
|     6.7|treatment|
+--------+---------+



[response: double, group: string]

In [22]:
val zeroTrimU = TwoSample.zeroTrimmedUDf(dataWithZerosUnequal, "group", "response", "control", "treatment", 0.05)

zeroTrimU = (2.1293281415589513,0.03322712107286785,10.0,(0.7953877737928181,19.204612226207182))


(2.1293281415589513,0.03322712107286785,10.0,(0.7953877737928181,19.204612226207182))

## RDD

In [23]:
// Extract x (control group) and y (treatment group) as RDDs
val xRdd = dataWithZerosUnequal.filter(col("group") === "control").select("response").rdd.map(row => row.getDouble(0))
val yRdd = dataWithZerosUnequal.filter(col("group") === "treatment").select("response").rdd.map(row => row.getDouble(0))


xRdd = MapPartitionsRDD[192] at map at <console>:47
yRdd = MapPartitionsRDD[198] at map at <console>:48


MapPartitionsRDD[198] at map at <console>:48

In [32]:
val zeroTrimU = TwoSample.zeroTrimmedU(xRdd, yRdd)

zeroTrimU = (2.1293281415589513,0.03322712107286785,1.0,(0.5397693886896409,1.4602306113103591))


(2.1293281415589513,0.03322712107286785,1.0,(0.5397693886896409,1.4602306113103591))

In [33]:
val mwU = TwoSample.mwU(xRdd, yRdd)

mwU = (1.8675952687646453,0.06181850640046682,0.8285714285714286,(0.48374930510628805,1.173393552036569))


(1.8675952687646453,0.06181850640046682,0.8285714285714286,(0.48374930510628805,1.173393552036569))

In [34]:
val t = TwoSample.tTest(xRdd, yRdd)

t = (-0.9724839617578134,0.33080983990375046,1.5428571428571431,(-1.5666485685012623,4.652362854215548))


(-0.9724839617578134,0.33080983990375046,1.5428571428571431,(-1.5666485685012623,4.652362854215548))