# Stratified train-test split in Spark

Dr Jose M Albornoz, June 2019

We will perform a stratified train-test split by following these steps:

* Determine how many examples of every label should be a part of train set given some ratio.
* Shuffle the rows of the DataFrame.
* Use window function to partition and order the DataFrame by label and then rank each label's observations using row_number().

In [1]:
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, row_number, udf, rand}

Intitializing Scala interpreter ...

Spark Web UI available at http://DESKTOP-FQ2BOOJ:4040
SparkContext available as 'sc' (version = 2.4.0, master = local[*], app id = local-1560170334163)
SparkSession available as 'spark'


import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, row_number, udf, rand}


# 1.- Some test data

In [2]:
val data = Seq((0, 0.0), (1, 1.0), (2, 0.0), (3, 1.0), (4, 0.0), (5, 1.0), (6, 0.0), (7, 1.0), (8, 0.0), (9, 1.0)) 
val df0 = data.toDF("id", "label")

data: Seq[(Int, Double)] = List((0,0.0), (1,1.0), (2,0.0), (3,1.0), (4,0.0), (5,1.0), (6,0.0), (7,1.0), (8,0.0), (9,1.0))
df0: org.apache.spark.sql.DataFrame = [id: int, label: double]


In [3]:
df0.show

+---+-----+
| id|label|
+---+-----+
|  0|  0.0|
|  1|  1.0|
|  2|  0.0|
|  3|  1.0|
|  4|  0.0|
|  5|  1.0|
|  6|  0.0|
|  7|  1.0|
|  8|  0.0|
|  9|  1.0|
+---+-----+



# 2.- Implementation 

In [4]:
def getNumExamplesPerClass(df: DataFrame, trainRatio: Double): Map[Int, Long] = {
    
    val countZeroes = (df.filter(df("label") === 0).count*trainRatio).toLong
    val countOnes = (df.filter(df("label") === 1).count*trainRatio).toLong
    Map(0 -> countZeroes, 1 -> countOnes)
}

getNumExamplesPerClass: (df: org.apache.spark.sql.DataFrame, trainRatio: Double)Map[Int,Long]


In [5]:
getNumExamplesPerClass(df0, 0.8)

res1: Map[Int,Long] = Map(0 -> 4, 1 -> 4)


In [6]:
def stratifiedTrainTestSplit(df: DataFrame, label: String, trainRatio: Double) = {
        
    val w = Window.partitionBy(col(label)).orderBy(col(label))

    val rowNumPartitioner = row_number().over(w)

    val dfRowNum = df.sort(rand).select(col("*"), rowNumPartitioner as "row_number")

    val observationsPerLabel: Map[Int, Long] = getNumExamplesPerClass(df, trainRatio)

    val addIsTrainColumn = udf((label: Int, rowNumber: Int) => rowNumber <= observationsPerLabel(label))

    val df_mark = dfRowNum.withColumn("isTrainSet", addIsTrainColumn(col(label), col("row_number")))
    
    val colsToRemove = Seq("row_number", "isTrainSet") 
    
    Array(df_mark.where(col("isTrainSet") === true).drop(colsToRemove : _*), 
          df_mark.where(col("isTrainSet") === false).drop(colsToRemove : _*))
}

stratifiedTrainTestSplit: (df: org.apache.spark.sql.DataFrame, label: String, trainRatio: Double)Array[org.apache.spark.sql.DataFrame]


In [7]:
val Array(train, test) = stratifiedTrainTestSplit(df0, "label", 0.8)

train: org.apache.spark.sql.DataFrame = [id: int, label: double]
test: org.apache.spark.sql.DataFrame = [id: int, label: double]


In [8]:
train.show

+---+-----+
| id|label|
+---+-----+
|  8|  0.0|
|  6|  0.0|
|  4|  0.0|
|  2|  0.0|
|  1|  1.0|
|  9|  1.0|
|  7|  1.0|
|  5|  1.0|
+---+-----+



In [9]:
test.show

+---+-----+
| id|label|
+---+-----+
|  0|  0.0|
|  3|  1.0|
+---+-----+

