I have recently been experimenting with implementing automated feature extraction and evaluation techniques in PySpark.
Some of them rely on running repeated k-fold cross-validation.
PySpark provides the [`CrossValidator`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.tuning.CrossValidator.html) object for precisely that.
However, simple random splits are not a suitable approach when the data is heavily imbalanced, which it often is.
This short post shows a way to use the [`CrossValidator`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.tuning.CrossValidator.html) for running ***stratified*** k-fold cross-validation that keeps the class distribution similar across the folds.

In [None]:
import pyspark.sql.functions as F
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.sql.window import Window

Let's say we have completed all the preprocessing and stored the results in a Data Frame `proc` that has two columns: `target` and `features`.
The `target` in this minimal example is binary, but the same approach applies to a multiclass case.
For regression, an additional step of creating quantile buckets can be added to use this approach.
The `features` column is the vector of features: output of the [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) transformer.
Its schema would look like this:

In [18]:
proc.printSchema()

root
 |-- target: long (nullable = true)
 |-- features: vector (nullable = true)



The data in `proc` is imbalanced, so pure k-fold cross-validation is unreliable.
However, [`CrossValidator`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.tuning.CrossValidator.html) provides a possibility to specify the folds manually, which allows us to run more fancy versions of cross-validation using the same object.
The folds have to be stored in a column whose name is passed to the `foldCol` argument of the [`CrossValidator`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.tuning.CrossValidator.html) constructor. 
During $i$th iteration, rows with value $i$ in the `foldCol` are used as the validation set, and the rest are used for training the model.

One simple way to split the data into folds in a stratified fasion is to apply the [`ntile`](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.ntile.html) window function using partitions of the data by `target` and ordering by a random column to shuffle:

In [30]:
k = 5
proc = proc.withColumn(
    "fold",
    F.ntile(k).over(Window.partitionBy("target").orderBy(F.rand())) - 1,
)

The function assigns every row to one of $k$ evenly split groups.
Subtracting one converts the tile number to the index of the fold.
Here is how the data is split across the folds:

In [29]:
proc.groupBy("fold", "target").count().show()

                                                                                

+----+------+------+
|fold|target| count|
+----+------+------+
|   0|     0|234162|
|   1|     0|234161|
|   2|     0|234161|
|   3|     0|234161|
|   4|     0|234161|
|   0|     1| 29987|
|   1|     1| 29987|
|   2|     1| 29987|
|   3|     1| 29987|
|   4|     1| 29987|
+----+------+------+



                                                                                

All folds have consistent target rates and shoud provide a more stable estimation of the error.
Now set up a model and an evaluator:

In [23]:
rfc = RandomForestClassifier(
    featuresCol="features",
    labelCol="target",
    predictionCol="prediction",
    probabilityCol="probability",
    rawPredictionCol="rawPrediction",
    numTrees=350,
    maxDepth=7,
    maxBins=128,
    minInstancesPerNode=5,
)

In [24]:
binary_evaluator = BinaryClassificationEvaluator(
    rawPredictionCol="rawPrediction",
    labelCol="target",
    metricName="areaUnderROC",
)

And run the cross-validation:

In [26]:
cv = CrossValidator(
    estimator=rfc,
    estimatorParamMaps=ParamGridBuilder().build(),
    evaluator=binary_evaluator,
    numFolds=k,
    foldCol="fold",
)

In [None]:
cvm = cv.fit(proc)

And the result is:

In [33]:
cvm.avgMetrics

[0.6324294190273916]